Repository: NVIDIA/spark-rapids-examples Branch: main Commit: 162959461bf8 Files: 277 Total size: 3.5 MB Directory structure: gitextract_pa_r2orm/ ├── .github/ │ ├── ISSUE_TEMPLATE/ │ │ └── bug_report.md │ └── workflows/ │ ├── add-to-project.yml │ ├── license-header-check.yml │ ├── markdown-links-check/ │ │ └── markdown-links-check-config.json │ ├── markdown-links-check.yml │ ├── shell-check.yml │ └── signoff-check.yml ├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── dockerfile/ │ ├── Dockerfile │ └── gpu_executor_template.yaml ├── docs/ │ ├── get-started/ │ │ └── xgboost-examples/ │ │ ├── building-sample-apps/ │ │ │ ├── python.md │ │ │ └── scala.md │ │ ├── csp/ │ │ │ ├── aws/ │ │ │ │ └── ec2.md │ │ │ ├── databricks/ │ │ │ │ ├── databricks.md │ │ │ │ └── init.sh │ │ │ └── dataproc/ │ │ │ └── gcp.md │ │ ├── dataset/ │ │ │ └── mortgage.md │ │ ├── notebook/ │ │ │ ├── python-notebook.md │ │ │ ├── spylon.md │ │ │ └── toree.md │ │ ├── on-prem-cluster/ │ │ │ ├── kubernetes-scala.md │ │ │ ├── standalone-python.md │ │ │ ├── standalone-scala.md │ │ │ ├── yarn-python.md │ │ │ └── yarn-scala.md │ │ └── prepare-package-data/ │ │ ├── preparation-python.md │ │ └── preparation-scala.md │ └── trouble-shooting/ │ └── xgboost-examples-trouble-shooting.md ├── examples/ │ ├── MIG-Support/ │ │ ├── README.md │ │ ├── device-plugins/ │ │ │ └── gpu-mig/ │ │ │ ├── README.md │ │ │ ├── pom.xml │ │ │ ├── scripts/ │ │ │ │ └── getMIGGPUs │ │ │ └── src/ │ │ │ ├── main/ │ │ │ │ └── java/ │ │ │ │ └── com/ │ │ │ │ └── nvidia/ │ │ │ │ └── spark/ │ │ │ │ └── NvidiaGPUMigPluginForRuntimeV2.java │ │ │ └── test/ │ │ │ └── java/ │ │ │ └── com/ │ │ │ └── nvidia/ │ │ │ └── spark/ │ │ │ └── TestNvidiaGPUMigPluginForRuntimeV2.java │ │ ├── resource-types/ │ │ │ └── gpu-mig/ │ │ │ ├── README.md │ │ │ ├── yarn312MIG.patch │ │ │ ├── yarn313to315MIG.patch │ │ │ └── yarn321to323MIG.patch │ │ └── yarn-unpatched/ │ │ ├── README.md │ │ └── scripts/ │ │ ├── mig2gpu.sh │ │ ├── nvidia-container-cli-wrapper.sh │ │ └── nvidia-smi │ ├── ML+DL-Examples/ │ │ ├── Optuna-Spark/ │ │ │ ├── README.md │ │ │ └── optuna-examples/ │ │ │ ├── databricks/ │ │ │ │ ├── init_optuna.sh │ │ │ │ └── start_cluster.sh │ │ │ ├── optuna-dataframe.ipynb │ │ │ └── optuna-joblibspark.ipynb │ │ ├── Spark-DL/ │ │ │ └── dl_inference/ │ │ │ ├── README.md │ │ │ ├── databricks/ │ │ │ │ ├── README.md │ │ │ │ └── setup/ │ │ │ │ ├── init_spark_dl.sh │ │ │ │ └── start_cluster.sh │ │ │ ├── dataproc/ │ │ │ │ ├── README.md │ │ │ │ └── setup/ │ │ │ │ ├── init_spark_dl.sh │ │ │ │ └── start_cluster.sh │ │ │ ├── huggingface/ │ │ │ │ ├── conditional_generation_tf.ipynb │ │ │ │ ├── conditional_generation_torch.ipynb │ │ │ │ ├── deepseek-r1_torch.ipynb │ │ │ │ ├── gemma-7b_torch.ipynb │ │ │ │ ├── pipelines_tf.ipynb │ │ │ │ ├── pipelines_torch.ipynb │ │ │ │ ├── qwen-2.5-7b_torch.ipynb │ │ │ │ └── sentence_transformers_torch.ipynb │ │ │ ├── pytorch/ │ │ │ │ ├── housing_regression_torch.ipynb │ │ │ │ └── image_classification_torch.ipynb │ │ │ ├── requirements.txt │ │ │ ├── server_utils.py │ │ │ ├── tensorflow/ │ │ │ │ ├── image_classification_tf.ipynb │ │ │ │ ├── keras_preprocessing_tf.ipynb │ │ │ │ ├── keras_resnet50_tf.ipynb │ │ │ │ └── text_classification_tf.ipynb │ │ │ ├── tf_requirements.txt │ │ │ ├── torch_requirements.txt │ │ │ ├── vllm/ │ │ │ │ ├── qwen-2.5-14b-tensor-parallel_vllm.ipynb │ │ │ │ └── qwen-2.5-7b_vllm.ipynb │ │ │ └── vllm_requirements.txt │ │ └── Spark-Rapids-ML/ │ │ └── pca/ │ │ ├── README.md │ │ ├── notebooks/ │ │ │ └── pca.ipynb │ │ └── start-spark-rapids.sh │ ├── SQL+DF-Examples/ │ │ ├── customer-churn/ │ │ │ ├── README.md │ │ │ └── notebooks/ │ │ │ └── python/ │ │ │ ├── README.md │ │ │ ├── augment.ipynb │ │ │ ├── churn/ │ │ │ │ ├── augment.py │ │ │ │ ├── eda.py │ │ │ │ └── etl.py │ │ │ └── etl.ipynb │ │ ├── demo/ │ │ │ ├── Spark_get_json_object.ipynb │ │ │ └── Spark_parquet_microkernels.ipynb │ │ ├── micro-benchmarks/ │ │ │ ├── README.md │ │ │ └── notebooks/ │ │ │ ├── micro-benchmarks-cpu.ipynb │ │ │ └── micro-benchmarks-gpu.ipynb │ │ ├── retail-analytics/ │ │ │ ├── README.md │ │ │ └── notebooks/ │ │ │ └── python/ │ │ │ ├── retail-analytic.ipynb │ │ │ └── retail-datagen.ipynb │ │ └── tpcds/ │ │ ├── README.md │ │ └── notebooks/ │ │ └── TPCDS-SF10.ipynb │ ├── UDF-Examples/ │ │ └── RAPIDS-accelerated-UDFs/ │ │ ├── Dockerfile │ │ ├── README.md │ │ ├── clone-cudf-repo.sh │ │ ├── conftest.py │ │ ├── extract-cudf-libs.sh │ │ ├── pom.xml │ │ ├── pytest.ini │ │ ├── run_pyspark_from_build.sh │ │ ├── runtests.py │ │ └── src/ │ │ └── main/ │ │ ├── cpp/ │ │ │ ├── CMakeLists.txt │ │ │ ├── benchmarks/ │ │ │ │ ├── CMakeLists.txt │ │ │ │ ├── cosine_similarity/ │ │ │ │ │ └── cosine_similarity_benchmark.cpp │ │ │ │ ├── fixture/ │ │ │ │ │ └── benchmark_fixture.hpp │ │ │ │ └── synchronization/ │ │ │ │ ├── synchronization.cpp │ │ │ │ └── synchronization.hpp │ │ │ └── src/ │ │ │ ├── CosineSimilarityJni.cpp │ │ │ ├── StringWordCountJni.cpp │ │ │ ├── cosine_similarity.cu │ │ │ ├── cosine_similarity.hpp │ │ │ ├── string_word_count.cu │ │ │ └── string_word_count.hpp │ │ ├── java/ │ │ │ └── com/ │ │ │ └── nvidia/ │ │ │ └── spark/ │ │ │ └── rapids/ │ │ │ └── udf/ │ │ │ ├── hive/ │ │ │ │ ├── DecimalFraction.java │ │ │ │ ├── StringWordCount.java │ │ │ │ ├── URLDecode.java │ │ │ │ └── URLEncode.java │ │ │ └── java/ │ │ │ ├── CosineSimilarity.java │ │ │ ├── DecimalFraction.java │ │ │ ├── NativeUDFExamplesLoader.java │ │ │ ├── URLDecode.java │ │ │ └── URLEncode.java │ │ ├── python/ │ │ │ ├── asserts.py │ │ │ ├── conftest.py │ │ │ ├── data_gen.py │ │ │ ├── rapids_udf_test.py │ │ │ ├── spark_init_internal.py │ │ │ └── spark_session.py │ │ └── scala/ │ │ └── com/ │ │ └── nvidia/ │ │ └── spark/ │ │ └── rapids/ │ │ └── udf/ │ │ └── scala/ │ │ ├── URLDecode.scala │ │ └── URLEncode.scala │ ├── XGBoost-Examples/ │ │ ├── .gitignore │ │ ├── README.md │ │ ├── agaricus/ │ │ │ ├── .gitignore │ │ │ ├── notebooks/ │ │ │ │ ├── python/ │ │ │ │ │ └── agaricus-gpu.ipynb │ │ │ │ └── scala/ │ │ │ │ └── agaricus-gpu.ipynb │ │ │ ├── pom.xml │ │ │ ├── python/ │ │ │ │ └── com/ │ │ │ │ ├── __init__.py │ │ │ │ └── nvidia/ │ │ │ │ ├── __init__.py │ │ │ │ └── spark/ │ │ │ │ ├── __init__.py │ │ │ │ └── examples/ │ │ │ │ ├── __init__.py │ │ │ │ └── agaricus/ │ │ │ │ ├── __init__.py │ │ │ │ └── main.py │ │ │ └── scala/ │ │ │ └── src/ │ │ │ └── com/ │ │ │ └── nvidia/ │ │ │ └── spark/ │ │ │ └── examples/ │ │ │ └── agaricus/ │ │ │ └── Main.scala │ │ ├── aggregator/ │ │ │ └── .gitignore │ │ ├── app-parameters/ │ │ │ ├── supported_xgboost_parameters_python.md │ │ │ └── supported_xgboost_parameters_scala.md │ │ ├── assembly/ │ │ │ └── assembly-no-scala.xml │ │ ├── main.py │ │ ├── mortgage/ │ │ │ ├── .gitignore │ │ │ ├── notebooks/ │ │ │ │ ├── python/ │ │ │ │ │ ├── MortgageETL+XGBoost.ipynb │ │ │ │ │ ├── MortgageETL.ipynb │ │ │ │ │ ├── cv-mortgage-gpu.ipynb │ │ │ │ │ └── mortgage-gpu.ipynb │ │ │ │ └── scala/ │ │ │ │ ├── mortgage-ETL.ipynb │ │ │ │ ├── mortgage-gpu.ipynb │ │ │ │ └── mortgage_gpu_crossvalidation.ipynb │ │ │ ├── pom.xml │ │ │ ├── python/ │ │ │ │ └── com/ │ │ │ │ ├── __init__.py │ │ │ │ └── nvidia/ │ │ │ │ ├── __init__.py │ │ │ │ └── spark/ │ │ │ │ ├── __init__.py │ │ │ │ └── examples/ │ │ │ │ ├── __init__.py │ │ │ │ └── mortgage/ │ │ │ │ ├── __init__.py │ │ │ │ ├── consts.py │ │ │ │ ├── cross_validator_main.py │ │ │ │ ├── etl.py │ │ │ │ ├── etl_main.py │ │ │ │ └── main.py │ │ │ └── scala/ │ │ │ └── src/ │ │ │ └── com/ │ │ │ └── nvidia/ │ │ │ └── spark/ │ │ │ └── examples/ │ │ │ └── mortgage/ │ │ │ ├── CrossValidationMain.scala │ │ │ ├── ETLMain.scala │ │ │ ├── Main.scala │ │ │ ├── Mortgage.scala │ │ │ └── XGBoostETL.scala │ │ ├── pack_pyspark_example.sh │ │ ├── pom.xml │ │ ├── taxi/ │ │ │ ├── .gitignore │ │ │ ├── notebooks/ │ │ │ │ ├── python/ │ │ │ │ │ ├── cv-taxi-gpu.ipynb │ │ │ │ │ ├── taxi-ETL.ipynb │ │ │ │ │ └── taxi-gpu.ipynb │ │ │ │ └── scala/ │ │ │ │ ├── taxi-ETL.ipynb │ │ │ │ ├── taxi-gpu.ipynb │ │ │ │ └── taxi_gpu_crossvalidation.ipynb │ │ │ ├── pom.xml │ │ │ ├── python/ │ │ │ │ └── com/ │ │ │ │ ├── __init__.py │ │ │ │ └── nvidia/ │ │ │ │ ├── __init__.py │ │ │ │ └── spark/ │ │ │ │ ├── __init__.py │ │ │ │ └── examples/ │ │ │ │ ├── __init__.py │ │ │ │ └── taxi/ │ │ │ │ ├── __init__.py │ │ │ │ ├── consts.py │ │ │ │ ├── cross_validator_main.py │ │ │ │ ├── etl_main.py │ │ │ │ ├── main.py │ │ │ │ └── pre_process.py │ │ │ └── scala/ │ │ │ └── src/ │ │ │ └── com/ │ │ │ └── nvidia/ │ │ │ └── spark/ │ │ │ └── examples/ │ │ │ └── taxi/ │ │ │ ├── CrossValidationMain.scala │ │ │ ├── ETLMain.scala │ │ │ ├── Main.scala │ │ │ └── Taxi.scala │ │ └── utility/ │ │ ├── .gitignore │ │ ├── pom.xml │ │ ├── python/ │ │ │ └── com/ │ │ │ ├── __init__.py │ │ │ └── nvidia/ │ │ │ ├── __init__.py │ │ │ └── spark/ │ │ │ ├── __init__.py │ │ │ └── examples/ │ │ │ ├── __init__.py │ │ │ ├── main.py │ │ │ └── utility/ │ │ │ ├── __init__.py │ │ │ ├── args.py │ │ │ └── utils.py │ │ └── scala/ │ │ └── src/ │ │ └── com/ │ │ └── nvidia/ │ │ └── spark/ │ │ └── examples/ │ │ └── utility/ │ │ ├── Benchmark.scala │ │ ├── SparkSetup.scala │ │ ├── Vectorize.scala │ │ └── XGBoostArgs.scala │ └── spark-connect-gpu/ │ ├── client/ │ │ ├── Dockerfile │ │ ├── README.md │ │ ├── docker-compose.yaml │ │ ├── nds/ │ │ │ ├── nds.ipynb │ │ │ └── query_0.sql │ │ ├── notebook/ │ │ │ ├── README.md │ │ │ ├── spark-connect-gpu-etl-ml.ipynb │ │ │ └── work/ │ │ │ ├── csv_raw_schema.ddl │ │ │ └── name_mapping.csv │ │ ├── python/ │ │ │ ├── batch-job.ipynb │ │ │ └── batch-job.py │ │ ├── requirements.txt │ │ └── scala/ │ │ ├── .gitignore │ │ ├── pom.xml │ │ ├── run.sh │ │ ├── scala-run.ipynb │ │ └── src/ │ │ └── main/ │ │ └── scala/ │ │ └── connect.scala │ └── server/ │ ├── README.md │ ├── docker-compose.yaml │ ├── proxy-service/ │ │ ├── Dockerfile │ │ └── nginx.conf │ ├── spark-connect-server/ │ │ ├── Dockerfile │ │ ├── requirements.txt │ │ ├── spark-defaults.conf │ │ └── spark-env.sh │ ├── spark-master/ │ │ ├── Dockerfile │ │ └── spark-env.sh │ └── spark-worker/ │ ├── Dockerfile │ ├── requirements.txt │ └── spark-env.sh ├── scripts/ │ ├── README.md │ ├── building/ │ │ └── python_build.sh │ ├── csp-startup-scripts/ │ │ ├── README.md │ │ └── emr/ │ │ ├── cgroup-bootstrap-action-emr6.sh │ │ ├── cgroup-bootstrap-action-emr7.sh │ │ ├── config-emr6.json │ │ ├── config-emr7.json │ │ └── emr-spark-plugin-startup.py │ ├── encoding/ │ │ └── python/ │ │ ├── .gitignore │ │ ├── com/ │ │ │ ├── __init__.py │ │ │ └── nvidia/ │ │ │ ├── __init__.py │ │ │ └── spark/ │ │ │ ├── __init__.py │ │ │ └── encoding/ │ │ │ ├── __init__.py │ │ │ ├── criteo/ │ │ │ │ ├── __init__.py │ │ │ │ ├── common.py │ │ │ │ ├── one_hot_cpu_main.py │ │ │ │ └── target_cpu_main.py │ │ │ ├── main.py │ │ │ └── utility/ │ │ │ ├── __init__.py │ │ │ ├── args.py │ │ │ └── utils.py │ │ └── main.py │ └── encoding-sample/ │ ├── repartition.py │ ├── run.sh │ └── truncate-model.py └── tools/ ├── databricks/ │ ├── README.md │ ├── [RAPIDS Accelerator for Apache Spark] Profiling Tool Notebook Template.ipynb │ └── [RAPIDS Accelerator for Apache Spark] Qualification Tool Notebook Template.ipynb └── emr/ ├── README.md ├── [RAPIDS Accelerator for Apache Spark] Profiling Tool Notebook Template.ipynb └── [RAPIDS Accelerator for Apache Spark] Qualification Tool Notebook Template.ipynb ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/ISSUE_TEMPLATE/bug_report.md ================================================ --- name: Bug report about: Create a report to help us improve title: '' labels: '' assignees: GaryShen2008 --- **Describe the bug** A clear and concise description of what the bug is. **Steps/Code to reproduce bug** Please provide a list of steps or a code sample to reproduce the issue. Avoid posting private or sensitive data. **Expected behavior** A clear and concise description of what you expected to happen. **Environment details (please complete the following information)** - Environment location: [Standalone, YARN, Kubernetes, Cloud(specify cloud provider)] - Spark configuration settings related to the issue ================================================ FILE: .github/workflows/add-to-project.yml ================================================ # Copyright (c) 2024-2025, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. name: Add new issues and pull requests to project on: issues: types: - opened pull_request_target: types: - opened jobs: Add-to-project: if: github.repository_owner == 'NVIDIA' # avoid adding issues from forks runs-on: ubuntu-latest steps: - name: add-to-project uses: NVIDIA/spark-rapids-common/add-to-project@main with: token: ${{ secrets.PROJECT_TOKEN }} ================================================ FILE: .github/workflows/license-header-check.yml ================================================ # Copyright (c) 2024-2025, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # A workflow to check copyright/license header name: license header check on: pull_request: types: [opened, synchronize, reopened] jobs: license-header-check: runs-on: ubuntu-latest if: "!contains(github.event.pull_request.title, '[bot]')" steps: - name: Get checkout depth run: | echo "PR_FETCH_DEPTH=$(( ${{ github.event.pull_request.commits }} + 10 ))" >> $GITHUB_ENV - name: Checkout code uses: NVIDIA/spark-rapids-common/checkout@main with: fetch-depth: ${{ env.PR_FETCH_DEPTH }} - name: license-header-check uses: NVIDIA/spark-rapids-common/license-header-check@main with: included_file_patterns: | *.sh, *.java, *.py, *.pbtxt, *Dockerfile*, *Jenkinsfile*, *.yml, *.yaml, *.cpp, *.hpp, *.txt, *.cu, *.scala, *.ini, *.xml ================================================ FILE: .github/workflows/markdown-links-check/markdown-links-check-config.json ================================================ { "ignorePatterns": [ { "pattern": "/docs" }, { "pattern": "/datasets" }, { "pattern": "/dockerfile" }, { "pattern": "/examples" }, { "pattern": "^http://localhost" }, { "pattern": "^http://spark-master" }, { "pattern": "^http://spark-worker" }, { "pattern": "^http://spark-connect-server" } ], "timeout": "15s", "retryOn429": true, "retryCount":30, "aliveStatusCodes": [200, 403] } ================================================ FILE: .github/workflows/markdown-links-check.yml ================================================ # Copyright (c) 2022-2025, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # A workflow to check if PR got broken hyperlinks name: Check Markdown links on: pull_request: types: [opened, synchronize, reopened] jobs: markdown-link-check: runs-on: ubuntu-latest steps: - name: work around permission issue run: git config --global --add safe.directory /github/workspace - name: checkout code uses: NVIDIA/spark-rapids-common/checkout@main - name: markdown link check uses: NVIDIA/spark-rapids-common/markdown-link-check@main with: max-depth: -1 use-verbose-mode: 'yes' config-file: '.github/workflows/markdown-links-check/markdown-links-check-config.json' base-branch: 'main' ================================================ FILE: .github/workflows/shell-check.yml ================================================ # Copyright (c) 2025, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # A workflow to check shell script syntax name: shell check on: pull_request: types: [opened, synchronize, reopened] jobs: shell-check: runs-on: ubuntu-latest if: "!contains(github.event.pull_request.title, '[bot]')" steps: - name: Checkout code uses: NVIDIA/spark-rapids-common/checkout@main - name: Run ShellCheck uses: NVIDIA/spark-rapids-common/shell-check@main with: excluded_codes: SC2164, SC2076, SC2054 # codes explanation: # SC2164: Use 'cd ... || exit' or 'cd ... || return' in case cd fails. # SC2076: Remove quotes from right-hand side of =~ to match as a regex rather than literally. # SC2054: Use spaces, not commas, to separate array elements. ================================================ FILE: .github/workflows/signoff-check.yml ================================================ # Copyright (c) 2021-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # A workflow to check if PR got sign-off name: signoff check on: pull_request_target: types: [opened, synchronize, reopened] jobs: signoff-check: runs-on: ubuntu-latest steps: - name: signoff uses: NVIDIA/spark-rapids-common/signoff-check@main with: owner: ${{ github.repository_owner }} repo: spark-rapids-examples pull_number: ${{ github.event.number }} token: ${{ secrets.GITHUB_TOKEN }} ================================================ FILE: .gitignore ================================================ *#*# *.#* *.iml *.ipr *.iws *.pyc *.pyo *.swp *~ .DS_Store .cache .classpath .ensime .ensime_cache/ .ensime_lucene .generated-mima* .idea/ .idea_modules/ .project .pydevproject .scala_dependencies .settings hs_err*.log target ================================================ FILE: CONTRIBUTING.md ================================================ # Contributing to Spark Examples ### Sign your work We require that all contributors sign-off on their commits. This certifies that the contribution is your original work, or you have rights to submit it under the same license, or a compatible license. Any contribution which contains commits that are not signed off will not be accepted. To sign off on a commit use the `--signoff` (or `-s`) option when committing your changes: ```shell git commit -s -m "Add cool feature." ``` This will append the following to your commit message: ``` Signed-off-by: Your Name ``` The sign-off is a simple line at the end of the explanation for the patch. Your signature certifies that you wrote the patch or otherwise have the right to pass it on as an open-source patch. Use your real name, no pseudonyms or anonymous contributions. If you set your `user.name` and `user.email` git configs, you can sign your commit automatically with `git commit -s`. The signoff means you certify the below (from [developercertificate.org](https://developercertificate.org)): ``` Developer Certificate of Origin Version 1.1 Copyright (C) 2004, 2006 The Linux Foundation and its contributors. 1 Letterman Drive Suite D4700 San Francisco, CA, 94129 Everyone is permitted to copy and distribute verbatim copies of this license document, but changing it is not allowed. Developer's Certificate of Origin 1.1 By making a contribution to this project, I certify that: (a) The contribution was created in whole or in part by me and I have the right to submit it under the open source license indicated in the file; or (b) The contribution is based upon previous work that, to the best of my knowledge, is covered under an appropriate open source license and I have the right under that license to submit that work with modifications, whether created in whole or in part by me, under the same open source license (unless I am permitted to submit under a different license), as indicated in the file; or (c) The contribution was provided directly to me by some other person who certified (a), (b) or (c) and I have not modified it. (d) I understand and agree that this project and the contribution are public and that a record of the contribution (including all personal information I submit with it, including my sign-off) is maintained indefinitely and may be redistributed consistent with this project or the open source license(s) involved. ``` Note: This section `Sign your work` is derived from [https://github.com/NVIDIA/spark-rapids](https://github.com/NVIDIA/spark-rapids) ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "{}" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright 2018 NVIDIA Corporation Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================ # spark-rapids-examples This is the [RAPIDS Accelerator for Apache Spark](https://nvidia.github.io/spark-rapids/) examples repo. RAPIDS Accelerator for Apache Spark accelerates Spark applications with no code changes. You can download the latest version of RAPIDS Accelerator [here](https://nvidia.github.io/spark-rapids/docs/download.html). This repo contains examples and applications that showcases the performance and benefits of using RAPIDS Accelerator in data processing and machine learning pipelines. There are broadly five categories of examples in this repo: 1. [SQL/Dataframe](./examples/SQL+DF-Examples) 2. [Spark XGBoost](./examples/XGBoost-Examples) 3. [Machine Learning/Deep Learning](./examples/ML+DL-Examples) 4. [RAPIDS UDF](./examples/UDF-Examples) 5. [Databricks Tools demo notebooks](./tools/databricks) For more information on each of the examples please look into respective categories. Here is the list of notebooks in this repo: | | Category | Notebook Name | Description | ------------- | ------------- | ------------- | ------------- | 1 | SQL/DF | Microbenchmark | Spark SQL operations such as expand, hash aggregate, windowing, and cross joins with up to 20x performance benefits | 2 | SQL/DF | Customer Churn | Data federation for modeling customer Churn with a sample telco customer data | 3 | XGBoost | Agaricus (Scala) | Uses XGBoost classifier function to create model that can accurately differentiate between edible and poisonous mushrooms with the [agaricus dataset](https://archive.ics.uci.edu/ml/datasets/mushroom) | 4 | XGBoost | Mortgage (Scala) | End-to-end ETL + XGBoost example to predict mortgage default with [Fannie Mae Single-Family Loan Performance Data](https://capitalmarkets.fanniemae.com/credit-risk-transfer/single-family-credit-risk-transfer/fannie-mae-single-family-loan-performance-data) | 5 | XGBoost | Taxi (Scala) | End-to-end ETL + XGBoost example to predict taxi trip fare amount with [NYC taxi trips data set](https://www1.nyc.gov/site/tlc/about/tlc-trip-record-data.page) | 6 | ML/DL | PCA | [Spark-Rapids-ML](https://github.com/NVIDIA/spark-rapids-ml) based PCA example to train and transform with a synthetic dataset | 7 | ML/DL | DL Inference | Several notebooks demonstrating distributed model inference on Spark using the `predict_batch_udf` across various frameworks: PyTorch, HuggingFace, vLLM, and TensorFlow | 8 | SQL/DF + MLlib | GPU-Accelerated Spark Connect | End-to-end SQL/DF + MLlib acceleration to predict mortgage default with [Fannie Mae Single-Family Loan Performance Data](https://capitalmarkets.fanniemae.com/credit-risk-transfer/single-family-credit-risk-transfer/fannie-mae-single-family-loan-performance-data) using the lightweight Spark Connect integration for Apache Spark 4.0+ | 9 | SQL/DF | [TPC-DS](https://www.tpc.org/tpcds/) Scale Factor 10 | Comparison of Spark SQL CPU vs GPU. Easy to run locally and on Google Colab Here is the list of Apache Spark applications (Scala and PySpark) that can be built for running on GPU with RAPIDS Accelerator in this repo: | | Category | Notebook Name | Description | ------------- | ------------- | ------------- | ------------- | 1 | XGBoost | Agaricus (Scala) | Uses XGBoost classifier function to create model that can accurately differentiate between edible and poisonous mushrooms with the [agaricus dataset](https://archive.ics.uci.edu/ml/datasets/mushroom) | 2 | XGBoost | Mortgage (Scala) | End-to-end ETL + XGBoost example to predict mortgage default with [Fannie Mae Single-Family Loan Performance Data](https://capitalmarkets.fanniemae.com/credit-risk-transfer/single-family-credit-risk-transfer/fannie-mae-single-family-loan-performance-data) | 3 | XGBoost | Taxi (Scala) | End-to-end ETL + XGBoost example to predict taxi trip fare amount with [NYC taxi trips data set](https://www1.nyc.gov/site/tlc/about/tlc-trip-record-data.page) | 4 | ML/DL | PCA | [Spark-Rapids-ML](https://github.com/NVIDIA/spark-rapids-ml) based PCA example to train and transform with a synthetic dataset | 5 | UDF | URL Decode | Decodes URL-encoded strings using the [Java APIs of RAPIDS cudf](https://docs.rapids.ai/api/cudf-java/legacy/) | 6 | UDF | URL Encode | URL-encodes strings using the [Java APIs of RAPIDS cudf](https://docs.rapids.ai/api/cudf-java/legacy/) | 7 | UDF | [CosineSimilarity](./examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/java/com/nvidia/spark/rapids/udf/java/CosineSimilarity.java) | Computes the cosine similarity between two float vectors using [native code](./examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/cpp/src) | 8 | UDF | [StringWordCount](./examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/java/com/nvidia/spark/rapids/udf/hive/StringWordCount.java) | Implements a Hive simple UDF using [native code](./examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/cpp/src) to count words in strings ================================================ FILE: dockerfile/Dockerfile ================================================ # Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # FROM nvidia/cuda:11.8.0-devel-ubuntu18.04 ARG spark_uid=185 # Install java dependencies RUN apt-get update && apt-get install -y --no-install-recommends openjdk-8-jdk openjdk-8-jre ENV JAVA_HOME /usr/lib/jvm/java-1.8.0-openjdk-amd64 ENV PATH $PATH:/usr/lib/jvm/java-1.8.0-openjdk-amd64/jre/bin:/usr/lib/jvm/java-1.8.0-openjdk-amd64/bin # Before building the docker image, first build and make a Spark distribution following # the instructions in http://spark.apache.org/docs/latest/building-spark.html. # If this docker file is being used in the context of building your images from a Spark # distribution, the docker build command should be invoked from the top level directory # of the Spark distribution. E.g.: # docker build -t spark:latest -f kubernetes/dockerfiles/spark/Dockerfile . RUN set -ex && \ ln -s /lib /lib64 && \ mkdir -p /opt/spark && \ mkdir -p /opt/spark/examples && \ mkdir -p /opt/spark/work-dir && \ touch /opt/spark/RELEASE && \ rm /bin/sh && \ ln -sv /bin/bash /bin/sh && \ echo "auth required pam_wheel.so use_uid" >> /etc/pam.d/su && \ chgrp root /etc/passwd && chmod ug+rw /etc/passwd ENV DEBIAN_FRONTEND noninteractive RUN apt-get update && apt-get install -y --no-install-recommends apt-utils \ && apt-get install -y --no-install-recommends python libgomp1 \ && rm -rf /var/lib/apt/lists/* COPY jars /opt/spark/jars COPY bin /opt/spark/bin COPY sbin /opt/spark/sbin COPY kubernetes/dockerfiles/spark/entrypoint.sh /opt/ COPY examples /opt/spark/examples COPY kubernetes/tests /opt/spark/tests COPY data /opt/spark/data ENV SPARK_HOME /opt/spark WORKDIR /opt/spark/work-dir RUN chmod g+w /opt/spark/work-dir ENV TINI_VERSION v0.18.0 ADD https://github.com/krallin/tini/releases/download/${TINI_VERSION}/tini /sbin/tini RUN chmod +rx /sbin/tini ENTRYPOINT [ "/opt/entrypoint.sh" ] # Specify the User that the actual main process will run as USER ${spark_uid} ================================================ FILE: dockerfile/gpu_executor_template.yaml ================================================ # Copyright (c) 2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. apiVersion: v1 kind: Pod spec: containers: - name: executor resources: limits: nvidia.com/gpu: 1 ================================================ FILE: docs/get-started/xgboost-examples/building-sample-apps/python.md ================================================ # Build XGBoost Python Examples ## Build Follow these steps to package the Python zip file: ``` bash git clone https://github.com/NVIDIA/spark-rapids-examples.git cd spark-rapids-examples/scripts/building sh python_build.sh ``` ## Files Required by PySpark Two files are required by PySpark: + *samples.zip* the package including all example code. Executing the above build commands generates the samples.zip file in 'spark-rapids-examples/examples/XGBoost-Examples' folder + *main.py* entrypoint for PySpark, you can find it in 'spark-rapids-examples/examples/XGBoost-Examples' folder ================================================ FILE: docs/get-started/xgboost-examples/building-sample-apps/scala.md ================================================ # Build XGBoost Scala Examples The examples rely on [XGBoost](https://github.com/dmlc/xgboost). ## Build Follow these steps to build the Scala jars: ``` bash git clone https://github.com/NVIDIA/spark-rapids-examples.git cd spark-rapids-examples/examples/XGBoost-Examples mvn package ``` ## The generated Jars Let's assume LATEST_VERSION is **0.2.3**. The build process will generate two jars as belows, + *aggregator/target/sample_xgboost_apps-${LATEST_VERSION}.jar* only classes for the examples are included, so it should be submitted to spark together with other dependent jars + *aggregator/target/sample_xgboost_apps-${LATEST_VERSION}-jar-with-dependencies.jar* both classes for the examples and the classes from dependent jars are included except cudf and rapids. ================================================ FILE: docs/get-started/xgboost-examples/csp/aws/ec2.md ================================================ # Get Started with XGBoost4J-Spark 3.0 on AWS EC2 This is a getting started guide to Spark 3.2+ on AWS EC2. At the end of this guide, the reader will be able to run a sample Apache Spark application that runs on NVIDIA GPUs on AWS EC2. For more details of AWS EC2 and get started, please check the [AWS document](https://aws.amazon.com/ec2/getting-started/). ## Configure and Launch AWS EC2 Go to AWS Management Console select a region, e.g. Oregon, and click EC2 service. ### Step 1: Launch New Instance Click "Launch instance" at the EC2 Management Console, and select "Launch instance". ![Step 1: Launch New Instance](pics/ec2_step1.png) ### Step 2: Configure Instance #### Step 2.1: Choose an Amazon Machine Image(AMI) Search for "deep learning base ami", choose "Deep Learning Base AMI (Ubuntu 18.04)". Click "Select". ![Step 2.1: Choose an Amazon Machine Image(AMI)](pics/ec2_step2-1.png) #### Step 2.2: Choose an Instance Type Choose type "p3.2xlarge". Click "Next: Configure Instance Details" at right buttom. ![Step 2.1: Choose an Instance Type](pics/ec2_step2-2.png) #### Step 2.3: Configure Instance Detials Do not need to change anything here, make sure "Number of instances" is 1. Click "Next: Add Storage" at right buttom. ![Step 2.3: Configure Instance Detials](pics/ec2_step2-3.png) #### Step 2.4: Add Storage Change the root disk size based on your needed, also you can add ebs volume by clicking "Add New Volume". In this sample, we use default 50G. Click "Next: Add Tag" at right buttom. For more details of AWS EBS please check the [AWS document](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/AmazonEBS.html). ![Step 2.4: Add Storage](pics/ec2_step2-4.png) #### Step 2.5: Add Tags You can add tag here or skip. In this sample, we will skip it. Click "Next: Configure Security Group" at right buttom. #### Step 2.6: Configure Security Group For convenience, in this sample, we open all ports. You can add your own rules. Create a new security group and select type as "All traffic". Click "Review and Launch" at right buttom. For more details of security group, please check the [AWS document](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/ec2-security-groups.html). ![Step 2.6: Configure Security Group](pics/ec2_step2-6.png) #### Step 2.7: Review Instance Launch Review your configuration. Click "Launch" at right buttom. Choose the key-pair you have and launch instances. Return "instances | EC2 Managemnt Console", you can find your instance running. (It may take a few minutes to initialize) ![Step 2.7: Review Instance Launch](pics/ec2_step2-7.png) ## Launch EC2 and Configure Spark 3.2+ ### Step 1: Launch EC2 Copy "Public DNS (IPv4)" of your instance Use ssh with your private key to launch the EC2 machine as user "ubuntu" ``` bash ssh -i "key.pem" ubuntu@xxxx.region.compute.amazonaws.com ``` ### Step 2: Download Spark package Download spark package and set environment variable. ``` bash # download the spark wget https://dlcdn.apache.org/spark/spark-3.2.1/spark-3.2.1-bin-hadoop3.2.tgz tar zxf spark-3.2.1-bin-hadoop3.2.tgz export SPARK_HOME=/your/spark/spark-3.2.1-bin-hadoop3.2 ``` ### Step 3: Download jars for S3A (optional) If your dataset is on S3, you should download below jar files to enable the accessing of S3. In this sample, we will use data on S3. The jars should under $SPARK_HOME/jars ``` bash cd $SPARK_HOME/jars wget https://github.com/JodaOrg/joda-time/releases/download/v2.10.5/joda-time-2.10.5.jar wget https://repo1.maven.org/maven2/org/apache/hadoop/hadoop-aws/3.2.0/hadoop-aws-3.2.0.jar wget https://repo1.maven.org/maven2/com/amazonaws/aws-java-sdk/1.11.687/aws-java-sdk-1.11.687.jar wget https://repo1.maven.org/maven2/com/amazonaws/aws-java-sdk-core/1.11.687/aws-java-sdk-core-1.11.687.jar wget https://repo1.maven.org/maven2/com/amazonaws/aws-java-sdk-dynamodb/1.11.687/aws-java-sdk-dynamodb-1.11.687.jar wget https://repo1.maven.org/maven2/com/amazonaws/aws-java-sdk-s3/1.11.687/aws-java-sdk-s3-1.11.687.jar ``` ### Step 4: Start Spark Standalone #### Step 4.1: Edit spark-default.conf cd $SPARK_HOME/conf and edit spark-defaults.conf By default, thers is only spark-defaults.conf.template in $SPARK_HOME/conf, you could edit it and rename to spark-defaults.conf You can find getGpusResources.sh in $SPARK_HOME/examples/src/main/scripts/getGpusResources.sh ``` bash spark.worker.resource.gpu.amount 1 spark.worker.resource.gpu.discoveryScript /path/to/getGpusResources.sh ``` The gpu.amount should be <= the number of GPUs the worker has. #### Step 4.2: Start Spark Standalone Start Spark. Default master-spark-URL is spark://$HOSTNAME:7077 . ``` bash $SPARK_HOME/sbin/start-master.sh $SPARK_HOME/sbin/start-slave.sh ``` ## Launch XGBoost-Spark examples on Spark 3.2+ ### Step 1: Download Jars Make sure you have prepared the necessary packages and dataset by following this [guide](/docs/get-started/xgboost-examples/prepare-package-data/preparation-scala.md) Copy rapids jars to `$SPARK_HOME/jars` ``` bash cp $RAPIDS_JAR $SPARK_HOME/jars/ ``` ### Step 2: Create sample running script Create running run.sh script with below content, make sure change the paths in it to your own. Also your aws key/secret. ``` bash #!/bin/bash export SPARK_HOME=/your/path/to/spark-3.2.1-bin-hadoop3.2 export PATH=$SPARK_HOME/bin:$SPARK_HOME/sbin:$PATH export TOTAL_CORES=8 export NUM_EXECUTORS=1 export NUM_EXECUTOR_CORES=$((${TOTAL_CORES}/${NUM_EXECUTORS})) export S3A_CREDS_USR=your_aws_key export S3A_CREDS_PSW=your_aws_secret spark-submit --master spark://$HOSTNAME:7077 \ --deploy-mode client \ --driver-memory 10G \ --executor-memory 22G \ --conf spark.hadoop.fs.s3a.impl=org.apache.hadoop.fs.s3a.S3AFileSystem \ --conf spark.hadoop.fs.s3a.access.key=$S3A_CREDS_USR \ --conf spark.hadoop.fs.s3a.secret.key=$S3A_CREDS_PSW \ --conf spark.executor.memoryOverhead=28G \ --conf spark.cores.max=$TOTAL_CORES \ --conf spark.executor.cores=$NUM_EXECUTOR_CORES \ --conf spark.task.cpus=$NUM_EXECUTOR_CORES \ --conf spark.sql.files.maxPartitionBytes=4294967296 \ --conf spark.yarn.maxAppAttempts=1 \ --conf spark.plugins=com.nvidia.spark.SQLPlugin \ --conf spark.rapids.memory.gpu.pooling.enabled=false \ --conf spark.executor.resource.gpu.amount=1 \ --conf spark.task.resource.gpu.amount=1 \ --class com.nvidia.spark.examples.mortgage.GPUMain \ ${SAMPLE_JAR} \ -num_workers=${NUM_EXECUTORS} \ -format=csv \ -dataPath="train::your-train-data-path" \ -dataPath="trans::your-eval-data-path" \ -numRound=100 -max_depth=8 -nthread=$NUM_EXECUTOR_CORES -showFeatures=0 \ -tree_method=gpu_hist ``` ### Step 3: Submit Sample job Run run.sh ``` bash ./run.sh ``` After running successfully, the job will print an accuracy benchmark for model prediction. ================================================ FILE: docs/get-started/xgboost-examples/csp/databricks/databricks.md ================================================ Get Started with XGBoost4J-Spark on Databricks ====================================================== This is a getting started guide to XGBoost4J-Spark on Databricks. At the end of this guide, the reader will be able to run a sample Apache Spark application that runs on NVIDIA GPUs on Databricks. Prerequisites ------------- * Apache Spark 3.x running in Databricks Runtime 10.4 ML or 11.3 ML with GPU * AWS: 10.4 LTS ML (GPU, Scala 2.12, Spark 3.2.1) or 11.3 LTS ML (GPU, Scala 2.12, Spark 3.3.0) * Azure: 10.4 LTS ML (GPU, Scala 2.12, Spark 3.2.1) or 11.3 LTS ML (GPU, Scala 2.12, Spark 3.3.0) The number of GPUs per node dictates the number of Spark executors that can run in that node. Each executor should only be allowed to run 1 task at any given time. Start A Databricks Cluster -------------------------- Before creating the cluster, we will need to create an [initialization script](https://docs.databricks.com/clusters/init-scripts.html) for the cluster to install the RAPIDS jars. Databricks recommends storing all cluster-scoped init scripts using workspace files. Each user has a Home directory configured under the /Users directory in the workspace. Navigate to your home directory in the UI and select **Create** > **File** from the menu, create an `init.sh` scripts with contents: ```bash #!/bin/bash sudo wget -O /databricks/jars/rapids-4-spark_2.12-26.02.0.jar https://repo1.maven.org/maven2/com/nvidia/rapids-4-spark_2.12/26.02.0/rapids-4-spark_2.12-26.02.0.jar ``` 1. Select the Databricks Runtime Version from one of the supported runtimes specified in the Prerequisites section. 2. Choose the number of workers that matches the number of GPUs you want to use. 3. Select a worker type. On AWS, use nodes with 1 GPU each such as `p3.2xlarge` or `g4dn.xlarge`. For Azure, choose GPU nodes such as Standard_NC6s_v3. For GCP, choose N1 or A2 instance types with GPUs. 4. Select the driver type. Generally this can be set to be the same as the worker. 5. Click the “Edit” button, then navigate down to the “Advanced Options” section. Select the “Init Scripts” tab in the advanced options section, and paste the workspace path to the initialization script:`/Users/user@domain/init.sh`, then click “Add”. ![Init Script](../../../../img/databricks/initscript.png) 6. Now select the “Spark” tab, and paste the following config options into the Spark Config section. Change the config values based on the workers you choose. See Apache Spark [configuration](https://spark.apache.org/docs/latest/configuration.html) and RAPIDS Accelerator for Apache Spark [descriptions](https://nvidia.github.io/spark-rapids/docs/configs.html) for each config. The [`spark.task.resource.gpu.amount`](https://spark.apache.org/docs/latest/configuration.html#scheduling) configuration is defaulted to 1 by Databricks. That means that only 1 task can run on an executor with 1 GPU, which is limiting, especially on the reads and writes from Parquet. Set this to 1/(number of cores per executor) which will allow multiple tasks to run in parallel just like the CPU side. Having the value smaller is fine as well. Note: Please remove the `spark.task.resource.gpu.amount` config for a single-node Databricks cluster because Spark local mode does not support GPU scheduling. ```bash spark.plugins com.nvidia.spark.SQLPlugin spark.task.resource.gpu.amount 0.1 spark.rapids.memory.pinnedPool.size 2G spark.rapids.sql.concurrentGpuTasks 2 ``` ![Spark Config](../../../../img/databricks/sparkconfig.png) If running Pandas UDFs with GPU support from the plugin, at least three additional options as below are required. The `spark.python.daemon.module` option is to choose the right daemon module of python for Databricks. On Databricks, the python runtime requires different parameters than the Spark one, so a dedicated python demon module `rapids.daemon_databricks` is created and should be specified here. Set the config [`spark.rapids.sql.python.gpu.enabled`](https://nvidia.github.io/spark-rapids/docs/configs.html#sql.python.gpu.enabled) to `true` to enable GPU support for python. Add the path of the plugin jar (supposing it is placed under `/databricks/jars/`) to the `spark.executorEnv.PYTHONPATH` option. For more details please go to [GPU Scheduling For Pandas UDF](https://nvidia.github.io/spark-rapids/docs/additional-functionality/rapids-udfs.html#gpu-support-for-pandas-udf) ```bash spark.rapids.sql.python.gpu.enabled true spark.python.daemon.module rapids.daemon_databricks spark.executorEnv.PYTHONPATH /databricks/jars/rapids-4-spark_2.12-26.02.0.jar:/databricks/spark/python ``` Note that since python memory pool require installing the cudf library, so you need to install cudf library in each worker nodes `pip install cudf-cu11 --extra-index-url=https://pypi.nvidia.com` or disable python memory pool `spark.rapids.python.memory.gpu.pooling.enabled=false`. 7. Click `Create Cluster`, it is now enabled for GPU-accelerated Spark. Install the xgboost4j_spark jar in the cluster --------------------------- 1. See [Libraries](https://docs.databricks.com/user-guide/libraries.html) for how to install jars from DBFS 2. Go to "Libraries" tab under your cluster and install dbfs:/FileStore/jars/${XGBOOST4J_SPARK_JAR} in your cluster by selecting the "DBFS" option for installing jars These steps will ensure you are able to import xgboost libraries in python notebooks. Import the GPU Mortgage Example Notebook --------------------------- 1. See [Managing Notebooks](https://docs.databricks.com/user-guide/notebooks/notebook-manage.html) on how to import a notebook. 2. Import the example notebook: [XGBoost4j-Spark mortgage notebook](../../../../../examples/XGBoost-Examples/mortgage/notebooks/scala/mortgage-gpu.ipynb) 3. Inside the mortgage example notebook, update the data paths from "/data/datasets/mortgage-small/train" to "dbfs:/FileStore/tables/mortgage/csv/train/mortgage_train_merged.csv" "/data/datasets/mortgage-small/eval" to "dbfs:/FileStore/tables/mortgage/csv/test/mortgage_eval_merged.csv" The example notebook comes with the following configuration, you can adjust this according to your setup. See supported configuration options here: [xgboost parameters](../../../../../examples/XGBoost-Examples/app-parameters/supported_xgboost_parameters_python.md) ``` bash params = { 'eta': 0.1, 'gamma': 0.1, 'missing': 0.0, 'treeMethod': 'gpu_hist', 'maxDepth': 10, 'maxLeaves': 256, 'growPolicy': 'depthwise', 'minChildWeight': 30.0, 'lambda_': 1.0, 'scalePosWeight': 2.0, 'subsample': 1.0, 'nthread': 1, 'numRound': 100, 'numWorkers': 1, } ``` 4. Run all the cells in the notebook. 5. View the results In the cell 5 (Training), 7 (Transforming) and 8 (Accuracy of Evaluation) you will see the output. ``` -------------- ==> Benchmark: Training takes 6.48 seconds -------------- -------------- ==> Benchmark: Transformation takes 3.2 seconds -------------- ------Accuracy of Evaluation------ Accuracy is 0.9980699597729774 ``` Limitations ------------- 1. When selecting GPU nodes, Databricks UI requires the driver node to be a GPU node. However you can use Databricks API to create a cluster with CPU driver node. Outside of Databricks the plugin can operate with the driver as a CPU node and workers as GPU nodes. 2. Cannot spin off multiple executors on a multi-GPU node. Even though it is possible to set `spark.executor.resource.gpu.amount=1` in the in Spark Configuration tab, Databricks overrides this to `spark.executor.resource.gpu.amount=N` (where N is the number of GPUs per node). This will result in failed executors when starting the cluster. 3. Parquet rebase mode is set to "LEGACY" by default. The following Spark configurations are set to `LEGACY` by default on Databricks: ``` spark.sql.legacy.parquet.datetimeRebaseModeInWrite spark.sql.legacy.parquet.int96RebaseModeInWrite ``` These settings will cause a CPU fallback for Parquet writes involving dates and timestamps. If you do not need `LEGACY` write semantics, set these configs to `EXCEPTION` which is the default value in Apache Spark 3.0 and higher. 4. Databricks makes changes to the runtime without notification. Databricks makes changes to existing runtimes, applying patches, without notification. [Issue-3098](https://github.com/NVIDIA/spark-rapids/issues/3098) is one example of this. We run regular integration tests on the Databricks environment to catch these issues and fix them once detected. ================================================ FILE: docs/get-started/xgboost-examples/csp/databricks/init.sh ================================================ #!/bin/bash # Copyright (c) 2025-2026, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # sudo rm -f /databricks/jars/spark--maven-trees--ml--10.x--xgboost-gpu--ml.dmlc--xgboost4j-gpu_2.12--ml.dmlc__xgboost4j-gpu_2.12__1.5.2.jar sudo rm -f /databricks/jars/spark--maven-trees--ml--10.x--xgboost-gpu--ml.dmlc--xgboost4j-spark-gpu_2.12--ml.dmlc__xgboost4j-spark-gpu_2.12__1.5.2.jar sudo wget -O /databricks/jars/rapids-4-spark_2.12-26.02.0.jar https://repo1.maven.org/maven2/com/nvidia/rapids-4-spark_2.12/26.02.0/rapids-4-spark_2.12-26.02.0.jar sudo wget -O /databricks/jars/xgboost4j-gpu_2.12-1.7.1.jar https://repo1.maven.org/maven2/ml/dmlc/xgboost4j-gpu_2.12/1.7.1/xgboost4j-gpu_2.12-1.7.1.jar sudo wget -O /databricks/jars/xgboost4j-spark-gpu_2.12-1.7.1.jar https://repo1.maven.org/maven2/ml/dmlc/xgboost4j-spark-gpu_2.12/1.7.1/xgboost4j-spark-gpu_2.12-1.7.1.jar ls -ltr mkdir -p /dbfs/FileStore/tables/ cd /dbfs/FileStore/tables/ # Note that this is just a dummy dataset for quickly hands on, please refer the instructions to download the full dataset: # https://github.com/NVIDIA/spark-rapids-examples/blob/main/docs/get-started/xgboost-examples/dataset/mortgage.md wget -O mortgage.zip https://rapidsai-data.s3.us-east-2.amazonaws.com/spark/mortgage.zip ls unzip -o mortgage.zip pwd ls -ltr mortgage/csv/* ================================================ FILE: docs/get-started/xgboost-examples/csp/dataproc/gcp.md ================================================ # Getting started pyspark+xgboost with RAPIDS Accelerator on GCP Dataproc [Google Cloud Dataproc](https://cloud.google.com/dataproc) is Google Cloud's fully managed Apache Spark and Hadoop service. Please make sure to install gcloud CLI by following this [guide](https://cloud.google.com/sdk/docs/install) before getting started. ## Create a Dataproc Cluster using T4's * One 16-core master node and 2 32-core worker nodes * Two NVIDIA T4 for each worker node ```bash export REGION=[Your Preferred GCP Region] export GCS_BUCKET=[Your GCS Bucket] export CLUSTER_NAME=[Your Cluster Name] export NUM_GPUS=2 export NUM_WORKERS=2 gcloud dataproc clusters create $CLUSTER_NAME \ --region=$REGION \ --image-version=2.0-ubuntu18 \ --master-machine-type=n2-standard-16 \ --num-workers=$NUM_WORKERS \ --worker-accelerator=type=nvidia-tesla-t4,count=$NUM_GPUS \ --worker-machine-type=n1-highmem-32\ --num-worker-local-ssds=4 \ --initialization-actions=gs://goog-dataproc-initialization-actions-${REGION}/spark-rapids/spark-rapids.sh \ --optional-components=JUPYTER,ZEPPELIN \ --metadata=rapids-runtime=SPARK \ --bucket=$GCS_BUCKET \ --enable-component-gateway \ --subnet=default ``` Explanation of parameters: * NUM_GPUS = number of GPUs to attach to each worker node in the cluster * NUM_WORKERS = number of Spark worker nodes in the cluster This takes around 10-15 minutes to complete. You can navigate to the Dataproc clusters tab in the Google Cloud Console to see the progress. ![Dataproc Cluster](../../../../img/GCP/dataproc-cluster.png) If you'd like to further accelerate init time to 4-5 minutes, create a custom Dataproc image using [this](#build-custom-dataproc-image-to-accelerate-cluster-init-time) guide. ## Get Application Files, Jar and Dataset Bash into the master node and make sure you have prepared the necessary packages and dataset by following this [guide](../../prepare-package-data/preparation-python.md). Note: Since there is no maven CLI in master node, so we need to manually install. ``` bash gcloud compute ssh your-name@your-cluster-m --zone your-zone sudo apt-get install maven -y ``` Then create a directory in HDFS, and run below commands, ``` bash [xgboost4j_spark_python]$ hadoop fs -mkdir /tmp/xgboost4j_spark_python [xgboost4j_spark_python]$ hadoop fs -copyFromLocal ${SPARK_XGBOOST_DIR}/mortgage/* /tmp/xgboost4j_spark_python ``` ## Preparing libraries Please make sure to install the XGBoost, cudf-cu11, numpy libraries on all nodes before running XGBoost application. ``` bash pip install xgboost pip install cudf-cu11 --extra-index-url=https://pypi.nvidia.com pip install numpy pip install scikit-learn ``` You can also create an isolated python environment by using [Virtualenv](https://virtualenv.pypa.io/en/latest/), and then directly pass/unpack the archive file and enable the environment on executors by leveraging the --archives option or spark.archives configuration. ``` bash # create an isolated python environment and install libraries python -m venv pyspark_venv source pyspark_venv/bin/activate pip install xgboost pip install cudf-cu11 --extra-index-url=https://pypi.nvidia.com pip install numpy pip install scikit-learn pip install venv-pack venv-pack -o pyspark_venv.tar.gz # enable archive python environment on executors export PYSPARK_DRIVER_PYTHON=python # Do not set in cluster modes. export PYSPARK_PYTHON=./environment/bin/python spark-submit --archives pyspark_venv.tar.gz#environment app.py ``` ## Run jupyter notebooks on Dataproc Bash into the master node and start up the notebook. ``` jupyter notebook --ip=0.0.0.0 --port=8124 --no-browser ``` If you want to remote access the notebook from local, please reserve an external static IP address first: 1. Access the IP addresses page through the navigation menu: `VPC network` -> `IP addresses` ![dataproc img2](../../../../img/GCP/dataproc-img2.png) 2. Click the `RESERVE EXTERNAL STATIC ADDRESS` button ![dataproc img3](../../../../img/GCP/dataproc-img3.png) 3. Attached the static address to the master node of your cluster ![dataproc img4](../../../../img/GCP/dataproc-img4.png) 4. Then you can access and run the notebook from the browser in local using the reserved address. ![dataproc img5](../../../../img/GCP/dataproc-img5.png) Then you can run the [notebook](../../../../../examples/XGBoost-Examples/mortgage/notebooks/python/mortgage-gpu.ipynb) and get the benchmark results. ![dataproc img6](../../../../img/GCP/dataproc-img6.png) ## Build custom dataproc image to accelerate cluster init time In order to accelerate cluster init time to 3-4 minutes, we need to build a custom Dataproc image that already has NVIDIA drivers and CUDA toolkit installed, with RAPIDS deployed. The custom image could also be used in an air gap environment. In this section, we will be using [these instructions from GCP](https://cloud.google.com/dataproc/docs/guides/dataproc-images) to create a custom image. Currently, we can directly download the [spark-rapids.sh](https://github.com/GoogleCloudDataproc/initialization-actions/tree/master/spark-rapids) script to create the Dataproc image: Google provides a `generate_custom_image.py` script that: - Launches a temporary Compute Engine VM instance with the specified Dataproc base image. - Then runs the customization script inside the VM instance to install custom packages and/or update configurations. - After the customization script finishes, it shuts down the VM instance and creates a Dataproc custom image from the disk of the VM instance. - The temporary VM is deleted after the custom image is created. - The custom image is saved and can be used to create Dataproc clusters. Download `spark-rapids.sh` in this repo. The script uses Google's `generate_custom_image.py` script. This step may take 20-25 minutes to complete. ```bash git clone https://github.com/GoogleCloudDataproc/custom-images cd custom-images export CUSTOMIZATION_SCRIPT=/path/to/spark-rapids.sh export ZONE=[Your Preferred GCP Zone] export GCS_BUCKET=[Your GCS Bucket] export IMAGE_NAME=sample-20-ubuntu18-gpu-t4 export DATAPROC_VERSION=2.0-ubuntu18 export GPU_NAME=nvidia-tesla-t4 export GPU_COUNT=1 python generate_custom_image.py \ --image-name $IMAGE_NAME \ --dataproc-version $DATAPROC_VERSION \ --customization-script $CUSTOMIZATION_SCRIPT \ --no-smoke-test \ --zone $ZONE \ --gcs-bucket $GCS_BUCKET \ --machine-type n1-standard-4 \ --accelerator type=$GPU_NAME,count=$GPU_COUNT \ --disk-size 200 \ --subnet default ``` See [here](https://cloud.google.com/dataproc/docs/guides/dataproc-images#running_the_code) for more details on `generate_custom_image.py` script arguments and [here](https://cloud.google.com/dataproc/docs/concepts/versioning/dataproc-versions) for dataproc version description. The image `sample-20-ubuntu18-gpu-t4` is now ready and can be viewed in the GCP console under `Compute Engine > Storage > Images`. The next step is to launch the cluster using this new image and new initialization actions (that do not install NVIDIA drivers since we are already past that step). Move this to your own bucket. Let's launch the cluster: ```bash export REGION=[Your Preferred GCP Region] export GCS_BUCKET=[Your GCS Bucket] export CLUSTER_NAME=[Your Cluster Name] export NUM_GPUS=1 export NUM_WORKERS=2 gcloud dataproc clusters create $CLUSTER_NAME \ --region=$REGION \ --image=sample-20-ubuntu18-gpu-t4 \ --master-machine-type=n1-standard-4 \ --num-workers=$NUM_WORKERS \ --worker-accelerator=type=nvidia-tesla-t4,count=$NUM_GPUS \ --worker-machine-type=n1-standard-4 \ --num-worker-local-ssds=1 \ --optional-components=JUPYTER,ZEPPELIN \ --metadata=rapids-runtime=SPARK \ --bucket=$GCS_BUCKET \ --enable-component-gateway \ --subnet=default ``` The new cluster should be up and running within 3-4 minutes! ================================================ FILE: docs/get-started/xgboost-examples/dataset/mortgage.md ================================================ # How to download the Mortgage dataset ## Steps to download the data 1. Go to the [Fannie Mae](https://capitalmarkets.fanniemae.com/credit-risk-transfer/single-family-credit-risk-transfer/fannie-mae-single-family-loan-performance-data) website 2. Click on [Single-Family Loan Performance Data](https://datadynamics.fanniemae.com/data-dynamics/?&_ga=2.181456292.2043790680.1657122341-289272350.1655822609#/reportMenu;category=HP) * Register as a new user if you are using the website for the first time * Use the credentials to login 3. Select [HP](https://datadynamics.fanniemae.com/data-dynamics/#/reportMenu;category=HP) 4. Click on **Download Data** and choose *Single-Family Loan Performance Data* 5. You will find a tabular list of 'Acquisition and Performance' files sorted based on year and quarter. Click on the file to download `Eg: 2017Q1.zip` 6. Unzip the downlad file to extract the csv file `Eg: 2017Q1.csv` 7. Copy only the csv files to a new folder for the ETL to read ## Notes 1. Refer to the [Loan Performance Data Tutorial](https://capitalmarkets.fanniemae.com/media/9066/display) for more details. 2. Note that *Single-Family Loan Performance Data* has 2 componenets. However, the Mortgage ETL requires only the first one (primary dataset) * Primary Dataset: Acquisition and Performance Files * HARP Dataset 3. Use the [Resources](https://datadynamics.fanniemae.com/data-dynamics/#/resources/HP) section to know more about the dataset ================================================ FILE: docs/get-started/xgboost-examples/notebook/python-notebook.md ================================================ Get Started with pyspark+XGBoost with Jupyter Notebook =================================================================== This is a getting started guide to XGBoost4J-Spark using an [Jupyter notebook](https://jupyter.org/). At the end of this guide, you will be able to run a sample notebook that runs on NVIDIA GPUs. Before you begin, please ensure that you have setup a Spark Cluster(Standalone or YARN). You should change `--master` config according to your cluster architecture. For example, set `--master yarn` for spark on YARN. It is assumed that the `SPARK_MASTER` and `SPARK_HOME` environment variables are defined and point to the Spark Master URL (e.g. `spark://localhost:7077`), and the home directory for Apache Spark respectively. 1. Make sure you have [Jupyter notebook installed](https://jupyter.org/install.html). If you install it with conda, please make sure your Python version is consistent. 2. Prepare packages and dataset. Make sure you have prepared the necessary packages and dataset by following this [guide](../prepare-package-data/preparation-python.md) 3. Launch the notebook: Note: For ETL jobs, Set `spark.task.resource.gpu.amount` to `1/spark.executor.cores`. For ETL: ``` bash PYSPARK_DRIVER_PYTHON=jupyter \ PYSPARK_DRIVER_PYTHON_OPTS=notebook \ pyspark \ --master ${SPARK_MASTER} \ --jars ${RAPIDS_JAR}\ --py-files ${SAMPLE_ZIP} \ --conf spark.plugins=com.nvidia.spark.SQLPlugin \ --conf spark.executor.resource.gpu.amount=1 \ --conf spark.executor.cores=10 \ --conf spark.task.resource.gpu.amount=0.1 \ --conf spark.sql.cache.serializer=com.nvidia.spark.ParquetCachedBatchSerializer \ --conf spark.executor.resource.gpu.discoveryScript=./getGpusResources.sh \ --files $SPARK_HOME/examples/src/main/scripts/getGpusResources.sh ``` For XGBoost: ``` bash PYSPARK_DRIVER_PYTHON=jupyter \ PYSPARK_DRIVER_PYTHON_OPTS=notebook \ pyspark \ --master ${SPARK_MASTER} \ --jars ${RAPIDS_JAR}\ --py-files ${SAMPLE_ZIP} \ --conf spark.plugins=com.nvidia.spark.SQLPlugin \ --conf spark.rapids.memory.gpu.pool=NONE \ --conf spark.executor.resource.gpu.amount=1 \ --conf spark.executor.cores=10 \ --conf spark.task.resource.gpu.amount=1 \ --conf spark.sql.execution.arrow.maxRecordsPerBatch=200000 \ --conf spark.executor.resource.gpu.discoveryScript=./getGpusResources.sh \ --files $SPARK_HOME/examples/src/main/scripts/getGpusResources.sh ``` 4. Launch ETL Part - Mortgage ETL Notebook: [Python](../../../../examples/XGBoost-Examples/mortgage/notebooks/python/MortgageETL.ipynb) - Taxi ETL Notebook: [Python](../../../../examples/XGBoost-Examples/taxi/notebooks/python/taxi-ETL.ipynb) - Note: Agaricus does not have ETL part. ================================================ FILE: docs/get-started/xgboost-examples/notebook/spylon.md ================================================ Get Started with XGBoost4J-Spark with Spylon Kernel Jupyter Notebook =================================================================== This is a getting started guide to XGBoost4J-Spark using a [Spylon Kernel](https://pypi.org/project/spylon-kernel/) Jupyter notebook. At the end of this guide, the reader will be able to run a sample notebook that runs on NVIDIA GPUs. Before you begin, please ensure that you have setup a [Spark Standalone Cluster](/docs/get-started/xgboost-examples/on-prem-cluster/standalone-scala.md). It is assumed that the `SPARK_MASTER` and `SPARK_HOME` environment variables are defined and point to the Spark Master URL, and the home directory for Apache Spark respectively. 1. Install Jupyter Notebook with spylon-kernel. ``` bash # Install notebook and spylon-kernel (Scala kernel for Jupyter Notebook), https://pypi.org/project/spylon-kernel/ # You can use spylon-kernel as Scala kernel for Jupyter Notebook. Do this when you want to work with Spark in Scala with a bit of Python code mixed in. RUN pip3 install jupyter notebook spylon-kernel RUN python -m spylon_kernel install # Latest version breaks nbconvert: https://github.com/ipython/ipykernel/issues/422 RUN pip3 install ipykernel==5.1.1 ``` 2. Start Jupyter Notebook. You can debug from webUI http://your_ip:your_port with your password. ``` bash export JUPYTER_CONFIG_FILE=~/.jupyter/jupyter_notebook_config.py rm -rf `dirname $JUPYTER_CONFIG_FILE` && mkdir -p `dirname $JUPYTER_CONFIG_FILE` && echo """ c.NotebookApp.ip='*' c.NotebookApp.password = your_hashed_password c.NotebookApp.password = your_password c.NotebookApp.open_browser = False c.NotebookApp.port = your_port """ > $JUPYTER_CONFIG_FILE jupyter notebook --allow-root --notebook-dir=$WORKSPACE --config=$JUPYTER_CONFIG_FILE & ``` 3. Prepare packages and dataset. Make sure you have prepared the necessary packages and dataset by following this [guide](/docs/get-started/xgboost-examples/prepare-package-data/preparation-scala.md) 4. Run scala notebook (e.g. [mortgage-gpu.ipynb](../../../../examples/XGBoost-Examples/mortgage/notebooks/scala/mortgage-gpu.ipynb)) ``` bash # Suppose your Scala file is $WORKSPACE/mortgage-gpu.ipynb jupyter nbconvert --to notebook --stdout --execute $WORKSPACE/mortgage-gpu.ipynb # -------you will see output looks like ---------------- # { # "cells": [ # { # "cell_type": "code", # "execution_count": 1, # "id": "5ca1ae16", # "metadata": { # ........ # ........ # ........ # "language_info": { # "codemirror_mode": "text/x-scala", # "file_extension": ".scala", # "help_links": [ # { # "text": "MetaKernel Magics", # "url": "https://metakernel.readthedocs.io/en/latest/source/README.html" # } # ], # "mimetype": "text/x-scala", # "name": "scala", # "pygments_lexer": "scala", # "version": "0.4.1" # } # }, # "nbformat": 4, # "nbformat_minor": 5 # } ``` You can also run python notebook with Spylon Kernel ``` bash # restart Jupyter Notebook export PYSPARK_DRIVER_PYTHON=jupyter export PYSPARK_DRIVER_PYTHON_OPTS="notebook --allow-root --notebook-dir=$WORKSPACE --config=$JUPYTER_CONFIG_FILE" pyspark & # Suppose your python file is $WORKSPACE/mortgage-gpu.ipynb jupyter nbconvert --to notebook--stdout --execute $WORKSPACE/mortgage-gpu.ipynb ``` 5. Launch ETL Part - Mortgage ETL Notebook: [Scala](../../../../examples/XGBoost-Examples/mortgage/notebooks/scala/mortgage-ETL.ipynb) or [Python](../../../../examples/XGBoost-Examples/mortgage/notebooks/python/MortgageETL.ipynb) - Taxi ETL Notebook: [Scala](../../../../examples/XGBoost-Examples/taxi/notebooks/scala/taxi-ETL.ipynb) or [Python](../../../../examples/XGBoost-Examples/taxi/notebooks/python/taxi-ETL.ipynb) - Note: Agaricus does not have ETL part. 6. Launch XGBoost Part - Mortgage XGBoost Notebook: [Scala](../../../../examples/XGBoost-Examples/mortgage/notebooks/scala/mortgage-gpu.ipynb) - Taxi XGBoost Notebook: [Scala](../../../../examples/XGBoost-Examples/taxi/notebooks/scala/taxi-gpu.ipynb) - Agaricus XGBoost Notebook: [Scala](../../../../examples/XGBoost-Examples/agaricus/notebooks/scala/agaricus-gpu.ipynb) ================================================ FILE: docs/get-started/xgboost-examples/notebook/toree.md ================================================ Get Started with XGBoost4J-Spark with Apache Toree Jupyter Notebook =================================================================== This is a getting started guide to XGBoost4J-Spark using an [Apache Toree](https://toree.apache.org/) Jupyter notebook. At the end of this guide, you will be able to run a sample notebook that runs on NVIDIA GPUs. Before you begin, please ensure that you have setup a Spark Cluster(Standalone or YARN). You should change `--master` config according to your cluster architecture. For example, set `--master yarn` for spark on YARN. It is assumed that the `SPARK_MASTER` and `SPARK_HOME` environment variables are defined and point to the Spark Master URL (e.g. `spark://localhost:7077`), and the home directory for Apache Spark respectively. 1. Make sure you have jupyter notebook and [sbt](https://www.scala-sbt.org/1.x/docs/Installing-sbt-on-Linux.html) installed first. 2. Build the 'toree' locally to support scala 2.12, and install it. ``` bash # Download toree wget https://github.com/apache/incubator-toree/archive/refs/tags/v0.5.0-incubating-rc4.tar.gz tar -xvzf v0.5.0-incubating-rc4.tar.gz # Build the Toree pip package. cd incubator-toree-0.5.0-incubating-rc4 make pip-release # Install Toree pip install dist/toree-pip/toree-0.5.0.tar.gz ``` 3. Prepare packages and dataset. Make sure you have prepared the necessary packages and dataset by following this [guide](/docs/get-started/xgboost-examples/prepare-package-data/preparation-scala.md) 4. Install a new kernel with gpu enabled and launch the notebook Note: For ETL jobs, Set `spark.task.resource.gpu.amount` to `1/spark.executor.cores`. For ETL: ``` bash jupyter toree install \ --spark_home=${SPARK_HOME} \ --user \ --toree_opts='--nosparkcontext' \ --kernel_name="ETL-Spark" \ --spark_opts='--master ${SPARK_MASTER} \ --jars ${RAPIDS_JAR},${SAMPLE_JAR} \ --conf spark.plugins=com.nvidia.spark.SQLPlugin \ --conf spark.executor.extraClassPath=${RAPIDS_JAR} \ --conf spark.executor.cores=10 \ --conf spark.task.resource.gpu.amount=0.1 \ --conf spark.executor.resource.gpu.discoveryScript=./getGpusResources.sh \ --files $SPARK_HOME/examples/src/main/scripts/getGpusResources.sh' ``` For XGBoost: ``` bash jupyter toree install \ --spark_home=${SPARK_HOME} \ --user \ --toree_opts='--nosparkcontext' \ --kernel_name="XGBoost-Spark" \ --spark_opts='--master ${SPARK_MASTER} \ --jars ${RAPIDS_JAR},${SAMPLE_JAR} \ --conf spark.plugins=com.nvidia.spark.SQLPlugin \ --conf spark.executor.extraClassPath=${RAPIDS_JAR} \ --conf spark.rapids.memory.gpu.pool=NONE \ --conf spark.executor.resource.gpu.amount=1 \ --conf spark.executor.cores=10 \ --conf spark.task.resource.gpu.amount=1 \ --conf spark.executor.resource.gpu.discoveryScript=./getGpusResources.sh \ --files $SPARK_HOME/examples/src/main/scripts/getGpusResources.sh' ``` Launch the notebook: ``` bash jupyter notebook ``` 4. Launch ETL Part - Mortgage ETL Notebook: [Scala](../../../../examples/XGBoost-Examples/mortgage/notebooks/scala/mortgage-ETL.ipynb) - Taxi ETL Notebook: [Scala](../../../../examples/XGBoost-Examples/taxi/notebooks/scala/taxi-ETL.ipynb) - Note: Agaricus does not have ETL part. 5. Launch XGBoost Part - Mortgage XGBoost Notebook: [Scala](../../../../examples/XGBoost-Examples/mortgage/notebooks/scala/mortgage-gpu.ipynb) - Taxi XGBoost Notebook: [Scala](../../../../examples/XGBoost-Examples/taxi/notebooks/scala/taxi-gpu.ipynb) - Agaricus XGBoost Notebook: [Scala](../../../../examples/XGBoost-Examples/agaricus/notebooks/scala/agaricus-gpu.ipynb) ================================================ FILE: docs/get-started/xgboost-examples/on-prem-cluster/kubernetes-scala.md ================================================ Get Started with XGBoost4J-Spark on Kubernetes ============================================== This is a getting started guide to deploy XGBoost4J-Spark package on a Kubernetes cluster. At the end of this guide, the reader will be able to run a sample Apache Spark XGBoost application on NVIDIA GPU Kubernetes cluster. Prerequisites ------------- * Apache Spark 3.2.0+ (e.g.: Spark 3.2.0) * Hardware Requirements * NVIDIA Pascal™ GPU architecture or better * Multi-node clusters with homogenous GPU configuration * Software Requirements * Ubuntu 20.04, 22.04/CentOS7, Rocky Linux 8 * CUDA 11.0+ * NVIDIA driver compatible with your CUDA * NCCL 2.7.8+ * [Kubernetes cluster with NVIDIA GPUs](https://docs.nvidia.com/datacenter/cloud-native/kubernetes/install-k8s.html) * See official [Spark on Kubernetes](https://spark.apache.org/docs/latest/running-on-kubernetes.html#prerequisites) instructions for detailed spark-specific cluster requirements * kubectl installed and configured in the job submission environment * Required for managing jobs and retrieving logs Build a GPU Spark Docker Image ------------------------------ Build a GPU Docker image with Spark resources in it, this Docker image must be accessible by each node in the Kubernetes cluster. 1. Locate your Spark installations. If you don't have one, you can [download](https://spark.apache.org/downloads.html) from Apache and unzip it. 2. `export SPARK_HOME=` 3. [Download the Dockerfile](/dockerfile/Dockerfile) into `${SPARK_HOME}`. (Here CUDA 11.0 is used as an example in the Dockerfile, you may need to update it for other CUDA versions.) 4. __(OPTIONAL)__ install any additional library jars into the `${SPARK_HOME}/jars` directory. * Most public cloud file systems are not natively supported -- pulling data and jar files from S3, GCS, etc. require installing additional libraries. 5. Build and push the docker image. ``` bash export SPARK_HOME= export SPARK_DOCKER_IMAGE= export SPARK_DOCKER_TAG= pushd ${SPARK_HOME} wget https://github.com/NVIDIA/spark-rapids-examples/raw/branch-25.08/dockerfile/Dockerfile # Optionally install additional jars into ${SPARK_HOME}/jars/ docker build . -t ${SPARK_DOCKER_IMAGE}:${SPARK_DOCKER_TAG} docker push ${SPARK_DOCKER_IMAGE}:${SPARK_DOCKER_TAG} popd ``` Get Jars and Dataset ------------------------------- Make sure you have prepared the necessary packages and dataset by following this [guide](/docs/get-started/xgboost-examples/prepare-package-data/preparation-scala.md). Make sure that data and jars are accessible by each node of the Kubernetes cluster via [Kubernetes volumes](https://spark.apache.org/docs/latest/running-on-kubernetes.html#using-kubernetes-volumes), on cluster filesystems like HDFS, or in [object stores like S3 and GCS](https://spark.apache.org/docs/2.3.0/cloud-integration.html). Note that using [application dependencies](https://spark.apache.org/docs/latest/running-on-kubernetes.html#dependency-management) from the submission client’s local file system is currently not yet supported. #### Note: 1. Mortgage and Taxi jobs have ETLs to generate the processed data. 2. For convenience, a subset of [Taxi](/datasets/) dataset is made available in this repo that can be readily used for launching XGBoost job. Use [ETL](#etl) to generate larger datasets for trainig and testing. 3. Agaricus does not have an ETL process, it is combined with XGBoost as there is just a filter operation. Save Kubernetes Template Resources ---------------------------------- When using Spark on Kubernetes the driver and executor pods can be launched with pod templates. In the XGBoost4J-Spark use case, these template yaml files are used to allocate and isolate specific GPUs to each pod. The following is a barebones template file to allocate 1 GPU per pod. ``` apiVersion: v1 kind: Pod spec: containers: - name: gpu-example resources: limits: nvidia.com/gpu: 1 ``` This 1 GPU template file should be sufficient for all XGBoost jobs because each executor should only run 1 task on a single GPU. Save this yaml file to the local environment of the machine you are submitting jobs from, you will need to provide a path to it as an argument in your spark-submit command. Without the template file a pod will see every GPU on the cluster node it is allocated on and can attempt to execute using a GPU which is already in use -- causing undefined behavior and errors. Launch Mortgage or Taxi ETL Part --------------------------- Use the ETL app to process raw Mortgage data. You can either use this ETLed data to split into training and evaluation data or run the ETL on different subsets of the dataset to produce training and evaluation datasets. Note: For ETL jobs, Set `spark.task.resource.gpu.amount` to `1/spark.executor.cores`. Run spark-submit ``` bash ${SPARK_HOME}/bin/spark-submit \ --conf spark.plugins=com.nvidia.spark.SQLPlugin \ --conf spark.executor.resource.gpu.amount=1 \ --conf spark.executor.cores=10 \ --conf spark.task.resource.gpu.amount=0.1 \ --conf spark.rapids.sql.incompatibleDateFormats.enabled=true \ --conf spark.rapids.sql.csv.read.double.enabled=true \ --conf spark.executor.resource.gpu.discoveryScript=./getGpusResources.sh \ --conf spark.sql.cache.serializer=com.nvidia.spark.ParquetCachedBatchSerializer \ --files $SPARK_HOME/examples/src/main/scripts/getGpusResources.sh \ --jars ${RAPIDS_JAR} \ --master \ --deploy-mode ${SPARK_DEPLOY_MODE} \ --num-executors ${SPARK_NUM_EXECUTORS} \ --driver-memory ${SPARK_DRIVER_MEMORY} \ --executor-memory ${SPARK_EXECUTOR_MEMORY} \ --class com.nvidia.spark.examples.mortgage.ETLMain \ $SAMPLE_JAR \ -format=csv \ -dataPath="data::${SPARK_XGBOOST_DIR}/mortgage/input/" \ -dataPath="out::${SPARK_XGBOOST_DIR}/mortgage/output/train/" \ -dataPath="tmp::${SPARK_XGBOOST_DIR}/mortgage/output/tmp/" # if generating eval data, change the data path to eval # -dataPath="data::${SPARK_XGBOOST_DIR}/mortgage/input/" # -dataPath="out::${SPARK_XGBOOST_DIR}/mortgage/output/eval/" # -dataPath="tmp::${SPARK_XGBOOST_DIR}/mortgage/output/tmp/" # if running Taxi ETL benchmark, change the class and data path params to # -class com.nvidia.spark.examples.taxi.ETLMain # -dataPath="raw::${SPARK_XGBOOST_DIR}/taxi/your-path" # -dataPath="out::${SPARK_XGBOOST_DIR}/taxi/your-path" ``` Launch XGBoost Part on GPU --------------------------- Variables required to run spark-submit command: ``` bash # Variables dependent on how data was made accessible to each node # Make sure to include relevant spark-submit configuration arguments # location where data was saved export DATA_PATH= # Variables independent of how data was made accessible to each node # kubernetes master URL, used as the spark master for job submission export SPARK_MASTER= # local path to the template file saved in the previous step export TEMPLATE_PATH=${HOME}/gpu_executor_template.yaml # spark docker image location export SPARK_DOCKER_IMAGE= export SPARK_DOCKER_TAG= # kubernetes service account to launch the job with export K8S_ACCOUNT= # spark deploy mode, cluster mode recommended for spark on kubernetes export SPARK_DEPLOY_MODE=cluster # run a single executor for this example to limit the number of spark tasks and # partitions to 1 as currently this number must match the number of input files export SPARK_NUM_EXECUTORS=1 # spark driver memory export SPARK_DRIVER_MEMORY=4g # spark executor memory export SPARK_EXECUTOR_MEMORY=8g # example class to use export EXAMPLE_CLASS=com.nvidia.spark.examples.mortgage.Main # or change to com.nvidia.spark.examples.taxi.Main to run Taxi Xgboost benchmark # or change to com.nvidia.spark.examples.agaricus.Main to run Agaricus Xgboost benchmark # tree construction algorithm export TREE_METHOD=gpu_hist ``` Run spark-submit: ``` bash ${SPARK_HOME}/bin/spark-submit \ --conf spark.plugins=com.nvidia.spark.SQLPlugin \ --conf spark.rapids.memory.gpu.pool=NONE \ --conf spark.executor.resource.gpu.amount=1 \ --conf spark.task.resource.gpu.amount=1 \ --conf spark.executor.resource.gpu.discoveryScript=./getGpusResources.sh \ --files $SPARK_HOME/examples/src/main/scripts/getGpusResources.sh \ --jars ${RAPIDS_JAR} \ --master ${SPARK_MASTER} \ --deploy-mode ${SPARK_DEPLOY_MODE} \ --class ${EXAMPLE_CLASS} \ --conf spark.executor.instances=${SPARK_NUM_EXECUTORS} \ --conf spark.kubernetes.authenticate.driver.serviceAccountName=${K8S_ACCOUNT} \ --conf spark.kubernetes.container.image=${SPARK_DOCKER_IMAGE}:${SPARK_DOCKER_TAG} \ --conf spark.kubernetes.driver.podTemplateFile=${TEMPLATE_PATH} \ --conf spark.kubernetes.executor.podTemplateFile=${TEMPLATE_PATH} \ --conf spark.kubernetes.authenticate.driver.serviceAccountName=spark \ ${SAMPLE_JAR} \ -dataPath=train::${SPARK_XGBOOST_DIR}/mortgage/output/train/ \ -dataPath=trans::${SPARK_XGBOOST_DIR}/mortgage/output/eval/ \ -format=parquet \ -numWorkers=${SPARK_NUM_EXECUTORS} \ -treeMethod=${TREE_METHOD} \ -numRound=100 \ -maxDepth=8 # Please make sure to change the class and data path while running Taxi or Agaricus benchmark ``` Retrieve the logs using the driver's pod name that is printed to `stdout` by spark-submit ``` export POD_NAME= kubectl logs -f ${POD_NAME} ``` In the driver log, you should see timings* (in seconds), and the accuracy metric(take Mortgage as example): ``` -------------- ==> Benchmark: Elapsed time for [Mortgage GPU train csv stub Unknown Unknown Unknown]: 30.132s -------------- -------------- ==> Benchmark: Elapsed time for [Mortgage GPU transform csv stub Unknown Unknown Unknown]: 22.352s -------------- -------------- ==> Benchmark: Accuracy for [Mortgage GPU Accuracy csv stub Unknown Unknown Unknown]: 0.9869451418401349 -------------- ``` \* Kubernetes logs may not be nicely formatted since `stdout` and `stderr` are not kept separately. \* The timings in this Getting Started guide are only for illustrative purpose. Please see our [release announcement](https://medium.com/rapids-ai/nvidia-gpus-and-apache-spark-one-step-closer-2d99e37ac8fd) for official benchmarks. ================================================ FILE: docs/get-started/xgboost-examples/on-prem-cluster/standalone-python.md ================================================ Get Started with XGBoost4J-Spark on an Apache Spark Standalone Cluster ====================================================================== This is a getting started guide to XGBoost4J-Spark on an Apache Spark 3.2+ Standalone Cluster. At the end of this guide, the user can run a sample Apache Spark Python application that runs on NVIDIA GPUs. Prerequisites ------------- * Apache Spark 3.2.0+ (e.g.: Spark 3.2.0) * Hardware Requirements * NVIDIA Pascal™ GPU architecture or better * Multi-node clusters with homogenous GPU configuration * Software Requirements * Ubuntu 20.04, 22.04/CentOS7, Rocky Linux 8 * CUDA 11.5+ * NVIDIA driver compatible with your CUDA * NCCL 2.7.8+ * Python 3.8 or 3.9 * NumPy * XGBoost 1.7.0+ * cudf-cu11 The number of GPUs in each host dictates the number of Spark executors that can run there. Additionally, cores per Spark executor and cores per Spark task must match, such that each executor can run 1 task at any given time. For example, if each host has 4 GPUs, there should be 4 or fewer executors running on each host, and each executor should run at most 1 task (e.g.: a total of 4 tasks running on 4 GPUs). In Spark Standalone mode, the default configuration is for an executor to take up all the cores assigned to each Spark Worker. In this example, we will limit the number of cores to 1, to match our dataset. Please see https://spark.apache.org/docs/latest/spark-standalone.html for more documentation regarding Standalone configuration. We use `SPARK_HOME` environment variable to point to the Apache Spark cluster. And here are the steps to enable the GPU resources discovery for Spark 3.2+. 1. Copy the spark config file from template ``` bash cd ${SPARK_HOME}/conf/ cp spark-defaults.conf.template spark-defaults.conf ``` 2. Add the following configs to the file `spark-defaults.conf`. The number in the first config should **NOT** be larger than the actual number of the GPUs on current host. This example uses 1 as below for one GPU on the host. ```bash spark.worker.resource.gpu.amount 1 spark.worker.resource.gpu.discoveryScript ${SPARK_HOME}/examples/src/main/scripts/getGpusResources.sh ``` 3. Install the XGBoost, cudf-cu11, numpy libraries on all nodes before running XGBoost application. ``` bash pip install xgboost pip install cudf-cu11 --extra-index-url=https://pypi.nvidia.com pip install numpy pip install scikit-learn ``` Get Application Files, Jar and Dataset ------------------------------- Make sure you have prepared the necessary packages and dataset by following this [guide](../prepare-package-data/preparation-python.md) #### Note: 1. Mortgage and Taxi jobs have ETLs to generate the processed data. 2. For convenience, a subset of [Taxi](/datasets/) dataset is made available in this repo that can be readily used for launching XGBoost job. Use [ETL](standalone-python.md#launch-mortgage-or-taxi-etl-part) to generate larger datasets for training and testing. 3. Agaricus does not have an ETL process, it is combined with XGBoost as there is just a filter operation. Launch a Standalone Spark Cluster --------------------------------- 1. Copy required jars to `$SPARK_HOME/jars` folder. ``` bash cp ${RAPIDS_JAR} $SPARK_HOME/jars/ ``` 2. Start the Spark Master process. ``` bash ${SPARK_HOME}/sbin/start-master.sh ``` Note the hostname or ip address of the Master host, so that it can be given to each Worker process, in this example the Master and Worker will run on the same host. 3. Start a spark slave process. ``` bash export SPARK_MASTER=spark://`hostname -f`:7077 export SPARK_CORES_PER_WORKER=1 ${SPARK_HOME}/sbin/start-slave.sh ${SPARK_MASTER} -c ${SPARK_CORES_PER_WORKER} ``` Note that in this example the Master and Worker processes are both running on the same host. This is not a requirement, as long as all hosts that are used to run the Spark app have access to the dataset. Launch Mortgage or Taxi ETL Part --------------------------- Use the ETL app to process raw Mortgage data. You can either use this ETLed data to split into training and evaluation data or run the ETL on different subsets of the dataset to produce training and evaluation datasets. Note: For ETL jobs, Set `spark.task.resource.gpu.amount` to `1/spark.executor.cores`. ### ETL on GPU ``` bash ${SPARK_HOME}/bin/spark-submit \ --master spark://$HOSTNAME:7077 \ --executor-memory 32G \ --conf spark.executor.resource.gpu.amount=1 \ --conf spark.executor.cores=10 \ --conf spark.task.resource.gpu.amount=0.1 \ --conf spark.plugins=com.nvidia.spark.SQLPlugin \ --conf spark.rapids.sql.incompatibleDateFormats.enabled=true \ --conf spark.rapids.sql.csv.read.double.enabled=true \ --conf spark.sql.cache.serializer=com.nvidia.spark.ParquetCachedBatchSerializer \ --py-files ${SAMPLE_ZIP} \ main.py \ --mainClass='com.nvidia.spark.examples.mortgage.etl_main' \ --format=csv \ --dataPath="data::${SPARK_XGBOOST_DIR}/mortgage/input/" \ --dataPath="out::${SPARK_XGBOOST_DIR}/mortgage/output/train/" \ --dataPath="tmp::${SPARK_XGBOOST_DIR}/mortgage/output/tmp/" # if generating eval data, change the data path to eval # --dataPath="data::${SPARK_XGBOOST_DIR}/mortgage/input/" # --dataPath="out::${SPARK_XGBOOST_DIR}/mortgage/output/eval/" # --dataPath="tmp::${SPARK_XGBOOST_DIR}/mortgage/output/tmp/" # if running Taxi ETL benchmark, change the class and data path params to # -class com.nvidia.spark.examples.taxi.ETLMain # -dataPath="raw::${SPARK_XGBOOST_DIR}/taxi/your-path" # -dataPath="out::${SPARK_XGBOOST_DIR}/taxi/your-path" ``` ### ETL on CPU ```bash ${SPARK_HOME}/bin/spark-submit \ --master spark://$HOSTNAME:7077 \ --executor-memory 32G \ --conf spark.executor.instances=1 \ --py-files ${SAMPLE_ZIP} \ main.py \ --mainClass='com.nvidia.spark.examples.mortgage.etl_main' \ --format=csv \ --dataPath="data::${SPARK_XGBOOST_DIR}/mortgage/input/" \ --dataPath="out::${SPARK_XGBOOST_DIR}/mortgage/output/train/" \ --dataPath="tmp::${SPARK_XGBOOST_DIR}/mortgage/output/tmp/" # if generating eval data, change the data path to eval # --dataPath="data::${SPARK_XGBOOST_DIR}/mortgage/input/" # --dataPath="out::${SPARK_XGBOOST_DIR}/mortgage/output/eval/" # --dataPath="tmp::${SPARK_XGBOOST_DIR}/mortgage/output/tmp/" # if running Taxi ETL benchmark, change the class and data path params to # -class com.nvidia.spark.examples.taxi.ETLMain # -dataPath="raw::${SPARK_XGBOOST_DIR}/taxi/your-path" # -dataPath="out::${SPARK_XGBOOST_DIR}/taxi/your-path" ``` Launch XGBoost Part on GPU --------------------------- Variables required to run spark-submit command: ``` bash # this is the same master host we defined while launching the cluster export SPARK_MASTER=spark://`hostname -f`:7077 # Currently the number of tasks and executors must match the number of input files. # For this example, we will set these such that we have 1 executor, with 1 core per executor ## take up the the whole worker export SPARK_CORES_PER_EXECUTOR=${SPARK_CORES_PER_WORKER} ## run 1 executor export SPARK_NUM_EXECUTORS=1 ## cores/executor * num_executors, which in this case is also 1, limits ## the number of cores given to the application export TOTAL_CORES=$((SPARK_CORES_PER_EXECUTOR * SPARK_NUM_EXECUTORS)) # spark driver memory export SPARK_DRIVER_MEMORY=4g # spark executor memory export SPARK_EXECUTOR_MEMORY=8g # example class to use export EXAMPLE_CLASS=com.nvidia.spark.examples.mortgage.main # or change to com.nvidia.spark.examples.taxi.main to run Taxi Xgboost benchmark # or change to com.nvidia.spark.examples.agaricus.main to run Agaricus Xgboost benchmark # tree construction algorithm export TREE_METHOD=gpu_hist # if you enable archive python environment export PYSPARK_DRIVER_PYTHON=python export PYSPARK_PYTHON=./environment/bin/python ``` Run spark-submit: ``` bash ${SPARK_HOME}/bin/spark-submit \ --conf spark.plugins=com.nvidia.spark.SQLPlugin \ --conf spark.rapids.memory.gpu.pool=NONE \ --conf spark.executor.resource.gpu.amount=1 \ --conf spark.task.resource.gpu.amount=1 \ --master ${SPARK_MASTER} \ --driver-memory ${SPARK_DRIVER_MEMORY} \ --executor-memory ${SPARK_EXECUTOR_MEMORY} \ --conf spark.cores.max=${TOTAL_CORES} \ --archives your_pyspark_venv.tar.gz#environment #if you enabled archive python environment \ --jars ${RAPIDS_JAR} \ --py-files ${SAMPLE_ZIP} \ ${MAIN_PY} \ --mainClass=${EXAMPLE_CLASS} \ --dataPath=train::${SPARK_XGBOOST_DIR}/mortgage/output/train/ \ --dataPath=trans::${SPARK_XGBOOST_DIR}/mortgage/output/eval/ \ --format=parquet \ --numWorkers=${SPARK_NUM_EXECUTORS} \ --treeMethod=${TREE_METHOD} \ --numRound=100 \ --maxDepth=8 # Change the format to csv if your input file is CSV format. # Please make sure to change the class and data path while running Taxi or Agaricus benchmark ``` In the `stdout` log on driver side, you should see timings* (in seconds), and the accuracy metric: ``` ---------------------------------------------------------------------------------------------------- Training takes 14.65 seconds ---------------------------------------------------------------------------------------------------- Transformation takes 12.21 seconds ---------------------------------------------------------------------------------------------------- Accuracy is 0.9873692247091792 ``` Launch XGBoost Part on CPU --------------------------- If you are running this example after running the GPU example above, please set these variables, to set both training and testing to run on the CPU exclusively: ``` bash # this is the same master host we defined while launching the cluster export SPARK_MASTER=spark://`hostname -f`:7077 # Currently the number of tasks and executors must match the number of input files. # For this example, we will set these such that we have 1 executor, with 1 core per executor ## take up the the whole worker export SPARK_CORES_PER_EXECUTOR=${SPARK_CORES_PER_WORKER} ## run 1 executor export SPARK_NUM_EXECUTORS=1 ## cores/executor * num_executors, which in this case is also 1, limits ## the number of cores given to the application export TOTAL_CORES=$((SPARK_CORES_PER_EXECUTOR * SPARK_NUM_EXECUTORS)) # spark driver memory export SPARK_DRIVER_MEMORY=4g # spark executor memory export SPARK_EXECUTOR_MEMORY=8g # example class to use export EXAMPLE_CLASS=com.nvidia.spark.examples.mortgage.main # Please make sure to change the class while running Taxi or Agaricus benchmark # tree construction algorithm export TREE_METHOD=hist # if you enable archive python environment export PYSPARK_DRIVER_PYTHON=python export PYSPARK_PYTHON=./environment/bin/python ``` This is the same command as for the GPU example, repeated for convenience: ``` bash ${SPARK_HOME}/bin/spark-submit \ --master ${SPARK_MASTER} \ --driver-memory ${SPARK_DRIVER_MEMORY} \ --executor-memory ${SPARK_EXECUTOR_MEMORY} \ --conf spark.cores.max=${TOTAL_CORES} \ --archives your_pyspark_venv.tar.gz#environment #if you enabled archive python environment \ --jars ${RAPIDS_JAR} \ --py-files ${SAMPLE_ZIP} \ ${SPARK_PYTHON_ENTRYPOINT} \ --mainClass=${EXAMPLE_CLASS} \ --dataPath=train::${DATA_PATH}/mortgage/output/train/ \ --dataPath=trans::${DATA_PATH}/mortgage/output/eval/ \ --format=parquet \ --numWorkers=${SPARK_NUM_EXECUTORS} \ --treeMethod=${TREE_METHOD} \ --numRound=100 \ --maxDepth=8 # Change the format to csv if your input file is CSV format. # Please make sure to change the class and data path while running Taxi or Agaricus benchmark ``` In the `stdout` log on driver side, you should see timings* (in seconds), and the accuracy metric: ``` ---------------------------------------------------------------------------------------------------- Training takes 225.7 seconds ---------------------------------------------------------------------------------------------------- Transformation takes 36.26 seconds ---------------------------------------------------------------------------------------------------- Accuracy is 0.9873709530950067 ``` * The timings in this Getting Started guide are only illustrative. Please see our [release announcement](https://medium.com/rapids-ai/nvidia-gpus-and-apache-spark-one-step-closer-2d99e37ac8fd) for official benchmarks. ================================================ FILE: docs/get-started/xgboost-examples/on-prem-cluster/standalone-scala.md ================================================ Get Started with XGBoost4J-Spark on an Apache Spark Standalone Cluster ====================================================================== This is a getting-started guide to XGBoost on an Apache Spark 3.2+ Standalone Cluster. At the end of this guide, the user can run a sample Apache Spark application that runs on NVIDIA GPUs. Prerequisites ------------- * Apache Spark 3.2.0+ Standalone Cluster (e.g.: Spark 3.2.0) * Hardware Requirements * NVIDIA Pascal™ GPU architecture or better * Multi-node clusters with homogenous GPU configuration * Software Requirements * Ubuntu 20.04, 22.04/CentOS7, Rocky Linux 8 * CUDA 11.0+ * NVIDIA driver compatible with your CUDA * NCCL 2.7.8+ The number of GPUs in each host dictates the number of Spark executors that can run there. Additionally, cores per Spark executor and cores per Spark task must match, such that each executor can run 1 task at any given time. For example, if each host has 4 GPUs, there should be 4 or fewer executors running on each host, and each executor should run at most 1 task (e.g.: a total of 4 tasks running on 4 GPUs). In Spark Standalone mode, the default configuration is for an executor to take up all the cores assigned to each Spark Worker. In this example, we will limit the number of cores to 1, to match our dataset. Please see https://spark.apache.org/docs/latest/spark-standalone.html for more documentation regarding Standalone configuration. We use `SPARK_HOME` environment variable to point to the Apache Spark cluster. And here are steps to enable the GPU resources discovery for Spark 3.2+. 1. Copy the spark configure file from template. ``` bash cd ${SPARK_HOME}/conf/ cp spark-defaults.conf.template spark-defaults.conf ``` 2. Add the following configs to the file `spark-defaults.conf`. The number in first config should NOT be larger than the actual number of the GPUs on current host. This example uses 1 as below for one GPU on the host. ``` bash spark.worker.resource.gpu.amount 1 spark.worker.resource.gpu.discoveryScript ${SPARK_HOME}/examples/src/main/scripts/getGpusResources.sh ``` Get Jars and Dataset ------------------------------- Make sure you have prepared the necessary packages and dataset by following this [guide](/docs/get-started/xgboost-examples/prepare-package-data/preparation-scala.md) #### Note: 1. Mortgage and Taxi jobs have ETLs to generate the processed data. 2. For convenience, a subset of [Taxi](/datasets/) dataset is made available in this repo that can be readily used for launching XGBoost job. Use [ETL](#etl) to generate larger datasets for trainig and testing. 3. Agaricus does not have an ETL process, it is combined with XGBoost as there is just a filter operation. Launch a Standalone Spark Cluster --------------------------------- 1. Copy required jars to `$SPARK_HOME/jars` folder. ``` bash cp $RAPIDS_JAR $SPARK_HOME/jars/ ``` 2. Start the Spark Master process. ``` bash ${SPARK_HOME}/sbin/start-master.sh ``` Note the hostname or ip address of the Master host, so that it can be given to each Worker process, in this example the Master and Worker will run on the same host. 3. Start a Spark slave process. ``` bash export SPARK_MASTER=spark://`hostname -f`:7077 export SPARK_CORES_PER_WORKER=1 ${SPARK_HOME}/sbin/start-slave.sh ${SPARK_MASTER} -c ${SPARK_CORES_PER_WORKER} ``` Note that in this example the Master and Worker processes are both running on the same host. This is not a requirement, as long as all hosts that are used to run the Spark app have access to the dataset. Launch Mortgage or Taxi ETL Part --------------------------- Use the ETL app to process raw Mortgage data. You can either use this ETLed data to split into training and evaluation data or run the ETL on different subsets of the dataset to produce training and evaluation datasets. Run spark-submit Note: For ETL jobs, Set `spark.task.resource.gpu.amount` to `1/spark.executor.cores`. ### ETL on GPU ``` bash ${SPARK_HOME}/bin/spark-submit \ --master spark://$HOSTNAME:7077 \ --executor-memory 32G \ --conf spark.executor.resource.gpu.amount=1 \ --conf spark.executor.cores=10 \ --conf spark.task.resource.gpu.amount=0.1 \ --conf spark.plugins=com.nvidia.spark.SQLPlugin \ --conf spark.rapids.sql.incompatibleDateFormats.enabled=true \ --conf spark.rapids.sql.csv.read.double.enabled=true \ --conf spark.sql.cache.serializer=com.nvidia.spark.ParquetCachedBatchSerializer \ --class com.nvidia.spark.examples.mortgage.ETLMain \ $SAMPLE_JAR \ -format=csv \ -dataPath="data::${SPARK_XGBOOST_DIR}/mortgage/input/" \ -dataPath="out::${SPARK_XGBOOST_DIR}/mortgage/output/train/" \ -dataPath="tmp::${SPARK_XGBOOST_DIR}/mortgage/output/tmp/" # if generating eval data, change the data path to eval # -dataPath="data::${SPARK_XGBOOST_DIR}/mortgage/input/" # -dataPath="out::${SPARK_XGBOOST_DIR}/mortgage/output/eval/" # -dataPath="tmp::${SPARK_XGBOOST_DIR}/mortgage/output/tmp/" # if running Taxi ETL benchmark, change the class and data path params to # -class com.nvidia.spark.examples.taxi.ETLMain # -dataPath="raw::${SPARK_XGBOOST_DIR}/taxi/your-path" # -dataPath="out::${SPARK_XGBOOST_DIR}/taxi/your-path" ``` ### ETL on CPU ```bash ${SPARK_HOME}/bin/spark-submit \ --master spark://$HOSTNAME:7077 \ --executor-memory 32G \ --conf spark.executor.instances=1 \ --conf spark.sql.broadcastTimeout=700 \ --class com.nvidia.spark.examples.mortgage.ETLMain \ $SAMPLE_JAR \ -format=csv \ -dataPath="data::${SPARK_XGBOOST_DIR}/mortgage/input/" \ -dataPath="out::${SPARK_XGBOOST_DIR}/mortgage/output/train/" \ -dataPath="tmp::${SPARK_XGBOOST_DIR}/mortgage/output/tmp/" # if generating eval data, change the data path to eval # -dataPath="data::${SPARK_XGBOOST_DIR}/mortgage/input/" # -dataPath="out::${SPARK_XGBOOST_DIR}/mortgage/output/eval/" # if running Taxi ETL benchmark, change the class and data path params to # -class com.nvidia.spark.examples.taxi.ETLMain # -dataPath="raw::${SPARK_XGBOOST_DIR}/taxi/your-path" # -dataPath="out::${SPARK_XGBOOST_DIR}/taxi/your-path" ``` Launch XGBoost Part on GPU --------------------------- Variables required to run spark-submit command: ``` bash # this is the same master host we defined while launching the cluster export SPARK_MASTER=spark://`hostname -f`:7077 # Currently the number of tasks and executors must match the number of input files. # For this example, we will set these such that we have 1 executor, with 1 core per executor ## take up the the whole worker export SPARK_CORES_PER_EXECUTOR=${SPARK_CORES_PER_WORKER} ## run 1 executor export SPARK_NUM_EXECUTORS=1 ## cores/executor * num_executors, which in this case is also 1, limits ## the number of cores given to the application export TOTAL_CORES=$((SPARK_CORES_PER_EXECUTOR * SPARK_NUM_EXECUTORS)) # spark driver memory export SPARK_DRIVER_MEMORY=4g # spark executor memory export SPARK_EXECUTOR_MEMORY=8g # example class to use export EXAMPLE_CLASS=com.nvidia.spark.examples.mortgage.Main # or change to com.nvidia.spark.examples.taxi.Main to run Taxi Xgboost benchmark # or change to com.nvidia.spark.examples.agaricus.Main to run Agaricus Xgboost benchmark # tree construction algorithm export TREE_METHOD=gpu_hist ``` Run spark-submit: ``` bash ${SPARK_HOME}/bin/spark-submit \ --conf spark.plugins=com.nvidia.spark.SQLPlugin \ --conf spark.rapids.memory.gpu.pool=NONE \ --conf spark.executor.resource.gpu.amount=1 \ --conf spark.task.resource.gpu.amount=1 \ --master ${SPARK_MASTER} \ --driver-memory ${SPARK_DRIVER_MEMORY} \ --executor-memory ${SPARK_EXECUTOR_MEMORY} \ --conf spark.cores.max=${TOTAL_CORES} \ --class ${EXAMPLE_CLASS} \ ${SAMPLE_JAR} \ -dataPath=train::${SPARK_XGBOOST_DIR}/mortgage/output/train/ \ -dataPath=trans::${SPARK_XGBOOST_DIR}/mortgage/output/eval/ \ -format=parquet \ -numWorkers=${SPARK_NUM_EXECUTORS} \ -treeMethod=${TREE_METHOD} \ -numRound=100 \ -maxDepth=8 # Please make sure to change the class and data path while running Taxi or Agaricus benchmark ``` In `stdout` log on driver side, you should see timings* (in seconds), and the accuracy metric(take Mortgage as example): ``` -------------- ==> Benchmark: Elapsed time for [Mortgage GPU train csv stub Unknown Unknown Unknown]: 26.572s -------------- -------------- ==> Benchmark: Elapsed time for [Mortgage GPU transform csv stub Unknown Unknown Unknown]: 10.323s -------------- -------------- ==> Benchmark: Accuracy for [Mortgage GPU Accuracy csv stub Unknown Unknown Unknown]: 0.9869227318579323 -------------- ``` Launch XGBoost Part on CPU --------------------------- If you are running this example after running the GPU example above, please set these variables, to set both training and testing to run on the CPU exclusively: ``` bash # this is the same master host we defined while launching the cluster export SPARK_MASTER=spark://`hostname -f`:7077 # Currently the number of tasks and executors must match the number of input files. # For this example, we will set these such that we have 1 executor, with 1 core per executor ## take up the the whole worker export SPARK_CORES_PER_EXECUTOR=${SPARK_CORES_PER_WORKER} ## run 1 executor export SPARK_NUM_EXECUTORS=1 ## cores/executor * num_executors, which in this case is also 1, limits ## the number of cores given to the application export TOTAL_CORES=$((SPARK_CORES_PER_EXECUTOR * SPARK_NUM_EXECUTORS)) # spark driver memory export SPARK_DRIVER_MEMORY=4g # spark executor memory export SPARK_EXECUTOR_MEMORY=8g # example class to use export EXAMPLE_CLASS=com.nvidia.spark.examples.mortgage.Main # Please make sure to change the class while running Taxi or Agaricus benchmark # tree construction algorithm export TREE_METHOD=hist ``` This is the same command as for the GPU example, repeated for convenience: ```bash ${SPARK_HOME}/bin/spark-submit \ --master ${SPARK_MASTER} \ --driver-memory ${SPARK_DRIVER_MEMORY} \ --executor-memory ${SPARK_EXECUTOR_MEMORY} \ --conf spark.cores.max=${TOTAL_CORES} \ --class ${EXAMPLE_CLASS} \ ${SAMPLE_JAR} \ -dataPath=train::${SPARK_XGBOOST_DIR}/mortgage/output/train/ \ -dataPath=trans::${SPARK_XGBOOST_DIR}/mortgage/output/eval/ \ -format=parquet \ -numWorkers=${SPARK_NUM_EXECUTORS} \ -treeMethod=${TREE_METHOD} \ -numRound=100 \ -maxDepth=8 # Please make sure to change the class and data path while running Taxi or Agaricus benchmark ``` In the `stdout` log on driver side, you should see timings* (in seconds), and the accuracy metric(take Mortgage as example): ``` -------------- ==> Benchmark: Elapsed time for [Mortgage CPU train csv stub Unknown Unknown Unknown]: 305.535s -------------- -------------- ==> Benchmark: Elapsed time for [Mortgage CPU transform csv stub Unknown Unknown Unknown]: 52.867s -------------- -------------- ==> Benchmark: Accuracy for [Mortgage CPU Accuracy csv stub Unknown Unknown Unknown]: 0.9872234894511343 -------------- ``` * The timings in this Getting Started guide are only for illustrative purpose. Please see our [release announcement](https://medium.com/rapids-ai/nvidia-gpus-and-apache-spark-one-step-closer-2d99e37ac8fd) for official benchmarks. ================================================ FILE: docs/get-started/xgboost-examples/on-prem-cluster/yarn-python.md ================================================ Get Started with XGBoost4J-Spark on Apache Hadoop YARN ====================================================== This is a getting started guide to XGBoost4J-Spark on Apache Hadoop YARN supporting GPU scheduling. At the end of this guide, the reader will be able to run a sample Apache Spark Python application that runs on NVIDIA GPUs. Prerequisites ------------- * Apache Spark 3.2.0+ running on YARN supporting GPU scheduling. (e.g.: Spark 3.2.0, Hadoop-Yarn 3.3.0) * Hardware Requirements * NVIDIA Pascal™ GPU architecture or better * Multi-node clusters with homogenous GPU configuration * Software Requirements * Ubuntu 20.04, 22.04/CentOS7, Rocky Linux 8 * CUDA 11.5+ * NVIDIA driver compatible with your CUDA * NCCL 2.7.8+ * Python 3.8 or 3.9 * NumPy * XGBoost 1.7.0+ * cudf-cu11 The number of GPUs per NodeManager dictates the number of Spark executors that can run in that NodeManager. Additionally, cores per Spark executor and cores per Spark task must match, such that each executor can run 1 task at any given time. For example: if each NodeManager has 4 GPUs, there should be 4 or fewer executors running on each NodeManager, and each executor should run 1 task (e.g.: A total of 4 tasks running on 4 GPUs). In order to achieve this, you may need to adjust `spark.task.cpus` and `spark.executor.cores` to match (both set to 1 by default). Additionally, we recommend adjusting `executor-memory` to divide host memory evenly amongst the number of GPUs in each NodeManager, such that Spark will schedule as many executors as there are GPUs in each NodeManager. We use `SPARK_HOME` environment variable to point to the Apache Spark cluster. And as to how to enable GPU scheduling and isolation for Yarn, please refer to [here](https://hadoop.apache.org/docs/r3.1.0/hadoop-yarn/hadoop-yarn-site/UsingGpus.html). Please make sure to install the XGBoost, cudf-cu11, numpy libraries on all nodes before running XGBoost application. ``` bash pip install xgboost pip install cudf-cu11 --extra-index-url=https://pypi.nvidia.com pip install numpy pip install scikit-learn ``` You can also create an isolated python environment by using [Virtualenv](https://virtualenv.pypa.io/en/latest/), and then directly pass/unpack the archive file and enable the environment on executors by leveraging the --archives option or spark.archives configuration. ``` bash # create an isolated python environment and install libraries python -m venv pyspark_venv source pyspark_venv/bin/activate pip install xgboost pip install cudf-cu11 --extra-index-url=https://pypi.nvidia.com pip install numpy pip install scikit-learn venv-pack -o pyspark_venv.tar.gz # enable archive python environment on executors export PYSPARK_DRIVER_PYTHON=python # Do not set in cluster modes. export PYSPARK_PYTHON=./environment/bin/python spark-submit --archives pyspark_venv.tar.gz#environment app.py ``` Get Application Files, Jar and Dataset ------------------------------- Make sure you have prepared the necessary packages and dataset by following this [guide](../prepare-package-data/preparation-python.md) Then create a directory in HDFS, and run below commands, ``` bash [xgboost4j_spark_python]$ hadoop fs -mkdir /tmp/xgboost4j_spark_python [xgboost4j_spark_python]$ hadoop fs -copyFromLocal ${SPARK_XGBOOST_DIR}/mortgage/* /tmp/xgboost4j_spark_python ``` Launch Mortgage or Taxi ETL Part --------------------------- Use the ETL app to process raw Mortgage data. You can either use this ETLed data to split into training and evaluation data or run the ETL on different subsets of the dataset to produce training and evaluation datasets. Note: For ETL jobs, Set `spark.task.resource.gpu.amount` to `1/spark.executor.cores`. ``` bash # location where data was downloaded export DATA_PATH=hdfs:/tmp/xgboost4j_spark_python/ ${SPARK_HOME}/bin/spark-submit \ --master yarn \ --deploy-mode cluster \ --conf spark.executor.cores=10 \ --conf spark.task.resource.gpu.amount=0.1 \ --conf spark.rapids.sql.incompatibleDateFormats.enabled=true \ --conf spark.rapids.sql.csv.read.double.enabled=true \ --conf spark.sql.cache.serializer=com.nvidia.spark.ParquetCachedBatchSerializer \ --jars ${RAPIDS_JAR}\ ${MAIN_PY} \ --mainClass='com.nvidia.spark.examples.mortgage.etl_main' \ --format=csv \ --dataPath="data::${DATA_PATH}/mortgage/data/mortgage/input/" \ --dataPath="out::${DATA_PATH}/mortgage/data/mortgage/output/train/" \ --dataPath="tmp::${SPARK_XGBOOST_DIR}/mortgage/output/tmp/" # if generating eval data, change the data path to eval # --dataPath="data::${SPARK_XGBOOST_DIR}/mortgage/input/" # --dataPath="out::${SPARK_XGBOOST_DIR}/mortgage/output/eval/" # --dataPath="tmp::${SPARK_XGBOOST_DIR}/mortgage/output/tmp/" # if running Taxi ETL benchmark, change the class and data path params to # -class com.nvidia.spark.examples.taxi.ETLMain # -dataPath="raw::${SPARK_XGBOOST_DIR}/taxi/your-path" # -dataPath="out::${SPARK_XGBOOST_DIR}/taxi/your-path" ``` Launch XGBoost Part on GPU --------------------------- Variables required to run spark-submit command: ``` bash # location where data was downloaded export DATA_PATH=hdfs:/tmp/xgboost4j_spark_python # spark deploy mode (see Apache Spark documentation for more information) export SPARK_DEPLOY_MODE=cluster # run a single executor for this example to limit the number of spark tasks and # partitions to 1 as currently this number must match the number of input files export SPARK_NUM_EXECUTORS=1 # spark driver memory export SPARK_DRIVER_MEMORY=4g # spark executor memory export SPARK_EXECUTOR_MEMORY=8g # python entrypoint export SPARK_PYTHON_ENTRYPOINT=${LIBS_PATH}/main.py # example class to use export EXAMPLE_CLASS=com.nvidia.spark.examples.mortgage.main # or change to com.nvidia.spark.examples.taxi.main to run Taxi Xgboost benchmark # or change to com.nvidia.spark.examples.agaricus.main to run Agaricus Xgboost benchmark # tree construction algorithm export TREE_METHOD=gpu_hist # if you enable archive python environment export PYSPARK_DRIVER_PYTHON=python export PYSPARK_PYTHON=./environment/bin/python ``` Run spark-submit: ``` bash ${SPARK_HOME}/bin/spark-submit \ --conf spark.plugins=com.nvidia.spark.SQLPlugin \ --conf spark.rapids.memory.gpu.pool=NONE \ --conf spark.executor.resource.gpu.amount=1 \ --conf spark.task.resource.gpu.amount=1 \ --conf spark.executor.resource.gpu.discoveryScript=./getGpusResources.sh \ --files ${SPARK_HOME}/examples/src/main/scripts/getGpusResources.sh \ --master yarn \ --deploy-mode ${SPARK_DEPLOY_MODE} \ --archives your_pyspark_venv.tar.gz#environment #if you enabled archive python environment \ --num-executors ${SPARK_NUM_EXECUTORS} \ --driver-memory ${SPARK_DRIVER_MEMORY} \ --executor-memory ${SPARK_EXECUTOR_MEMORY} \ --jars ${RAPIDS_JAR} \ --py-files ${SAMPLE_ZIP} \ ${MAIN_PY} \ --mainClass=${EXAMPLE_CLASS} \ --dataPath=train::${DATA_PATH}/mortgage/out/train/ \ --dataPath=trans::${DATA_PATH}/mortgage/out/eval/ \ --format=parquet \ --numWorkers=${SPARK_NUM_EXECUTORS} \ --treeMethod=${TREE_METHOD} \ --numRound=100 \ --maxDepth=8 # Change the format to csv if your input file is CSV format. # Please make sure to change the class and data path while running Taxi or Agaricus benchmark ``` In the `stdout` driver log, you should see timings* (in seconds), and the accuracy metric: ``` ---------------------------------------------------------------------------------------------------- Training takes 10.75 seconds ---------------------------------------------------------------------------------------------------- Transformation takes 4.38 seconds ---------------------------------------------------------------------------------------------------- Accuracy is 0.997544753891 ``` Launch XGBoost Part on CPU --------------------------- If you are running this example after running the GPU example above, please set these variables, to set both training and testing to run on the CPU exclusively: ``` bash # location where data was downloaded export DATA_PATH=hdfs:/tmp/xgboost4j_spark_python/ # spark deploy mode (see Apache Spark documentation for more information) export SPARK_DEPLOY_MODE=cluster # run a single executor for this example to limit the number of spark tasks and # partitions to 1 as currently this number must match the number of input files export SPARK_NUM_EXECUTORS=1 # spark driver memory export SPARK_DRIVER_MEMORY=4g # spark executor memory export SPARK_EXECUTOR_MEMORY=8g # example class to use export EXAMPLE_CLASS=com.nvidia.spark.examples.mortgage.main # or change to com.nvidia.spark.examples.taxi.main to run Taxi Xgboost benchmark # or change to com.nvidia.spark.examples.agaricus.main to run Agaricus Xgboost benchmark # tree construction algorithm export TREE_METHOD=hist # if you enable archive python environment export PYSPARK_DRIVER_PYTHON=python export PYSPARK_PYTHON=./environment/bin/python ``` This is the same command as for the GPU example, repeated for convenience: ``` bash ${SPARK_HOME}/bin/spark-submit \ --master yarn \ --archives your_pyspark_venv.tar.gz#environment #if you enabled archive python environment \ --deploy-mode ${SPARK_DEPLOY_MODE} \ --num-executors ${SPARK_NUM_EXECUTORS} \ --driver-memory ${SPARK_DRIVER_MEMORY} \ --executor-memory ${SPARK_EXECUTOR_MEMORY} \ --jars ${RAPIDS_JAR} \ --py-files ${SAMPLE_ZIP} \ ${MAIN_PY} \ --mainClass=${EXAMPLE_CLASS} \ --dataPath=train::${DATA_PATH}/mortgage/output/train/ \ --dataPath=trans::${DATA_PATH}/mortgage/output/eval/ \ --format=parquet \ --numWorkers=${SPARK_NUM_EXECUTORS} \ --treeMethod=${TREE_METHOD} \ --numRound=100 \ --maxDepth=8 # Please make sure to change the class and data path while running Taxi or Agaricus benchmark ``` In the `stdout` driver log, you should see timings* (in seconds), and the accuracy metric: ``` ---------------------------------------------------------------------------------------------------- Training takes 10.76 seconds ---------------------------------------------------------------------------------------------------- Transformation takes 1.25 seconds ---------------------------------------------------------------------------------------------------- Accuracy is 0.998526852335 ``` * The timings in this Getting Started guide are only for illustrative purpose. Please see our [release announcement](https://medium.com/rapids-ai/nvidia-gpus-and-apache-spark-one-step-closer-2d99e37ac8fd) for official benchmarks. ================================================ FILE: docs/get-started/xgboost-examples/on-prem-cluster/yarn-scala.md ================================================ Get Started with XGBoost4J-Spark on Apache Hadoop YARN ====================================================== This is a getting started guide to XGBoost4J-Spark on Apache Hadoop YARN supporting GPU scheduling. At the end of this guide, the reader will be able to run a sample Apache Spark application that runs on NVIDIA GPUs. Prerequisites ------------- * Apache Spark 3.2.0+ running on YARN supporting GPU scheduling. (e.g.: Spark 3.2.0, Hadoop-Yarn 3.3.0) * Hardware Requirements * NVIDIA Pascal™ GPU architecture or better * Multi-node clusters with homogenous GPU configuration * Software Requirements * Ubuntu 20.04, 22.04/CentOS7, Rocky Linux 8 * CUDA 11.0+ * NVIDIA driver compatible with your CUDA * NCCL 2.7.8+ The number of GPUs per NodeManager dictates the number of Spark executors that can run in that NodeManager. Additionally, cores per Spark executor and cores per Spark task must match, such that each executor can run 1 task at any given time. For example: if each NodeManager has 4 GPUs, there should be 4 or fewer executors running on each NodeManager, and each executor should run 1 task (e.g.: A total of 4 tasks running on 4 GPUs). In order to achieve this, you may need to adjust `spark.task.cpus` and `spark.executor.cores` to match (both set to 1 by default). Additionally, we recommend adjusting `executor-memory` to divide host memory evenly amongst the number of GPUs in each NodeManager, such that Spark will schedule as many executors as there are GPUs in each NodeManager. We use `SPARK_HOME` environment variable to point to the Apache Spark cluster. And as to how to enable GPU scheduling and isolation for Yarn, please refer to [here](https://hadoop.apache.org/docs/r3.1.0/hadoop-yarn/hadoop-yarn-site/UsingGpus.html). Get Jars and Dataset ------------------------------- Make sure you have prepared the necessary packages and dataset by following this [guide](/docs/get-started/xgboost-examples/prepare-package-data/preparation-scala.md) #### Note: 1. Mortgage and Taxi jobs have ETLs to generate the processed data. 2. For convenience, a subset of [Taxi](/datasets/) dataset is made available in this repo that can be readily used for launching XGBoost job. Use [ETL](#etl) to generate larger datasets for trainig and testing. 3. Agaricus does not have an ETL process, it is combined with XGBoost as there is just a filter operation. Create a directory in HDFS, and copy: ``` bash [xgboost4j_spark]$ hadoop fs -mkdir /tmp/xgboost4j_spark [xgboost4j_spark]$ hadoop fs -copyFromLocal ${SPARK_XGBOOST_DIR}/mortgage/* /tmp/xgboost4j_spark ``` Launch Mortgage or Taxi ETL Part --------------------------- Use the ETL app to process raw Mortgage data. You can either use this ETLed data to split into training and evaluation data or run the ETL on different subsets of the dataset to produce training and evaluation datasets. Note: For ETL jobs, Set `spark.task.resource.gpu.amount` to `1/spark.executor.cores`. Run spark-submit ``` bash ${SPARK_HOME}/bin/spark-submit \ --conf spark.plugins=com.nvidia.spark.SQLPlugin \ --conf spark.executor.resource.gpu.amount=1 \ --conf spark.executor.cores=10 \ --conf spark.task.resource.gpu.amount=0.1 \ --conf spark.rapids.sql.incompatibleDateFormats.enabled=true \ --conf spark.rapids.sql.csv.read.double.enabled=true \ --conf spark.executor.resource.gpu.discoveryScript=./getGpusResources.sh \ --conf spark.sql.cache.serializer=com.nvidia.spark.ParquetCachedBatchSerializer \ --files $SPARK_HOME/examples/src/main/scripts/getGpusResources.sh \ --jars ${RAPIDS_JAR} \ --master yarn \ --deploy-mode ${SPARK_DEPLOY_MODE} \ --num-executors ${SPARK_NUM_EXECUTORS} \ --driver-memory ${SPARK_DRIVER_MEMORY} \ --executor-memory ${SPARK_EXECUTOR_MEMORY} \ --class com.nvidia.spark.examples.mortgage.ETLMain \ $SAMPLE_JAR \ -format=csv \ -dataPath="data::${SPARK_XGBOOST_DIR}/mortgage/input/" \ -dataPath="out::${SPARK_XGBOOST_DIR}/mortgage/output/train/" \ -dataPath="tmp::${SPARK_XGBOOST_DIR}/mortgage/output/tmp/" # if generating eval data, change the data path to eval # -dataPath="data::${SPARK_XGBOOST_DIR}/mortgage/input/" # -dataPath="out::${SPARK_XGBOOST_DIR}/mortgage/output/eval/" # -dataPath="tmp::${SPARK_XGBOOST_DIR}/mortgage/output/tmp/" # if running Taxi ETL benchmark, change the class and data path params to # -class com.nvidia.spark.examples.taxi.ETLMain # -dataPath="raw::${SPARK_XGBOOST_DIR}/taxi/your-path" # -dataPath="out::${SPARK_XGBOOST_DIR}/taxi/your-path" ``` Launch XGBoost Part on GPU --------------------------- Variables required to run spark-submit command: ``` bash # location where data was downloaded export DATA_PATH=hdfs:/tmp/xgboost4j_spark/data # spark deploy mode (see Apache Spark documentation for more information) export SPARK_DEPLOY_MODE=cluster # run a single executor for this example to limit the number of spark tasks and # partitions to 1 as currently this number must match the number of input files export SPARK_NUM_EXECUTORS=1 # spark driver memory export SPARK_DRIVER_MEMORY=4g # spark executor memory export SPARK_EXECUTOR_MEMORY=8g # example class to use export EXAMPLE_CLASS=com.nvidia.spark.examples.mortgage.Main # or change to com.nvidia.spark.examples.taxi.Main to run Taxi Xgboost benchmark # or change to com.nvidia.spark.examples.agaricus.Main to run Agaricus Xgboost benchmark # tree construction algorithm export TREE_METHOD=gpu_hist ``` Run spark-submit: ``` bash ${SPARK_HOME}/bin/spark-submit \ --conf spark.plugins=com.nvidia.spark.SQLPlugin \ --conf spark.rapids.memory.gpu.pool=NONE \ --conf spark.executor.resource.gpu.amount=1 \ --conf spark.task.resource.gpu.amount=1 \ --conf spark.executor.resource.gpu.discoveryScript=./getGpusResources.sh \ --files $SPARK_HOME/examples/src/main/scripts/getGpusResources.sh \ --jars ${RAPIDS_JAR} \ --master yarn \ --deploy-mode ${SPARK_DEPLOY_MODE} \ --num-executors ${SPARK_NUM_EXECUTORS} \ --driver-memory ${SPARK_DRIVER_MEMORY} \ --executor-memory ${SPARK_EXECUTOR_MEMORY} \ --class ${EXAMPLE_CLASS} \ ${SAMPLE_JAR} \ -dataPath=train::${SPARK_XGBOOST_DIR}/mortgage/output/train/ \ -dataPath=trans::${SPARK_XGBOOST_DIR}/mortgage/output/eval/ \ -format=parquet \ -numWorkers=${SPARK_NUM_EXECUTORS} \ -treeMethod=${TREE_METHOD} \ -numRound=100 \ -maxDepth=8 # Please make sure to change the class and data path while running Taxi or Agaricus benchmark ``` In the `stdout` driver log, you should see timings* (in seconds), and the accuracy metric(take Mortgage as example): ``` -------------- ==> Benchmark: Elapsed time for [Mortgage GPU train csv stub Unknown Unknown Unknown]: 29.642s -------------- -------------- ==> Benchmark: Elapsed time for [Mortgage GPU transform csv stub Unknown Unknown Unknown]: 21.272s -------------- -------------- ==> Benchmark: Accuracy for [Mortgage GPU Accuracy csv stub Unknown Unknown Unknown]: 0.9874184013493451 -------------- ``` Launch XGBoost Part on CPU --------------------------- If you are running this example after running the GPU example above, please set these variables, to set both training and testing to run on the CPU exclusively: ``` bash # location where data was downloaded export DATA_PATH=hdfs:/tmp/xgboost4j_spark/data # spark deploy mode (see Apache Spark documentation for more information) export SPARK_DEPLOY_MODE=cluster # run a single executor for this example to limit the number of spark tasks and # partitions to 1 as currently this number must match the number of input files export SPARK_NUM_EXECUTORS=1 # spark driver memory export SPARK_DRIVER_MEMORY=4g # spark executor memory export SPARK_EXECUTOR_MEMORY=8g # example class to use export EXAMPLE_CLASS=com.nvidia.spark.examples.mortgage.Main # Please make sure to change the class while running Taxi or Agaricus benchmark # tree construction algorithm export TREE_METHOD=hist ``` This is the same command as for the GPU example, repeated for convenience: ``` bash ${SPARK_HOME}/bin/spark-submit \ --master yarn \ --deploy-mode ${SPARK_DEPLOY_MODE} \ --num-executors ${SPARK_NUM_EXECUTORS} \ --driver-memory ${SPARK_DRIVER_MEMORY} \ --executor-memory ${SPARK_EXECUTOR_MEMORY} \ --class ${EXAMPLE_CLASS} \ ${SAMPLE_JAR} \ -dataPath=train::${SPARK_XGBOOST_DIR}/mortgage/output/train/ \ -dataPath=trans::${SPARK_XGBOOST_DIR}/mortgage/output/eval/ \ -format=parquet \ -numWorkers=${SPARK_NUM_EXECUTORS} \ -treeMethod=${TREE_METHOD} \ -numRound=100 \ -maxDepth=8 # Please make sure to change the class and data path while running Taxi or Agaricus benchmark ``` In the `stdout` driver log, you should see timings* (in seconds), and the accuracy metric(take Mortgage as example): ``` -------------- ==> Benchmark: Elapsed time for [Mortgage CPU train csv stub Unknown Unknown Unknown]: 286.398s -------------- -------------- ==> Benchmark: Elapsed time for [Mortgage CPU transform csv stub Unknown Unknown Unknown]: 49.836s -------------- -------------- ==> Benchmark: Accuracy for [Mortgage CPU Accuracy csv stub Unknown Unknown Unknown]: 0.9873709530950067 -------------- ``` * The timings in this Getting Started guide are only for illustrative purpose. Please see our [release announcement](https://medium.com/rapids-ai/nvidia-gpus-and-apache-spark-one-step-closer-2d99e37ac8fd) for official benchmarks. ================================================ FILE: docs/get-started/xgboost-examples/prepare-package-data/preparation-python.md ================================================ ## Prepare packages and dataset for pyspark For simplicity export the location to these jars. All examples assume the packages and dataset will be placed in the `/opt/xgboost` directory: ### Download the jars Download the RAPIDS Accelerator for Apache Spark plugin jar * [RAPIDS Spark Package](https://repo1.maven.org/maven2/com/nvidia/rapids-4-spark_2.12/26.02.0/rapids-4-spark_2.12-26.02.0.jar) ### Build XGBoost Python Examples Following this [guide](/docs/get-started/xgboost-examples/building-sample-apps/python.md), you can get *samples.zip* and *main.py* and copy them to `/opt/xgboost` ### Download dataset You need to copy the dataset to `/opt/xgboost`. Use the following links to download the data. 1. [Mortgage dataset](/docs/get-started/xgboost-examples/dataset/mortgage.md) 2. [Taxi dataset](https://www1.nyc.gov/site/tlc/about/tlc-trip-record-data.page) 3. [Agaricus dataset](https://github.com/dmlc/xgboost/tree/master/demo/data) ================================================ FILE: docs/get-started/xgboost-examples/prepare-package-data/preparation-scala.md ================================================ ## Prepare packages and dataset for scala For simplicity export the location to these jars. All examples assume the packages and dataset will be placed in the `/opt/xgboost` directory: ### Download the jars 1. Download the RAPIDS Accelerator for Apache Spark plugin jar * [RAPIDS Spark Package](https://repo1.maven.org/maven2/com/nvidia/rapids-4-spark_2.12/26.02.0/rapids-4-spark_2.12-26.02.0.jar) ### Build XGBoost Scala Examples Following this [guide](/docs/get-started/xgboost-examples/building-sample-apps/scala.md), you can get *sample_xgboost_apps-0.2.3-jar-with-dependencies.jar* and copy it to `/opt/xgboost` ### Download dataset You need to copy the dataset to `/opt/xgboost`. Use the following links to download the data. 1. [Mortgage dataset](/docs/get-started/xgboost-examples/dataset/mortgage.md) 2. [Taxi dataset](https://www1.nyc.gov/site/tlc/about/tlc-trip-record-data.page) 3. [Agaricus dataset](https://github.com/dmlc/xgboost/tree/master/demo/data) ================================================ FILE: docs/trouble-shooting/xgboost-examples-trouble-shooting.md ================================================ ## XGBoost ### 1. NCCL errors XGBoost supports distributed GPU training which depends on NCCL2 available at [this link](https://developer.nvidia.com/nccl). NCCL auto-detects which network interfaces to use for inter-node communication. If some interfaces are in state up, however are not able to communicate between nodes, NCCL may try to use them anyway and therefore fail during the init functions or **even hang**. To track NCCL error, User needs to enable NCCL_DEBUG when submitting spark application by ``` xml --conf spark.executorEnv.NCCL_DEBUG=INFO ``` Sometimes, Node tries to connect to another node which selects an inappropriate interface, which may cause xgboost task hang. To fix this kind of issue, User needs to specify an appropriate interface for the node by NCCL_SOCKET_IFNAME ``` xml --conf spark.executorEnv.NCCL_SOCKET_IFNAME=eth0 ``` ================================================ FILE: examples/MIG-Support/README.md ================================================ # Multi-Instance GPU (MIG) support in Apache Hadoop YARN There are multiple solutions for MIG scheduling on YARN that you can choose based on your environment and deployment requirements: - [YARN 3.3.0+ MIG GPU Plugin](/examples/MIG-Support/device-plugins/gpu-mig) for adding a Java-based plugin for MIG on top of the Pluggable Device Framework - [YARN 3.1.2 until YARN 3.3.0 MIG GPU Support](/examples/MIG-Support/resource-types/gpu-mig) for patching and rebuilding YARN code base to support MIG devices. - [YARN 3.1.2+ MIG GPU Support without modifying YARN / Device Plugin Code](/examples/MIG-Support/yarn-unpatched) relying on installing nvidia CLI wrappers written in `bash`, but unlike the solutions above without any Java code changes. ## Limitations and Caveats Note that are some common caveats for the solutions above. ### Single MIG GPU per Container Please see the [MIG Application Considerations](https://docs.nvidia.com/datacenter/tesla/mig-user-guide/#app-considerations) and [CUDA Device Enumeration](https://docs.nvidia.com/datacenter/tesla/mig-user-guide/index.html#cuda-visible-devices). It is important to note that CUDA 11 only supports enumeration of a single MIG instance. It is recommended that you configure YARN to only allow a single GPU be requested. See the YARN config `yarn.resource-types.nvidia/miggpu.maximum-allocation` for the [Pluggable Device Framework](/examples/MIG-Support/device-plugins/gpu-mig) solution and `yarn.resource-types.yarn.io/gpu.maximum-allocation` for the remainder of MIG Support options above, respectively. ### Metrics Some metrics are not and cannot be broken down by MIG device. For example, `utilization` is the aggregate utilization of the parent GPU, and there is no attribution of `temperature` to a particular MIG device. ### GPU index / address as reported by Apache Spark in logs and UI With YARN isolation using NVIDIA Container Runtime ensuring a single visible device per Docker container running a Spark Executor, each Executor will see a disjoint list comprising a single device. Therefore, the user will end up observing index 0 being used by all executors. However, they refer to different GPU/MIG instances. You can verify this by running something like the following on a YARN worker node host OS: ```bash for cid in $(sudo docker ps -q); do sudo docker exec $cid bash -c "printenv | grep VISIBLE; nvidia-smi -L"; done NVIDIA_VISIBLE_DEVICES=3 GPU 0: NVIDIA A30 (UUID: GPU-05aa99be-b706-0dc1-ab62-dd12f2227b7d) MIG 1g.6gb Device 0: (UUID: MIG-70dc024a-e8d7-587c-81dd-57ad493b1d91) NVIDIA_VISIBLE_DEVICES=1 GPU 0: NVIDIA A30 (UUID: GPU-05aa99be-b706-0dc1-ab62-dd12f2227b7d) MIG 1c.2g.12gb Device 0: (UUID: MIG-54cc2421-6f2d-59e9-b074-20707aadd71e) NVIDIA_VISIBLE_DEVICES=2 GPU 0: NVIDIA A30 (UUID: GPU-05aa99be-b706-0dc1-ab62-dd12f2227b7d) MIG 1g.6gb Device 0: (UUID: MIG-7e5552bf-d328-57a8-b091-0720d4530ffb) NVIDIA_VISIBLE_DEVICES=0 GPU 0: NVIDIA A30 (UUID: GPU-05aa99be-b706-0dc1-ab62-dd12f2227b7d) MIG 1c.2g.12gb Device 0: (UUID: MIG-e6af58f0-9af8-594f-825e-74d23e1a68c1) ``` ================================================ FILE: examples/MIG-Support/device-plugins/gpu-mig/README.md ================================================ # NVIDIA GPU Plugin for YARN with MIG support for YARN 3.3.0+ This plugin adds support for GPUs with [MIG](https://docs.nvidia.com/datacenter/tesla/mig-user-guide/) on YARN. The built-in YARN GPU plugin does not support MIG enabled GPUs. This plugin also works with GPUs without MIG or GPUs with MIG disabled but the limitation section still applies. It supports heterogenous environments where there may be some MIG enabled GPUs and some without MIG. If you are not using MIG enabled GPUs, you should use the built-in YARN GPU plugin. ## Compatibility It works with Apache YARN 3.3.0+ versions that support the [Pluggable Device Framework](https://hadoop.apache.org/docs/current/hadoop-yarn/hadoop-yarn-site/PluggableDeviceFramework.html). This plugin requires YARN to be configured with Docker using the NVIDIA Container Toolkit (nvidia-docker2). ## Limitations Please see the [MIG Application Considerations](https://docs.nvidia.com/datacenter/tesla/mig-user-guide/#app-considerations) and [CUDA Device Enumeration](https://docs.nvidia.com/datacenter/tesla/mig-user-guide/index.html#cuda-visible-devices). It is important to note that CUDA 11 only supports enumeration of a single MIG instance. This means that this plugin only supports 1 GPU per container and the plugin will throw an exception by default if you request more. It is recommended that you configure YARN to only allow a single GPU be requested. See the yarn config: ``` yarn.resource-types.nvidia/miggpu.maximum-allocation ``` See [YARN Resource Configuration](https://hadoop.apache.org/docs/r3.3.1/hadoop-yarn/hadoop-yarn-site/ResourceModel.html) for more details. If you do not configure the maximum allocation and someone requests multiple GPUs, the default behavior is to throw an exception. The user visible exception is not very useful, as the real exception will be in the nodemanager logs. See the [Configuration](#configuration) section for options if it throws an exception. ## Building From Source ``` mvn package ``` This will create a jar `target/yarn-gpu-mig-plugin-1.0.0.jar`. This jar can be installed on your YARN cluster as a plugin. ## Installation These instructions assume YARN is already installed and configured with Docker enabled using the NVIDIA Container Toolkit (nvidia-docker2). Enable and configure your [GPUs with MIG](https://docs.nvidia.com/datacenter/tesla/mig-user-guide/index.html) on all of the nodes it applies to. Install the jar into your Hadoop Cluster, see the [Test and Use Your Own Plugin](https://hadoop.apache.org/docs/current/hadoop-yarn/hadoop-yarn-site/DevelopYourOwnDevicePlugin.html) section. This recommends installing it in something like `$HADOOP_COMMOND_HOME/share/hadoop/yarn`. Configure the device plugin, see the YARN documentation on [Pluggable Device Framework](https://hadoop.apache.org/docs/current/hadoop-yarn/hadoop-yarn-site/PluggableDeviceFramework.html). After enabling the framework, enable the plugin in `yarn-site.xml`: ``` yarn.nodemanager.pluggable-device-framework.device-classes com.nvidia.spark.NvidiaGPUMigPluginForRuntimeV2 ``` Configure YARN to have the new resource type by modifying the `resource-types.xml` file to include: ``` yarn.resource-types nvidia/miggpu ``` Restart YARN to pick up any configuration changes. ## Configuration To change the behavior of throwing when the user allocates multiple GPUs, you can either set a config in the `yarn-site.xml` or set an environment variable when launching the Spark application. The environment variable will take precendence if both are set. In either case, `true` means to throw if a user requests multiple GPUs (this is the default), `false` means it won't throw and if the container is allocated with multiple MIG devices from the same GPU, it is up to the application to know how to use them. Config for `yarn-site.xml`: ``` com.nvidia.spark.NvidiaGPUMigPluginForRuntimeV2.throwOnMultipleGPUs true ``` Environment variable for Spark application: ``` --conf spark.executorEnv.NVIDIA_MIG_PLUGIN_THROW_ON_MULTIPLE_GPUS=true ``` ## Using with Apache Spark on YARN Spark supports [scheduling GPUs and other custom resources on YARN](http://spark.apache.org/docs/latest/running-on-yarn.html#resource-allocation-and-configuration-overview). There are 2 options for using this plugin with Spark to allocate GPUs with MIG support: - Use Spark 3.2.1 or newer and remap the standard Spark `gpu` resource (i.e.: `spark.executor.resource.gpu.amount`) to be the new MIG GPU resource type using: ``` --conf spark.yarn.resourceGpuDeviceName=nvidia/miggpu ``` This means users don't have to change their configs if they were already using the `gpu` resource type. - Spark applications specify the `nvidia/miggpu` resource type instead of the `gpu` resource type. For this the user has to change the resource type to `nvidia/miggpu`, update the discovery script, and specify an extra YARN config(`spark.yarn.executor.resource.nvidia/miggpu.amount`). The command would be something like below (update the amounts according to your setup): ``` --conf spark.executor.resource.nvidia/miggpu.amount=1 --conf spark.executor.resource.nvidia/miggpu.discoveryScript=./getMIGGPUs --conf spark.task.resource.nvidia/miggpu.amount=0.25 --files ./getMIGGpus --conf spark.yarn.executor.resource.nvidia/miggpu.amount=1 ``` Note the getMIGGpus discovery script would is in the `scripts` directory in this repo. It just changes the resource name returned to match `nvidia/miggpu`. ## Testing Run a Spark application using the [Rapids Accelerator for Apache Spark](https://nvidia.github.io/spark-rapids/) and request GPUs from YARN and verify they use the MIG enabled GPUs. ================================================ FILE: examples/MIG-Support/device-plugins/gpu-mig/pom.xml ================================================ 4.0.0 com.nvidia yarn-gpu-mig-plugin YARN Device Plugin that supports MIG The root project of the YARN Device Plugin that supports MIG 1.0.0 jar Apache License, Version 2.0 https://www.apache.org/licenses/LICENSE-2.0.txt repo 3.3.6 1.8 3.8.1 3.2.0 4.13.1 3.4.6 org.apache.hadoop hadoop-yarn-server-nodemanager ${yarn.version} provided junit junit ${junit.version} test org.mockito mockito-core ${mockito.core.version} test org.apache.maven.plugins maven-compiler-plugin ${maven.compiler.version} ${java.version} ${java.version} org.apache.maven.plugins maven-jar-plugin ${maven.jar.plugin.version} default-jar package jar ================================================ FILE: examples/MIG-Support/device-plugins/gpu-mig/scripts/getMIGGPUs ================================================ #!/usr/bin/env bash # Copyright (c) 2021, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # This script is a basic example script to get resource information about NVIDIA MIG GPUs. # It works with the NVIDIA GPU Plugin for YARN with MIG support and is expected to be run # in a container where the nvidia-docker-v2 plugin has taken care of mapping the MIG # devices. This is the same as the Aapche Spark script, except the resource name is changed # to match the new plugin. # # It assumes the drivers are properly installed and the nvidia-smi command is available. # It is not guaranteed to work on all setups so please test and customize as needed # for your environment. It can be passed into SPARK via the config # spark.{driver/executor}.resource.gpu.discoveryScript to allow the driver or executor to discover # the GPUs it was allocated. It assumes you are running within an isolated container where the # GPUs are allocated exclusively to that driver or executor. # It outputs a JSON formatted string that is expected by the # spark.{driver/executor}.resource.gpu.discoveryScript config. # # Example output: {"name": "nvidia/miggpu", "addresses":["0"]} ADDRS=`nvidia-smi --query-gpu=index --format=csv,noheader | sed -e ':a' -e 'N' -e'$!ba' -e 's/\n/","/g'` echo {\"name\": \"nvidia/miggpu\", \"addresses\":[\"$ADDRS\"]} ================================================ FILE: examples/MIG-Support/device-plugins/gpu-mig/src/main/java/com/nvidia/spark/NvidiaGPUMigPluginForRuntimeV2.java ================================================ /* * Copyright (c) 2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.nvidia.spark; import java.util.regex.Matcher; import java.util.regex.Pattern; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.util.Shell; import org.apache.hadoop.yarn.exceptions.YarnException; import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.Device; import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePlugin; import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePluginScheduler; import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRegisterRequest; import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRuntimeSpec; import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.YarnRuntimeType; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.File; import java.io.IOException; import java.util.HashMap; import java.util.Map; import java.util.Set; import java.util.TreeSet; /** * Nvidia GPU plugin supporting both Nvidia container runtime v2. * It supports discovering and allocating MIG devices. Currently, with CUDA 11, * only enumeration of a single MIG instance is supported. This means that * this plugin officially only supports 1 GPU per container and by default * will throw an exception if more are requested. The behavior of throwing * an exception is configurable by either setting the environment variable * {@code NVIDIA_MIG_PLUGIN_THROW_ON_MULTIPLE_GPUS} or by setting the YARN config * {@code com.nvidia.spark.NvidiaGPUMigPluginForRuntimeV2.throwOnMultipleGPUs} * to false. */ public class NvidiaGPUMigPluginForRuntimeV2 implements DevicePlugin, DevicePluginScheduler { public static final Logger LOG = LoggerFactory.getLogger( NvidiaGPUMigPluginForRuntimeV2.class); public static final String NV_RESOURCE_NAME = "nvidia/miggpu"; private NvidiaCommandExecutor shellExecutor = new NvidiaCommandExecutor(); private Map environment = new HashMap<>(); // If this environment is set, use it directly private static final String ENV_BINARY_PATH = "NVIDIA_SMI_PATH"; private static final String DEFAULT_BINARY_NAME = "nvidia-smi"; private static final String DEV_NAME_PREFIX = "nvidia"; private static final String THROW_MULTI_CONF = "com.nvidia.spark.NvidiaGPUMigPluginForRuntimeV2.throwOnMultipleGPUs"; private static final String THROW_MULTI_ENV = "NVIDIA_MIG_PLUGIN_THROW_ON_MULTIPLE_GPUS"; private Boolean shouldThrowOnMultipleGPUFromConf = new Configuration().getBoolean(THROW_MULTI_CONF, true); private String shouldThrowOnMultipleGPUFromEnv = null; private String pathOfGpuBinary = null; // command should not run more than 10 sec. private static final int MAX_EXEC_TIMEOUT_MS = 10 * 1000; // When executable path not set, try to search default dirs // By default search /usr/bin, /bin, and /usr/local/nvidia/bin (when // launched by nvidia-docker. private static final String[] DEFAULT_BINARY_SEARCH_DIRS = new String[]{ "/usr/bin", "/bin", "/usr/local/nvidia/bin"}; // device id -> mig id, populated during discovery and used when launching // containers private Map migDevices = new HashMap<>(); private String migInfoOutput = null; @Override public DeviceRegisterRequest getRegisterRequestInfo() throws Exception { return DeviceRegisterRequest.Builder.newInstance() .setResourceName(NV_RESOURCE_NAME).build(); } @Override public Set getDevices() throws Exception { shellExecutor.searchBinary(); TreeSet r = new TreeSet<>(); String output; try { output = shellExecutor.getDeviceInfo(); String[] lines = output.trim().split("\n"); int id = 0; for (String oneLine : lines) { String[] tokensEachLine = oneLine.split(","); if (tokensEachLine.length != 3) { throw new Exception("Cannot parse the output to get the MIG enabled info. " + "output: " + oneLine + " expected index,pci.bus_id,mig.mode.current"); } String minorNumber = tokensEachLine[0].trim(); String busId = tokensEachLine[1].trim(); String migMode = tokensEachLine[2].trim(); String majorNumber = getMajorNumber(DEV_NAME_PREFIX + minorNumber); if (majorNumber != null) { if (migMode.equalsIgnoreCase("enabled")) { if (migInfoOutput == null) { // we get the mig info for all the GPUs on the host so only get it once migInfoOutput = shellExecutor.getDeviceMigInfo(); if (migInfoOutput == null) { throw new Exception("MIG device enabled but no device info found"); } } String[] linesMig = migInfoOutput.trim().split("\n"); Integer minorNumInt = Integer.parseInt(minorNumber); Integer migDevCount = 0; Integer numMigOutputLines = linesMig.length; for (int idmig = 0; idmig < numMigOutputLines; idmig++) { // first line should start with GPU // GPU 0: NVIDIA A30 (UUID: GPU-e7076666-0544-e103-4f65-a047fc18269e) // MIG 1g.6gb Device 0: (UUID: MIG-de9876e2-eef7-5b5a-9701-db694ffe8a77) if (linesMig[idmig].startsWith("GPU " + minorNumInt) && numMigOutputLines > (idmig + 1)) { // process any MIG devices, this expects all the lines to be MIG devices until // we find one that starts with GPU String nextLine = linesMig[++idmig].trim(); String regex = "MIG (.+)Device\\s+(\\d+):\\s+\\(UUID:(.*)\\)"; Pattern pattern = Pattern.compile(regex); while (nextLine.startsWith("MIG")) { Matcher matcher = pattern.matcher(nextLine); while (matcher.find()) { String devId = matcher.group(2); migDevices.put(id, devId); migDevCount++; r.add(Device.Builder.newInstance() .setId(id) .setMajorNumber(Integer.parseInt(majorNumber)) .setMinorNumber(minorNumInt) .setBusID(busId) .setDevPath("/dev/" + DEV_NAME_PREFIX + minorNumber) .setHealthy(true) .setStatus(devId) .build()); id++; if (++idmig < numMigOutputLines) { nextLine = linesMig[idmig].trim(); } else { nextLine = ""; } } } idmig = numMigOutputLines; } } if (migDevCount < 1) { throw new IOException("Error finding MIG devices on GPU with " + "MIG enabled: " + migInfoOutput); } LOG.info("Added GPU " + majorNumber + ":" + minorNumInt + " with MIG Enabled, found " + migDevCount + " MIG devices"); } else { Integer majorNumInt = Integer.parseInt(majorNumber); Integer minorNumInt = Integer.parseInt(minorNumber); r.add(Device.Builder.newInstance() .setId(id) .setMajorNumber(majorNumInt) .setMinorNumber(minorNumInt) .setBusID(busId) .setDevPath("/dev/" + DEV_NAME_PREFIX + minorNumber) .setHealthy(true) .build()); LOG.info("Added GPU " + majorNumInt + ":" + minorNumInt); id++; } } } return r; } catch (IOException e) { LOG.debug("Failed to get output from {}", pathOfGpuBinary); throw new YarnException(e); } } private Boolean shouldThrowOnMultipleGPUs() { // env setting takes highest priority if it is set if (shouldThrowOnMultipleGPUFromEnv != null) { return Boolean.parseBoolean(shouldThrowOnMultipleGPUFromEnv); } return shouldThrowOnMultipleGPUFromConf; } @Override public DeviceRuntimeSpec onDevicesAllocated(Set allocatedDevices, YarnRuntimeType yarnRuntime) throws Exception { LOG.debug("Generating runtime spec for allocated devices: {}, {}", allocatedDevices, yarnRuntime.getName()); if (allocatedDevices.size() > 1 && shouldThrowOnMultipleGPUs()) { throw new YarnException("Allocating more than 1 GPU per container is" + " not supported with use of MIG!"); } if (yarnRuntime == YarnRuntimeType.RUNTIME_DOCKER) { String nvidiaRuntime = "nvidia"; String nvidiaVisibleDevices = "NVIDIA_VISIBLE_DEVICES"; StringBuffer gpuMinorNumbersSB = new StringBuffer(); for (Device device : allocatedDevices) { Integer minorNum = device.getMinorNumber(); Integer id = device.getId(); if (migDevices.containsKey(id)) { gpuMinorNumbersSB.append(minorNum + ":" + migDevices.get(id) + ","); } else { gpuMinorNumbersSB.append(minorNum + ","); } } String minorNumbers = gpuMinorNumbersSB.toString(); LOG.info("Nvidia Docker v2 assigned GPU: " + minorNumbers); String deviceStr = minorNumbers.substring(0, minorNumbers.length() - 1); return DeviceRuntimeSpec.Builder.newInstance() .addEnv(nvidiaVisibleDevices, deviceStr) .setContainerRuntime(nvidiaRuntime) .build(); } return null; } @Override public void onDevicesReleased(Set releasedDevices) throws Exception { // do nothing } // Get major number from device name. private String getMajorNumber(String devName) { String output = null; // output "major:minor" in hex try { LOG.debug("Get major numbers from /dev/{}", devName); output = shellExecutor.getMajorMinorInfo(devName); String[] strs = output.trim().split(":"); output = Integer.toString(Integer.parseInt(strs[0], 16)); } catch (IOException e) { String msg = "Failed to get major number from reading /dev/" + devName; LOG.warn(msg); } catch (NumberFormatException e) { LOG.error("Failed to parse device major number from stat output"); output = null; } return output; } @Override public Set allocateDevices(Set availableDevices, int count, Map envs) { Set allocation = new TreeSet<>(); String envShouldThrow = envs.get(THROW_MULTI_ENV); if (envShouldThrow != null) { shouldThrowOnMultipleGPUFromEnv = envShouldThrow; } // Only officially support 1 GPU per container so don't worry about topology // scheduling. basicSchedule(allocation, count, availableDevices); return allocation; } public void basicSchedule(Set allocation, int count, Set availableDevices) { // Basic scheduling // allocate all available if (count == availableDevices.size()) { allocation.addAll(availableDevices); return; } int number = 0; for (Device d : availableDevices) { allocation.add(d); number++; if (number == count) { break; } } } /** * A shell wrapper class easy for test. */ public class NvidiaCommandExecutor { public String getDeviceInfo() throws IOException { return Shell.execCommand(environment, new String[]{pathOfGpuBinary, "--query-gpu=index,pci.bus_id,mig.mode.current", "--format=csv,noheader"}, MAX_EXEC_TIMEOUT_MS); } public String getDeviceMigInfo() throws IOException { return Shell.execCommand(environment, new String[]{pathOfGpuBinary, "-L"}, MAX_EXEC_TIMEOUT_MS); } public String getMajorMinorInfo(String devName) throws IOException { // output "major:minor" in hex Shell.ShellCommandExecutor shexec = new Shell.ShellCommandExecutor( new String[]{"stat", "-c", "%t:%T", "/dev/" + devName}); shexec.execute(); return shexec.getOutput(); } public void searchBinary() throws Exception { if (pathOfGpuBinary != null) { LOG.info("Skip searching, the NVIDIA gpu binary is already set: " + pathOfGpuBinary); return; } // search env for the binary String envBinaryPath = System.getenv(ENV_BINARY_PATH); if (null != envBinaryPath) { if (new File(envBinaryPath).exists()) { pathOfGpuBinary = envBinaryPath; LOG.info("Use NVIDIA gpu binary: " + pathOfGpuBinary); return; } } LOG.debug("Search binary.."); // search if binary exists in default folders File binaryFile; boolean found = false; for (String dir : DEFAULT_BINARY_SEARCH_DIRS) { binaryFile = new File(dir, DEFAULT_BINARY_NAME); if (binaryFile.exists()) { found = true; pathOfGpuBinary = binaryFile.getAbsolutePath(); LOG.info("Found binary:" + pathOfGpuBinary); break; } } if (!found) { LOG.error("No binary found from env variable: " + ENV_BINARY_PATH + " or path " + DEFAULT_BINARY_SEARCH_DIRS.toString()); throw new Exception("No binary found for " + NvidiaGPUMigPluginForRuntimeV2.class); } } } // visible for testing public void setPathOfGpuBinary(String pOfGpuBinary) { this.pathOfGpuBinary = pOfGpuBinary; } // visible for testing public void setShellExecutor(NvidiaCommandExecutor shellExecutor) { this.shellExecutor = shellExecutor; } // visible for testing public void setMigDevices(Map migDevices) { this.migDevices = migDevices; } // visible for testing public void setShouldThrowOnMultipleGPUFromConf(Boolean shouldThrow) { this.shouldThrowOnMultipleGPUFromConf = shouldThrow; } } ================================================ FILE: examples/MIG-Support/device-plugins/gpu-mig/src/test/java/com/nvidia/spark/TestNvidiaGPUMigPluginForRuntimeV2.java ================================================ /* * Copyright (c) 2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.nvidia.spark; import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.Device; import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRuntimeSpec; import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.YarnRuntimeType; import org.junit.Assert; import org.junit.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.HashMap; import java.util.Map; import java.util.Set; import java.util.TreeSet; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; /** * Test case for NvidiaGPUMigPluginForRuntimeV2 device plugin. */ public class TestNvidiaGPUMigPluginForRuntimeV2 { private static final Logger LOG = LoggerFactory.getLogger(TestNvidiaGPUMigPluginForRuntimeV2.class); @Test public void testGetNvidiaDevices() throws Exception { NvidiaGPUMigPluginForRuntimeV2.NvidiaCommandExecutor mockShell = mock(NvidiaGPUMigPluginForRuntimeV2.NvidiaCommandExecutor.class); String deviceInfoShellOutput = "0, 00000000:04:00.0, [N/A]\n" + "1, 00000000:82:00.0, Enabled"; String majorMinorNumber0 = "c3:0"; String majorMinorNumber1 = "c3:1"; String deviceMigInfoShellOutput = "GPU 0: NVIDIA A100 80GB PCIe (UUID: GPU-aa72194b-fdd4-24b0-f659-17c929f46267)\n" + " MIG 1g.10gb Device 0: (UUID: MIG-aa2c982c-48a9-5046-b7f8-aa4732879e02)\n" + "GPU 1: NVIDIA A100 80GB PCIe (UUID: GPU-aa7153bf-c0ba-00ef-cdce-f861c34172f6)\n" + " MIG 1g.10gb Device 0: (UUID: MIG-aa59d467-ba39-5d0a-a085-66af03246526)\n" + " MIG 1g.10gb Device 1: (UUID: MIG-aad5cb29-8e6f-510a-8352-8e18f483dc74)" + when(mockShell.getDeviceInfo()).thenReturn(deviceInfoShellOutput); when(mockShell.getDeviceMigInfo()).thenReturn(deviceMigInfoShellOutput); when(mockShell.getMajorMinorInfo("nvidia0")) .thenReturn(majorMinorNumber0); when(mockShell.getMajorMinorInfo("nvidia1")) .thenReturn(majorMinorNumber1); NvidiaGPUMigPluginForRuntimeV2 plugin = new NvidiaGPUMigPluginForRuntimeV2(); plugin.setShellExecutor(mockShell); plugin.setPathOfGpuBinary("/fake/nvidia-smi"); Set expectedDevices = new TreeSet<>(); expectedDevices.add(Device.Builder.newInstance() .setId(0).setHealthy(true) .setBusID("00000000:04:00.0") .setDevPath("/dev/nvidia0") .setMajorNumber(195) .setStatus("0") .setMinorNumber(0).build()); expectedDevices.add(Device.Builder.newInstance() .setId(1).setHealthy(true) .setBusID("00000000:82:00.0") .setDevPath("/dev/nvidia1") .setMajorNumber(195) .setStatus("0") .setMinorNumber(1).build()); expectedDevices.add(Device.Builder.newInstance() .setId(2).setHealthy(true) .setBusID("00000000:82:00.0") .setDevPath("/dev/nvidia1") .setMajorNumber(195) .setStatus("1") .setMinorNumber(1).build()); Set devices = plugin.getDevices(); Assert.assertEquals(expectedDevices, devices); } @Test(expected = Exception.class) public void testOnDeviceAllocatedMultiGPU() throws Exception { NvidiaGPUMigPluginForRuntimeV2 plugin = new NvidiaGPUMigPluginForRuntimeV2(); Set allocatedDevices = new TreeSet<>(); DeviceRuntimeSpec spec = plugin.onDevicesAllocated(allocatedDevices, YarnRuntimeType.RUNTIME_DEFAULT); Assert.assertNull(spec); // allocate one device allocatedDevices.add(Device.Builder.newInstance() .setId(0).setHealthy(true) .setBusID("00000000:04:00.0") .setDevPath("/dev/nvidia0") .setMajorNumber(195) .setMinorNumber(0).build()); spec = plugin.onDevicesAllocated(allocatedDevices, YarnRuntimeType.RUNTIME_DOCKER); Assert.assertEquals("nvidia", spec.getContainerRuntime()); Assert.assertEquals("0", spec.getEnvs().get("NVIDIA_VISIBLE_DEVICES")); // two device allowed allocatedDevices.add(Device.Builder.newInstance() .setId(0).setHealthy(true) .setBusID("00000000:82:00.0") .setDevPath("/dev/nvidia1") .setMajorNumber(195) .setMinorNumber(1).build()); spec = plugin.onDevicesAllocated(allocatedDevices, YarnRuntimeType.RUNTIME_DOCKER); } @Test public void testMultiGPUsEnvPrecedence() throws Exception { NvidiaGPUMigPluginForRuntimeV2 plugin = new NvidiaGPUMigPluginForRuntimeV2(); Set allocatedDevices = new TreeSet<>(); DeviceRuntimeSpec spec = plugin.onDevicesAllocated(allocatedDevices, YarnRuntimeType.RUNTIME_DEFAULT); Assert.assertNull(spec); // allocate one device allocatedDevices.add(Device.Builder.newInstance() .setId(0).setHealthy(true) .setBusID("00000000:04:00.0") .setDevPath("/dev/nvidia0") .setMajorNumber(195) .setMinorNumber(0).build()); // two device allowed allocatedDevices.add(Device.Builder.newInstance() .setId(0).setHealthy(true) .setBusID("00000000:82:00.0") .setDevPath("/dev/nvidia1") .setMajorNumber(195) .setMinorNumber(1).build()); // test that env variable takes presedence plugin.setShouldThrowOnMultipleGPUFromConf(true); Map envs = new HashMap<>(); envs.put("NVIDIA_MIG_PLUGIN_THROW_ON_MULTIPLE_GPUS", "false"); // note the allocated devices doesn't matter here, just the env passed in plugin.allocateDevices(allocatedDevices, 2, envs); spec = plugin.onDevicesAllocated(allocatedDevices, YarnRuntimeType.RUNTIME_DOCKER); Assert.assertEquals("nvidia", spec.getContainerRuntime()); Assert.assertEquals("0,1", spec.getEnvs().get("NVIDIA_VISIBLE_DEVICES")); } @Test public void testMultiGPUsConf() throws Exception { NvidiaGPUMigPluginForRuntimeV2 plugin = new NvidiaGPUMigPluginForRuntimeV2(); Set allocatedDevices = new TreeSet<>(); DeviceRuntimeSpec spec = plugin.onDevicesAllocated(allocatedDevices, YarnRuntimeType.RUNTIME_DEFAULT); Assert.assertNull(spec); // allocate one device allocatedDevices.add(Device.Builder.newInstance() .setId(0).setHealthy(true) .setBusID("00000000:04:00.0") .setDevPath("/dev/nvidia0") .setMajorNumber(195) .setMinorNumber(0).build()); // two device allowed allocatedDevices.add(Device.Builder.newInstance() .setId(0).setHealthy(true) .setBusID("00000000:82:00.0") .setDevPath("/dev/nvidia1") .setMajorNumber(195) .setMinorNumber(1).build()); // test that env variable takes presedence plugin.setShouldThrowOnMultipleGPUFromConf(false); spec = plugin.onDevicesAllocated(allocatedDevices, YarnRuntimeType.RUNTIME_DOCKER); Assert.assertEquals("nvidia", spec.getContainerRuntime()); Assert.assertEquals("0,1", spec.getEnvs().get("NVIDIA_VISIBLE_DEVICES")); } @Test public void testOnDeviceAllocatedMig() throws Exception { NvidiaGPUMigPluginForRuntimeV2 plugin = new NvidiaGPUMigPluginForRuntimeV2(); Set allocatedDevices = new TreeSet<>(); DeviceRuntimeSpec spec = plugin.onDevicesAllocated(allocatedDevices, YarnRuntimeType.RUNTIME_DEFAULT); Assert.assertNull(spec); Map testMigDevices = new HashMap<>(); testMigDevices.put(0, "0"); plugin.setMigDevices(testMigDevices); // allocate one device allocatedDevices.add(Device.Builder.newInstance() .setId(0).setHealthy(true) .setBusID("00000000:04:00.0") .setDevPath("/dev/nvidia0") .setMajorNumber(195) .setMinorNumber(0).build()); spec = plugin.onDevicesAllocated(allocatedDevices, YarnRuntimeType.RUNTIME_DOCKER); Assert.assertEquals("nvidia", spec.getContainerRuntime()); Assert.assertEquals("0:0", spec.getEnvs().get("NVIDIA_VISIBLE_DEVICES")); } @Test public void testOnDeviceAllocatedNoMig() throws Exception { NvidiaGPUMigPluginForRuntimeV2 plugin = new NvidiaGPUMigPluginForRuntimeV2(); Set allocatedDevices = new TreeSet<>(); DeviceRuntimeSpec spec = plugin.onDevicesAllocated(allocatedDevices, YarnRuntimeType.RUNTIME_DEFAULT); Assert.assertNull(spec); // allocate one device allocatedDevices.add(Device.Builder.newInstance() .setId(0).setHealthy(true) .setBusID("00000000:04:00.0") .setDevPath("/dev/nvidia0") .setMajorNumber(195) .setMinorNumber(0).build()); spec = plugin.onDevicesAllocated(allocatedDevices, YarnRuntimeType.RUNTIME_DOCKER); Assert.assertEquals("nvidia", spec.getContainerRuntime()); Assert.assertEquals("0", spec.getEnvs().get("NVIDIA_VISIBLE_DEVICES")); } } ================================================ FILE: examples/MIG-Support/resource-types/gpu-mig/README.md ================================================ # NVIDIA Support for GPU for YARN with MIG support for YARN 3.1.2 until YARN 3.3.0 This adds support for GPUs with [MIG](https://docs.nvidia.com/datacenter/tesla/mig-user-guide/) on YARN for versions prior to YARN 3.3.0 which don't support the pluggable device framework. Use the [GPU Plugin for YARN with MIG support](../../device-plugins/gpu-mig/README.md) for YARN 3.3.0 and newer versions. The built-in YARN GPU plugin does not support MIG enabled GPUs. This patch works with GPUs without MIG or GPUs with MIG disabled but the limitation section still applies. It supports heterogenous environments where there may be some MIG enabled GPUs and some without MIG. This requires patching YARN and rebuilding it. ## Compatibility Requires YARN 3.1.2 or newer that supports GPU scheduling. See the [supported versions](#supported-versions) section below for specific versions supported. MIG support requires YARN to be configured with Docker and using the NVIDIA Container Toolkit (nvidia-docker2) ## Limitations Please see the [MIG Application Considerations](https://docs.nvidia.com/datacenter/tesla/mig-user-guide/#app-considerations) and [CUDA Device Enumeration](https://docs.nvidia.com/datacenter/tesla/mig-user-guide/index.html#cuda-visible-devices). It is important to note that CUDA 11 only supports enumeration of a single MIG instance. This means that with this patch and MIG support enabled, it only supports 1 GPU per container and will throw an exception by default if you request more. It is recommended that you configure YARN to only allow a single GPU be requested. See the yarn config: ``` yarn.resource-types.yarn.io/gpu.maximum-allocation ``` See [YARN Resource Configuration](https://hadoop.apache.org/docs/r3.1.2/hadoop-yarn/hadoop-yarn-site/ResourceModel.html) for more details. If you do not configure the maximum allocation and someone requests multiple GPUs, the default behavior is to throw an exception. See the [Configuration](#configuration) section for options if it throws an exception. ## Supported Versions There are different patches available depending on the YARN version you are using: - YARN 3.1.2 use patch `yarn312MIG.patch` - YARN versions 3.1.3 to 3.1.5 (git hash cd7c34f9b4005d27886f73e58bef88e706fcccf9 since 3.1.5 was not released when this was tested) use `yarn313to315MIG.patch` - YARN 3.2.0, no patch is currently available, backport patch for YARN 3.2.1 or contact us. - YARN 3.2.1 and 3.2.3 use patch `yarn321to323MIG.patch` ## Building Apply the patch to your YARN version and build it like you would normally for your deployment. For example: ``` patch -p1 < yarn312MIG.patch mvn clean package -Pdist -Dtar -DskipTests ``` Run unit tests: ``` mvn test -Pdist -Dtar -Dtest=TestGpuDiscoverer mvn test -Pdist -Dtar -Dtest=TestNvidiaDockerV2CommandPlugin ``` ## Installation These instructions assume YARN is already installed and configured with GPU Scheduling enabled using Docker and the NVIDIA Container Toolkit (nvidia-docker2). See [Using GPU on YARN](https://hadoop.apache.org/docs/current/hadoop-yarn/hadoop-yarn-site/UsingGpus.html) if you need more information. Enable and configure your [GPUs with MIG](https://docs.nvidia.com/datacenter/tesla/mig-user-guide/index.html) on all of the nodes it applies to. Install the new YARN version built with the patch on your YARN Cluster. Enable the MIG GPU support in the Hadoop configuration files: ``` yarn.nodemanager.resource-plugins.gpu.use-mig-enabled true ``` Restart YARN if needed to pick up any configuration changes. ## Configuration The default behavior of the GPU resource plugin on YARN is to use `auto` discovery mode of GPUs on each nodemanager. It also allows you to manually allow certain gpu devices. This configuration was extended to support MIG devices. `yarn.nodemanager.resource-plugins.gpu.allowed-gpu-devices` configuration can be used to manually specify devices. GPU device is identified by their minor device number, index, and optionally MIG device index. A common approach to get minor device number of GPUs is using nvidia-smi -q and search Minor Number output and optionally MIG device indices. The format is index:minor_number[:mig_index][,index:minor_number...]. An example of manual specification is 0:0,1:1:0,1:1:1,2:2" to allow YARN NodeManager to manage GPU devices with indices 0/1/2 and minor number 0/1/2 where GPU indices 1 has 2 MIG enabled devices with indices 0/1. ``` yarn.nodemanager.resource-plugins.gpu.allowed-gpu-devices 0:0,1:1:0,1:1:1,2:2 ``` To change the behavior of throwing when the user allocates multiple GPUs can be controlled by setting an environment variable when the Spark application is launched. Setting it to `true` means to throw if a user requests multiple GPUs (this is the default), `false` means it won't throw and if the container is allocated with multiple MIG devices from the same GPU, it is up to the application to know how to use them. Environment variable for Spark application: ``` --conf spark.executorEnv.NVIDIA_MIG_PLUGIN_THROW_ON_MULTIPLE_GPUS=false ``` ## Testing Run a Spark application using the [Rapids Accelerator for Apache Spark](https://nvidia.github.io/spark-rapids/) and request GPUs from YARN and verify they use the MIG enabled GPUs. ================================================ FILE: examples/MIG-Support/resource-types/gpu-mig/yarn312MIG.patch ================================================ diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-api/src/main/java/org/apache/hadoop/yarn/conf/YarnConfiguration.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-api/src/main/java/org/apache/hadoop/yarn/conf/YarnConfiguration.java index 36fafefdbc4..e37d0a3a685 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-api/src/main/java/org/apache/hadoop/yarn/conf/YarnConfiguration.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-api/src/main/java/org/apache/hadoop/yarn/conf/YarnConfiguration.java @@ -1574,6 +1574,10 @@ public static boolean isAclEnabled(Configuration conf) { @Private public static final String AUTOMATICALLY_DISCOVER_GPU_DEVICES = "auto"; + @Private + public static final String USE_MIG_ENABLED_GPUS = + NM_GPU_RESOURCE_PREFIX + "use-mig-enabled"; + /** * This setting controls where to how to invoke GPU binaries */ diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/AssignedGpuDevice.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/AssignedGpuDevice.java index 26fd9050742..e84b920dcee 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/AssignedGpuDevice.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/AssignedGpuDevice.java @@ -34,6 +34,12 @@ public AssignedGpuDevice(int index, int minorNumber, this.containerId = containerId.toString(); } + public AssignedGpuDevice(int index, int minorNumber, + int migIndex, ContainerId containerId) { + super(index, minorNumber, migIndex); + this.containerId = containerId.toString(); + } + public String getContainerId() { return containerId; } @@ -49,6 +55,7 @@ public boolean equals(Object obj) { } AssignedGpuDevice other = (AssignedGpuDevice) obj; return index == other.index && minorNumber == other.minorNumber + && migDeviceIndex == other.migDeviceIndex && containerId.equals(other.containerId); } @@ -68,12 +75,16 @@ public int compareTo(Object obj) { if (0 != result) { return result; } - return containerId.compareTo(other.containerId); + result = containerId.compareTo(other.containerId); + if (0 != result) { + return result; + } + return Integer.compare(migDeviceIndex, other.migDeviceIndex); } @Override public int hashCode() { final int prime = 47; - return prime * (prime * index + minorNumber) + containerId.hashCode(); + return prime * (prime * index + minorNumber + migDeviceIndex) + containerId.hashCode(); } } diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDevice.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDevice.java index bce1d9fa480..3cb42d3c58f 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDevice.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDevice.java @@ -26,6 +26,7 @@ public class GpuDevice implements Serializable, Comparable { protected int index; protected int minorNumber; + protected int migDeviceIndex = -1; private static final long serialVersionUID = -6812314470754667710L; public GpuDevice(int index, int minorNumber) { @@ -33,6 +34,12 @@ public GpuDevice(int index, int minorNumber) { this.minorNumber = minorNumber; } + public GpuDevice(int index, int minorNumber, int migIndex) { + this.index = index; + this.minorNumber = minorNumber; + this.migDeviceIndex = migIndex; + } + public int getIndex() { return index; } @@ -41,13 +48,17 @@ public int getMinorNumber() { return minorNumber; } + public int getMIGIndex() { + return migDeviceIndex; + } + @Override public boolean equals(Object obj) { if (obj == null || !(obj instanceof GpuDevice)) { return false; } GpuDevice other = (GpuDevice) obj; - return index == other.index && minorNumber == other.minorNumber; + return index == other.index && minorNumber == other.minorNumber && migDeviceIndex == other.migDeviceIndex; } @Override @@ -62,17 +73,21 @@ public int compareTo(Object obj) { if (0 != result) { return result; } - return Integer.compare(minorNumber, other.minorNumber); + result = Integer.compare(minorNumber, other.minorNumber); + if (0 != result) { + return result; + } + return Integer.compare(migDeviceIndex, other.migDeviceIndex); } @Override public int hashCode() { final int prime = 47; - return prime * index + minorNumber; + return prime * index + minorNumber + migDeviceIndex; } @Override public String toString() { - return "(index=" + index + ",minor_number=" + minorNumber + ")"; + return "(index=" + index + ",minor_number=" + minorNumber + ",mig_index=" + migDeviceIndex + ")"; } } diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDiscoverer.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDiscoverer.java index 6e3cf1315ce..55f7379d4cc 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDiscoverer.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDiscoverer.java @@ -30,6 +30,7 @@ import org.apache.hadoop.yarn.server.nodemanager.webapp.dao.gpu.GpuDeviceInformation; import org.apache.hadoop.yarn.server.nodemanager.webapp.dao.gpu.GpuDeviceInformationParser; import org.apache.hadoop.yarn.server.nodemanager.webapp.dao.gpu.PerGpuDeviceInformation; +import org.apache.hadoop.yarn.server.nodemanager.webapp.dao.gpu.PerGpuMigDevice; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -149,6 +150,10 @@ public synchronized GpuDeviceInformation getGpuDeviceInformation() YarnConfiguration.NM_GPU_ALLOWED_DEVICES, YarnConfiguration.AUTOMATICALLY_DISCOVER_GPU_DEVICES); + Boolean useMIGEnabledGPUs = conf.getBoolean( + YarnConfiguration.USE_MIG_ENABLED_GPUS, false); + LOG.info("Use MIG enabled is: " + useMIGEnabledGPUs); + List gpuDevices = new ArrayList<>(); if (allowedDevicesStr.equals( @@ -171,21 +176,45 @@ public synchronized GpuDeviceInformation getGpuDeviceInformation() i++) { List gpuInfos = lastDiscoveredGpuInformation.getGpus(); - gpuDevices.add(new GpuDevice(i, gpuInfos.get(i).getMinorNumber())); + if (useMIGEnabledGPUs && + gpuInfos.get(i).getMIGMode().getCurrentMigMode().equalsIgnoreCase("enabled")) { + LOG.info("GPU id " + i + " has MIG mode enabled."); + for (PerGpuMigDevice dev: gpuInfos.get(i).getMIGDevices()) { + gpuDevices.add(new GpuDevice(i, gpuInfos.get(i).getMinorNumber(), dev.getMigDeviceIndex())); + } + } else { + gpuDevices.add(new GpuDevice(i, gpuInfos.get(i).getMinorNumber())); + } } + LOG.info("Discovered GPU devices: " + gpuDevices); } } else{ for (String s : allowedDevicesStr.split(",")) { if (s.trim().length() > 0) { String[] kv = s.trim().split(":"); - if (kv.length != 2) { - throw new YarnException( - "Illegal format, it should be index:minor_number format, now it=" - + s); + if (useMIGEnabledGPUs) { + if (kv.length != 2 && kv.length != 3) { + throw new YarnException( + "Illegal format, it should be index:minor_number or index:minor_number:mig_device_id" + + " format, now it=" + s); + } + if (kv.length == 3) { + // assumes this is MIG enabled device + gpuDevices.add( + new GpuDevice(Integer.parseInt(kv[0]), Integer.parseInt(kv[1]), Integer.parseInt(kv[2]))); + } else { + gpuDevices.add( + new GpuDevice(Integer.parseInt(kv[0]), Integer.parseInt(kv[1]))); + } + } else { + if (kv.length != 2) { + throw new YarnException( + "Illegal format, it should be index:minor_number format, now it=" + + s); + } + gpuDevices.add( + new GpuDevice(Integer.parseInt(kv[0]), Integer.parseInt(kv[1]))); } - - gpuDevices.add( - new GpuDevice(Integer.parseInt(kv[0]), Integer.parseInt(kv[1]))); } } LOG.info("Allowed GPU devices:" + gpuDevices); diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDockerCommandPluginFactory.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDockerCommandPluginFactory.java index 051afd6c561..996cb58ac45 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDockerCommandPluginFactory.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDockerCommandPluginFactory.java @@ -36,7 +36,7 @@ public static DockerCommandPlugin createGpuDockerCommandPlugin( } // nvidia-docker2 if (impl.equals(YarnConfiguration.NVIDIA_DOCKER_V2)) { - return new NvidiaDockerV2CommandPlugin(); + return new NvidiaDockerV2CommandPlugin(conf); } throw new YarnException( diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/NvidiaDockerV2CommandPlugin.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/NvidiaDockerV2CommandPlugin.java index ff25eb6ced6..c2cc0e5a2d1 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/NvidiaDockerV2CommandPlugin.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/NvidiaDockerV2CommandPlugin.java @@ -21,7 +21,9 @@ import com.google.common.annotations.VisibleForTesting; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.yarn.api.records.ResourceInformation; +import org.apache.hadoop.yarn.conf.YarnConfiguration; import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container; import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.ResourceMappings; import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.gpu.GpuResourceAllocator; @@ -45,8 +47,12 @@ private String nvidiaRuntime = "nvidia"; private String nvidiaVisibleDevices = "NVIDIA_VISIBLE_DEVICES"; + private String nvidiaMigThrowOnMultiGpus = "NVIDIA_MIG_PLUGIN_THROW_ON_MULTIPLE_GPUS"; + private Boolean isMigEnabled = false; - public NvidiaDockerV2CommandPlugin() {} + public NvidiaDockerV2CommandPlugin(Configuration conf) { + isMigEnabled = conf.getBoolean(YarnConfiguration.USE_MIG_ENABLED_GPUS, false); + } private Set getAssignedGpus(Container container) { ResourceMappings resourceMappings = container.getResourceMappings(); @@ -84,10 +90,23 @@ public synchronized void updateDockerRunCommand( return; } Map environment = new HashMap<>(); + if (isMigEnabled && assignedResources.size() > 1) { + Map existingEnv = container.getLaunchContext().getEnvironment(); + Boolean shouldThrowOnMultipleGpus = Boolean.parseBoolean( + existingEnv.getOrDefault(nvidiaMigThrowOnMultiGpus, "true")); + if (shouldThrowOnMultipleGpus) { + throw new ContainerExecutionException("Allocating more than 1 GPU per container is " + + "not supported with use of MIG!"); + } + } String gpuIndexList = ""; for (GpuDevice gpuDevice : assignedResources) { - gpuIndexList = gpuIndexList + gpuDevice.getIndex() + ","; - LOG.info("nvidia docker2 assigned gpu index: " + gpuDevice.getIndex()); + String deviceIndex = String.valueOf(gpuDevice.getIndex()); + if (gpuDevice.getMIGIndex() != -1) { + deviceIndex = gpuDevice.getIndex() + ":" + gpuDevice.getMIGIndex(); + } + gpuIndexList = gpuIndexList + deviceIndex + ","; + LOG.info("nvidia docker2 assigned gpu index: " + deviceIndex); } dockerRunCommand.addRuntime(nvidiaRuntime); environment.put(nvidiaVisibleDevices, diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuDeviceInformation.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuDeviceInformation.java index 25c2e3a1f1d..15cb7eac10a 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuDeviceInformation.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuDeviceInformation.java @@ -22,8 +22,10 @@ import org.apache.hadoop.classification.InterfaceStability; import javax.xml.bind.annotation.XmlElement; +import javax.xml.bind.annotation.XmlElementWrapper; import javax.xml.bind.annotation.XmlRootElement; import javax.xml.bind.annotation.adapters.XmlAdapter; +import java.util.List; /** * Capture single GPU device information such as memory size, temperature, @@ -38,6 +40,8 @@ private String uuid = "N/A"; private int minorNumber = -1; + private List migDevices; + private PerGpuMigMode migMode; private PerGpuUtilizations gpuUtilizations; private PerGpuMemoryUsage gpuMemoryUsage; private PerGpuTemperature temperature; @@ -108,6 +112,25 @@ public void setUuid(String uuid) { this.uuid = uuid; } + @XmlElement(name = "mig_mode") + public PerGpuMigMode getMIGMode() { + return migMode; + } + + public void setMIGMode(PerGpuMigMode mode) { + this.migMode = mode; + } + + @XmlElementWrapper( name = "mig_devices" ) + @XmlElement(name = "mig_device") + public List getMIGDevices() { + return migDevices; + } + + public void setMIGDevices(List devices) { + this.migDevices = devices; + } + @XmlElement(name = "product_name") public String getProductName() { return productName; diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuMigDevice.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuMigDevice.java new file mode 100644 index 00000000000..4ce7cec6e55 --- /dev/null +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuMigDevice.java @@ -0,0 +1,48 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.server.nodemanager.webapp.dao.gpu; + +import org.apache.hadoop.classification.InterfaceAudience; +import org.apache.hadoop.classification.InterfaceStability; + +import javax.xml.bind.annotation.XmlElement; +import javax.xml.bind.annotation.XmlRootElement; + +/** + * GPU MIG Device Information + */ +@InterfaceAudience.Private +@InterfaceStability.Unstable +@XmlRootElement(name = "mig_device") +public class PerGpuMigDevice { + private int index; + + /** + * MIG device index + * @return MIG device index + */ + @XmlElement(name = "index") + public int getMigDeviceIndex() { + return index; + } + + public void setMigDeviceIndex(int index) { + this.index = index; + } +} diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuMigMode.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuMigMode.java new file mode 100644 index 00000000000..b706df2c3bb --- /dev/null +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuMigMode.java @@ -0,0 +1,48 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.server.nodemanager.webapp.dao.gpu; + +import org.apache.hadoop.classification.InterfaceAudience; +import org.apache.hadoop.classification.InterfaceStability; + +import javax.xml.bind.annotation.XmlElement; +import javax.xml.bind.annotation.XmlRootElement; + +/** + * GPU MIG Mode + */ +@InterfaceAudience.Private +@InterfaceStability.Unstable +@XmlRootElement(name = "mig_mode") +public class PerGpuMigMode { + private String currentMigMode; + + /** + * Current MIG mode + * @return MIG mode enabled or disabled + */ + @XmlElement(name = "current_mig") + public String getCurrentMigMode() { + return currentMigMode; + } + + public void setCurrentMigMode(String migMode) { + this.currentMigMode = migMode; + } +} diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/TestGpuDiscoverer.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/TestGpuDiscoverer.java index 4abb633a69a..404930d00c2 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/TestGpuDiscoverer.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/TestGpuDiscoverer.java @@ -138,4 +138,47 @@ public void getNumberOfUsableGpusFromConfig() throws YarnException { Assert.assertTrue(2 == usableGpuDevices.get(2).getMinorNumber()); Assert.assertTrue(4 == usableGpuDevices.get(3).getMinorNumber()); } + + @Test + public void getNumberOfUsableGpusFromConfigMIG() throws YarnException { + Configuration conf = new Configuration(false); + conf.set(YarnConfiguration.USE_MIG_ENABLED_GPUS, "true"); + + // Illegal format + conf.set(YarnConfiguration.NM_GPU_ALLOWED_DEVICES, "0:0,1:1:2,2:2:0,3"); + GpuDiscoverer plugin = new GpuDiscoverer(); + try { + plugin.initialize(conf); + plugin.getGpusUsableByYarn(); + Assert.fail("Illegal format, should fail."); + } catch (YarnException e) { + // Expected + } + + // Valid format + conf.set(YarnConfiguration.NM_GPU_ALLOWED_DEVICES, "0:0,1:1:0,1:1:2,2:2:0,3:4"); + plugin = new GpuDiscoverer(); + plugin.initialize(conf); + + List usableGpuDevices = plugin.getGpusUsableByYarn(); + Assert.assertEquals(5, usableGpuDevices.size()); + + Assert.assertTrue(0 == usableGpuDevices.get(0).getIndex()); + Assert.assertTrue(1 == usableGpuDevices.get(1).getIndex()); + Assert.assertTrue(1 == usableGpuDevices.get(2).getIndex()); + Assert.assertTrue(2 == usableGpuDevices.get(3).getIndex()); + Assert.assertTrue(3 == usableGpuDevices.get(4).getIndex()); + + Assert.assertTrue(0 == usableGpuDevices.get(0).getMinorNumber()); + Assert.assertTrue(1 == usableGpuDevices.get(1).getMinorNumber()); + Assert.assertTrue(1 == usableGpuDevices.get(2).getMinorNumber()); + Assert.assertTrue(2 == usableGpuDevices.get(3).getMinorNumber()); + Assert.assertTrue(4 == usableGpuDevices.get(4).getMinorNumber()); + + Assert.assertTrue(-1 == usableGpuDevices.get(0).getMIGIndex()); + Assert.assertTrue(0 == usableGpuDevices.get(1).getMIGIndex()); + Assert.assertTrue(2 == usableGpuDevices.get(2).getMIGIndex()); + Assert.assertTrue(0 == usableGpuDevices.get(3).getMIGIndex()); + Assert.assertTrue(-1 == usableGpuDevices.get(4).getMIGIndex()); + } } diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/TestNvidiaDockerV2CommandPlugin.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/TestNvidiaDockerV2CommandPlugin.java index b0b523360ef..798a95cb009 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/TestNvidiaDockerV2CommandPlugin.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/TestNvidiaDockerV2CommandPlugin.java @@ -20,10 +20,14 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.Sets; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.yarn.conf.YarnConfiguration; +import org.apache.hadoop.yarn.api.records.ContainerLaunchContext; import org.apache.hadoop.yarn.api.records.ResourceInformation; import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container; import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.ResourceMappings; import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.docker.DockerRunCommand; +import org.apache.hadoop.yarn.server.nodemanager.containermanager.runtime.ContainerExecutionException; import org.junit.Assert; import org.junit.Test; @@ -69,7 +73,13 @@ private boolean commandlinesEquals(Map> cli1, extends NvidiaDockerV2CommandPlugin { private boolean requestsGpu = false; - MyNvidiaDockerV2CommandPlugin() {} + MyNvidiaDockerV2CommandPlugin() { + super(new Configuration()); + } + + MyNvidiaDockerV2CommandPlugin(Configuration conf) { + super(conf); + } public void setRequestsGpu(boolean r) { requestsGpu = r; @@ -127,4 +137,118 @@ public void testPlugin() throws Exception { // runtime should exist Assert.assertTrue(newCommandLine.containsKey("runtime")); } -} \ No newline at end of file + + @Test + public void testPluginMIG() throws Exception { + DockerRunCommand runCommand = new DockerRunCommand("container_1", "user", + "fakeimage"); + + Map> originalCommandline = copyCommandLine( + runCommand.getDockerCommandWithArguments()); + + Configuration conf = new Configuration(); + conf.set(YarnConfiguration.USE_MIG_ENABLED_GPUS, "true"); + MyNvidiaDockerV2CommandPlugin + commandPlugin = new MyNvidiaDockerV2CommandPlugin(conf); + + Container nmContainer = mock(Container.class); + ResourceMappings resourceMappings = new ResourceMappings(); + when(nmContainer.getResourceMappings()).thenReturn(resourceMappings); + + // Assign GPU resource + ResourceMappings.AssignedResources assigned = + new ResourceMappings.AssignedResources(); + assigned.updateAssignedResources( + ImmutableList.of(new GpuDevice(0, 0, 0))); + resourceMappings.addAssignedResources(ResourceInformation.GPU_URI, + assigned); + + commandPlugin.setRequestsGpu(true); + commandPlugin.updateDockerRunCommand(runCommand, nmContainer); + Map> newCommandLine = + runCommand.getDockerCommandWithArguments(); + + // Command line will be updated + Assert.assertFalse(commandlinesEquals(originalCommandline, newCommandLine)); + // NVIDIA_VISIBLE_DEVICES will be set + Assert.assertTrue( + runCommand.getEnv().get("NVIDIA_VISIBLE_DEVICES").equals("0:0")); + // runtime should exist + Assert.assertTrue(newCommandLine.containsKey("runtime")); + } + + @Test(expected = ContainerExecutionException.class) + public void testPluginMIGThrowsMulti() throws Exception { + DockerRunCommand runCommand = new DockerRunCommand("container_1", "user", + "fakeimage"); + + Map> originalCommandline = copyCommandLine( + runCommand.getDockerCommandWithArguments()); + + Configuration conf = new Configuration(); + conf.set(YarnConfiguration.USE_MIG_ENABLED_GPUS, "true"); + MyNvidiaDockerV2CommandPlugin + commandPlugin = new MyNvidiaDockerV2CommandPlugin(conf); + + Container nmContainer = mock(Container.class); + ResourceMappings resourceMappings = new ResourceMappings(); + Map env = new HashMap<>(); + env.put("NVIDIA_MIG_PLUGIN_THROW_ON_MULTIPLE_GPUS", "true"); + when(nmContainer.getResourceMappings()).thenReturn(resourceMappings); + ContainerLaunchContext launchCtx = mock(ContainerLaunchContext.class); + when(nmContainer.getLaunchContext()).thenReturn(launchCtx); + when(launchCtx.getEnvironment()).thenReturn(env); + + // Assign GPU resource + ResourceMappings.AssignedResources assigned = + new ResourceMappings.AssignedResources(); + assigned.updateAssignedResources( + ImmutableList.of(new GpuDevice(0, 0, 0), new GpuDevice(1, 1, 2))); + resourceMappings.addAssignedResources(ResourceInformation.GPU_URI, + assigned); + + commandPlugin.setRequestsGpu(true); + commandPlugin.updateDockerRunCommand(runCommand, nmContainer); + } + + @Test + public void testPluginMIGNoThrowsMulti() throws Exception { + DockerRunCommand runCommand = new DockerRunCommand("container_1", "user", + "fakeimage"); + + Map> originalCommandline = copyCommandLine( + runCommand.getDockerCommandWithArguments()); + + Configuration conf = new Configuration(); + conf.set(YarnConfiguration.USE_MIG_ENABLED_GPUS, "true"); + MyNvidiaDockerV2CommandPlugin + commandPlugin = new MyNvidiaDockerV2CommandPlugin(conf); + + Container nmContainer = mock(Container.class); + ResourceMappings resourceMappings = new ResourceMappings(); + Map env = new HashMap<>(); + env.put("NVIDIA_MIG_PLUGIN_THROW_ON_MULTIPLE_GPUS", "false"); + when(nmContainer.getResourceMappings()).thenReturn(resourceMappings); + ContainerLaunchContext launchCtx = mock(ContainerLaunchContext.class); + when(nmContainer.getLaunchContext()).thenReturn(launchCtx); + when(launchCtx.getEnvironment()).thenReturn(env); + + // Assign GPU resource + ResourceMappings.AssignedResources assigned = + new ResourceMappings.AssignedResources(); + assigned.updateAssignedResources( + ImmutableList.of(new GpuDevice(0, 0, 0), new GpuDevice(1, 1, 2))); + resourceMappings.addAssignedResources(ResourceInformation.GPU_URI, + assigned); + + commandPlugin.setRequestsGpu(true); + commandPlugin.updateDockerRunCommand(runCommand, nmContainer); + Map> newCommandLine = + runCommand.getDockerCommandWithArguments(); + // NVIDIA_VISIBLE_DEVICES will be set + Assert.assertTrue( + runCommand.getEnv().get("NVIDIA_VISIBLE_DEVICES").equals("0:0,1:2")); + // runtime should exist + Assert.assertTrue(newCommandLine.containsKey("runtime")); + } +} ================================================ FILE: examples/MIG-Support/resource-types/gpu-mig/yarn313to315MIG.patch ================================================ diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-api/src/main/java/org/apache/hadoop/yarn/conf/YarnConfiguration.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-api/src/main/java/org/apache/hadoop/yarn/conf/YarnConfiguration.java index 737baee70bb..0e113036a80 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-api/src/main/java/org/apache/hadoop/yarn/conf/YarnConfiguration.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-api/src/main/java/org/apache/hadoop/yarn/conf/YarnConfiguration.java @@ -1655,6 +1655,10 @@ public static boolean isAclEnabled(Configuration conf) { @Private public static final String AUTOMATICALLY_DISCOVER_GPU_DEVICES = "auto"; + @Private + public static final String USE_MIG_ENABLED_GPUS = + NM_GPU_RESOURCE_PREFIX + "use-mig-enabled"; + /** * This setting controls where to how to invoke GPU binaries */ diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/AssignedGpuDevice.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/AssignedGpuDevice.java index 26fd9050742..e84b920dcee 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/AssignedGpuDevice.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/AssignedGpuDevice.java @@ -34,6 +34,12 @@ public AssignedGpuDevice(int index, int minorNumber, this.containerId = containerId.toString(); } + public AssignedGpuDevice(int index, int minorNumber, + int migIndex, ContainerId containerId) { + super(index, minorNumber, migIndex); + this.containerId = containerId.toString(); + } + public String getContainerId() { return containerId; } @@ -49,6 +55,7 @@ public boolean equals(Object obj) { } AssignedGpuDevice other = (AssignedGpuDevice) obj; return index == other.index && minorNumber == other.minorNumber + && migDeviceIndex == other.migDeviceIndex && containerId.equals(other.containerId); } @@ -68,12 +75,16 @@ public int compareTo(Object obj) { if (0 != result) { return result; } - return containerId.compareTo(other.containerId); + result = containerId.compareTo(other.containerId); + if (0 != result) { + return result; + } + return Integer.compare(migDeviceIndex, other.migDeviceIndex); } @Override public int hashCode() { final int prime = 47; - return prime * (prime * index + minorNumber) + containerId.hashCode(); + return prime * (prime * index + minorNumber + migDeviceIndex) + containerId.hashCode(); } } diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDevice.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDevice.java index bce1d9fa480..3cb42d3c58f 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDevice.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDevice.java @@ -26,6 +26,7 @@ public class GpuDevice implements Serializable, Comparable { protected int index; protected int minorNumber; + protected int migDeviceIndex = -1; private static final long serialVersionUID = -6812314470754667710L; public GpuDevice(int index, int minorNumber) { @@ -33,6 +34,12 @@ public GpuDevice(int index, int minorNumber) { this.minorNumber = minorNumber; } + public GpuDevice(int index, int minorNumber, int migIndex) { + this.index = index; + this.minorNumber = minorNumber; + this.migDeviceIndex = migIndex; + } + public int getIndex() { return index; } @@ -41,13 +48,17 @@ public int getMinorNumber() { return minorNumber; } + public int getMIGIndex() { + return migDeviceIndex; + } + @Override public boolean equals(Object obj) { if (obj == null || !(obj instanceof GpuDevice)) { return false; } GpuDevice other = (GpuDevice) obj; - return index == other.index && minorNumber == other.minorNumber; + return index == other.index && minorNumber == other.minorNumber && migDeviceIndex == other.migDeviceIndex; } @Override @@ -62,17 +73,21 @@ public int compareTo(Object obj) { if (0 != result) { return result; } - return Integer.compare(minorNumber, other.minorNumber); + result = Integer.compare(minorNumber, other.minorNumber); + if (0 != result) { + return result; + } + return Integer.compare(migDeviceIndex, other.migDeviceIndex); } @Override public int hashCode() { final int prime = 47; - return prime * index + minorNumber; + return prime * index + minorNumber + migDeviceIndex; } @Override public String toString() { - return "(index=" + index + ",minor_number=" + minorNumber + ")"; + return "(index=" + index + ",minor_number=" + minorNumber + ",mig_index=" + migDeviceIndex + ")"; } } diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDeviceSpecificationException.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDeviceSpecificationException.java index 9d61b91a1f2..d775aab0226 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDeviceSpecificationException.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDeviceSpecificationException.java @@ -26,6 +26,8 @@ public final class GpuDeviceSpecificationException extends YarnException { private static final String VALID_FORMAT_MESSAGE = "The valid format " + "should be: index:minor_number"; + private static final String VALID_MIG_FORMAT_MESSAGE = VALID_FORMAT_MESSAGE + + "or with MIG enabled: index:minor_number:mig_index"; private GpuDeviceSpecificationException(String message) { super(message); @@ -57,12 +59,25 @@ public static GpuDeviceSpecificationException createWithWrongValueSpecified( return new GpuDeviceSpecificationException(message); } + public static GpuDeviceSpecificationException createWithWrongValueSpecifiedMIG( + String device, String configValue) { + final String message = createIllegalFormatMessageMIG(device, configValue); + return new GpuDeviceSpecificationException(message); + } + public static GpuDeviceSpecificationException createWithDuplicateValueSpecified( String device, String configValue) { final String message = createDuplicateFormatMessage(device, configValue); return new GpuDeviceSpecificationException(message); } + private static String createIllegalFormatMessageMIG(String device, + String configValue) { + return String.format("Illegal format of individual GPU device: %s, " + + "the whole config value was: '%s'! " + VALID_MIG_FORMAT_MESSAGE, + device, configValue); + } + private static String createIllegalFormatMessage(String device, String configValue) { return String.format("Illegal format of individual GPU device: %s, " + @@ -79,4 +94,4 @@ private static String createDuplicateFormatMessage(String device, "! Current value of the configuration is: %s", device, configValue); } -} \ No newline at end of file +} diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDiscoverer.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDiscoverer.java index ce767229e50..c74651b41df 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDiscoverer.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDiscoverer.java @@ -31,6 +31,7 @@ import org.apache.hadoop.yarn.server.nodemanager.webapp.dao.gpu.GpuDeviceInformation; import org.apache.hadoop.yarn.server.nodemanager.webapp.dao.gpu.GpuDeviceInformationParser; import org.apache.hadoop.yarn.server.nodemanager.webapp.dao.gpu.PerGpuDeviceInformation; +import org.apache.hadoop.yarn.server.nodemanager.webapp.dao.gpu.PerGpuMigDevice; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -69,6 +70,7 @@ private GpuDeviceInformation lastDiscoveredGpuInformation = null; private List gpuDevicesFromUser; + private Boolean useMIGEnabledGPUs = false; private void validateConfOrThrowException() throws YarnException { if (conf == null) { @@ -194,8 +196,17 @@ private boolean IsAutoDiscoveryEnabled() { for (int i = 0; i < numberOfGpus; i++) { List gpuInfos = lastDiscoveredGpuInformation.getGpus(); - gpuDevices.add(new GpuDevice(i, gpuInfos.get(i).getMinorNumber())); + if (useMIGEnabledGPUs && + gpuInfos.get(i).getMIGMode().getCurrentMigMode().equalsIgnoreCase("enabled")) { + LOG.info("GPU id " + i + " has MIG mode enabled."); + for (PerGpuMigDevice dev: gpuInfos.get(i).getMIGDevices()) { + gpuDevices.add(new GpuDevice(i, gpuInfos.get(i).getMinorNumber(), dev.getMigDeviceIndex())); + } + } else { + gpuDevices.add(new GpuDevice(i, gpuInfos.get(i).getMinorNumber())); + } } + LOG.info("Discovered GPU devices: " + gpuDevices); } return gpuDevices; } @@ -218,18 +229,39 @@ private boolean IsAutoDiscoveryEnabled() { for (String device : devices.split(",")) { if (device.trim().length() > 0) { String[] splitByColon = device.trim().split(":"); - if (splitByColon.length != 2) { - throw GpuDeviceSpecificationException. - createWithWrongValueSpecified(device, devices); - } - - GpuDevice gpuDevice = parseGpuDevice(device, splitByColon, devices); - if (!gpuDevices.contains(gpuDevice)) { - gpuDevices.add(gpuDevice); + if (useMIGEnabledGPUs) { + if (splitByColon.length != 2 && splitByColon.length != 3) { + throw GpuDeviceSpecificationException. + createWithWrongValueSpecifiedMIG(device, devices); + } + + GpuDevice gpuDevice; + if (splitByColon.length == 3) { + gpuDevice = parseGpuMIGDevice(device, splitByColon, devices); + } else { + gpuDevice = parseGpuDevice(device, splitByColon, devices); + } + if (!gpuDevices.contains(gpuDevice)) { + gpuDevices.add(gpuDevice); + } else { + throw GpuDeviceSpecificationException + .createWithDuplicateValueSpecified(device, devices); + } } else { - throw GpuDeviceSpecificationException - .createWithDuplicateValueSpecified(device, devices); + if (splitByColon.length != 2) { + throw GpuDeviceSpecificationException. + createWithWrongValueSpecified(device, devices); + } + + GpuDevice gpuDevice = parseGpuDevice(device, splitByColon, devices); + if (!gpuDevices.contains(gpuDevice)) { + gpuDevices.add(gpuDevice); + } else { + throw GpuDeviceSpecificationException + .createWithDuplicateValueSpecified(device, devices); + } } + } } LOG.info("Allowed GPU devices:" + gpuDevices); @@ -237,6 +269,19 @@ private boolean IsAutoDiscoveryEnabled() { return gpuDevices; } + private GpuDevice parseGpuMIGDevice(String device, String[] splitByColon, + String allowedDevicesStr) throws YarnException { + try { + int index = Integer.parseInt(splitByColon[0]); + int minorNumber = Integer.parseInt(splitByColon[1]); + int migIndex = Integer.parseInt(splitByColon[2]); + return new GpuDevice(index, minorNumber, migIndex); + } catch (NumberFormatException e) { + throw GpuDeviceSpecificationException. + createWithWrongValueSpecified(device, allowedDevicesStr, e); + } + } + private GpuDevice parseGpuDevice(String device, String[] splitByColon, String allowedDevicesStr) throws YarnException { try { @@ -268,6 +313,9 @@ public synchronized void initialize(Configuration config) LOG.warn(msg); } } + useMIGEnabledGPUs = conf.getBoolean(YarnConfiguration.USE_MIG_ENABLED_GPUS, false); + LOG.info("Use MIG enabled is: " + useMIGEnabledGPUs); + } private void lookUpAutoDiscoveryBinary(Configuration config) diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDockerCommandPluginFactory.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDockerCommandPluginFactory.java index 051afd6c561..996cb58ac45 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDockerCommandPluginFactory.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDockerCommandPluginFactory.java @@ -36,7 +36,7 @@ public static DockerCommandPlugin createGpuDockerCommandPlugin( } // nvidia-docker2 if (impl.equals(YarnConfiguration.NVIDIA_DOCKER_V2)) { - return new NvidiaDockerV2CommandPlugin(); + return new NvidiaDockerV2CommandPlugin(conf); } throw new YarnException( diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/NvidiaDockerV2CommandPlugin.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/NvidiaDockerV2CommandPlugin.java index ff25eb6ced6..c2cc0e5a2d1 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/NvidiaDockerV2CommandPlugin.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/NvidiaDockerV2CommandPlugin.java @@ -21,7 +21,9 @@ import com.google.common.annotations.VisibleForTesting; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.yarn.api.records.ResourceInformation; +import org.apache.hadoop.yarn.conf.YarnConfiguration; import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container; import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.ResourceMappings; import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.gpu.GpuResourceAllocator; @@ -45,8 +47,12 @@ private String nvidiaRuntime = "nvidia"; private String nvidiaVisibleDevices = "NVIDIA_VISIBLE_DEVICES"; + private String nvidiaMigThrowOnMultiGpus = "NVIDIA_MIG_PLUGIN_THROW_ON_MULTIPLE_GPUS"; + private Boolean isMigEnabled = false; - public NvidiaDockerV2CommandPlugin() {} + public NvidiaDockerV2CommandPlugin(Configuration conf) { + isMigEnabled = conf.getBoolean(YarnConfiguration.USE_MIG_ENABLED_GPUS, false); + } private Set getAssignedGpus(Container container) { ResourceMappings resourceMappings = container.getResourceMappings(); @@ -84,10 +90,23 @@ public synchronized void updateDockerRunCommand( return; } Map environment = new HashMap<>(); + if (isMigEnabled && assignedResources.size() > 1) { + Map existingEnv = container.getLaunchContext().getEnvironment(); + Boolean shouldThrowOnMultipleGpus = Boolean.parseBoolean( + existingEnv.getOrDefault(nvidiaMigThrowOnMultiGpus, "true")); + if (shouldThrowOnMultipleGpus) { + throw new ContainerExecutionException("Allocating more than 1 GPU per container is " + + "not supported with use of MIG!"); + } + } String gpuIndexList = ""; for (GpuDevice gpuDevice : assignedResources) { - gpuIndexList = gpuIndexList + gpuDevice.getIndex() + ","; - LOG.info("nvidia docker2 assigned gpu index: " + gpuDevice.getIndex()); + String deviceIndex = String.valueOf(gpuDevice.getIndex()); + if (gpuDevice.getMIGIndex() != -1) { + deviceIndex = gpuDevice.getIndex() + ":" + gpuDevice.getMIGIndex(); + } + gpuIndexList = gpuIndexList + deviceIndex + ","; + LOG.info("nvidia docker2 assigned gpu index: " + deviceIndex); } dockerRunCommand.addRuntime(nvidiaRuntime); environment.put(nvidiaVisibleDevices, diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuDeviceInformation.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuDeviceInformation.java index 11ff2a4c49c..939ed46aac7 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuDeviceInformation.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuDeviceInformation.java @@ -22,8 +22,10 @@ import org.apache.hadoop.classification.InterfaceStability; import javax.xml.bind.annotation.XmlElement; +import javax.xml.bind.annotation.XmlElementWrapper; import javax.xml.bind.annotation.XmlRootElement; import javax.xml.bind.annotation.adapters.XmlAdapter; +import java.util.List; /** * Capture single GPU device information such as memory size, temperature, @@ -37,6 +39,8 @@ private String uuid = "N/A"; private int minorNumber = -1; + private List migDevices; + private PerGpuMigMode migMode; private PerGpuUtilizations gpuUtilizations; private PerGpuMemoryUsage gpuMemoryUsage; private PerGpuTemperature temperature; @@ -107,6 +111,25 @@ public void setUuid(String uuid) { this.uuid = uuid; } + @XmlElement(name = "mig_mode") + public PerGpuMigMode getMIGMode() { + return migMode; + } + + public void setMIGMode(PerGpuMigMode mode) { + this.migMode = mode; + } + + @XmlElementWrapper( name = "mig_devices" ) + @XmlElement(name = "mig_device") + public List getMIGDevices() { + return migDevices; + } + + public void setMIGDevices(List devices) { + this.migDevices = devices; + } + @XmlElement(name = "product_name") public String getProductName() { return productName; diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuMigDevice.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuMigDevice.java new file mode 100644 index 00000000000..4ce7cec6e55 --- /dev/null +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuMigDevice.java @@ -0,0 +1,48 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.server.nodemanager.webapp.dao.gpu; + +import org.apache.hadoop.classification.InterfaceAudience; +import org.apache.hadoop.classification.InterfaceStability; + +import javax.xml.bind.annotation.XmlElement; +import javax.xml.bind.annotation.XmlRootElement; + +/** + * GPU MIG Device Information + */ +@InterfaceAudience.Private +@InterfaceStability.Unstable +@XmlRootElement(name = "mig_device") +public class PerGpuMigDevice { + private int index; + + /** + * MIG device index + * @return MIG device index + */ + @XmlElement(name = "index") + public int getMigDeviceIndex() { + return index; + } + + public void setMigDeviceIndex(int index) { + this.index = index; + } +} diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuMigMode.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuMigMode.java new file mode 100644 index 00000000000..b706df2c3bb --- /dev/null +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuMigMode.java @@ -0,0 +1,48 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.server.nodemanager.webapp.dao.gpu; + +import org.apache.hadoop.classification.InterfaceAudience; +import org.apache.hadoop.classification.InterfaceStability; + +import javax.xml.bind.annotation.XmlElement; +import javax.xml.bind.annotation.XmlRootElement; + +/** + * GPU MIG Mode + */ +@InterfaceAudience.Private +@InterfaceStability.Unstable +@XmlRootElement(name = "mig_mode") +public class PerGpuMigMode { + private String currentMigMode; + + /** + * Current MIG mode + * @return MIG mode enabled or disabled + */ + @XmlElement(name = "current_mig") + public String getCurrentMigMode() { + return currentMigMode; + } + + public void setCurrentMigMode(String migMode) { + this.currentMigMode = migMode; + } +} diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/TestGpuDiscoverer.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/TestGpuDiscoverer.java index f0f100c1f8b..02b213b6734 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/TestGpuDiscoverer.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/TestGpuDiscoverer.java @@ -372,6 +372,37 @@ public void testGetNumberOfUsableGpusFromConfig() throws YarnException { assertEquals(4, usableGpuDevices.get(3).getMinorNumber()); } + @Test + public void testGetNumberOfUsableGpusFromConfigMIG() throws YarnException { + Configuration conf = createConfigWithAllowedDevices("0:0,1:1:0,1:1:3,2:2,3:4"); + conf.set(YarnConfiguration.USE_MIG_ENABLED_GPUS, "true"); + GpuDiscoverer discoverer = new GpuDiscoverer(); + discoverer.initialize(conf); + + List usableGpuDevices = discoverer.getGpusUsableByYarn(); + assertEquals(5, usableGpuDevices.size()); + + assertEquals(0, usableGpuDevices.get(0).getIndex()); + assertEquals(0, usableGpuDevices.get(0).getMinorNumber()); + assertEquals(-1, usableGpuDevices.get(0).getMIGIndex()); + + assertEquals(1, usableGpuDevices.get(1).getIndex()); + assertEquals(1, usableGpuDevices.get(1).getMinorNumber()); + assertEquals(0, usableGpuDevices.get(1).getMIGIndex()); + + assertEquals(1, usableGpuDevices.get(2).getIndex()); + assertEquals(1, usableGpuDevices.get(2).getMinorNumber()); + assertEquals(3, usableGpuDevices.get(2).getMIGIndex()); + + assertEquals(2, usableGpuDevices.get(3).getIndex()); + assertEquals(2, usableGpuDevices.get(3).getMinorNumber()); + assertEquals(-1, usableGpuDevices.get(3).getMIGIndex()); + + assertEquals(3, usableGpuDevices.get(4).getIndex()); + assertEquals(4, usableGpuDevices.get(4).getMinorNumber()); + assertEquals(-1, usableGpuDevices.get(4).getMIGIndex()); + } + @Test public void testGetNumberOfUsableGpusFromConfigDuplicateValues() throws YarnException { @@ -512,4 +543,5 @@ public void testScriptNotCalled() throws YarnException { verify(gpuSpy, never()).getGpuDeviceInformation(); } + } diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/TestNvidiaDockerV2CommandPlugin.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/TestNvidiaDockerV2CommandPlugin.java index b0b523360ef..798a95cb009 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/TestNvidiaDockerV2CommandPlugin.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/TestNvidiaDockerV2CommandPlugin.java @@ -20,10 +20,14 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.Sets; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.yarn.conf.YarnConfiguration; +import org.apache.hadoop.yarn.api.records.ContainerLaunchContext; import org.apache.hadoop.yarn.api.records.ResourceInformation; import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container; import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.ResourceMappings; import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.docker.DockerRunCommand; +import org.apache.hadoop.yarn.server.nodemanager.containermanager.runtime.ContainerExecutionException; import org.junit.Assert; import org.junit.Test; @@ -69,7 +73,13 @@ private boolean commandlinesEquals(Map> cli1, extends NvidiaDockerV2CommandPlugin { private boolean requestsGpu = false; - MyNvidiaDockerV2CommandPlugin() {} + MyNvidiaDockerV2CommandPlugin() { + super(new Configuration()); + } + + MyNvidiaDockerV2CommandPlugin(Configuration conf) { + super(conf); + } public void setRequestsGpu(boolean r) { requestsGpu = r; @@ -127,4 +137,118 @@ public void testPlugin() throws Exception { // runtime should exist Assert.assertTrue(newCommandLine.containsKey("runtime")); } -} \ No newline at end of file + + @Test + public void testPluginMIG() throws Exception { + DockerRunCommand runCommand = new DockerRunCommand("container_1", "user", + "fakeimage"); + + Map> originalCommandline = copyCommandLine( + runCommand.getDockerCommandWithArguments()); + + Configuration conf = new Configuration(); + conf.set(YarnConfiguration.USE_MIG_ENABLED_GPUS, "true"); + MyNvidiaDockerV2CommandPlugin + commandPlugin = new MyNvidiaDockerV2CommandPlugin(conf); + + Container nmContainer = mock(Container.class); + ResourceMappings resourceMappings = new ResourceMappings(); + when(nmContainer.getResourceMappings()).thenReturn(resourceMappings); + + // Assign GPU resource + ResourceMappings.AssignedResources assigned = + new ResourceMappings.AssignedResources(); + assigned.updateAssignedResources( + ImmutableList.of(new GpuDevice(0, 0, 0))); + resourceMappings.addAssignedResources(ResourceInformation.GPU_URI, + assigned); + + commandPlugin.setRequestsGpu(true); + commandPlugin.updateDockerRunCommand(runCommand, nmContainer); + Map> newCommandLine = + runCommand.getDockerCommandWithArguments(); + + // Command line will be updated + Assert.assertFalse(commandlinesEquals(originalCommandline, newCommandLine)); + // NVIDIA_VISIBLE_DEVICES will be set + Assert.assertTrue( + runCommand.getEnv().get("NVIDIA_VISIBLE_DEVICES").equals("0:0")); + // runtime should exist + Assert.assertTrue(newCommandLine.containsKey("runtime")); + } + + @Test(expected = ContainerExecutionException.class) + public void testPluginMIGThrowsMulti() throws Exception { + DockerRunCommand runCommand = new DockerRunCommand("container_1", "user", + "fakeimage"); + + Map> originalCommandline = copyCommandLine( + runCommand.getDockerCommandWithArguments()); + + Configuration conf = new Configuration(); + conf.set(YarnConfiguration.USE_MIG_ENABLED_GPUS, "true"); + MyNvidiaDockerV2CommandPlugin + commandPlugin = new MyNvidiaDockerV2CommandPlugin(conf); + + Container nmContainer = mock(Container.class); + ResourceMappings resourceMappings = new ResourceMappings(); + Map env = new HashMap<>(); + env.put("NVIDIA_MIG_PLUGIN_THROW_ON_MULTIPLE_GPUS", "true"); + when(nmContainer.getResourceMappings()).thenReturn(resourceMappings); + ContainerLaunchContext launchCtx = mock(ContainerLaunchContext.class); + when(nmContainer.getLaunchContext()).thenReturn(launchCtx); + when(launchCtx.getEnvironment()).thenReturn(env); + + // Assign GPU resource + ResourceMappings.AssignedResources assigned = + new ResourceMappings.AssignedResources(); + assigned.updateAssignedResources( + ImmutableList.of(new GpuDevice(0, 0, 0), new GpuDevice(1, 1, 2))); + resourceMappings.addAssignedResources(ResourceInformation.GPU_URI, + assigned); + + commandPlugin.setRequestsGpu(true); + commandPlugin.updateDockerRunCommand(runCommand, nmContainer); + } + + @Test + public void testPluginMIGNoThrowsMulti() throws Exception { + DockerRunCommand runCommand = new DockerRunCommand("container_1", "user", + "fakeimage"); + + Map> originalCommandline = copyCommandLine( + runCommand.getDockerCommandWithArguments()); + + Configuration conf = new Configuration(); + conf.set(YarnConfiguration.USE_MIG_ENABLED_GPUS, "true"); + MyNvidiaDockerV2CommandPlugin + commandPlugin = new MyNvidiaDockerV2CommandPlugin(conf); + + Container nmContainer = mock(Container.class); + ResourceMappings resourceMappings = new ResourceMappings(); + Map env = new HashMap<>(); + env.put("NVIDIA_MIG_PLUGIN_THROW_ON_MULTIPLE_GPUS", "false"); + when(nmContainer.getResourceMappings()).thenReturn(resourceMappings); + ContainerLaunchContext launchCtx = mock(ContainerLaunchContext.class); + when(nmContainer.getLaunchContext()).thenReturn(launchCtx); + when(launchCtx.getEnvironment()).thenReturn(env); + + // Assign GPU resource + ResourceMappings.AssignedResources assigned = + new ResourceMappings.AssignedResources(); + assigned.updateAssignedResources( + ImmutableList.of(new GpuDevice(0, 0, 0), new GpuDevice(1, 1, 2))); + resourceMappings.addAssignedResources(ResourceInformation.GPU_URI, + assigned); + + commandPlugin.setRequestsGpu(true); + commandPlugin.updateDockerRunCommand(runCommand, nmContainer); + Map> newCommandLine = + runCommand.getDockerCommandWithArguments(); + // NVIDIA_VISIBLE_DEVICES will be set + Assert.assertTrue( + runCommand.getEnv().get("NVIDIA_VISIBLE_DEVICES").equals("0:0,1:2")); + // runtime should exist + Assert.assertTrue(newCommandLine.containsKey("runtime")); + } +} ================================================ FILE: examples/MIG-Support/resource-types/gpu-mig/yarn321to323MIG.patch ================================================ diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-api/src/main/java/org/apache/hadoop/yarn/conf/YarnConfiguration.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-api/src/main/java/org/apache/hadoop/yarn/conf/YarnConfiguration.java index ad4d87daa1a..95259b1d956 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-api/src/main/java/org/apache/hadoop/yarn/conf/YarnConfiguration.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-api/src/main/java/org/apache/hadoop/yarn/conf/YarnConfiguration.java @@ -1716,6 +1716,10 @@ public static boolean isAclEnabled(Configuration conf) { @Private public static final String AUTOMATICALLY_DISCOVER_GPU_DEVICES = "auto"; + @Private + public static final String USE_MIG_ENABLED_GPUS = + NM_GPU_RESOURCE_PREFIX + "use-mig-enabled"; + /** * This setting controls where to how to invoke GPU binaries */ diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/AssignedGpuDevice.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/AssignedGpuDevice.java index 26fd9050742..e84b920dcee 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/AssignedGpuDevice.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/AssignedGpuDevice.java @@ -34,6 +34,12 @@ public AssignedGpuDevice(int index, int minorNumber, this.containerId = containerId.toString(); } + public AssignedGpuDevice(int index, int minorNumber, + int migIndex, ContainerId containerId) { + super(index, minorNumber, migIndex); + this.containerId = containerId.toString(); + } + public String getContainerId() { return containerId; } @@ -49,6 +55,7 @@ public boolean equals(Object obj) { } AssignedGpuDevice other = (AssignedGpuDevice) obj; return index == other.index && minorNumber == other.minorNumber + && migDeviceIndex == other.migDeviceIndex && containerId.equals(other.containerId); } @@ -68,12 +75,16 @@ public int compareTo(Object obj) { if (0 != result) { return result; } - return containerId.compareTo(other.containerId); + result = containerId.compareTo(other.containerId); + if (0 != result) { + return result; + } + return Integer.compare(migDeviceIndex, other.migDeviceIndex); } @Override public int hashCode() { final int prime = 47; - return prime * (prime * index + minorNumber) + containerId.hashCode(); + return prime * (prime * index + minorNumber + migDeviceIndex) + containerId.hashCode(); } } diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDevice.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDevice.java index bce1d9fa480..3cb42d3c58f 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDevice.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDevice.java @@ -26,6 +26,7 @@ public class GpuDevice implements Serializable, Comparable { protected int index; protected int minorNumber; + protected int migDeviceIndex = -1; private static final long serialVersionUID = -6812314470754667710L; public GpuDevice(int index, int minorNumber) { @@ -33,6 +34,12 @@ public GpuDevice(int index, int minorNumber) { this.minorNumber = minorNumber; } + public GpuDevice(int index, int minorNumber, int migIndex) { + this.index = index; + this.minorNumber = minorNumber; + this.migDeviceIndex = migIndex; + } + public int getIndex() { return index; } @@ -41,13 +48,17 @@ public int getMinorNumber() { return minorNumber; } + public int getMIGIndex() { + return migDeviceIndex; + } + @Override public boolean equals(Object obj) { if (obj == null || !(obj instanceof GpuDevice)) { return false; } GpuDevice other = (GpuDevice) obj; - return index == other.index && minorNumber == other.minorNumber; + return index == other.index && minorNumber == other.minorNumber && migDeviceIndex == other.migDeviceIndex; } @Override @@ -62,17 +73,21 @@ public int compareTo(Object obj) { if (0 != result) { return result; } - return Integer.compare(minorNumber, other.minorNumber); + result = Integer.compare(minorNumber, other.minorNumber); + if (0 != result) { + return result; + } + return Integer.compare(migDeviceIndex, other.migDeviceIndex); } @Override public int hashCode() { final int prime = 47; - return prime * index + minorNumber; + return prime * index + minorNumber + migDeviceIndex; } @Override public String toString() { - return "(index=" + index + ",minor_number=" + minorNumber + ")"; + return "(index=" + index + ",minor_number=" + minorNumber + ",mig_index=" + migDeviceIndex + ")"; } } diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDeviceSpecificationException.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDeviceSpecificationException.java index 9d61b91a1f2..ffc2a4c19af 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDeviceSpecificationException.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDeviceSpecificationException.java @@ -26,6 +26,8 @@ public final class GpuDeviceSpecificationException extends YarnException { private static final String VALID_FORMAT_MESSAGE = "The valid format " + "should be: index:minor_number"; + private static final String VALID_MIG_FORMAT_MESSAGE = VALID_FORMAT_MESSAGE + + " or with MIG enabled: index:minor_number:mig_index"; private GpuDeviceSpecificationException(String message) { super(message); @@ -57,12 +59,31 @@ public static GpuDeviceSpecificationException createWithWrongValueSpecified( return new GpuDeviceSpecificationException(message); } + public static GpuDeviceSpecificationException createWithWrongValueSpecifiedMIG( + String device, String configValue, Exception cause) { + final String message = createIllegalFormatMessageMIG(device, configValue); + return new GpuDeviceSpecificationException(message, cause); + } + + public static GpuDeviceSpecificationException createWithWrongValueSpecifiedMIG( + String device, String configValue) { + final String message = createIllegalFormatMessageMIG(device, configValue); + return new GpuDeviceSpecificationException(message); + } + public static GpuDeviceSpecificationException createWithDuplicateValueSpecified( String device, String configValue) { final String message = createDuplicateFormatMessage(device, configValue); return new GpuDeviceSpecificationException(message); } + private static String createIllegalFormatMessageMIG(String device, + String configValue) { + return String.format("Illegal format of individual GPU device: %s, " + + "the whole config value was: '%s'! " + VALID_MIG_FORMAT_MESSAGE, + device, configValue); + } + private static String createIllegalFormatMessage(String device, String configValue) { return String.format("Illegal format of individual GPU device: %s, " + @@ -79,4 +100,4 @@ private static String createDuplicateFormatMessage(String device, "! Current value of the configuration is: %s", device, configValue); } -} \ No newline at end of file +} diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDiscoverer.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDiscoverer.java index f710ff0bccd..1517e12599a 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDiscoverer.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDiscoverer.java @@ -36,6 +36,7 @@ import org.apache.hadoop.yarn.server.nodemanager.webapp.dao.gpu.GpuDeviceInformation; import org.apache.hadoop.yarn.server.nodemanager.webapp.dao.gpu.GpuDeviceInformationParser; import org.apache.hadoop.yarn.server.nodemanager.webapp.dao.gpu.PerGpuDeviceInformation; +import org.apache.hadoop.yarn.server.nodemanager.webapp.dao.gpu.PerGpuMigDevice; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -70,6 +71,7 @@ private GpuDeviceInformation lastDiscoveredGpuInformation = null; private List gpuDevicesFromUser; + private Boolean useMIGEnabledGPUs = false; private void validateConfOrThrowException() throws YarnException { if (conf == null) { @@ -188,8 +190,17 @@ private boolean isAutoDiscoveryEnabled() { for (int i = 0; i < numberOfGpus; i++) { List gpuInfos = lastDiscoveredGpuInformation.getGpus(); - gpuDevices.add(new GpuDevice(i, gpuInfos.get(i).getMinorNumber())); + if (useMIGEnabledGPUs && + gpuInfos.get(i).getMIGMode().getCurrentMigMode().equalsIgnoreCase("enabled")) { + LOG.info("GPU id " + i + " has MIG mode enabled."); + for (PerGpuMigDevice dev: gpuInfos.get(i).getMIGDevices()) { + gpuDevices.add(new GpuDevice(i, gpuInfos.get(i).getMinorNumber(), dev.getMigDeviceIndex())); + } + } else { + gpuDevices.add(new GpuDevice(i, gpuInfos.get(i).getMinorNumber())); + } } + LOG.info("Discovered GPU devices: " + gpuDevices); } return gpuDevices; } @@ -212,28 +223,56 @@ private boolean isAutoDiscoveryEnabled() { for (String device : devices.split(",")) { if (device.trim().length() > 0) { String[] splitByColon = device.trim().split(":"); - if (splitByColon.length != 2) { - throwIfNecessary(GpuDeviceSpecificationException - .createWithWrongValueSpecified(device, devices), conf); - LOG.warn("Wrong GPU specification string {}, ignored", device); - } - GpuDevice gpuDevice; - try { - gpuDevice = parseGpuDevice(splitByColon); - } catch (NumberFormatException e) { - throwIfNecessary(GpuDeviceSpecificationException - .createWithWrongValueSpecified(device, devices, e), conf); - LOG.warn("Cannot parse GPU device numbers: {}", device); - continue; - } + if (useMIGEnabledGPUs) { + if (splitByColon.length != 2 && splitByColon.length != 3) { + throwIfNecessary(GpuDeviceSpecificationException + .createWithWrongValueSpecifiedMIG(device, devices), conf); + LOG.warn("Wrong GPU specification string {}, ignored", device); + } + GpuDevice gpuDevice; + try { + if (splitByColon.length == 3) { + gpuDevice = parseGpuMIGDevice(splitByColon); + } else { + gpuDevice = parseGpuDevice(splitByColon); + } + } catch (NumberFormatException e) { + throwIfNecessary(GpuDeviceSpecificationException + .createWithWrongValueSpecifiedMIG(device, devices, e), conf); + LOG.warn("Cannot parse GPU device numbers: {}", device); + continue; + } + if (!gpuDevices.contains(gpuDevice)) { + gpuDevices.add(gpuDevice); + } else { + throw GpuDeviceSpecificationException + .createWithDuplicateValueSpecified(device, devices); + } - if (!gpuDevices.contains(gpuDevice)) { - gpuDevices.add(gpuDevice); } else { - throwIfNecessary(GpuDeviceSpecificationException - .createWithDuplicateValueSpecified(device, devices), conf); - LOG.warn("CPU device is duplicated: {}", device); + if (splitByColon.length != 2) { + throwIfNecessary(GpuDeviceSpecificationException + .createWithWrongValueSpecified(device, devices), conf); + LOG.warn("Wrong GPU specification string {}, ignored", device); + } + GpuDevice gpuDevice; + try { + gpuDevice = parseGpuDevice(splitByColon); + } catch (NumberFormatException e) { + throwIfNecessary(GpuDeviceSpecificationException + .createWithWrongValueSpecified(device, devices, e), conf); + LOG.warn("Cannot parse GPU device numbers: {}", device); + continue; + } + + if (!gpuDevices.contains(gpuDevice)) { + gpuDevices.add(gpuDevice); + } else { + throwIfNecessary(GpuDeviceSpecificationException + .createWithDuplicateValueSpecified(device, devices), conf); + LOG.warn("CPU device is duplicated: {}", device); + } } } } @@ -248,6 +287,12 @@ private GpuDevice parseGpuDevice(String[] splitByColon) { return new GpuDevice(index, minorNumber); } + private GpuDevice parseGpuMIGDevice(String[] splitByColon) { + int index = Integer.parseInt(splitByColon[0]); + int minorNumber = Integer.parseInt(splitByColon[1]); + int migIndex = Integer.parseInt(splitByColon[2]); + return new GpuDevice(index, minorNumber, migIndex); + } public synchronized void initialize(Configuration config, NvidiaBinaryHelper nvidiaHelper) throws YarnException { @@ -269,6 +314,9 @@ public synchronized void initialize(Configuration config, LOG.warn(msg); } } + useMIGEnabledGPUs = conf.getBoolean(YarnConfiguration.USE_MIG_ENABLED_GPUS, false); + LOG.info("Use MIG enabled is: " + useMIGEnabledGPUs); + } private void lookUpAutoDiscoveryBinary(Configuration config) diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDockerCommandPluginFactory.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDockerCommandPluginFactory.java index 051afd6c561..996cb58ac45 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDockerCommandPluginFactory.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDockerCommandPluginFactory.java @@ -36,7 +36,7 @@ public static DockerCommandPlugin createGpuDockerCommandPlugin( } // nvidia-docker2 if (impl.equals(YarnConfiguration.NVIDIA_DOCKER_V2)) { - return new NvidiaDockerV2CommandPlugin(); + return new NvidiaDockerV2CommandPlugin(conf); } throw new YarnException( diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/NvidiaDockerV2CommandPlugin.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/NvidiaDockerV2CommandPlugin.java index ff25eb6ced6..c2cc0e5a2d1 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/NvidiaDockerV2CommandPlugin.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/NvidiaDockerV2CommandPlugin.java @@ -21,7 +21,9 @@ import com.google.common.annotations.VisibleForTesting; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.yarn.api.records.ResourceInformation; +import org.apache.hadoop.yarn.conf.YarnConfiguration; import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container; import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.ResourceMappings; import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.gpu.GpuResourceAllocator; @@ -45,8 +47,12 @@ private String nvidiaRuntime = "nvidia"; private String nvidiaVisibleDevices = "NVIDIA_VISIBLE_DEVICES"; + private String nvidiaMigThrowOnMultiGpus = "NVIDIA_MIG_PLUGIN_THROW_ON_MULTIPLE_GPUS"; + private Boolean isMigEnabled = false; - public NvidiaDockerV2CommandPlugin() {} + public NvidiaDockerV2CommandPlugin(Configuration conf) { + isMigEnabled = conf.getBoolean(YarnConfiguration.USE_MIG_ENABLED_GPUS, false); + } private Set getAssignedGpus(Container container) { ResourceMappings resourceMappings = container.getResourceMappings(); @@ -84,10 +90,23 @@ public synchronized void updateDockerRunCommand( return; } Map environment = new HashMap<>(); + if (isMigEnabled && assignedResources.size() > 1) { + Map existingEnv = container.getLaunchContext().getEnvironment(); + Boolean shouldThrowOnMultipleGpus = Boolean.parseBoolean( + existingEnv.getOrDefault(nvidiaMigThrowOnMultiGpus, "true")); + if (shouldThrowOnMultipleGpus) { + throw new ContainerExecutionException("Allocating more than 1 GPU per container is " + + "not supported with use of MIG!"); + } + } String gpuIndexList = ""; for (GpuDevice gpuDevice : assignedResources) { - gpuIndexList = gpuIndexList + gpuDevice.getIndex() + ","; - LOG.info("nvidia docker2 assigned gpu index: " + gpuDevice.getIndex()); + String deviceIndex = String.valueOf(gpuDevice.getIndex()); + if (gpuDevice.getMIGIndex() != -1) { + deviceIndex = gpuDevice.getIndex() + ":" + gpuDevice.getMIGIndex(); + } + gpuIndexList = gpuIndexList + deviceIndex + ","; + LOG.info("nvidia docker2 assigned gpu index: " + deviceIndex); } dockerRunCommand.addRuntime(nvidiaRuntime); environment.put(nvidiaVisibleDevices, diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuDeviceInformation.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuDeviceInformation.java index 11ff2a4c49c..939ed46aac7 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuDeviceInformation.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuDeviceInformation.java @@ -22,8 +22,10 @@ import org.apache.hadoop.classification.InterfaceStability; import javax.xml.bind.annotation.XmlElement; +import javax.xml.bind.annotation.XmlElementWrapper; import javax.xml.bind.annotation.XmlRootElement; import javax.xml.bind.annotation.adapters.XmlAdapter; +import java.util.List; /** * Capture single GPU device information such as memory size, temperature, @@ -37,6 +39,8 @@ private String uuid = "N/A"; private int minorNumber = -1; + private List migDevices; + private PerGpuMigMode migMode; private PerGpuUtilizations gpuUtilizations; private PerGpuMemoryUsage gpuMemoryUsage; private PerGpuTemperature temperature; @@ -107,6 +111,25 @@ public void setUuid(String uuid) { this.uuid = uuid; } + @XmlElement(name = "mig_mode") + public PerGpuMigMode getMIGMode() { + return migMode; + } + + public void setMIGMode(PerGpuMigMode mode) { + this.migMode = mode; + } + + @XmlElementWrapper( name = "mig_devices" ) + @XmlElement(name = "mig_device") + public List getMIGDevices() { + return migDevices; + } + + public void setMIGDevices(List devices) { + this.migDevices = devices; + } + @XmlElement(name = "product_name") public String getProductName() { return productName; diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuMigDevice.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuMigDevice.java new file mode 100644 index 00000000000..4ce7cec6e55 --- /dev/null +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuMigDevice.java @@ -0,0 +1,48 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.server.nodemanager.webapp.dao.gpu; + +import org.apache.hadoop.classification.InterfaceAudience; +import org.apache.hadoop.classification.InterfaceStability; + +import javax.xml.bind.annotation.XmlElement; +import javax.xml.bind.annotation.XmlRootElement; + +/** + * GPU MIG Device Information + */ +@InterfaceAudience.Private +@InterfaceStability.Unstable +@XmlRootElement(name = "mig_device") +public class PerGpuMigDevice { + private int index; + + /** + * MIG device index + * @return MIG device index + */ + @XmlElement(name = "index") + public int getMigDeviceIndex() { + return index; + } + + public void setMigDeviceIndex(int index) { + this.index = index; + } +} diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuMigMode.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuMigMode.java new file mode 100644 index 00000000000..b706df2c3bb --- /dev/null +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuMigMode.java @@ -0,0 +1,48 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.server.nodemanager.webapp.dao.gpu; + +import org.apache.hadoop.classification.InterfaceAudience; +import org.apache.hadoop.classification.InterfaceStability; + +import javax.xml.bind.annotation.XmlElement; +import javax.xml.bind.annotation.XmlRootElement; + +/** + * GPU MIG Mode + */ +@InterfaceAudience.Private +@InterfaceStability.Unstable +@XmlRootElement(name = "mig_mode") +public class PerGpuMigMode { + private String currentMigMode; + + /** + * Current MIG mode + * @return MIG mode enabled or disabled + */ + @XmlElement(name = "current_mig") + public String getCurrentMigMode() { + return currentMigMode; + } + + public void setCurrentMigMode(String migMode) { + this.currentMigMode = migMode; + } +} diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/TestGpuDiscoverer.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/TestGpuDiscoverer.java index 8261895b2a9..6c1f500009c 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/TestGpuDiscoverer.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/TestGpuDiscoverer.java @@ -373,6 +373,37 @@ public void testGetNumberOfUsableGpusFromConfig() throws YarnException { assertEquals(4, usableGpuDevices.get(3).getMinorNumber()); } + @Test + public void testGetNumberOfUsableGpusFromConfigMIG() throws YarnException { + Configuration conf = createConfigWithAllowedDevices("0:0,1:1:0,1:1:3,2:2,3:4"); + conf.set(YarnConfiguration.USE_MIG_ENABLED_GPUS, "true"); + GpuDiscoverer discoverer = new GpuDiscoverer(); + discoverer.initialize(conf, binaryHelper); + + List usableGpuDevices = discoverer.getGpusUsableByYarn(); + assertEquals(5, usableGpuDevices.size()); + + assertEquals(0, usableGpuDevices.get(0).getIndex()); + assertEquals(0, usableGpuDevices.get(0).getMinorNumber()); + assertEquals(-1, usableGpuDevices.get(0).getMIGIndex()); + + assertEquals(1, usableGpuDevices.get(1).getIndex()); + assertEquals(1, usableGpuDevices.get(1).getMinorNumber()); + assertEquals(0, usableGpuDevices.get(1).getMIGIndex()); + + assertEquals(1, usableGpuDevices.get(2).getIndex()); + assertEquals(1, usableGpuDevices.get(2).getMinorNumber()); + assertEquals(3, usableGpuDevices.get(2).getMIGIndex()); + + assertEquals(2, usableGpuDevices.get(3).getIndex()); + assertEquals(2, usableGpuDevices.get(3).getMinorNumber()); + assertEquals(-1, usableGpuDevices.get(3).getMIGIndex()); + + assertEquals(3, usableGpuDevices.get(4).getIndex()); + assertEquals(4, usableGpuDevices.get(4).getMinorNumber()); + assertEquals(-1, usableGpuDevices.get(4).getMIGIndex()); + } + @Test public void testGetNumberOfUsableGpusFromConfigDuplicateValues() throws YarnException { @@ -513,4 +544,4 @@ public void testScriptNotCalled() throws YarnException, IOException { verify(gpuSpy, never()).getGpuDeviceInformation(); } -} \ No newline at end of file +} diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/TestNvidiaDockerV2CommandPlugin.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/TestNvidiaDockerV2CommandPlugin.java index b0b523360ef..798a95cb009 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/TestNvidiaDockerV2CommandPlugin.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/TestNvidiaDockerV2CommandPlugin.java @@ -20,10 +20,14 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.Sets; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.yarn.conf.YarnConfiguration; +import org.apache.hadoop.yarn.api.records.ContainerLaunchContext; import org.apache.hadoop.yarn.api.records.ResourceInformation; import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container; import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.ResourceMappings; import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.docker.DockerRunCommand; +import org.apache.hadoop.yarn.server.nodemanager.containermanager.runtime.ContainerExecutionException; import org.junit.Assert; import org.junit.Test; @@ -69,7 +73,13 @@ private boolean commandlinesEquals(Map> cli1, extends NvidiaDockerV2CommandPlugin { private boolean requestsGpu = false; - MyNvidiaDockerV2CommandPlugin() {} + MyNvidiaDockerV2CommandPlugin() { + super(new Configuration()); + } + + MyNvidiaDockerV2CommandPlugin(Configuration conf) { + super(conf); + } public void setRequestsGpu(boolean r) { requestsGpu = r; @@ -127,4 +137,118 @@ public void testPlugin() throws Exception { // runtime should exist Assert.assertTrue(newCommandLine.containsKey("runtime")); } -} \ No newline at end of file + + @Test + public void testPluginMIG() throws Exception { + DockerRunCommand runCommand = new DockerRunCommand("container_1", "user", + "fakeimage"); + + Map> originalCommandline = copyCommandLine( + runCommand.getDockerCommandWithArguments()); + + Configuration conf = new Configuration(); + conf.set(YarnConfiguration.USE_MIG_ENABLED_GPUS, "true"); + MyNvidiaDockerV2CommandPlugin + commandPlugin = new MyNvidiaDockerV2CommandPlugin(conf); + + Container nmContainer = mock(Container.class); + ResourceMappings resourceMappings = new ResourceMappings(); + when(nmContainer.getResourceMappings()).thenReturn(resourceMappings); + + // Assign GPU resource + ResourceMappings.AssignedResources assigned = + new ResourceMappings.AssignedResources(); + assigned.updateAssignedResources( + ImmutableList.of(new GpuDevice(0, 0, 0))); + resourceMappings.addAssignedResources(ResourceInformation.GPU_URI, + assigned); + + commandPlugin.setRequestsGpu(true); + commandPlugin.updateDockerRunCommand(runCommand, nmContainer); + Map> newCommandLine = + runCommand.getDockerCommandWithArguments(); + + // Command line will be updated + Assert.assertFalse(commandlinesEquals(originalCommandline, newCommandLine)); + // NVIDIA_VISIBLE_DEVICES will be set + Assert.assertTrue( + runCommand.getEnv().get("NVIDIA_VISIBLE_DEVICES").equals("0:0")); + // runtime should exist + Assert.assertTrue(newCommandLine.containsKey("runtime")); + } + + @Test(expected = ContainerExecutionException.class) + public void testPluginMIGThrowsMulti() throws Exception { + DockerRunCommand runCommand = new DockerRunCommand("container_1", "user", + "fakeimage"); + + Map> originalCommandline = copyCommandLine( + runCommand.getDockerCommandWithArguments()); + + Configuration conf = new Configuration(); + conf.set(YarnConfiguration.USE_MIG_ENABLED_GPUS, "true"); + MyNvidiaDockerV2CommandPlugin + commandPlugin = new MyNvidiaDockerV2CommandPlugin(conf); + + Container nmContainer = mock(Container.class); + ResourceMappings resourceMappings = new ResourceMappings(); + Map env = new HashMap<>(); + env.put("NVIDIA_MIG_PLUGIN_THROW_ON_MULTIPLE_GPUS", "true"); + when(nmContainer.getResourceMappings()).thenReturn(resourceMappings); + ContainerLaunchContext launchCtx = mock(ContainerLaunchContext.class); + when(nmContainer.getLaunchContext()).thenReturn(launchCtx); + when(launchCtx.getEnvironment()).thenReturn(env); + + // Assign GPU resource + ResourceMappings.AssignedResources assigned = + new ResourceMappings.AssignedResources(); + assigned.updateAssignedResources( + ImmutableList.of(new GpuDevice(0, 0, 0), new GpuDevice(1, 1, 2))); + resourceMappings.addAssignedResources(ResourceInformation.GPU_URI, + assigned); + + commandPlugin.setRequestsGpu(true); + commandPlugin.updateDockerRunCommand(runCommand, nmContainer); + } + + @Test + public void testPluginMIGNoThrowsMulti() throws Exception { + DockerRunCommand runCommand = new DockerRunCommand("container_1", "user", + "fakeimage"); + + Map> originalCommandline = copyCommandLine( + runCommand.getDockerCommandWithArguments()); + + Configuration conf = new Configuration(); + conf.set(YarnConfiguration.USE_MIG_ENABLED_GPUS, "true"); + MyNvidiaDockerV2CommandPlugin + commandPlugin = new MyNvidiaDockerV2CommandPlugin(conf); + + Container nmContainer = mock(Container.class); + ResourceMappings resourceMappings = new ResourceMappings(); + Map env = new HashMap<>(); + env.put("NVIDIA_MIG_PLUGIN_THROW_ON_MULTIPLE_GPUS", "false"); + when(nmContainer.getResourceMappings()).thenReturn(resourceMappings); + ContainerLaunchContext launchCtx = mock(ContainerLaunchContext.class); + when(nmContainer.getLaunchContext()).thenReturn(launchCtx); + when(launchCtx.getEnvironment()).thenReturn(env); + + // Assign GPU resource + ResourceMappings.AssignedResources assigned = + new ResourceMappings.AssignedResources(); + assigned.updateAssignedResources( + ImmutableList.of(new GpuDevice(0, 0, 0), new GpuDevice(1, 1, 2))); + resourceMappings.addAssignedResources(ResourceInformation.GPU_URI, + assigned); + + commandPlugin.setRequestsGpu(true); + commandPlugin.updateDockerRunCommand(runCommand, nmContainer); + Map> newCommandLine = + runCommand.getDockerCommandWithArguments(); + // NVIDIA_VISIBLE_DEVICES will be set + Assert.assertTrue( + runCommand.getEnv().get("NVIDIA_VISIBLE_DEVICES").equals("0:0,1:2")); + // runtime should exist + Assert.assertTrue(newCommandLine.containsKey("runtime")); + } +} ================================================ FILE: examples/MIG-Support/yarn-unpatched/README.md ================================================ # MIG Support for Spark on YARN using unmodified versions of Apache Hadoop 3.1.2+ This document describes a solution for utilizing MIG with YARN when upgrading to a recent 3.3+ version or patching older versions of Apache Hadoop is not feasible. Please refer to the corresponding alternatives for more information: - [Device Plugins README](../device-plugins/gpu-mig/README.md) - [YARN patch README](../resource-types/gpu-mig/README.md) ## Introduction We provide a set of scripts that wrap the original `nvidia-smi` from the NVIDIA GPU Driver and `nvidia-container-cli` included in [NVIDIA Container Toolkit](https://github.com/NVIDIA/nvidia-docker). `nvidia-smi` is a wrapper script that parses the XML output of `nvidia-smi -q -x` used by YARN to discover GPUs. It replaces MIG-enabled GPUs with the list of `` elements corresponding to every `` element of the GPU with additional annotation to construct the MIG identifier for `nvidia-container-cli`. This reverse mapping is performed by modified `nvidia` Docker runtime using `nvidia-container-cli-wrapper.sh`. ## Requirements Please see the [MIG Application Considerations](https://docs.nvidia.com/datacenter/tesla/mig-user-guide/#app-considerations) and [CUDA Device Enumeration](https://docs.nvidia.com/datacenter/tesla/mig-user-guide/index.html#cuda-visible-devices). Special note, that this method only works with drivers >= R470 (470.42.01+). ## Installation These instructions assume YARN is already installed and configured with GPU Scheduling enabled using Docker and the NVIDIA Container Toolkit (nvidia-docker2). See [Using GPU on YARN](https://hadoop.apache.org/docs/r3.1.2/hadoop-yarn/hadoop-yarn-site/UsingGpus.html) if you need more information. Enable and configure your [GPUs with MIG](https://docs.nvidia.com/datacenter/tesla/mig-user-guide/index.html) on all of the nodes it applies to. Download the contents of [scripts](./scripts/) to every YARN NodeManager (worker) machine to some location, for example: `/usr/local/yarn-mig-scripts`. Make sure that the scripts are executable by the docker daemon user (i.e., `root`), and YARN NM service user (typically `yarn`). Note that the scripts leave the original outputs untouched if the environment variable `MIG_AS_GPU_ENABLED` is not 1. ### YARN Configuration #### Customizing yarn-env.sh In `$YARN_CONF_DIR/yarn-env.sh` - Add `export MIG_AS_GPU_ENABLED=1` to enable replacing of MIG-enabled GPUs with a list of of MIG devices as if they are physical GPU. - Customize `REAL_NVIDIA_SMI_PATH` value if nvidia-smi is not at the default location `/usr/bin/nvidia-smi`. - Add `ENABLE_NON_MIG_GPUS=0` if you want to prevent discovery of physical GPUs that are not subdivided in MIGs. Default is ENABLE_NON_MIG_GPUS=1 and physical GPUs in the MIG-Disabled state are listed along with MIG sub-devices on the node. Modify the following config `$YARN_CONF_DIR/yarn-site.xml`: ```xml yarn.nodemanager.resource-plugins.gpu.path-to-discovery-executables /usr/local/yarn-mig-scripts/ ``` By default, `yarn.nodemanager.resource-plugins.gpu.allowed-gpu-devices` is set to `auto` and and `/usr/local/yarn-mig-scripts/nvidia-smi` will be called by YARN to discover GPUs. If you disable the default automatic GPU discovery, you can manually specify the list of MIG instances to use by setting `yarn.nodemanager.resource-plugins.gpu.allowed-gpu-devices` to the list of 0-based indices corresponding to the desired `` elements in the output of ```bash MIG_AS_GPU_ENABLED=1 /usr/local/yarn-mig-scripts/nvidia-smi -q -x ``` In other words, if you want to allow MIG 1:2 and 2:0 and they are listed as 3rd and 5th `` elements the value for `yarn.nodemanager.resource-plugins.gpu.allowed-gpu-devices` should be "2,4". ### NVIDIA Docker Runtime Configuration Modify section `[nvidia-container-cli]` in `/etc/nvidia-container-runtime/config.toml`: ```toml path = "/usr/local/yarn-mig-scripts/nvidia-container-cli-wrapper.sh" environment = [ "MIG_AS_GPU_ENABLED=1", "REAL_NVIDIA_SMI_PATH=/if/non-default/path/nvidia-smi" ] ``` Note, the values for `MIG_AS_GPU_ENABLED`, `REAL_NVIDIA_SMI_PATH`, `ENABLE_NON_MIG_GPUS` should be identical to the ones specified in `yarn-env.sh`. ================================================ FILE: examples/MIG-Support/yarn-unpatched/scripts/mig2gpu.sh ================================================ #!/bin/bash # Copyright (c) 2022, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. set -e # This file contains the logic for parsing and manipulating the well-formed # pretty-printed XML output generated by nvidia-smi. It replaces the a each MIG-enabled gpu element with # with a list of gpu elements corresponding to its configured MIG devices. # If there is at least one MIG-enabled GPU, the output for non-MIG GPUs is suppressed by default. However, # this can be overridden using ENABLE_NON_MIG_GPUS=1. # XML fragments are viewed and manipuated using bash arrays of lines. Each elmenent of interest is tracked by # a start offset into the line array pointing to the the line with the opening tag and the end offset, # which is the line number past the closing tag. # # NOTE: this is not a real XML parser, but it is sufficient to handle XML without nested # tags mixed on the same line. When making changes try to avoid non-bash dependencies. # Include both MIG and non-MIG devices by default # Set ENABLE_NON_MIG_GPUS=0 to discover only GPU devices with the current MIG mode Disabled ENABLE_NON_MIG_GPUS=${ENABLE_NON_MIG_GPUS:-1} # If setting YARN up to use Cgroups without official YARN support, # enabling this tells the script to use the NVIDIA capabilities access # device number for the minor number so that the YARN Cgroup code # denies access to MIG devices properly. ENABLE_MIG_GPUS_FOR_CGROUPS=${ENABLE_MIG_GPUS_FOR_CGROUPS:-0} # For stored input test: NVIDIA_SMI_QX=./src/resources/tom-nvidia-smi-xq.xml # For live input test: NVIDIA_SMI_QX=/dev/stdin NVIDIA_SMI_QX="${NVIDIA_SMI_QX:-"/dev/stdin"}" mig2gpu_inputLines=() # buffer global output here mig2gpu_out=() mig2gpu_migEnabled=0 mig2gpu_driverVersion="INVALID_DRIVER_VERSION" # buffer non-MIG GPU output here mig2gpu_nonMigGpu_out=() mig2gpu_migGpu_out=() # Slice of original XML defining the current GPU element mig2gpu_gpu_lineNumberStart=-1 mig2gpu_gpu_lineNumberEnd=-1 # Slice of original XML defining the current MIG element mig2gpu_mig_lineNumberStart=-1 mig2gpu_mig_lineNumberEnd=-1 mig2gpu_migIndex=-1 # Parent GPU context for MIG mig2gpu_gpuIdx=-1 mig2gpu_migGpuInstanceId=-1 mig2gpu_migComputeInstanceUuid=-1 mig2gpu_productName="INVALID_GPU_PRODUCT_NAME" mig2gpu_gpuUuid="INVALID_GPU_UUID" mig2gpu_gpuMinorNumber="INVALID_GPU_MINOR_NUMBER" mig2gpu_gpu_utilization_lineNumberStart=-1 mig2gpu_gpu_utilization_lineNumberEnd=-1 mig2gpu_gpu_temperature_lineNumberStart=-1 mig2gpu_gpu_temperature_lineNumberEnd=-1 # The function to replace a MIG-enabled GPU with the "fake" GPU device elements # corresponding to MIG devices contained within the given GPU element # # The minimum GPU content YARN needs from GPU for parse to succeed: # # # 495.29.05 # # Quadro RTX 6000 # GPU-903720f4-f8d1-11e0-3b2f-4bd740b2f424 # 0 # # 673 MiB # &23547 MiB # # # &23 % # # # 38 C # 94 C # 91 C # # # # # A MIG device looks like this: # # 0 # 3 # 0 # # # 14 # 1 # 0 # 1 # 0 # 0 # # # # # 0 # # # # 6016 MiB # 3 MiB # 6012 MiB # # # 8191 MiB # 0 MiB # 8191 MiB # # # # To satisfy the minimum parseable GPU element, we need to # 1) add a element, parent's orginal text + MIG + index # 2) add a element accoring to https://docs.nvidia.com/datacenter/tesla/mig-user-guide/#cuda-gi # MIG-// # 3) add parent's 0 (don't care) # 4) use MIG's own element unchanged # 5) copy element from parent # 6) copy element from parent # # To enable bidirectional translation to/from fake # 7) add a <_mig2gpu_device_id> element: ":", e.g. 0:0 function processParentGpuGlobals { local lineNumber # increment 0-based GPU iteration order index mig2gpu_gpuIdx=$((mig2gpu_gpuIdx+1)) for ((lineNumber=mig2gpu_gpu_lineNumberStart; lineNumber'*'') if [[ "$line" =~ 'Enabled' ]]; then mig2gpu_migEnabled=1 else mig2gpu_migEnabled=0 fi ;; $'\t'*''*) if [[ "$line" =~ $'\t\t'(.*)'' ]]; then mig2gpu_productName="${BASH_REMATCH[1]}" fi ;; $'\t'*''*) if [[ "$line" =~ $'\t\t'(.*)'' ]]; then mig2gpu_gpuUuid="${BASH_REMATCH[1]}" fi ;; $'\t'*''*) mig2gpu_gpuMinorNumber="$line" ;; $'\t'*''*) mig2gpu_gpu_utilization_lineNumberStart="$lineNumber" ;; $'\t'*''*) mig2gpu_gpu_utilization_lineNumberEnd=$((lineNumber+1)) ;; $'\t'*''*) mig2gpu_gpu_temperature_lineNumberStart="$lineNumber" ;; $'\t'*''*) mig2gpu_gpu_temperature_lineNumberEnd=$((lineNumber+1)) ;; esac done } function addOriginalGpuIndexAsDeviceId { local afterUuidLineStart=$((mig2gpu_gpu_lineNumberStart+3)) local afterUuidGpuLength=$((mig2gpu_gpu_lineNumberEnd-afterUuidLineStart)) mig2gpu_nonMigGpu_out+=( "${mig2gpu_inputLines[@]:$mig2gpu_gpu_lineNumberStart:3}" ) mig2gpu_nonMigGpu_out+=( $'\t\t'"<_mig2gpu_device_id>$mig2gpu_gpuIdx") mig2gpu_nonMigGpu_out+=( "${mig2gpu_inputLines[@]:$afterUuidLineStart:$afterUuidGpuLength}" ) } function replaceParentGpuWithMigs { for ((lineNumber=mig2gpu_gpu_lineNumberStart; lineNumber'*) mig2gpu_mig_lineNumberStart=$lineNumber ;; $'\t'*''*) if [[ "$line" =~ $'\t'*''(.*)'' ]]; then mig2gpu_migIndex="${BASH_REMATCH[1]}" fi ;; $'\t'*'_instance_id>'*) if [[ "$line" =~ $'\t'*''(.*)'' ]]; then mig2gpu_migGpuInstanceId="${BASH_REMATCH[1]}" elif [[ "$line" =~ $'\t'*''(.*)'' ]]; then mig2gpu_migComputeInstanceId="${BASH_REMATCH[1]}" fi ;; $'\t'*''*) local fbMemoryUsage_lineNumberStart=$lineNumber ;; $'\t'*''*) local fbMemoryUsage_lineNumberEnd=$((lineNumber+1)) local fbMemryUsageLength=$((fbMemoryUsage_lineNumberEnd-fbMemoryUsage_lineNumberStart)) local fbMemoryUsage=("${mig2gpu_inputLines[@]:$fbMemoryUsage_lineNumberStart:fbMemryUsageLength}") local migFbMemoryUsage=("${fbMemoryUsage[@]//$'\t\t\t'/$'\t\t'}") ;; $'\t'*''*) mig2gpu_mig_lineNumberEnd=$((lineNumber+1)) # mig2gpu_migGpu_out+=("${mig2gpu_inputLines[$mig2gpu_gpu_lineNumberStart]}") mig2gpu_migGpu_out+=($'\t\t'"$mig2gpu_productName (MIG)") # We don't really use it since driver-dependent # but R450 & R460 form is more useful for debugging # https://docs.nvidia.com/datacenter/tesla/mig-user-guide/#cuda-visible-devices # local migUuid="MIG-$mig2gpu_gpuUuid/$mig2gpu_migGpuInstanceId/$mig2gpu_migComputeInstanceId" mig2gpu_migGpu_out+=($'\t\t'"$migUuid") # https://github.com/NVIDIA/nvidia-container-runtime#nvidia_visible_devices # The scheme : is not annotated with any # driver version caveats, so adding this for stability and simplicity local migDeviceId="$mig2gpu_gpuIdx:$mig2gpu_migIndex" mig2gpu_migGpu_out+=($'\t\t'"<_mig2gpu_device_id>$migDeviceId") # if using this with CGROUP workaround we need the minor number to be from nvidia-caps access if [[ "$ENABLE_MIG_GPUS_FOR_CGROUPS" == 1 ]]; then mig_minor_dev_num=`cat /proc/driver/nvidia-caps/mig-minors | grep gpu$mig2gpu_gpuIdx/gi$mig2gpu_migGpuInstanceId/access | cut -d ' ' -f 2` mig2gpu_migGpu_out+=($'\t\t'"$mig_minor_dev_num") else mig2gpu_migGpu_out+=("$mig2gpu_gpuMinorNumber") fi mig2gpu_migGpu_out+=("${migFbMemoryUsage[@]}") local gpuUtilizationLength=$((mig2gpu_gpu_utilization_lineNumberEnd - mig2gpu_gpu_utilization_lineNumberStart)) local gpuUtilization=("${mig2gpu_inputLines[@]:$mig2gpu_gpu_utilization_lineNumberStart:gpuUtilizationLength}") mig2gpu_migGpu_out+=("${gpuUtilization[@]}") local gpuTemperatureLength=$((mig2gpu_gpu_temperature_lineNumberEnd - mig2gpu_gpu_temperature_lineNumberStart)) mig2gpu_migGpu_out+=("${mig2gpu_inputLines[@]:$mig2gpu_gpu_temperature_lineNumberStart:$gpuTemperatureLength}") # mig2gpu_migGpu_out+=("${mig2gpu_inputLines[$((mig2gpu_gpu_lineNumberEnd-1))]}") ;; esac done } function processGpuElement { processParentGpuGlobals if [[ "$mig2gpu_migEnabled" != "1" ]]; then addOriginalGpuIndexAsDeviceId else # scan gpu element lines twice because the mig section appears before # the info needed from parent replaceParentGpuWithMigs fi } function mig2gpuMain { local line local lineNumber # simplified regex-free parser relying on the fact # that nvidia-smi output is pretty-printed with tabs while IFS= read -r line; do lineNumber=${#mig2gpu_inputLines[@]} mig2gpu_inputLines+=("$line") case "$line" in # document-level tags '<'*) mig2gpu_out+=("$line") ;; $'\t'*) mig2gpu_driverVersion="$line" ;; *) # ignore infeasible ;; esac done < "$NVIDIA_SMI_QX" for outLine in "${mig2gpu_out[@]}"; do printf '%s\n' "$outLine" if [[ "$outLine" =~ '' ]]; then printf '%s\n' "$mig2gpu_driverVersion" printf '%s\n' "${mig2gpu_migGpu_out[@]}" # output non-MIG only if ENABLE_NON_MIG_GPUS is set # https://docs.nvidia.com/datacenter/tesla/mig-user-guide/#cuda-visible-devices # currently mixing MIG and non-MIG GPUs is not supported by the driver # "Note that these constraints may be relaxed in future NVIDIA driver releases for MIG" if [[ "${#mig2gpu_migGpu_out[@]}" == "0" || "$ENABLE_NON_MIG_GPUS" == "1" ]]; then printf '%s\n' "${mig2gpu_nonMigGpu_out[@]}" fi fi done } mig2gpuMain ================================================ FILE: examples/MIG-Support/yarn-unpatched/scripts/nvidia-container-cli-wrapper.sh ================================================ #!/bin/bash # Copyright (c) 2022-2025, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # This script is executed by the `nvidia` Docker runtime on the host before creating the container. # It intercepts the device assigned by YARN, a 0-based index and converts it to a pair # GPU device index:MIG device index that is stored in _mig2gpu_device_is elememnt # by mig2gpu.sh in the nvidia-smi-wrapper.sh which limits all the processes withing the container # to the corresponding MIG Compute Instance https://github.com/NVIDIA/nvidia-container-runtime#nvidia_visible_devices. # customize in /etc/nvidia-container-runtime/config.toml # [nvidia-container-cli] # environment = [ "VAR1=VAL1", "VAR2=VAL2" ] REAL_NVIDIA_CONTAINER_CLI_PATH=${REAL_NVIDIA_CONTAINER_CLI_PATH:-"/usr/bin/nvidia-container-cli"} REAL_NVIDIA_SMI_PATH=${REAL_NVIDIA_SMI_PATH:-"/usr/bin/nvidia-smi"} MIG_AS_GPU_ENABLED=${MIG_AS_GPU_ENABLED:-"0"} THIS_PATH="$(readlink -f $0)" THIS_DIR="$(dirname $THIS_PATH)" if [[ "$MIG_AS_GPU_ENABLED" == "1" ]]; then realArgs=() for arg in "$@"; do case "$arg" in "--device="*) nvcli_migDeviceIds=() # map CSV of indexes 0,3,10 to ,0,3,10, # so we can do an easy "contains" test # the device N is included if deviceArgWithLeadingTrailingComma # matches =~ ",N," deviceArgWithLeadingTrailingComma=",${arg#*=}," current_gpu_idx=-1 while read -r line; do case "$line" in # found the device id constructed in mig2gpu.sh with the original nvidia-smi enumeration # gpu index, mig index *"<_mig2gpu_device_id>"*) current_gpu_idx=$((current_gpu_idx+1)) if [[ "$deviceArgWithLeadingTrailingComma" =~ ",${current_gpu_idx}," && "$line" =~ '<_mig2gpu_device_id>'(.*)'' ]]; then nvcli_migDeviceIds+=("${BASH_REMATCH[1]}") fi ;; esac done < <("$REAL_NVIDIA_SMI_PATH" -q -x | "$THIS_DIR/mig2gpu.sh") # make sure the above redirect into the while read loop does not use the here-string (<<<) method because different # versions of bash materialize newlines differently in the string. Older versions treat it as a single # line and newer versions leave it as a multiline string. Here it needs to be a multiline. if (( ${#nvcli_migDeviceIds[@]} )); then migDeviceIdsCsv=$(IFS=','; echo "${nvcli_migDeviceIds[*]}") realArgs+=("--device=$migDeviceIdsCsv") else realArgs+=("$arg") fi ;; *) realArgs+=("$arg") ;; esac done "$REAL_NVIDIA_CONTAINER_CLI_PATH" "${realArgs[@]}" else "$REAL_NVIDIA_CONTAINER_CLI_PATH" "$@" fi ================================================ FILE: examples/MIG-Support/yarn-unpatched/scripts/nvidia-smi ================================================ #!/bin/bash # Copyright (c) 2022, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # This script is designed as a drop-in replacement for YARN node manager's automatic # MIG-aware GPU discovery. YARN config # yarn.nodemanager.resource-plugins.gpu.path-to-discovery-executables # should point to this script on NM host, e.g # # yarn.nodemanager.resource-plugins.gpu.path-to-discovery-executables # /usr/local/yarn-mig-scripts/ # # customize in yarn-env.sh REAL_NVIDIA_SMI_PATH=${REAL_NVIDIA_SMI_PATH:-"/usr/bin/nvidia-smi"} MIG_AS_GPU_ENABLED=${MIG_AS_GPU_ENABLED:-"0"} THIS_PATH="$(readlink -f $0)" THIS_DIR="$(dirname $THIS_PATH)" for arg in "$@"; do case "$arg" in "-q"|"--query") QUERY_ARG=1 ;; "-x"|"--xml-format") XML_FORMAT_ARG=1 ;; esac done if [[ "$MIG_AS_GPU_ENABLED" == "1" && "$XML_FORMAT_ARG" == "1" && "$QUERY_ARG" == "1" ]]; then "$REAL_NVIDIA_SMI_PATH" "$@" | "$THIS_DIR/mig2gpu.sh" else "$REAL_NVIDIA_SMI_PATH" "$@" fi ================================================ FILE: examples/ML+DL-Examples/Optuna-Spark/README.md ================================================ # Distributed Hyperparameter Tuning These examples demonstrate distributed hyperparameter tuning with [Optuna](https://optuna.readthedocs.io/en/stable/index.html) on Apache Spark, accelerated with [RAPIDS](https://rapids.ai/) on GPU. We showcase how to set up and tune XGBoost on GPU, with deployment on Spark Standalone or Databricks clusters. ## Contents: - [Overview](#overview) - [Examples](#examples) - [Running Optuna on Spark Standalone](#running-optuna-on-spark-standalone) - [Setup Database for Optuna](#1-setup-database-for-optuna) - [Setup Optuna Python Environment](#2-setup-optuna-python-environment) - [Start Standalone Cluster and Run](#3-start-standalone-cluster-and-run) - [Running Optuna on Databricks](#running-optuna-on-databricks) - [Upload Init Script and Notebook](#1-upload-init-script-and-notebook) - [Create Cluster](#2-create-cluster) - [Run Notebook](#3-run-notebook) - [Benchmarks](#benchmarks) - [How Does it Work?](#how-does-it-work) - [Implementation Notes](#implementation-notes) --- ## Overview Optuna is a lightweight Python library for hyperparameter tuning, integrating state-of-the-art hyperparameter optimization algorithms. At a high level, we optimize hyperparameters in three steps: 1. Wrap model training with an `objective` function that returns a loss metric. 2. In each `trial`, suggest hyperparameters based on previous results. 3. Create a `study` object, which executes the optimization and stores the trial results. **Local example**: tuning XGBoost with Optuna (from [Optuna docs](https://optuna.org/#code_examples)): ```python import xgboost as xgb import optuna # 1. Define an objective function to be maximized. def objective(trial): ... # 2. Suggest values of the hyperparameters using a trial object. param = { "objective": "binary:logistic", "booster": trial.suggest_categorical("booster", ["gbtree", "gblinear", "dart"]), "lambda": trial.suggest_float("lambda", 1e-8, 1.0, log=True), "alpha": trial.suggest_float("alpha", 1e-8, 1.0, log=True), "subsample": trial.suggest_float("subsample", 0.2, 1.0), "colsample_bytree": trial.suggest_float("colsample_bytree", 0.2, 1.0), } booster = xgb.train(param, dtrain) ... return accuracy # 3. Create a study object and optimize the objective function. study = optuna.create_study(direction='maximize') study.optimize(objective, n_trials=100) ``` To run **distributed tuning** on Spark, we take the following steps: 1. Each worker receives a copy of the same dataset. 2. Each worker runs a subset of the trials in parallel. 3. Workers write trial results and receive new hyperparameters using a shared database. ### Examples We provide **2 notebooks**, with differences in the backend/implementation. See [implementation notes](#implementation-notes) for more details. - `optuna-joblibspark.ipynb`: - Uses the [Joblib Spark backend](https://github.com/joblib/joblib-spark) to distribute tasks on the Spark cluster. - Implements *Worker-I/O*, where each worker reads the full dataset from a specified filepath (e.g., distributed file system). - Builds on [this Databricks example](https://docs.databricks.com/en/machine-learning/automl-hyperparam-tuning/optuna.html). - `optuna-dataframe.ipynb`: - Uses Spark dataframes to distribute tasks on the cluster. - Implements *Spark-I/O*, where Spark reads the dataset from a specified filepath, then duplicates and repartitions it so that each worker task is mapped onto a copy of the dataset. - Dataframe operations are accelerated on GPU with the [Spark-RAPIDS Accelerator](https://nvidia.github.io/spark-rapids/). ## Running Optuna on Spark Standalone ### 1. Setup Database for Optuna Optuna offers an RDBStorage option which allows for the persistence of experiments across different machines and processes, thereby enabling Optuna tasks to be distributed. This section will walk you through setting up MySQL as the backend for RDBStorage in Optuna. We highly recommend installing MySQL on the driver node. This setup eliminates concerns regarding MySQL connectivity between worker nodes and the driver, simplifying the management of database connections. (For Databricks, the installation is handled by the init script). 1. Install MySql: ``` shell sudo apt install mysql-server ``` 2. Configure MySQL bind address: in `/etc/mysql/mysql.conf.d/mysqld.cnf` ``` shell bind-address = YOUR_DRIVER_HOST_IP mysqlx-bind-address = YOUR_DRIVER_HOST_IP ``` 3. Restart MySQL: ``` shell sudo systemctl restart mysql.service ``` 4. Setup user: ```shell sudo mysql ``` ``` mysql mysql> CREATE USER 'optuna_user'@'%' IDENTIFIED BY 'optuna_password'; Query OK, 0 rows affected (0.01 sec) mysql> GRANT ALL PRIVILEGES ON *.* TO 'optuna_user'@'%' WITH GRANT OPTION; Query OK, 0 rows affected (0.01 sec) mysql> FLUSH PRIVILEGES; Query OK, 0 rows affected (0.01 sec) mysql> EXIT; Bye ``` Create a database for Optuna: ``` shell mysql -u optuna_user -p -e "CREATE DATABASE IF NOT EXISTS optuna" ``` Troubleshooting: > If you encounter `"ERROR 2002 (HY000): Can't connect to local MySQL server through socket '/tmp/mysql.sock' (2)"`, try the command: `ln -s /var/run/mysqld/mysqld.sock /tmp/mysql.sock` ### 2. Setup Optuna Python Environment Install the MySQL client and create a conda environment with the required libraries. We use [RAPIDS](https://docs.rapids.ai/install/#get-rapids) for GPU-accelerated ETL. See the [docs](https://docs.rapids.ai/install/#get-rapids) for version selection. ``` shell sudo apt install libmysqlclient-dev conda create -n rapids-26.02 -c rapidsai -c conda-forge -c nvidia \ cudf=26.02 cuml=26.02 python=3.10 'cuda-version>=12.0,<=12.5' conda activate optuna-spark pip install mysqlclient pip install optuna joblib joblibspark ipywidgets ``` ### 3. Start Standalone Cluster and Run Configure your standalone cluster settings. This example just creates local cluster with a single GPU worker: ```shell export SPARK_HOME=/path/to/spark export SPARK_WORKER_OPTS="-Dspark.worker.resource.gpu.amount=1 \ -Dspark.worker.resource.gpu.discoveryScript=$SPARK_HOME/examples/src/main/scripts/getGpusResources.sh" export MASTER=spark://$(hostname):7077; export SPARK_WORKER_INSTANCES=1; export CORES_PER_WORKER=8 ${SPARK_HOME}/sbin/start-master.sh; ${SPARK_HOME}/sbin/start-worker.sh -c ${CORES_PER_WORKER} -m 16G ${MASTER} ``` You can now run the notebook using the `optuna-spark` Python kernel! The notebook contains instructions to attach to the standalone cluster. ## Running Optuna on Databricks ### 1. Upload Init Script and Notebook - Make sure your [Databricks CLI]((https://docs.databricks.com/en/dev-tools/cli/tutorial.html)) is configured for your Databricks workspace. - Copy the desired notebook into your Databricks workspace. For example: ```shell databricks workspace import /Users/someone@example.com/optuna/optuna-joblibspark.ipynb --format JUPYTER --file optuna-joblibspark.ipynb ``` - Copy the init script ```databricks/init_optuna.sh```: ```shell databricks workspace import /Users/someone@example.com/optuna/init_optuna.sh --format AUTO --file databricks/init_optuna.sh ``` ### 2. Create Cluster *For Databricks Azure*: Use the cluster startup script, which is configured to create a 4 node GPU cluster: ```shell export INIT_PATH=/Users/someone@example.com/optuna/init_optuna.sh cd databricks chmod +x start_cluster.sh ./start_cluster.sh ``` Or, create a cluster via the web UI: - Go to `Compute > Create compute` and set the desired cluster settings. - Under `Advanced Options > Init Scripts`, upload the init script from your workspace. - Under `Advanced Options > Spark > Environment variables`, set `LIBCUDF_CUFILE_POLICY=OFF`. - Make sure to use a GPU cluster and include task GPU resources. The init script will install the required libraries on all nodes, including RAPIDS and the Spark-RAPIDS plugin for GPU-accelerated ETL. On the driver, it will setup the MySQL server backend. ### 3. Run Notebook Locate the notebook in your workspace and click on `Connect` to attach it to the cluster. The notebook is ready to run! ## Benchmarks The graph below shows running times comparing distributed (8 GPUs) vs. single GPU hyperparameter tuning with 100 trials on synthetic regression datasets. ![Databricks benchmarking results](images/runtimes.png) ## How does it work? The Optuna tasks will be serialized into bytes and distributed to Spark workers to run. The Optuna task on the executor side that loads the Optuna study from RDBStorage, and then runs its set of trials. During tuning, the Optuna tasks send intermediate results back to RDBStorage to persist, and ask for the parameters from RDBStorage sampled by Optuna on the driver to run next. **Using JoblibSpark**: each Optuna task is a Spark application that has only 1 job, 1 stage, 1 task, and the Spark application will be submitted on the local threads. Here the parameter `n_jobs` configures the Spark backend to limit how many Spark applications are submitted at the same time. Thus Optuna with JoblibSpark uses Spark application level parallelism, rather than task-level parallelism. For larger datasets, ensure that a single XGBoost task can run on a single node without any CPU/GPU OOM. Application parallelism with JoblibSpark: ![Optuna on JoblibSpark](images/optuna.svg) ### Implementation Notes ###### Data I/O: Since each worker requires the full dataset to perform hyperparameter tuning, there are two strategies to get the data into worker memory: - **Worker I/O**: *each worker reads the dataset* from the filepath once the task has begun. In practice, this requires the dataset to be written to a distributed file system accessible to all workers prior to tuning. The `optuna-joblibspark` notebook demonstrates this. - **Spark I/O**: Spark reads the dataset and *creates a copy of the dataset for each worker*, then maps the tuning task onto each copy. In practice, this enables the code to be chained to other Dataframe operations (e.g. ETL stages) without the intermediate step of writing to DBFS, at the cost of some overhead during duplication. The `optuna-dataframe` notebook demonstrates this. - To achieve this, we coalesce the input Dataframe to a single partition, and recursively self-union until we have the desired number of copies (number of workers). Thus each partition will contain a duplicate of the entire dataset, and the Optuna task can be mapped directly onto the partitions. ###### Misc: - Please be aware that Optuna studies will continue where they left off from previous trials; delete and recreate the study if you would like to start anew. - Optuna in distributed mode is **non-deterministic** (see [this link](https://optuna.readthedocs.io/en/stable/faq.html#how-can-i-obtain-reproducible-optimization-results)), as trials are executed asynchronously by executors. Deterministic behavior can be achieved using Spark barriers to coordinate reads/writes to the database. - Reading data with GPU using cuDF requires disabling [GPUDirect Storage](https://docs.rapids.ai/api/cudf/nightly/user_guide/io/io/#magnum-io-gpudirect-storage-integration), i.e., setting the environment variable `LIBCUDF_CUFILE_POLICY=OFF`, to be compatible with the Databricks file system. Without GDS, cuDF will use a CPU bounce buffer when reading files, but all parsing and decoding will still be accelerated by the GPU. - Note that the storage doesn’t store the state of the instance of samplers and pruners. To resume a study with a sampler whose seed argument is specified, [the sampler can be pickled](https://optuna.readthedocs.io/en/stable/tutorial/20_recipes/001_rdb.html#resume-study) and returned to the driver alongside the results. ================================================ FILE: examples/ML+DL-Examples/Optuna-Spark/optuna-examples/databricks/init_optuna.sh ================================================ #!/bin/bash # # Copyright (c) 2025-2026, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # set -x sudo rm -r /var/lib/apt/lists/* sudo apt clean && sudo apt update --fix-missing -y if [[ $DB_IS_DRIVER = "TRUE" ]]; then # setup database for optuna on driver # install mysql server sudo apt install -y mysql-server if [[ ! -f "/etc/mysql/mysql.conf.d/mysqld.cnf" ]]; then sudo apt remove --purge mysql\* sudo apt clean && sudo apt update --fix-missing -y sudo apt install -y mysql-server fi if [[ ! -f "/etc/mysql/mysql.conf.d/mysqld.cnf" ]]; then echo "ERROR: MYSQL installation failed" exit 1 fi # configure mysql BIND_ADDRESS=$DB_DRIVER_IP MYSQL_CONFIG_FILE="/etc/mysql/mysql.conf.d/mysqld.cnf" sudo sed -i "s/^bind-address\s*=.*/bind-address = $BIND_ADDRESS/" "$MYSQL_CONFIG_FILE" sudo sed -i "s/^mysqlx-bind-address\s*=.*/mysqlx-bind-address = $BIND_ADDRESS/" "$MYSQL_CONFIG_FILE" sudo systemctl restart mysql.service # setup user OPTUNA_USER="optuna_user" OPTUNA_PASSWORD="optuna_password" sudo mysql -u root -e " CREATE USER IF NOT EXISTS '$OPTUNA_USER'@'%' IDENTIFIED BY '$OPTUNA_PASSWORD'; GRANT ALL PRIVILEGES ON *.* TO '$OPTUNA_USER'@'%' WITH GRANT OPTION; FLUSH PRIVILEGES;" fi # rapids import SPARK_RAPIDS_VERSION=26.02.0 curl -L https://repo1.maven.org/maven2/com/nvidia/rapids-4-spark_2.12/${SPARK_RAPIDS_VERSION}/rapids-4-spark_2.12-${SPARK_RAPIDS_VERSION}.jar -o \ /databricks/jars/rapids-4-spark_2.12-${SPARK_RAPIDS_VERSION}.jar # setup cuda: install cudatoolkit 11.8 via runfile approach wget https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run sh cuda_11.8.0_520.61.05_linux.run --silent --toolkit # reset symlink and update library loading paths rm /usr/local/cuda ln -s /usr/local/cuda-11.8 /usr/local/cuda sudo /databricks/python3/bin/pip3 install \ --extra-index-url=https://pypi.nvidia.com \ "cudf-cu11==25.02.*" "cuml-cu11==25.02.*" # setup python environment sudo apt clean && sudo apt update --fix-missing -y sudo apt install pkg-config sudo apt install -y libmysqlclient-dev sudo /databricks/python3/bin/pip3 install --upgrade pip sudo /databricks/python3/bin/pip3 install mysqlclient xgboost sudo /databricks/python3/bin/pip3 install optuna joblib joblibspark if [[ $DB_IS_DRIVER = "TRUE" ]]; then # create optuna database and study sudo mysql -u $OPTUNA_USER -p$OPTUNA_PASSWORD -e "CREATE DATABASE IF NOT EXISTS optuna;" fi set +x ================================================ FILE: examples/ML+DL-Examples/Optuna-Spark/optuna-examples/databricks/start_cluster.sh ================================================ #!/bin/bash # # Copyright (c) 2025-2026, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # if [[ -z ${INIT_PATH} ]]; then echo "Please export INIT_PATH per README.md" exit 1 fi json_config=$(cat <\n", "\n", "# Distributed Hyperparameter Tuning: Optuna + Spark Dataframes\n", "\n", "\n", "This demo demonstrates distributed hyperparameter tuning for XGBoost using Spark Dataframes. \n", "We implement best practices to precompute data and maximize computations on the GPU. \n", "\n", "Reference: https://forecastegy.com/posts/xgboost-hyperparameter-tuning-with-optuna/" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Note:\n", "Before running, please make sure you've followed the relevant [setup instructions](../README.md) for your environment (standalone or databricks).\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from typing import Iterable, List, Dict, Optional, Union, Sequence, Any\n", "import math\n", "import os\n", "import requests\n", "import pandas as pd\n", "import optuna\n", "from optuna.samplers import TPESampler\n", "import xgboost as xgb\n", "from pyspark.sql import SparkSession, DataFrame\n", "from pyspark import TaskContext, SparkConf\n", "from pyspark.sql.types import StructType, StructField, DoubleType, IntegerType, StringType, BooleanType" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Download the dataset\n", "\n", "We'll use the [red wine quality dataset](https://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-red.csv) to regress wine quality based on features such as acidity, sugar content, etc. \n", "\n", "**Note**: This example uses a small dataset for demonstration purposes. The performance advantages of distributed training are best realized with large datasets and computational workloads." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "cwd = os.getcwd()\n", "os.mkdir(os.path.join(cwd, \"data\")) if not os.path.exists(os.path.join(cwd, \"data\")) else None\n", "filepath = os.path.join(cwd, \"data\", \"winequality-red.csv\")" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "File downloaded and saved to /home/rishic/Code/myforks/spark-rapids-examples/examples/ML+DL-Examples/Optuna-Spark/optuna-examples/data/winequality-red.csv\n" ] } ], "source": [ "url = \"https://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-red.csv\"\n", "\n", "response = requests.get(url)\n", "if response.status_code == 200:\n", " with open(filepath, \"wb\") as f:\n", " f.write(response.content)\n", " print(f\"File downloaded and saved to {filepath}\")\n", "else:\n", " print(f\"Failed to download the file. Status code: {response.status_code}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Part 1. Running Optuna locally" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "import cudf\n", "from cuml.metrics.regression import mean_squared_error\n", "from cuml.model_selection import train_test_split" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Prepare data" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
fixed acidityvolatile aciditycitric acidresidual sugarchloridesfree sulfur dioxidetotal sulfur dioxidedensitypHsulphatesalcoholquality
07.40.700.001.90.07611.034.00.99783.510.569.45
17.80.880.002.60.09825.067.00.99683.200.689.85
27.80.760.042.30.09215.054.00.99703.260.659.85
311.20.280.561.90.07517.060.00.99803.160.589.86
47.40.700.001.90.07611.034.00.99783.510.569.45
\n", "
" ], "text/plain": [ " fixed acidity volatile acidity citric acid residual sugar chlorides \\\n", "0 7.4 0.70 0.00 1.9 0.076 \n", "1 7.8 0.88 0.00 2.6 0.098 \n", "2 7.8 0.76 0.04 2.3 0.092 \n", "3 11.2 0.28 0.56 1.9 0.075 \n", "4 7.4 0.70 0.00 1.9 0.076 \n", "\n", " free sulfur dioxide total sulfur dioxide density pH sulphates \\\n", "0 11.0 34.0 0.9978 3.51 0.56 \n", "1 25.0 67.0 0.9968 3.20 0.68 \n", "2 15.0 54.0 0.9970 3.26 0.65 \n", "3 17.0 60.0 0.9980 3.16 0.58 \n", "4 11.0 34.0 0.9978 3.51 0.56 \n", "\n", " alcohol quality \n", "0 9.4 5 \n", "1 9.8 5 \n", "2 9.8 5 \n", "3 9.8 6 \n", "4 9.4 5 " ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data = cudf.read_csv(filepath, delimiter=\";\")\n", "data.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Prepare the train/validation sets. Precompute the Quantile DMatrix, which is used by histogram-based tree methods to save memory." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "X = data.iloc[:, :-1].values\n", "y = data[\"quality\"].values\n", "X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)\n", "Xy_train_qdm = xgb.QuantileDMatrix(X_train, y_train) # Precompute Quantile DMatrix to avoid repeated quantization every trial." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Objective function\n", "\n", "We define the objective and a hyperparameter search space to optimize via the `trial.suggest_` methods. \n", "\n", "In each trial, new hyperparameters will be suggested based on previous results. See [optuna.trial.Trial](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.trial.Trial.html) API for a full list of functions." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "def objective(trial):\n", " params = {\n", " \"objective\": \"reg:squarederror\",\n", " \"verbosity\": 0,\n", " \"learning_rate\": trial.suggest_float(\"learning_rate\", 1e-3, 0.1, log=True),\n", " \"max_depth\": trial.suggest_int(\"max_depth\", 1, 10),\n", " \"subsample\": trial.suggest_float(\"subsample\", 0.05, 1.0),\n", " \"colsample_bytree\": trial.suggest_float(\"colsample_bytree\", 0.05, 1.0),\n", " \"min_child_weight\": trial.suggest_int(\"min_child_weight\", 1, 20),\n", " \"tree_method\": \"gpu_hist\",\n", " \"device\": \"cuda\",\n", " }\n", "\n", " booster = xgb.train(params=params, dtrain=Xy_train_qdm, num_boost_round=trial.suggest_int(\"num_boost_round\", 100, 500))\n", " predictions = booster.inplace_predict(X_val)\n", " rmse = mean_squared_error(y_val, predictions, squared=False).get()\n", " \n", " return rmse " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Create the study and optimize. By default, the study results will be stored in memory." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[I 2024-12-11 23:47:48,356] A new study created in memory with name: optuna-xgboost-local\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[I 2024-12-11 23:47:48,724] Trial 0 finished with value: 0.6377619522504244 and parameters: {'learning_rate': 0.005611516415334507, 'max_depth': 10, 'subsample': 0.7453942447208348, 'colsample_bytree': 0.6187255599871848, 'min_child_weight': 4, 'num_boost_round': 162}. Best is trial 0 with value: 0.6377619522504244.\n", "[I 2024-12-11 23:47:49,676] Trial 1 finished with value: 0.6703788974319568 and parameters: {'learning_rate': 0.0013066739238053278, 'max_depth': 9, 'subsample': 0.6210592611560484, 'colsample_bytree': 0.7226689489062432, 'min_child_weight': 1, 'num_boost_round': 488}. Best is trial 0 with value: 0.6377619522504244.\n", "[I 2024-12-11 23:47:49,819] Trial 2 finished with value: 0.6181751362616256 and parameters: {'learning_rate': 0.04622589001020832, 'max_depth': 3, 'subsample': 0.2227337188467456, 'colsample_bytree': 0.22423428436076215, 'min_child_weight': 7, 'num_boost_round': 310}. Best is trial 2 with value: 0.6181751362616256.\n", "[I 2024-12-11 23:47:49,942] Trial 3 finished with value: 0.6698576232920956 and parameters: {'learning_rate': 0.007309539835912915, 'max_depth': 3, 'subsample': 0.6312602499862605, 'colsample_bytree': 0.18251916761943976, 'min_child_weight': 6, 'num_boost_round': 246}. Best is trial 2 with value: 0.6181751362616256.\n", "[I 2024-12-11 23:47:50,060] Trial 4 finished with value: 0.6704590546150145 and parameters: {'learning_rate': 0.008168455894760165, 'max_depth': 8, 'subsample': 0.23969009305044175, 'colsample_bytree': 0.538522716492931, 'min_child_weight': 12, 'num_boost_round': 118}. Best is trial 2 with value: 0.6181751362616256.\n", "[I 2024-12-11 23:47:50,214] Trial 5 finished with value: 0.6088806682631155 and parameters: {'learning_rate': 0.016409286730647923, 'max_depth': 2, 'subsample': 0.11179901333601554, 'colsample_bytree': 0.9514412603906666, 'min_child_weight': 20, 'num_boost_round': 424}. Best is trial 5 with value: 0.6088806682631155.\n", "[I 2024-12-11 23:47:50,289] Trial 6 finished with value: 0.7103495949713845 and parameters: {'learning_rate': 0.0040665633135147945, 'max_depth': 1, 'subsample': 0.700021375186549, 'colsample_bytree': 0.4681448690526212, 'min_child_weight': 3, 'num_boost_round': 298}. Best is trial 5 with value: 0.6088806682631155.\n", "[I 2024-12-11 23:47:50,693] Trial 7 finished with value: 0.7255199474722185 and parameters: {'learning_rate': 0.001171593739230706, 'max_depth': 10, 'subsample': 0.29584098252001606, 'colsample_bytree': 0.6793961701362828, 'min_child_weight': 7, 'num_boost_round': 308}. Best is trial 5 with value: 0.6088806682631155.\n", "[I 2024-12-11 23:47:50,858] Trial 8 finished with value: 0.6060010014477214 and parameters: {'learning_rate': 0.0123999678368461, 'max_depth': 2, 'subsample': 0.9711053963763306, 'colsample_bytree': 0.7863761821930588, 'min_child_weight': 19, 'num_boost_round': 458}. Best is trial 8 with value: 0.6060010014477214.\n", "[I 2024-12-11 23:47:51,199] Trial 9 finished with value: 0.6292433375858283 and parameters: {'learning_rate': 0.015696396388661146, 'max_depth': 10, 'subsample': 0.13406787694932354, 'colsample_bytree': 0.23618371929818793, 'min_child_weight': 1, 'num_boost_round': 230}. Best is trial 8 with value: 0.6060010014477214.\n" ] } ], "source": [ "study = optuna.create_study(study_name=\"optuna-xgboost-local\", sampler=TPESampler(seed=42))\n", "study.optimize(objective, n_trials=10)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Best RMSE: 0.6060010014477214\n", "Best hyperparameters: {'learning_rate': 0.0123999678368461, 'max_depth': 2, 'subsample': 0.9711053963763306, 'colsample_bytree': 0.7863761821930588, 'min_child_weight': 19, 'num_boost_round': 458}\n" ] } ], "source": [ "trial = study.best_trial\n", "print(\"Best RMSE: \", trial.value)\n", "print(\"Best hyperparameters: \", trial.params)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Part 2. Distributed Optuna on Spark " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### PySpark\n", "\n", "For standalone users, we need to create the Spark session with the Spark-Rapids plugin. For Databricks users, the Spark session will be preconfigured and this cell can be skipped." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Plugin file already exists. Skipping download.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "24/12/11 23:47:51 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)\n", "24/12/11 23:47:51 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\n", "24/12/11 23:47:52 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n", "Setting default log level to \"WARN\".\n", "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n", "24/12/11 23:47:52 WARN RapidsPluginUtils: RAPIDS Accelerator 25.02.1 using cudf 25.02.1, private revision bd4e99e18e20234ee0c54f95f4b0bfce18a6255e\n", "24/12/11 23:47:52 WARN RapidsPluginUtils: RAPIDS Accelerator is enabled, to disable GPU support set `spark.rapids.sql.enabled` to false.\n" ] } ], "source": [ "def get_rapids_jar():\n", " SPARK_RAPIDS_VERSION = \"26.02.0\"\n", " rapids_jar = f\"rapids-4-spark_2.12-{SPARK_RAPIDS_VERSION}.jar\"\n", " if not os.path.exists(rapids_jar):\n", " print(\"Downloading Spark Rapids jar\")\n", " url = f\"https://repo1.maven.org/maven2/com/nvidia/rapids-4-spark_2.12/{SPARK_RAPIDS_VERSION}/{rapids_jar}\"\n", " response = requests.get(url)\n", " if response.status_code == 200:\n", " with open(rapids_jar, \"wb\") as f:\n", " f.write(response.content)\n", " print(f\"File '{rapids_jar}' downloaded and saved successfully.\")\n", " else:\n", " print(f\"Failed to download the plugin. Status code: {response.status_code}\")\n", " else:\n", " print(\"Plugin file already exists. Skipping download.\")\n", " return rapids_jar\n", "\n", "def initialize_spark(rapids_jar: str):\n", " import socket\n", " hostname = socket.gethostname()\n", " conda_env = os.environ.get(\"CONDA_PREFIX\")\n", "\n", " conf = SparkConf()\n", " conf.setMaster(f\"spark://{hostname}:7077\") # Assuming master is on host and default port. \n", " conf.set(\"spark.task.maxFailures\", \"1\")\n", " conf.set(\"spark.task.resource.gpu.amount\", f\"{1/4}\") # Setting to 1/4 for single-node demo. In practice, set to 1. \n", " conf.set(\"spark.executor.resource.gpu.amount\", \"1\")\n", " conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"true\")\n", " conf.set(\"spark.pyspark.python\", f\"{conda_env}/bin/python\")\n", " conf.set(\"spark.pyspark.driver.python\", f\"{conda_env}/bin/python\")\n", " conf.set(\"spark.jars\", rapids_jar)\n", " conf.set(\"spark.executorEnv.PYTHONPATH\", rapids_jar)\n", " conf.set(\"spark.rapids.memory.gpu.minAllocFraction\", \"0.0001\")\n", " conf.set(\"spark.plugins\", \"com.nvidia.spark.SQLPlugin\")\n", " conf.set(\"spark.locality.wait\", \"0s\")\n", " conf.set(\"spark.sql.cache.serializer\", \"com.nvidia.spark.ParquetCachedBatchSerializer\")\n", " conf.set(\"spark.rapids.memory.gpu.pooling.enabled\", \"false\")\n", " conf.set(\"spark.sql.execution.sortBeforeRepartition\", \"false\")\n", " conf.set(\"spark.rapids.sql.format.parquet.reader.type\", \"MULTITHREADED\")\n", " conf.set(\"spark.rapids.sql.format.parquet.multiThreadedRead.maxNumFilesParallel\", \"20\")\n", " conf.set(\"spark.rapids.sql.multiThreadedRead.numThreads\", \"20\")\n", " conf.set(\"spark.rapids.sql.python.gpu.enabled\", \"true\")\n", " conf.set(\"spark.rapids.memory.pinnedPool.size\", \"2G\")\n", " conf.set(\"spark.python.daemon.module\", \"rapids.daemon\")\n", " conf.set(\"spark.rapids.sql.batchSizeBytes\", \"512m\")\n", " conf.set(\"spark.sql.adaptive.enabled\", \"false\")\n", " conf.set(\"spark.sql.files.maxPartitionBytes\", \"512m\")\n", " conf.set(\"spark.rapids.sql.concurrentGpuTasks\", \"2\")\n", " conf.set(\"spark.rapids.sql.explain\", \"NONE\")\n", " \n", " spark = SparkSession.builder.appName(\"optuna-spark-xgboost\").config(conf=conf).getOrCreate()\n", " return spark\n", "\n", "if 'spark' not in globals():\n", " rapids_jar = get_rapids_jar()\n", " spark = initialize_spark(rapids_jar)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Helper Class\n", "\n", "First we'll define a helper class. This will store the hyperparameters we want optimized in each trial, and easily convert that into a schema for the output dataframe." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "class OptunaParams:\n", " def __init__(self):\n", " self.hyperparameters = {}\n", "\n", " def add_categorical_param(self, name: str, choices: Sequence[Union[None, bool, int, float, str]]):\n", " \"\"\"\n", " Adds a categorical hyperparameter to be tuned via Optuna's trial.suggest_categorical().\n", " \"\"\"\n", " self.hyperparameters[name] = { \"type\": \"categorical\", \"choices\": choices }\n", " \n", " def add_int_param(self, name: str, low: int, high: int, step: int = 1, log: bool = False):\n", " \"\"\"\n", " Adds an integer hyperparameter to be tuned via Optuna's trial.suggest_int().\n", " \"\"\"\n", " self.hyperparameters[name] = { \"type\": \"int\", \"low\": low, \"high\": high, \"step\": step, \"log\": log }\n", " \n", " def add_float_param(self, name: str, low: float, high: float, step: Optional[float] = None, log: bool = False):\n", " \"\"\"\n", " Adds a float hyperparameter to be tuned via Optuna's trial.suggest_float().\n", " \"\"\"\n", " self.hyperparameters[name] = { \"type\": \"float\", \"low\": low, \"high\": high, \"step\": step,\"log\": log }\n", "\n", " def suggest_params(self, trial) -> Dict[str, Union[int, float, str, bool]]:\n", " \"\"\"\n", " Converts the hyperparameter space into a dictionary of suggested values in Optuna format,\n", " to be called within the objective function.\n", " \"\"\"\n", " suggested_params = {}\n", " for name, config in self.hyperparameters.items():\n", " if config[\"type\"] == \"categorical\":\n", " suggested_params[name] = trial.suggest_categorical(name, config[\"choices\"])\n", " elif config[\"type\"] == \"int\":\n", " suggested_params[name] = trial.suggest_int(\n", " name, config[\"low\"], config[\"high\"], step=config[\"step\"], log=config[\"log\"]\n", " )\n", " elif config[\"type\"] == \"float\":\n", " suggested_params[name] = trial.suggest_float(\n", " name, config[\"low\"], config[\"high\"], step=config.get(\"step\", None), log=config[\"log\"]\n", " )\n", " return suggested_params\n", "\n", " def to_schema(self) -> StructType:\n", " \"\"\"\n", " Converts the hyperparameter space into a Spark StructType output schema.\n", " \"\"\"\n", " fields = []\n", " for name, config in self.hyperparameters.items():\n", " if config[\"type\"] == \"float\":\n", " fields.append(StructField(name, DoubleType(), False))\n", " elif config[\"type\"] == \"int\":\n", " fields.append(StructField(name, IntegerType(), False))\n", " elif config[\"type\"] == \"categorical\":\n", " if isinstance(config[\"choices\"][0], str):\n", " fields.append(StructField(name, StringType(), False))\n", " elif isinstance(config[\"choices\"][0], bool):\n", " fields.append(StructField(name, BooleanType(), False))\n", " elif isinstance(config[\"choices\"][0], (int, float)):\n", " fields.append(StructField(name, DoubleType(), False))\n", " else:\n", " raise ValueError(f\"Unsupported categorical type for field {name}\")\n", " \n", " # Study will also return the best achieved loss:\n", " fields.append(StructField(\"best_value\", DoubleType(), False)) \n", " return StructType(fields)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Optuna Task\n", "\n", "This implementation demonstrates **Spark I/O**.\n", "\n", "This means that Spark will read the dataset and create a duplicate of the dataset for each worker (1 partition = 1 duplicate), then map the tuning task onto each partition. \n", "In practice, this enables the code to be chained to other Dataframe operations (e.g. ETL stages) without the intermediate step of writing to DBFS, at the cost of some overhead during duplication.\n", "\n", "For the alternative implementation using **Worker I/O**, see the [JoblibSpark notebook](optuna-joblibspark.ipynb). " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In the task, each worker will:\n", "1. Concatenate the pandas partition batches to form the dataset\n", "2. Load the study from the MySQL storage backend\n", "3. Optimize over the objective for the assigned number of trials, sending results back to the database after each iteration" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "def task_udf(pdf_iter: Iterable[pd.DataFrame],\n", " xgb_params: Dict[str, Any],\n", " optuna_params: OptunaParams,\n", " trials_per_task: List[int], \n", " driver_ip: str,\n", " study_name: str,\n", " seed: int) -> Iterable[pd.DataFrame]:\n", "\n", " import cudf\n", " from cuml.metrics.regression import mean_squared_error\n", " from cuml.model_selection import train_test_split\n", " \n", " tc = TaskContext.get()\n", " assert \"gpu\" in tc.resources(), \"GPU resource not found.\"\n", " num_trials = trials_per_task[tc.partitionId()]\n", "\n", " df_list = []\n", " for pdf in pdf_iter:\n", " df_list.append(cudf.DataFrame.from_pandas(pdf))\n", " \n", " data = cudf.concat(df_list)\n", " X = data.iloc[:, :-1].values\n", " y = data[\"quality\"].values\n", " X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)\n", "\n", " tuning_max_bin = \"max_bin\" in optuna_params.hyperparameters\n", " if not tuning_max_bin:\n", " max_bin = xgb_params.get(\"max_bin\", 256)\n", " # Precompute Quantile DMatrix to avoid repeated quantization every trial.\n", " Xy_train_qdm = xgb.QuantileDMatrix(X_train, y_train, max_bin=max_bin)\n", "\n", " def objective(trial):\n", " tuning_params = optuna_params.suggest_params(trial)\n", " xgb_params.update(tuning_params)\n", "\n", " if tuning_max_bin:\n", " # If tuning the max_bin param, we must recompute the QDM every trial, since the quantiles change.\n", " if \"n_estimators\" not in xgb_params:\n", " xgb_params[\"n_estimators\"] = 100 # Default value if not tuning.\n", "\n", " model = xgb.XGBRegressor(**xgb_params)\n", " model.fit(X_train, y_train)\n", " booster = model.get_booster()\n", " else:\n", " # Train the model with xgb.train() API using the precomputed QDM.\n", " num_boost_round = xgb_params.get(\"n_estimators\", 100)\n", " booster = xgb.train(params=xgb_params, dtrain=Xy_train_qdm, num_boost_round=num_boost_round)\n", " \n", " predictions = booster.inplace_predict(X_val)\n", " rmse = mean_squared_error(y_val, predictions, squared=False).get()\n", " \n", " return rmse\n", "\n", " study = optuna.load_study(\n", " study_name=study_name,\n", " storage=f\"mysql://optuna_user:optuna_password@{driver_ip}/optuna\",\n", " sampler=TPESampler(seed=seed),\n", " )\n", "\n", " print(f\"Running {num_trials} trials on partition {tc.partitionId()}.\")\n", " study.optimize(objective, n_trials=num_trials)\n", "\n", " result_dict = {f\"{key}\": [value] for key, value in study.best_params.items()}\n", " result_dict['best_value'] = [study.best_value]\n", " \n", " yield pd.DataFrame(result_dict)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup and run the Optuna study" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Get the driver IP for the MySQL database. \n", "- For standalone users, make sure you've followed the [database setup instructions](../README.md#setup-database-for-optuna). The database should be on 'localhost'. \n", "- For databricks users, the database should already be setup on the driver node by the init script." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "# check if we're running on databricks\n", "on_databricks = os.environ.get(\"DATABRICKS_RUNTIME_VERSION\", False)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MySQL database is hosted on localhost\n" ] } ], "source": [ "if on_databricks:\n", " driver_ip = spark.conf.get(\"spark.driver.host\")\n", "else:\n", " driver_ip = \"localhost\"\n", "\n", "print(f\"MySQL database is hosted on {driver_ip}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Create a new study, referencing the MySQL database as the storage backend." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[I 2024-12-11 23:47:53,347] A new study created in RDB with name: optuna-xgboost-dataframe\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "study_name = \"optuna-xgboost-dataframe\"\n", "seed = 42\n", "\n", "try:\n", " # Delete the study if it already exists\n", " optuna.delete_study(\n", " study_name=study_name, \n", " storage=f\"mysql://optuna_user:optuna_password@{driver_ip}/optuna\"\n", " )\n", "except:\n", " pass\n", "\n", "optuna.create_study(\n", " study_name=study_name,\n", " storage=f\"mysql://optuna_user:optuna_password@{driver_ip}/optuna\",\n", " sampler=TPESampler(seed=seed)\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Define the number of tasks, number of trials, and trials per task. \n", "\n", "**NOTE**: for standalone users running on a single worker, the 4 tasks will all be assigned to the same worker and will time-share the GPU for demonstration. In practice, you should set `spark.task.resource.gpu.amount=1` and set num_tasks to the number of workers in the cluster so that each task gets full access to the GPU." ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "def partition_trials(total_trials: int, total_tasks: int) -> List[int]:\n", " base_size = total_trials // total_tasks\n", " extra = total_trials % total_tasks\n", " partitions = [base_size] * total_tasks\n", " for i in range(extra):\n", " partitions[i] += 1\n", " \n", " return partitions" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Trials per task: [25, 25, 25, 25]\n" ] } ], "source": [ "num_tasks = 4\n", "num_trials = 100\n", "trials_per_task = partition_trials(num_trials, num_tasks)\n", "print(f\"Trials per task: {trials_per_task}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Define params\n", "Define the XGBoost model params and the hyperparams for Optuna to tune. " ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "# Keep these params consistent:\n", "xgb_params = {\n", " \"objective\": \"reg:squarederror\",\n", " \"verbosity\": 0,\n", " \"tree_method\": \"gpu_hist\",\n", " \"device\": \"cuda\",\n", " \"seed\": seed,\n", "}" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "# Tune these params:\n", "hyperparams = OptunaParams()\n", "hyperparams.add_int_param(\"n_estimators\", low=100, high=500)\n", "hyperparams.add_float_param(\"learning_rate\", low=1e-3, high=0.1, log=True)\n", "hyperparams.add_int_param(\"max_depth\", low=1, high=10)\n", "hyperparams.add_float_param(\"subsample\", low=0.05, high=1.0)\n", "hyperparams.add_float_param(\"colsample_bytree\", low=0.05, high=1.0)\n", "hyperparams.add_int_param(\"min_child_weight\", low=1, high=20)\n", "\n", "out_schema = hyperparams.to_schema()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll also define the following helper function, which will create duplicates of the dataframe held in separate partitions." ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "def coalesce_tree_union(df: DataFrame, num_duplicates: int):\n", " \"\"\"\n", " Coalesce the DataFrame to a single partition and recursively self-union to create duplicates.\n", " \"\"\"\n", " input_df = df.coalesce(1).cache()\n", " current_df = input_df\n", " \n", " if num_duplicates <= 1:\n", " return current_df\n", "\n", " recursions = int(math.log(num_duplicates, 2))\n", " remainder = num_duplicates - 2 ** recursions\n", "\n", " for _ in range(recursions):\n", " current_df = current_df.union(current_df)\n", "\n", " for _ in range(remainder):\n", " current_df = current_df.union(input_df)\n", " \n", " return current_df" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Load dataset\n", "\n", "Read the data from the local directory with Spark and then duplicate it to prepare to run the task." ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "if on_databricks:\n", " # once the dataset is in dbfs, databricks appends \"dbfs:\" to the filepath automatically\n", " filepath = '/FileStore/optuna-data/winequality-red.csv'\n", "else:\n", " cwd = os.getcwd()\n", " filepath = os.path.join(cwd, \"data\", \"winequality-red.csv\")\n", "\n", "in_schema = StructType([\n", " StructField(\"fixed acidity\", DoubleType(), True),\n", " StructField(\"volatile acidity\", DoubleType(), True),\n", " StructField(\"citric acid\", DoubleType(), True),\n", " StructField(\"residual sugar\", DoubleType(), True),\n", " StructField(\"chlorides\", DoubleType(), True),\n", " StructField(\"free sulfur dioxide\", DoubleType(), True),\n", " StructField(\"total sulfur dioxide\", DoubleType(), True),\n", " StructField(\"density\", DoubleType(), True),\n", " StructField(\"pH\", DoubleType(), True),\n", " StructField(\"sulphates\", DoubleType(), True),\n", " StructField(\"alcohol\", DoubleType(), True),\n", " StructField(\"quality\", IntegerType(), True)\n", "])\n", "\n", "data_df = spark.read.csv(filepath, header=True, schema=in_schema, sep=\";\")\n", "data_df = coalesce_tree_union(data_df, num_duplicates=num_tasks) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Run the study\n", "\n", "Map the Optuna task onto the dataframe and collect the results (it might take a few minutes)." ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "result_df = data_df.mapInPandas(lambda pdf_iter: \n", " task_udf(pdf_iter,\n", " xgb_params=xgb_params,\n", " optuna_params=hyperparams,\n", " trials_per_task=trials_per_task,\n", " driver_ip=driver_ip,\n", " study_name=study_name,\n", " seed=seed),\n", " schema=out_schema).toPandas()" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Best parameters: {'n_estimators': 419.0, 'learning_rate': 0.015039610889407229, 'max_depth': 10.0, 'subsample': 0.6630214978050138, 'colsample_bytree': 0.8524338650689898, 'min_child_weight': 2.0}\n", "Best value: 0.533100375625104\n" ] } ], "source": [ "results = result_df.iloc[0].to_dict()\n", "best_value = results.pop(\"best_value\")\n", "\n", "print(f\"Best parameters: {results}\")\n", "print(f\"Best value: {best_value}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "optuna-spark", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.15" } }, "nbformat": 4, "nbformat_minor": 2 } ================================================ FILE: examples/ML+DL-Examples/Optuna-Spark/optuna-examples/optuna-joblibspark.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "#\n", "# Copyright (c) 2024, NVIDIA CORPORATION.\n", "#\n", "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# http://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License.\n", "#" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "\n", "# Distributed Hyperparameter Tuning: Optuna + JoblibSpark\n", "\n", "\n", "This demo demonstrates distributed hyperparameter tuning for XGBoost using the [JoblibSpark backend](https://github.com/joblib/joblib-spark), building on this [example from Databricks](https://docs.databricks.com/en/machine-learning/automl-hyperparam-tuning/optuna.html). \n", "We implement best practices to precompute data and maximize computations on the GPU. \n", "\n", "\n", "\n", "Reference: https://forecastegy.com/posts/xgboost-hyperparameter-tuning-with-optuna/" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Note:\n", "Before running, please make sure you've followed the relevant [setup instructions](../README.md) for your environment (standalone or databricks).\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from typing import List\n", "import os\n", "import requests\n", "import joblib\n", "from joblibspark import register_spark\n", "import optuna\n", "from optuna.samplers import TPESampler\n", "import xgboost as xgb\n", "from pyspark.sql import SparkSession\n", "from pyspark import TaskContext, SparkConf" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Download the dataset\n", "\n", "We'll use the [red wine quality dataset](https://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-red.csv) to regress wine quality based on features such as acidity, sugar content, etc. \n", "\n", "**Note**: This example uses a small dataset for demonstration purposes. The performance advantages of distributed training are best realized with large datasets and computational workloads." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "cwd = os.getcwd()\n", "os.mkdir(os.path.join(cwd, \"data\")) if not os.path.exists(os.path.join(cwd, \"data\")) else None\n", "filepath = os.path.join(cwd, \"data\", \"winequality-red.csv\")" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "File downloaded and saved to /home/rishic/Code/myforks/spark-rapids-examples/examples/ML+DL-Examples/Optuna-Spark/optuna-examples/data/winequality-red.csv\n" ] } ], "source": [ "url = \"https://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-red.csv\"\n", "\n", "response = requests.get(url)\n", "if response.status_code == 200:\n", " with open(filepath, \"wb\") as f:\n", " f.write(response.content)\n", " print(f\"File downloaded and saved to {filepath}\")\n", "else:\n", " print(f\"Failed to download the file. Status code: {response.status_code}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Part 1. Running Optuna locally" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "import cudf\n", "from cuml.metrics.regression import mean_squared_error\n", "from cuml.model_selection import train_test_split" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Prepare data" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
fixed acidityvolatile aciditycitric acidresidual sugarchloridesfree sulfur dioxidetotal sulfur dioxidedensitypHsulphatesalcoholquality
07.40.700.001.90.07611.034.00.99783.510.569.45
17.80.880.002.60.09825.067.00.99683.200.689.85
27.80.760.042.30.09215.054.00.99703.260.659.85
311.20.280.561.90.07517.060.00.99803.160.589.86
47.40.700.001.90.07611.034.00.99783.510.569.45
\n", "
" ], "text/plain": [ " fixed acidity volatile acidity citric acid residual sugar chlorides \\\n", "0 7.4 0.70 0.00 1.9 0.076 \n", "1 7.8 0.88 0.00 2.6 0.098 \n", "2 7.8 0.76 0.04 2.3 0.092 \n", "3 11.2 0.28 0.56 1.9 0.075 \n", "4 7.4 0.70 0.00 1.9 0.076 \n", "\n", " free sulfur dioxide total sulfur dioxide density pH sulphates \\\n", "0 11.0 34.0 0.9978 3.51 0.56 \n", "1 25.0 67.0 0.9968 3.20 0.68 \n", "2 15.0 54.0 0.9970 3.26 0.65 \n", "3 17.0 60.0 0.9980 3.16 0.58 \n", "4 11.0 34.0 0.9978 3.51 0.56 \n", "\n", " alcohol quality \n", "0 9.4 5 \n", "1 9.8 5 \n", "2 9.8 5 \n", "3 9.8 6 \n", "4 9.4 5 " ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data = cudf.read_csv(filepath, delimiter=\";\")\n", "data.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Prepare the train/validation sets. Precompute the Quantile DMatrix, which is used by histogram-based tree methods to save memory." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "X = data.iloc[:, :-1].values\n", "y = data[\"quality\"].values\n", "X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)\n", "Xy_train_qdm = xgb.QuantileDMatrix(X_train, y_train) # Precompute Quantile DMatrix to avoid repeated quantization every trial." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Objective function\n", "\n", "We define the objective and a hyperparameter search space to optimize via the `trial.suggest_` methods. \n", "\n", "In each trial, new hyperparameters will be suggested based on previous results. See [optuna.trial.Trial](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.trial.Trial.html) API for a full list of functions." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "def objective(trial):\n", " params = {\n", " \"objective\": \"reg:squarederror\",\n", " \"verbosity\": 0,\n", " \"learning_rate\": trial.suggest_float(\"learning_rate\", 1e-3, 0.1, log=True),\n", " \"max_depth\": trial.suggest_int(\"max_depth\", 1, 10),\n", " \"subsample\": trial.suggest_float(\"subsample\", 0.05, 1.0),\n", " \"colsample_bytree\": trial.suggest_float(\"colsample_bytree\", 0.05, 1.0),\n", " \"min_child_weight\": trial.suggest_int(\"min_child_weight\", 1, 20),\n", " \"tree_method\": \"gpu_hist\",\n", " \"device\": \"cuda\",\n", " }\n", "\n", " booster = xgb.train(params=params, dtrain=Xy_train_qdm, num_boost_round=trial.suggest_int(\"num_boost_round\", 100, 500))\n", " predictions = booster.inplace_predict(X_val)\n", " rmse = mean_squared_error(y_val, predictions, squared=False).get()\n", " \n", " return rmse " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Create the study and optimize. By default, the study results will be stored in memory." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[I 2024-12-11 23:42:09,341] A new study created in memory with name: optuna-xgboost-local\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[I 2024-12-11 23:42:09,715] Trial 0 finished with value: 0.6377619522504244 and parameters: {'learning_rate': 0.005611516415334507, 'max_depth': 10, 'subsample': 0.7453942447208348, 'colsample_bytree': 0.6187255599871848, 'min_child_weight': 4, 'num_boost_round': 162}. Best is trial 0 with value: 0.6377619522504244.\n", "[I 2024-12-11 23:42:10,666] Trial 1 finished with value: 0.6703788974319568 and parameters: {'learning_rate': 0.0013066739238053278, 'max_depth': 9, 'subsample': 0.6210592611560484, 'colsample_bytree': 0.7226689489062432, 'min_child_weight': 1, 'num_boost_round': 488}. Best is trial 0 with value: 0.6377619522504244.\n", "[I 2024-12-11 23:42:10,806] Trial 2 finished with value: 0.6181751362616256 and parameters: {'learning_rate': 0.04622589001020832, 'max_depth': 3, 'subsample': 0.2227337188467456, 'colsample_bytree': 0.22423428436076215, 'min_child_weight': 7, 'num_boost_round': 310}. Best is trial 2 with value: 0.6181751362616256.\n", "[I 2024-12-11 23:42:10,922] Trial 3 finished with value: 0.6698576232920956 and parameters: {'learning_rate': 0.007309539835912915, 'max_depth': 3, 'subsample': 0.6312602499862605, 'colsample_bytree': 0.18251916761943976, 'min_child_weight': 6, 'num_boost_round': 246}. Best is trial 2 with value: 0.6181751362616256.\n", "[I 2024-12-11 23:42:11,039] Trial 4 finished with value: 0.6704590546150145 and parameters: {'learning_rate': 0.008168455894760165, 'max_depth': 8, 'subsample': 0.23969009305044175, 'colsample_bytree': 0.538522716492931, 'min_child_weight': 12, 'num_boost_round': 118}. Best is trial 2 with value: 0.6181751362616256.\n", "[I 2024-12-11 23:42:11,191] Trial 5 finished with value: 0.6088806682631155 and parameters: {'learning_rate': 0.016409286730647923, 'max_depth': 2, 'subsample': 0.11179901333601554, 'colsample_bytree': 0.9514412603906666, 'min_child_weight': 20, 'num_boost_round': 424}. Best is trial 5 with value: 0.6088806682631155.\n", "[I 2024-12-11 23:42:11,266] Trial 6 finished with value: 0.7103495949713845 and parameters: {'learning_rate': 0.0040665633135147945, 'max_depth': 1, 'subsample': 0.700021375186549, 'colsample_bytree': 0.4681448690526212, 'min_child_weight': 3, 'num_boost_round': 298}. Best is trial 5 with value: 0.6088806682631155.\n", "[I 2024-12-11 23:42:11,666] Trial 7 finished with value: 0.7255199474722185 and parameters: {'learning_rate': 0.001171593739230706, 'max_depth': 10, 'subsample': 0.29584098252001606, 'colsample_bytree': 0.6793961701362828, 'min_child_weight': 7, 'num_boost_round': 308}. Best is trial 5 with value: 0.6088806682631155.\n", "[I 2024-12-11 23:42:11,829] Trial 8 finished with value: 0.6060010014477214 and parameters: {'learning_rate': 0.0123999678368461, 'max_depth': 2, 'subsample': 0.9711053963763306, 'colsample_bytree': 0.7863761821930588, 'min_child_weight': 19, 'num_boost_round': 458}. Best is trial 8 with value: 0.6060010014477214.\n", "[I 2024-12-11 23:42:12,168] Trial 9 finished with value: 0.6292433375858283 and parameters: {'learning_rate': 0.015696396388661146, 'max_depth': 10, 'subsample': 0.13406787694932354, 'colsample_bytree': 0.23618371929818793, 'min_child_weight': 1, 'num_boost_round': 230}. Best is trial 8 with value: 0.6060010014477214.\n" ] } ], "source": [ "study = optuna.create_study(study_name=\"optuna-xgboost-local\", sampler=TPESampler(seed=42))\n", "study.optimize(objective, n_trials=10)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Best RMSE: 0.6060010014477214\n", "Best hyperparameters: {'learning_rate': 0.0123999678368461, 'max_depth': 2, 'subsample': 0.9711053963763306, 'colsample_bytree': 0.7863761821930588, 'min_child_weight': 19, 'num_boost_round': 458}\n" ] } ], "source": [ "trial = study.best_trial\n", "print(\"Best RMSE: \", trial.value)\n", "print(\"Best hyperparameters: \", trial.params)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Part 2. Distributed Optuna on Spark " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### PySpark\n", "\n", "For standalone users, we need to create the Spark session. For Databricks users, the Spark session will be preconfigured." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "24/12/11 23:42:12 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)\n", "24/12/11 23:42:12 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\n", "Setting default log level to \"WARN\".\n", "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n", "24/12/11 23:42:13 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n" ] } ], "source": [ "def initialize_spark():\n", " import socket\n", " hostname = socket.gethostname()\n", " conda_env = os.environ.get(\"CONDA_PREFIX\")\n", "\n", " conf = SparkConf()\n", " conf.setMaster(f\"spark://{hostname}:7077\") # Assuming master is on host and default port. \n", " conf.set(\"spark.task.maxFailures\", \"1\")\n", " conf.set(\"spark.task.resource.gpu.amount\", \"1\")\n", " conf.set(\"spark.executor.resource.gpu.amount\", \"1\")\n", " conf.set(\"spark.pyspark.python\", f\"{conda_env}/bin/python\")\n", " conf.set(\"spark.pyspark.driver.python\", f\"{conda_env}/bin/python\")\n", " \n", " spark = SparkSession.builder.appName(\"optuna-joblibspark-xgboost\").config(conf=conf).getOrCreate()\n", " return spark\n", "\n", "if 'spark' not in globals():\n", " spark = initialize_spark()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Optuna Task\n", "\n", "This implementation demonstrates **Worker I/O**. \n", "\n", "This means that each worker will read the full dataset from the filepath rather than passing the data in a dataframe. \n", "In practice, this requires the dataset to be written to a distributed file system accessible to all workers prior to tuning. \n", "\n", "For the alternative implementation using **Spark I/O**, see the [Spark Dataframe notebook](optuna-dataframe.ipynb)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In the task, each worker will:\n", "1. Read the dataset from the filepath\n", "2. Load the study from the MySQL storage backend\n", "3. Optimize over the objective for the assigned number of trials, sending results back to the database after each iteration\n", "\n", "Here we use Optuna's [Define-and-Run](https://optuna.readthedocs.io/en/stable/tutorial/20_recipes/009_ask_and_tell.html#define-and-run) API, which allows us to predefine the hyperparameter space and pass it to the task." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "def task(num_trials: int, xgb_params: dict, optuna_params: dict, driver_ip: str, study_name: str, seed: int, filepath: str):\n", " import cudf\n", " from cuml.metrics.regression import mean_squared_error\n", " from cuml.model_selection import train_test_split\n", "\n", " tc = TaskContext.get()\n", " assert \"gpu\" in tc.resources(), \"GPU resource not found.\"\n", "\n", " if filepath.startswith(\"/dbfs/\"):\n", " # Check to ensure GPU direct storage is disabled for cuDF on databricks.\n", " libcudf_policy = os.environ.get('LIBCUDF_CUFILE_POLICY')\n", " if libcudf_policy != 'OFF':\n", " raise RuntimeError(\"Set LIBCUDF_CUFILE_POLICY=OFF to read from DBFS with cuDF.\")\n", " \n", " data = cudf.read_csv(filepath, delimiter=\";\")\n", " X = data.iloc[:, :-1].values\n", " y = data[\"quality\"].values\n", " X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=seed)\n", "\n", " tuning_max_bin = \"max_bin\" in optuna_params\n", " if not tuning_max_bin:\n", " max_bin = xgb_params.get(\"max_bin\", 256)\n", " # Precompute Quantile DMatrix to avoid repeated quantization every trial.\n", " Xy_train_qdm = xgb.QuantileDMatrix(X_train, y_train, max_bin=max_bin)\n", "\n", " study = optuna.load_study(\n", " study_name=study_name,\n", " storage=f\"mysql://optuna_user:optuna_password@{driver_ip}/optuna\",\n", " sampler=TPESampler(seed=seed),\n", " )\n", "\n", " print(f\"Running {num_trials} trials on partition {tc.partitionId()}.\")\n", "\n", " ### Objective ###\n", " for _ in range(num_trials):\n", " trial = study.ask(optuna_params)\n", " xgb_params.update(trial.params)\n", "\n", " if tuning_max_bin:\n", " # If tuning the max_bin param, we must recompute the QDM every trial.\n", " if \"n_estimators\" not in xgb_params:\n", " xgb_params[\"n_estimators\"] = 100 # Default value if not tuning.\n", "\n", " model = xgb.XGBRegressor(**xgb_params)\n", " model.fit(X_train, y_train)\n", " booster = model.get_booster()\n", " else:\n", " # Train the model with xgb.train() API using the precomputed QDM.\n", " num_boost_round = xgb_params.get(\"n_estimators\", 100)\n", " booster = xgb.train(params=xgb_params, dtrain=Xy_train_qdm, num_boost_round=num_boost_round)\n", " \n", " # Perform in-place predictions on GPU using the booster.\n", " predictions = booster.inplace_predict(X_val)\n", " rmse = mean_squared_error(y_val, predictions, squared=False).get()\n", " \n", " study.tell(trial, rmse)\n", "\n", " return study.best_params, study.best_value" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "# This will register the Spark Session with the Joblib Spark backend.\n", "register_spark()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup and run the Optuna study" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Get the driver IP for the MySQL database. \n", "- For standalone users, make sure you've followed the [database setup instructions](../README.md#setup-database-for-optuna). The database should be on 'localhost'. \n", "- For databricks users, the database should already be setup on the driver node by the init script." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "# check if we're running on databricks\n", "on_databricks = os.environ.get(\"DATABRICKS_RUNTIME_VERSION\", False)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MySQL database is hosted on localhost\n" ] } ], "source": [ "if on_databricks:\n", " driver_ip = spark.conf.get(\"spark.driver.host\")\n", "else:\n", " driver_ip = \"localhost\"\n", "\n", "print(f\"MySQL database is hosted on {driver_ip}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Create a new study, referencing the MySQL database as the storage backend." ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[I 2024-12-11 23:42:13,928] A new study created in RDB with name: optuna-xgboost-joblibspark\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "study_name = \"optuna-xgboost-joblibspark\"\n", "seed = 42\n", "\n", "try:\n", " # Delete the study if it already exists\n", " optuna.delete_study(\n", " study_name=study_name, \n", " storage=f\"mysql://optuna_user:optuna_password@{driver_ip}/optuna\"\n", " )\n", "except:\n", " pass\n", "\n", "optuna.create_study(\n", " study_name=study_name,\n", " storage=f\"mysql://optuna_user:optuna_password@{driver_ip}/optuna\",\n", " sampler=TPESampler(seed=seed)\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Define the number of tasks, number of trials, and trials per task. \n", "\n", "**NOTE**: for standalone users running on a single worker, the 4 tasks will all be assigned to the same worker and executed sequentially in this demonstration. This can easily be scaled up to run concurrently by adding more workers." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "def partition_trials(total_trials: int, total_tasks: int) -> List[int]:\n", " base_size = total_trials // total_tasks\n", " extra = total_trials % total_tasks\n", " partitions = [base_size] * total_tasks\n", " for i in range(extra):\n", " partitions[i] += 1\n", " \n", " return partitions" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Trials per task: [25, 25, 25, 25]\n" ] } ], "source": [ "num_tasks = 4\n", "num_trials = 100\n", "trials_per_task = partition_trials(num_trials, num_tasks)\n", "print(f\"Trials per task: {trials_per_task}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Define params\n", "Define the XGBoost model params and the hyperparams for Optuna to tune. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Keep these params consistent:\n", "xgb_params = {\n", " \"objective\": \"reg:squarederror\",\n", " \"verbosity\": 0,\n", " \"tree_method\": \"gpu_hist\",\n", " \"device\": f\"cuda\",\n", " \"seed\": seed,\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Tune these params:\n", "optuna_params = {\n", " \"n_estimators\": optuna.distributions.IntDistribution(100, 500),\n", " \"learning_rate\": optuna.distributions.FloatDistribution(1e-3, 0.1, log=True),\n", " \"max_depth\": optuna.distributions.IntDistribution(1, 10),\n", " \"subsample\": optuna.distributions.FloatDistribution(0.05, 1.0),\n", " \"colsample_bytree\": optuna.distributions.FloatDistribution(0.05, 1.0),\n", " \"min_child_weight\": optuna.distributions.IntDistribution(1, 20),\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**For Databricks**: we must download the dataset to DBFS so that all workers can access it." ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "if on_databricks:\n", " dbutils.fs.mkdirs(\"/FileStore/optuna-data\")\n", " filepath = \"/dbfs/FileStore/optuna-data/winequality-red.csv\"\n", " url = \"https://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-red.csv\"\n", "\n", " response = requests.get(url)\n", " if response.status_code == 200:\n", " with open(filepath, \"wb\") as f:\n", " f.write(response.content)\n", " print(f\"File downloaded and saved to {filepath}\")\n", " else:\n", " print(f\"Failed to download the file. Status code: {response.status_code}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Run the study\n", "\n", "Run parallel threads to execute the Optuna task and collect the reuslts (it might take a few minutes)." ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/rishic/anaconda3/envs/optuna-spark/lib/python3.10/site-packages/joblibspark/backend.py:115: UserWarning: Spark version does not support stage-level scheduling.\n", " warnings.warn(\"Spark version does not support stage-level scheduling.\")\n", "/home/rishic/anaconda3/envs/optuna-spark/lib/python3.10/site-packages/joblibspark/backend.py:154: UserWarning: User-specified n_jobs (4) is greater than the max number of concurrent tasks (1) this cluster can run now.If dynamic allocation is enabled for the cluster, you might see more executors allocated.\n", " warnings.warn(f\"User-specified n_jobs ({n_jobs}) is greater than the max number of \"\n", " \r" ] } ], "source": [ "with joblib.parallel_backend(\"spark\", n_jobs=num_tasks):\n", " results = joblib.Parallel()(\n", " joblib.delayed(task)(num_trials,\n", " xgb_params,\n", " optuna_params,\n", " driver_ip,\n", " study_name,\n", " seed,\n", " filepath) for num_trials in trials_per_task\n", " )" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Best parameters: {'n_estimators': 463, 'learning_rate': 0.05206124631137337, 'max_depth': 9, 'subsample': 0.7434942725744815, 'colsample_bytree': 0.877391644494205, 'min_child_weight': 4}\n", "Best value: 0.5324732150787205\n" ] } ], "source": [ "best_params = min(results, key=lambda x: x[1])[0]\n", "best_value = min(results, key=lambda x: x[1])[1]\n", "\n", "print(f\"Best parameters: {best_params}\")\n", "print(f\"Best value: {best_value}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "optuna-spark", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.15" } }, "nbformat": 4, "nbformat_minor": 2 } ================================================ FILE: examples/ML+DL-Examples/Spark-DL/dl_inference/README.md ================================================ # Deep Learning Inference on Spark Example notebooks demonstrating **distributed deep learning inference** using the [predict_batch_udf](https://developer.nvidia.com/blog/distributed-deep-learning-made-easy-with-spark-3-4/#distributed_inference) introduced in Spark 3.4.0. These notebooks also demonstrate model serving integrations with [Triton Inference Server](https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/index.html) and [vLLM serve](https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html). ## Contents: - [Overview](#overview) - [Running Locally](#running-locally) - [Running on Cloud](#running-on-cloud-environments) - [Inference Serving Integration](#inference-serving) ## Overview These notebooks demonstrate how models from external frameworks (Torch, Huggingface, Tensorflow, vLLM) trained on single-worker machines can be used for large-scale distributed inference on Spark clusters. For example, a basic model trained in TensorFlow and saved on disk as "mnist_model" can be used in Spark as follows: ``` import numpy as np from pyspark.sql.functions import predict_batch_udf from pyspark.sql.types import ArrayType, FloatType def predict_batch_fn(): import tensorflow as tf model = tf.keras.models.load_model("/path/to/mnist_model") def predict(inputs: np.ndarray) -> np.ndarray: return model.predict(inputs) return predict mnist = predict_batch_udf(predict_batch_fn, return_type=ArrayType(FloatType()), batch_size=1024, input_tensor_shapes=[[784]]) df = spark.read.parquet("mnist_data") predictions = df.withColumn("preds", mnist("data")).collect() ``` In this simple case, the `predict_batch_fn` will use TensorFlow APIs to load the model and return a simple `predict` function. The `predict_batch_udf` will handle the data conversion from Spark DataFrame columns into batched numpy inputs. #### Notebook List Below is a full list of the notebooks and their links. All notebooks have been saved with sample outputs for quick browsing. | | Framework | Notebook Name | Description | Link | ------------- | ------------- | ------------- | ------------- | ------------- | 1 | HuggingFace | DeepSeek-R1 | LLM batch inference using the DeepSeek-R1-Distill-Llama reasoning model to solve word problems. | [Link](huggingface/deepseek-r1_torch.ipynb) | 2 | HuggingFace | Qwen-2.5-7b | LLM batch inference using the Qwen-2.5-7b model for text summarization. | [Link](huggingface/qwen-2.5-7b_torch.ipynb) | 3 | HuggingFace | Gemma-7b | LLM batch inference using the Google Gemma-7b model for code comprehension tasks. | [Link](huggingface/gemma-7b_torch.ipynb) | 4 | HuggingFace | Sentence Transformers | Sentence embeddings using SentenceTransformers in Torch. | [Link](huggingface/sentence_transformers_torch.ipynb) | 5+6 | HuggingFace | Conditional Generation | Sentence translation using the T5 text-to-text transformer (Torch and Tensorflow). | [Torch Link](huggingface/conditional_generation_torch.ipynb), [TF Link](huggingface/conditional_generation_tf.ipynb) | 7+8 | HuggingFace | Pipelines | Sentiment analysis using Huggingface pipelines (Torch and Tensorflow). | [Torch Link](huggingface/pipelines_torch.ipynb), [TF Link](huggingface/pipelines_tf.ipynb) | 9 | vLLM | Qwen-2.5-14b-tensor-parallel | Tensor-parallel LLM batch inference using the Qwen-2.5-14b model to summarize unstructured text data into a structured schema, using vLLM serve. | [Link](vllm/qwen-2.5-14b-tensor-parallel_vllm.ipynb) | 10 | vLLM | Qwen-2.5-7b | LLM batch inference using the Qwen-2.5-7b model to summarize for text summarization, using vLLM serve. | [Link](vllm/qwen-2.5-7b_vllm.ipynb) | 11 | PyTorch | Image Classification | Training a model to predict clothing categories in FashionMNIST, and deploying with Torch-TensorRT accelerated inference. | [Link](pytorch/image_classification_torch.ipynb) | 12 | PyTorch | Housing Regression | Training and deploying a model to predict housing prices in the California Housing Dataset, and deploying with Torch-TensorRT accelerated inference. | [Link](pytorch/housing_regression_torch.ipynb) | 13 | Tensorflow | Image Classification | Training and deploying a model to predict hand-written digits in MNIST. | [Link](tensorflow/image_classification_tf.ipynb) | 14 | Tensorflow | Keras Preprocessing | Training and deploying a model with preprocessing layers to predict likelihood of pet adoption in the PetFinder mini dataset. | [Link](tensorflow/keras_preprocessing_tf.ipynb) | 15 | Tensorflow | Keras Resnet50 | Deploying ResNet-50 to perform flower recognition from flower images. | [Link](tensorflow/keras_resnet50_tf.ipynb) | 16 | Tensorflow | Text Classification | Training and deploying a model to perform sentiment analysis on the IMDB dataset. | [Link](tensorflow/text_classification_tf.ipynb) ## Running Locally To run the notebooks locally, please follow these instructions: #### Create environment Each notebook has a suffix `_torch`, `_tf`, or `_vllm` specifying the environment used. **For PyTorch:** ``` conda create -n spark-dl-torch -c conda-forge python=3.11 conda activate spark-dl-torch conda install -c conda-forge libstdcxx-ng pip install -r torch_requirements.txt ``` **For TensorFlow:** ``` conda create -n spark-dl-tf -c conda-forge python=3.11 conda activate spark-dl-tf conda install -c conda-forge libstdcxx-ng pip install -r tf_requirements.txt ``` **For vLLM:** ``` conda create -n spark-dl-vllm -c conda-forge python=3.11 conda activate spark-dl-vllm pip install -r vllm_requirements.txt ``` #### Start Cluster For demonstration, these instructions just use a local Standalone cluster with a single executor, but they can be run on any distributed Spark cluster. If you haven't already, [install Spark](https://spark.apache.org/downloads.html) on your system. ```shell # Replace with your Spark installation path export SPARK_HOME=
``` ```shell # Configure and start cluster export MASTER=spark://$(hostname):7077 export SPARK_WORKER_INSTANCES=1 export CORES_PER_WORKER=8 export SPARK_WORKER_OPTS="-Dspark.worker.resource.gpu.amount=1 \ -Dspark.worker.resource.gpu.discoveryScript=$SPARK_HOME/examples/src/main/scripts/getGpusResources.sh" ${SPARK_HOME}/sbin/start-master.sh; ${SPARK_HOME}/sbin/start-worker.sh -c ${CORES_PER_WORKER} -m 16G ${MASTER} ``` The notebooks are ready to run! Each notebook has a cell to connect to the standalone cluster and create a SparkSession. **Notes**: - Please create separate environments for different frameworks as specified above. This will avoid conflicts between the CUDA libraries bundled with their respective versions. - `requirements.txt` installs pyspark>=3.4.0. Make sure the installed PySpark version is compatible with your system's Spark installation. - The notebooks require an NVIDIA GPU on your system. - The PyTorch notebooks include model compilation and accelerated inference with TensorRT. While not included in the notebooks, Tensorflow also supports [integration with TensorRT](https://docs.nvidia.com/deeplearning/frameworks/tf-trt-user-guide/index.html), but as of writing it is not supported in TF==2.17.0. - Note that some Huggingface models may be gated and will require a login, e.g.,: ```python from huggingface_hub import login login() ``` ## Running on Cloud Environments We also provide instructions to run the notebooks on CSP Spark environments. See the instructions for [Databricks](databricks/README.md) and [GCP Dataproc](dataproc/README.md). ## Inference Serving drawing The notebooks demonstrate deploying models on an inference server as a sidecar process, as shown above. The process looks like this: - Prior to inference, launch a server process on each node. - Define a predict function, which creates a client that sends/receives inference requests to the local server. - Wrap the predict function in a predict_batch_udf to launch parallel inference requests using Spark. This logically separates the CPU parallelism from the GPU parallelism for streamlined deployment. For instance, say we want to run a 20GB model on a GPU with 25GB of memory. - With `predict_batch_udf` using an in-process framework, we must set `spark.task.resource.gpu.amount=1`, which limits parallelism to 1 task (i.e. model instance) per GPU for the entire application due to memory constraints. - Using an inference server, we can set `spark.task.resource.gpu.amount=(num_cores)` to leverage all the executor CPUs for Dataframe operations (reading/preprocessing/writing), while the server loads 1 instance of the model on the GPU for inference. See [`server_utils.py`](server_utils.py) for more details on how we manage servers on the Spark cluster. ### Triton Inference Server Each notebook has a section that demonstrates model serving with [Triton Inference Server](https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/index.html), an open-source serving platform for deep learning models, which includes many [major features](https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/index.html#triton-major-features) to streamline inference. To leverage Triton through Python, we use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles communication with the Triton server. Triton allows you to define a Python function encapsulating the inference logic, including complex pipelines such as model ensembles or concurrent execution. For more information on how PyTriton works, see the [PyTriton docs](https://triton-inference-server.github.io/pytriton/latest/high_level_design/). ### vLLM Server The vLLM notebooks demonstrate serving with [vLLM serve](https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html), an OpenAI-compatible HTTP server to deploy vLLM models. If you do not need the custom inference logic provided by Triton, vLLM serve is a straightforward alternative to deploy a vLLM-compatible LLM. ================================================ FILE: examples/ML+DL-Examples/Spark-DL/dl_inference/databricks/README.md ================================================ # Spark DL Inference on Databricks **Note**: fields in \ require user inputs. Make sure you are in [this](./) directory. ## Setup 1. Install the latest [databricks-cli](https://docs.databricks.com/en/dev-tools/cli/tutorial.html) and configure for your workspace. 2. Specify the path to your Databricks workspace: ```shell export WS_PATH= ``` ```shell export SPARK_DL_WS=${WS_PATH}/spark-dl databricks workspace mkdirs ${SPARK_DL_WS} ``` 3. Specify the local paths to the notebook you wish to run. As an example for a PyTorch notebook: ```shell export NOTEBOOK_SRC=
``` 4. Specify the framework to torch, tf, or vllm, corresponding to the notebook you wish to run. Continuing with the PyTorch example: ```shell export FRAMEWORK=torch ``` This will tell the init script which libraries to install on the cluster. 5. Copy the notebook, the utils file, and the init script to the Databricks Workspace: ```shell databricks workspace import ${SPARK_DL_WS}/$(basename "$NOTEBOOK_SRC") --format JUPYTER --file $NOTEBOOK_SRC databricks workspace import ${SPARK_DL_WS}/server_utils.py --format AUTO --file $(realpath ../server_utils.py) databricks workspace import ${SPARK_DL_WS}/init_spark_dl.sh --format AUTO --file $(pwd)/setup/init_spark_dl.sh ``` 6. Launch the cluster with the provided script with the argument `aws` or `azure` based on your provider. Modify the scripts if you do not have the specific instance types. By default the script will create a cluster with 2 A10 workers and 1 A10 driver. ```shell cd setup chmod +x start_cluster.sh ./start_cluster.sh aws # or ./start_cluster.sh azure ``` To create a cluster capable of tensor parallelism, include the argument `tp` to acquire multiple GPUs per node: ```shell ./start_cluster.sh aws tp # or ./start_cluster.sh azure tp ``` In this case, the Azure worker nodes will have 2 GPUs each and the AWS workers will have 4 GPUs each (since AWS does not have an instance type with 2 GPUs) to run the tensor parallel example.* 7. Navigate to the notebook in your workspace and attach it to the cluster. The default cluster name is `spark-dl-inference-$FRAMEWORK`. *Note that the RAPIDS Accelerator for Apache Spark is not compatible with this case, since [multiple GPUs per executor are not yet supported](https://docs.nvidia.com/spark-rapids/user-guide/latest/faq.html#why-are-multiple-gpus-per-executor-not-supported). ================================================ FILE: examples/ML+DL-Examples/Spark-DL/dl_inference/databricks/setup/init_spark_dl.sh ================================================ #!/bin/bash # Copyright (c) 2025, NVIDIA CORPORATION. set -euxo pipefail # install requirements sudo /databricks/python3/bin/pip3 install --upgrade pip if [[ "${FRAMEWORK}" == "torch" ]]; then cat < temp_requirements.txt datasets==3.* transformers nvidia-pytriton torch<=2.5.1 torchvision --extra-index-url https://download.pytorch.org/whl/cu121 torch-tensorrt tensorrt --extra-index-url https://download.pytorch.org/whl/cu121 sentence_transformers sentencepiece nvidia-modelopt[all] --extra-index-url https://pypi.nvidia.com EOF elif [[ "${FRAMEWORK}" == "tf" ]]; then cat < temp_requirements.txt datasets==3.* transformers nvidia-pytriton EOF elif [[ "${FRAMEWORK}" == "vllm" ]]; then cat < temp_requirements.txt vllm==0.8.2 EOF else echo "Please export FRAMEWORK as torch, tf, or vllm per README" exit 1 fi sudo /databricks/python3/bin/pip3 install --upgrade --force-reinstall -r temp_requirements.txt rm temp_requirements.txt ================================================ FILE: examples/ML+DL-Examples/Spark-DL/dl_inference/databricks/setup/start_cluster.sh ================================================ #!/bin/bash # Copyright (c) 2025, NVIDIA CORPORATION. set -eo pipefail if [ $# -lt 1 ] || [ $# -gt 2 ]; then echo "Usage: $0 [tp]" exit 1 fi CLOUD_PROVIDER=$1 TENSOR_PARALLEL=false # Check if the second argument is "tp" for tensor parallelism if [ $# -eq 2 ] && [ "$2" == "tp" ]; then TENSOR_PARALLEL=true fi if [[ "${FRAMEWORK}" != "vllm" && "${FRAMEWORK}" != "torch" && "${FRAMEWORK}" != "tf" ]]; then echo "Error: Please export FRAMEWORK as torch, tf, or vllm per README" exit 1 fi # Modify the node types below if your Databricks account does not have these specific instance types. # Modify EXECUTOR_CORES=(cores per node) and EXECUTOR_GPU_AMT=(GPUs per node) accordingly. # We recommend selecting instances with A10/L4+ GPUs for these examples. if [[ "${CLOUD_PROVIDER}" == "aws" ]]; then DRIVER_NODE_TYPE="g5.2xlarge" if [[ "${TENSOR_PARALLEL}" == "true" ]]; then # For tensor-parallelism examples, we default to the g5.12xlarge with 4 A10 GPUs (AWS does not have 2-GPU instances). NODE_TYPE="g5.12xlarge" EXECUTOR_CORES=48 EXECUTOR_GPU_AMT=4 else NODE_TYPE="g5.4xlarge" EXECUTOR_CORES=16 EXECUTOR_GPU_AMT=1 fi elif [[ "${CLOUD_PROVIDER}" == "azure" ]]; then DRIVER_NODE_TYPE="Standard_NV36ads_A10_v5" if [[ "${TENSOR_PARALLEL}" == "true" ]]; then # For tensor-parallelism examples, we default to the Standard_NV72ads_A10_v5 with 2 A10 GPUs. NODE_TYPE="Standard_NV72ads_A10_v5" EXECUTOR_CORES=72 EXECUTOR_GPU_AMT=2 else NODE_TYPE="Standard_NV36ads_A10_v5" EXECUTOR_CORES=36 EXECUTOR_GPU_AMT=1 fi else echo "Error: Cloud provider must be either 'aws' or 'azure'" exit 1 fi CLUSTER_SUFFIX="${FRAMEWORK}" if [[ "${TENSOR_PARALLEL}" == "true" ]]; then CLUSTER_SUFFIX="${FRAMEWORK}-tp" fi # Task GPU amount = Executor GPU amount / Executor cores TASK_GPU_AMT=$(awk "BEGIN {print ${EXECUTOR_GPU_AMT}/${EXECUTOR_CORES}}") json_config=$(cat < require user inputs. Make sure you are in [this](./) directory. #### Setup GCloud CLI 1. Install the latest [gcloud-cli](https://cloud.google.com/sdk/docs/install) and initialize with `gcloud init`. 2. Configure the following settings: ```shell export PROJECT= export DATAPROC_REGION= export COMPUTE_REGION= export COMPUTE_ZONE= gcloud config set project ${PROJECT} gcloud config set dataproc/region ${DATAPROC_REGION} gcloud config set compute/region ${COMPUTE_REGION} gcloud config set compute/zone ${COMPUTE_ZONE} ``` #### Copy files to GCS 3. Create a GCS bucket if you don't already have one: ```shell export GCS_BUCKET= gcloud storage buckets create gs://${GCS_BUCKET} ``` 4. Specify the local path to the notebook(s) and copy to the GCS bucket. As an example for a torch notebook: ```shell export SPARK_DL_HOME=${GCS_BUCKET}/spark-dl gcloud storage cp
gs://${SPARK_DL_HOME}/notebooks/ ``` Repeat this step for any notebooks you wish to run. All notebooks under `gs://${SPARK_DL_HOME}/notebooks/` will be copied to the master node during initialization. 5. Copy the utils file to the GCS bucket. ```shell gcloud storage cp $(realpath ../server_utils.py) gs://${SPARK_DL_HOME}/ ``` #### Start cluster and run 5. Specify the framework to use (torch, tf, or vllm), which will determine what libraries to install on the cluster. For example: ```shell export FRAMEWORK=torch ``` Run the cluster startup script. The script will also retrieve and use the [spark-rapids initialization script](https://github.com/GoogleCloudDataproc/initialization-actions/blob/master/spark-rapids/spark-rapids.sh) to setup GPU resources. The script will create 2 L4 worker nodes and 1 L4 driver node by default, named `${USER}-spark-dl-inference-${FRAMEWORK}`. ```shell cd setup chmod +x start_cluster.sh ./start_cluster.sh ``` To create a cluster capable of tensor parallelism, include the argument `tp` to acquire multiple GPUs per node: ```shell ./start_cluster.sh tp ``` In this case, the worker nodes will have 2 L4s each to run the tensor parallel example.* 7. Browse to the Jupyter web UI: - Go to `Dataproc` > `Clusters` > `(Cluster Name)` > `Web Interfaces` > `Jupyter/Lab` Or, get the link by running this command (under httpPorts > Jupyter/Lab): ```shell gcloud dataproc clusters describe ${CLUSTER_NAME} --region=${COMPUTE_REGION} ``` 8. Open and run the notebook interactively with the **Python 3 kernel**. The notebooks can be found under `Local Disk/spark-dl-notebooks` on the master node (folder icon on the top left > Local Disk). *Note that the RAPIDS Accelerator for Apache Spark is not applicable in this case, since [multiple GPUs per executor are not yet supported](https://docs.nvidia.com/spark-rapids/user-guide/latest/faq.html#why-are-multiple-gpus-per-executor-not-supported). ================================================ FILE: examples/ML+DL-Examples/Spark-DL/dl_inference/dataproc/setup/init_spark_dl.sh ================================================ #!/bin/bash # Copyright (c) 2025, NVIDIA CORPORATION. set -euxo pipefail function get_metadata_attribute() { local -r attribute_name=$1 local -r default_value=$2 /usr/share/google/get_metadata_value "attributes/${attribute_name}" || echo -n "${default_value}" } SPARK_DL_HOME=$(get_metadata_attribute spark-dl-home UNSET) if [[ ${SPARK_DL_HOME} == "UNSET" ]]; then echo "Please set --metadata spark-dl-home" exit 1 fi GCS_BUCKET=$(get_metadata_attribute gcs-bucket UNSET) if [[ ${GCS_BUCKET} == "UNSET" ]]; then echo "Please set --metadata gcs-bucket" exit 1 fi REQUIREMENTS=$(get_metadata_attribute requirements UNSET) if [[ ${REQUIREMENTS} == "UNSET" ]]; then echo "Please set --metadata requirements" exit 1 fi # mount gcs bucket as fuse export GCSFUSE_REPO=gcsfuse-`lsb_release -c -s` echo "deb https://packages.cloud.google.com/apt $GCSFUSE_REPO main" | sudo tee /etc/apt/sources.list.d/gcsfuse.list curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add - sudo apt-get update sudo apt-get install -y fuse gcsfuse sudo mkdir -p /mnt/gcs gcsfuse -o allow_other --implicit-dirs ${GCS_BUCKET} /mnt/gcs sudo chmod -R 777 /mnt/gcs # install requirements pip install --upgrade pip echo "${REQUIREMENTS}" > temp_requirements.txt pip install --upgrade --force-reinstall -r temp_requirements.txt rm temp_requirements.txt # copy notebooks to master ROLE=$(/usr/share/google/get_metadata_value attributes/dataproc-role) if [[ "${ROLE}" == 'Master' ]]; then if gsutil -q stat gs://${SPARK_DL_HOME}/notebooks/**; then mkdir spark-dl-notebooks gcloud storage cp -r gs://${SPARK_DL_HOME}/notebooks/* spark-dl-notebooks gcloud storage cp gs://${SPARK_DL_HOME}/server_utils.py . else echo "Failed to retrieve notebooks from gs://${SPARK_DL_HOME}/notebooks/" exit 1 fi fi sudo chmod -R a+rw /home/ sudo systemctl daemon-reload ================================================ FILE: examples/ML+DL-Examples/Spark-DL/dl_inference/dataproc/setup/start_cluster.sh ================================================ #!/bin/bash # Copyright (c) 2025, NVIDIA CORPORATION. set -eo pipefail TENSOR_PARALLEL=false if [[ $# -gt 0 && "$1" == "tp" ]]; then TENSOR_PARALLEL=true echo "Tensor parallelism enabled - will use larger machine types with multiple GPUs" fi # configure arguments if [[ -z ${GCS_BUCKET} ]]; then echo "Please export GCS_BUCKET per README.md" exit 1 fi if [[ -z ${FRAMEWORK} ]]; then echo "Please export FRAMEWORK as 'torch', 'tf', or 'vllm'" exit 1 fi if [[ -z ${COMPUTE_REGION} ]]; then COMPUTE_REGION=$(gcloud config get-value compute/region) if [[ -z ${COMPUTE_REGION} ]]; then echo "Please export COMPUTE_REGION per README.md or set it in gcloud config." exit 1 fi fi SPARK_DL_HOME=${SPARK_DL_HOME:-${GCS_BUCKET}/spark-dl} # copy init script to gcs gcloud storage cp init_spark_dl.sh gs://${SPARK_DL_HOME}/init/ INIT_PATH=gs://${SPARK_DL_HOME}/init/init_spark_dl.sh # retrieve and upload spark-rapids initialization script to gcs curl -LO https://raw.githubusercontent.com/GoogleCloudDataproc/initialization-actions/master/spark-rapids/spark-rapids.sh # don't enable rapids plugin by default sed -i '/spark.plugins=com.nvidia.spark.SQLPlugin/d' spark-rapids.sh gcloud storage cp spark-rapids.sh gs://${SPARK_DL_HOME}/init/ # rm spark-rapids.sh COMMON_REQUIREMENTS="numpy pandas matplotlib portalocker pyarrow pydot scikit-learn huggingface datasets==3.* transformers nvidia-pytriton" TORCH_REQUIREMENTS="${COMMON_REQUIREMENTS} torch<=2.5.1 torchvision --extra-index-url https://download.pytorch.org/whl/cu121 torch-tensorrt tensorrt --extra-index-url https://download.pytorch.org/whl/cu121 sentence_transformers sentencepiece nvidia-modelopt[all] --extra-index-url https://pypi.nvidia.com" TF_REQUIREMENTS="${COMMON_REQUIREMENTS} tensorflow[and-cuda] tf-keras" VLLM_REQUIREMENTS="datasets==3.* vllm==0.8.2" cluster_name=${USER}-spark-dl-inference-${FRAMEWORK} if [[ "${TENSOR_PARALLEL}" == "true" ]]; then cluster_name="${cluster_name}-tp" fi if [[ ${FRAMEWORK} == "torch" ]]; then requirements=${TORCH_REQUIREMENTS} echo "=========================================================" echo "Starting PyTorch cluster ${cluster_name}" echo "=========================================================" elif [[ ${FRAMEWORK} == "tf" ]]; then requirements=${TF_REQUIREMENTS} echo "=========================================================" echo "Starting Tensorflow cluster ${cluster_name}" echo "=========================================================" elif [[ ${FRAMEWORK} == "vllm" ]]; then requirements=${VLLM_REQUIREMENTS} echo "=========================================================" echo "Starting vLLM cluster ${cluster_name}" echo "=========================================================" else echo "Please export FRAMEWORK as torch, tf, or vllm" exit 1 fi if [[ "${TENSOR_PARALLEL}" == "true" ]]; then WORKER_MACHINE_TYPE="g2-standard-24" # 2 L4 GPUs per node else WORKER_MACHINE_TYPE="g2-standard-8" # 1 L4 GPU per node fi if gcloud dataproc clusters list | grep -q "${cluster_name}"; then echo "Cluster ${cluster_name} already exists." exit 0 fi CLUSTER_PARAMS=( --image-version=2.2-ubuntu --region "${COMPUTE_REGION}" --num-workers 2 --master-machine-type g2-standard-8 --worker-machine-type "${WORKER_MACHINE_TYPE}" --initialization-actions gs://"${SPARK_DL_HOME}"/init/spark-rapids.sh,"${INIT_PATH}" --metadata gpu-driver-provider="NVIDIA" --metadata gcs-bucket="${GCS_BUCKET}" --metadata spark-dl-home="${SPARK_DL_HOME}" --metadata requirements="${requirements}" --worker-local-ssd-interface=NVME --optional-components=JUPYTER --bucket "${GCS_BUCKET}" --enable-component-gateway --max-idle "60m" --subnet=default --no-shielded-secure-boot ) gcloud dataproc clusters create ${cluster_name} "${CLUSTER_PARAMS[@]}" ================================================ FILE: examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/conditional_generation_tf.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "id": "777fc40d", "metadata": {}, "source": [ "\n", "\n", "# PySpark Huggingface Inferencing\n", "### Conditional generation with Tensorflow\n", "\n", "In this notebook, we demonstrate distributed inference with the T5 transformer to perform sentence translation. \n", "From: https://huggingface.co/docs/transformers/model_doc/t5" ] }, { "cell_type": "markdown", "id": "05c79ac4-bf25-421e-b55e-020d6d9e15d5", "metadata": {}, "source": [ "Note that cuFFT/cuDNN/cuBLAS registration errors are expected (as of `tf=2.17.0`) and will not affect behavior, as noted in [this issue.](https://github.com/tensorflow/tensorflow/issues/62075)" ] }, { "cell_type": "code", "execution_count": 1, "id": "f6f0dbf3-712b-4c58-85eb-261ce15bb2be", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-02-04 13:53:50.831324: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", "2025-02-04 13:53:50.838528: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "2025-02-04 13:53:50.846226: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "2025-02-04 13:53:50.848585: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", "2025-02-04 13:53:50.854859: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", "To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", "2025-02-04 13:53:51.229622: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n" ] } ], "source": [ "from transformers import AutoTokenizer, TFT5ForConditionalGeneration\n", "\n", "# Manually enable Huggingface tokenizer parallelism to avoid disabling with PySpark parallelism.\n", "# See (https://github.com/huggingface/transformers/issues/5486) for more info. \n", "import os\n", "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"true\"" ] }, { "cell_type": "code", "execution_count": 2, "id": "275890d7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.17.0\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", "I0000 00:00:1738706031.770264 3625306 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1738706031.793270 3625306 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1738706031.796251 3625306 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n" ] } ], "source": [ "import tensorflow as tf\n", "\n", "# Enable GPU memory growth\n", "gpus = tf.config.experimental.list_physical_devices('GPU')\n", "if gpus:\n", " try:\n", " for gpu in gpus:\n", " tf.config.experimental.set_memory_growth(gpu, True)\n", " except RuntimeError as e:\n", " print(e)\n", " \n", "print(tf.__version__)" ] }, { "cell_type": "code", "execution_count": 3, "id": "2684fb41-9467-40c0-9d7e-a1cc867c5a3c", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "I0000 00:00:1738706032.132191 3625306 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1738706032.134996 3625306 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1738706032.137528 3625306 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1738706032.251302 3625306 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1738706032.252345 3625306 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1738706032.253281 3625306 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "2025-02-04 13:53:52.254192: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 43462 MB memory: -> device: 0, name: NVIDIA RTX A6000, pci bus id: 0000:01:00.0, compute capability: 8.6\n", "All PyTorch model weights were used when initializing TFT5ForConditionalGeneration.\n", "\n", "All the weights of TFT5ForConditionalGeneration were initialized from the PyTorch model.\n", "If your task is similar to the task the model of the checkpoint was trained on, you can already use TFT5ForConditionalGeneration for predictions without further training.\n" ] } ], "source": [ "tokenizer = AutoTokenizer.from_pretrained(\"google-t5/t5-small\")\n", "model = TFT5ForConditionalGeneration.from_pretrained(\"google-t5/t5-small\")\n", "\n", "task_prefix = \"translate English to German: \"\n", "\n", "lines = [\n", " \"The house is wonderful\",\n", " \"Welcome to NYC\",\n", " \"HuggingFace is a company\"\n", "]\n", "\n", "input_sequences = [task_prefix + l for l in lines]" ] }, { "cell_type": "code", "execution_count": 4, "id": "6eb2dfdb-0ad3-4d0f-81a4-268d92c53759", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", "I0000 00:00:1738706033.555987 3625654 service.cc:146] XLA service 0x712d300025f0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:\n", "I0000 00:00:1738706033.556005 3625654 service.cc:154] StreamExecutor device (0): NVIDIA RTX A6000, Compute Capability 8.6\n", "2025-02-04 13:53:53.558887: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.\n", "2025-02-04 13:53:53.569767: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:531] Loaded cuDNN version 8907\n", "I0000 00:00:1738706033.604327 3625654 device_compiler.h:188] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.\n" ] } ], "source": [ "inputs = tokenizer(input_sequences, \n", " padding=True,\n", " return_tensors=\"tf\")\n", "outputs = model.generate(input_ids=inputs[\"input_ids\"], attention_mask=inputs[\"attention_mask\"], max_length=128)" ] }, { "cell_type": "code", "execution_count": 5, "id": "720158d4-e0e0-4904-b096-e5aede756afd", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['Das Haus ist wunderbar',\n", " 'Willkommen in NYC',\n", " 'HuggingFace ist ein Unternehmen']" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "[tokenizer.decode(o, skip_special_tokens=True) for o in outputs]" ] }, { "cell_type": "markdown", "id": "546eabe0", "metadata": {}, "source": [ "## PySpark" ] }, { "cell_type": "code", "execution_count": 6, "id": "68121304-f1df-466e-9347-c9d2b36a9b3a", "metadata": {}, "outputs": [], "source": [ "from pyspark.sql.types import *\n", "from pyspark import SparkConf\n", "from pyspark.sql import SparkSession\n", "from pyspark.sql.functions import pandas_udf, col, struct\n", "from pyspark.ml.functions import predict_batch_udf" ] }, { "cell_type": "code", "execution_count": 7, "id": "2f6db1f0-7d68-4af7-8bd6-c9fa45906c61", "metadata": {}, "outputs": [], "source": [ "import json\n", "import pandas as pd\n", "import datasets\n", "from datasets import load_dataset\n", "datasets.disable_progress_bars()" ] }, { "cell_type": "markdown", "id": "0d636975", "metadata": {}, "source": [ "Check the cluster environment to handle any platform-specific Spark configurations." ] }, { "cell_type": "code", "execution_count": 8, "id": "ca351245", "metadata": {}, "outputs": [], "source": [ "on_databricks = os.environ.get(\"DATABRICKS_RUNTIME_VERSION\", False)\n", "on_dataproc = os.environ.get(\"DATAPROC_IMAGE_VERSION\", False)\n", "on_standalone = not (on_databricks or on_dataproc)" ] }, { "cell_type": "markdown", "id": "d3199f8b", "metadata": {}, "source": [ "#### Create Spark Session\n", "\n", "For local standalone clusters, we'll connect to the cluster and create the Spark Session. \n", "For CSP environments, Spark will either be preconfigured (Databricks) or we'll need to create the Spark Session (Dataproc)." ] }, { "cell_type": "code", "execution_count": 9, "id": "6279a849", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "25/02/04 13:53:54 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)\n", "25/02/04 13:53:54 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\n", "Setting default log level to \"WARN\".\n", "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n", "25/02/04 13:53:55 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n" ] } ], "source": [ "conf = SparkConf()\n", "\n", "if 'spark' not in globals():\n", " if on_standalone:\n", " import socket\n", " \n", " conda_env = os.environ.get(\"CONDA_PREFIX\")\n", " hostname = socket.gethostname()\n", " conf.setMaster(f\"spark://{hostname}:7077\")\n", " conf.set(\"spark.pyspark.python\", f\"{conda_env}/bin/python\")\n", " conf.set(\"spark.pyspark.driver.python\", f\"{conda_env}/bin/python\")\n", " elif on_dataproc:\n", " conf.set(\"spark.executorEnv.TF_GPU_ALLOCATOR\", \"cuda_malloc_async\")\n", "\n", " conf.set(\"spark.executor.cores\", \"8\")\n", " conf.set(\"spark.task.resource.gpu.amount\", \"0.125\")\n", " conf.set(\"spark.executor.resource.gpu.amount\", \"1\")\n", " conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"true\")\n", " conf.set(\"spark.python.worker.reuse\", \"true\")\n", "\n", "conf.set(\"spark.sql.execution.arrow.maxRecordsPerBatch\", \"1000\")\n", "spark = SparkSession.builder.appName(\"spark-dl-examples\").config(conf=conf).getOrCreate()\n", "sc = spark.sparkContext" ] }, { "cell_type": "markdown", "id": "7f311650", "metadata": {}, "source": [ "Load the IMBD Movie Reviews dataset from Huggingface." ] }, { "cell_type": "code", "execution_count": 10, "id": "b8453111-d068-49bb-ab91-8ae3d8bcdb7a", "metadata": {}, "outputs": [], "source": [ "dataset = load_dataset(\"imdb\", split=\"test\")\n", "dataset = dataset.to_pandas().drop(columns=\"label\")" ] }, { "cell_type": "markdown", "id": "6fd5b472-47e8-4804-9907-772793fedb2b", "metadata": {}, "source": [ "### Create PySpark DataFrame" ] }, { "cell_type": "code", "execution_count": 11, "id": "d24d9404-0269-476e-a9dd-1842667c915a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "StructType([StructField('text', StringType(), True)])" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = spark.createDataFrame(dataset).repartition(8)\n", "df.schema" ] }, { "cell_type": "code", "execution_count": 12, "id": "c76314b7", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "25000" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.count()" ] }, { "cell_type": "code", "execution_count": 13, "id": "4384c762-1f79-4f60-876c-94b1f552e8fb", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "25/02/04 13:54:01 WARN TaskSetManager: Stage 6 contains a task of very large size (4021 KiB). The maximum recommended task size is 1000 KiB.\n" ] }, { "data": { "text/plain": [ "[Row(text=\"Anyone remember the first CKY, CKY2K etc..? Back when it was about making crazy cool stuff, rather than watching Bam Margera act like a douchebag, spoiled 5 year old, super/rock-star wannabe.

The show used to be awesome, however, Bam's fame and wealth has led him to believe, that we now enjoy him acting childish and idiotic, more than actual cool stuff, that used to be in ex. CKY2K.

The acts are so repetitive, there's like nothing new, except annoying stupidity and rehearsed comments... The only things we see is Bam Margera, so busy showing us how much he doesn't care, how much money he got or whatsoever.

I really got nothing much left to say except, give us back CKY2K, cause Bam suck..

I enjoy watching Steve-o, Knoxville etc. a thousand times more.\")]" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.take(1)" ] }, { "cell_type": "markdown", "id": "42ba3513-82dd-47e7-8193-eb4389458757", "metadata": {}, "source": [ "### Save the test dataset as parquet files" ] }, { "cell_type": "code", "execution_count": 14, "id": "e7eec8ec-4126-4890-b957-025809fad67d", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "25/02/04 13:54:02 WARN TaskSetManager: Stage 9 contains a task of very large size (4021 KiB). The maximum recommended task size is 1000 KiB.\n" ] } ], "source": [ "data_path = \"spark-dl-datasets/imdb_test\"\n", "if on_databricks:\n", " dbutils.fs.mkdirs(\"/FileStore/spark-dl-datasets\")\n", " data_path = \"dbfs:/FileStore/\" + data_path\n", "\n", "df.write.mode(\"overwrite\").parquet(data_path)" ] }, { "cell_type": "markdown", "id": "078425e1", "metadata": {}, "source": [ "#### Load and preprocess DataFrame\n", "\n", "Define our preprocess function. We'll take the first sentence from each sample as our input for translation." ] }, { "cell_type": "code", "execution_count": 15, "id": "b9a0889a-35b4-493a-8197-1146fc7efd53", "metadata": {}, "outputs": [], "source": [ "def preprocess(text: pd.Series, prefix: str = \"\") -> pd.Series:\n", " @pandas_udf(\"string\")\n", " def _preprocess(text: pd.Series) -> pd.Series:\n", " return pd.Series([prefix + s.split(\".\")[0] for s in text])\n", " return _preprocess(text)" ] }, { "cell_type": "code", "execution_count": 16, "id": "c483e4d4-9ab1-416f-a766-694e17490fd3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+----------------------------------------------------------------------------------------------------+\n", "| text|\n", "+----------------------------------------------------------------------------------------------------+\n", "|Doesn't anyone bother to check where this kind of sludge comes from before blathering on about it...|\n", "|There were two things I hated about WASTED : The directing and the script . I know I`m opening my...|\n", "|I'm rather surprised that anybody found this film touching or moving.

The basic premis...|\n", "|Cultural Vandalism Is the new Hallmark production of Gulliver's Travels an act of cultural vandal...|\n", "|I was at Wrestlemania VI in Toronto as a 10 year old, and the event I saw then was pretty differe...|\n", "|This movie has been done before. It is basically a unoriginal combo of \"Napoleon Dynamite\" and \"S...|\n", "|[ as a new resolution for this year 2005, i decide to write a comment for each movie I saw in the...|\n", "|This movie is over hyped!! I am sad to say that I manage to watch the first 15 minutes of this mo...|\n", "|This show had a promising start as sort of the opposite of 'Oceans 11' but has developed into a s...|\n", "|MINOR PLOT SPOILERS AHEAD!!!

How did such talented actors get involved in such mindles...|\n", "|There is not one character on this sitcom with any redeeming qualities. They are all self-centere...|\n", "|Tommy Lee Jones was the best Woodroe and no one can play Woodroe F. Call better than he. Not only...|\n", "|My wife rented this movie and then conveniently never got to see it. If I ever want to torture he...|\n", "|This is one of those star-filled over-the-top comedies that could a) be hysterical, or b) wish th...|\n", "|This excruciatingly boring and unfunny movie made me think that Chaplin was the real Hitler, as o...|\n", "|you will likely be sorely disappointed by this sequel that's not a sequel.AWIL is a classic.but t...|\n", "|If I was British, I would be embarrassed by this portrayal of incompetence. A protection agent of...|\n", "|One of those movies in which there are no big twists whatsoever and you can predict pretty much w...|\n", "|This show is like watching someone who is in training to someday host a show. There are some good...|\n", "|Sigh. I'm baffled when I see a short like this get attention and assignments and whatnot. I saw t...|\n", "+----------------------------------------------------------------------------------------------------+\n", "only showing top 20 rows\n", "\n" ] } ], "source": [ "# Limit to N rows, since this can be slow\n", "df = spark.read.parquet(data_path).limit(256).repartition(8)\n", "df.show(truncate=100)" ] }, { "cell_type": "markdown", "id": "a9f8e538", "metadata": {}, "source": [ "Append a prefix to tell the model to translate English to French:" ] }, { "cell_type": "code", "execution_count": 17, "id": "831bc52c-a5c6-4c29-a6da-0566b5167773", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+----------------------------------------------------------------------------------------------------+\n", "| input|\n", "+----------------------------------------------------------------------------------------------------+\n", "|translate English to French: Doesn't anyone bother to check where this kind of sludge comes from ...|\n", "|translate English to French: There were two things I hated about WASTED : The directing and the s...|\n", "| translate English to French: I'm rather surprised that anybody found this film touching or moving|\n", "|translate English to French: Cultural Vandalism Is the new Hallmark production of Gulliver's Trav...|\n", "|translate English to French: I was at Wrestlemania VI in Toronto as a 10 year old, and the event ...|\n", "| translate English to French: This movie has been done before|\n", "|translate English to French: [ as a new resolution for this year 2005, i decide to write a commen...|\n", "|translate English to French: This movie is over hyped!! I am sad to say that I manage to watch th...|\n", "|translate English to French: This show had a promising start as sort of the opposite of 'Oceans 1...|\n", "|translate English to French: MINOR PLOT SPOILERS AHEAD!!!

How did such talented actors...|\n", "| translate English to French: There is not one character on this sitcom with any redeeming qualities|\n", "| translate English to French: Tommy Lee Jones was the best Woodroe and no one can play Woodroe F|\n", "| translate English to French: My wife rented this movie and then conveniently never got to see it|\n", "|translate English to French: This is one of those star-filled over-the-top comedies that could a)...|\n", "|translate English to French: This excruciatingly boring and unfunny movie made me think that Chap...|\n", "|translate English to French: you will likely be sorely disappointed by this sequel that's not a s...|\n", "|translate English to French: If I was British, I would be embarrassed by this portrayal of incomp...|\n", "|translate English to French: One of those movies in which there are no big twists whatsoever and ...|\n", "|translate English to French: This show is like watching someone who is in training to someday hos...|\n", "| translate English to French: Sigh|\n", "+----------------------------------------------------------------------------------------------------+\n", "only showing top 20 rows\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "input_df = df.select(preprocess(col(\"text\"), \"translate English to French: \").alias(\"input\")).cache()\n", "input_df.show(truncate=100)" ] }, { "cell_type": "markdown", "id": "ec53a65c", "metadata": {}, "source": [ "## Inference using Spark DL API\n", "\n", "Distributed inference using the PySpark [predict_batch_udf](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.functions.predict_batch_udf.html#pyspark.ml.functions.predict_batch_udf):\n", "\n", "- predict_batch_fn uses Tensorflow APIs to load the model and return a predict function which operates on numpy arrays \n", "- predict_batch_udf will convert the Spark DataFrame columns into numpy input batches for the predict function" ] }, { "cell_type": "code", "execution_count": 18, "id": "e7ae69d3-70c2-4765-928f-c96a7ba59829", "metadata": {}, "outputs": [], "source": [ "def predict_batch_fn():\n", " import tensorflow as tf\n", " import numpy as np\n", " from transformers import TFT5ForConditionalGeneration, AutoTokenizer\n", "\n", " # Enable GPU memory growth\n", " print(\"initializing model\")\n", " gpus = tf.config.experimental.list_physical_devices('GPU')\n", " if gpus:\n", " try:\n", " for gpu in gpus:\n", " tf.config.experimental.set_memory_growth(gpu, True)\n", " except RuntimeError as e:\n", " print(e)\n", "\n", " model = TFT5ForConditionalGeneration.from_pretrained(\"google-t5/t5-small\")\n", " tokenizer = AutoTokenizer.from_pretrained(\"google-t5/t5-small\")\n", "\n", " def predict(inputs):\n", " flattened = np.squeeze(inputs).tolist()\n", " inputs = tokenizer(flattened, \n", " padding=True, \n", " return_tensors=\"tf\")\n", " outputs = model.generate(input_ids=inputs[\"input_ids\"],\n", " attention_mask=inputs[\"attention_mask\"],\n", " max_length=128)\n", " string_outputs = np.array([tokenizer.decode(o, skip_special_tokens=True) for o in outputs])\n", " print(\"predict: {}\".format(len(flattened)))\n", " return string_outputs\n", " \n", " return predict" ] }, { "cell_type": "code", "execution_count": 19, "id": "36684f59-d947-43f8-a2e8-c7a423764e88", "metadata": {}, "outputs": [], "source": [ "generate = predict_batch_udf(predict_batch_fn,\n", " return_type=StringType(),\n", " batch_size=32)" ] }, { "cell_type": "code", "execution_count": 20, "id": "6a01c855-8fa1-4765-a3a5-2c9dd872df10", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 24:====================================> (5 + 3) / 8]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 9.07 ms, sys: 8.83 ms, total: 17.9 ms\n", "Wall time: 19.3 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "%%time\n", "# first pass caches model/fn\n", "preds = input_df.withColumn(\"preds\", generate(struct(\"input\")))\n", "results = preds.collect()" ] }, { "cell_type": "code", "execution_count": 21, "id": "d912d4b0-cd0b-44ea-859a-b23455cc2700", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 27:==================================================> (7 + 1) / 8]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 7.51 ms, sys: 4.96 ms, total: 12.5 ms\n", "Wall time: 12.4 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "%%time\n", "preds = input_df.withColumn(\"preds\", generate(\"input\"))\n", "results = preds.collect()" ] }, { "cell_type": "code", "execution_count": 22, "id": "5fe3d88b-30f7-468f-8db8-1f4118d0f26c", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 30:=====================> (3 + 5) / 8]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 5.46 ms, sys: 5.98 ms, total: 11.4 ms\n", "Wall time: 11.4 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "%%time\n", "preds = input_df.withColumn(\"preds\", generate(col(\"input\")))\n", "results = preds.collect()" ] }, { "cell_type": "code", "execution_count": 23, "id": "4ad9b365-4b9a-438e-8fdf-47da55cb1cf4", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 33:> (0 + 1) / 1]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "+--------------------------------------------------+--------------------------------------------------+\n", "| input| preds|\n", "+--------------------------------------------------+--------------------------------------------------+\n", "|translate English to French: Doesn't anyone bot...|Ne s'ennuie-t-il pas de vérifier où viennent ce...|\n", "|translate English to French: There were two thi...|Il y avait deux choses que j'ai hâte de voir : ...|\n", "|translate English to French: I'm rather surpris...|Je suis plutôt surpris que quelqu'un ait trouvé...|\n", "|translate English to French: Cultural Vandalism...|Vandalisme culturel La nouvelle production Hall...|\n", "|translate English to French: I was at Wrestlema...|J'étais à Wrestlemania VI à Toronto en 10 ans, ...|\n", "|translate English to French: This movie has bee...| Ce film a été réalisé avant|\n", "|translate English to French: [ as a new resolut...|[ en tant que nouvelle résolution pour cette an...|\n", "|translate English to French: This movie is over...|Je suis triste de dire que je parviens à regard...|\n", "|translate English to French: This show had a pr...|Ce spectacle a eu un début prometteur en l'espè...|\n", "|translate English to French: MINOR PLOT SPOILER...|br />br /> Comment ces acteurs talentueux ont-i...|\n", "|translate English to French: There is not one c...|Il n'y a pas d'un personnage sur ce sitcom ayan...|\n", "|translate English to French: Tommy Lee Jones wa...|Tommy Lee Jones était le meilleur Woodroe et pe...|\n", "|translate English to French: My wife rented thi...|Ma femme a loué ce film et n'a jamais pu le voi...|\n", "|translate English to French: This is one of tho...|C’est l’une des comédies en étoiles à l’étoile ...|\n", "|translate English to French: This excruciatingl...|Ce film excruciant ennuyant et infaillible m’a ...|\n", "|translate English to French: you will likely be...|Vous serez probablement très déçu par cette séq...|\n", "|translate English to French: If I was British, ...|Si j'étais britannique, je seraitis embarrassé ...|\n", "|translate English to French: One of those movie...|Un des films dans lesquels il n'y a pas de gros...|\n", "|translate English to French: This show is like ...|Ce spectacle ressemble à l'observation d'une pe...|\n", "| translate English to French: Sigh| Pesée|\n", "+--------------------------------------------------+--------------------------------------------------+\n", "only showing top 20 rows\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "preds.show(truncate=50)" ] }, { "cell_type": "code", "execution_count": 24, "id": "1eb0c83b-d91b-4f8c-a5e7-c35f55c88108", "metadata": {}, "outputs": [], "source": [ "input_df2 = df.select(preprocess(col(\"text\"), \"translate English to German: \").alias(\"input\")).cache()" ] }, { "cell_type": "code", "execution_count": 25, "id": "6f6b70f9-188a-402b-9143-78a5788140e4", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 36:==================================================> (7 + 1) / 8]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 9.1 ms, sys: 4.04 ms, total: 13.1 ms\n", "Wall time: 14.9 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "%%time\n", "# first pass caches model/fn\n", "preds = input_df2.withColumn(\"preds\", generate(struct(\"input\")))\n", "result = preds.collect()" ] }, { "cell_type": "code", "execution_count": 26, "id": "031a6a5e-7999-4653-b394-19ed478d8c96", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 39:==================================================> (7 + 1) / 8]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 6.62 ms, sys: 5.23 ms, total: 11.9 ms\n", "Wall time: 11.9 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "%%time\n", "preds = input_df2.withColumn(\"preds\", generate(\"input\"))\n", "result = preds.collect()" ] }, { "cell_type": "code", "execution_count": 27, "id": "229b6515-82f6-4e9c-90f0-a9c3cfb26301", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 42:==============> (2 + 6) / 8]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 8.67 ms, sys: 3.27 ms, total: 11.9 ms\n", "Wall time: 11.7 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "%%time\n", "preds = input_df2.withColumn(\"preds\", generate(col(\"input\")))\n", "result = preds.collect()" ] }, { "cell_type": "code", "execution_count": 28, "id": "8be750ac-fa39-452e-bb4c-c2270bc2f70d", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 45:> (0 + 1) / 1]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "+--------------------------------------------------+--------------------------------------------------+\n", "| input| preds|\n", "+--------------------------------------------------+--------------------------------------------------+\n", "|translate English to German: Doesn't anyone bot...|Warum hat man sich nicht angeschaut, woher der ...|\n", "|translate English to German: There were two thi...|Es gab zwei Dinge, die ich hat an WASTED gehass...|\n", "|translate English to German: I'm rather surpris...|Ich bin ziemlich überrascht, dass jemand diesen...|\n", "|translate English to German: Cultural Vandalism...|Kultureller Vandalismus Ist die neue Hallmark-P...|\n", "|translate English to German: I was at Wrestlema...|Ich war als 10 Jahre alt bei Wrestlemania VI in...|\n", "|translate English to German: This movie has bee...| Dieser Film wurde bereits vorgenommen|\n", "|translate English to German: [ as a new resolut...|[ als neue Entschließung für dieses Jahr 2005, ...|\n", "|translate English to German: This movie is over...|Ich hoffe, dass ich die ersten 15 Minuten diese...|\n", "|translate English to German: This show had a pr...|Diese Show hatte einen vielversprechenden Start...|\n", "|translate English to German: MINOR PLOT SPOILER...|br />br />Wie haben sich so talentierte Schausp...|\n", "|translate English to German: There is not one c...|Es gibt keinen Charakter auf dieser Seite mit i...|\n", "|translate English to German: Tommy Lee Jones wa...|Tommy Lee Jones war der beste Woodroe und niema...|\n", "|translate English to German: My wife rented thi...|Meine Frau hat diesen Film vermietet und dann b...|\n", "|translate English to German: This is one of tho...|Dies ist eines der Sterne-gefüllten über-the-to...|\n", "|translate English to German: This excruciatingl...|Dieser schreckliche langweilige und unfunnelnde...|\n", "|translate English to German: you will likely be...|Sie werden wahrscheinlich ernsthaft enttäuscht ...|\n", "|translate English to German: If I was British, ...|Wenn ich Britisch wäre, wäre ich beschämt über ...|\n", "|translate English to German: One of those movie...|Einer der Filme, in denen es keine großen Drehu...|\n", "|translate English to German: This show is like ...|Diese Show ist wie ein jemanden, der in Ausbild...|\n", "| translate English to German: Sigh| Segnen|\n", "+--------------------------------------------------+--------------------------------------------------+\n", "only showing top 20 rows\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "preds.show(truncate=50)" ] }, { "cell_type": "markdown", "id": "f5803188", "metadata": {}, "source": [ "## Using Triton Inference Server\n", "In this section, we demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL. \n", "We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server. \n", "\n", "The process looks like this:\n", "- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.\n", "- Define a Triton inference function, which contains a client that binds to the local server on a given node and sends inference requests.\n", "- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.\n", "- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.\n", "\n", "\"drawing\"" ] }, { "cell_type": "code", "execution_count": 29, "id": "6d09f972", "metadata": {}, "outputs": [], "source": [ "from functools import partial" ] }, { "cell_type": "markdown", "id": "2964ffee", "metadata": {}, "source": [ "Import the helper class from server_utils.py:" ] }, { "cell_type": "code", "execution_count": 30, "id": "f1083dc8", "metadata": {}, "outputs": [], "source": [ "sc.addPyFile(\"server_utils.py\")\n", "\n", "from server_utils import TritonServerManager" ] }, { "cell_type": "markdown", "id": "066c8695", "metadata": {}, "source": [ "Define the Triton Server function:" ] }, { "cell_type": "code", "execution_count": 31, "id": "afd00b7e-8150-4c95-a2e4-037e9c90f92a", "metadata": {}, "outputs": [], "source": [ "def triton_server(ports):\n", " import time\n", " import signal\n", " import numpy as np\n", " import tensorflow as tf\n", " from transformers import TFT5ForConditionalGeneration, AutoTokenizer\n", " from pytriton.decorators import batch\n", " from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor\n", " from pytriton.triton import Triton, TritonConfig\n", " from pyspark import TaskContext\n", "\n", " print(f\"SERVER: Initializing Conditional Generation model on worker {TaskContext.get().partitionId()}.\")\n", "\n", " # Enable GPU memory growth\n", " gpus = tf.config.experimental.list_physical_devices('GPU')\n", " if gpus:\n", " try:\n", " for gpu in gpus:\n", " tf.config.experimental.set_memory_growth(gpu, True)\n", " except RuntimeError as e:\n", " print(e)\n", " \n", " tokenizer = AutoTokenizer.from_pretrained(\"google-t5/t5-small\")\n", " model = TFT5ForConditionalGeneration.from_pretrained(\"google-t5/t5-small\")\n", "\n", " @batch\n", " def _infer_fn(**inputs):\n", " sentences = np.squeeze(inputs[\"text\"]).tolist()\n", " print(f\"SERVER: Received batch of size {len(sentences)}\")\n", " decoded_sentences = [s.decode(\"utf-8\") for s in sentences]\n", " inputs = tokenizer(decoded_sentences,\n", " padding=True,\n", " return_tensors=\"tf\")\n", " output_ids = model.generate(input_ids=inputs[\"input_ids\"],\n", " attention_mask=inputs[\"attention_mask\"],\n", " max_length=128)\n", " outputs = np.array([[tokenizer.decode(o, skip_special_tokens=True)] for o in output_ids])\n", " return {\n", " \"translations\": outputs,\n", " }\n", "\n", " workspace_path = f\"/tmp/triton_{time.strftime('%m_%d_%M_%S')}\"\n", " triton_conf = TritonConfig(http_port=ports[0], grpc_port=ports[1], metrics_port=ports[2])\n", " with Triton(config=triton_conf, workspace=workspace_path) as triton:\n", " triton.bind(\n", " model_name=\"ConditionalGeneration\",\n", " infer_func=_infer_fn,\n", " inputs=[\n", " Tensor(name=\"text\", dtype=object, shape=(-1,)),\n", " ],\n", " outputs=[\n", " Tensor(name=\"translations\", dtype=object, shape=(-1,)),\n", " ],\n", " config=ModelConfig(\n", " max_batch_size=64,\n", " batcher=DynamicBatcher(max_queue_delay_microseconds=5000), # 5ms\n", " ),\n", " strict=True,\n", " )\n", "\n", " def _stop_triton(signum, frame):\n", " # The server manager sends SIGTERM to stop the server; this function ensures graceful cleanup.\n", " print(\"SERVER: Received SIGTERM. Stopping Triton server.\")\n", " triton.stop()\n", "\n", " signal.signal(signal.SIGTERM, _stop_triton)\n", "\n", " print(\"SERVER: Serving inference\")\n", " triton.serve()" ] }, { "cell_type": "markdown", "id": "527da1b0", "metadata": {}, "source": [ "#### Start Triton servers" ] }, { "cell_type": "markdown", "id": "4142ebfc", "metadata": {}, "source": [ "The `TritonServerManager` will handle the lifecycle of Triton server instances across the Spark cluster:\n", "- Find available ports for HTTP/gRPC/metrics\n", "- Deploy a server on each node via stage-level scheduling\n", "- Gracefully shutdown servers across nodes" ] }, { "cell_type": "code", "execution_count": null, "id": "3d522f30", "metadata": {}, "outputs": [], "source": [ "model_name = \"ConditionalGeneration\"\n", "server_manager = TritonServerManager(model_name=model_name)" ] }, { "cell_type": "code", "execution_count": null, "id": "7c18994c", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-02-07 11:03:44,809 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\n", "2025-02-07 11:03:44,810 - INFO - Starting 1 servers.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "data": { "text/plain": [ "{'cb4ae00-lcedt': (2020631, [7000, 7001, 7002])}" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Returns {'hostname', (server_pid, [http_port, grpc_port, metrics_port])}\n", "server_manager.start_servers(triton_server)" ] }, { "cell_type": "markdown", "id": "3f284eb3", "metadata": {}, "source": [ "#### Define client function" ] }, { "cell_type": "markdown", "id": "237e56dd", "metadata": {}, "source": [ "Get the hostname -> url mapping from the server manager:" ] }, { "cell_type": "code", "execution_count": null, "id": "826db582", "metadata": {}, "outputs": [], "source": [ "host_to_http_url = server_manager.host_to_http_url # or server_manager.host_to_grpc_url" ] }, { "cell_type": "markdown", "id": "f3f58e7b", "metadata": {}, "source": [ "Define the Triton inference function, which returns a predict function for batch inference through the server:" ] }, { "cell_type": "code", "execution_count": 36, "id": "aff88b3f", "metadata": {}, "outputs": [], "source": [ "def triton_fn(model_name, host_to_url):\n", " import socket\n", " import numpy as np\n", " from pytriton.client import ModelClient\n", "\n", " url = host_to_url[socket.gethostname()]\n", " print(f\"Connecting to Triton model {model_name} at {url}.\")\n", "\n", " def infer_batch(inputs):\n", " with ModelClient(url, model_name, inference_timeout_s=240) as client:\n", " flattened = np.squeeze(inputs).tolist() \n", " # Encode batch\n", " encoded_batch = [[text.encode(\"utf-8\")] for text in flattened]\n", " encoded_batch_np = np.array(encoded_batch, dtype=np.bytes_)\n", " # Run inference\n", " result_data = client.infer_batch(encoded_batch_np)\n", " result_data = np.squeeze(result_data[\"translations\"], -1)\n", " return result_data\n", " \n", " return infer_batch" ] }, { "cell_type": "code", "execution_count": 40, "id": "5d10c61c-6102-4d19-8dd6-0c7b5b65343e", "metadata": {}, "outputs": [], "source": [ "generate = predict_batch_udf(partial(triton_fn, model_name=model_name, host_to_url=host_to_http_url),\n", " return_type=StringType(),\n", " input_tensor_shapes=[[1]],\n", " batch_size=32)" ] }, { "cell_type": "markdown", "id": "a85e2ceb", "metadata": {}, "source": [ "#### Load and preprocess DataFrame" ] }, { "cell_type": "code", "execution_count": 37, "id": "2fa3664e", "metadata": {}, "outputs": [], "source": [ "def preprocess(text: pd.Series, prefix: str = \"\") -> pd.Series:\n", " @pandas_udf(\"string\")\n", " def _preprocess(text: pd.Series) -> pd.Series:\n", " return pd.Series([prefix + s.split(\".\")[0] for s in text])\n", " return _preprocess(text)" ] }, { "cell_type": "code", "execution_count": 38, "id": "5d6c54e7-534d-406f-b8e6-fd592efd0ab2", "metadata": {}, "outputs": [], "source": [ "df = spark.read.parquet(data_path).limit(256).repartition(8)" ] }, { "cell_type": "code", "execution_count": 39, "id": "dc1bbbe3-4232-49e5-80f6-99976524b73b", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "25/02/04 13:55:37 WARN CacheManager: Asked to cache already cached data.\n" ] } ], "source": [ "input_df = df.select(preprocess(col(\"text\"), \"translate English to French: \").alias(\"input\")).cache()" ] }, { "cell_type": "markdown", "id": "e71f07d4", "metadata": {}, "source": [ "#### Run Inference" ] }, { "cell_type": "code", "execution_count": 41, "id": "2e0907da-a5d9-4c3b-9db4-ce5e70ca9bb4", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 51:==================================================> (7 + 1) / 8]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 10.8 ms, sys: 8.12 ms, total: 18.9 ms\n", "Wall time: 30 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "%%time\n", "# first pass caches model/fn\n", "preds = input_df.withColumn(\"preds\", generate(struct(\"input\")))\n", "results = preds.collect()" ] }, { "cell_type": "code", "execution_count": 42, "id": "9308bdd7-6f67-484d-8b51-dd1e1b2960ba", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 54:===========================================> (6 + 2) / 8]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 7.23 ms, sys: 3.43 ms, total: 10.7 ms\n", "Wall time: 21.2 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "%%time\n", "preds = input_df.withColumn(\"preds\", generate(\"input\"))\n", "results = preds.collect()" ] }, { "cell_type": "code", "execution_count": 43, "id": "38484ffd-370d-492b-8ca4-9eff9f242a9f", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 57:===========================================> (6 + 2) / 8]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 2.81 ms, sys: 12.7 ms, total: 15.5 ms\n", "Wall time: 22.3 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "%%time\n", "preds = input_df.withColumn(\"preds\", generate(col(\"input\")))\n", "results = preds.collect()" ] }, { "cell_type": "code", "execution_count": 44, "id": "ebcb6699-3ac2-4529-ab0f-fab0a5e792da", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 60:> (0 + 1) / 1]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "+--------------------------------------------------+--------------------------------------------------+\n", "| input| preds|\n", "+--------------------------------------------------+--------------------------------------------------+\n", "|translate English to French: Doesn't anyone bot...|Ne s'ennuie-t-il pas de vérifier où viennent ce...|\n", "|translate English to French: There were two thi...|Il y avait deux choses que j'ai hâte de voir : ...|\n", "|translate English to French: I'm rather surpris...|Je suis plutôt surpris que quelqu'un ait trouvé...|\n", "|translate English to French: Cultural Vandalism...|Vandalisme culturel La nouvelle production Hall...|\n", "|translate English to French: I was at Wrestlema...|J'étais à Wrestlemania VI à Toronto en 10 ans, ...|\n", "|translate English to French: This movie has bee...| Ce film a été réalisé avant|\n", "|translate English to French: [ as a new resolut...|[ en tant que nouvelle résolution pour cette an...|\n", "|translate English to French: This movie is over...|Je suis triste de dire que je parviens à regard...|\n", "|translate English to French: This show had a pr...|Ce spectacle a eu un début prometteur en l'espè...|\n", "|translate English to French: MINOR PLOT SPOILER...|br />br /> Comment ces acteurs talentueux ont-i...|\n", "|translate English to French: There is not one c...|Il n'y a pas d'un personnage sur ce sitcom ayan...|\n", "|translate English to French: Tommy Lee Jones wa...|Tommy Lee Jones était le meilleur Woodroe et pe...|\n", "|translate English to French: My wife rented thi...|Ma femme a loué ce film et n'a jamais pu le voi...|\n", "|translate English to French: This is one of tho...|C’est l’une des comédies en étoiles à l’étoile ...|\n", "|translate English to French: This excruciatingl...|Ce film excruciant ennuyant et infaillible m’a ...|\n", "|translate English to French: you will likely be...|Vous serez probablement très déçu par cette séq...|\n", "|translate English to French: If I was British, ...|Si j'étais britannique, je seraitis embarrassé ...|\n", "|translate English to French: One of those movie...|Un des films dans lesquels il n'y a pas de gros...|\n", "|translate English to French: This show is like ...|Ce spectacle ressemble à l'observation d'une pe...|\n", "| translate English to French: Sigh| Pesée|\n", "+--------------------------------------------------+--------------------------------------------------+\n", "only showing top 20 rows\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "preds.show(truncate=50)" ] }, { "cell_type": "markdown", "id": "919e3113-64dd-482a-9233-6607b3f63c1e", "metadata": { "tags": [] }, "source": [ "#### Shut down server on each executor" ] }, { "cell_type": "code", "execution_count": 45, "id": "425d3b28-7705-45ba-8a18-ad34fc895219", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-02-04 13:56:54,506 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\n", "2025-02-04 13:56:59,695 - INFO - Sucessfully stopped 1 servers. \n" ] }, { "data": { "text/plain": [ "[True]" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "server_manager.stop_servers()" ] }, { "cell_type": "code", "execution_count": 46, "id": "2dec80ca-7a7c-46a9-97c0-7afb1572f5b9", "metadata": {}, "outputs": [], "source": [ "if not on_databricks: # on databricks, spark.stop() puts the cluster in a bad state\n", " spark.stop()" ] }, { "cell_type": "code", "execution_count": null, "id": "f43118ab-fc0a-4f64-a126-4302e615654a", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "spark-dl-tf", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/conditional_generation_torch.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "8f6659b4-88da-4207-8d32-2674da5383a0", "showTitle": false, "tableResultSettingsMap": {}, "title": "" } }, "source": [ "\n", "\n", "# PySpark DL Inference\n", "### Conditional generation with Huggingface\n", "\n", "In this notebook, we demonstrate distributed inference with the T5 transformer to perform sentence translation. \n", "From: https://huggingface.co/docs/transformers/model_doc/t5" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from transformers import T5Tokenizer, T5ForConditionalGeneration\n", "\n", "# Manually enable Huggingface tokenizer parallelism to avoid disabling with PySpark parallelism.\n", "# See (https://github.com/huggingface/transformers/issues/5486) for more info. \n", "import os\n", "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"true\"" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "You are using the default legacy behaviour of the . This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565\n" ] } ], "source": [ "tokenizer = T5Tokenizer.from_pretrained(\"google-t5/t5-small\")\n", "model = T5ForConditionalGeneration.from_pretrained(\"google-t5/t5-small\")\n", "\n", "task_prefix = \"translate English to German: \"\n", "\n", "lines = [\n", " \"The house is wonderful\",\n", " \"Welcome to NYC\",\n", " \"HuggingFace is a company\"\n", "]\n", "\n", "input_sequences = [task_prefix + l for l in lines]" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "inputs = tokenizer(input_sequences,\n", " padding=True, \n", " return_tensors=\"pt\")\n", "\n", "outputs = model.generate(input_ids=inputs[\"input_ids\"], attention_mask=inputs[\"attention_mask\"], max_length=128)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['Das Haus ist wunderbar',\n", " 'Willkommen in NYC',\n", " 'HuggingFace ist ein Unternehmen']" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "[tokenizer.decode(o, skip_special_tokens=True) for o in outputs]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## PySpark" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "1b8dae4a-3bfc-4430-b28a-7350db5efed4", "showTitle": false, "tableResultSettingsMap": {}, "title": "" } }, "outputs": [], "source": [ "from pyspark.sql.types import *\n", "from pyspark import SparkConf\n", "from pyspark.sql import SparkSession\n", "from pyspark.sql.functions import pandas_udf, col, struct\n", "from pyspark.ml.functions import predict_batch_udf" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "a93a1424-e483-4d37-a719-32fabee3f285", "showTitle": false, "tableResultSettingsMap": {}, "title": "" } }, "outputs": [], "source": [ "import json\n", "import pandas as pd\n", "import datasets\n", "from datasets import load_dataset\n", "datasets.disable_progress_bars()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Check the cluster environment to handle any platform-specific Spark configurations." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "on_databricks = os.environ.get(\"DATABRICKS_RUNTIME_VERSION\", False)\n", "on_dataproc = os.environ.get(\"DATAPROC_IMAGE_VERSION\", False)\n", "on_standalone = not (on_databricks or on_dataproc)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Create Spark Session\n", "\n", "For local standalone clusters, we'll connect to the cluster and create the Spark Session. \n", "For CSP environments, Spark will either be preconfigured (Databricks) or we'll need to create the Spark Session (Dataproc)." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "25/02/04 13:34:55 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)\n", "25/02/04 13:34:55 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\n", "Setting default log level to \"WARN\".\n", "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n", "25/02/04 13:34:55 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n" ] } ], "source": [ "conf = SparkConf()\n", "\n", "if 'spark' not in globals():\n", " if on_standalone:\n", " import socket\n", " conda_env = os.environ.get(\"CONDA_PREFIX\")\n", " hostname = socket.gethostname()\n", " conf.setMaster(f\"spark://{hostname}:7077\")\n", " conf.set(\"spark.pyspark.python\", f\"{conda_env}/bin/python\")\n", " conf.set(\"spark.pyspark.driver.python\", f\"{conda_env}/bin/python\")\n", "\n", " conf.set(\"spark.executor.cores\", \"8\")\n", " conf.set(\"spark.task.resource.gpu.amount\", \"0.125\")\n", " conf.set(\"spark.executor.resource.gpu.amount\", \"1\")\n", " conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"true\")\n", " conf.set(\"spark.python.worker.reuse\", \"true\")\n", "\n", "conf.set(\"spark.sql.execution.arrow.maxRecordsPerBatch\", \"1000\")\n", "spark = SparkSession.builder.appName(\"spark-dl-examples\").config(conf=conf).getOrCreate()\n", "sc = spark.sparkContext" ] }, { "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "f08c37a5-fb0c-45f6-8630-d2af67831641", "showTitle": false, "tableResultSettingsMap": {}, "title": "" } }, "source": [ "Load the IMBD Movie Reviews dataset from Huggingface." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "f0ec30c9-365a-43c5-9c53-3497400ee548", "showTitle": false, "tableResultSettingsMap": {}, "title": "" } }, "outputs": [], "source": [ "dataset = load_dataset(\"imdb\", split=\"test\")\n", "dataset = dataset.to_pandas().drop(columns=\"label\")" ] }, { "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "1e4269da-d2b3-46a5-9309-38a1ba825a47", "showTitle": false, "tableResultSettingsMap": {}, "title": "" } }, "source": [ "#### Create PySpark DataFrame" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "30dab34d-8e4b-4f30-b7c2-3dff49da018b", "showTitle": false, "tableResultSettingsMap": {}, "title": "" } }, "outputs": [ { "data": { "text/plain": [ "StructType([StructField('text', StringType(), True)])" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = spark.createDataFrame(dataset).repartition(8)\n", "df.schema" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "55c33cc0-5dfb-449c-ae79-80972fb04405", "showTitle": false, "tableResultSettingsMap": {}, "title": "" } }, "outputs": [ { "data": { "text/plain": [ "25000" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.count()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "efd6d6d9-1c2c-4131-8df4-a3ef75c3fc57", "showTitle": false, "tableResultSettingsMap": {}, "title": "" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "25/02/04 13:35:02 WARN TaskSetManager: Stage 6 contains a task of very large size (4021 KiB). The maximum recommended task size is 1000 KiB.\n" ] }, { "data": { "text/plain": [ "[Row(text=\"Anyone remember the first CKY, CKY2K etc..? Back when it was about making crazy cool stuff, rather than watching Bam Margera act like a douchebag, spoiled 5 year old, super/rock-star wannabe.

The show used to be awesome, however, Bam's fame and wealth has led him to believe, that we now enjoy him acting childish and idiotic, more than actual cool stuff, that used to be in ex. CKY2K.

The acts are so repetitive, there's like nothing new, except annoying stupidity and rehearsed comments... The only things we see is Bam Margera, so busy showing us how much he doesn't care, how much money he got or whatsoever.

I really got nothing much left to say except, give us back CKY2K, cause Bam suck..

I enjoy watching Steve-o, Knoxville etc. a thousand times more.\")]" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.take(1)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "65a5b258-1634-441e-8b36-29777e54592d", "showTitle": false, "tableResultSettingsMap": {}, "title": "" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "25/02/04 13:35:02 WARN TaskSetManager: Stage 9 contains a task of very large size (4021 KiB). The maximum recommended task size is 1000 KiB.\n" ] } ], "source": [ "data_path = \"spark-dl-datasets/imdb_test\"\n", "if on_databricks:\n", " dbutils.fs.mkdirs(\"/FileStore/spark-dl-datasets\")\n", " data_path = \"dbfs:/FileStore/\" + data_path\n", "\n", "df.write.mode(\"overwrite\").parquet(data_path)" ] }, { "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "89b909f4-5732-428b-ad61-9a6c5cf94df2", "showTitle": false, "tableResultSettingsMap": {}, "title": "" } }, "source": [ "#### Load and preprocess DataFrame\n", "\n", "Define our preprocess function. We'll take the first sentence from each sample as our input for translation." ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "eb7e53d6-bbd0-48d2-a3be-36847275e2a9", "showTitle": false, "tableResultSettingsMap": {}, "title": "" } }, "outputs": [], "source": [ "def preprocess(text: pd.Series, prefix: str = \"\") -> pd.Series:\n", " @pandas_udf(\"string\")\n", " def _preprocess(text: pd.Series) -> pd.Series:\n", " return pd.Series([prefix + s.split(\".\")[0] for s in text])\n", " return _preprocess(text)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "97eee1a4-9dc4-43b0-9578-6d7f8ff338bd", "showTitle": false, "tableResultSettingsMap": {}, "title": "" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+----------------------------------------------------------------------------------------------------+\n", "| text|\n", "+----------------------------------------------------------------------------------------------------+\n", "|The only reason I'm even giving this movie a 4 is because it was made in to an episode of Mystery...|\n", "|Awkward disaster mishmash has a team of scavengers coming across the overturned S.S. Poseidon, ho...|\n", "|Here is a fantastic concept for a film - a series of meteors crash into a small town and the resu...|\n", "|I walked out of the cinema having suffered this film after 30 mins. I left two friends pinned in ...|\n", "|A wildly uneven film where the major problem is the uneasy mix of comedy and thriller. To me, the...|\n", "|Leonard Rossiter and Frances de la Tour carry this film, not without a struggle, as the script wa...|\n", "|A good cast... A good idea but turns out it is flawed as hypnosis is not allowed as evidence in c...|\n", "|Yet again, I appear to be the only person on planet Earth who is capable of criticizing Japanese ...|\n", "|As a serious horror fan, I get that certain marketing ploys are used to sell movies, especially t...|\n", "|Upon writing this review I have difficulty trying to think of what to write about. Nothing much h...|\n", "|Simply awful. I'm including a spoiler warning here only because of including a coupla jokes from ...|\n", "|I am a fan of Ed Harris' work and I really had high expectations about this film. Having so good ...|\n", "|Well...I like Patricia Kaas. She is a beautiful lady and an extremely gifted and versatile singer...|\n", "|This is a new approach to comedy. It isn't funny.

The joke is that this, in and of its...|\n", "|It's been mentioned by others the inane dialogue in this series and I agree.

If Mom an...|\n", "|One of the most boring movies I've ever had to sit through, it's completely formulaic. Just a coo...|\n", "|This movie was playing on Lifetime Movie Network last month and I decided to check it out. I watc...|\n", "|1983's \"Frightmare\" is an odd little film. The director seems to be trying to combine the atmosph...|\n", "|'Felony' is a B-movie. No doubt about it.

Of course, if you take a look at the cast li...|\n", "|This movie defines the word \"confused\". All the actors stay true to the script. More's the pity, ...|\n", "+----------------------------------------------------------------------------------------------------+\n", "only showing top 20 rows\n", "\n" ] } ], "source": [ "# Limit to N rows, since this can be slow\n", "df = spark.read.parquet(data_path).limit(512).repartition(8)\n", "df.show(truncate=100)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Append a prefix to tell the model to translate English to French:" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "fa14304d-b409-4d07-99ef-9da7c7c76158", "showTitle": false, "tableResultSettingsMap": {}, "title": "" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+----------------------------------------------------------------------------------------------------+\n", "| input|\n", "+----------------------------------------------------------------------------------------------------+\n", "|translate English to French: The only reason I'm even giving this movie a 4 is because it was mad...|\n", "|translate English to French: Awkward disaster mishmash has a team of scavengers coming across the...|\n", "|translate English to French: Here is a fantastic concept for a film - a series of meteors crash i...|\n", "| translate English to French: I walked out of the cinema having suffered this film after 30 mins|\n", "|translate English to French: A wildly uneven film where the major problem is the uneasy mix of co...|\n", "|translate English to French: Leonard Rossiter and Frances de la Tour carry this film, not without...|\n", "| translate English to French: A good cast|\n", "|translate English to French: Yet again, I appear to be the only person on planet Earth who is cap...|\n", "|translate English to French: As a serious horror fan, I get that certain marketing ploys are used...|\n", "|translate English to French: Upon writing this review I have difficulty trying to think of what t...|\n", "| translate English to French: Simply awful|\n", "|translate English to French: I am a fan of Ed Harris' work and I really had high expectations abo...|\n", "| translate English to French: Well|\n", "| translate English to French: This is a new approach to comedy|\n", "|translate English to French: It's been mentioned by others the inane dialogue in this series and ...|\n", "|translate English to French: One of the most boring movies I've ever had to sit through, it's com...|\n", "|translate English to French: This movie was playing on Lifetime Movie Network last month and I de...|\n", "| translate English to French: 1983's \"Frightmare\" is an odd little film|\n", "| translate English to French: 'Felony' is a B-movie|\n", "| translate English to French: This movie defines the word \"confused\"|\n", "+----------------------------------------------------------------------------------------------------+\n", "only showing top 20 rows\n", "\n" ] } ], "source": [ "input_df = df.select(preprocess(col(\"text\"), \"translate English to French: \").alias(\"input\")).cache()\n", "input_df.show(truncate=100)" ] }, { "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "bc9cbdd2-1ca6-48e4-a549-792b3726525b", "showTitle": false, "tableResultSettingsMap": {}, "title": "" } }, "source": [ "## Inference using Spark DL API\n", "\n", "Distributed inference using the PySpark [predict_batch_udf](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.functions.predict_batch_udf.html#pyspark.ml.functions.predict_batch_udf):\n", "\n", "- predict_batch_fn uses PyTorch APIs to load the model and return a predict function which operates on numpy arrays \n", "- predict_batch_udf will convert the Spark DataFrame columns into numpy input batches for the predict function" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "adb81177-442d-42ab-b86d-d8792201b4c8", "showTitle": false, "tableResultSettingsMap": {}, "title": "" } }, "outputs": [], "source": [ "def predict_batch_fn():\n", " import numpy as np\n", " import torch\n", " from transformers import T5ForConditionalGeneration, T5Tokenizer\n", " from pyspark import TaskContext\n", "\n", " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", " print(f\"Initializing model on worker {TaskContext.get().partitionId()}, device {device}.\")\n", " model = T5ForConditionalGeneration.from_pretrained(\"t5-small\").to(device)\n", " tokenizer = T5Tokenizer.from_pretrained(\"t5-small\")\n", "\n", " def predict(inputs):\n", " flattened = np.squeeze(inputs).tolist()\n", " inputs = tokenizer(flattened, \n", " padding=True,\n", " return_tensors=\"pt\").to(device)\n", " outputs = model.generate(input_ids=inputs[\"input_ids\"],\n", " attention_mask=inputs[\"attention_mask\"],\n", " max_length=128)\n", " string_outputs = np.array([tokenizer.decode(o, skip_special_tokens=True) for o in outputs])\n", " print(\"predict: {}\".format(len(flattened)))\n", " return string_outputs\n", " \n", " return predict" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "20aab3a1-2284-4c07-9ce1-a20cf54d88f3", "showTitle": false, "tableResultSettingsMap": {}, "title": "" } }, "outputs": [], "source": [ "generate = predict_batch_udf(predict_batch_fn,\n", " return_type=StringType(),\n", " batch_size=32)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "a8d6f48e-09e7-4fc7-9d2f-1b68bc2976a7", "showTitle": false, "tableResultSettingsMap": {}, "title": "" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 24:=============================> (4 + 4) / 8]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 10.2 ms, sys: 5.05 ms, total: 15.2 ms\n", "Wall time: 7.41 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "%%time\n", "# first pass caches model/fn\n", "preds = input_df.withColumn(\"preds\", generate(struct(\"input\")))\n", "results = preds.collect()" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "abe2271d-0077-48f6-98b1-93524dd86447", "showTitle": false, "tableResultSettingsMap": {}, "title": "" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 27:=============================> (4 + 4) / 8]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 3.93 ms, sys: 1.98 ms, total: 5.91 ms\n", "Wall time: 4.08 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "%%time\n", "preds = input_df.withColumn(\"preds\", generate(\"input\"))\n", "results = preds.collect()" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "77623711-a742-4262-8839-16fc3ddd1af7", "showTitle": false, "tableResultSettingsMap": {}, "title": "" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 30:==============> (2 + 6) / 8]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 3.85 ms, sys: 1.75 ms, total: 5.6 ms\n", "Wall time: 4.08 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "%%time\n", "preds = input_df.withColumn(\"preds\", generate(col(\"input\")))\n", "results = preds.collect()" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "f339c654-52fd-4992-b054-188dfb260e5d", "showTitle": false, "tableResultSettingsMap": {}, "title": "" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+--------------------------------------------------+--------------------------------------------------+\n", "| input| preds|\n", "+--------------------------------------------------+--------------------------------------------------+\n", "|translate English to French: The only reason I'...|La seule raison pour laquelle je donne même ce ...|\n", "|translate English to French: Awkward disaster m...|La mishmash d’Awkward a eu une équipe de scaven...|\n", "|translate English to French: Here is a fantasti...|Voici un concept fantastique pour un film : une...|\n", "|translate English to French: I walked out of th...|Je me suis rendu du cinéma après avoir subi ce ...|\n", "|translate English to French: A wildly uneven fi...|Un film extrêmement inégal où le problème majeu...|\n", "|translate English to French: Leonard Rossiter a...|Leonard Rossiter et Frances de la Tour mettent ...|\n", "| translate English to French: A good cast| Une bonne étoile|\n", "|translate English to French: Yet again, I appea...|Encore une fois, je semble être la seule person...|\n", "|translate English to French: As a serious horro...|En tant que grand fan d'horreur, je peux obteni...|\n", "|translate English to French: Upon writing this ...|la suite de cette étude, j'ai de la difficulté ...|\n", "| translate English to French: Simply awful| Tout simplement terrible|\n", "|translate English to French: I am a fan of Ed H...|Je suis un fan de l'oeuvre d'Ed Harris et j'ai ...|\n", "| translate English to French: Well| Eh bien|\n", "|translate English to French: This is a new appr...| Il s’agit d’une nouvelle approche de la comédie.|\n", "|translate English to French: It's been mentione...|Il a été mentionné par d'autres le dialogue ina...|\n", "|translate English to French: One of the most bo...|Un des films les plus ennuyeux que je n'ai jama...|\n", "|translate English to French: This movie was pla...|Ce film jouait sur Lifetime Movie Network le mo...|\n", "|translate English to French: 1983's \"Frightmare...|Le film \"Frightmare\" de 1983 est un petit film ...|\n", "|translate English to French: 'Felony' is a B-movie| 'Felony' est un mouvement B|\n", "|translate English to French: This movie defines...| Ce film définit le mot «confus»|\n", "+--------------------------------------------------+--------------------------------------------------+\n", "only showing top 20 rows\n", "\n" ] } ], "source": [ "preds.show(truncate=50)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's try English to German:" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "input_df2 = df.select(preprocess(col(\"text\"), \"translate English to German: \").alias(\"input\")).cache()" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 36:==================================================> (7 + 1) / 8]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 6.02 ms, sys: 705 μs, total: 6.73 ms\n", "Wall time: 4.24 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "%%time\n", "# first pass caches model/fn\n", "preds = input_df2.withColumn(\"preds\", generate(struct(\"input\")))\n", "result = preds.collect()" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 39:==============> (2 + 6) / 8]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 6.12 ms, sys: 319 μs, total: 6.43 ms\n", "Wall time: 3.88 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "%%time\n", "preds = input_df2.withColumn(\"preds\", generate(\"input\"))\n", "result = preds.collect()" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 42:==============> (2 + 6) / 8]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 7.03 ms, sys: 16 μs, total: 7.05 ms\n", "Wall time: 3.9 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "%%time\n", "preds = input_df2.withColumn(\"preds\", generate(col(\"input\")))\n", "result = preds.collect()" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+--------------------------------------------------+--------------------------------------------------+\n", "| input| preds|\n", "+--------------------------------------------------+--------------------------------------------------+\n", "|translate English to German: The only reason I'...|Der einzige Grund, warum ich sogar diesen Film ...|\n", "|translate English to German: Awkward disaster m...|Awkward-Katastrophenmischmash hat ein Team von ...|\n", "|translate English to German: Here is a fantasti...|Hier ist ein fantastisches Konzept für einen Fi...|\n", "|translate English to German: I walked out of th...|Ich ging aus dem Kino, nachdem ich diesen Film ...|\n", "|translate English to German: A wildly uneven fi...|Ein völlig ungleicher Film, in dem das Hauptpro...|\n", "|translate English to German: Leonard Rossiter a...|Leonard Rossiter und Frances de la Tour tragen ...|\n", "| translate English to German: A good cast| Gutes Casting|\n", "|translate English to German: Yet again, I appea...|Ich scheine wieder einmal die einzige Person au...|\n", "|translate English to German: As a serious horro...|Als ernsthafter Horrorfan erhalte ich, dass bes...|\n", "|translate English to German: Upon writing this ...|Ich habe Schwierigkeiten, mich an die Regeln zu...|\n", "| translate English to German: Simply awful| Einfach schrecklich|\n", "|translate English to German: I am a fan of Ed H...|Ich bin ein Fan von Ed Harris' Arbeit und hatte...|\n", "| translate English to German: Well| Nun|\n", "|translate English to German: This is a new appr...| Das ist ein neuer Ansatz für die Komödie|\n", "|translate English to German: It's been mentione...|Es wurde von anderen erwähnt, die unangenehme D...|\n", "|translate English to German: One of the most bo...|Einer der langwierigen Filme, die ich jemals du...|\n", "|translate English to German: This movie was pla...|Dieser Film spielte im letzten Monat auf Lifeti...|\n", "|translate English to German: 1983's \"Frightmare...| 1983 ist \"Frightmare\" ein merkwürdiger Film|\n", "|translate English to German: 'Felony' is a B-movie| 'Felony' ist ein B-Film|\n", "|translate English to German: This movie defines...| Dieser Film definiert das Wort \"verwirrt\"|\n", "+--------------------------------------------------+--------------------------------------------------+\n", "only showing top 20 rows\n", "\n" ] } ], "source": [ "preds.show(truncate=50)" ] }, { "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "a79a6f3a-cc34-46a4-aadd-16870423fffa", "showTitle": false, "tableResultSettingsMap": {}, "title": "" } }, "source": [ "## Using Triton Inference Server\n", "In this section, we demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL. \n", "We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server. \n", "\n", "The process looks like this:\n", "- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.\n", "- Define a Triton inference function, which contains a client that binds to the local server on a given node and sends inference requests.\n", "- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.\n", "- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.\n", "\n", "\"drawing\"" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "1e73757e-a451-4835-98e0-257ccf7a9025", "showTitle": false, "tableResultSettingsMap": {}, "title": "" } }, "outputs": [], "source": [ "from functools import partial" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Import the helper class from server_utils.py:" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [], "source": [ "sc.addPyFile(\"server_utils.py\")\n", "\n", "from server_utils import TritonServerManager" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Define the Triton Server function:" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "71b1cb49-3d8f-4eeb-937a-c0c334bd2947", "showTitle": false, "tableResultSettingsMap": {}, "title": "" } }, "outputs": [], "source": [ "def triton_server(ports):\n", " import time\n", " import signal\n", " import numpy as np\n", " import torch\n", " from transformers import T5Tokenizer, T5ForConditionalGeneration\n", " from pytriton.decorators import batch\n", " from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor\n", " from pytriton.triton import Triton, TritonConfig\n", " from pyspark import TaskContext\n", "\n", " print(f\"SERVER: Initializing Conditional Generation model on worker {TaskContext.get().partitionId()}.\")\n", " tokenizer = T5Tokenizer.from_pretrained(\"t5-small\")\n", " model = T5ForConditionalGeneration.from_pretrained(\"t5-small\")\n", " \n", " DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", " print(f\"SERVER: Using {DEVICE} device.\")\n", " model = model.to(DEVICE)\n", "\n", " @batch\n", " def _infer_fn(**inputs):\n", " sentences = np.squeeze(inputs[\"text\"]).tolist()\n", " print(f\"SERVER: Received batch of size {len(sentences)}\")\n", " decoded_sentences = [s.decode(\"utf-8\") for s in sentences]\n", " inputs = tokenizer(decoded_sentences,\n", " padding=True,\n", " return_tensors=\"pt\").to(DEVICE)\n", " output_ids = model.generate(input_ids=inputs[\"input_ids\"],\n", " attention_mask=inputs[\"attention_mask\"],\n", " max_length=128)\n", " outputs = np.array([[tokenizer.decode(o, skip_special_tokens=True)] for o in output_ids])\n", " return {\n", " \"translations\": outputs,\n", " }\n", "\n", " workspace_path = f\"/tmp/triton_{time.strftime('%m_%d_%M_%S')}\"\n", " triton_conf = TritonConfig(http_port=ports[0], grpc_port=ports[1], metrics_port=ports[2])\n", " with Triton(config=triton_conf, workspace=workspace_path) as triton:\n", " triton.bind(\n", " model_name=\"ConditionalGeneration\",\n", " infer_func=_infer_fn,\n", " inputs=[\n", " Tensor(name=\"text\", dtype=object, shape=(-1,)),\n", " ],\n", " outputs=[\n", " Tensor(name=\"translations\", dtype=object, shape=(-1,)),\n", " ],\n", " config=ModelConfig(\n", " max_batch_size=64,\n", " batcher=DynamicBatcher(max_queue_delay_microseconds=5000), # 5ms\n", " ),\n", " strict=True,\n", " )\n", "\n", " def _stop_triton(signum, frame):\n", " # The server manager sends SIGTERM to stop the server; this function ensures graceful cleanup.\n", " print(\"SERVER: Received SIGTERM. Stopping Triton server.\")\n", " triton.stop()\n", "\n", " signal.signal(signal.SIGTERM, _stop_triton)\n", "\n", " print(\"SERVER: Serving inference\")\n", " triton.serve()" ] }, { "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "1bf14846-15a3-4bc8-b0c5-ce71680d3550", "showTitle": false, "tableResultSettingsMap": {}, "title": "" } }, "source": [ "#### Start Triton servers" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `TritonServerManager` will handle the lifecycle of Triton server instances across the Spark cluster:\n", "- Find available ports for HTTP/gRPC/metrics\n", "- Deploy a server on each node via stage-level scheduling\n", "- Gracefully shutdown servers across nodes" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "5bf1fafc-d9c9-4fd7-901d-da97cf4ff496", "showTitle": false, "tableResultSettingsMap": {}, "title": "" } }, "outputs": [], "source": [ "model_name = \"ConditionalGeneration\"\n", "server_manager = TritonServerManager(model_name=model_name)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-02-07 11:03:44,809 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\n", "2025-02-07 11:03:44,810 - INFO - Starting 1 servers.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "data": { "text/plain": [ "{'cb4ae00-lcedt': (2020631, [7000, 7001, 7002])}" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Returns {'hostname', (server_pid, [http_port, grpc_port, metrics_port])}\n", "server_manager.start_servers(triton_server)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Define client function" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Get the hostname -> url mapping from the server manager:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "host_to_http_url = server_manager.host_to_http_url # or server_manager.host_to_grpc_url" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Define the Triton inference function, which returns a predict function for batch inference through the server:" ] }, { "cell_type": "code", "execution_count": 35, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "e203eb19-166d-4177-aa87-fd31b7e3c90e", "showTitle": false, "tableResultSettingsMap": {}, "title": "" } }, "outputs": [], "source": [ "def triton_fn(model_name, host_to_url):\n", " import socket\n", " import numpy as np\n", " from pytriton.client import ModelClient\n", "\n", " url = host_to_url[socket.gethostname()]\n", " print(f\"Connecting to Triton model {model_name} at {url}.\")\n", "\n", " def infer_batch(inputs):\n", " with ModelClient(url, model_name, inference_timeout_s=240) as client:\n", " flattened = np.squeeze(inputs).tolist() \n", " # Encode batch\n", " encoded_batch = [[text.encode(\"utf-8\")] for text in flattened]\n", " encoded_batch_np = np.array(encoded_batch, dtype=np.bytes_)\n", " # Run inference\n", " result_data = client.infer_batch(encoded_batch_np)\n", " result_data = np.squeeze(result_data[\"translations\"], -1)\n", " return result_data\n", " \n", " return infer_batch" ] }, { "cell_type": "code", "execution_count": 39, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "be692f4a-cf86-4cf4-9530-7c62e479cacd", "showTitle": false, "tableResultSettingsMap": {}, "title": "" } }, "outputs": [], "source": [ "generate = predict_batch_udf(partial(triton_fn, model_name=model_name, host_to_url=host_to_http_url),\n", " return_type=StringType(),\n", " input_tensor_shapes=[[1]],\n", " batch_size=32)" ] }, { "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "1b6b2a05-aea4-4e4d-a87d-0a6bd5ab554c", "showTitle": false, "tableResultSettingsMap": {}, "title": "" } }, "source": [ "#### Load and preprocess DataFrame" ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "a5e83230-5178-4fec-bba2-0e69be40e68c", "showTitle": false, "tableResultSettingsMap": {}, "title": "" } }, "outputs": [], "source": [ "def preprocess(text: pd.Series, prefix: str = \"\") -> pd.Series:\n", " @pandas_udf(\"string\")\n", " def _preprocess(text: pd.Series) -> pd.Series:\n", " return pd.Series([prefix + s.split(\".\")[0] for s in text])\n", " return _preprocess(text)" ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "aad299b0-34bb-4edb-b1e4-cd0c82bb7455", "showTitle": false, "tableResultSettingsMap": {}, "title": "" } }, "outputs": [], "source": [ "df = spark.read.parquet(data_path).limit(512).repartition(8)" ] }, { "cell_type": "code", "execution_count": 38, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "7934a6fc-57bc-4104-a52c-076351e77cbe", "showTitle": false, "tableResultSettingsMap": {}, "title": "" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "25/02/04 13:35:39 WARN CacheManager: Asked to cache already cached data.\n" ] } ], "source": [ "input_df = df.select(preprocess(col(\"text\"), \"translate English to French: \").alias(\"input\")).cache()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Run Inference" ] }, { "cell_type": "code", "execution_count": 40, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "0f6229ef-01c8-43c9-a259-c5df6a18d689", "showTitle": false, "tableResultSettingsMap": {}, "title": "" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 51:====================================> (5 + 3) / 8]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 5.09 ms, sys: 4.41 ms, total: 9.5 ms\n", "Wall time: 4.96 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "%%time\n", "# first pass caches model/fn\n", "preds = input_df.withColumn(\"preds\", generate(struct(\"input\")))\n", "results = preds.collect()" ] }, { "cell_type": "code", "execution_count": 41, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "5a543b4c-8b29-4f61-9773-2639bbc7f728", "showTitle": false, "tableResultSettingsMap": {}, "title": "" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 54:===========================================> (6 + 2) / 8]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 5.4 ms, sys: 1.12 ms, total: 6.52 ms\n", "Wall time: 4.41 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "%%time\n", "preds = input_df.withColumn(\"preds\", generate(\"input\"))\n", "results = preds.collect()" ] }, { "cell_type": "code", "execution_count": 42, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "4c0cfc4e-ef0a-435e-9fdf-72b72b6def93", "showTitle": false, "tableResultSettingsMap": {}, "title": "" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 57:===========================================> (6 + 2) / 8]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 4.59 ms, sys: 1.79 ms, total: 6.38 ms\n", "Wall time: 4.55 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "%%time\n", "preds = input_df.withColumn(\"preds\", generate(col(\"input\")))\n", "results = preds.collect()" ] }, { "cell_type": "code", "execution_count": 43, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "2d756e2e-8b60-43cb-b5f9-e27de11be24d", "showTitle": false, "tableResultSettingsMap": {}, "title": "" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+--------------------------------------------------+--------------------------------------------------+\n", "| input| preds|\n", "+--------------------------------------------------+--------------------------------------------------+\n", "|translate English to French: The only reason I'...|La seule raison pour laquelle je donne même ce ...|\n", "|translate English to French: Awkward disaster m...|La mishmash d’Awkward a eu une équipe de scaven...|\n", "|translate English to French: Here is a fantasti...|Voici un concept fantastique pour un film : une...|\n", "|translate English to French: I walked out of th...|Je me suis rendu du cinéma après avoir subi ce ...|\n", "|translate English to French: A wildly uneven fi...|Un film extrêmement inégal où le problème majeu...|\n", "|translate English to French: Leonard Rossiter a...|Leonard Rossiter et Frances de la Tour mettent ...|\n", "| translate English to French: A good cast| Une bonne étoile|\n", "|translate English to French: Yet again, I appea...|Encore une fois, je semble être la seule person...|\n", "|translate English to French: As a serious horro...|En tant que grand fan d'horreur, je peux obteni...|\n", "|translate English to French: Upon writing this ...|la suite de cette étude, j'ai de la difficulté ...|\n", "| translate English to French: Simply awful| Tout simplement terrible|\n", "|translate English to French: I am a fan of Ed H...|Je suis un fan de l'oeuvre d'Ed Harris et j'ai ...|\n", "| translate English to French: Well| Eh bien|\n", "|translate English to French: This is a new appr...| Il s’agit d’une nouvelle approche de la comédie.|\n", "|translate English to French: It's been mentione...|Il a été mentionné par d'autres le dialogue ina...|\n", "|translate English to French: One of the most bo...|Un des films les plus ennuyeux que je n'ai jama...|\n", "|translate English to French: This movie was pla...|Ce film jouait sur Lifetime Movie Network le mo...|\n", "|translate English to French: 1983's \"Frightmare...|Le film \"Frightmare\" de 1983 est un petit film ...|\n", "|translate English to French: 'Felony' is a B-movie| 'Felony' est un mouvement B|\n", "|translate English to French: This movie defines...| Ce film définit le mot «confus»|\n", "+--------------------------------------------------+--------------------------------------------------+\n", "only showing top 20 rows\n", "\n" ] } ], "source": [ "preds.show(truncate=50)" ] }, { "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "86ae68d4-57da-41d9-91b4-625ef9465d60", "showTitle": false, "tableResultSettingsMap": {}, "title": "" } }, "source": [ "#### Shut down servers on each executor" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-02-04 13:35:53,794 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\n", "2025-02-04 13:35:58,983 - INFO - Sucessfully stopped 1 servers. \n" ] }, { "data": { "text/plain": [ "[True]" ] }, "execution_count": 44, "metadata": {}, "output_type": "execute_result" } ], "source": [ "server_manager.stop_servers()" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [], "source": [ "if not on_databricks: # on databricks, spark.stop() puts the cluster in a bad state\n", " spark.stop()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "008c3e50-d321-4431-a9ab-919b35d1b042", "showTitle": false, "tableResultSettingsMap": {}, "title": "" } }, "outputs": [], "source": [] } ], "metadata": { "application/vnd.databricks.v1+notebook": { "dashboards": [], "environmentMetadata": null, "language": "python", "notebookMetadata": { "mostRecentlyExecutedCommandWithImplicitDF": { "commandId": 421988607303514, "dataframes": [ "_sqldf" ] }, "pythonIndentUnit": 4 }, "notebookName": "spark-triton-db.ipynb", "widgets": {} }, "kernelspec": { "display_name": "spark-dl-torch", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.11" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/deepseek-r1_torch.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "\n", "# PySpark LLM Inference: DeepSeek-R1 Reasoning Q/A\n", "\n", "In this notebook, we demonstrate distributed batch inference with [DeepSeek-R1](https://github.com/deepseek-ai/DeepSeek-R1), using open weights on Huggingface.\n", "\n", "We use [DeepSeek-R1-Distill-Llama-8B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Llama-8B) as demonstration. DeepSeek's distilled models are based on open-source LLMs (such as Llama/Qwen), and are fine-tuned using samples generated by DeepSeek-R1. We'll show how to use the model to reason through word problems.\n", "\n", "**Note:** Running this model on GPU with 16-bit precision requires **~18GB** of GPU RAM. Make sure your instances have sufficient GPU capacity." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Manually enable Huggingface tokenizer parallelism to avoid disabling with PySpark parallelism.\n", "# See (https://github.com/huggingface/transformers/issues/5486) for more info. \n", "import os\n", "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"true\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Check the cluster environment to handle any platform-specific configurations." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "on_databricks = os.environ.get(\"DATABRICKS_RUNTIME_VERSION\", False)\n", "on_dataproc = os.environ.get(\"DATAPROC_IMAGE_VERSION\", False)\n", "on_standalone = not (on_databricks or on_dataproc)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# For cloud environments, load the model to the distributed file system.\n", "if on_databricks:\n", " models_dir = \"/dbfs/FileStore/spark-dl-models\"\n", " dbutils.fs.mkdirs(\"/FileStore/spark-dl-models\")\n", " model_path = f\"{models_dir}/deepseek-r1-distill-llama-8b\"\n", "elif on_dataproc:\n", " models_dir = \"/mnt/gcs/spark-dl-models\"\n", " os.mkdir(models_dir) if not os.path.exists(models_dir) else None\n", " model_path = f\"{models_dir}/deepseek-r1-distill-llama-8b\"\n", "else:\n", " model_path = os.path.abspath(\"deepseek-r1-distill-llama-8b\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Download the model from huggingface hub." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from huggingface_hub import snapshot_download\n", "\n", "model_path = snapshot_download(\n", " repo_id=\"deepseek-ai/DeepSeek-R1-Distill-Llama-8B\",\n", " local_dir=model_path\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Warmup: Running locally\n", "\n", "**Note:** If the driver node does not have sufficient GPU capacity, proceed to the PySpark section." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "0ab193983c774a948e375407d7df1f83", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/2 [00:00\n", "\n", "To determine how many **r's** are in the word **strawberry**, let's follow these steps:\n", "\n", "1. **Write down the word:**\n", " \n", " S T R A W B E R R Y\n", "\n", "2. **Identify and count each occurrence of the letter R:**\n", " \n", " - **1.** S - no R\n", " - **2.** T - no R\n", " - **3.** R - **1 R**\n", " - **4.** A - no R\n", " - **5.** W - no R\n", " - **6.** B - no R\n", " - **7.** E - no R\n", " - **8.** R - **2 R's**\n", " - **9.** R - **3 R's**\n", " - **10.** Y - no R\n", "\n", "3. **Total count of R's:**\n", " \n", " There are **3 R's** in the word **strawberry**.\n", "\n", "\\boxed{3}\n" ] } ], "source": [ "res = pipe([\"How many r's are there in strawberry?\"], max_new_tokens=512, temperature=0.1)\n", "print(\"\\n\", res[0][0]['generated_text'])" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", " Which number is bigger: 9.9 or 9.11? Let's see.\n", "\n", "First, I need to compare the whole number parts of both numbers. Both 9.9 and 9.11 have the same whole number part, which is 9.\n", "\n", "Since the whole numbers are equal, I'll compare the decimal parts. For 9.9, the decimal part is 0.9, and for 9.11, the decimal part is 0.11.\n", "\n", "To make it easier, I can express 0.9 as 0.90. Now, comparing 0.90 and 0.11, it's clear that 0.90 is greater than 0.11.\n", "\n", "Therefore, 9.9 is bigger than 9.11.\n", "\n", "\n", "To determine which number is larger between **9.9** and **9.11**, let's compare them step by step.\n", "\n", "1. **Compare the Whole Numbers:**\n", " - Both numbers have the same whole number part: **9**.\n", " \n", "2. **Compare the Decimal Parts:**\n", " - **9.9** can be written as **9.90**.\n", " - **9.11** remains **9.11**.\n", " \n", "3. **Analyze the Decimal Comparison:**\n", " - Compare the tenths place:\n", " - **9.90** has **9** in the tenths place.\n", " - **9.11** has **1** in the tenths place.\n", " - Since **9 > 1**, **9.90** is greater than **9.11**.\n", "\n", "4. **Conclusion:**\n", " - Therefore, **9.9** is larger than **9.11**.\n", "\n", "\\boxed{9.9}\n" ] } ], "source": [ "res = pipe([\"Which number is bigger: 9.9 or 9.11?\"], max_new_tokens=512, temperature=0.1)\n", "print(\"\\n\", res[0][0]['generated_text'])" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import torch\n", "\n", "# Unload the model from GPU memory.\n", "del pipe\n", "torch.cuda.empty_cache()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## PySpark" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "from pyspark.sql.types import *\n", "from pyspark import SparkConf\n", "from pyspark.sql import SparkSession\n", "from pyspark.sql.functions import pandas_udf, col, struct, length\n", "from pyspark.ml.functions import predict_batch_udf" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import os\n", "import pandas as pd\n", "import datasets\n", "from datasets import load_dataset\n", "datasets.disable_progress_bars()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Create Spark Session\n", "\n", "For local standalone clusters, we'll connect to the cluster and create the Spark Session. \n", "For CSP environments, Spark will either be preconfigured (Databricks) or we'll need to create the Spark Session (Dataproc)." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "25/02/10 09:40:01 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)\n", "25/02/10 09:40:01 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\n", "Setting default log level to \"WARN\".\n", "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n", "25/02/10 09:40:02 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n" ] } ], "source": [ "conf = SparkConf()\n", "\n", "if 'spark' not in globals():\n", " if on_standalone:\n", " import socket\n", " conda_env = os.environ.get(\"CONDA_PREFIX\")\n", " hostname = socket.gethostname()\n", " conf.setMaster(f\"spark://{hostname}:7077\")\n", " conf.set(\"spark.pyspark.python\", f\"{conda_env}/bin/python\")\n", " conf.set(\"spark.pyspark.driver.python\", f\"{conda_env}/bin/python\")\n", "\n", " conf.set(\"spark.executor.cores\", \"8\")\n", " conf.set(\"spark.task.maxFailures\", \"1\")\n", " conf.set(\"spark.task.resource.gpu.amount\", \"0.125\")\n", " conf.set(\"spark.executor.resource.gpu.amount\", \"1\")\n", " conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"true\")\n", " conf.set(\"spark.python.worker.reuse\", \"true\")\n", "\n", "spark = SparkSession.builder.appName(\"spark-dl-examples\").config(conf=conf).getOrCreate()\n", "sc = spark.sparkContext" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Load DataFrame\n", "\n", "Load the first 500 samples of the [Orca Math Word Problems dataset](https://huggingface.co/datasets/microsoft/orca-math-word-problems-200k) from Huggingface and store in a Spark Dataframe." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "dataset = load_dataset(\"microsoft/orca-math-word-problems-200k\", split=\"train\", streaming=True)\n", "dataset = pd.Series([sample[\"question\"] for sample in dataset.take(500)])" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+----------------------------------------------------------------------------------------------------+\n", "| question|\n", "+----------------------------------------------------------------------------------------------------+\n", "|Jungkook is the 5th place. Find the number of people who crossed the finish line faster than Jung...|\n", "|A number divided by 10 is 6. Yoongi got the result by subtracting 15 from a certain number. What ...|\n", "|Dongju selects a piece of paper with a number written on it, and wants to make a three-digit numb...|\n", "|You wanted to subtract 46 from a number, but you accidentally subtract 59 and get 43. How much do...|\n", "|The length of one span of Jinseo is about 12 centimeters (cm). When Jinseo measured the length of...|\n", "+----------------------------------------------------------------------------------------------------+\n", "only showing top 5 rows\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "df = spark.createDataFrame(dataset, schema=StringType()).withColumnRenamed(\"value\", \"question\")\n", "df.show(5, truncate=100)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "data_path = \"spark-dl-datasets/orca_math\"\n", "if on_databricks:\n", " dbutils.fs.mkdirs(\"/FileStore/spark-dl-datasets\")\n", " data_path = \"dbfs:/FileStore/\" + data_path\n", "\n", "df.write.mode(\"overwrite\").json(data_path)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Triton Inference Server\n", "We'll demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL. \n", "We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server. \n", "\n", "The process looks like this:\n", "- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.\n", "- Define a Triton inference function, which contains a client that binds to the local server on a given node and sends inference requests.\n", "- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.\n", "- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.\n", "\n", "\"drawing\"" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "from functools import partial" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Import the helper class from server_utils.py:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "sc.addPyFile(\"server_utils.py\")\n", "\n", "from server_utils import TritonServerManager" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Define the Triton Server function:" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "def triton_server(ports, model_path):\n", " import time\n", " import signal\n", " import numpy as np\n", " import torch\n", " from transformers import pipeline\n", " from pytriton.decorators import batch\n", " from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor\n", " from pytriton.triton import Triton, TritonConfig\n", " from pyspark import TaskContext\n", "\n", " print(f\"SERVER: Initializing model on worker {TaskContext.get().partitionId()}.\")\n", " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", " pipe = pipeline(\"text-generation\", model=model_path, torch_dtype=torch.bfloat16, device=device)\n", " print(f\"SERVER: Using {device} device.\")\n", "\n", " @batch\n", " def _infer_fn(**inputs):\n", " prompts = np.squeeze(inputs[\"prompts\"]).tolist()\n", " decoded_prompts = [p.decode(\"utf-8\") for p in prompts]\n", " # limit responses to 256 tokens, since reasoning tasks can take a while\n", " responses = pipe(decoded_prompts, max_new_tokens=256, temperature=0.2, return_full_text=False)\n", " return {\n", " \"responses\": np.array([r[0]['generated_text'] for r in responses]).reshape(-1, 1)\n", " }\n", "\n", " workspace_path = f\"/tmp/triton_{time.strftime('%m_%d_%M_%S')}\"\n", " triton_conf = TritonConfig(http_port=ports[0], grpc_port=ports[1], metrics_port=ports[2])\n", " with Triton(config=triton_conf, workspace=workspace_path) as triton:\n", " triton.bind(\n", " model_name=\"deepseek-r1\",\n", " infer_func=_infer_fn,\n", " inputs=[\n", " Tensor(name=\"prompts\", dtype=object, shape=(-1,)),\n", " ],\n", " outputs=[\n", " Tensor(name=\"responses\", dtype=object, shape=(-1,)),\n", " ],\n", " config=ModelConfig(\n", " max_batch_size=16,\n", " batcher=DynamicBatcher(max_queue_delay_microseconds=5000), # 5ms\n", " ),\n", " strict=True,\n", " )\n", "\n", " def _stop_triton(signum, frame):\n", " # The server manager sends SIGTERM to stop the server; this function ensures graceful cleanup.\n", " print(\"SERVER: Received SIGTERM. Stopping Triton server.\")\n", " triton.stop()\n", "\n", " signal.signal(signal.SIGTERM, _stop_triton)\n", "\n", " print(\"SERVER: Serving inference\")\n", " triton.serve()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Start Triton servers" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `TritonServerManager` will handle the lifecycle of Triton server instances across the Spark cluster:\n", "- Find available ports for HTTP/gRPC/metrics\n", "- Deploy a server on each node via stage-level scheduling\n", "- Gracefully shutdown servers across nodes" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "model_name = \"deepseek-r1\"\n", "server_manager = TritonServerManager(model_name=model_name, model_path=model_path)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-02-10 09:40:17,442 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\n", "2025-02-10 09:40:17,442 - INFO - Starting 1 servers.\n", " \r" ] }, { "data": { "text/plain": [ "{'cb4ae00-lcedt': (272659, [7000, 7001, 7002])}" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Returns {'hostname', (server_pid, [http_port, grpc_port, metrics_port])}\n", "server_manager.start_servers(triton_server, wait_retries=24) # allow up to 2 minutes for model loading" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Define client function" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Get the hostname -> url mapping from the server manager:" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "host_to_grpc_url = server_manager.host_to_grpc_url # or server_manager.host_to_http_url" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Define the Triton inference function, which returns a predict function for batch inference through the server:" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "def triton_fn(model_name, host_to_url):\n", " import socket\n", " import numpy as np\n", " from pytriton.client import ModelClient\n", "\n", " url = host_to_url[socket.gethostname()]\n", " print(f\"Connecting to Triton model {model_name} at {url}.\")\n", "\n", " def infer_batch(inputs):\n", " with ModelClient(url, model_name, inference_timeout_s=500) as client:\n", " flattened = np.squeeze(inputs).tolist()\n", " # Encode batch\n", " encoded_batch = [[text.encode(\"utf-8\")] for text in flattened]\n", " encoded_batch_np = np.array(encoded_batch, dtype=np.bytes_)\n", " # Run inference\n", " result_data = client.infer_batch(encoded_batch_np)\n", " result_data = np.squeeze(result_data[\"responses\"], -1)\n", " return result_data\n", " \n", " return infer_batch" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "generate = predict_batch_udf(partial(triton_fn, model_name=model_name, host_to_url=host_to_grpc_url),\n", " return_type=StringType(),\n", " input_tensor_shapes=[[1]],\n", " batch_size=2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Load and preprocess DataFrame\n", "\n", "We'll select a few of the shorter questions for demonstration, since reasoning tasks can take a while." ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "df = spark.read.json(data_path)\n", "df = df.filter(length(col(\"question\")) <= 100).limit(16).repartition(8).cache()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Run Inference" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 6:==============> (2 + 6) / 8]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 18.6 ms, sys: 8.31 ms, total: 26.9 ms\n", "Wall time: 1min 46s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "%%time\n", "# first pass caches model/fn\n", "preds = df.withColumn(\"response\", generate(col(\"question\")))\n", "results = preds.collect()" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 23:==================================================> (7 + 1) / 8]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 9.55 ms, sys: 4.51 ms, total: 14.1 ms\n", "Wall time: 1min 45s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "%%time\n", "preds = df.withColumn(\"response\", generate(\"question\"))\n", "results = preds.collect()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Sample output:" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Q: There are 9 dogs and 23 cats. How many more cats are there than dogs? \n", "\n", "A: Let me think. So, I have 23 cats and 9 dogs. To find out how many more cats there are than dogs, I need to subtract the number of dogs from the number of cats. That would be 23 minus 9. Let me do the subtraction: 23 minus 9 is 14. So, there are 14 more cats than dogs.\n", "\n", "Wait, let me double-check that. If I have 9 dogs and 23 cats, subtracting the number of dogs from the number of cats should give me the difference. So, 23 minus 9 is indeed 14. Yeah, that seems right. I don't think I made a mistake there. So, the answer is 14 more cats than dogs.\n", "\n", "**Final Answer**\n", "The number of cats exceeds the number of dogs by \\boxed{14}.\n", "\\boxed{14}\n", "\n", "\n", "To determine how many more cats there are than dogs, we subtract the number of dogs from the number of cats. \n", "\n", "Given:\n", "- Number of cats = 23\n", "- Number of dogs = 9\n", "\n", "The calculation is:\n", "\\[ 23 - 9 = 14 \\]\n", "\n", "Thus, there are 14 more cats than dogs.\n", "\n", "\\[\n", "\\boxed{14}\n", "\\] \n", "\n" ] } ], "source": [ "print(f\"Q: {results[2].question} \\n\")\n", "print(f\"A: {results[2].response} \\n\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Shut down server on each executor" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-02-10 09:43:36,499 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2025-02-10 09:43:41,701 - INFO - Sucessfully stopped 1 servers. \n" ] }, { "data": { "text/plain": [ "[True]" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "server_manager.stop_servers()" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [], "source": [ "if not on_databricks: # on databricks, spark.stop() puts the cluster in a bad state\n", " spark.stop()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "spark-dl-torch", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.11" } }, "nbformat": 4, "nbformat_minor": 2 } ================================================ FILE: examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/gemma-7b_torch.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "\n", "# PySpark LLM Inference: Gemma-7b Code Comprehension\n", "\n", "In this notebook, we demonstrate distributed inference with the Google [Gemma-7b-instruct](https://huggingface.co/google/gemma-7b-it) LLM, using open-weights on Huggingface.\n", "\n", "The Gemma-7b-instruct is an instruction-fine-tuned version of the Gemma-7b base model. We'll show how to use the model to perform code comprehension tasks.\n", "\n", "**Note:** Running this model on GPU with 16-bit precision requires **~18 GB** of GPU RAM. Make sure your instances have sufficient GPU capacity." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Manually enable Huggingface tokenizer parallelism to avoid disabling with PySpark parallelism.\n", "# See (https://github.com/huggingface/transformers/issues/5486) for more info. \n", "import os\n", "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"true\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Check the cluster environment to handle any platform-specific configurations." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "on_databricks = os.environ.get(\"DATABRICKS_RUNTIME_VERSION\", False)\n", "on_dataproc = os.environ.get(\"DATAPROC_IMAGE_VERSION\", False)\n", "on_standalone = not (on_databricks or on_dataproc)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "# For cloud environments, load the model to the distributed file system.\n", "if on_databricks:\n", " models_dir = \"/dbfs/FileStore/spark-dl-models\"\n", " dbutils.fs.mkdirs(\"/FileStore/spark-dl-models\")\n", " model_path = f\"{models_dir}/gemma-7b-it\"\n", "elif on_dataproc:\n", " models_dir = \"/mnt/gcs/spark-dl-models\"\n", " os.mkdir(models_dir) if not os.path.exists(models_dir) else None\n", " model_path = f\"{models_dir}/gemma-7b-it\"\n", "else:\n", " model_path = os.path.abspath(\"gemma-7b-it\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First visit the [Gemma Huggingface repository](https://huggingface.co/google/gemma-7b-it) to accept the terms to access the model, then login via huggingface_hub." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from huggingface_hub import login\n", "\n", "login()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Once you have access, you can download the model:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from huggingface_hub import snapshot_download\n", "\n", "model_path = snapshot_download(\n", " repo_id=\"google/gemma-7b-it\",\n", " local_dir=model_path,\n", " ignore_patterns=\"*.gguf\"\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Warmup: Running locally\n", "\n", "**Note**: If the driver node does not have sufficient GPU capacity, proceed to the PySpark section." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "58494ca5858c40e39f924ad330a65885", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/4 [00:00Write me a poem about Apache Spark.\n", "\n", "In the realm of big data, a spark ignites,\n", "A framework born to conquer the night.\n", "Apache Spark, a lightning-fast tool,\n", "For processing data, swift and cool.\n", "\n", "With its resilient distributed architecture,\n", "It slices through terabytes with grace.\n", "No longer bound by memory's plight,\n", "Spark empowers us to analyze with might.\n", "\n", "From Python to Scala, it's a versatile spark,\n", "Unveiling insights hidden in the dark.\n", "\n" ] } ], "source": [ "input_text = \"Write me a poem about Apache Spark.\"\n", "inputs = tokenizer(input_text, return_tensors=\"pt\").to(\"cuda\")\n", "\n", "outputs = model.generate(**inputs, max_new_tokens=100, temperature=0.1, do_sample=True)\n", "print(tokenizer.decode(outputs[0]))" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import torch\n", "\n", "# Unload the model from GPU memory.\n", "del model\n", "torch.cuda.empty_cache()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## PySpark" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from pyspark.sql.types import *\n", "from pyspark import SparkConf\n", "from pyspark.sql import SparkSession\n", "from pyspark.sql.functions import *\n", "from pyspark.ml.functions import predict_batch_udf" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "import os\n", "import pandas as pd\n", "import datasets\n", "from datasets import load_dataset\n", "datasets.disable_progress_bars()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Create Spark Session\n", "\n", "For local standalone clusters, we'll connect to the cluster and create the Spark Session. \n", "For CSP environments, Spark will either be preconfigured (Databricks) or we'll need to create the Spark Session (Dataproc)." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "25/02/10 09:44:33 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)\n", "25/02/10 09:44:33 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\n", "Setting default log level to \"WARN\".\n", "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n", "25/02/10 09:44:33 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n" ] } ], "source": [ "conf = SparkConf()\n", "\n", "if 'spark' not in globals():\n", " if on_standalone:\n", " import socket\n", " conda_env = os.environ.get(\"CONDA_PREFIX\")\n", " hostname = socket.gethostname()\n", " conf.setMaster(f\"spark://{hostname}:7077\")\n", " conf.set(\"spark.pyspark.python\", f\"{conda_env}/bin/python\")\n", " conf.set(\"spark.pyspark.driver.python\", f\"{conda_env}/bin/python\")\n", "\n", " conf.set(\"spark.executor.cores\", \"8\")\n", " conf.set(\"spark.task.maxFailures\", \"1\")\n", " conf.set(\"spark.task.resource.gpu.amount\", \"0.125\")\n", " conf.set(\"spark.executor.resource.gpu.amount\", \"1\")\n", " conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"true\")\n", " conf.set(\"spark.python.worker.reuse\", \"true\")\n", "\n", "spark = SparkSession.builder.appName(\"spark-dl-examples\").config(conf=conf).getOrCreate()\n", "sc = spark.sparkContext" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Load DataFrame\n", "\n", "Load the first 500 samples of the [Code Comprehension dataset](https://huggingface.co/datasets/imbue/code-comprehension) from Huggingface and store in a Spark Dataframe." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "dataset = load_dataset(\"imbue/code-comprehension\", split=\"train\", streaming=True)\n", "dataset = pd.Series([sample[\"question\"] for sample in dataset.take(500)])" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "df = spark.createDataFrame(dataset, schema=StringType()).withColumnRenamed(\"value\", \"prompt\")" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+----------------------------------------------------------------------------------------------------+\n", "| prompt|\n", "+----------------------------------------------------------------------------------------------------+\n", "|If we execute the code below, what will `result` be equal to?\\n\\n```python\\nN = 'quz'\\nN += 'bar'...|\n", "|```python\\nresult = 9 - 9 - 1 - 7 - 9 - 1 + 9 - 2 + 6 - 4 - 8 - 1\\n```\\n\\nOut of these options, w...|\n", "|```python\\nx = 'bas'\\nD = 'bar'.swapcase()\\nx = len(x)\\nx = str(x)\\nnu = 'bar'.isnumeric()\\nx += ...|\n", "|If we execute the code below, what will `result` be equal to?\\n\\n```python\\n\\nl = 'likewise'\\nmat...|\n", "|```python\\nresult = 'mazda' + 'isolated' + 'mistakes' + 'grew' + 'raid' + 'junk' + 'jamaica' + 'c...|\n", "+----------------------------------------------------------------------------------------------------+\n", "only showing top 5 rows\n", "\n" ] } ], "source": [ "df.show(5, truncate=100)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "If we execute the code below, what will `result` be equal to?\n", "\n", "```python\n", "N = 'quz'\n", "N += 'bar'\n", "N = N.swapcase()\n", "N = len(N)\n", "mu = 'bar'.strip()\n", "N = str(N)\n", "Q = N.isalpha()\n", "if N == 'bawr':\n", " N = 'BAWR'.lower()\n", "N = N + N\n", "N = '-'.join([N, N, N, 'foo'])\n", "if mu == N:\n", " N = 'bar'.upper()\n", "gamma = 'BAZ'.lower()\n", "\n", "result = N\n", "```\n" ] } ], "source": [ "print(df.take(1)[0].prompt)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "data_path = \"spark-dl-datasets/code_comprehension\"\n", "if on_databricks:\n", " dbutils.fs.mkdirs(\"/FileStore/spark-dl-datasets\")\n", " data_path = \"dbfs:/FileStore/\" + data_path\n", "\n", "df.write.mode(\"overwrite\").json(data_path)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Using Triton Inference Server\n", "In this section, we demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL. \n", "We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server. \n", "\n", "The process looks like this:\n", "- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.\n", "- Define a Triton inference function, which contains a client that binds to the local server on a given node and sends inference requests.\n", "- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.\n", "- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.\n", "\n", "\"drawing\"" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "from functools import partial" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Import the helper class from server_utils.py:" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "sc.addPyFile(\"server_utils.py\")\n", "\n", "from server_utils import TritonServerManager" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Define the Triton Server function:" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [], "source": [ "def triton_server(ports, model_path):\n", " import time\n", " import signal\n", " import numpy as np\n", " import torch\n", " from transformers import AutoTokenizer, AutoModelForCausalLM\n", " from pytriton.decorators import batch\n", " from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor\n", " from pytriton.triton import Triton, TritonConfig\n", " from pyspark import TaskContext\n", "\n", " print(f\"SERVER: Initializing model on worker {TaskContext.get().partitionId()}.\")\n", " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", " tokenizer = AutoTokenizer.from_pretrained(model_path)\n", " model = AutoModelForCausalLM.from_pretrained(model_path, device_map=\"auto\", torch_dtype=torch.bfloat16)\n", " print(f\"SERVER: Using {device} device.\")\n", "\n", " @batch\n", " def _infer_fn(**inputs):\n", " prompts = np.squeeze(inputs[\"prompts\"]).tolist()\n", " print(f\"SERVER: Received batch of size {len(prompts)}\")\n", " decoded_prompts = [p.decode(\"utf-8\") for p in prompts]\n", " tokenized_inputs = tokenizer(decoded_prompts, padding=True, return_tensors=\"pt\").to(device)\n", " outputs = model.generate(**tokenized_inputs, max_new_tokens=256, temperature=0.1, do_sample=True)\n", " # Decode only the model output (excluding the input prompt) and remove special tokens.\n", " responses = np.array(tokenizer.batch_decode(outputs[:, tokenized_inputs.input_ids.shape[1]:], skip_special_tokens = True))\n", " return {\n", " \"responses\": responses.reshape(-1, 1),\n", " }\n", "\n", " workspace_path = f\"/tmp/triton_{time.strftime('%m_%d_%M_%S')}\"\n", " triton_conf = TritonConfig(http_port=ports[0], grpc_port=ports[1], metrics_port=ports[2])\n", " with Triton(config=triton_conf, workspace=workspace_path) as triton:\n", " triton.bind(\n", " model_name=\"gemma-7b\",\n", " infer_func=_infer_fn,\n", " inputs=[\n", " Tensor(name=\"prompts\", dtype=object, shape=(-1,)),\n", " ],\n", " outputs=[\n", " Tensor(name=\"responses\", dtype=object, shape=(-1,)),\n", " ],\n", " config=ModelConfig(\n", " max_batch_size=16,\n", " batcher=DynamicBatcher(max_queue_delay_microseconds=5000), # 5ms\n", " ),\n", " strict=True,\n", " )\n", "\n", " def _stop_triton(signum, frame):\n", " # The server manager sends SIGTERM to stop the server; this function ensures graceful cleanup.\n", " print(\"SERVER: Received SIGTERM. Stopping Triton server.\")\n", " triton.stop()\n", "\n", " signal.signal(signal.SIGTERM, _stop_triton)\n", "\n", " print(\"SERVER: Serving inference\")\n", " triton.serve()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Start Triton servers" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `TritonServerManager` will handle the lifecycle of Triton server instances across the Spark cluster:\n", "- Find available ports for HTTP/gRPC/metrics\n", "- Deploy a server on each node via stage-level scheduling\n", "- Gracefully shutdown servers across nodes" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [], "source": [ "model_name = \"gemma-7b\"\n", "server_manager = TritonServerManager(model_name=model_name, model_path=model_path)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-02-10 09:06:38,803 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\n", "2025-02-10 09:06:38,805 - INFO - Starting 1 servers.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "data": { "text/plain": [ "{'cb4ae00-lcedt': (252119, [7000, 7001, 7002])}" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Returns {'hostname', (server_pid, [http_port, grpc_port, metrics_port])}\n", "server_manager.start_servers(triton_server, wait_retries=24) # allow up to 2 minutes for model loading" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Define client function" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Get the hostname -> url mapping from the server manager:" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [], "source": [ "host_to_grpc_url = server_manager.host_to_grpc_url # or server_manager.host_to_http_url" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Define the Triton inference function, which returns a predict function for batch inference through the server:" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [], "source": [ "def triton_fn(model_name, host_to_url):\n", " import socket\n", " import numpy as np\n", " from pytriton.client import ModelClient\n", "\n", " url = host_to_url[socket.gethostname()]\n", " print(f\"Connecting to Triton model {model_name} at {url}.\")\n", "\n", " def infer_batch(inputs):\n", " with ModelClient(url, model_name, inference_timeout_s=500) as client:\n", " flattened = np.squeeze(inputs).tolist()\n", " # Encode batch\n", " encoded_batch = [[text.encode(\"utf-8\")] for text in flattened]\n", " encoded_batch_np = np.array(encoded_batch, dtype=np.bytes_)\n", " # Run inference\n", " result_data = client.infer_batch(encoded_batch_np)\n", " result_data = np.squeeze(result_data[\"responses\"], -1)\n", " return result_data\n", " \n", " return infer_batch" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [], "source": [ "generate = predict_batch_udf(partial(triton_fn, model_name=model_name, host_to_url=host_to_grpc_url),\n", " return_type=StringType(),\n", " input_tensor_shapes=[[1]],\n", " batch_size=4)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Load and preprocess DataFrame\n", "\n", "We'll parallelize over a small set of questions for demonstration." ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [], "source": [ "df = spark.read.json(data_path).limit(32).repartition(8)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Run Inference" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 30:====================================> (5 + 3) / 8]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 5.6 ms, sys: 3.51 ms, total: 9.11 ms\n", "Wall time: 28.1 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "%%time\n", "# first pass caches model/fn\n", "preds = df.withColumn(\"response\", generate(col(\"prompt\")))\n", "results = preds.collect()" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 42:=============================> (4 + 4) / 8]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 8.12 ms, sys: 3.13 ms, total: 11.2 ms\n", "Wall time: 23.1 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "%%time\n", "preds = df.withColumn(\"response\", generate(\"prompt\"))\n", "results = preds.collect()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Sample output:" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Q: ```python\n", "result = ['mirrors', 'limousines', 'meaningful', 'cats', UNKNOWN, 'striking', 'wings', 'injured', 'wishlist', 'granny'].index('oracle')\n", "print(result)\n", "```\n", "\n", "The code above has one or more parts replaced with the word UNKNOWN. Knowing that running the code prints `4` to the console, what should go in place of UNKNOWN? \n", "\n", "A: \n", "\n", "The answer is `oracle`.\n", "\n", "The code is searching for the index of the word `oracle` in the list `result`, and the index is returned as `4`. \n", "\n" ] } ], "source": [ "print(f\"Q: {results[2].prompt} \\n\")\n", "print(f\"A: {results[2].response} \\n\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Shut down server on each executor" ] }, { "cell_type": "code", "execution_count": 55, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-02-10 09:11:11,880 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2025-02-10 09:11:17,105 - INFO - Sucessfully stopped 1 servers. \n" ] }, { "data": { "text/plain": [ "[True]" ] }, "execution_count": 55, "metadata": {}, "output_type": "execute_result" } ], "source": [ "server_manager.stop_servers()" ] }, { "cell_type": "code", "execution_count": 56, "metadata": {}, "outputs": [], "source": [ "if not on_databricks: # on databricks, spark.stop() puts the cluster in a bad state\n", " spark.stop()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "spark-dl-torch", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.11" } }, "nbformat": 4, "nbformat_minor": 2 } ================================================ FILE: examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/pipelines_tf.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "id": "9e9fe848", "metadata": {}, "source": [ "\n", "\n", "# PySpark Huggingface Inferencing\n", "### Sentiment Analysis using Pipelines with Tensorflow\n", "\n", "In this notebook, we demonstrate distributed inference with Huggingface Pipelines to perform sentiment analysis. \n", "From: https://huggingface.co/docs/transformers/quicktour#pipeline-usage" ] }, { "cell_type": "markdown", "id": "1799fd4f", "metadata": {}, "source": [ "Note that cuFFT/cuDNN/cuBLAS registration errors are expected (as of `tf=2.17.0`) and will not affect behavior, as noted in [this issue.](https://github.com/tensorflow/tensorflow/issues/62075) " ] }, { "cell_type": "code", "execution_count": 1, "id": "0dd0f77b-ee1b-4477-a038-d25a4f1da0ea", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-02-04 13:57:08.242673: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", "2025-02-04 13:57:08.249833: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "2025-02-04 13:57:08.257735: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "2025-02-04 13:57:08.259994: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", "2025-02-04 13:57:08.266655: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", "To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", "2025-02-04 13:57:08.649929: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n" ] } ], "source": [ "import tensorflow as tf\n", "from transformers import pipeline\n", "\n", "# Manually enable Huggingface tokenizer parallelism to avoid disabling with PySpark parallelism.\n", "# See (https://github.com/huggingface/transformers/issues/5486) for more info. \n", "import os\n", "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"true\"" ] }, { "cell_type": "code", "execution_count": 2, "id": "d80fc3f8", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", "I0000 00:00:1738706229.309141 3668091 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1738706229.333555 3668091 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1738706229.336487 3668091 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n" ] } ], "source": [ "device = 0 if tf.config.list_physical_devices('GPU') else -1" ] }, { "cell_type": "code", "execution_count": 3, "id": "e60a2877", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.17.0\n" ] } ], "source": [ "# Enable GPU memory growth\n", "gpus = tf.config.experimental.list_physical_devices('GPU')\n", "if gpus:\n", " try:\n", " for gpu in gpus:\n", " tf.config.experimental.set_memory_growth(gpu, True)\n", " except RuntimeError as e:\n", " print(e)\n", "\n", "print(tf.__version__)" ] }, { "cell_type": "code", "execution_count": 4, "id": "553b28d2-a5d1-4d07-8a49-8f82b808e738", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "No model was supplied, defaulted to distilbert/distilbert-base-uncased-finetuned-sst-2-english and revision 714eb0f (https://huggingface.co/distilbert/distilbert-base-uncased-finetuned-sst-2-english).\n", "Using a pipeline without specifying a model name and revision in production is not recommended.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "I0000 00:00:1738706229.617170 3668091 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1738706229.620218 3668091 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1738706229.622781 3668091 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1738706229.732012 3668091 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1738706229.733045 3668091 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1738706229.733965 3668091 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "2025-02-04 13:57:09.734873: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 43096 MB memory: -> device: 0, name: NVIDIA RTX A6000, pci bus id: 0000:01:00.0, compute capability: 8.6\n", "All PyTorch model weights were used when initializing TFDistilBertForSequenceClassification.\n", "\n", "All the weights of TFDistilBertForSequenceClassification were initialized from the PyTorch model.\n", "If your task is similar to the task the model of the checkpoint was trained on, you can already use TFDistilBertForSequenceClassification for predictions without further training.\n" ] } ], "source": [ "classifier = pipeline(\"sentiment-analysis\", device=device)" ] }, { "cell_type": "code", "execution_count": null, "id": "3b91fe91-b725-4564-ae93-56e3fb51e47c", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[{'label': 'POSITIVE', 'score': 0.9997794032096863}]" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "classifier((\"We are very happy to show you the 🤗 Transformers library.\"))" ] }, { "cell_type": "code", "execution_count": null, "id": "0be39eb3-462c-42ff-b8f4-09f4e4fe3a3c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "label: POSITIVE, with score: 0.9998\n", "label: NEGATIVE, with score: 0.5282\n" ] } ], "source": [ "results = classifier([\"We are very happy to show you the 🤗 Transformers library.\", \"We hope you don't hate it.\"])\n", "for result in results:\n", " print(f\"label: {result['label']}, with score: {round(result['score'], 4)}\")" ] }, { "cell_type": "markdown", "id": "e29ee6d8", "metadata": {}, "source": [ "Let's try a different model and tokenizer in the pipeline." ] }, { "cell_type": "code", "execution_count": 7, "id": "cd9d3349", "metadata": {}, "outputs": [], "source": [ "model_name = \"nlptown/bert-base-multilingual-uncased-sentiment\"" ] }, { "cell_type": "code", "execution_count": 8, "id": "99e21b58", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "All PyTorch model weights were used when initializing TFBertForSequenceClassification.\n", "\n", "All the weights of TFBertForSequenceClassification were initialized from the PyTorch model.\n", "If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertForSequenceClassification for predictions without further training.\n" ] } ], "source": [ "from transformers import AutoTokenizer, TFAutoModelForSequenceClassification\n", "\n", "model = TFAutoModelForSequenceClassification.from_pretrained(model_name)\n", "tokenizer = AutoTokenizer.from_pretrained(model_name)" ] }, { "cell_type": "code", "execution_count": null, "id": "31079133", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[{'label': '5 stars', 'score': 0.7272477746009827}]" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "classifier = pipeline(\"sentiment-analysis\", model=model, tokenizer=tokenizer, device=device)\n", "classifier(\"Nous sommes très heureux de vous présenter la bibliothèque 🤗 Transformers.\")" ] }, { "cell_type": "markdown", "id": "e6357234", "metadata": {}, "source": [ "## PySpark" ] }, { "cell_type": "code", "execution_count": 10, "id": "69dd6a1a-f450-47f0-9dbf-ad250585a011", "metadata": {}, "outputs": [], "source": [ "from pyspark.sql.functions import col, struct, pandas_udf\n", "from pyspark.ml.functions import predict_batch_udf\n", "from pyspark.sql.types import *\n", "from pyspark.sql import SparkSession\n", "from pyspark import SparkConf" ] }, { "cell_type": "code", "execution_count": 11, "id": "287b1e96", "metadata": {}, "outputs": [], "source": [ "import os\n", "import json\n", "import pandas as pd\n", "import datasets\n", "from datasets import load_dataset\n", "datasets.disable_progress_bars()" ] }, { "cell_type": "markdown", "id": "50e124cd", "metadata": {}, "source": [ "Check the cluster environment to handle any platform-specific Spark configurations." ] }, { "cell_type": "code", "execution_count": 12, "id": "36001f55", "metadata": {}, "outputs": [], "source": [ "on_databricks = os.environ.get(\"DATABRICKS_RUNTIME_VERSION\", False)\n", "on_dataproc = os.environ.get(\"DATAPROC_IMAGE_VERSION\", False)\n", "on_standalone = not (on_databricks or on_dataproc)" ] }, { "cell_type": "markdown", "id": "48c7271a", "metadata": {}, "source": [ "#### Create Spark Session\n", "\n", "For local standalone clusters, we'll connect to the cluster and create the Spark Session. \n", "For CSP environments, Spark will either be preconfigured (Databricks) or we'll need to create the Spark Session (Dataproc)." ] }, { "cell_type": "code", "execution_count": 13, "id": "6e0e0dd7", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "25/02/04 13:57:12 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)\n", "25/02/04 13:57:12 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\n", "Setting default log level to \"WARN\".\n", "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n", "25/02/04 13:57:12 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n" ] } ], "source": [ "conf = SparkConf()\n", "\n", "if 'spark' not in globals():\n", " if on_standalone:\n", " import socket\n", " \n", " conda_env = os.environ.get(\"CONDA_PREFIX\")\n", " hostname = socket.gethostname()\n", " conf.setMaster(f\"spark://{hostname}:7077\")\n", " conf.set(\"spark.pyspark.python\", f\"{conda_env}/bin/python\")\n", " conf.set(\"spark.pyspark.driver.python\", f\"{conda_env}/bin/python\")\n", " elif on_dataproc:\n", " conf.set(\"spark.executorEnv.TF_GPU_ALLOCATOR\", \"cuda_malloc_async\")\n", "\n", " conf.set(\"spark.executor.cores\", \"8\")\n", " conf.set(\"spark.task.resource.gpu.amount\", \"0.125\")\n", " conf.set(\"spark.executor.resource.gpu.amount\", \"1\")\n", " conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"true\")\n", " conf.set(\"spark.python.worker.reuse\", \"true\")\n", "\n", "conf.set(\"spark.sql.execution.arrow.maxRecordsPerBatch\", \"1000\")\n", "spark = SparkSession.builder.appName(\"spark-dl-examples\").config(conf=conf).getOrCreate()\n", "sc = spark.sparkContext" ] }, { "cell_type": "code", "execution_count": 14, "id": "42d70208", "metadata": {}, "outputs": [], "source": [ "dataset = load_dataset(\"imdb\", split=\"test\")\n", "dataset = dataset.to_pandas().drop(columns=\"label\")" ] }, { "cell_type": "markdown", "id": "95ded4b2", "metadata": {}, "source": [ "#### Create PySpark DataFrame" ] }, { "cell_type": "code", "execution_count": 15, "id": "ac24f3c2", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "StructType([StructField('text', StringType(), True)])" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = spark.createDataFrame(dataset).repartition(8)\n", "df.schema" ] }, { "cell_type": "code", "execution_count": 16, "id": "1db4db3a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "25000" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.count()" ] }, { "cell_type": "code", "execution_count": 17, "id": "517fe2e9", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "25/02/04 13:57:20 WARN TaskSetManager: Stage 6 contains a task of very large size (4021 KiB). The maximum recommended task size is 1000 KiB.\n" ] }, { "data": { "text/plain": [ "[Row(text=\"Anyone remember the first CKY, CKY2K etc..? Back when it was about making crazy cool stuff, rather than watching Bam Margera act like a douchebag, spoiled 5 year old, super/rock-star wannabe.

The show used to be awesome, however, Bam's fame and wealth has led him to believe, that we now enjoy him acting childish and idiotic, more than actual cool stuff, that used to be in ex. CKY2K.

The acts are so repetitive, there's like nothing new, except annoying stupidity and rehearsed comments... The only things we see is Bam Margera, so busy showing us how much he doesn't care, how much money he got or whatsoever.

I really got nothing much left to say except, give us back CKY2K, cause Bam suck..

I enjoy watching Steve-o, Knoxville etc. a thousand times more.\")]" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.take(1)" ] }, { "cell_type": "code", "execution_count": 18, "id": "e176d28b", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "25/02/04 13:57:20 WARN TaskSetManager: Stage 9 contains a task of very large size (4021 KiB). The maximum recommended task size is 1000 KiB.\n" ] } ], "source": [ "data_path = \"spark-dl-datasets/imdb_test\"\n", "if on_databricks:\n", " dbutils.fs.mkdirs(\"/FileStore/spark-dl-datasets\")\n", " data_path = \"dbfs:/FileStore/\" + data_path\n", "\n", "df.write.mode(\"overwrite\").parquet(data_path)" ] }, { "cell_type": "markdown", "id": "395e0374", "metadata": {}, "source": [ "#### Load and preprocess DataFrame\n", "\n", "Define our preprocess function. We'll take the first sentence from each sample as our input for sentiment analysis." ] }, { "cell_type": "code", "execution_count": 19, "id": "9665b7b6-d7e9-4bd4-b29d-7a449ac5b574", "metadata": {}, "outputs": [], "source": [ "@pandas_udf(\"string\")\n", "def preprocess(text: pd.Series) -> pd.Series:\n", " return pd.Series([s.split(\".\")[0] for s in text])" ] }, { "cell_type": "code", "execution_count": 20, "id": "26693020", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+----------------------------------------------------------------------------------------------------+\n", "| input|\n", "+----------------------------------------------------------------------------------------------------+\n", "|Doesn't anyone bother to check where this kind of sludge comes from before blathering on about it...|\n", "| There were two things I hated about WASTED : The directing and the script |\n", "| I'm rather surprised that anybody found this film touching or moving|\n", "|Cultural Vandalism Is the new Hallmark production of Gulliver's Travels an act of cultural vandal...|\n", "|I was at Wrestlemania VI in Toronto as a 10 year old, and the event I saw then was pretty differe...|\n", "| This movie has been done before|\n", "|[ as a new resolution for this year 2005, i decide to write a comment for each movie I saw in the...|\n", "|This movie is over hyped!! I am sad to say that I manage to watch the first 15 minutes of this mo...|\n", "|This show had a promising start as sort of the opposite of 'Oceans 11' but has developed into a s...|\n", "|MINOR PLOT SPOILERS AHEAD!!!

How did such talented actors get involved in such mindles...|\n", "| There is not one character on this sitcom with any redeeming qualities|\n", "| Tommy Lee Jones was the best Woodroe and no one can play Woodroe F|\n", "| My wife rented this movie and then conveniently never got to see it|\n", "|This is one of those star-filled over-the-top comedies that could a) be hysterical, or b) wish th...|\n", "|This excruciatingly boring and unfunny movie made me think that Chaplin was the real Hitler, as o...|\n", "| you will likely be sorely disappointed by this sequel that's not a sequel|\n", "| If I was British, I would be embarrassed by this portrayal of incompetence|\n", "|One of those movies in which there are no big twists whatsoever and you can predict pretty much w...|\n", "| This show is like watching someone who is in training to someday host a show|\n", "| Sigh|\n", "+----------------------------------------------------------------------------------------------------+\n", "only showing top 20 rows\n", "\n" ] } ], "source": [ "# Limit to N rows, since this can be slow\n", "df = spark.read.parquet(data_path).limit(256).repartition(8)\n", "df = df.select(preprocess(col(\"text\")).alias(\"input\")).cache()\n", "df.show(truncate=100)" ] }, { "cell_type": "markdown", "id": "76dc525c", "metadata": {}, "source": [ "## Inference using Spark DL API\n", "\n", "Distributed inference using the PySpark [predict_batch_udf](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.functions.predict_batch_udf.html#pyspark.ml.functions.predict_batch_udf):\n", "\n", "- predict_batch_fn uses Tensorflow APIs to load the model and return a predict function which operates on numpy arrays \n", "- predict_batch_udf will convert the Spark DataFrame columns into numpy input batches for the predict function" ] }, { "cell_type": "code", "execution_count": 21, "id": "0da9d25c-5ebe-4503-bb19-154fcc047cbf", "metadata": {}, "outputs": [], "source": [ "def predict_batch_fn():\n", " import tensorflow as tf\n", " from transformers import pipeline\n", "\n", " # Enable GPU memory growth\n", " gpus = tf.config.experimental.list_physical_devices('GPU')\n", " if gpus:\n", " try:\n", " for gpu in gpus:\n", " tf.config.experimental.set_memory_growth(gpu, True)\n", " except RuntimeError as e:\n", " print(e)\n", " \n", " device = 0 if tf.config.list_physical_devices('GPU') else -1\n", " pipe = pipeline(\"sentiment-analysis\", device=device)\n", " def predict(inputs):\n", " return pipe(inputs.tolist())\n", " return predict" ] }, { "cell_type": "code", "execution_count": 22, "id": "78afef29-ee30-4267-9fb6-be2dcb86cbba", "metadata": {}, "outputs": [], "source": [ "classify = predict_batch_udf(predict_batch_fn,\n", " return_type=StructType([\n", " StructField(\"label\", StringType(), True),\n", " StructField(\"score\", FloatType(), True)\n", " ]),\n", " batch_size=32)" ] }, { "cell_type": "code", "execution_count": 23, "id": "a5bc327e-89cf-4731-82e6-e66cb93deef1", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 18:=======> (1 + 7) / 8]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 8.06 ms, sys: 2.92 ms, total: 11 ms\n", "Wall time: 4.86 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "%%time\n", "# first pass caches model/fn\n", "# note: expanding the \"struct\" return_type to top-level columns\n", "preds = df.withColumn(\"preds\", classify(struct(\"input\"))).select(\"input\", \"preds.*\")\n", "results = preds.collect()" ] }, { "cell_type": "code", "execution_count": 24, "id": "ac642895-cfd6-47ee-9b21-02e7835424e4", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 21:> (0 + 8) / 8]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 4.5 ms, sys: 1.43 ms, total: 5.93 ms\n", "Wall time: 1.19 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "%%time\n", "preds = df.withColumn(\"preds\", classify(\"input\")).select(\"input\", \"preds.*\")\n", "results = preds.collect()" ] }, { "cell_type": "code", "execution_count": 25, "id": "76a44d80-d5db-405f-989c-7246379cfb95", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 24:> (0 + 8) / 8]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 5.9 ms, sys: 605 μs, total: 6.5 ms\n", "Wall time: 1.37 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "%%time\n", "preds = df.withColumn(\"preds\", classify(col(\"input\"))).select(\"input\", \"preds.*\")\n", "results = preds.collect()" ] }, { "cell_type": "code", "execution_count": 26, "id": "c01761b3-c766-46b0-ae0b-fcf968ffb3a1", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 27:> (0 + 1) / 1]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "+--------------------------------------------------------------------------------+--------+----------+\n", "| input| label| score|\n", "+--------------------------------------------------------------------------------+--------+----------+\n", "|Doesn't anyone bother to check where this kind of sludge comes from before bl...|NEGATIVE| 0.9984061|\n", "| There were two things I hated about WASTED : The directing and the script |NEGATIVE| 0.9979007|\n", "| I'm rather surprised that anybody found this film touching or moving|POSITIVE|0.83874947|\n", "|Cultural Vandalism Is the new Hallmark production of Gulliver's Travels an ac...|NEGATIVE|0.99727434|\n", "|I was at Wrestlemania VI in Toronto as a 10 year old, and the event I saw the...|POSITIVE| 0.982114|\n", "| This movie has been done before|NEGATIVE|0.94210696|\n", "|[ as a new resolution for this year 2005, i decide to write a comment for eac...|NEGATIVE| 0.9967818|\n", "|This movie is over hyped!! I am sad to say that I manage to watch the first 1...|NEGATIVE| 0.9985843|\n", "|This show had a promising start as sort of the opposite of 'Oceans 11' but ha...|NEGATIVE|0.99926835|\n", "|MINOR PLOT SPOILERS AHEAD!!!

How did such talented actors get invo...|NEGATIVE|0.99956733|\n", "| There is not one character on this sitcom with any redeeming qualities|NEGATIVE| 0.9985662|\n", "| Tommy Lee Jones was the best Woodroe and no one can play Woodroe F|POSITIVE| 0.994562|\n", "| My wife rented this movie and then conveniently never got to see it|NEGATIVE|0.99841607|\n", "|This is one of those star-filled over-the-top comedies that could a) be hyste...|NEGATIVE| 0.9953243|\n", "|This excruciatingly boring and unfunny movie made me think that Chaplin was t...|NEGATIVE| 0.9997607|\n", "| you will likely be sorely disappointed by this sequel that's not a sequel|NEGATIVE| 0.9997198|\n", "| If I was British, I would be embarrassed by this portrayal of incompetence|NEGATIVE| 0.9965172|\n", "|One of those movies in which there are no big twists whatsoever and you can p...|NEGATIVE| 0.9986059|\n", "| This show is like watching someone who is in training to someday host a show|NEGATIVE|0.97015846|\n", "| Sigh|NEGATIVE| 0.9923151|\n", "+--------------------------------------------------------------------------------+--------+----------+\n", "only showing top 20 rows\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "preds.show(truncate=80)" ] }, { "cell_type": "markdown", "id": "fc8127d9", "metadata": {}, "source": [ "## Using Triton Inference Server\n", "In this section, we demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL. \n", "We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server. \n", "\n", "The process looks like this:\n", "- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.\n", "- Define a Triton inference function, which contains a client that binds to the local server on a given node and sends inference requests.\n", "- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.\n", "- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.\n", "\n", "\"drawing\"" ] }, { "cell_type": "code", "execution_count": 27, "id": "4d4be844-4b8c-47df-bd09-0c280c7ff16b", "metadata": {}, "outputs": [], "source": [ "from functools import partial" ] }, { "cell_type": "markdown", "id": "4f15dfcb", "metadata": {}, "source": [ "Import the helper class from server_utils.py:" ] }, { "cell_type": "code", "execution_count": 28, "id": "bfa7ec9d", "metadata": {}, "outputs": [], "source": [ "sc.addPyFile(\"server_utils.py\")\n", "\n", "from server_utils import TritonServerManager" ] }, { "cell_type": "markdown", "id": "1bf04546", "metadata": {}, "source": [ "Define the Triton Server function:" ] }, { "cell_type": "code", "execution_count": 29, "id": "7e53df9f-43cb-4c38-b8ac-dc2cbad99815", "metadata": {}, "outputs": [], "source": [ "def triton_server(ports):\n", " import time\n", " import signal\n", " import numpy as np\n", " import tensorflow as tf\n", " from transformers import pipeline\n", " from pytriton.decorators import batch\n", " from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor\n", " from pytriton.triton import Triton, TritonConfig\n", " from pyspark import TaskContext\n", "\n", " print(f\"SERVER: Initializing pipeline on worker {TaskContext.get().partitionId()}.\")\n", " # Enable GPU memory growth\n", " gpus = tf.config.experimental.list_physical_devices('GPU')\n", " if gpus:\n", " try:\n", " for gpu in gpus:\n", " tf.config.experimental.set_memory_growth(gpu, True)\n", " except RuntimeError as e:\n", " print(e)\n", " \n", " device = 0 if tf.config.list_physical_devices('GPU') else -1\n", " \n", " pipe = pipeline(\"sentiment-analysis\", device=device)\n", " print(f\"SERVER: Using {device} device.\")\n", "\n", " @batch\n", " def _infer_fn(**inputs):\n", " sentences = np.squeeze(inputs[\"text\"]).tolist()\n", " print(f\"SERVER: Received batch of size {len(sentences)}\")\n", " decoded_sentences = [s.decode(\"utf-8\") for s in sentences]\n", " return {\n", " \"outputs\": np.array([[json.dumps(o)] for o in pipe(decoded_sentences)])\n", " }\n", "\n", " workspace_path = f\"/tmp/triton_{time.strftime('%m_%d_%M_%S')}\"\n", " triton_conf = TritonConfig(http_port=ports[0], grpc_port=ports[1], metrics_port=ports[2])\n", " with Triton(config=triton_conf, workspace=workspace_path) as triton:\n", " triton.bind(\n", " model_name=\"SentimentAnalysis\",\n", " infer_func=_infer_fn,\n", " inputs=[\n", " Tensor(name=\"text\", dtype=object, shape=(-1,)),\n", " ],\n", " outputs=[\n", " Tensor(name=\"outputs\", dtype=object, shape=(-1,)),\n", " ],\n", " config=ModelConfig(\n", " max_batch_size=64,\n", " batcher=DynamicBatcher(max_queue_delay_microseconds=5000), # 5ms\n", " ),\n", " strict=True,\n", " )\n", "\n", " def _stop_triton(signum, frame):\n", " print(\"SERVER: Received SIGTERM. Stopping Triton server.\")\n", " triton.stop()\n", "\n", " signal.signal(signal.SIGTERM, _stop_triton)\n", "\n", " print(\"SERVER: Serving inference\")\n", " triton.serve()" ] }, { "cell_type": "markdown", "id": "19d9028d", "metadata": {}, "source": [ "#### Start Triton servers" ] }, { "cell_type": "markdown", "id": "5354c597", "metadata": {}, "source": [ "The `TritonServerManager` will handle the lifecycle of Triton server instances across the Spark cluster:\n", "- Find available ports for HTTP/gRPC/metrics\n", "- Deploy a server on each node via stage-level scheduling\n", "- Gracefully shutdown servers across nodes" ] }, { "cell_type": "code", "execution_count": null, "id": "156de815", "metadata": {}, "outputs": [], "source": [ "model_name = \"SentimentAnalysis\"\n", "server_manager = TritonServerManager(model_name=model_name)" ] }, { "cell_type": "code", "execution_count": null, "id": "d003a862", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-02-07 11:03:44,809 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\n", "2025-02-07 11:03:44,810 - INFO - Starting 1 servers.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "data": { "text/plain": [ "{'cb4ae00-lcedt': (2020631, [7000, 7001, 7002])}" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Returns {'hostname', (server_pid, [http_port, grpc_port, metrics_port])}\n", "server_manager.start_servers(triton_server)" ] }, { "cell_type": "markdown", "id": "e4c4017c", "metadata": {}, "source": [ "#### Define client function" ] }, { "cell_type": "markdown", "id": "405edc49", "metadata": {}, "source": [ "Get the hostname -> url mapping from the server manager:" ] }, { "cell_type": "code", "execution_count": null, "id": "19768ddb", "metadata": {}, "outputs": [], "source": [ "host_to_http_url = server_manager.host_to_http_url # or server_manager.host_to_grpc_url" ] }, { "cell_type": "markdown", "id": "eb5dbb89", "metadata": {}, "source": [ "Define the Triton inference function, which returns a predict function for batch inference through the server:" ] }, { "cell_type": "code", "execution_count": 34, "id": "431b864c", "metadata": {}, "outputs": [], "source": [ "def triton_fn(model_name, host_to_url):\n", " import socket\n", " import numpy as np\n", " from pytriton.client import ModelClient\n", "\n", " url = host_to_url[socket.gethostname()]\n", " print(f\"Connecting to Triton model {model_name} at {url}.\")\n", "\n", " def infer_batch(inputs):\n", " with ModelClient(url, model_name, inference_timeout_s=240) as client:\n", " flattened = np.squeeze(inputs).tolist()\n", " # Encode batch\n", " encoded_batch = [[text.encode(\"utf-8\")] for text in flattened]\n", " encoded_batch_np = np.array(encoded_batch, dtype=np.bytes_)\n", " # Run inference\n", " result_data = client.infer_batch(encoded_batch_np)\n", " result_data = np.squeeze(result_data[\"outputs\"], -1)\n", " return [json.loads(o) for o in result_data]\n", " \n", " return infer_batch" ] }, { "cell_type": "code", "execution_count": 37, "id": "3930cfcd-3284-4c6a-a9b5-36b8053fe899", "metadata": {}, "outputs": [], "source": [ "classify = predict_batch_udf(partial(triton_fn, model_name=model_name, host_to_url=host_to_http_url),\n", " return_type=StructType([\n", " StructField(\"label\", StringType(), True),\n", " StructField(\"score\", FloatType(), True)\n", " ]),\n", " input_tensor_shapes=[[1]],\n", " batch_size=32)" ] }, { "cell_type": "markdown", "id": "5a8ec7be", "metadata": {}, "source": [ "#### Load and preprocess DataFrame" ] }, { "cell_type": "code", "execution_count": 35, "id": "d53fb283-bf9e-4571-8c68-b75a41f1f067", "metadata": {}, "outputs": [], "source": [ "@pandas_udf(\"string\")\n", "def preprocess(text: pd.Series) -> pd.Series:\n", " return pd.Series([s.split(\".\")[0] for s in text])" ] }, { "cell_type": "code", "execution_count": 36, "id": "29b0cc0d-c480-4e4a-bd41-207dc314cba5", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "25/02/04 13:57:36 WARN CacheManager: Asked to cache already cached data.\n" ] } ], "source": [ "df = spark.read.parquet(data_path).limit(256).repartition(8)\n", "df = df.select(preprocess(col(\"text\")).alias(\"input\")).cache()" ] }, { "cell_type": "markdown", "id": "da39990f", "metadata": {}, "source": [ "#### Run Inference" ] }, { "cell_type": "code", "execution_count": 38, "id": "8eecbf23-4e9e-4d4c-8645-98209b25db2c", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 33:===========================================> (6 + 2) / 8]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 13.1 ms, sys: 8.29 ms, total: 21.4 ms\n", "Wall time: 7.54 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "%%time\n", "# first pass caches model/fn\n", "# note: expanding the \"struct\" return_type to top-level columns\n", "preds = df.withColumn(\"preds\", classify(struct(\"input\"))).select(\"input\", \"preds.*\")\n", "results = preds.collect()" ] }, { "cell_type": "code", "execution_count": 39, "id": "566ba28c-0ca4-4479-a24a-c8a362228b89", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 36:===========================================> (6 + 2) / 8]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 7.54 ms, sys: 3.13 ms, total: 10.7 ms\n", "Wall time: 7.02 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "%%time\n", "preds = df.withColumn(\"preds\", classify(\"input\")).select(\"input\", \"preds.*\")\n", "results = preds.collect()" ] }, { "cell_type": "code", "execution_count": 40, "id": "44c7e776-08da-484a-ba07-9d6add1a0f15", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 39:===========================================> (6 + 2) / 8]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 6.26 ms, sys: 3 ms, total: 9.26 ms\n", "Wall time: 7.03 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "%%time\n", "preds = df.withColumn(\"preds\", classify(col(\"input\"))).select(\"input\", \"preds.*\")\n", "results = preds.collect()" ] }, { "cell_type": "code", "execution_count": 41, "id": "f61d79f8-661e-4d9e-a3aa-c0754b854603", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+--------------------------------------------------------------------------------+--------+----------+\n", "| input| label| score|\n", "+--------------------------------------------------------------------------------+--------+----------+\n", "|Doesn't anyone bother to check where this kind of sludge comes from before bl...|NEGATIVE| 0.9984061|\n", "| There were two things I hated about WASTED : The directing and the script |NEGATIVE| 0.9979007|\n", "| I'm rather surprised that anybody found this film touching or moving|POSITIVE|0.83874947|\n", "|Cultural Vandalism Is the new Hallmark production of Gulliver's Travels an ac...|NEGATIVE|0.99727434|\n", "|I was at Wrestlemania VI in Toronto as a 10 year old, and the event I saw the...|POSITIVE| 0.982114|\n", "| This movie has been done before|NEGATIVE|0.94210696|\n", "|[ as a new resolution for this year 2005, i decide to write a comment for eac...|NEGATIVE| 0.9967818|\n", "|This movie is over hyped!! I am sad to say that I manage to watch the first 1...|NEGATIVE| 0.9985843|\n", "|This show had a promising start as sort of the opposite of 'Oceans 11' but ha...|NEGATIVE|0.99926835|\n", "|MINOR PLOT SPOILERS AHEAD!!!

How did such talented actors get invo...|NEGATIVE|0.99956733|\n", "| There is not one character on this sitcom with any redeeming qualities|NEGATIVE| 0.9985662|\n", "| Tommy Lee Jones was the best Woodroe and no one can play Woodroe F|POSITIVE| 0.994562|\n", "| My wife rented this movie and then conveniently never got to see it|NEGATIVE|0.99841607|\n", "|This is one of those star-filled over-the-top comedies that could a) be hyste...|NEGATIVE| 0.9953243|\n", "|This excruciatingly boring and unfunny movie made me think that Chaplin was t...|NEGATIVE| 0.9997607|\n", "| you will likely be sorely disappointed by this sequel that's not a sequel|NEGATIVE| 0.9997198|\n", "| If I was British, I would be embarrassed by this portrayal of incompetence|NEGATIVE| 0.9965172|\n", "|One of those movies in which there are no big twists whatsoever and you can p...|NEGATIVE| 0.9986059|\n", "| This show is like watching someone who is in training to someday host a show|NEGATIVE|0.97015846|\n", "| Sigh|NEGATIVE| 0.9923151|\n", "+--------------------------------------------------------------------------------+--------+----------+\n", "only showing top 20 rows\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "preds.show(truncate=80)" ] }, { "cell_type": "markdown", "id": "fac2ae57", "metadata": {}, "source": [ "#### Shut down server on each executor" ] }, { "cell_type": "code", "execution_count": 42, "id": "425d3b28-7705-45ba-8a18-ad34fc895219", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-02-04 13:57:58,747 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\n", "2025-02-04 13:58:03,931 - INFO - Sucessfully stopped 1 servers. \n" ] }, { "data": { "text/plain": [ "[True]" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "source": [ "server_manager.stop_servers()" ] }, { "cell_type": "code", "execution_count": 43, "id": "9f19643c-4ee4-44f2-b762-2078c0c8eba9", "metadata": {}, "outputs": [], "source": [ "if not on_databricks: # on databricks, spark.stop() puts the cluster in a bad state\n", " spark.stop()" ] }, { "cell_type": "code", "execution_count": null, "id": "6a538c47-317d-4cac-b9b9-559e88677518", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "spark-dl-tf", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/pipelines_torch.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "id": "60f7ac5d-4a95-4170-a0ac-a7faac9d9ef4", "metadata": {}, "source": [ "\n", "\n", "# PySpark Huggingface Inferencing\n", "### Sentiment Analysis using Pipelines with PyTorch\n", "\n", "In this notebook, we demonstrate distributed inference with Huggingface Pipelines to perform sentiment analysis. \n", "From: https://huggingface.co/docs/transformers/quicktour#pipeline-usage" ] }, { "cell_type": "code", "execution_count": 1, "id": "0dd0f77b-ee1b-4477-a038-d25a4f1da0ea", "metadata": {}, "outputs": [], "source": [ "import torch\n", "from transformers import pipeline" ] }, { "cell_type": "code", "execution_count": 2, "id": "e1f756c6", "metadata": {}, "outputs": [], "source": [ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" ] }, { "cell_type": "code", "execution_count": 3, "id": "553b28d2-a5d1-4d07-8a49-8f82b808e738", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "No model was supplied, defaulted to distilbert/distilbert-base-uncased-finetuned-sst-2-english and revision 714eb0f (https://huggingface.co/distilbert/distilbert-base-uncased-finetuned-sst-2-english).\n", "Using a pipeline without specifying a model name and revision in production is not recommended.\n", "Device set to use cuda\n" ] } ], "source": [ "classifier = pipeline(\"sentiment-analysis\", device=device)" ] }, { "cell_type": "code", "execution_count": 4, "id": "3b91fe91-b725-4564-ae93-56e3fb51e47c", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[{'label': 'POSITIVE', 'score': 0.9997795224189758}]" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "classifier((\"We are very happy to show you the 🤗 Transformers library.\"))" ] }, { "cell_type": "code", "execution_count": 5, "id": "0be39eb3-462c-42ff-b8f4-09f4e4fe3a3c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "label: POSITIVE, with score: 0.9998\n", "label: NEGATIVE, with score: 0.5309\n" ] } ], "source": [ "results = classifier([\"We are very happy to show you the 🤗 Transformers library.\", \"We hope you don't hate it.\"])\n", "for result in results:\n", " print(f\"label: {result['label']}, with score: {round(result['score'], 4)}\")" ] }, { "cell_type": "markdown", "id": "f752f929", "metadata": {}, "source": [ "Let's try a different model and tokenizer in the pipeline." ] }, { "cell_type": "code", "execution_count": 6, "id": "9861865f", "metadata": {}, "outputs": [], "source": [ "model_name = \"nlptown/bert-base-multilingual-uncased-sentiment\"" ] }, { "cell_type": "code", "execution_count": 7, "id": "506e7834", "metadata": {}, "outputs": [], "source": [ "from transformers import AutoTokenizer, AutoModelForSequenceClassification\n", "\n", "model = AutoModelForSequenceClassification.from_pretrained(model_name)\n", "tokenizer = AutoTokenizer.from_pretrained(model_name)" ] }, { "cell_type": "code", "execution_count": 8, "id": "312017fc", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Device set to use cuda\n" ] }, { "data": { "text/plain": [ "[{'label': '5 stars', 'score': 0.7272652983665466}]" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "classifier = pipeline(\"sentiment-analysis\", model=model, tokenizer=tokenizer, device=device)\n", "classifier(\"Nous sommes très heureux de vous présenter la bibliothèque 🤗 Transformers.\")" ] }, { "cell_type": "markdown", "id": "ae92b15e-0da0-46c3-81a3-fabaedbfc42c", "metadata": {}, "source": [ "## PySpark" ] }, { "cell_type": "code", "execution_count": 9, "id": "69dd6a1a-f450-47f0-9dbf-ad250585a011", "metadata": {}, "outputs": [], "source": [ "from pyspark.sql.functions import col, struct, pandas_udf\n", "from pyspark.ml.functions import predict_batch_udf\n", "from pyspark.sql.types import *\n", "from pyspark.sql import SparkSession\n", "from pyspark import SparkConf" ] }, { "cell_type": "code", "execution_count": 10, "id": "42c19ad8", "metadata": {}, "outputs": [], "source": [ "import os\n", "import json\n", "import pandas as pd\n", "import datasets\n", "from datasets import load_dataset\n", "datasets.disable_progress_bars()" ] }, { "cell_type": "markdown", "id": "3f1a0210", "metadata": {}, "source": [ "Check the cluster environment to handle any platform-specific Spark configurations." ] }, { "cell_type": "code", "execution_count": 11, "id": "79aaf5ec", "metadata": {}, "outputs": [], "source": [ "on_databricks = os.environ.get(\"DATABRICKS_RUNTIME_VERSION\", False)\n", "on_dataproc = os.environ.get(\"DATAPROC_IMAGE_VERSION\", False)\n", "on_standalone = not (on_databricks or on_dataproc)" ] }, { "cell_type": "markdown", "id": "b99f9c38", "metadata": {}, "source": [ "#### Create Spark Session\n", "\n", "For local standalone clusters, we'll connect to the cluster and create the Spark Session. \n", "For CSP environments, Spark will either be preconfigured (Databricks) or we'll need to create the Spark Session (Dataproc)." ] }, { "cell_type": "code", "execution_count": 12, "id": "6e0e0dd7", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "25/02/04 13:23:47 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)\n", "25/02/04 13:23:47 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\n", "Setting default log level to \"WARN\".\n", "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n", "25/02/04 13:23:47 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n" ] } ], "source": [ "conf = SparkConf()\n", "\n", "if 'spark' not in globals():\n", " if on_standalone:\n", " import socket\n", " conda_env = os.environ.get(\"CONDA_PREFIX\")\n", " hostname = socket.gethostname()\n", " conf.setMaster(f\"spark://{hostname}:7077\")\n", " conf.set(\"spark.pyspark.python\", f\"{conda_env}/bin/python\")\n", " conf.set(\"spark.pyspark.driver.python\", f\"{conda_env}/bin/python\")\n", "\n", " conf.set(\"spark.executor.cores\", \"8\")\n", " conf.set(\"spark.task.resource.gpu.amount\", \"0.125\")\n", " conf.set(\"spark.executor.resource.gpu.amount\", \"1\")\n", " conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"true\")\n", " conf.set(\"spark.python.worker.reuse\", \"true\")\n", "\n", "conf.set(\"spark.sql.execution.arrow.maxRecordsPerBatch\", \"1000\")\n", "spark = SparkSession.builder.appName(\"spark-dl-examples\").config(conf=conf).getOrCreate()\n", "sc = spark.sparkContext" ] }, { "cell_type": "code", "execution_count": 13, "id": "42d70208", "metadata": {}, "outputs": [], "source": [ "dataset = load_dataset(\"imdb\", split=\"test\")\n", "dataset = dataset.to_pandas().drop(columns=\"label\")" ] }, { "cell_type": "markdown", "id": "de0f421d", "metadata": {}, "source": [ "#### Create PySpark DataFrame" ] }, { "cell_type": "code", "execution_count": 14, "id": "ac24f3c2", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "StructType([StructField('text', StringType(), True)])" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = spark.createDataFrame(dataset).repartition(8)\n", "df.schema" ] }, { "cell_type": "code", "execution_count": 15, "id": "b0d1876b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "25000" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.count()" ] }, { "cell_type": "code", "execution_count": 16, "id": "06ec6bb6", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "25/02/04 13:23:54 WARN TaskSetManager: Stage 6 contains a task of very large size (4021 KiB). The maximum recommended task size is 1000 KiB.\n" ] }, { "data": { "text/plain": [ "[Row(text=\"Anyone remember the first CKY, CKY2K etc..? Back when it was about making crazy cool stuff, rather than watching Bam Margera act like a douchebag, spoiled 5 year old, super/rock-star wannabe.

The show used to be awesome, however, Bam's fame and wealth has led him to believe, that we now enjoy him acting childish and idiotic, more than actual cool stuff, that used to be in ex. CKY2K.

The acts are so repetitive, there's like nothing new, except annoying stupidity and rehearsed comments... The only things we see is Bam Margera, so busy showing us how much he doesn't care, how much money he got or whatsoever.

I really got nothing much left to say except, give us back CKY2K, cause Bam suck..

I enjoy watching Steve-o, Knoxville etc. a thousand times more.\")]" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.take(1)" ] }, { "cell_type": "code", "execution_count": 17, "id": "eeadf4e2", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "25/02/04 13:23:54 WARN TaskSetManager: Stage 9 contains a task of very large size (4021 KiB). The maximum recommended task size is 1000 KiB.\n" ] } ], "source": [ "data_path = \"spark-dl-datasets/imdb_test\"\n", "if on_databricks:\n", " dbutils.fs.mkdirs(\"/FileStore/spark-dl-datasets\")\n", " data_path = \"dbfs:/FileStore/\" + data_path\n", "\n", "df.write.mode(\"overwrite\").parquet(data_path)" ] }, { "cell_type": "markdown", "id": "09cddc95", "metadata": {}, "source": [ "#### Load and preprocess DataFrame\n", "\n", "Define our preprocess function. We'll take the first sentence from each sample as our input for sentiment analysis." ] }, { "cell_type": "code", "execution_count": 18, "id": "9665b7b6-d7e9-4bd4-b29d-7a449ac5b574", "metadata": {}, "outputs": [], "source": [ "@pandas_udf(\"string\")\n", "def preprocess(text: pd.Series) -> pd.Series:\n", " return pd.Series([s.split(\".\")[0] for s in text])" ] }, { "cell_type": "code", "execution_count": 19, "id": "74cfa3ff", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+----------------------------------------------------------------------------------------------------+\n", "| input|\n", "+----------------------------------------------------------------------------------------------------+\n", "|Doesn't anyone bother to check where this kind of sludge comes from before blathering on about it...|\n", "| There were two things I hated about WASTED : The directing and the script |\n", "| I'm rather surprised that anybody found this film touching or moving|\n", "|Cultural Vandalism Is the new Hallmark production of Gulliver's Travels an act of cultural vandal...|\n", "|I was at Wrestlemania VI in Toronto as a 10 year old, and the event I saw then was pretty differe...|\n", "| This movie has been done before|\n", "|[ as a new resolution for this year 2005, i decide to write a comment for each movie I saw in the...|\n", "|This movie is over hyped!! I am sad to say that I manage to watch the first 15 minutes of this mo...|\n", "|This show had a promising start as sort of the opposite of 'Oceans 11' but has developed into a s...|\n", "|MINOR PLOT SPOILERS AHEAD!!!

How did such talented actors get involved in such mindles...|\n", "| There is not one character on this sitcom with any redeeming qualities|\n", "| Tommy Lee Jones was the best Woodroe and no one can play Woodroe F|\n", "| My wife rented this movie and then conveniently never got to see it|\n", "|This is one of those star-filled over-the-top comedies that could a) be hysterical, or b) wish th...|\n", "|This excruciatingly boring and unfunny movie made me think that Chaplin was the real Hitler, as o...|\n", "| you will likely be sorely disappointed by this sequel that's not a sequel|\n", "| If I was British, I would be embarrassed by this portrayal of incompetence|\n", "|One of those movies in which there are no big twists whatsoever and you can predict pretty much w...|\n", "| This show is like watching someone who is in training to someday host a show|\n", "| Sigh|\n", "+----------------------------------------------------------------------------------------------------+\n", "only showing top 20 rows\n", "\n" ] } ], "source": [ "# Limit to N rows, since this can be slow\n", "df = spark.read.parquet(data_path).limit(256).repartition(8)\n", "df = df.select(preprocess(col(\"text\")).alias(\"input\")).cache()\n", "df.show(truncate=100)" ] }, { "cell_type": "markdown", "id": "1ad92750", "metadata": {}, "source": [ "## Inference using Spark DL API\n", "\n", "Distributed inference using the PySpark [predict_batch_udf](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.functions.predict_batch_udf.html#pyspark.ml.functions.predict_batch_udf):\n", "\n", "- predict_batch_fn uses PyTorch APIs to load the model and return a predict function which operates on numpy arrays \n", "- predict_batch_udf will convert the Spark DataFrame columns into numpy input batches for the predict function" ] }, { "cell_type": "code", "execution_count": 20, "id": "0da9d25c-5ebe-4503-bb19-154fcc047cbf", "metadata": {}, "outputs": [], "source": [ "def predict_batch_fn():\n", " import torch\n", " from transformers import pipeline\n", " \n", " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", " pipe = pipeline(\"sentiment-analysis\", device=device)\n", " def predict(inputs):\n", " return pipe(inputs.tolist())\n", " return predict" ] }, { "cell_type": "code", "execution_count": 21, "id": "78afef29-ee30-4267-9fb6-be2dcb86cbba", "metadata": {}, "outputs": [], "source": [ "classify = predict_batch_udf(predict_batch_fn,\n", " return_type=StructType([\n", " StructField(\"label\", StringType(), True),\n", " StructField(\"score\", FloatType(), True)\n", " ]),\n", " batch_size=32)" ] }, { "cell_type": "code", "execution_count": 22, "id": "a5bc327e-89cf-4731-82e6-e66cb93deef1", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 18:====================================> (5 + 3) / 8]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 8.82 ms, sys: 2.5 ms, total: 11.3 ms\n", "Wall time: 3.59 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "%%time\n", "# first pass caches model/fn\n", "# note: expanding the \"struct\" return_type to top-level columns\n", "preds = df.withColumn(\"preds\", classify(struct(\"input\"))).select(\"input\", \"preds.*\")\n", "results = preds.collect()" ] }, { "cell_type": "code", "execution_count": 23, "id": "ac642895-cfd6-47ee-9b21-02e7835424e4", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 3.19 ms, sys: 1.65 ms, total: 4.84 ms\n", "Wall time: 392 ms\n" ] } ], "source": [ "%%time\n", "preds = df.withColumn(\"preds\", classify(\"input\")).select(\"input\", \"preds.*\")\n", "results = preds.collect()" ] }, { "cell_type": "code", "execution_count": 24, "id": "76a44d80-d5db-405f-989c-7246379cfb95", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 3.43 ms, sys: 2.33 ms, total: 5.77 ms\n", "Wall time: 403 ms\n" ] } ], "source": [ "%%time\n", "preds = df.withColumn(\"preds\", classify(col(\"input\"))).select(\"input\", \"preds.*\")\n", "results = preds.collect()" ] }, { "cell_type": "code", "execution_count": 25, "id": "c01761b3-c766-46b0-ae0b-fcf968ffb3a1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+--------------------------------------------------------------------------------+--------+----------+\n", "| input| label| score|\n", "+--------------------------------------------------------------------------------+--------+----------+\n", "|Doesn't anyone bother to check where this kind of sludge comes from before bl...|NEGATIVE| 0.9984042|\n", "| There were two things I hated about WASTED : The directing and the script |NEGATIVE| 0.9979019|\n", "| I'm rather surprised that anybody found this film touching or moving|POSITIVE| 0.839279|\n", "|Cultural Vandalism Is the new Hallmark production of Gulliver's Travels an ac...|NEGATIVE|0.99726933|\n", "|I was at Wrestlemania VI in Toronto as a 10 year old, and the event I saw the...|POSITIVE|0.98212504|\n", "| This movie has been done before|NEGATIVE| 0.9419482|\n", "|[ as a new resolution for this year 2005, i decide to write a comment for eac...|NEGATIVE|0.99678314|\n", "|This movie is over hyped!! I am sad to say that I manage to watch the first 1...|NEGATIVE| 0.9985846|\n", "|This show had a promising start as sort of the opposite of 'Oceans 11' but ha...|NEGATIVE|0.99926823|\n", "|MINOR PLOT SPOILERS AHEAD!!!

How did such talented actors get invo...|NEGATIVE| 0.9995671|\n", "| There is not one character on this sitcom with any redeeming qualities|NEGATIVE|0.99856514|\n", "| Tommy Lee Jones was the best Woodroe and no one can play Woodroe F|POSITIVE| 0.9945687|\n", "| My wife rented this movie and then conveniently never got to see it|NEGATIVE| 0.9984137|\n", "|This is one of those star-filled over-the-top comedies that could a) be hyste...|NEGATIVE| 0.9953224|\n", "|This excruciatingly boring and unfunny movie made me think that Chaplin was t...|NEGATIVE| 0.9997607|\n", "| you will likely be sorely disappointed by this sequel that's not a sequel|NEGATIVE|0.99971956|\n", "| If I was British, I would be embarrassed by this portrayal of incompetence|NEGATIVE|0.99651587|\n", "|One of those movies in which there are no big twists whatsoever and you can p...|NEGATIVE|0.99860746|\n", "| This show is like watching someone who is in training to someday host a show|NEGATIVE| 0.970153|\n", "| Sigh|NEGATIVE|0.99231356|\n", "+--------------------------------------------------------------------------------+--------+----------+\n", "only showing top 20 rows\n", "\n" ] } ], "source": [ "preds.show(truncate=80)" ] }, { "cell_type": "markdown", "id": "8ba1a6ce", "metadata": {}, "source": [ "## Using Triton Inference Server\n", "In this section, we demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL. \n", "We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server. \n", "\n", "The process looks like this:\n", "- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.\n", "- Define a Triton inference function, which contains a client that binds to the local server on a given node and sends inference requests.\n", "- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.\n", "- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.\n", "\n", "\"drawing\"" ] }, { "cell_type": "code", "execution_count": 26, "id": "4d4be844-4b8c-47df-bd09-0c280c7ff16b", "metadata": {}, "outputs": [], "source": [ "from functools import partial" ] }, { "cell_type": "markdown", "id": "ab52381b", "metadata": {}, "source": [ "Import the helper class from server_utils.py:" ] }, { "cell_type": "code", "execution_count": 32, "id": "4e6764c4", "metadata": {}, "outputs": [], "source": [ "sc.addPyFile(\"server_utils.py\")\n", "\n", "from server_utils import TritonServerManager" ] }, { "cell_type": "markdown", "id": "bab70481", "metadata": {}, "source": [ "Define the Triton Server function:" ] }, { "cell_type": "code", "execution_count": 33, "id": "7e53df9f-43cb-4c38-b8ac-dc2cbad99815", "metadata": {}, "outputs": [], "source": [ "def triton_server(ports):\n", " import time\n", " import signal\n", " import numpy as np\n", " import torch\n", " from transformers import pipeline\n", " from pytriton.decorators import batch\n", " from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor\n", " from pytriton.triton import Triton, TritonConfig\n", " from pyspark import TaskContext\n", "\n", " print(f\"SERVER: Initializing pipeline on worker {TaskContext.get().partitionId()}.\")\n", " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", " pipe = pipeline(\"sentiment-analysis\", device=device)\n", " print(f\"SERVER: Using {device} device.\")\n", "\n", " @batch\n", " def _infer_fn(**inputs):\n", " sentences = np.squeeze(inputs[\"text\"]).tolist()\n", " print(f\"SERVER: Received batch of size {len(sentences)}\")\n", " decoded_sentences = [s.decode(\"utf-8\") for s in sentences]\n", " return {\n", " \"outputs\": np.array([[json.dumps(o)] for o in pipe(decoded_sentences)])\n", " }\n", "\n", " workspace_path = f\"/tmp/triton_{time.strftime('%m_%d_%M_%S')}\"\n", " triton_conf = TritonConfig(http_port=ports[0], grpc_port=ports[1], metrics_port=ports[2])\n", " with Triton(config=triton_conf, workspace=workspace_path) as triton:\n", " triton.bind(\n", " model_name=\"SentimentAnalysis\",\n", " infer_func=_infer_fn,\n", " inputs=[\n", " Tensor(name=\"text\", dtype=object, shape=(-1,)),\n", " ],\n", " outputs=[\n", " Tensor(name=\"outputs\", dtype=object, shape=(-1,)),\n", " ],\n", " config=ModelConfig(\n", " max_batch_size=64,\n", " batcher=DynamicBatcher(max_queue_delay_microseconds=5000), # 5ms\n", " ),\n", " strict=True,\n", " )\n", "\n", " def _stop_triton(signum, frame):\n", " # The server manager sends SIGTERM to stop the server; this function ensures graceful cleanup.\n", " print(\"SERVER: Received SIGTERM. Stopping Triton server.\")\n", " triton.stop()\n", "\n", " signal.signal(signal.SIGTERM, _stop_triton)\n", "\n", " print(\"SERVER: Serving inference\")\n", " triton.serve()" ] }, { "cell_type": "markdown", "id": "7c5f4f2d", "metadata": {}, "source": [ "#### Start Triton servers" ] }, { "cell_type": "markdown", "id": "b5ef160a", "metadata": {}, "source": [ "The `TritonServerManager` will handle the lifecycle of Triton server instances across the Spark cluster:\n", "- Find available ports for HTTP/gRPC/metrics\n", "- Deploy a server on each node via stage-level scheduling\n", "- Gracefully shutdown servers across nodes" ] }, { "cell_type": "code", "execution_count": 35, "id": "ad13db78", "metadata": {}, "outputs": [], "source": [ "model_name = \"SentimentAnalysis\"\n", "server_manager = TritonServerManager(model_name=model_name)" ] }, { "cell_type": "code", "execution_count": null, "id": "e62d9739", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-02-07 11:03:44,809 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\n", "2025-02-07 11:03:44,810 - INFO - Starting 1 servers.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "data": { "text/plain": [ "{'cb4ae00-lcedt': (2020631, [7000, 7001, 7002])}" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Returns {'hostname', (server_pid, [http_port, grpc_port, metrics_port])}\n", "server_manager.start_servers(triton_server)" ] }, { "cell_type": "markdown", "id": "f5ae0b8e", "metadata": {}, "source": [ "#### Define client function" ] }, { "cell_type": "markdown", "id": "9e2059f9", "metadata": {}, "source": [ "Get the hostname -> url mapping from the server manager:" ] }, { "cell_type": "code", "execution_count": null, "id": "7ede428b", "metadata": {}, "outputs": [], "source": [ "host_to_http_url = server_manager.host_to_http_url # or server_manager.host_to_grpc_url" ] }, { "cell_type": "markdown", "id": "72f16ff5", "metadata": {}, "source": [ "Define the Triton inference function, which returns a predict function for batch inference through the server:" ] }, { "cell_type": "code", "execution_count": 38, "id": "14760940", "metadata": {}, "outputs": [], "source": [ "def triton_fn(model_name, host_to_url):\n", " import socket\n", " import numpy as np\n", " from pytriton.client import ModelClient\n", "\n", " url = host_to_url[socket.gethostname()]\n", " print(f\"Connecting to Triton model {model_name} at {url}.\")\n", "\n", " def infer_batch(inputs):\n", " with ModelClient(url, model_name, inference_timeout_s=240) as client:\n", " flattened = np.squeeze(inputs).tolist()\n", " # Encode batch\n", " encoded_batch = [[text.encode(\"utf-8\")] for text in flattened]\n", " encoded_batch_np = np.array(encoded_batch, dtype=np.bytes_)\n", " # Run inference\n", " result_data = client.infer_batch(encoded_batch_np)\n", " result_data = np.squeeze(result_data[\"outputs\"], -1)\n", " return [json.loads(o) for o in result_data]\n", " \n", " return infer_batch" ] }, { "cell_type": "code", "execution_count": 41, "id": "3930cfcd-3284-4c6a-a9b5-36b8053fe899", "metadata": {}, "outputs": [], "source": [ "classify = predict_batch_udf(partial(triton_fn, model_name=model_name, host_to_url=host_to_http_url),\n", " return_type=StructType([\n", " StructField(\"label\", StringType(), True),\n", " StructField(\"score\", FloatType(), True)\n", " ]),\n", " input_tensor_shapes=[[1]],\n", " batch_size=32)" ] }, { "cell_type": "markdown", "id": "a741e23a", "metadata": {}, "source": [ "#### Load and preprocess DataFrame" ] }, { "cell_type": "code", "execution_count": 39, "id": "ccc884a4", "metadata": {}, "outputs": [], "source": [ "@pandas_udf(\"string\")\n", "def preprocess(text: pd.Series) -> pd.Series:\n", " return pd.Series([s.split(\".\")[0] for s in text])" ] }, { "cell_type": "code", "execution_count": 40, "id": "c426fdbe", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "25/02/04 13:24:35 WARN CacheManager: Asked to cache already cached data.\n" ] } ], "source": [ "df = spark.read.parquet(data_path).limit(256).repartition(8)\n", "df = df.select(preprocess(col(\"text\")).alias(\"input\")).cache()" ] }, { "cell_type": "markdown", "id": "7da06df4", "metadata": {}, "source": [ "#### Run Inference" ] }, { "cell_type": "code", "execution_count": 42, "id": "8eecbf23-4e9e-4d4c-8645-98209b25db2c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 10.5 ms, sys: 2.2 ms, total: 12.7 ms\n", "Wall time: 671 ms\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "%%time\n", "# first pass caches model/fn\n", "# note: expanding the \"struct\" return_type to top-level columns\n", "preds = df.withColumn(\"preds\", classify(struct(\"input\"))).select(\"input\", \"preds.*\")\n", "results = preds.collect()" ] }, { "cell_type": "code", "execution_count": 43, "id": "566ba28c-0ca4-4479-a24a-c8a362228b89", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 1.68 ms, sys: 1.87 ms, total: 3.55 ms\n", "Wall time: 396 ms\n" ] } ], "source": [ "%%time\n", "preds = df.withColumn(\"preds\", classify(\"input\")).select(\"input\", \"preds.*\")\n", "results = preds.collect()" ] }, { "cell_type": "code", "execution_count": 44, "id": "44c7e776-08da-484a-ba07-9d6add1a0f15", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 3.06 ms, sys: 5.02 ms, total: 8.08 ms\n", "Wall time: 408 ms\n" ] } ], "source": [ "%%time\n", "preds = df.withColumn(\"preds\", classify(col(\"input\"))).select(\"input\", \"preds.*\")\n", "results = preds.collect()" ] }, { "cell_type": "code", "execution_count": 45, "id": "f61d79f8-661e-4d9e-a3aa-c0754b854603", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+----------------------------------------------------------------------+--------+----------+\n", "| input| label| score|\n", "+----------------------------------------------------------------------+--------+----------+\n", "|Doesn't anyone bother to check where this kind of sludge comes from...|NEGATIVE| 0.9984042|\n", "|There were two things I hated about WASTED : The directing and the ...|NEGATIVE| 0.9979019|\n", "| I'm rather surprised that anybody found this film touching or moving|POSITIVE| 0.839279|\n", "|Cultural Vandalism Is the new Hallmark production of Gulliver's Tra...|NEGATIVE|0.99726933|\n", "|I was at Wrestlemania VI in Toronto as a 10 year old, and the event...|POSITIVE|0.98212504|\n", "| This movie has been done before|NEGATIVE| 0.9419482|\n", "|[ as a new resolution for this year 2005, i decide to write a comme...|NEGATIVE|0.99678314|\n", "|This movie is over hyped!! I am sad to say that I manage to watch t...|NEGATIVE| 0.9985846|\n", "|This show had a promising start as sort of the opposite of 'Oceans ...|NEGATIVE|0.99926823|\n", "|MINOR PLOT SPOILERS AHEAD!!!

How did such talented actor...|NEGATIVE| 0.9995671|\n", "|There is not one character on this sitcom with any redeeming qualities|NEGATIVE|0.99856514|\n", "| Tommy Lee Jones was the best Woodroe and no one can play Woodroe F|POSITIVE| 0.9945687|\n", "| My wife rented this movie and then conveniently never got to see it|NEGATIVE| 0.9984137|\n", "|This is one of those star-filled over-the-top comedies that could a...|NEGATIVE| 0.9953224|\n", "|This excruciatingly boring and unfunny movie made me think that Cha...|NEGATIVE| 0.9997607|\n", "|you will likely be sorely disappointed by this sequel that's not a ...|NEGATIVE|0.99971956|\n", "|If I was British, I would be embarrassed by this portrayal of incom...|NEGATIVE|0.99651587|\n", "|One of those movies in which there are no big twists whatsoever and...|NEGATIVE|0.99860746|\n", "|This show is like watching someone who is in training to someday ho...|NEGATIVE| 0.970153|\n", "| Sigh|NEGATIVE|0.99231356|\n", "+----------------------------------------------------------------------+--------+----------+\n", "only showing top 20 rows\n", "\n" ] } ], "source": [ "preds.show(truncate=70)" ] }, { "cell_type": "markdown", "id": "2248858c", "metadata": {}, "source": [ "#### Shut down server on each executor" ] }, { "cell_type": "code", "execution_count": 46, "id": "e3a4e51f", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-02-04 13:24:40,325 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2025-02-04 13:24:45,576 - INFO - Sucessfully stopped 1 servers. \n" ] }, { "data": { "text/plain": [ "[True]" ] }, "execution_count": 46, "metadata": {}, "output_type": "execute_result" } ], "source": [ "server_manager.stop_servers()" ] }, { "cell_type": "code", "execution_count": 47, "id": "9f19643c-4ee4-44f2-b762-2078c0c8eba9", "metadata": {}, "outputs": [], "source": [ "if not on_databricks: # on databricks, spark.stop() puts the cluster in a bad state\n", " spark.stop()" ] }, { "cell_type": "code", "execution_count": null, "id": "a8b03e1e", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "spark-dl-torch", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.11" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/qwen-2.5-7b_torch.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "\n", "# PySpark LLM Inference: Qwen-2.5 Text Summarization\n", "\n", "In this notebook, we demonstrate distributed batch inference with [Qwen-2.5](https://huggingface.co/Qwen/Qwen2.5-7B-Instruct), using open weights on Huggingface.\n", "\n", "The Qwen-2.5-7b-instruct is an instruction-fine-tuned version of the Qwen-2.5-7b base model. We'll show how to use the model to perform text summarization.\n", "\n", "**Note:** Running this model on GPU with 16-bit precision requires **~16GB** of GPU RAM. Make sure your instances have sufficient GPU capacity." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The dataset we'll use requires Zstandard compression." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Collecting zstandard\n", " Downloading zstandard-0.23.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.0 kB)\n", "Downloading zstandard-0.23.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.4 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.4/5.4 MB\u001b[0m \u001b[31m66.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hInstalling collected packages: zstandard\n", "Successfully installed zstandard-0.23.0\n", "Note: you may need to restart the kernel to use updated packages.\n" ] } ], "source": [ "%pip install zstandard" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# Manually enable Huggingface tokenizer parallelism to avoid disabling with PySpark parallelism.\n", "# See (https://github.com/huggingface/transformers/issues/5486) for more info. \n", "import os\n", "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"true\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Check the cluster environment to handle any platform-specific configurations." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "on_databricks = os.environ.get(\"DATABRICKS_RUNTIME_VERSION\", False)\n", "on_dataproc = os.environ.get(\"DATAPROC_IMAGE_VERSION\", False)\n", "on_standalone = not (on_databricks or on_dataproc)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# For cloud environments, load the model to the distributed file system.\n", "if on_databricks:\n", " models_dir = \"/dbfs/FileStore/spark-dl-models\"\n", " dbutils.fs.mkdirs(\"/FileStore/spark-dl-models\")\n", " model_path = f\"{models_dir}/qwen-2.5-7b\"\n", "elif on_dataproc:\n", " models_dir = \"/mnt/gcs/spark-dl-models\"\n", " os.mkdir(models_dir) if not os.path.exists(models_dir) else None\n", " model_path = f\"{models_dir}/qwen-2.5-7b\"\n", "else:\n", " model_path = os.path.abspath(\"qwen-2.5-7b\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Download the model from huggingface hub." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from huggingface_hub import snapshot_download\n", "\n", "model_path = snapshot_download(\n", " repo_id=\"Qwen/Qwen2.5-7B-Instruct\",\n", " local_dir=model_path\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Warmup: Running locally\n", "\n", "**Note**: If the driver node does not have sufficient GPU capacity, proceed to the PySpark section." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "352b738e1a2442b0a997467aaf6eb0ad", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/4 [00:00system\\n\"),\n", " lit(system_prompt),\n", " lit(\"<|im_end|>\\n<|im_start|>user\\n\"),\n", " col(\"value\"),\n", " lit(\"<|im_end|>\\n<|im_start|>assistant\\n\")\n", " ).alias(\"prompt\")\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "<|im_start|>system\n", "You are a knowledgeable AI assistant. Your job is to create a 2-3 sentence summary \n", "of a research abstract that captures the main objective, methodology, and key findings, using clear \n", "language while preserving technical accuracy and quantitative results.<|im_end|>\n", "<|im_start|>user\n", "Epidemiology of hypoxaemia in children with acute lower respiratory infection.\n", "To determine the prevalence of hypoxaemia in children aged under 5 years suffering acute lower respiratory infections (ALRI), the risk factors for hypoxaemia in children under 5 years of age with ALRI, and the association of hypoxaemia with an increased risk of dying in children of the same age. Systematic review of the published literature. Out-patient clinics, emergency departments and hospitalisation wards in 23 health centres from 10 countries. Cohort studies reporting the frequency of hypoxaemia in children under 5 years of age with ALRI, and the association between hypoxaemia and the risk of dying. Prevalence of hypoxaemia measured in children with ARI and relative risks for the association between the severity of illness and the frequency of hypoxaemia, and between hypoxaemia and the risk of dying. Seventeen published studies were found that included 4,021 children under 5 with acute respiratory infections (ARI) and reported the prevalence of hypoxaemia. Out-patient children and those with a clinical diagnosis of upper ARI had a low risk of hypoxaemia (pooled estimate of 6% to 9%). The prevalence increased to 31% and to 43% in patients in emergency departments and in cases with clinical pneumonia, respectively, and it was even higher among hospitalised children (47%) and in those with radiographically confirmed pneumonia (72%). The cumulated data also suggest that hypoxaemia is more frequent in children living at high altitude. Three papers reported an association between hypoxaemia and death, with relative risks varying between 1.4 and 4.6. Papers describing predictors of hypoxaemia have focused on clinical signs for detecting hypoxaemia rather than on identifying risk factors for developing this complication. Hypoxaemia is a common and potentially lethal complication of ALRI in children under 5, particularly among those with severe disease and those living at high altitude. Given the observed high prevalence of hypoxaemia and its likely association with increased mortality, efforts should be made to improve the detection of hypoxaemia and to provide oxygen earlier to more children with severe ALRI.<|im_end|>\n", "<|im_start|>assistant\n", "\n" ] } ], "source": [ "print(df.take(1)[0].prompt)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "data_path = \"spark-dl-datasets/pubmed_abstracts\"\n", "if on_databricks:\n", " dbutils.fs.mkdirs(\"/FileStore/spark-dl-datasets\")\n", " data_path = \"dbfs:/FileStore/\" + data_path\n", "\n", "df.write.mode(\"overwrite\").parquet(data_path)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Using Triton Inference Server\n", "In this section, we demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL. \n", "We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server. \n", "\n", "The process looks like this:\n", "- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.\n", "- Define a Triton inference function, which contains a client that binds to the local server on a given node and sends inference requests.\n", "- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.\n", "- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.\n", "\n", "\"drawing\"" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "from functools import partial" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Import the helper class from server_utils.py:" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "sc.addPyFile(\"server_utils.py\")\n", "\n", "from server_utils import TritonServerManager" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Define the Triton Server function:" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "def triton_server(ports, model_path):\n", " import time\n", " import signal\n", " import torch\n", " import numpy as np\n", " from pytriton.decorators import batch\n", " from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor\n", " from pytriton.triton import Triton, TritonConfig\n", " from pyspark import TaskContext\n", " from transformers import AutoModelForCausalLM, AutoTokenizer\n", "\n", " print(f\"SERVER: Initializing model on worker {TaskContext.get().partitionId()}.\")\n", " model = AutoModelForCausalLM.from_pretrained(\n", " model_path,\n", " torch_dtype=torch.bfloat16,\n", " device_map=\"auto\"\n", " )\n", " tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side=\"left\")\n", "\n", " @batch\n", " def _infer_fn(**inputs):\n", " prompts = np.squeeze(inputs[\"prompts\"]).tolist()\n", " print(f\"SERVER: Received batch of size {len(prompts)}\")\n", " decoded_prompts = [p.decode(\"utf-8\") for p in prompts]\n", " tokenized_inputs = tokenizer(decoded_prompts, padding=True, return_tensors=\"pt\").to(model.device)\n", " generated_ids = model.generate(**tokenized_inputs, max_new_tokens=256)\n", " outputs = tokenizer.batch_decode(generated_ids[:, tokenized_inputs.input_ids.shape[1]:], skip_special_tokens = True)\n", " return {\n", " \"outputs\": np.array(outputs).reshape(-1, 1)\n", " }\n", "\n", " workspace_path = f\"/tmp/triton_{time.strftime('%m_%d_%M_%S')}\"\n", " triton_conf = TritonConfig(http_port=ports[0], grpc_port=ports[1], metrics_port=ports[2])\n", " with Triton(config=triton_conf, workspace=workspace_path) as triton:\n", " triton.bind(\n", " model_name=\"qwen-2.5\",\n", " infer_func=_infer_fn,\n", " inputs=[\n", " Tensor(name=\"prompts\", dtype=object, shape=(-1,)),\n", " ],\n", " outputs=[\n", " Tensor(name=\"outputs\", dtype=object, shape=(-1,)),\n", " ],\n", " config=ModelConfig(\n", " max_batch_size=64,\n", " batcher=DynamicBatcher(max_queue_delay_microseconds=5000), # 5ms\n", " ),\n", " strict=True,\n", " )\n", "\n", " def _stop_triton(signum, frame):\n", " print(\"SERVER: Received SIGTERM. Stopping Triton server.\")\n", " triton.stop()\n", "\n", " signal.signal(signal.SIGTERM, _stop_triton)\n", "\n", " print(\"SERVER: Serving inference\")\n", " triton.serve()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Start Triton servers" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `TritonServerManager` will handle the lifecycle of Triton server instances across the Spark cluster:\n", "- Find available ports for HTTP/gRPC/metrics\n", "- Deploy a server on each node via stage-level scheduling\n", "- Gracefully shutdown servers across nodes" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "model_name = \"qwen-2.5\"\n", "server_manager = TritonServerManager(model_name=model_name, model_path=model_path)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-02-16 11:49:25,237 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\n", "2025-02-16 11:49:25,239 - INFO - Starting 1 servers.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "data": { "text/plain": [ "{'cb4ae00-lcedt': (3490378, [7000, 7001, 7002])}" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Returns {'hostname', (server_pid, [http_port, grpc_port, metrics_port])}\n", "server_manager.start_servers(triton_server, wait_retries=24) # allow up to 2 minutes for model loading" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Define client function" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Get the hostname -> url mapping from the server manager:" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "host_to_grpc_url = server_manager.host_to_grpc_url # or server_manager.host_to_http_url" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "def triton_fn(model_name, host_to_url):\n", " import socket\n", " import numpy as np\n", " from pytriton.client import ModelClient\n", "\n", " url = host_to_url[socket.gethostname()]\n", " print(f\"Connecting to Triton model {model_name} at {url}.\")\n", "\n", " def infer_batch(inputs):\n", " with ModelClient(url, model_name, inference_timeout_s=500) as client:\n", " flattened = np.squeeze(inputs).tolist()\n", " # Encode batch\n", " encoded_batch = [[text.encode(\"utf-8\")] for text in flattened]\n", " encoded_batch_np = np.array(encoded_batch, dtype=np.bytes_)\n", " # Run inference\n", " result_data = client.infer_batch(encoded_batch_np)\n", " result_data = np.squeeze(result_data[\"outputs\"], -1)\n", " return result_data\n", " \n", " return infer_batch" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "generate = predict_batch_udf(partial(triton_fn, model_name=model_name, host_to_url=host_to_grpc_url),\n", " return_type=StringType(),\n", " input_tensor_shapes=[[1]],\n", " batch_size=8)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Load DataFrame\n", "\n", "We'll parallelize over a small set of prompts for demonstration." ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "df = spark.read.parquet(data_path).limit(64).repartition(8)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Run Inference" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 10:=====================> (3 + 5) / 8]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 10.5 ms, sys: 6.63 ms, total: 17.1 ms\n", "Wall time: 23.7 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "%%time\n", "# first pass caches model/fn\n", "preds = df.withColumn(\"outputs\", generate(col(\"prompt\")))\n", "results = preds.collect()" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 16:=====================> (3 + 5) / 8]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 8.1 ms, sys: 4.47 ms, total: 12.6 ms\n", "Wall time: 21.7 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "%%time\n", "preds = df.withColumn(\"outputs\", generate(col(\"prompt\")))\n", "results = preds.collect()" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Q: <|im_start|>system\n", "You are a knowledgeable AI assistant. Your job is to create a 2-3 sentence summary \n", "of a research abstract that captures the main objective, methodology, and key findings, using clear \n", "language while preserving technical accuracy and quantitative results.<|im_end|>\n", "<|im_start|>user\n", "Oral health promotion evaluation--time for development.\n", "Increasing emphasis is now being placed upon the evaluation of health service interventions to demonstrate their effects. A series of effectiveness reviews of the oral health education and promotion literature has demonstrated that many of these interventions are poorly and inadequately evaluated. It is therefore difficult to determine the effectiveness of many interventions. Based upon developments from the field of health promotion research this paper explores options for improving the quality of oral health promotion evaluation. It is essential that the methods and measures used in the evaluation of oral health promotion are appropriate to the intervention. For many oral health promotion interventions clinical measures and methods of evaluation may not be appropriate. This paper outlines an evaluation framework which can be used to assess the range of effects of oral health promotion programmes. Improving the quality of oral health promotion evaluation is a shared responsibility between researchers and those involved in the provision of programmes. The provision of adequate resources and training are essential requirements for this to be successfully achieved.<|im_end|>\n", "<|im_start|>assistant\n", " \n", "\n", "A: This research aims to improve the evaluation of oral health promotion programs by developing an appropriate framework. It explores how methods and measures should align with the specific nature of these interventions, emphasizing that both researchers and program providers must collaborate to ensure adequate resources and training are available for high-quality evaluations. \n", "\n" ] } ], "source": [ "print(f\"Q: {results[0].prompt} \\n\")\n", "print(f\"A: {results[0].outputs} \\n\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Shut down server on each executor" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-02-16 11:51:42,365 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2025-02-16 11:51:47,609 - INFO - Sucessfully stopped 1 servers. \n" ] }, { "data": { "text/plain": [ "[True]" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "server_manager.stop_servers()" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "if not on_databricks: # on databricks, spark.stop() puts the cluster in a bad state\n", " spark.stop()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "spark-dl-torch", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.11" } }, "nbformat": 4, "nbformat_minor": 2 } ================================================ FILE: examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/sentence_transformers_torch.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "id": "777fc40d", "metadata": {}, "source": [ "\n", "\n", "# PySpark Huggingface Inferencing\n", "### Sentence Transformers with PyTorch\n", "\n", "In this notebook, we demonstrate distributed inference with the Huggingface SentenceTransformer library for sentence embedding. \n", "From: https://huggingface.co/sentence-transformers" ] }, { "cell_type": "code", "execution_count": 2, "id": "c5f0d0a8", "metadata": {}, "outputs": [], "source": [ "import torch\n", "from sentence_transformers import SentenceTransformer\n", "\n", "# Manually enable Huggingface tokenizer parallelism to avoid disabling with PySpark parallelism.\n", "# See (https://github.com/huggingface/transformers/issues/5486) for more info. \n", "import os\n", "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"true\"" ] }, { "cell_type": "code", "execution_count": null, "id": "731faab7-a700-46f8-bba5-1c8764e5eacb", "metadata": {}, "outputs": [], "source": [ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "model = SentenceTransformer(\"paraphrase-MiniLM-L6-v2\", device=device)\n", "\n", "sentence = ['This framework generates embeddings for each input sentence']\n", "embedding = model.encode(sentence)" ] }, { "cell_type": "code", "execution_count": 3, "id": "96eea5ca-3cf7-46e3-b40c-598538112d24", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[-0.17621444 0.1206013 -0.29362372 -0.22985819 -0.08229247 0.2377093\n", " 0.33998525 -0.7809643 0.11812777 0.16337365 -0.13771524 0.24028276\n", " 0.4251256 0.17241786 0.10527937 0.5181643 0.062222 0.39928585\n", " -0.18165241 -0.58557856]\n" ] } ], "source": [ "print(embedding[0][:20])" ] }, { "cell_type": "markdown", "id": "546eabe0", "metadata": {}, "source": [ "## PySpark" ] }, { "cell_type": "code", "execution_count": 4, "id": "dbda3e66-005a-4ad0-8017-c1cc7cbf0058", "metadata": {}, "outputs": [], "source": [ "from pyspark.sql.types import *\n", "from pyspark import SparkConf\n", "from pyspark.sql import SparkSession\n", "from pyspark.sql.functions import pandas_udf, col, struct\n", "from pyspark.ml.functions import predict_batch_udf" ] }, { "cell_type": "code", "execution_count": 5, "id": "b525c5c4", "metadata": {}, "outputs": [], "source": [ "import json\n", "import pandas as pd\n", "import datasets\n", "from datasets import load_dataset\n", "datasets.disable_progress_bars()" ] }, { "cell_type": "markdown", "id": "58e7c1bc", "metadata": {}, "source": [ "Check the cluster environment to handle any platform-specific Spark configurations." ] }, { "cell_type": "code", "execution_count": 6, "id": "5a013217", "metadata": {}, "outputs": [], "source": [ "on_databricks = os.environ.get(\"DATABRICKS_RUNTIME_VERSION\", False)\n", "on_dataproc = os.environ.get(\"DATAPROC_IMAGE_VERSION\", False)\n", "on_standalone = not (on_databricks or on_dataproc)" ] }, { "cell_type": "markdown", "id": "ad3c003d", "metadata": {}, "source": [ "#### Create Spark Session\n", "\n", "For local standalone clusters, we'll connect to the cluster and create the Spark Session. \n", "For CSP environments, Spark will either be preconfigured (Databricks) or we'll need to create the Spark Session (Dataproc)." ] }, { "cell_type": "code", "execution_count": 7, "id": "23ec67ba", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "25/02/04 13:40:01 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)\n", "25/02/04 13:40:01 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\n", "Setting default log level to \"WARN\".\n", "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n", "25/02/04 13:40:01 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n" ] } ], "source": [ "conf = SparkConf()\n", "\n", "if 'spark' not in globals():\n", " if on_standalone:\n", " import socket\n", " conda_env = os.environ.get(\"CONDA_PREFIX\")\n", " hostname = socket.gethostname()\n", " conf.setMaster(f\"spark://{hostname}:7077\")\n", " conf.set(\"spark.pyspark.python\", f\"{conda_env}/bin/python\")\n", " conf.set(\"spark.pyspark.driver.python\", f\"{conda_env}/bin/python\")\n", "\n", " conf.set(\"spark.executor.cores\", \"8\")\n", " conf.set(\"spark.task.resource.gpu.amount\", \"0.125\")\n", " conf.set(\"spark.executor.resource.gpu.amount\", \"1\")\n", " conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"true\")\n", " conf.set(\"spark.python.worker.reuse\", \"true\")\n", "\n", "conf.set(\"spark.sql.execution.arrow.maxRecordsPerBatch\", \"1000\")\n", "spark = SparkSession.builder.appName(\"spark-dl-examples\").config(conf=conf).getOrCreate()\n", "sc = spark.sparkContext" ] }, { "cell_type": "markdown", "id": "4cfd1394", "metadata": {}, "source": [ "Load the IMBD Movie Reviews dataset from Huggingface." ] }, { "cell_type": "code", "execution_count": 8, "id": "9bc1edb5", "metadata": {}, "outputs": [], "source": [ "dataset = load_dataset(\"imdb\", split=\"test\")\n", "dataset = dataset.to_pandas().drop(columns=\"label\")" ] }, { "cell_type": "markdown", "id": "59c71bff", "metadata": {}, "source": [ "#### Create PySpark DataFrame" ] }, { "cell_type": "code", "execution_count": 9, "id": "836e5f84-12c6-4c95-838e-53de7e46a20b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "StructType([StructField('text', StringType(), True)])" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = spark.createDataFrame(dataset).repartition(8)\n", "df.schema" ] }, { "cell_type": "code", "execution_count": 10, "id": "36703d23-37a3-40df-b09a-c68206d285b6", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "25000" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.count()" ] }, { "cell_type": "code", "execution_count": 11, "id": "1f122ae3", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "25/02/04 13:40:08 WARN TaskSetManager: Stage 6 contains a task of very large size (4021 KiB). The maximum recommended task size is 1000 KiB.\n" ] }, { "data": { "text/plain": [ "[Row(text=\"Anyone remember the first CKY, CKY2K etc..? Back when it was about making crazy cool stuff, rather than watching Bam Margera act like a douchebag, spoiled 5 year old, super/rock-star wannabe.

The show used to be awesome, however, Bam's fame and wealth has led him to believe, that we now enjoy him acting childish and idiotic, more than actual cool stuff, that used to be in ex. CKY2K.

The acts are so repetitive, there's like nothing new, except annoying stupidity and rehearsed comments... The only things we see is Bam Margera, so busy showing us how much he doesn't care, how much money he got or whatsoever.

I really got nothing much left to say except, give us back CKY2K, cause Bam suck..

I enjoy watching Steve-o, Knoxville etc. a thousand times more.\")]" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.take(1)" ] }, { "cell_type": "code", "execution_count": 12, "id": "14fd59fb", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "25/02/04 13:40:08 WARN TaskSetManager: Stage 9 contains a task of very large size (4021 KiB). The maximum recommended task size is 1000 KiB.\n" ] } ], "source": [ "data_path = \"spark-dl-datasets/imdb_test\"\n", "if on_databricks:\n", " dbutils.fs.mkdirs(\"/FileStore/spark-dl-datasets\")\n", " data_path = \"dbfs:/FileStore/\" + data_path\n", "\n", "df.write.mode(\"overwrite\").parquet(data_path)" ] }, { "cell_type": "markdown", "id": "6bb083ec", "metadata": {}, "source": [ "#### Load and preprocess DataFrame\n", "\n", "Define our preprocess function. We'll take the first sentence from each sample as our input for translation." ] }, { "cell_type": "code", "execution_count": 13, "id": "2510bdd1", "metadata": {}, "outputs": [], "source": [ "@pandas_udf(\"string\")\n", "def preprocess(text: pd.Series) -> pd.Series:\n", " return pd.Series([s.split(\".\")[0] for s in text])" ] }, { "cell_type": "code", "execution_count": 14, "id": "5bb28548", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+----------------------------------------------------------------------------------------------------+\n", "| input|\n", "+----------------------------------------------------------------------------------------------------+\n", "|Doesn't anyone bother to check where this kind of sludge comes from before blathering on about it...|\n", "| There were two things I hated about WASTED : The directing and the script |\n", "| I'm rather surprised that anybody found this film touching or moving|\n", "|Cultural Vandalism Is the new Hallmark production of Gulliver's Travels an act of cultural vandal...|\n", "|I was at Wrestlemania VI in Toronto as a 10 year old, and the event I saw then was pretty differe...|\n", "| This movie has been done before|\n", "|[ as a new resolution for this year 2005, i decide to write a comment for each movie I saw in the...|\n", "|This movie is over hyped!! I am sad to say that I manage to watch the first 15 minutes of this mo...|\n", "|This show had a promising start as sort of the opposite of 'Oceans 11' but has developed into a s...|\n", "|MINOR PLOT SPOILERS AHEAD!!!

How did such talented actors get involved in such mindles...|\n", "| There is not one character on this sitcom with any redeeming qualities|\n", "| Tommy Lee Jones was the best Woodroe and no one can play Woodroe F|\n", "| My wife rented this movie and then conveniently never got to see it|\n", "|This is one of those star-filled over-the-top comedies that could a) be hysterical, or b) wish th...|\n", "|This excruciatingly boring and unfunny movie made me think that Chaplin was the real Hitler, as o...|\n", "| you will likely be sorely disappointed by this sequel that's not a sequel|\n", "| If I was British, I would be embarrassed by this portrayal of incompetence|\n", "|One of those movies in which there are no big twists whatsoever and you can predict pretty much w...|\n", "| This show is like watching someone who is in training to someday host a show|\n", "| Sigh|\n", "+----------------------------------------------------------------------------------------------------+\n", "only showing top 20 rows\n", "\n" ] } ], "source": [ "# Limit to N rows, since this can be slow\n", "df = spark.read.parquet(data_path).limit(256).repartition(8)\n", "df = df.select(preprocess(col(\"text\")).alias(\"input\")).cache()\n", "df.show(truncate=100)" ] }, { "cell_type": "markdown", "id": "014eae88", "metadata": {}, "source": [ "## Inference using Spark DL API\n", "\n", "Distributed inference using the PySpark [predict_batch_udf](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.functions.predict_batch_udf.html#pyspark.ml.functions.predict_batch_udf):\n", "\n", "- predict_batch_fn uses PyTorch APIs to load the model and return a predict function which operates on numpy arrays \n", "- predict_batch_udf will convert the Spark DataFrame columns into numpy input batches for the predict function" ] }, { "cell_type": "code", "execution_count": 15, "id": "f780c026-0f3f-4aea-8b61-5b3dbae83fb7", "metadata": {}, "outputs": [], "source": [ "def predict_batch_fn():\n", " import torch\n", " from sentence_transformers import SentenceTransformer\n", "\n", " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", " model = SentenceTransformer(\"paraphrase-MiniLM-L6-v2\", device=device)\n", " def predict(inputs):\n", " return model.encode(inputs.tolist())\n", " return predict" ] }, { "cell_type": "code", "execution_count": 16, "id": "f5c88ddc-ca19-4430-8b0e-b9fae143b237", "metadata": {}, "outputs": [], "source": [ "encode = predict_batch_udf(predict_batch_fn,\n", " return_type=ArrayType(FloatType()),\n", " batch_size=32)" ] }, { "cell_type": "code", "execution_count": 17, "id": "85344c22-4a4d-4cb0-8771-5836ae2794db", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 18:=====================> (3 + 5) / 8]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 10.6 ms, sys: 4.83 ms, total: 15.4 ms\n", "Wall time: 4.23 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "%%time\n", "# first pass caches model/fn\n", "embeddings = df.withColumn(\"embedding\", encode(struct(\"input\")))\n", "results = embeddings.collect()" ] }, { "cell_type": "code", "execution_count": 18, "id": "c23bb885-6ab0-4471-943d-4c10414100fa", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 6.7 ms, sys: 2.44 ms, total: 9.15 ms\n", "Wall time: 163 ms\n" ] } ], "source": [ "%%time\n", "embeddings = df.withColumn(\"embedding\", encode(\"input\"))\n", "results = embeddings.collect()" ] }, { "cell_type": "code", "execution_count": 19, "id": "93bc6da3-d853-4233-b805-cb4a46f4f9b9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 5.37 ms, sys: 2.73 ms, total: 8.1 ms\n", "Wall time: 232 ms\n" ] } ], "source": [ "%%time\n", "embeddings = df.withColumn(\"embedding\", encode(col(\"input\")))\n", "results = embeddings.collect()" ] }, { "cell_type": "code", "execution_count": 20, "id": "2073616f-7151-4760-92f2-441dd0bfe9fe", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+--------------------------------------------------+--------------------------------------------------+\n", "| input| embedding|\n", "+--------------------------------------------------+--------------------------------------------------+\n", "|Doesn't anyone bother to check where this kind ...|[0.118947476, -0.053823642, -0.29726124, 0.0720...|\n", "|There were two things I hated about WASTED : Th...|[0.18953452, 0.11079162, 0.07503566, 0.01050696...|\n", "|I'm rather surprised that anybody found this fi...|[-0.0010759671, -0.14203517, -0.06649738, 0.129...|\n", "|Cultural Vandalism Is the new Hallmark producti...|[0.34815887, -0.2966917, -0.10905265, 0.1051652...|\n", "|I was at Wrestlemania VI in Toronto as a 10 yea...|[0.45902696, 0.019472413, 0.28720972, -0.070724...|\n", "| This movie has been done before|[-0.062292397, -0.025909504, -0.031942524, 0.01...|\n", "|[ as a new resolution for this year 2005, i dec...|[0.3469342, -0.14378615, 0.30223376, -0.1102267...|\n", "|This movie is over hyped!! I am sad to say that...|[0.13230576, -0.06588756, 0.0472389, 0.08353163...|\n", "|This show had a promising start as sort of the ...|[-0.19361982, -0.14412567, 0.15149693, -0.17715...|\n", "|MINOR PLOT SPOILERS AHEAD!!!

How did...|[-0.048036292, 0.050720096, -0.04668727, -0.316...|\n", "|There is not one character on this sitcom with ...|[0.13720773, -0.5963504, 0.30331734, -0.3830607...|\n", "|Tommy Lee Jones was the best Woodroe and no one...|[-0.20960267, -0.15760122, -0.30596405, -0.5181...|\n", "|My wife rented this movie and then conveniently...|[0.46534792, -0.40655977, 0.054217298, -0.03414...|\n", "|This is one of those star-filled over-the-top c...|[0.14433198, -0.016140658, 0.3775344, 0.0659043...|\n", "|This excruciatingly boring and unfunny movie ma...|[0.056464806, 0.01144963, -0.51797307, 0.089813...|\n", "|you will likely be sorely disappointed by this ...|[-0.44146675, -0.17866582, 0.49889183, -0.26819...|\n", "|If I was British, I would be embarrassed by thi...|[0.1191261, -0.15379854, 0.17487673, -0.5123498...|\n", "|One of those movies in which there are no big t...|[-0.016174048, -0.5558219, -0.024818476, 0.1543...|\n", "|This show is like watching someone who is in tr...|[0.033776704, -0.6682203, 0.30547586, -0.581407...|\n", "| Sigh|[-0.119870394, 0.40893683, 0.4174831, -0.010004...|\n", "+--------------------------------------------------+--------------------------------------------------+\n", "only showing top 20 rows\n", "\n" ] } ], "source": [ "embeddings.show(truncate=50)" ] }, { "cell_type": "markdown", "id": "0c9c6535", "metadata": {}, "source": [ "## Using Triton Inference Server\n", "In this section, we demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL. \n", "We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server. \n", "\n", "The process looks like this:\n", "- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.\n", "- Define a Triton inference function, which contains a client that binds to the local server on a given node and sends inference requests.\n", "- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.\n", "- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.\n", "\n", "\"drawing\"" ] }, { "cell_type": "code", "execution_count": 21, "id": "772e337e-1098-4c7b-ba81-8cb221a518e2", "metadata": {}, "outputs": [], "source": [ "from functools import partial" ] }, { "cell_type": "markdown", "id": "759385ac", "metadata": {}, "source": [ "Import the helper class from server_utils.py:" ] }, { "cell_type": "code", "execution_count": 22, "id": "485fb0de", "metadata": {}, "outputs": [], "source": [ "sc.addPyFile(\"server_utils.py\")\n", "\n", "from server_utils import TritonServerManager" ] }, { "cell_type": "markdown", "id": "ece5c38a", "metadata": {}, "source": [ "Define the Triton Server function:" ] }, { "cell_type": "code", "execution_count": 23, "id": "69d0c93a-bb0b-46c5-9d28-7b08a2e70964", "metadata": {}, "outputs": [], "source": [ "def triton_server(ports):\n", " import time\n", " import signal\n", " import numpy as np\n", " import torch\n", " from sentence_transformers import SentenceTransformer\n", " from pytriton.decorators import batch\n", " from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor\n", " from pytriton.triton import Triton, TritonConfig\n", " from pyspark import TaskContext\n", "\n", " print(f\"SERVER: Initializing sentence transformer on worker {TaskContext.get().partitionId()}.\")\n", " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", " model = SentenceTransformer(\"paraphrase-MiniLM-L6-v2\", device=device)\n", " print(f\"SERVER: Using {device} device.\")\n", "\n", " @batch\n", " def _infer_fn(**inputs):\n", " sentences = np.squeeze(inputs[\"text\"])\n", " print(f\"SERVER: Received batch of size {len(sentences)}\")\n", " decoded_sentences = [s.decode(\"utf-8\") for s in sentences]\n", " embeddings = model.encode(decoded_sentences)\n", " return {\n", " \"embeddings\": embeddings,\n", " }\n", "\n", " workspace_path = f\"/tmp/triton_{time.strftime('%m_%d_%M_%S')}\"\n", " triton_conf = TritonConfig(http_port=ports[0], grpc_port=ports[1], metrics_port=ports[2])\n", " with Triton(config=triton_conf, workspace=workspace_path) as triton:\n", " triton.bind(\n", " model_name=\"SentenceTransformer\",\n", " infer_func=_infer_fn,\n", " inputs=[\n", " Tensor(name=\"text\", dtype=object, shape=(-1,)),\n", " ],\n", " outputs=[\n", " Tensor(name=\"embeddings\", dtype=np.float32, shape=(-1,)),\n", " ],\n", " config=ModelConfig(\n", " max_batch_size=64,\n", " batcher=DynamicBatcher(max_queue_delay_microseconds=5000), # 5ms\n", " ),\n", " strict=True,\n", " )\n", "\n", " def _stop_triton(signum, frame):\n", " # The server manager sends SIGTERM to stop the server; this function ensures graceful cleanup.\n", " print(\"SERVER: Received SIGTERM. Stopping Triton server.\")\n", " triton.stop()\n", "\n", " signal.signal(signal.SIGTERM, _stop_triton)\n", "\n", " print(\"SERVER: Serving inference\")\n", " triton.serve()" ] }, { "cell_type": "markdown", "id": "79532110", "metadata": {}, "source": [ "#### Start Triton servers" ] }, { "cell_type": "markdown", "id": "1b0371c8", "metadata": {}, "source": [ "The `TritonServerManager` will handle the lifecycle of Triton server instances across the Spark cluster:\n", "- Find available ports for HTTP/gRPC/metrics\n", "- Deploy a server on each node via stage-level scheduling\n", "- Gracefully shutdown servers across nodes" ] }, { "cell_type": "code", "execution_count": 25, "id": "e66e8927", "metadata": {}, "outputs": [], "source": [ "model_name = \"SentenceTransformer\"\n", "server_manager = TritonServerManager(model_name=model_name)" ] }, { "cell_type": "code", "execution_count": null, "id": "040df0dd", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-02-07 11:03:44,809 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\n", "2025-02-07 11:03:44,810 - INFO - Starting 1 servers.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "data": { "text/plain": [ "{'cb4ae00-lcedt': (2020631, [7000, 7001, 7002])}" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Returns {'hostname', (server_pid, [http_port, grpc_port, metrics_port])}\n", "server_manager.start_servers(triton_server)" ] }, { "cell_type": "markdown", "id": "1fd19fae", "metadata": {}, "source": [ "#### Define client function" ] }, { "cell_type": "markdown", "id": "ddeadc74", "metadata": {}, "source": [ "Get the hostname -> url mapping from the server manager:" ] }, { "cell_type": "code", "execution_count": null, "id": "c42d1578", "metadata": {}, "outputs": [], "source": [ "host_to_http_url = server_manager.host_to_http_url # or server_manager.host_to_grpc_url" ] }, { "cell_type": "code", "execution_count": 28, "id": "807dbc45", "metadata": {}, "outputs": [], "source": [ "def triton_fn(model_name, host_to_url):\n", " import socket\n", " import numpy as np\n", " from pytriton.client import ModelClient\n", "\n", " url = host_to_url[socket.gethostname()]\n", " print(f\"Connecting to Triton model {model_name} at {url}.\")\n", "\n", " def infer_batch(inputs):\n", " with ModelClient(url, model_name, inference_timeout_s=240) as client:\n", " flattened = np.squeeze(inputs).tolist()\n", " # Encode batch\n", " encoded_batch = [[text.encode(\"utf-8\")] for text in flattened]\n", " encoded_batch_np = np.array(encoded_batch, dtype=np.bytes_)\n", " # Run inference\n", " result_data = client.infer_batch(encoded_batch_np)\n", " return result_data[\"embeddings\"]\n", " \n", " return infer_batch" ] }, { "cell_type": "code", "execution_count": 31, "id": "9c712b8f-6eb4-4fb8-9f0a-04feef847fea", "metadata": {}, "outputs": [], "source": [ "encode = predict_batch_udf(partial(triton_fn, model_name=model_name, host_to_url=host_to_http_url),\n", " return_type=ArrayType(FloatType()),\n", " input_tensor_shapes=[[1]],\n", " batch_size=32)" ] }, { "cell_type": "markdown", "id": "af174106", "metadata": {}, "source": [ "#### Load and preprocess DataFrame" ] }, { "cell_type": "code", "execution_count": 29, "id": "2969d502-e97b-49d6-bf80-7d177ae867cf", "metadata": {}, "outputs": [], "source": [ "@pandas_udf(\"string\")\n", "def preprocess(text: pd.Series) -> pd.Series:\n", " return pd.Series([s.split(\".\")[0] for s in text])" ] }, { "cell_type": "code", "execution_count": 30, "id": "c8f1e6d6-6519-49e7-8465-4419547633b8", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "25/02/04 13:40:22 WARN CacheManager: Asked to cache already cached data.\n" ] } ], "source": [ "df = spark.read.parquet(data_path).limit(256).repartition(8)\n", "df = df.select(preprocess(col(\"text\")).alias(\"input\")).cache()" ] }, { "cell_type": "markdown", "id": "cf0ee731", "metadata": {}, "source": [ "#### Run Inference" ] }, { "cell_type": "code", "execution_count": 32, "id": "934c1a1f-b126-45b0-9c15-265236820ad3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 5.59 ms, sys: 5.1 ms, total: 10.7 ms\n", "Wall time: 605 ms\n" ] } ], "source": [ "%%time\n", "# first pass caches model/fn\n", "embeddings = df.withColumn(\"embedding\", encode(struct(\"input\")))\n", "results = embeddings.collect()" ] }, { "cell_type": "code", "execution_count": 33, "id": "f84cd3f6-b6a8-4142-859a-91f3c183457b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 2.57 ms, sys: 4.36 ms, total: 6.93 ms\n", "Wall time: 161 ms\n" ] } ], "source": [ "%%time\n", "embeddings = df.withColumn(\"embedding\", encode(\"input\"))\n", "results = embeddings.collect()" ] }, { "cell_type": "code", "execution_count": 34, "id": "921a4c01-e296-4406-be90-86f20c8c582d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 7.06 ms, sys: 605 μs, total: 7.67 ms\n", "Wall time: 191 ms\n" ] } ], "source": [ "%%time\n", "embeddings = df.withColumn(\"embedding\", encode(col(\"input\")))\n", "results = embeddings.collect()" ] }, { "cell_type": "code", "execution_count": 35, "id": "9f67584e-9c4e-474f-b6ea-7811b14d116e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+--------------------------------------------------+--------------------------------------------------+\n", "| input| embedding|\n", "+--------------------------------------------------+--------------------------------------------------+\n", "|Doesn't anyone bother to check where this kind ...|[0.118947476, -0.053823642, -0.29726124, 0.0720...|\n", "|There were two things I hated about WASTED : Th...|[0.18953452, 0.11079162, 0.07503566, 0.01050696...|\n", "|I'm rather surprised that anybody found this fi...|[-0.0010759671, -0.14203517, -0.06649738, 0.129...|\n", "|Cultural Vandalism Is the new Hallmark producti...|[0.34815887, -0.2966917, -0.10905265, 0.1051652...|\n", "|I was at Wrestlemania VI in Toronto as a 10 yea...|[0.45902696, 0.019472413, 0.28720972, -0.070724...|\n", "| This movie has been done before|[-0.062292397, -0.025909504, -0.031942524, 0.01...|\n", "|[ as a new resolution for this year 2005, i dec...|[0.3469342, -0.14378615, 0.30223376, -0.1102267...|\n", "|This movie is over hyped!! I am sad to say that...|[0.13230576, -0.06588756, 0.0472389, 0.08353163...|\n", "|This show had a promising start as sort of the ...|[-0.19361982, -0.14412567, 0.15149693, -0.17715...|\n", "|MINOR PLOT SPOILERS AHEAD!!!

How did...|[-0.048036292, 0.050720096, -0.04668727, -0.316...|\n", "|There is not one character on this sitcom with ...|[0.13720773, -0.5963504, 0.30331734, -0.3830607...|\n", "|Tommy Lee Jones was the best Woodroe and no one...|[-0.20960267, -0.15760122, -0.30596405, -0.5181...|\n", "|My wife rented this movie and then conveniently...|[0.46534792, -0.40655977, 0.054217298, -0.03414...|\n", "|This is one of those star-filled over-the-top c...|[0.14433198, -0.016140658, 0.3775344, 0.0659043...|\n", "|This excruciatingly boring and unfunny movie ma...|[0.056464806, 0.01144963, -0.51797307, 0.089813...|\n", "|you will likely be sorely disappointed by this ...|[-0.44146675, -0.17866582, 0.49889183, -0.26819...|\n", "|If I was British, I would be embarrassed by thi...|[0.1191261, -0.15379854, 0.17487673, -0.5123498...|\n", "|One of those movies in which there are no big t...|[-0.016174048, -0.5558219, -0.024818476, 0.1543...|\n", "|This show is like watching someone who is in tr...|[0.033776704, -0.6682203, 0.30547586, -0.581407...|\n", "| Sigh|[-0.119870394, 0.40893683, 0.4174831, -0.010004...|\n", "+--------------------------------------------------+--------------------------------------------------+\n", "only showing top 20 rows\n", "\n" ] } ], "source": [ "embeddings.show(truncate=50)" ] }, { "cell_type": "markdown", "id": "e3b0077c-785f-41af-9fa9-812e7fb63810", "metadata": { "tags": [] }, "source": [ "#### Stop Triton Server on each executor" ] }, { "cell_type": "code", "execution_count": 36, "id": "ef780e30", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-02-04 13:40:23,196 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\n", "2025-02-04 13:40:28,390 - INFO - Sucessfully stopped 1 servers. \n" ] }, { "data": { "text/plain": [ "[True]" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "server_manager.stop_servers()" ] }, { "cell_type": "code", "execution_count": 37, "id": "e82b9518-da7b-4ebc-8990-c8ab909bec18", "metadata": {}, "outputs": [], "source": [ "if not on_databricks: # on databricks, spark.stop() puts the cluster in a bad state\n", " spark.stop()" ] }, { "cell_type": "code", "execution_count": null, "id": "33a60f2d-295a-4270-a2fd-16559962edda", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "spark-dl-torch", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.11" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: examples/ML+DL-Examples/Spark-DL/dl_inference/pytorch/housing_regression_torch.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "id": "792d95f9", "metadata": {}, "source": [ "\n", "\n", "# PySpark PyTorch Inference\n", "\n", "### Regression\n", "\n", "In this notebook, we will train an MLP to perform regression on the California housing dataset, and load it for distributed inference with Spark. \n", "\n", "Based on: https://github.com/christianversloot/machine-learning-articles/blob/main/how-to-create-a-neural-network-for-regression-with-pytorch.md \n", "\n", "We also demonstrate accelerated inference via Torch-TensorRT model compilation. " ] }, { "cell_type": "code", "execution_count": 1, "id": "75930360-c5ce-49ef-a69a-da88fa69a2ef", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import os\n", "import shutil\n", "import numpy as np\n", "from torch import nn\n", "from torch.utils.data import DataLoader\n", "from sklearn.datasets import fetch_california_housing\n", "from sklearn.preprocessing import StandardScaler" ] }, { "cell_type": "code", "execution_count": 2, "id": "1de685f4", "metadata": {}, "outputs": [], "source": [ "os.mkdir('models') if not os.path.exists('models') else None" ] }, { "cell_type": "code", "execution_count": 3, "id": "6d5bc0c7", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'2.5.1+cu124'" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.__version__" ] }, { "cell_type": "markdown", "id": "8754b174", "metadata": {}, "source": [ "### Load Dataset\n", "\n", "Each label corresponds to the average house value in units of 100,000, which we'll try to predict using the following features: \n", "['MedInc', 'HouseAge', 'AveRooms', 'AveBedrms', 'Population', 'AveOccup', 'Latitude', 'Longitude']" ] }, { "cell_type": "code", "execution_count": 4, "id": "2bee64cf-a44a-4aff-82db-c64ee3a8b0e8", "metadata": {}, "outputs": [], "source": [ "X, y = fetch_california_housing(return_X_y=True)" ] }, { "cell_type": "code", "execution_count": 5, "id": "8644e508-5e4c-4cdd-9ed1-9235887d9659", "metadata": {}, "outputs": [], "source": [ "class HousingDataset(torch.utils.data.Dataset):\n", " def __init__(self, X, y, scale_data=True):\n", " if not torch.is_tensor(X) and not torch.is_tensor(y):\n", " # Apply scaling if necessary\n", " if scale_data:\n", " X = StandardScaler().fit_transform(X)\n", " self.X = torch.from_numpy(X.astype(np.float32))\n", " self.y = torch.from_numpy(y.astype(np.float32))\n", "\n", " def __len__(self):\n", " return len(self.X)\n", "\n", " def __getitem__(self, i):\n", " return self.X[i], self.y[i]" ] }, { "cell_type": "code", "execution_count": 6, "id": "cc6b55c3-dc7b-4831-9943-83efd48091bf", "metadata": {}, "outputs": [], "source": [ "dataset = HousingDataset(X, y)\n", "trainloader = torch.utils.data.DataLoader(\n", " dataset, batch_size=10, shuffle=True, num_workers=1)" ] }, { "cell_type": "code", "execution_count": null, "id": "d868f39d-4695-4110-91d2-6f7a09d73b93", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[tensor([[ 6.5799e-01, 4.2594e-01, -1.4755e-01, -2.3638e-01, -4.0221e-01,\n", " -5.6793e-02, 8.8868e-01, -1.3528e+00],\n", " [ 6.7288e-01, -1.0043e+00, 5.7486e-01, -1.6537e-01, -3.3422e-01,\n", " -6.4971e-02, -1.2790e+00, 1.2327e+00],\n", " [-1.1616e-01, 2.8646e-02, -1.7830e-01, -2.3817e-01, -6.7154e-01,\n", " -3.6429e-02, -1.3258e+00, 1.2726e+00],\n", " [-3.2513e-01, -6.8648e-01, -3.4226e-01, -8.2805e-02, 5.1239e+00,\n", " 2.6689e-02, -7.7338e-01, 8.3340e-01],\n", " [ 1.0892e-01, -1.2427e+00, 2.7819e-01, -8.7150e-02, 3.0158e-01,\n", " -1.8564e-02, -1.1245e+00, 1.1628e+00],\n", " [-8.6416e-02, 5.8485e-01, -7.8085e-02, 8.1655e-02, -6.7154e-01,\n", " -1.6053e-02, -3.4733e-01, 1.2577e+00],\n", " [-1.2463e-01, 1.0810e-01, 2.6662e-01, -1.0883e-01, 3.4839e-01,\n", " -2.3125e-02, -7.7338e-01, 1.3325e+00],\n", " [-9.2662e-01, -1.6400e+00, -2.4824e-01, 6.0041e-01, 6.3361e-01,\n", " -1.0926e-01, -8.8574e-01, 1.2826e+00],\n", " [ 2.0038e+00, -6.0702e-01, 8.4770e-01, -2.1254e-01, 1.3745e+00,\n", " -5.0489e-03, -6.5165e-01, 2.5441e-01],\n", " [-3.9250e-01, 1.0616e+00, -1.8614e-01, -1.7073e-01, -3.8543e-01,\n", " -8.1186e-02, 1.0806e+00, -1.3827e+00]]),\n", " tensor([3.1090, 1.8430, 1.6890, 1.8670, 1.9600, 0.6200, 0.9860, 0.9440, 3.9120,\n", " 1.4390])]" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "next(iter(trainloader))" ] }, { "cell_type": "markdown", "id": "1e817b9a", "metadata": {}, "source": [ "### Create and Train Model" ] }, { "cell_type": "code", "execution_count": 8, "id": "9a441b60-dca4-44d2-bc1c-aa7336d704bb", "metadata": {}, "outputs": [], "source": [ "class MLP(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.layers = nn.Sequential(\n", " nn.Linear(8, 64),\n", " nn.ReLU(),\n", " nn.Linear(64, 32),\n", " nn.ReLU(),\n", " nn.Linear(32, 1)\n", " )\n", "\n", " def forward(self, x):\n", " return self.layers(x)" ] }, { "cell_type": "code", "execution_count": null, "id": "15cff2b4-9d23-4d2b-808a-a5edb8eda135", "metadata": { "scrolled": true, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using cuda device\n" ] } ], "source": [ "# Initialize the MLP\n", "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "print(f\"Using {device} device\")\n", "mlp = MLP().to(device)\n", "\n", "# Define the loss function and optimizer\n", "loss_function = nn.L1Loss()\n", "optimizer = torch.optim.Adam(mlp.parameters(), lr=1e-4)" ] }, { "cell_type": "code", "execution_count": 10, "id": "5e2db3f9-5db8-4b42-89ad-e77f23c4c1fe", "metadata": { "scrolled": true, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Starting epoch 1\n", "Loss after mini-batch 1: 0.004\n", "Loss after mini-batch 201: 0.701\n", "Loss after mini-batch 401: 0.463\n", "Loss after mini-batch 601: 0.329\n", "Loss after mini-batch 801: 0.285\n", "Loss after mini-batch 1001: 0.253\n", "Loss after mini-batch 1201: 0.247\n", "Loss after mini-batch 1401: 0.234\n", "Loss after mini-batch 1601: 0.232\n", "Loss after mini-batch 1801: 0.217\n", "Loss after mini-batch 2001: 0.211\n", "Starting epoch 2\n", "Loss after mini-batch 1: 0.001\n", "Loss after mini-batch 201: 0.205\n", "Loss after mini-batch 401: 0.212\n", "Loss after mini-batch 601: 0.206\n", "Loss after mini-batch 801: 0.205\n", "Loss after mini-batch 1001: 0.202\n", "Loss after mini-batch 1201: 0.202\n", "Loss after mini-batch 1401: 0.204\n", "Loss after mini-batch 1601: 0.198\n", "Loss after mini-batch 1801: 0.188\n", "Loss after mini-batch 2001: 0.188\n", "Starting epoch 3\n", "Loss after mini-batch 1: 0.001\n", "Loss after mini-batch 201: 0.197\n", "Loss after mini-batch 401: 0.193\n", "Loss after mini-batch 601: 0.196\n", "Loss after mini-batch 801: 0.189\n", "Loss after mini-batch 1001: 0.183\n", "Loss after mini-batch 1201: 0.191\n", "Loss after mini-batch 1401: 0.193\n", "Loss after mini-batch 1601: 0.181\n", "Loss after mini-batch 1801: 0.185\n", "Loss after mini-batch 2001: 0.181\n", "Starting epoch 4\n", "Loss after mini-batch 1: 0.001\n", "Loss after mini-batch 201: 0.190\n", "Loss after mini-batch 401: 0.181\n", "Loss after mini-batch 601: 0.189\n", "Loss after mini-batch 801: 0.180\n", "Loss after mini-batch 1001: 0.184\n", "Loss after mini-batch 1201: 0.180\n", "Loss after mini-batch 1401: 0.180\n", "Loss after mini-batch 1601: 0.184\n", "Loss after mini-batch 1801: 0.186\n", "Loss after mini-batch 2001: 0.179\n", "Starting epoch 5\n", "Loss after mini-batch 1: 0.000\n", "Loss after mini-batch 201: 0.181\n", "Loss after mini-batch 401: 0.177\n", "Loss after mini-batch 601: 0.185\n", "Loss after mini-batch 801: 0.179\n", "Loss after mini-batch 1001: 0.178\n", "Loss after mini-batch 1201: 0.173\n", "Loss after mini-batch 1401: 0.185\n", "Loss after mini-batch 1601: 0.177\n", "Loss after mini-batch 1801: 0.181\n", "Loss after mini-batch 2001: 0.178\n", "Training process has finished.\n" ] } ], "source": [ "# Run the training loop\n", "for epoch in range(0, 5): # 5 epochs at maximum\n", "\n", " # Print epoch\n", " print(f'Starting epoch {epoch+1}')\n", "\n", " # Set current loss value\n", " current_loss = 0.0\n", "\n", " # Iterate over the DataLoader for training data\n", " for i, data in enumerate(trainloader, 0):\n", "\n", " # Get and prepare inputs\n", " inputs, targets = data\n", " inputs, targets = inputs.to(device), targets.to(device)\n", " targets = targets.reshape((targets.shape[0], 1))\n", "\n", " # Zero the gradients\n", " optimizer.zero_grad()\n", "\n", " # Perform forward pass\n", " outputs = mlp(inputs)\n", "\n", " # Compute loss\n", " loss = loss_function(outputs, targets)\n", "\n", " # Perform backward pass\n", " loss.backward()\n", "\n", " # Perform optimization\n", " optimizer.step()\n", "\n", " # Print statistics\n", " current_loss += loss.item()\n", " if i % 200 == 0:\n", " print('Loss after mini-batch %5d: %.3f' %\n", " (i + 1, current_loss / 500))\n", " current_loss = 0.0\n", "\n", "# Process is complete.\n", "print('Training process has finished.')" ] }, { "cell_type": "markdown", "id": "352539f5", "metadata": {}, "source": [ "### Save Model State Dict\n", "This saves the serialized object to disk using pickle." ] }, { "cell_type": "code", "execution_count": 11, "id": "b950a3ed-ffe1-477f-a84f-f71c85dbf9ce", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Saved PyTorch Model State to models/housing_model.pt\n" ] } ], "source": [ "torch.save(mlp.state_dict(), \"models/housing_model.pt\")\n", "print(\"Saved PyTorch Model State to models/housing_model.pt\")" ] }, { "cell_type": "markdown", "id": "0060fcca", "metadata": {}, "source": [ "### Save Model as TorchScript\n", "This saves an [intermediate representation of the compute graph](https://pytorch.org/tutorials/beginner/saving_loading_models.html#export-load-model-in-torchscript-format), which does not require pickle (or even python). " ] }, { "cell_type": "code", "execution_count": null, "id": "20fedb5d-c59e-4b0b-ba91-3dd15df1f09e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Saved TorchScript Model to models/ts_housing_model.pt\n" ] } ], "source": [ "scripted = torch.jit.script(mlp)\n", "scripted.save(\"models/ts_housing_model.pt\")\n", "print(\"Saved TorchScript Model to models/ts_housing_model.pt\")" ] }, { "cell_type": "markdown", "id": "3101c0fe-65f1-411e-9192-e8a6b585ba0d", "metadata": {}, "source": [ "### Load and Test from Model State" ] }, { "cell_type": "code", "execution_count": 13, "id": "7411b00f-88d2-40f5-b716-a26733c968ff", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "loaded_mlp = MLP().to(device)\n", "loaded_mlp.load_state_dict(torch.load(\"models/housing_model.pt\", weights_only=True))" ] }, { "cell_type": "code", "execution_count": 14, "id": "e226f449-2931-4492-9003-503cdc61f061", "metadata": {}, "outputs": [], "source": [ "testX, testY = next(iter(trainloader))" ] }, { "cell_type": "code", "execution_count": null, "id": "d46af47e-db7e-42ee-9bd3-6e7d93850be3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Predictions:\n" ] }, { "data": { "text/plain": [ "tensor([[2.3652],\n", " [1.8444],\n", " [2.4587],\n", " [3.1243],\n", " [2.2726],\n", " [2.1818],\n", " [1.5222],\n", " [0.5554],\n", " [2.2508],\n", " [3.5971]], device='cuda:0', grad_fn=)" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print(\"Predictions:\")\n", "loaded_mlp(testX.to(device))" ] }, { "cell_type": "code", "execution_count": 16, "id": "13ae2c0f-1da5-45a4-bf32-ed8b562d7907", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Labels:\n" ] }, { "data": { "text/plain": [ "tensor([2.7370, 2.2110, 2.5360, 2.6330, 1.6540, 2.3360, 1.4600, 0.6590, 2.6380,\n", " 3.6220])" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print(\"Labels:\")\n", "testY" ] }, { "cell_type": "markdown", "id": "3bcd329d", "metadata": {}, "source": [ "### Load and Test from TorchScript" ] }, { "cell_type": "code", "execution_count": 17, "id": "422e317f-c9bd-4f76-9463-7af2935d401d", "metadata": {}, "outputs": [], "source": [ "scripted_mlp = torch.jit.load(\"models/ts_housing_model.pt\")" ] }, { "cell_type": "code", "execution_count": null, "id": "0cda8ec8-644e-4888-bfa0-b79425ece7c3", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Predictions:\n" ] }, { "data": { "text/plain": [ "tensor([2.3652, 1.8444, 2.4587, 3.1243, 2.2726, 2.1818, 1.5222, 0.5554, 2.2508,\n", " 3.5971], device='cuda:0', grad_fn=)" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print(\"Predictions:\")\n", "scripted_mlp(testX.to(device)).flatten()" ] }, { "cell_type": "markdown", "id": "2a3b64e4", "metadata": {}, "source": [ "### Compile using the Torch JIT Compiler\n", "This leverages the [Torch-TensorRT inference compiler](https://pytorch.org/TensorRT/) for accelerated inference on GPUs using the `torch.compile` JIT interface under the hood. The compiler stack returns a [boxed-function](http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/) that triggers compilation on the first call. \n", "\n", "Modules compiled in this fashion are [not serializable with pickle](https://github.com/pytorch/pytorch/issues/101107#issuecomment-1542688089), so we cannot send the compiled model directly to Spark. " ] }, { "cell_type": "markdown", "id": "c613f24e", "metadata": {}, "source": [ "(You may see a warning about modelopt quantization. This is safe to ignore, as [implicit quantization](https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#intro-quantization) is deprecated in the latest TensorRT. See [this link](https://pytorch.org/TensorRT/tutorials/_rendered_examples/dynamo/vgg16_fp8_ptq.html) for a guide to explicit quantization.)" ] }, { "cell_type": "code", "execution_count": 19, "id": "9ffb27fc", "metadata": {}, "outputs": [], "source": [ "import torch_tensorrt as trt\n", "import time" ] }, { "cell_type": "code", "execution_count": 20, "id": "e0c10f90", "metadata": {}, "outputs": [], "source": [ "# Optional: set the filename for the TensorRT timing cache\n", "timestamp = time.time()\n", "timing_cache = f\"/tmp/timing_cache-{timestamp}.bin\"\n", "with open(timing_cache, \"wb\") as f:\n", " pass" ] }, { "cell_type": "code", "execution_count": 21, "id": "b4aa2523", "metadata": {}, "outputs": [], "source": [ "inputs_bs1 = torch.randn((10, 8), dtype=torch.float).to(\"cuda\")\n", "# This indicates dimension 0 of inputs_bs1 is dynamic with a range of values [1, 50]. No recompilation will happen when the batch size changes.\n", "torch._dynamo.mark_dynamic(inputs_bs1, 0, min=1, max=50)\n", "trt_model = trt.compile(\n", " loaded_mlp,\n", " ir=\"torch_compile\",\n", " inputs=inputs_bs1,\n", " enabled_precisions={torch.float},\n", " timing_cache_path=timing_cache,\n", ")" ] }, { "cell_type": "code", "execution_count": 22, "id": "a5da8cab", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:torch_tensorrt.dynamo._compiler:Node linear_default of op type call_function does not have metadata. This could sometimes lead to undefined behavior.\n", "WARNING:torch_tensorrt.dynamo._compiler:Some nodes do not have metadata (shape and dtype information). This could lead to problems sometimes if the graph has PyTorch and TensorRT segments.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Predictions:\n", "tensor([[2.3652],\n", " [1.8444],\n", " [2.4587],\n", " [3.1243],\n", " [2.2726],\n", " [2.1818],\n", " [1.5222],\n", " [0.5554],\n", " [2.2508],\n", " [3.5971]], device='cuda:0')\n" ] } ], "source": [ "stream = torch.cuda.Stream()\n", "with torch.no_grad(), torch.cuda.stream(stream):\n", " testX = testX.to(device)\n", " print(\"Predictions:\")\n", " print(trt_model(testX))" ] }, { "cell_type": "markdown", "id": "d2c55e07", "metadata": {}, "source": [ "### Compile using the Torch-TensorRT AOT Compiler\n", "Alternatively, use the Torch-TensorRT Dynamo backend for Ahead-of-Time (AOT) compilation to eagerly optimize the model in an explicit compilation phase. We first export the model to produce a traced graph representing the Tensor computation in an AOT fashion, which produces a `ExportedProgram` object which can be [serialized and reloaded](https://pytorch.org/TensorRT/user_guide/saving_models.html). We can then compile this IR using the Torch-TensorRT AOT compiler for inference. \n", "\n", "[Read the docs](https://pytorch.org/TensorRT/user_guide/torch_tensorrt_explained.html) for more information on JIT vs AOT compilation." ] }, { "cell_type": "code", "execution_count": 23, "id": "bf36a50d", "metadata": {}, "outputs": [], "source": [ "example_inputs = (torch.randn((10, 8), dtype=torch.float).to(\"cuda\"),)\n", "\n", "# Mark dim 1 (batch size) as dynamic\n", "batch = torch.export.Dim(\"batch\", min=1, max=64)\n", "# Produce traced graph in ExportedProgram format\n", "exp_program = torch.export.export(loaded_mlp, args=example_inputs, dynamic_shapes={\"x\": {0: batch}})\n", "# Compile the traced graph to produce an optimized module\n", "trt_gm = trt.dynamo.compile(exp_program,\n", " tuple(example_inputs),\n", " enabled_precisions={torch.float},\n", " timing_cache_path=timing_cache,\n", " workspace_size=1<<30)" ] }, { "cell_type": "code", "execution_count": 24, "id": "4fc4efd5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", ".GraphModuleImpl'>\n" ] } ], "source": [ "print(type(exp_program))\n", "print(type(trt_gm))" ] }, { "cell_type": "code", "execution_count": 25, "id": "1bcf9c47", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Predictions:\n", "tensor([[2.3653],\n", " [1.8443],\n", " [2.4586],\n", " [3.1242],\n", " [2.2725],\n", " [2.1815],\n", " [1.5221],\n", " [0.5556],\n", " [2.2508],\n", " [3.5971]], device='cuda:0')\n" ] } ], "source": [ "stream = torch.cuda.Stream()\n", "with torch.no_grad(), torch.cuda.stream(stream):\n", " print(\"Predictions:\")\n", " testX = testX.to(device)\n", " print(trt_gm(testX))" ] }, { "cell_type": "markdown", "id": "0eeb957a", "metadata": {}, "source": [ "We can run the optimized module with a few different batch sizes (without recompilation!):" ] }, { "cell_type": "code", "execution_count": null, "id": "49f72c14", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Output shapes:\n", "torch.Size([10, 1])\n", "torch.Size([1, 1])\n", "torch.Size([50, 1])\n" ] } ], "source": [ "inputs = (torch.randn((10, 8), dtype=torch.float).cuda(),)\n", "inputs_bs1 = (torch.randn((1, 8), dtype=torch.float).cuda(),)\n", "inputs_bs50 = (torch.randn((50, 8), dtype=torch.float).cuda(),)\n", "\n", "stream = torch.cuda.Stream()\n", "with torch.no_grad(), torch.cuda.stream(stream):\n", " print(\"Output shapes:\")\n", " print(trt_gm(*inputs).shape)\n", " print(trt_gm(*inputs_bs1).shape)\n", " print(trt_gm(*inputs_bs50).shape)" ] }, { "cell_type": "markdown", "id": "b4fef57d", "metadata": {}, "source": [ "We can serialize the ExportedProgram (a traced graph representing the model's forward function) using `torch.export.save` to be recompiled at a later date." ] }, { "cell_type": "code", "execution_count": null, "id": "876fea4a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Saved ExportedProgram to models/trt_housing_model.ep\n" ] } ], "source": [ "torch.export.save(exp_program, \"models/trt_housing_model.ep\")\n", "print(\"Saved ExportedProgram to models/trt_housing_model.ep\")" ] }, { "cell_type": "markdown", "id": "13631d1f-2c71-4bee-afcb-bd3b55ec87c5", "metadata": {}, "source": [ "## PySpark" ] }, { "cell_type": "code", "execution_count": null, "id": "bb71dd36", "metadata": {}, "outputs": [], "source": [ "from pyspark.sql.functions import col, struct, pandas_udf, array\n", "from pyspark.ml.functions import predict_batch_udf\n", "from pyspark.sql.types import *\n", "from pyspark.sql import SparkSession\n", "from pyspark import SparkConf\n", "import json\n", "import pandas as pd" ] }, { "cell_type": "markdown", "id": "6769c060", "metadata": {}, "source": [ "Check the cluster environment to handle any platform-specific Spark configurations." ] }, { "cell_type": "code", "execution_count": 29, "id": "f7727b58", "metadata": {}, "outputs": [], "source": [ "on_databricks = os.environ.get(\"DATABRICKS_RUNTIME_VERSION\", False)\n", "on_dataproc = os.environ.get(\"DATAPROC_IMAGE_VERSION\", False)\n", "on_standalone = not (on_databricks or on_dataproc)" ] }, { "cell_type": "markdown", "id": "a3b7d360", "metadata": {}, "source": [ "#### Create Spark Session\n", "\n", "For local standalone clusters, we'll connect to the cluster and create the Spark Session. \n", "For CSP environments, Spark will either be preconfigured (Databricks) or we'll need to create the Spark Session (Dataproc)." ] }, { "cell_type": "code", "execution_count": 30, "id": "52e9dbdb", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "25/02/04 13:46:28 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)\n", "25/02/04 13:46:28 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\n", "Setting default log level to \"WARN\".\n", "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n", "25/02/04 13:46:28 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n" ] } ], "source": [ "conf = SparkConf()\n", "\n", "if 'spark' not in globals():\n", " if on_standalone:\n", " import socket\n", " conda_env = os.environ.get(\"CONDA_PREFIX\")\n", " hostname = socket.gethostname()\n", " conf.setMaster(f\"spark://{hostname}:7077\")\n", " conf.set(\"spark.pyspark.python\", f\"{conda_env}/bin/python\")\n", " conf.set(\"spark.pyspark.driver.python\", f\"{conda_env}/bin/python\")\n", "\n", " conf.set(\"spark.executor.cores\", \"8\")\n", " conf.set(\"spark.task.resource.gpu.amount\", \"0.125\")\n", " conf.set(\"spark.executor.resource.gpu.amount\", \"1\")\n", " conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"true\")\n", " conf.set(\"spark.python.worker.reuse\", \"true\")\n", " \n", "spark = SparkSession.builder.appName(\"spark-dl-examples\").config(conf=conf).getOrCreate()\n", "sc = spark.sparkContext" ] }, { "cell_type": "markdown", "id": "e3b9937e-2c70-4d67-b95f-4d9d5ab17c12", "metadata": {}, "source": [ "### Create Spark DataFrame from Pandas DataFrame" ] }, { "cell_type": "code", "execution_count": 31, "id": "cf35da14-61a3-4e7b-9d4f-086bf5e931b3", "metadata": {}, "outputs": [], "source": [ "housing = fetch_california_housing()" ] }, { "cell_type": "code", "execution_count": 32, "id": "95148019-ea95-40e5-a529-fcdb9a06f928", "metadata": {}, "outputs": [], "source": [ "X = StandardScaler().fit_transform(housing.data.astype(np.float32))" ] }, { "cell_type": "code", "execution_count": 33, "id": "f82d957c-6747-4408-aac8-45305afbfe5e", "metadata": {}, "outputs": [], "source": [ "pdf = pd.DataFrame(X, columns=housing.feature_names)" ] }, { "cell_type": "code", "execution_count": 34, "id": "881afee9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+------------+------------+-----------+------------+-----------+------------+----------+------------+\n", "| MedInc| HouseAge| AveRooms| AveBedrms| Population| AveOccup| Latitude| Longitude|\n", "+------------+------------+-----------+------------+-----------+------------+----------+------------+\n", "| 0.20909257| -1.1632254| 0.38946992| 0.04609274| -0.9806099| -0.07099328|0.61245227|-0.020113053|\n", "|-0.098627955| 0.34647804| 0.27216315| -0.0129226| -0.6953838| -0.05380849| 1.0665938| -1.2479742|\n", "| -0.66006273| 1.0616008|-0.55292207| -0.48945764|-0.13641118| 0.028952759| 1.1040496| -1.3827378|\n", "| 0.08218294| 0.5848523|-0.13912922| -0.14707813|-0.19116047| -0.07136432|0.96827507| -1.3028787|\n", "| 0.0784456| -1.4810578| 0.57265776| 0.32067496| 1.0345173|-0.024157424| 1.4411427| -0.52423614|\n", "| -0.82318723| -0.36864465| 0.07829511| -0.1808107|-0.67242444|-0.061470542| 1.9374212| -1.0083897|\n", "| 0.59671736| 0.5848523| 0.19346413| -0.1371872|-0.19645879| 0.009964322|0.96827507| -1.2928978|\n", "| -0.9612035| -1.5605159|-0.56329846| 0.027148023|-0.71127874| -0.08471591| 0.5328614| -0.13990337|\n", "| -0.74344087| -1.2426835| 0.27282518| 0.4037246| -0.9841421| -0.05610115| 1.2257773| -0.42940006|\n", "| 0.9784464| -0.2891866| 0.24374022| -0.24670053| 0.28922042| -0.01102468| 1.1087307| -1.2280084|\n", "| -0.5070446| -1.0043093|-0.78254056|0.0122275995| 2.8465424|-0.060435444| 0.8980464| -1.2080427|\n", "| -0.18690155| 1.2205169|0.015323491| 0.12183313|-0.41015765| 0.04452552| 1.010412| -1.3228445|\n", "| -1.2551856| 1.6178073| -0.3341509|-0.060125165| -0.7554314| -0.08777025| 1.0291398| -1.3477987|\n", "| 4.9607058| -1.9578062| 1.4854684| -0.03948475| 2.1833694|0.0029250523| 1.024457| -1.1581304|\n", "| 0.73652315| -1.6399739| 0.7913185| -0.05238397| 1.67738| 0.01944797| 1.0993668| -1.1331724|\n", "| -0.505834| 0.18756187|-0.47093546| -0.24297306|-0.60619545| -0.10791535| 0.977639| -1.2879055|\n", "| -0.88477343|-0.050812364| -0.6318951| -0.15244243| -0.5258376| -0.15618815| 0.9823201| -1.2879055|\n", "| -0.42840376| 0.9821427| -0.2266495| -0.36083496| -0.6883194| -0.08552282| 0.5328614| -0.12493005|\n", "| 0.9369153| -1.4810578| 0.6722208|-0.121177554| 0.3996021| 0.01291408| 1.1040496| -1.1082181|\n", "| -0.80702734| -0.92485124|-0.26602685| -0.1560743| 1.4398388| -0.09314839|0.55627036| -0.09498342|\n", "+------------+------------+-----------+------------+-----------+------------+----------+------------+\n", "only showing top 20 rows\n", "\n" ] } ], "source": [ "schema = StructType([\n", " StructField(\"MedInc\",FloatType(),True),\n", " StructField(\"HouseAge\",FloatType(),True),\n", " StructField(\"AveRooms\",FloatType(),True),\n", " StructField(\"AveBedrms\",FloatType(),True),\n", " StructField(\"Population\",FloatType(),True),\n", " StructField(\"AveOccup\",FloatType(),True),\n", " StructField(\"Latitude\",FloatType(),True),\n", " StructField(\"Longitude\",FloatType(),True)\n", "])\n", "\n", "df = spark.createDataFrame(pdf, schema=schema).repartition(8)\n", "df.show(truncate=12)" ] }, { "cell_type": "code", "execution_count": 35, "id": "7b33d367-fbf9-4918-b755-5447125547c4", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "StructType([StructField('MedInc', FloatType(), True), StructField('HouseAge', FloatType(), True), StructField('AveRooms', FloatType(), True), StructField('AveBedrms', FloatType(), True), StructField('Population', FloatType(), True), StructField('AveOccup', FloatType(), True), StructField('Latitude', FloatType(), True), StructField('Longitude', FloatType(), True)])" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.schema" ] }, { "cell_type": "code", "execution_count": 36, "id": "751bff7a-b687-4184-b3fa-b5f5b46ef5d1", "metadata": {}, "outputs": [], "source": [ "data_path = \"spark-dl-datasets/california_housing\"\n", "if on_databricks:\n", " dbutils.fs.mkdirs(\"/FileStore/spark-dl-datasets\")\n", " data_path = \"dbfs:/FileStore/\" + data_path\n", "\n", "df.write.mode(\"overwrite\").parquet(data_path)" ] }, { "cell_type": "markdown", "id": "88c3cd75", "metadata": {}, "source": [ "## Inference using Spark DL API\n", "\n", "Distributed inference using the PySpark [predict_batch_udf](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.functions.predict_batch_udf.html#pyspark.ml.functions.predict_batch_udf):\n", "\n", "- predict_batch_fn uses PyTorch APIs to load the model and return a predict function which operates on numpy arrays \n", "- predict_batch_udf will convert the Spark DataFrame columns into numpy input batches for the predict function" ] }, { "cell_type": "code", "execution_count": 37, "id": "1e40c266-24de-454d-a776-f3716ba50e90", "metadata": {}, "outputs": [], "source": [ "df = spark.read.parquet(data_path)" ] }, { "cell_type": "code", "execution_count": 38, "id": "5b144c17", "metadata": {}, "outputs": [], "source": [ "columns = df.columns" ] }, { "cell_type": "code", "execution_count": 39, "id": "3d608e2f-66a8-44b5-9cde-5f7837bf4247", "metadata": {}, "outputs": [], "source": [ "# get absolute path to model\n", "model_path = \"{}/models/trt_housing_model.ep\".format(os.getcwd())\n", "\n", "# For cloud environments, copy the model to the distributed file system.\n", "if on_databricks:\n", " dbutils.fs.mkdirs(\"/FileStore/spark-dl-models\")\n", " dbfs_model_path = \"/dbfs/FileStore/spark-dl-models/trt_housing_model.ep\"\n", " shutil.copy(model_path, dbfs_model_path)\n", " model_path = dbfs_model_path\n", "elif on_dataproc:\n", " # GCS is mounted at /mnt/gcs by the init script\n", " models_dir = \"/mnt/gcs/spark-dl/models\"\n", " os.mkdir(models_dir) if not os.path.exists(models_dir) else None\n", " gcs_model_path = models_dir + \"/trt_housing_model.ep\"\n", " shutil.copy(model_path, gcs_model_path)\n", " model_path = gcs_model_path" ] }, { "cell_type": "markdown", "id": "2fd143e7", "metadata": {}, "source": [ "For inference on Spark, we'll load the ExportedProgram and compile the model with the Torch-TensorRT AOT compiler and cache on the executor. " ] }, { "cell_type": "code", "execution_count": 40, "id": "fc400771", "metadata": {}, "outputs": [], "source": [ "# A resource warning may occur due to unclosed file descriptors used by TensorRT across multiple PySpark daemon processes.\n", "# These can be safely ignored as the resources will be cleaned up when the worker processes terminate.\n", "\n", "import warnings\n", "warnings.simplefilter(\"ignore\", ResourceWarning)" ] }, { "cell_type": "code", "execution_count": 41, "id": "a2f45f5d-c941-4197-a274-1eec2af3fca4", "metadata": {}, "outputs": [], "source": [ "def predict_batch_fn():\n", " import torch\n", " import torch_tensorrt as trt\n", "\n", " device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", " if device != \"cuda\":\n", " raise ValueError(\"This function uses the TensorRT model which requires a GPU device\")\n", "\n", " example_inputs = (torch.randn((50, 8), dtype=torch.float).to(\"cuda\"),)\n", " exp_program = torch.export.load(model_path)\n", " trt_gm = trt.dynamo.compile(exp_program,\n", " tuple(example_inputs),\n", " enabled_precisions={torch.float},\n", " timing_cache_path=timing_cache,\n", " workspace_size=1<<30)\n", "\n", " print(\"Model compiled.\")\n", " \n", " def predict(inputs):\n", " stream = torch.cuda.Stream()\n", " with torch.no_grad(), torch.cuda.stream(stream), trt.logging.errors():\n", " print(f\"Predict {inputs.shape}\")\n", " torch_inputs = torch.from_numpy(inputs).to(device)\n", " outputs = trt_gm(torch_inputs) # .flatten()\n", " return outputs.detach().cpu().numpy()\n", "\n", " return predict" ] }, { "cell_type": "code", "execution_count": 42, "id": "220a00a4-e842-4f5d-a4b3-7693d09e2d31", "metadata": {}, "outputs": [], "source": [ "regress = predict_batch_udf(predict_batch_fn,\n", " return_type=FloatType(),\n", " input_tensor_shapes=[[8]],\n", " batch_size=50)" ] }, { "cell_type": "code", "execution_count": 43, "id": "0f3bf287-8ffc-4456-8772-e97c418d6aee", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 7:==============> (2 + 6) / 8]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 30.4 ms, sys: 13.1 ms, total: 43.5 ms\n", "Wall time: 10.1 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "%%time\n", "preds = df.withColumn(\"preds\", regress(struct(*columns)))\n", "results = preds.collect()" ] }, { "cell_type": "code", "execution_count": 44, "id": "6cd23b71-296d-4ce7-b56c-567cc2eec79c", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 31.6 ms, sys: 7.39 ms, total: 39 ms\n", "Wall time: 263 ms\n" ] } ], "source": [ "%%time\n", "preds = df.withColumn(\"preds\", regress(array(*columns)))\n", "results = preds.collect()" ] }, { "cell_type": "code", "execution_count": 45, "id": "75d16bd5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 28.7 ms, sys: 6.67 ms, total: 35.4 ms\n", "Wall time: 296 ms\n" ] } ], "source": [ "%%time\n", "preds = df.withColumn(\"preds\", regress(array(*columns)))\n", "results = preds.collect()" ] }, { "cell_type": "code", "execution_count": 46, "id": "764a40d8-25f7-425c-ba03-fe8c45f4b063", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+------------+------------+-----------+------------+-----------+------------+----------+------------+---------+\n", "| MedInc| HouseAge| AveRooms| AveBedrms| Population| AveOccup| Latitude| Longitude| preds|\n", "+------------+------------+-----------+------------+-----------+------------+----------+------------+---------+\n", "| 0.20909257| -1.1632254| 0.38946992| 0.04609274| -0.9806099| -0.07099328|0.61245227|-0.020113053|1.3746364|\n", "|-0.098627955| 0.34647804| 0.27216315| -0.0129226| -0.6953838| -0.05380849| 1.0665938| -1.2479742|1.8087528|\n", "| -0.66006273| 1.0616008|-0.55292207| -0.48945764|-0.13641118| 0.028952759| 1.1040496| -1.3827378|1.4245079|\n", "| 0.08218294| 0.5848523|-0.13912922| -0.14707813|-0.19116047| -0.07136432|0.96827507| -1.3028787|2.3895802|\n", "| 0.0784456| -1.4810578| 0.57265776| 0.32067496| 1.0345173|-0.024157424| 1.4411427| -0.52423614|1.3616933|\n", "| -0.82318723| -0.36864465| 0.07829511| -0.1808107|-0.67242444|-0.061470542| 1.9374212| -1.0083897|0.7539238|\n", "| 0.59671736| 0.5848523| 0.19346413| -0.1371872|-0.19645879| 0.009964322|0.96827507| -1.2928978|2.6816423|\n", "| -0.9612035| -1.5605159|-0.56329846| 0.027148023|-0.71127874| -0.08471591| 0.5328614| -0.13990337|1.1731354|\n", "| -0.74344087| -1.2426835| 0.27282518| 0.4037246| -0.9841421| -0.05610115| 1.2257773| -0.42940006|1.0198532|\n", "| 0.9784464| -0.2891866| 0.24374022| -0.24670053| 0.28922042| -0.01102468| 1.1087307| -1.2280084| 2.708211|\n", "| -0.5070446| -1.0043093|-0.78254056|0.0122275995| 2.8465424|-0.060435444| 0.8980464| -1.2080427|2.0327075|\n", "| -0.18690155| 1.2205169|0.015323491| 0.12183313|-0.41015765| 0.04452552| 1.010412| -1.3228445|1.9909104|\n", "| -1.2551856| 1.6178073| -0.3341509|-0.060125165| -0.7554314| -0.08777025| 1.0291398| -1.3477987|1.2702764|\n", "| 4.9607058| -1.9578062| 1.4854684| -0.03948475| 2.1833694|0.0029250523| 1.024457| -1.1581304| 5.975229|\n", "| 0.73652315| -1.6399739| 0.7913185| -0.05238397| 1.67738| 0.01944797| 1.0993668| -1.1331724|1.9309721|\n", "| -0.505834| 0.18756187|-0.47093546| -0.24297306|-0.60619545| -0.10791535| 0.977639| -1.2879055|1.7610806|\n", "| -0.88477343|-0.050812364| -0.6318951| -0.15244243| -0.5258376| -0.15618815| 0.9823201| -1.2879055| 1.655031|\n", "| -0.42840376| 0.9821427| -0.2266495| -0.36083496| -0.6883194| -0.08552282| 0.5328614| -0.12493005|1.1175063|\n", "| 0.9369153| -1.4810578| 0.6722208|-0.121177554| 0.3996021| 0.01291408| 1.1040496| -1.1082181|2.1779811|\n", "| -0.80702734| -0.92485124|-0.26602685| -0.1560743| 1.4398388| -0.09314839|0.55627036| -0.09498342|0.9102398|\n", "+------------+------------+-----------+------------+-----------+------------+----------+------------+---------+\n", "only showing top 20 rows\n", "\n" ] } ], "source": [ "preds.show()" ] }, { "cell_type": "code", "execution_count": 47, "id": "0aa85f81", "metadata": {}, "outputs": [], "source": [ "# This will clear the engine cache (containing previously compiled TensorRT engines) and reset the CUDA Context.\n", "torch._dynamo.reset()" ] }, { "cell_type": "markdown", "id": "53536808", "metadata": {}, "source": [ "## Using Triton Inference Server\n", "In this section, we demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL. \n", "We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server. \n", "\n", "The process looks like this:\n", "- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.\n", "- Define a Triton inference function, which contains a client that binds to the local server on a given node and sends inference requests.\n", "- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.\n", "- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.\n", "\n", "\"drawing\"" ] }, { "cell_type": "code", "execution_count": 48, "id": "a9ab4cdf-8103-447e-9ac8-944e2e527239", "metadata": {}, "outputs": [], "source": [ "from functools import partial" ] }, { "cell_type": "markdown", "id": "1b77dc96", "metadata": {}, "source": [ "Import the helper class from server_utils.py:" ] }, { "cell_type": "code", "execution_count": 49, "id": "1ac83062", "metadata": {}, "outputs": [], "source": [ "sc.addPyFile(\"server_utils.py\")\n", "\n", "from server_utils import TritonServerManager" ] }, { "cell_type": "markdown", "id": "a4cc5d81", "metadata": {}, "source": [ "Define the Triton Server function:" ] }, { "cell_type": "code", "execution_count": 50, "id": "6632636e-67a3-406c-832c-758aac4245fd", "metadata": {}, "outputs": [], "source": [ "def triton_server(ports, model_path):\n", " import time\n", " import signal\n", " import numpy as np\n", " import torch\n", " import torch_tensorrt as trt\n", " from pytriton.decorators import batch\n", " from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor\n", " from pytriton.triton import Triton, TritonConfig\n", " from pyspark import TaskContext\n", "\n", " print(f\"SERVER: Initializing model on worker {TaskContext.get().partitionId()}.\")\n", " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", " \n", " exp_program = torch.export.load(model_path)\n", " example_inputs = (torch.randn((50, 8), dtype=torch.float).to(\"cuda\"),)\n", " trt_gm = trt.dynamo.compile(exp_program,\n", " tuple(example_inputs),\n", " enabled_precisions={torch.float},\n", " workspace_size=1<<30)\n", "\n", " print(\"SERVER: Compiled model.\")\n", "\n", " @batch\n", " def _infer_fn(**inputs):\n", " features = inputs[\"features\"]\n", " if len(inputs[\"features\"]) != 1:\n", " features = np.squeeze(features)\n", " stream = torch.cuda.Stream()\n", " with torch.no_grad(), torch.cuda.stream(stream):\n", " torch_inputs = torch.from_numpy(features).to(device)\n", " outputs = trt_gm(torch_inputs)\n", " return {\n", " \"preds\": outputs.cpu().numpy(),\n", " }\n", "\n", " workspace_path = f\"/tmp/triton_{time.strftime('%m_%d_%M_%S')}\"\n", " triton_conf = TritonConfig(http_port=ports[0], grpc_port=ports[1], metrics_port=ports[2])\n", " with Triton(config=triton_conf, workspace=workspace_path) as triton:\n", " triton.bind(\n", " model_name=\"HousingModel\",\n", " infer_func=_infer_fn,\n", " inputs=[\n", " Tensor(name=\"features\", dtype=np.float32, shape=(-1,)),\n", " ],\n", " outputs=[\n", " Tensor(name=\"preds\", dtype=np.float32, shape=(-1,)),\n", " ],\n", " config=ModelConfig(\n", " max_batch_size=50,\n", " batcher=DynamicBatcher(max_queue_delay_microseconds=5000), # 5ms\n", " ),\n", " strict=True,\n", " )\n", "\n", " def _stop_triton(signum, frame):\n", " # The server manager sends SIGTERM to stop the server; this function ensures graceful cleanup.\n", " print(\"SERVER: Received SIGTERM. Stopping Triton server.\")\n", " triton.stop()\n", "\n", " signal.signal(signal.SIGTERM, _stop_triton)\n", "\n", " print(\"SERVER: Serving inference\")\n", " triton.serve()" ] }, { "cell_type": "markdown", "id": "74121cd7", "metadata": {}, "source": [ "#### Start Triton servers" ] }, { "cell_type": "markdown", "id": "6d6b7143", "metadata": {}, "source": [ "The `TritonServerManager` will handle the lifecycle of Triton server instances across the Spark cluster:\n", "- Find available ports for HTTP/gRPC/metrics\n", "- Deploy a server on each node via stage-level scheduling\n", "- Gracefully shutdown servers across nodes" ] }, { "cell_type": "code", "execution_count": 54, "id": "2fb22db8", "metadata": {}, "outputs": [], "source": [ "model_name = \"HousingModel\"\n", "server_manager = TritonServerManager(model_name=model_name, model_path=model_path)" ] }, { "cell_type": "code", "execution_count": null, "id": "e067aa14", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-02-07 11:03:44,809 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\n", "2025-02-07 11:03:44,810 - INFO - Starting 1 servers.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "data": { "text/plain": [ "{'cb4ae00-lcedt': (2020631, [7000, 7001, 7002])}" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Returns {'hostname', (server_pid, [http_port, grpc_port, metrics_port])}\n", "server_manager.start_servers(triton_server)" ] }, { "cell_type": "markdown", "id": "9a1ac038", "metadata": {}, "source": [ "#### Define client function" ] }, { "cell_type": "markdown", "id": "d4ac45ef", "metadata": {}, "source": [ "Get the hostname -> url mapping from the server manager:" ] }, { "cell_type": "code", "execution_count": null, "id": "92760dac", "metadata": {}, "outputs": [], "source": [ "host_to_http_url = server_manager.host_to_http_url # or server_manager.host_to_grpc_url" ] }, { "cell_type": "markdown", "id": "122ebe7c", "metadata": {}, "source": [ "Define the Triton inference function, which returns a predict function for batch inference through the server:" ] }, { "cell_type": "code", "execution_count": 57, "id": "1ae91c54", "metadata": {}, "outputs": [], "source": [ "def triton_fn(model_name, host_to_url):\n", " import socket\n", " from pytriton.client import ModelClient\n", "\n", " url = host_to_url[socket.gethostname()]\n", " print(f\"Connecting to Triton model {model_name} at {url}.\")\n", "\n", " def infer_batch(inputs):\n", " with ModelClient(url, model_name, inference_timeout_s=240) as client:\n", " result_data = client.infer_batch(inputs)\n", " return result_data[\"preds\"]\n", " \n", " return infer_batch" ] }, { "cell_type": "code", "execution_count": 60, "id": "d3e64fda-117b-4810-a9a2-dd498239496f", "metadata": {}, "outputs": [], "source": [ "regress = predict_batch_udf(partial(triton_fn, model_name=model_name, host_to_url=host_to_http_url),\n", " input_tensor_shapes=[[8]],\n", " return_type=FloatType(),\n", " batch_size=50)" ] }, { "cell_type": "markdown", "id": "20b8514e-01de-481f-86aa-75afd99bcc7c", "metadata": {}, "source": [ "### Run Inference" ] }, { "cell_type": "code", "execution_count": 58, "id": "5eae04bc-75ca-421a-87c8-ac507ce1f2f5", "metadata": {}, "outputs": [], "source": [ "df = spark.read.parquet(data_path)" ] }, { "cell_type": "code", "execution_count": 59, "id": "b350bd8e-9b8f-4511-9ddf-76d917b21b5f", "metadata": {}, "outputs": [], "source": [ "columns = df.columns" ] }, { "cell_type": "code", "execution_count": 61, "id": "a24149a5-3adc-4089-8769-13cf1e44547a", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 16:> (0 + 8) / 8]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 25.8 ms, sys: 6.21 ms, total: 32.1 ms\n", "Wall time: 2.37 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "%%time\n", "# first pass caches model/fn\n", "predictions = df.withColumn(\"preds\", regress(struct(*columns)))\n", "preds = predictions.collect()" ] }, { "cell_type": "code", "execution_count": 62, "id": "df2ce39f-30af-491a-8472-800fb1ce8458", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 17:> (0 + 8) / 8]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 171 ms, sys: 3.76 ms, total: 174 ms\n", "Wall time: 2.5 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "%%time\n", "predictions = df.withColumn(\"preds\", regress(array(*columns)))\n", "preds = predictions.collect()" ] }, { "cell_type": "code", "execution_count": 63, "id": "ca6f3eaa-9569-45d0-88bf-9aa0757e1ecb", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 18:> (0 + 8) / 8]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 24.4 ms, sys: 4.83 ms, total: 29.2 ms\n", "Wall time: 1.97 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "%%time\n", "predictions = df.withColumn(\"preds\", regress(array(*columns)))\n", "preds = predictions.collect()" ] }, { "cell_type": "code", "execution_count": 64, "id": "b79c62c8-e1e8-4467-8aef-8939c31833b8", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+------------+------------+-----------+------------+-----------+------------+----------+------------+---------+\n", "| MedInc| HouseAge| AveRooms| AveBedrms| Population| AveOccup| Latitude| Longitude| preds|\n", "+------------+------------+-----------+------------+-----------+------------+----------+------------+---------+\n", "| 0.20909257| -1.1632254| 0.38946992| 0.04609274| -0.9806099| -0.07099328|0.61245227|-0.020113053|1.3746364|\n", "|-0.098627955| 0.34647804| 0.27216315| -0.0129226| -0.6953838| -0.05380849| 1.0665938| -1.2479742|1.8087528|\n", "| -0.66006273| 1.0616008|-0.55292207| -0.48945764|-0.13641118| 0.028952759| 1.1040496| -1.3827378|1.4245079|\n", "| 0.08218294| 0.5848523|-0.13912922| -0.14707813|-0.19116047| -0.07136432|0.96827507| -1.3028787|2.3895802|\n", "| 0.0784456| -1.4810578| 0.57265776| 0.32067496| 1.0345173|-0.024157424| 1.4411427| -0.52423614|1.3616933|\n", "| -0.82318723| -0.36864465| 0.07829511| -0.1808107|-0.67242444|-0.061470542| 1.9374212| -1.0083897|0.7539238|\n", "| 0.59671736| 0.5848523| 0.19346413| -0.1371872|-0.19645879| 0.009964322|0.96827507| -1.2928978|2.6816423|\n", "| -0.9612035| -1.5605159|-0.56329846| 0.027148023|-0.71127874| -0.08471591| 0.5328614| -0.13990337|1.1731354|\n", "| -0.74344087| -1.2426835| 0.27282518| 0.4037246| -0.9841421| -0.05610115| 1.2257773| -0.42940006|1.0198532|\n", "| 0.9784464| -0.2891866| 0.24374022| -0.24670053| 0.28922042| -0.01102468| 1.1087307| -1.2280084| 2.708211|\n", "| -0.5070446| -1.0043093|-0.78254056|0.0122275995| 2.8465424|-0.060435444| 0.8980464| -1.2080427|2.0327075|\n", "| -0.18690155| 1.2205169|0.015323491| 0.12183313|-0.41015765| 0.04452552| 1.010412| -1.3228445|1.9909104|\n", "| -1.2551856| 1.6178073| -0.3341509|-0.060125165| -0.7554314| -0.08777025| 1.0291398| -1.3477987|1.2702764|\n", "| 4.9607058| -1.9578062| 1.4854684| -0.03948475| 2.1833694|0.0029250523| 1.024457| -1.1581304| 5.975229|\n", "| 0.73652315| -1.6399739| 0.7913185| -0.05238397| 1.67738| 0.01944797| 1.0993668| -1.1331724|1.9309721|\n", "| -0.505834| 0.18756187|-0.47093546| -0.24297306|-0.60619545| -0.10791535| 0.977639| -1.2879055|1.7610806|\n", "| -0.88477343|-0.050812364| -0.6318951| -0.15244243| -0.5258376| -0.15618815| 0.9823201| -1.2879055| 1.655031|\n", "| -0.42840376| 0.9821427| -0.2266495| -0.36083496| -0.6883194| -0.08552282| 0.5328614| -0.12493005|1.1175063|\n", "| 0.9369153| -1.4810578| 0.6722208|-0.121177554| 0.3996021| 0.01291408| 1.1040496| -1.1082181|2.1779811|\n", "| -0.80702734| -0.92485124|-0.26602685| -0.1560743| 1.4398388| -0.09314839|0.55627036| -0.09498342|0.9102398|\n", "+------------+------------+-----------+------------+-----------+------------+----------+------------+---------+\n", "only showing top 20 rows\n", "\n" ] } ], "source": [ "predictions.show()" ] }, { "cell_type": "markdown", "id": "3fec23b0-eaf2-4b6a-aa38-7a09873ed6eb", "metadata": { "tags": [] }, "source": [ "#### Stop Triton Server on each executor" ] }, { "cell_type": "code", "execution_count": 65, "id": "8084bdef", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "data": { "text/plain": [ "[True]" ] }, "execution_count": 65, "metadata": {}, "output_type": "execute_result" } ], "source": [ "server_manager.stop_servers()" ] }, { "cell_type": "code", "execution_count": 66, "id": "0138a029-87c5-497f-ac5c-3eed0e11b0f6", "metadata": {}, "outputs": [], "source": [ "if not on_databricks: # on databricks, spark.stop() puts the cluster in a bad state\n", " spark.stop()" ] }, { "cell_type": "code", "execution_count": null, "id": "d24147e7-5695-44a0-9961-b94bfba1cfff", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "spark-dl-torch", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.11" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: examples/ML+DL-Examples/Spark-DL/dl_inference/pytorch/image_classification_torch.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "id": "9e87c927", "metadata": {}, "source": [ "\n", "\n", "# PySpark PyTorch Inference\n", "\n", "### Image Classification\n", "\n", "In this notebook, we will train an MLP to perform image classification on FashionMNIST, and load it for distributed inference with Spark.\n", "\n", "Based on: https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html \n", "\n", "We also demonstrate accelerated inference via Torch-TensorRT model compilation. " ] }, { "cell_type": "code", "execution_count": 1, "id": "91d7ec98", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import os\n", "import shutil\n", "from torch import nn\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets\n", "from torchvision.transforms import ToTensor" ] }, { "cell_type": "code", "execution_count": 2, "id": "f71f801d", "metadata": {}, "outputs": [], "source": [ "os.mkdir('models') if not os.path.exists('models') else None" ] }, { "cell_type": "code", "execution_count": 3, "id": "d714f40d", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'2.5.1+cu124'" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.__version__" ] }, { "cell_type": "markdown", "id": "d0f6fb37", "metadata": {}, "source": [ "### Load Dataset" ] }, { "cell_type": "code", "execution_count": 4, "id": "1c942a46", "metadata": {}, "outputs": [], "source": [ "# Download training data from open datasets.\n", "training_data = datasets.FashionMNIST(\n", " root=\"datasets/data\",\n", " train=True,\n", " download=True,\n", " transform=ToTensor(),\n", ")\n", "\n", "# Download test data from open datasets.\n", "test_data = datasets.FashionMNIST(\n", " root=\"datasets/data\",\n", " train=False,\n", " download=True,\n", " transform=ToTensor(),\n", ")" ] }, { "cell_type": "code", "execution_count": 5, "id": "4a89aa8e-ef62-4aac-8260-4b004f2c1b55", "metadata": {}, "outputs": [], "source": [ "classes = [\n", " \"T-shirt/top\",\n", " \"Trouser\",\n", " \"Pullover\",\n", " \"Dress\",\n", " \"Coat\",\n", " \"Sandal\",\n", " \"Shirt\",\n", " \"Sneaker\",\n", " \"Bag\",\n", " \"Ankle boot\",\n", "]" ] }, { "cell_type": "code", "execution_count": 6, "id": "10a97111", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28]) torch.float32\n", "Shape of y: torch.Size([64]) torch.int64\n" ] } ], "source": [ "batch_size = 64\n", "\n", "# Create data loaders.\n", "train_dataloader = DataLoader(training_data, batch_size=batch_size)\n", "test_dataloader = DataLoader(test_data, batch_size=batch_size)\n", "\n", "for X, y in test_dataloader:\n", " print(f\"Shape of X [N, C, H, W]: {X.shape} {X.dtype}\")\n", " print(f\"Shape of y: {y.shape} {y.dtype}\")\n", " break" ] }, { "cell_type": "markdown", "id": "ca7af350", "metadata": {}, "source": [ "### Create model" ] }, { "cell_type": "code", "execution_count": 7, "id": "512d0bc7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using cuda device\n", "NeuralNetwork(\n", " (linear_relu_stack): Sequential(\n", " (0): Linear(in_features=784, out_features=512, bias=True)\n", " (1): ReLU()\n", " (2): Linear(in_features=512, out_features=512, bias=True)\n", " (3): ReLU()\n", " (4): Linear(in_features=512, out_features=10, bias=True)\n", " )\n", ")\n" ] } ], "source": [ "# Get cpu or gpu device for training.\n", "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "print(f\"Using {device} device\")\n", "\n", "# Define model\n", "class NeuralNetwork(nn.Module):\n", " def __init__(self):\n", " super(NeuralNetwork, self).__init__()\n", " self.linear_relu_stack = nn.Sequential(\n", " nn.Linear(28*28, 512),\n", " nn.ReLU(),\n", " nn.Linear(512, 512),\n", " nn.ReLU(),\n", " nn.Linear(512, 10)\n", " )\n", "\n", " def forward(self, x):\n", " logits = self.linear_relu_stack(x)\n", " return logits\n", "\n", "model = NeuralNetwork().to(device)\n", "print(model)" ] }, { "cell_type": "markdown", "id": "4573c1b7", "metadata": {}, "source": [ "### Train Model" ] }, { "cell_type": "code", "execution_count": 8, "id": "4d4f5538", "metadata": {}, "outputs": [], "source": [ "loss_fn = nn.CrossEntropyLoss()\n", "optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)" ] }, { "cell_type": "code", "execution_count": 9, "id": "92d9076a", "metadata": {}, "outputs": [], "source": [ "def train(dataloader, model, loss_fn, optimizer):\n", " size = len(dataloader.dataset)\n", " model.train()\n", " for batch, (X, y) in enumerate(dataloader):\n", " X, y = X.to(device), y.to(device)\n", " X = torch.flatten(X, start_dim=1, end_dim=-1)\n", "\n", " # Zero gradients\n", " optimizer.zero_grad()\n", "\n", " # Compute prediction error\n", " pred = model(X)\n", " loss = loss_fn(pred, y)\n", "\n", " # Backpropagation\n", " loss.backward()\n", " optimizer.step()\n", "\n", " if batch % 100 == 0:\n", " loss, current = loss.item(), (batch + 1) * len(X)\n", " print(f\"loss: {loss:>7f} [{current:>5d}/{size:>5d}]\")" ] }, { "cell_type": "code", "execution_count": 10, "id": "11c5650d", "metadata": {}, "outputs": [], "source": [ "def test(dataloader, model, loss_fn):\n", " size = len(dataloader.dataset)\n", " num_batches = len(dataloader)\n", " model.eval()\n", " test_loss, correct = 0, 0\n", " with torch.no_grad():\n", " for X, y in dataloader:\n", " X, y = X.to(device), y.to(device)\n", " X = torch.flatten(X, start_dim=1, end_dim=-1)\n", " pred = model(X)\n", " test_loss += loss_fn(pred, y).item()\n", " correct += (pred.argmax(1) == y).type(torch.float).sum().item()\n", " test_loss /= num_batches\n", " correct /= size\n", " print(f\"Test Error: \\n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \\n\")" ] }, { "cell_type": "code", "execution_count": 11, "id": "854608e6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1\n", "-------------------------------\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss: 2.298206 [ 64/60000]\n", "loss: 2.283203 [ 6464/60000]\n", "loss: 2.262282 [12864/60000]\n", "loss: 2.259791 [19264/60000]\n", "loss: 2.240928 [25664/60000]\n", "loss: 2.218922 [32064/60000]\n", "loss: 2.225280 [38464/60000]\n", "loss: 2.193091 [44864/60000]\n", "loss: 2.194699 [51264/60000]\n", "loss: 2.157922 [57664/60000]\n", "Test Error: \n", " Accuracy: 38.6%, Avg loss: 2.149652 \n", "\n", "Epoch 2\n", "-------------------------------\n", "loss: 2.164765 [ 64/60000]\n", "loss: 2.153999 [ 6464/60000]\n", "loss: 2.094229 [12864/60000]\n", "loss: 2.107332 [19264/60000]\n", "loss: 2.060189 [25664/60000]\n", "loss: 2.009164 [32064/60000]\n", "loss: 2.033063 [38464/60000]\n", "loss: 1.954014 [44864/60000]\n", "loss: 1.968186 [51264/60000]\n", "loss: 1.892358 [57664/60000]\n", "Test Error: \n", " Accuracy: 54.1%, Avg loss: 1.883826 \n", "\n", "Epoch 3\n", "-------------------------------\n", "loss: 1.922989 [ 64/60000]\n", "loss: 1.895849 [ 6464/60000]\n", "loss: 1.767882 [12864/60000]\n", "loss: 1.804950 [19264/60000]\n", "loss: 1.702711 [25664/60000]\n", "loss: 1.664090 [32064/60000]\n", "loss: 1.682484 [38464/60000]\n", "loss: 1.577310 [44864/60000]\n", "loss: 1.613093 [51264/60000]\n", "loss: 1.510797 [57664/60000]\n", "Test Error: \n", " Accuracy: 59.5%, Avg loss: 1.517127 \n", "\n", "Epoch 4\n", "-------------------------------\n", "loss: 1.588409 [ 64/60000]\n", "loss: 1.558777 [ 6464/60000]\n", "loss: 1.393466 [12864/60000]\n", "loss: 1.465835 [19264/60000]\n", "loss: 1.350062 [25664/60000]\n", "loss: 1.359687 [32064/60000]\n", "loss: 1.370576 [38464/60000]\n", "loss: 1.287119 [44864/60000]\n", "loss: 1.330430 [51264/60000]\n", "loss: 1.238912 [57664/60000]\n", "Test Error: \n", " Accuracy: 62.4%, Avg loss: 1.254357 \n", "\n", "Epoch 5\n", "-------------------------------\n", "loss: 1.333722 [ 64/60000]\n", "loss: 1.322049 [ 6464/60000]\n", "loss: 1.143545 [12864/60000]\n", "loss: 1.250494 [19264/60000]\n", "loss: 1.123120 [25664/60000]\n", "loss: 1.166146 [32064/60000]\n", "loss: 1.181268 [38464/60000]\n", "loss: 1.112326 [44864/60000]\n", "loss: 1.155791 [51264/60000]\n", "loss: 1.079376 [57664/60000]\n", "Test Error: \n", " Accuracy: 64.0%, Avg loss: 1.092456 \n", "\n", "Done!\n" ] } ], "source": [ "epochs = 5\n", "for t in range(epochs):\n", " print(f\"Epoch {t+1}\\n-------------------------------\")\n", " train(train_dataloader, model, loss_fn, optimizer)\n", " test(test_dataloader, model, loss_fn)\n", "print(\"Done!\")" ] }, { "cell_type": "markdown", "id": "85d97839", "metadata": {}, "source": [ "### Save Model State Dict\n", "This saves the serialized object to disk using pickle." ] }, { "cell_type": "code", "execution_count": 12, "id": "5d5d24de", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Saved PyTorch Model State to models/model.pt\n" ] } ], "source": [ "torch.save(model.state_dict(), \"models/model.pt\")\n", "print(\"Saved PyTorch Model State to models/model.pt\")" ] }, { "cell_type": "markdown", "id": "ac221ca7-e227-4c8c-8577-1eeda4a61fc7", "metadata": {}, "source": [ "### Save Model as TorchScript\n", "This saves an [intermediate representation of the compute graph](https://pytorch.org/tutorials/beginner/saving_loading_models.html#export-load-model-in-torchscript-format), which does not require pickle (or even python). " ] }, { "cell_type": "code", "execution_count": 13, "id": "6d9b3a45-7618-43e4-8bd3-8bb317a484d3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Saved TorchScript Model to models/ts_model.pt\n" ] } ], "source": [ "scripted = torch.jit.script(model)\n", "scripted.save(\"models/ts_model.pt\")\n", "print(\"Saved TorchScript Model to models/ts_model.pt\")" ] }, { "cell_type": "markdown", "id": "12ee8916-f437-4a2a-9bf4-14ff5376d305", "metadata": {}, "source": [ "### Load Model State" ] }, { "cell_type": "code", "execution_count": 14, "id": "8fe3b5d1", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model_from_state = NeuralNetwork().to(device)\n", "model_from_state.load_state_dict(torch.load(\"models/model.pt\", weights_only=True))" ] }, { "cell_type": "code", "execution_count": 15, "id": "0c405bd0", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Predicted: \"Ankle boot\", Actual: \"Ankle boot\"\n" ] } ], "source": [ "model_from_state.eval()\n", "x, y = test_data[0][0], test_data[0][1]\n", "with torch.no_grad():\n", " x = torch.flatten(x.to(device), start_dim=1, end_dim=-1)\n", " pred = model_from_state(x)\n", " predicted, actual = classes[pred[0].argmax(0)], classes[y]\n", " print(f'Predicted: \"{predicted}\", Actual: \"{actual}\"')" ] }, { "cell_type": "markdown", "id": "290c482a-1c5d-4bf2-bc3f-8a4e53d442b5", "metadata": {}, "source": [ "### Load Torchscript Model" ] }, { "cell_type": "code", "execution_count": 16, "id": "ef3c419e-d384-446c-b07b-1af93e07d6c0", "metadata": {}, "outputs": [], "source": [ "ts_model = torch.jit.load(\"models/ts_model.pt\")" ] }, { "cell_type": "code", "execution_count": 17, "id": "c92d6cdb", "metadata": {}, "outputs": [], "source": [ "x, y = test_data[0][0], test_data[0][1]" ] }, { "cell_type": "code", "execution_count": null, "id": "038af830-a360-45eb-ab4e-b1adab0af164", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Predicted: \"Ankle boot\", Actual: \"Ankle boot\"\n" ] } ], "source": [ "with torch.no_grad():\n", " pred = ts_model(torch.flatten(x.to(device), start_dim=1, end_dim=-1))\n", " predicted, actual = classes[pred[0].argmax(0)], classes[y]\n", " print(f'Predicted: \"{predicted}\", Actual: \"{actual}\"')" ] }, { "cell_type": "markdown", "id": "76980495", "metadata": {}, "source": [ "### Compile using the Torch JIT Compiler\n", "This leverages the [Torch-TensorRT inference compiler](https://pytorch.org/TensorRT/) for accelerated inference on GPUs using the `torch.compile` JIT interface under the hood. The compiler stack returns a [boxed-function](http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/) that triggers compilation on the first call. \n", "\n", "Modules compiled in this fashion are [not serializable with pickle](https://github.com/pytorch/pytorch/issues/101107#issuecomment-1542688089), so we cannot send the compiled model directly to Spark. " ] }, { "cell_type": "markdown", "id": "414bc856", "metadata": {}, "source": [ "(You may see a warning about modelopt quantization. This is safe to ignore, as [implicit quantization](https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#intro-quantization) is deprecated in the latest TensorRT. See [this link](https://pytorch.org/TensorRT/tutorials/_rendered_examples/dynamo/vgg16_fp8_ptq.html) for a guide to explicit quantization.)" ] }, { "cell_type": "code", "execution_count": 19, "id": "362b266b", "metadata": {}, "outputs": [], "source": [ "import torch_tensorrt as trt\n", "import time" ] }, { "cell_type": "code", "execution_count": 20, "id": "f0ac1362", "metadata": {}, "outputs": [], "source": [ "# Optional: set the filename for the TensorRT timing cache\n", "timestamp = time.time()\n", "timing_cache = f\"/tmp/timing_cache-{timestamp}.bin\"\n", "with open(timing_cache, \"wb\") as f:\n", " pass" ] }, { "cell_type": "code", "execution_count": 21, "id": "f3e3bdc4", "metadata": {}, "outputs": [], "source": [ "inputs_bs1 = torch.randn((10, 784), dtype=torch.float).to(\"cuda\")\n", "# This indicates dimension 0 of inputs_bs1 is dynamic whose range of values is [1, 50]. \n", "torch._dynamo.mark_dynamic(inputs_bs1, 0, min=1, max=64)\n", "trt_model = trt.compile(\n", " model,\n", " ir=\"torch_compile\",\n", " inputs=inputs_bs1,\n", " enabled_precisions={torch.float},\n", " timing_cache_path=timing_cache,\n", ")" ] }, { "cell_type": "code", "execution_count": 22, "id": "66f61302", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:torch_tensorrt.dynamo._compiler:Node linear_default of op type call_function does not have metadata. This could sometimes lead to undefined behavior.\n", "WARNING:torch_tensorrt.dynamo._compiler:Some nodes do not have metadata (shape and dtype information). This could lead to problems sometimes if the graph has PyTorch and TensorRT segments.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Predicted: \"Ankle boot\", Actual: \"Ankle boot\"\n" ] } ], "source": [ "stream = torch.cuda.Stream()\n", "with torch.no_grad(), torch.cuda.stream(stream):\n", " pred = trt_model(torch.flatten(x.to(device), start_dim=1, end_dim=-1))\n", " predicted, actual = classes[pred[0].argmax(0)], classes[y]\n", " print(f'Predicted: \"{predicted}\", Actual: \"{actual}\"')" ] }, { "cell_type": "markdown", "id": "9ec04be8", "metadata": {}, "source": [ "### Compile using the Torch-TensorRT AOT Compiler\n", "Alternatively, use the Torch-TensorRT Dynamo backend for Ahead-of-Time (AOT) compilation to eagerly optimize the model in an explicit compilation phase. We first export the model to produce a traced graph representing the Tensor computation in an AOT fashion, which produces a `ExportedProgram` object which can be [serialized and reloaded](https://pytorch.org/TensorRT/user_guide/saving_models.html). We can then compile this IR using the Torch-TensorRT AOT compiler for inference. \n", "\n", "[Read the docs](https://pytorch.org/TensorRT/user_guide/torch_tensorrt_explained.html) for more information on JIT vs AOT compilation." ] }, { "cell_type": "code", "execution_count": 23, "id": "3e7e7689", "metadata": {}, "outputs": [], "source": [ "example_inputs = (torch.randn((10, 784), dtype=torch.float).to(\"cuda\"),)\n", "\n", "# Mark dim 1 (batch size) as dynamic\n", "batch = torch.export.Dim(\"batch\", min=1, max=64)\n", "# Produce traced graph in ExportedProgram format\n", "exp_program = torch.export.export(model_from_state, args=example_inputs, dynamic_shapes={\"x\": {0: batch}})\n", "# Compile the traced graph to produce an optimized module\n", "trt_gm = trt.dynamo.compile(exp_program, tuple(example_inputs), enabled_precisions={torch.float}, timing_cache_path=timing_cache)" ] }, { "cell_type": "code", "execution_count": 24, "id": "6fda0c0e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", ".GraphModuleImpl'>\n" ] } ], "source": [ "print(type(exp_program))\n", "print(type(trt_gm))" ] }, { "cell_type": "code", "execution_count": 25, "id": "5ed9e4c5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Predicted: \"Ankle boot\", Actual: \"Ankle boot\"\n" ] } ], "source": [ "stream = torch.cuda.Stream()\n", "with torch.no_grad(), torch.cuda.stream(stream):\n", " trt_gm(torch.flatten(x.to(device), start_dim=1, end_dim=-1))\n", " predicted, actual = classes[pred[0].argmax(0)], classes[y]\n", " print(f'Predicted: \"{predicted}\", Actual: \"{actual}\"')" ] }, { "cell_type": "markdown", "id": "38697a06", "metadata": {}, "source": [ "We can run the optimized module with a few different batch sizes (without recompilation!):" ] }, { "cell_type": "code", "execution_count": null, "id": "27871156", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Output shapes:\n", "torch.Size([10, 10])\n", "torch.Size([1, 10])\n", "torch.Size([50, 10])\n" ] } ], "source": [ "inputs = (torch.randn((10, 784), dtype=torch.float).cuda(),)\n", "inputs_bs1 = (torch.randn((1, 784), dtype=torch.float).cuda(),)\n", "inputs_bs50 = (torch.randn((50, 784), dtype=torch.float).cuda(),)\n", "\n", "stream = torch.cuda.Stream()\n", "with torch.no_grad(), torch.cuda.stream(stream):\n", " print(\"Output shapes:\")\n", " print(trt_gm(*inputs).shape)\n", " print(trt_gm(*inputs_bs1).shape)\n", " print(trt_gm(*inputs_bs50).shape)" ] }, { "cell_type": "markdown", "id": "ab974244", "metadata": {}, "source": [ "We can serialize the ExportedProgram (a traced graph representing the model's forward function) using `torch.export.save` to be recompiled at a later date." ] }, { "cell_type": "code", "execution_count": null, "id": "d87e4b20", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Saved ExportedProgram to models/trt_model.ep\n" ] } ], "source": [ "torch.export.save(exp_program, \"models/trt_model.ep\")\n", "print(\"Saved ExportedProgram to models/trt_model.ep\")" ] }, { "cell_type": "markdown", "id": "ad918393", "metadata": {}, "source": [ "## PySpark" ] }, { "cell_type": "code", "execution_count": 28, "id": "42c5feba", "metadata": {}, "outputs": [], "source": [ "from pyspark.sql.functions import col, struct, pandas_udf, array\n", "from pyspark.ml.functions import predict_batch_udf\n", "from pyspark.sql.types import *\n", "from pyspark.sql import SparkSession\n", "from pyspark import SparkConf" ] }, { "cell_type": "code", "execution_count": 29, "id": "ef97321d", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import numpy as np\n", "import json\n", "import os" ] }, { "cell_type": "markdown", "id": "ece094d6", "metadata": {}, "source": [ "Check the cluster environment to handle any platform-specific Spark configurations." ] }, { "cell_type": "code", "execution_count": 30, "id": "10eb841f", "metadata": {}, "outputs": [], "source": [ "on_databricks = os.environ.get(\"DATABRICKS_RUNTIME_VERSION\", False)\n", "on_dataproc = os.environ.get(\"DATAPROC_IMAGE_VERSION\", False)\n", "on_standalone = not (on_databricks or on_dataproc)" ] }, { "cell_type": "markdown", "id": "425e94ac", "metadata": {}, "source": [ "#### Create Spark Session\n", "\n", "For local standalone clusters, we'll connect to the cluster and create the Spark Session. \n", "For CSP environments, Spark will either be preconfigured (Databricks) or we'll need to create the Spark Session (Dataproc)." ] }, { "cell_type": "code", "execution_count": 31, "id": "60ba6e74", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "25/02/04 13:50:47 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)\n", "25/02/04 13:50:47 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\n", "Setting default log level to \"WARN\".\n", "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n", "25/02/04 13:50:48 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n" ] } ], "source": [ "conf = SparkConf()\n", "\n", "if 'spark' not in globals():\n", " if on_standalone:\n", " import socket\n", " conda_env = os.environ.get(\"CONDA_PREFIX\")\n", " hostname = socket.gethostname()\n", " conf.setMaster(f\"spark://{hostname}:7077\")\n", " conf.set(\"spark.pyspark.python\", f\"{conda_env}/bin/python\")\n", " conf.set(\"spark.pyspark.driver.python\", f\"{conda_env}/bin/python\")\n", "\n", " conf.set(\"spark.executor.cores\", \"8\")\n", " conf.set(\"spark.task.resource.gpu.amount\", \"0.125\")\n", " conf.set(\"spark.executor.resource.gpu.amount\", \"1\")\n", " conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"true\")\n", " conf.set(\"spark.python.worker.reuse\", \"true\")\n", " \n", "spark = SparkSession.builder.appName(\"spark-dl-examples\").config(conf=conf).getOrCreate()\n", "sc = spark.sparkContext" ] }, { "cell_type": "markdown", "id": "2cd11476", "metadata": {}, "source": [ "#### Create Spark DataFrame from Pandas DataFrame" ] }, { "cell_type": "code", "execution_count": 32, "id": "f063cbe7", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "((10000, 28, 28), dtype('uint8'))" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data = test_data.data.numpy()\n", "data.shape, data.dtype" ] }, { "cell_type": "code", "execution_count": null, "id": "8c828393", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "((10000, 784), dtype('float64'))" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data = data.reshape(10000, 784) / 255.0\n", "data.shape, data.dtype" ] }, { "cell_type": "code", "execution_count": 34, "id": "7760bdbe", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
0123456789...774775776777778779780781782783
00.00.00.00.0000000.00.0000000.0000000.00.0000000.000000...0.0000000.0000000.00.0000000.0000000.0000000.0000000.00.00.0
10.00.00.00.0000000.00.0000000.0000000.00.0000000.000000...0.0078430.0117650.00.0117650.6823530.7411760.2627450.00.00.0
20.00.00.00.0000000.00.0000000.0000000.00.0039220.000000...0.6431370.2274510.00.0000000.0000000.0000000.0000000.00.00.0
30.00.00.00.0000000.00.0000000.0000000.00.0000000.082353...0.0039220.0000000.00.0000000.0000000.0000000.0000000.00.00.0
40.00.00.00.0078430.00.0039220.0039220.00.0000000.000000...0.2784310.0470590.00.0000000.0000000.0000000.0000000.00.00.0
..................................................................
99950.00.00.00.0000000.00.0000000.0000000.00.0000000.000000...0.0000000.0000000.00.0000000.0000000.0000000.0000000.00.00.0
99960.00.00.00.0000000.00.0000000.0000000.00.0000000.121569...0.0000000.0000000.00.0000000.0000000.0000000.0000000.00.00.0
99970.00.00.00.0000000.00.0000000.0000000.00.0000000.000000...0.1058820.0000000.00.0000000.0000000.0000000.0000000.00.00.0
99980.00.00.00.0000000.00.0000000.0000000.00.0000000.000000...0.0000000.0000000.00.0000000.0000000.0000000.0000000.00.00.0
99990.00.00.00.0000000.00.0000000.0000000.00.0000000.000000...0.0000000.0000000.00.0000000.0000000.0000000.0000000.00.00.0
\n", "

10000 rows × 784 columns

\n", "
" ], "text/plain": [ " 0 1 2 3 4 5 6 7 8 \\\n", "0 0.0 0.0 0.0 0.000000 0.0 0.000000 0.000000 0.0 0.000000 \n", "1 0.0 0.0 0.0 0.000000 0.0 0.000000 0.000000 0.0 0.000000 \n", "2 0.0 0.0 0.0 0.000000 0.0 0.000000 0.000000 0.0 0.003922 \n", "3 0.0 0.0 0.0 0.000000 0.0 0.000000 0.000000 0.0 0.000000 \n", "4 0.0 0.0 0.0 0.007843 0.0 0.003922 0.003922 0.0 0.000000 \n", "... ... ... ... ... ... ... ... ... ... \n", "9995 0.0 0.0 0.0 0.000000 0.0 0.000000 0.000000 0.0 0.000000 \n", "9996 0.0 0.0 0.0 0.000000 0.0 0.000000 0.000000 0.0 0.000000 \n", "9997 0.0 0.0 0.0 0.000000 0.0 0.000000 0.000000 0.0 0.000000 \n", "9998 0.0 0.0 0.0 0.000000 0.0 0.000000 0.000000 0.0 0.000000 \n", "9999 0.0 0.0 0.0 0.000000 0.0 0.000000 0.000000 0.0 0.000000 \n", "\n", " 9 ... 774 775 776 777 778 779 \\\n", "0 0.000000 ... 0.000000 0.000000 0.0 0.000000 0.000000 0.000000 \n", "1 0.000000 ... 0.007843 0.011765 0.0 0.011765 0.682353 0.741176 \n", "2 0.000000 ... 0.643137 0.227451 0.0 0.000000 0.000000 0.000000 \n", "3 0.082353 ... 0.003922 0.000000 0.0 0.000000 0.000000 0.000000 \n", "4 0.000000 ... 0.278431 0.047059 0.0 0.000000 0.000000 0.000000 \n", "... ... ... ... ... ... ... ... ... \n", "9995 0.000000 ... 0.000000 0.000000 0.0 0.000000 0.000000 0.000000 \n", "9996 0.121569 ... 0.000000 0.000000 0.0 0.000000 0.000000 0.000000 \n", "9997 0.000000 ... 0.105882 0.000000 0.0 0.000000 0.000000 0.000000 \n", "9998 0.000000 ... 0.000000 0.000000 0.0 0.000000 0.000000 0.000000 \n", "9999 0.000000 ... 0.000000 0.000000 0.0 0.000000 0.000000 0.000000 \n", "\n", " 780 781 782 783 \n", "0 0.000000 0.0 0.0 0.0 \n", "1 0.262745 0.0 0.0 0.0 \n", "2 0.000000 0.0 0.0 0.0 \n", "3 0.000000 0.0 0.0 0.0 \n", "4 0.000000 0.0 0.0 0.0 \n", "... ... ... ... ... \n", "9995 0.000000 0.0 0.0 0.0 \n", "9996 0.000000 0.0 0.0 0.0 \n", "9997 0.000000 0.0 0.0 0.0 \n", "9998 0.000000 0.0 0.0 0.0 \n", "9999 0.000000 0.0 0.0 0.0 \n", "\n", "[10000 rows x 784 columns]" ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pdf784 = pd.DataFrame(data)\n", "pdf784" ] }, { "cell_type": "code", "execution_count": 35, "id": "f7d2bc0d", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
data
0[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
1[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
2[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.003...
3[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
4[0.0, 0.0, 0.0, 0.00784313725490196, 0.0, 0.00...
......
9995[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
9996[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
9997[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
9998[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
9999[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
\n", "

10000 rows × 1 columns

\n", "
" ], "text/plain": [ " data\n", "0 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...\n", "1 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...\n", "2 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.003...\n", "3 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...\n", "4 [0.0, 0.0, 0.0, 0.00784313725490196, 0.0, 0.00...\n", "... ...\n", "9995 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...\n", "9996 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...\n", "9997 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...\n", "9998 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...\n", "9999 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...\n", "\n", "[10000 rows x 1 columns]" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 1 column of array\n", "pdf1 = pd.DataFrame()\n", "pdf1['data'] = pdf784.values.tolist()\n", "pdf1" ] }, { "cell_type": "markdown", "id": "07b2a70b", "metadata": {}, "source": [ "Create dataframes with a single column of 784 floats and 784 separate columns." ] }, { "cell_type": "code", "execution_count": 36, "id": "4863d5ff", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 185 ms, sys: 28.9 ms, total: 214 ms\n", "Wall time: 1.5 s\n" ] }, { "data": { "text/plain": [ "StructType([StructField('data', ArrayType(FloatType(), True), True)])" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%time\n", "# force FloatType since Spark defaults to DoubleType\n", "schema = StructType([StructField(\"data\",ArrayType(FloatType()), True)])\n", "df = spark.createDataFrame(pdf1, schema).repartition(8)\n", "df.schema" ] }, { "cell_type": "code", "execution_count": 37, "id": "831f4a01-3a49-4114-b9a0-2ae54526d72d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 66.9 ms, sys: 11.2 ms, total: 78.1 ms\n", "Wall time: 875 ms\n" ] }, { "data": { "text/plain": [ "StructType([StructField('data', ArrayType(FloatType(), True), True)])" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%time\n", "# force FloatType since Spark defaults to DoubleType\n", "schema = StructType([StructField(str(x), FloatType(), True) for x in range(784)])\n", "df784 = spark.createDataFrame(pdf784, schema).repartition(8)\n", "df.schema" ] }, { "cell_type": "code", "execution_count": 38, "id": "e8ebae46", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "25/02/04 13:50:51 WARN TaskSetManager: Stage 0 contains a task of very large size (4030 KiB). The maximum recommended task size is 1000 KiB.\n", "[Stage 0:=======> (1 + 7) / 8]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 2.09 ms, sys: 1.6 ms, total: 3.69 ms\n", "Wall time: 1.71 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "%%time\n", "data_path_1 = \"spark-dl-datasets/fashion_mnist_1\"\n", "if on_databricks:\n", " dbutils.fs.mkdirs(\"/FileStore/spark-dl-datasets\")\n", " data_path_1 = \"dbfs:/FileStore/\" + data_path_1\n", "\n", "df.write.mode(\"overwrite\").parquet(data_path_1)" ] }, { "cell_type": "code", "execution_count": 39, "id": "922314ce-2996-4666-9fc9-bcd98d16bb56", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "25/02/04 13:50:53 WARN package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "25/02/04 13:50:53 WARN TaskSetManager: Stage 3 contains a task of very large size (7847 KiB). The maximum recommended task size is 1000 KiB.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 2.94 ms, sys: 61 μs, total: 3 ms\n", "Wall time: 943 ms\n" ] } ], "source": [ "%%time\n", "data_path_784 = \"spark-dl-datasets/fashion_mnist_784\"\n", "if on_databricks:\n", " dbutils.fs.mkdirs(\"/FileStore/spark-dl-datasets\")\n", " data_path_784 = \"dbfs:/FileStore/\" + data_path_784\n", "\n", "df784.write.mode(\"overwrite\").parquet(data_path_784)" ] }, { "cell_type": "markdown", "id": "fce89cb0", "metadata": {}, "source": [ "## Inference using Spark DL API\n", "\n", "Distributed inference using the PySpark [predict_batch_udf](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.functions.predict_batch_udf.html#pyspark.ml.functions.predict_batch_udf):\n", "\n", "- predict_batch_fn uses PyTorch APIs to load the model and return a predict function which operates on numpy arrays \n", "- predict_batch_udf will convert the Spark DataFrame columns into numpy input batches for the predict function" ] }, { "cell_type": "markdown", "id": "59395856-a588-43c6-93c8-c83100716ac1", "metadata": { "tags": [] }, "source": [ "### 1 column of 784 float" ] }, { "cell_type": "code", "execution_count": 40, "id": "79b151d9-d112-43b6-a479-887e2fd0e2b1", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = spark.read.parquet(data_path_1)\n", "len(df.columns)" ] }, { "cell_type": "code", "execution_count": 41, "id": "3e6a4dbb", "metadata": {}, "outputs": [], "source": [ "# A resource warning may occur due to unclosed file descriptors used by TensorRT across multiple PySpark daemon processes.\n", "# These can be safely ignored as the resources will be cleaned up when the worker processes terminate.\n", "\n", "import warnings\n", "warnings.simplefilter(\"ignore\", ResourceWarning)" ] }, { "cell_type": "code", "execution_count": 42, "id": "16e523c2", "metadata": {}, "outputs": [], "source": [ "# get absolute path to model\n", "model_path = \"{}/models/trt_model.ep\".format(os.getcwd())\n", "\n", "# For cloud environments, copy the model to the distributed file system.\n", "if on_databricks:\n", " dbutils.fs.mkdirs(\"/FileStore/spark-dl-models\")\n", " dbfs_model_path = \"/dbfs/FileStore/spark-dl-models/model.pt\"\n", " shutil.copy(model_path, dbfs_model_path)\n", " model_path = dbfs_model_path\n", "elif on_dataproc:\n", " # GCS is mounted at /mnt/gcs by the init script\n", " models_dir = \"/mnt/gcs/spark-dl/models\"\n", " os.mkdir(models_dir) if not os.path.exists(models_dir) else None\n", " gcs_model_path = models_dir + \"/trt_model.ep\"\n", " shutil.copy(model_path, gcs_model_path)\n", " model_path = gcs_model_path" ] }, { "cell_type": "code", "execution_count": 43, "id": "73dc73cb-25e3-4798-a019-e1abd684eaa1", "metadata": {}, "outputs": [], "source": [ "def predict_batch_fn():\n", " import torch\n", " import torch_tensorrt as trt\n", " \n", " device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", " if device != \"cuda\":\n", " raise ValueError(\"This function uses the TensorRT model which requires a GPU device\")\n", "\n", " example_inputs = (torch.randn((50, 784), dtype=torch.float).to(\"cuda\"),)\n", " exp_program = torch.export.load(model_path)\n", " trt_gm = trt.dynamo.compile(exp_program,\n", " tuple(example_inputs),\n", " enabled_precisions={torch.float},\n", " workspace_size=1<<30)\n", "\n", " def predict(inputs: np.ndarray):\n", " stream = torch.cuda.Stream()\n", " with torch.no_grad(), torch.cuda.stream(stream):\n", " # use array to combine columns into tensors\n", " torch_inputs = torch.from_numpy(inputs).to(device)\n", " outputs = trt_gm(torch_inputs)\n", " return outputs.detach().cpu().numpy()\n", "\n", " return predict" ] }, { "cell_type": "code", "execution_count": 44, "id": "df68cca1-2d47-4e88-8aad-9899402aee97", "metadata": {}, "outputs": [], "source": [ "mnist = predict_batch_udf(predict_batch_fn,\n", " input_tensor_shapes=[[784]],\n", " return_type=ArrayType(FloatType()),\n", " batch_size=50)" ] }, { "cell_type": "code", "execution_count": 45, "id": "63555b3b-3673-4712-97aa-fd728c6c4979", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 167 ms, sys: 76.2 ms, total: 243 ms\n", "Wall time: 10.9 s\n" ] } ], "source": [ "%%time\n", "# first pass compiles and caches model/fn\n", "preds = df.withColumn(\"preds\", mnist(struct(df.columns))).collect()" ] }, { "cell_type": "code", "execution_count": 46, "id": "5dbf058a-70d6-4199-af9d-13843d078950", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 234 ms, sys: 64.1 ms, total: 298 ms\n", "Wall time: 685 ms\n" ] } ], "source": [ "%%time\n", "preds = df.withColumn(\"preds\", mnist(*df.columns)).collect()" ] }, { "cell_type": "code", "execution_count": 47, "id": "3f5ed801-6ca5-43a0-bf9c-2535a0dfe2e8", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 403 ms, sys: 60.1 ms, total: 463 ms\n", "Wall time: 809 ms\n" ] } ], "source": [ "%%time\n", "preds = df.withColumn(\"preds\", mnist(*[col(c) for c in df.columns])).collect()" ] }, { "cell_type": "markdown", "id": "c6dbec03-9b64-46c4-a748-f889be571384", "metadata": { "tags": [] }, "source": [ "### Check predictions" ] }, { "cell_type": "code", "execution_count": 48, "id": "f1f1e5fd-5866-4b78-b9d3-709e6b383a0c", "metadata": {}, "outputs": [], "source": [ "predictions = preds[0].preds\n", "img = preds[0].data" ] }, { "cell_type": "code", "execution_count": 49, "id": "76b76502-adb7-45ec-a365-2e61cdd576fc", "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": 50, "id": "c163953a-1504-444f-b39f-86b61d34e440", "metadata": {}, "outputs": [], "source": [ "img = np.array(img).reshape(28,28)" ] }, { "cell_type": "code", "execution_count": 51, "id": "bc0fad05-50ab-4ae5-b9fd-e50133c4c92a", "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAJJ5JREFUeJzt3Xt0lfWd7/HPTkg2t2THEHKTQAMotHLplErKqBRLDpDOuEA5HW9zDrg6MNLgqlKrJz1Watuz0uIa66lDca2zWqir4oVzREbHYgUljAp0QBjGXlLAKGEgoWCTDQlJdrJ/5w/GzERB+P5M8kvC+7XWXovs/Xx4fnnyJJ882TvfRJxzTgAA9LKU0AsAAFyaKCAAQBAUEAAgCAoIABAEBQQACIICAgAEQQEBAIKggAAAQQwKvYAPSyaTOnr0qDIyMhSJREIvBwBg5JzTqVOnVFhYqJSU81/n9LkCOnr0qIqKikIvAwDwCdXW1mrUqFHnfbzPFVBGRoYk6Vp9WYOUFng16Ha9dVXLhCkgmHYl9Lpe6vx6fj49VkCrV6/Www8/rLq6Ok2dOlWPPfaYpk+ffsHcBz92G6Q0DYpQQANOr/1YlQICgvn3T78LPY3SIy9CeOaZZ7RixQqtXLlSb731lqZOnaq5c+fq+PHjPbE7AEA/1CMF9Mgjj2jJkiW644479JnPfEaPP/64hg4dqp/97Gc9sTsAQD/U7QXU1tamPXv2qLS09D92kpKi0tJS7dix4yPbt7a2Kh6Pd7kBAAa+bi+gEydOqKOjQ3l5eV3uz8vLU11d3Ue2r6ysVCwW67zxCjgAuDQE/0XUiooKNTY2dt5qa2tDLwkA0Au6/VVwOTk5Sk1NVX19fZf76+vrlZ+f/5Hto9GootFody8DANDHdfsVUHp6uqZNm6atW7d23pdMJrV161bNmDGju3cHAOineuT3gFasWKFFixbp85//vKZPn65HH31UTU1NuuOOO3pidwCAfqhHCujmm2/WH//4Rz344IOqq6vTZz/7WW3evPkjL0wAAFy6Is71rZkl8XhcsVhMszSfSQgYsFLHF5szx/7O/lxp7vzfmzPAJ9XuEtqmTWpsbFRmZuZ5twv+KjgAwKWJAgIABEEBAQCCoIAAAEFQQACAICggAEAQFBAAIAgKCAAQBAUEAAiCAgIABEEBAQCCoIAAAEH0yDRsoL+q+YH9b1bdP3+jOVPVcP4BjeczPu2MOZO/356RpF9smG3OFH3vTa994dLFFRAAIAgKCAAQBAUEAAiCAgIABEEBAQCCoIAAAEFQQACAICggAEAQFBAAIAgKCAAQBAUEAAiCAgIABEEBAQCCYBp2HxYZZP/wuI4O+46cs2c8RdLSzRmXaDNnBhWPMWckacttD5szX/zV3ebMlX+z25ypNyek3Tdf75GSHvzuU+bM2u/5HXOzSMSe6cVzHBePKyAAQBAUEAAgCAoIABAEBQQACIICAgAEQQEBAIKggAAAQVBAAIAgKCAAQBAUEAAgCAoIABAEBQQACCLiXN+a0hePxxWLxTRL8zUokhZ6Od3HY4BiJDXVnPEaRuqrb506Xfxh7TSv3NiiP5ozg0oPe+2rLzvzcrE5c8Pl+82ZLZMyzBn0fe0uoW3apMbGRmVmZp53O66AAABBUEAAgCAoIABAEBQQACAICggAEAQFBAAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACCIQaEXcMnwGNzp2tvt+/EYetqXh4pKUsrUT5szW770v732VfrLFebMlfIYRppiHzQbSbF/bL3OIUlpP8w2Z25e9y/mzIbF3zRnLlu3w5xB38QVEAAgCAoIABBEtxfQd77zHUUikS63iRMndvduAAD9XI88B3TVVVdpy5Yt/7GTQTzVBADoqkeaYdCgQcrPz++J/xoAMED0yHNABw4cUGFhocaOHavbb79dhw+f/1VCra2tisfjXW4AgIGv2wuopKRE69at0+bNm7VmzRrV1NTouuuu06lTp865fWVlpWKxWOetqKiou5cEAOiDur2AysrK9JWvfEVTpkzR3Llz9dJLL6mhoUHPPvvsObevqKhQY2Nj5622tra7lwQA6IN6/NUBWVlZuvLKK3Xw4MFzPh6NRhWNRnt6GQCAPqbHfw/o9OnTOnTokAoKCnp6VwCAfqTbC+jee+9VVVWV3n33Xb355pu68cYblZqaqltvvbW7dwUA6Me6/UdwR44c0a233qqTJ09q5MiRuvbaa7Vz506NHDmyu3cFAOjHur2Ann766e7+L2HhM1jUYzCmJCnZYY7Eb/2COTNm+R/MmcdPXmfOSFLhq700ncolzZFIdKh9N57DSOu+YH9e9p1EpjnzzEMPmzMr//bL5kz9DH69oy9iFhwAIAgKCAAQBAUEAAiCAgIABEEBAQCCoIAAAEFQQACAICggAEAQFBAAIAgKCAAQBAUEAAiCAgIABNHjf5Cu10Qi9ozP4M7e5DMk1GPIpc9QUV9Xf2OPOVPdmGfONLenmzOSNPzZnV45q0iq5wDYXpJ2yp55s+kKc+bRP33KnLlr1BZz5oHblpgzkpS53uN86K2vRT778d1XD+EKCAAQBAUEAAiCAgIABEEBAQCCoIAAAEFQQACAICggAEAQFBAAIAgKCAAQBAUEAAiCAgIABEEBAQCCoIAAAEEMnGnYfZ3H5Fqficku0XuTrYdtH2nOtDv7mOXUFPuE73c3jTVnJKlAdV45K9fh8XFqS3T/Qs4j77E3zZlvVVSbM9cevcqcWXlgvjlz8//cbM5I0ssvFJkzyVMeo8R9+E617kN/OYArIABAEBQQACAICggAEAQFBAAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACAICggAEAQFBAAIYuAMI/UYlhdJS/fbVaLNI2Rfn9d+PBy978+9ct/MfdacefLfvmDOJGUfnljwiH2YZq/qw+eDr181p5kz/33MTnNm1d455kxqkd8wzYxf2r9GNF7rtaveE/G47nA9M+SYKyAAQBAUEAAgCAoIABAEBQQACIICAgAEQQEBAIKggAAAQVBAAIAgKCAAQBAUEAAgCAoIABAEBQQACCLinMdUxB4Uj8cVi8U0S/M1KGIfbjiQHF9uHxLammXfz6+WrrKHJK1r+Lw5Myr9fXPme7+8yZyJ/cE+wFSSrv+bXebMxt981pwZ9G9RcyalzeN9ivh9evvsq3VE0pwZOeGEOROLtpgzZ9r9vpasHP8P5syyDUvNmeL/scOc6cvaXULbtEmNjY3KzMw873ZcAQEAgqCAAABBmAto+/btuuGGG1RYWKhIJKLnn3++y+POOT344IMqKCjQkCFDVFpaqgMHDnTXegEAA4S5gJqamjR16lStXr36nI+vWrVKP/7xj/X4449r165dGjZsmObOnauWFvvPbQEAA5f5L6KWlZWprKzsnI855/Too4/qgQce0Pz58yVJTzzxhPLy8vT888/rlltu+WSrBQAMGN36HFBNTY3q6upUWlraeV8sFlNJSYl27Dj3qzxaW1sVj8e73AAAA1+3FlBdXZ0kKS8vr8v9eXl5nY99WGVlpWKxWOetqKioO5cEAOijgr8KrqKiQo2NjZ232tra0EsCAPSCbi2g/Px8SVJ9fX2X++vr6zsf+7BoNKrMzMwuNwDAwNetBVRcXKz8/Hxt3bq18754PK5du3ZpxowZ3bkrAEA/Z34V3OnTp3Xw4MHOt2tqarRv3z5lZ2dr9OjRuvvuu/X9739fV1xxhYqLi/Xtb39bhYWFWrBgQXeuGwDQz5kLaPfu3br++us7316xYoUkadGiRVq3bp3uu+8+NTU1aenSpWpoaNC1116rzZs3a/Dgwd23agBAv3dJDyN9Z5XfjwX/7safmzMr/vmvzJn09HZz5gdTnzNnNjdMMWckKUX2U+d3jXkX3uhD3tt7uTnTcVnCnJGkwbXp5kzmO/bjkNJuz3Sk2weEJs3fYp7lUu2ZZJp9fZGk/Ticntlszny26Ig5I0lHT8fMmWvy3jFn3nrf/urfI+9nmTOSNPor/+qVs2AYKQCgT6OAAABBUEAAgCAoIABAEBQQACAICggAEAQFBAAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACAIz1m5A8Nvb/97r9zCg39hzuT8o/3PUZxeeMqc+dnR68yZxja/P5VxR9Eb5syJtmHmTM3gpDmjDvtkZklqu8y+r8RX/mTOjIo1mjMjo6fNmWiqfaK6JGUNsk+cTniM0G7qiJozMzOr7ftJ2vcjSW8OGm/OpMp+Dg0b1GbOvDR9jTkjSbfefq85E3typ9e+LoQrIABAEBQQACAICggAEAQFBAAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACAICggAEAQFBAAIYsAMI225Ybo5kxbZ57evb+WbMzn/6z1z5uXxz5kzD5+wH4ehKfZBiJK08vUF5kxK3H7KuSyPgZoe80vP7ithzsQPXGbOHGwYYc7U2meeKrXV2UOeOqL2AbDOY2bs6+mfM2duX/yKfUeSrsv6gznzucGHzZmX064yZ/7in+80ZyTpbyp+Zc68/GSm174uhCsgAEAQFBAAIAgKCAAQBAUEAAiCAgIABEEBAQCCoIAAAEFQQACAICggAEAQFBAAIAgKCAAQBAUEAAhiwAwjrf98770rWT+sNWf+MudfzJmfNtgHFP5V1j+bM9+t/UtzRpIu/2WqOZMY6jF9UmnmRCTpN4TTpdjX15Fu308yzb4+n7W1Zvkcb0kesYjHzFif/Qz/N/uk2cf/6Xr7jiT9Yf4ac+Y3bfZ3alFsvznzj5mTzRlJ+uvYv5ozv/qzpabtIx2t0r9suuB2XAEBAIKggAAAQVBAAIAgKCAAQBAUEAAgCAoIABAEBQQACIICAgAEQQEBAIKggAAAQVBAAIAgKCAAQBADZhhp2une29e9l282Zzb8abo5k5seN2e+eei/mjPNf3+5OSNJZ3Lt37/4DMdMbTNHlEz1G8KZ0ksDNX0Gdzr77FevtUlS+xC/nJXPcTg12n7eFVTZ9yNJX5s205wpHNxgzlSfzjNnFl6+15yRpBEp9g/umcuHmbZvT6RKFzF/mSsgAEAQFBAAIAhzAW3fvl033HCDCgsLFYlE9Pzzz3d5fPHixYpEIl1u8+bN6671AgAGCHMBNTU1aerUqVq9evV5t5k3b56OHTvWeXvqqac+0SIBAAOP+UUIZWVlKisr+9htotGo8vPzvRcFABj4euQ5oG3btik3N1cTJkzQsmXLdPLkyfNu29raqng83uUGABj4ur2A5s2bpyeeeEJbt27VD3/4Q1VVVamsrEwdHR3n3L6yslKxWKzzVlRU1N1LAgD0Qd3+e0C33HJL578nT56sKVOmaNy4cdq2bZtmz579ke0rKiq0YsWKzrfj8TglBACXgB5/GfbYsWOVk5OjgwcPnvPxaDSqzMzMLjcAwMDX4wV05MgRnTx5UgUFBT29KwBAP2L+Edzp06e7XM3U1NRo3759ys7OVnZ2th566CEtXLhQ+fn5OnTokO677z6NHz9ec+fO7daFAwD6N3MB7d69W9dff33n2x88f7No0SKtWbNG+/fv189//nM1NDSosLBQc+bM0fe+9z1Fo9HuWzUAoN8zF9CsWbPknDvv4y+//PInWpCvQc29t68DbfbfcXo43z448P+dtj8f1viz/2LOJGN+EysTw+y5yLlfDPmxOtLtGd8hnB9zap9/V700WNRrQKjncRh0xp5J8Rga63M++By7jqjfgfj1+qnmzIYVD5sz/5Q+zpy5esi75owkxZNJc2ZY9QnT9u0drRe1HbPgAABBUEAAgCAoIABAEBQQACAICggAEAQFBAAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACAICggAEES3/0nuUIacsE949TUu7bg5c6TdPl74vpfuNGcyRti/p0gMM0ckSelxe8Z5fMvTMdiekcdUa0lqH+qxK4/pzD7r85m67Ssx3L7ApMdXk9Q2+5TqlIsbtNxF/FN+07AzDtuPw7cOzzdn/u+4LebMtjMeJ6ukwtRT5kzHgXds27vERW3HFRAAIAgKCAAQBAUEAAiCAgIABEEBAQCCoIAAAEFQQACAICggAEAQFBAAIAgKCAAQBAUEAAiCAgIABDFghpFmHjptziRch9e+xqa1mDMzdywzZ/LfsA9C/NNEc8R7yGVLjj3TEbW/T9H3PQZJen5r5TMstX2o/X1qz7CfeynDL27A43+WbPWZlCpFj6aZM2mn7R8nr2PnMZw2tdVvGOmZXHvunfVXmDO/u/8fzBlpuEdGuixliDmTOmG8aXvX0SoduPB2XAEBAIKggAAAQVBAAIAgKCAAQBAUEAAgCAoIABAEBQQACIICAgAEQQEBAIKggAAAQVBAAIAgKCAAQBADZhhpSnObOZMW8RvU+K9tmebM8F/ZBwf+aaLHAMWkPdIxxD4QUpLasuw7i56wH/P0U/b1pS84bs5I0snGYeZMR5v90yj9cNScyd5u/34x4veh1Zls+7nXXOgxWNRjGKl85or6zSJVIsO+viEp9o/T/J13mjO/nPETc0aS2mU/9yIJ28TiSPLitucKCAAQBAUEAAiCAgIABEEBAQCCoIAAAEFQQACAICggAEAQFBAAIAgKCAAQBAUEAAiCAgIABEEBAQCCGDDDSNXeYY40Js947epvd37NnEkbZZ+G2Jpjf59SWu37iST8JjWm5zebM2N/FDdn2t87Ys7UR0vMGUkav/WkPXS01p7JHWGOvHdTrjnTPMY2RPIDkaH24b7ujMdw36TH+dpuz3Sk+k1ldYPsueZR9kzGTvuw4rrpQ80ZSRqXZr/uaH/nXdv2LnFR23EFBAAIggICAARhKqDKykpdffXVysjIUG5urhYsWKDq6uou27S0tKi8vFwjRozQ8OHDtXDhQtXX13frogEA/Z+pgKqqqlReXq6dO3fqlVdeUSKR0Jw5c9TU1NS5zT333KMXXnhBGzZsUFVVlY4ePaqbbrqp2xcOAOjfTC9C2Lx5c5e3161bp9zcXO3Zs0czZ85UY2OjfvrTn2r9+vX60pe+JElau3atPv3pT2vnzp36whe+0H0rBwD0a5/oOaDGxkZJUnZ2tiRpz549SiQSKi0t7dxm4sSJGj16tHbs2HHO/6O1tVXxeLzLDQAw8HkXUDKZ1N13361rrrlGkyZNkiTV1dUpPT1dWVlZXbbNy8tTXV3dOf+fyspKxWKxzltRUZHvkgAA/Yh3AZWXl+vtt9/W008//YkWUFFRocbGxs5bba3H71QAAPodr19EXb58uV588UVt375do0aN6rw/Pz9fbW1tamho6HIVVF9fr/z8/HP+X9FoVNFo1GcZAIB+zHQF5JzT8uXLtXHjRr366qsqLi7u8vi0adOUlpamrVu3dt5XXV2tw4cPa8aMGd2zYgDAgGC6AiovL9f69eu1adMmZWRkdD6vE4vFNGTIEMViMX31q1/VihUrlJ2drczMTN11112aMWMGr4ADAHRhKqA1a9ZIkmbNmtXl/rVr12rx4sWSpB/96EdKSUnRwoUL1draqrlz5+onP/lJtywWADBwmArIuQsP2Rs8eLBWr16t1atXey/KS6r99RQvNo268Ebn4JIeGY95n6nNHkMDh3sMMI34vRYl2Wp/CrH5ypHmTHrNe+bM5c8eMmckqf4vxpozJ6/JMGeyRpw2Z1pP24fnRt5PN2ckSQ1p9n35zLT1OfV8Mn6zSBVJ2HfmhtsHwDYX2vfz11v+1pyRpJq//D/mTOqIbNP2LtkmvX/h7ZgFBwAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACAICggAEAQFBAAIggICAARBAQEAgqCAAABBUEAAgCC8/iJqX9TxuwPmzLCUVq99PXHNT82Zv04sNWcizanmTGosYc4kO/wmJrsW++kTX95ozqTcdaU509xqn+YsSc6dsofeH2KONL6bZc5E7IPOpTTPMdD2U89r4rRL9Qh5TKOP+Iyjl+QG2w96aoP986JjmP2dSq/vvS/fLX9WfOGN/pP29hbptQtvxxUQACAICggAEAQFBAAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACAICggAEAQFBAAIggICAAQxYIaR+shKafbKpXpMXfzRdU+bM4WD/mTO/OLkn5szL74xzZyRJHXYBzy+fyTLnElttn+f5HrxW6uIx0BNl2YfPun8ZsZ6ibR7DO/0mffpMyvVYz8uxXMoq8c5nox67ssoJeE3YLU52WbOtGXZqqI9cXHbcwUEAAiCAgIABEEBAQCCoIAAAEFQQACAICggAEAQFBAAIAgKCAAQBAUEAAiCAgIABEEBAQCCoIAAAEFc0sNIt5y6yis3Zehhc+Zyj8GiE9LazZlPDT5pzvy3Wf9kzkjS+t9cbQ8dHmKOpNgPgxIx+7BPSYp4DJ+MJO2ZlDN+gyStXO/sRpIU8ZjB6VLsC/QZNOuztrP78nmnfM4h+24GtdgzklTT3mHODD6RMG3f3n5x23MFBAAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACAICggAEAQFBAAIggICAARBAQEAgqCAAABBXNrDSI9O8MrFRjebM+PS/mjO/CJ+pTmTFrEPGizL2G/OSNKUz9eaMyNKTpszean2zMb4n5kzkvTSUfuA2vak/fu45rY0cybpsZ9omm2I5AeGeAzC9TkOg1LsUzh95oqebol6pKTYEPvEz8LhjeZMW0eqOZP0mcoq6Y0z48yZYzMGm7bvaJV0ETOOuQICAARBAQEAgjAVUGVlpa6++mplZGQoNzdXCxYsUHV1dZdtZs2apUgk0uV25513duuiAQD9n6mAqqqqVF5erp07d+qVV15RIpHQnDlz1NTU1GW7JUuW6NixY523VatWdeuiAQD9n+lFCJs3b+7y9rp165Sbm6s9e/Zo5syZnfcPHTpU+fn53bNCAMCA9ImeA2psPPtqj+zs7C73P/nkk8rJydGkSZNUUVGh5ubzv2qstbVV8Xi8yw0AMPB5vww7mUzq7rvv1jXXXKNJkyZ13n/bbbdpzJgxKiws1P79+3X//ferurpazz333Dn/n8rKSj300EO+ywAA9FPeBVReXq63335br7/+epf7ly5d2vnvyZMnq6CgQLNnz9ahQ4c0btxHX39eUVGhFStWdL4dj8dVVFTkuywAQD/hVUDLly/Xiy++qO3bt2vUqFEfu21JSYkk6eDBg+csoGg0qmjU75fEAAD9l6mAnHO66667tHHjRm3btk3FxcUXzOzbt0+SVFBQ4LVAAMDAZCqg8vJyrV+/Xps2bVJGRobq6uokSbFYTEOGDNGhQ4e0fv16ffnLX9aIESO0f/9+3XPPPZo5c6amTJnSI+8AAKB/MhXQmjVrJJ39ZdP/bO3atVq8eLHS09O1ZcsWPfroo2pqalJRUZEWLlyoBx54oNsWDAAYGMw/gvs4RUVFqqqq+kQLAgBcGi7padhDPCcFL8v6jTlzMBExZ8qz7NOm/dgn8Z5l/52t5mSbOTM0Zag58+mc6gtvdA7LLttrzlyWal+fjzda7JOjG5J+a8tPtX9sB3tMYu+Q/fOixdnP15EpreaMJBWnDTdn9rTaz/HxafZjt6Mly5yRpPV/LDFnRlW+adq+3SV04CK2YxgpACAICggAEAQFBAAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACAICggAEAQFBAAIggICAARxaQ8j/brfX2Itm/B1cybaYB98mkzrne8POqJ++zlyvT0XybMPhcx4c4g5U/jLo+aMJLlU+/vUcdkwcyb1VIs5o2PHzRGXaLfvR1JkqH2IaWS4x+DTC0zYP6d2++BOX81X2f+Qps/n7fA9h82Z9mN15sxZ9kGzPYUrIABAEBQQACAICggAEAQFBAAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACAICggAEESfmwXn/n02VLsSkseYKNO+OuxzySSpPWGf45Xa7jELLtJLs+BS/PaTbPGYBddsP+YdbRFzpj3p97F1Ht+TdbSn2vfjc+65NnvEec6CS9q/NESS9uPgNQsu2Xuz4Nrb7Z/rSY9zqD1p/9i2O/vXlN7SrrNrcxf4+EbchbboZUeOHFFRUVHoZQAAPqHa2lqNGjXqvI/3uQJKJpM6evSoMjIyFIl0/c43Ho+rqKhItbW1yszMDLTC8DgOZ3EczuI4nMVxOKsvHAfnnE6dOqXCwkKlfMxPWPrcj+BSUlI+tjElKTMz85I+wT7AcTiL43AWx+EsjsNZoY9DLBa74Da8CAEAEAQFBAAIol8VUDQa1cqVKxWN+v0l04GC43AWx+EsjsNZHIez+tNx6HMvQgAAXBr61RUQAGDgoIAAAEFQQACAICggAEAQ/aaAVq9erU996lMaPHiwSkpK9Otf/zr0knrdd77zHUUikS63iRMnhl5Wj9u+fbtuuOEGFRYWKhKJ6Pnnn+/yuHNODz74oAoKCjRkyBCVlpbqwIEDYRbbgy50HBYvXvyR82PevHlhFttDKisrdfXVVysjI0O5ublasGCBqquru2zT0tKi8vJyjRgxQsOHD9fChQtVX18faMU942KOw6xZsz5yPtx5552BVnxu/aKAnnnmGa1YsUIrV67UW2+9palTp2ru3Lk6fvx46KX1uquuukrHjh3rvL3++uuhl9TjmpqaNHXqVK1evfqcj69atUo//vGP9fjjj2vXrl0aNmyY5s6dq5YW+yDJvuxCx0GS5s2b1+X8eOqpp3pxhT2vqqpK5eXl2rlzp1555RUlEgnNmTNHTU1Nndvcc889euGFF7RhwwZVVVXp6NGjuummmwKuuvtdzHGQpCVLlnQ5H1atWhVoxefh+oHp06e78vLyzrc7OjpcYWGhq6ysDLiq3rdy5Uo3derU0MsISpLbuHFj59vJZNLl5+e7hx9+uPO+hoYGF41G3VNPPRVghb3jw8fBOecWLVrk5s+fH2Q9oRw/ftxJclVVVc65sx/7tLQ0t2HDhs5tfve73zlJbseOHaGW2eM+fBycc+6LX/yi+/rXvx5uURehz18BtbW1ac+ePSotLe28LyUlRaWlpdqxY0fAlYVx4MABFRYWauzYsbr99tt1+PDh0EsKqqamRnV1dV3Oj1gsppKSkkvy/Ni2bZtyc3M1YcIELVu2TCdPngy9pB7V2NgoScrOzpYk7dmzR4lEosv5MHHiRI0ePXpAnw8fPg4fePLJJ5WTk6NJkyapoqJCzc3NIZZ3Xn1uGOmHnThxQh0dHcrLy+tyf15enn7/+98HWlUYJSUlWrdunSZMmKBjx47poYce0nXXXae3335bGRkZoZcXRF1dnSSd8/z44LFLxbx583TTTTepuLhYhw4d0re+9S2VlZVpx44dSk31+Fs9fVwymdTdd9+ta665RpMmTZJ09nxIT09XVlZWl20H8vlwruMgSbfddpvGjBmjwsJC7d+/X/fff7+qq6v13HPPBVxtV32+gPAfysrKOv89ZcoUlZSUaMyYMXr22Wf11a9+NeDK0Bfccsstnf+ePHmypkyZonHjxmnbtm2aPXt2wJX1jPLycr399tuXxPOgH+d8x2Hp0qWd/548ebIKCgo0e/ZsHTp0SOPGjevtZZ5Tn/8RXE5OjlJTUz/yKpb6+nrl5+cHWlXfkJWVpSuvvFIHDx4MvZRgPjgHOD8+auzYscrJyRmQ58fy5cv14osv6rXXXuvy51vy8/PV1tamhoaGLtsP1PPhfMfhXEpKSiSpT50Pfb6A0tPTNW3aNG3durXzvmQyqa1bt2rGjBkBVxbe6dOndejQIRUUFIReSjDFxcXKz8/vcn7E43Ht2rXrkj8/jhw5opMnTw6o88M5p+XLl2vjxo169dVXVVxc3OXxadOmKS0trcv5UF1drcOHDw+o8+FCx+Fc9u3bJ0l963wI/SqIi/H000+7aDTq1q1b537729+6pUuXuqysLFdXVxd6ab3qG9/4htu2bZurqalxb7zxhistLXU5OTnu+PHjoZfWo06dOuX27t3r9u7d6yS5Rx55xO3du9e99957zjnnfvCDH7isrCy3adMmt3//fjd//nxXXFzszpw5E3jl3evjjsOpU6fcvffe63bs2OFqamrcli1b3Oc+9zl3xRVXuJaWltBL7zbLli1zsVjMbdu2zR07dqzz1tzc3LnNnXfe6UaPHu1effVVt3v3bjdjxgw3Y8aMgKvufhc6DgcPHnTf/e533e7du11NTY3btGmTGzt2rJs5c2bglXfVLwrIOecee+wxN3r0aJeenu6mT5/udu7cGXpJve7mm292BQUFLj093V1++eXu5ptvdgcPHgy9rB732muvOUkfuS1atMg5d/al2N/+9rddXl6ei0ajbvbs2a66ujrsonvAxx2H5uZmN2fOHDdy5EiXlpbmxowZ45YsWTLgvkk71/svya1du7ZzmzNnzrivfe1r7rLLLnNDhw51N954ozt27Fi4RfeACx2Hw4cPu5kzZ7rs7GwXjUbd+PHj3Te/+U3X2NgYduEfwp9jAAAE0eefAwIADEwUEAAgCAoIABAEBQQACIICAgAEQQEBAIKggAAAQVBAAIAgKCAAQBAUEAAgCAoIABAEBQQACOL/AyBNQnoqGwl/AAAAAElFTkSuQmCC", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.figure()\n", "plt.imshow(img)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 52, "id": "56f36efb-e3a2-49f9-b9fb-1657bc25e5c5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[-1.0776339769363403, -3.4281859397888184, 1.0321333408355713, -2.1151161193847656, 0.7665405869483948, 0.7089913487434387, 0.6775667071342468, 0.3138602077960968, 2.9969606399536133, 0.7927607297897339]\n", "predicted label: Bag\n" ] } ], "source": [ "print(predictions)\n", "print(\"predicted label:\", classes[np.argmax(predictions)])" ] }, { "cell_type": "markdown", "id": "56ca1195-ea0f-405f-87fe-857e5c0c76a5", "metadata": {}, "source": [ "### 784 columns of float" ] }, { "cell_type": "code", "execution_count": 53, "id": "e0ab0af6-b5c9-4b74-9dd6-baa7737cc986", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "784" ] }, "execution_count": 53, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = spark.read.parquet(data_path_784)\n", "len(df.columns)" ] }, { "cell_type": "code", "execution_count": 54, "id": "13ae45dc-85a0-4864-8a58-9dc29ae4efd7", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 225 ms, sys: 91.1 ms, total: 316 ms\n", "Wall time: 3.16 s\n" ] } ], "source": [ "%%time\n", "preds = df.withColumn(\"preds\", mnist(struct(df.columns))).collect()" ] }, { "cell_type": "code", "execution_count": 55, "id": "0b3fb48b-f871-41f2-ac57-346899a6fe48", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 283 ms, sys: 67.8 ms, total: 351 ms\n", "Wall time: 1.47 s\n" ] } ], "source": [ "%%time\n", "preds = df.withColumn(\"preds\", mnist(array(*df.columns))).collect()" ] }, { "cell_type": "code", "execution_count": 56, "id": "b59114ad", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 543 ms, sys: 65.1 ms, total: 608 ms\n", "Wall time: 1.36 s\n" ] } ], "source": [ "%%time\n", "preds = df.withColumn(\"preds\", mnist(array(*df.columns))).collect()" ] }, { "cell_type": "markdown", "id": "dc48ec42-0df6-4e6a-b019-1270ab71d2cf", "metadata": { "tags": [] }, "source": [ "### Check predictions" ] }, { "cell_type": "code", "execution_count": 57, "id": "d815c701-9f5b-422c-b3f9-fbc30456953c", "metadata": {}, "outputs": [], "source": [ "preds = df.withColumn(\"preds\", mnist(array(*df.columns))).limit(10).toPandas()" ] }, { "cell_type": "code", "execution_count": 58, "id": "b571b742-5079-42b2-8524-9181a0dec2c7", "metadata": {}, "outputs": [], "source": [ "sample = preds.iloc[0]\n", "predictions = sample.preds\n", "img = sample.drop('preds').to_numpy(dtype=float)" ] }, { "cell_type": "code", "execution_count": 59, "id": "d33d6a4e-e6b9-489d-ac21-c4eddc801784", "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": 60, "id": "6d10061e-aca6-4f81-bdfe-72e327ed7349", "metadata": {}, "outputs": [], "source": [ "img = np.array(img).reshape(28,28)" ] }, { "cell_type": "code", "execution_count": 61, "id": "01f70e08-2c1d-419f-8676-3f6f4aba760f", "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAHpRJREFUeJzt3X1wlPW99/HPZpNsAoSNIeSpBBpQoZWHnlJJuVWKJQOkczyiTMenP8BxYLTBKVKrk46K2s6kxRnr6FA8f7RQ71t8mhEYPb3pKJowtoEOKDeH0zZCTix4QoLS5oGEPJD9nT84bu+FAP1dbPLdhPdr5prJ7l7fXF+uvchnr+y134Scc04AAAyzNOsGAABXJgIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJtKtGzhXLBZTc3OzcnJyFAqFrNsBAHhyzqmzs1MlJSVKS7vweU7KBVBzc7NKS0ut2wAAXKZjx45p0qRJF3w85QIoJydHknSjvqN0ZRh3g4tJn/wl75qy/33Cu2b3f03zrsnKOONdI0lnBobnt9IDbnjO7mOxYP+ejPCAd01He7Z3zdypR71r2v+527vG9fZ61yC4M+rXB/pN/Of5hQxZAG3cuFHPPPOMWlpaNGfOHL3wwguaN2/eJeu++LVbujKUHiKAUll6WsS7JnOc/3MaHuO/nXBG2LtGktwwBZCGKYBCAQMoHCCA0vqyvGsyxmZ616SH+r1rXCjmXYPL8D8TRi/1NsqQ/G977bXXtG7dOq1fv14ffvih5syZoyVLlujECf9XvwCA0WlIAujZZ5/VqlWrdO+99+qrX/2qXnzxRY0ZM0a/+tWvhmJzAIARKOkB1NfXp/3796uiouLvG0lLU0VFherr689bv7e3Vx0dHQkLAGD0S3oAff755xoYGFBhYWHC/YWFhWppaTlv/ZqaGkWj0fjCFXAAcGUw/yBqdXW12tvb48uxY8esWwIADIOkXwWXn5+vcDis1tbWhPtbW1tVVFR03vqRSESRiP9VTgCAkS3pZ0CZmZmaO3eudu3aFb8vFotp165dmj9/frI3BwAYoYbkc0Dr1q3TihUr9I1vfEPz5s3Tc889p66uLt17771DsTkAwAg0JAF0xx136LPPPtMTTzyhlpYWfe1rX9POnTvPuzABAHDlGrJJCGvWrNGaNWuG6tsjBfz1Bv9RPG8Wv+ld80xmp3dNWeQz7xpJCsv/E/NpAT5lPzbNfzTMgPP/jXk44ASA/9c9xbtmz9/KvGv+dcq/edf8y+Lve9dkvfUH7xoMPfOr4AAAVyYCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmhmwYKUa/9qn+r1/29Ya9a5p7c71r0uS8ayQpppB3TU8sw7smmt7tXZMVOuNdE2RQqiQ1dk/0rjnWlutd89vu8/9I5aW0Xe3/Y8t/KxgOnAEBAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwwDRuB9RQNeNdMzzjtXVMSafOuyc/o9K6RpBP9471rMkL++6E/5v9fr9tFvGtywj3eNZJUnNXuXfONIv8J5LMyj3vX9OV4lyBFcQYEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABMNIEVhmQbd3TY/zH1g5LsBAzX4X9q6RpDFpfd41nQNZ3jXtA9neNZ/3jvOuuSn3Y+8aKdg+D4diAWr8j4dYxL8GqYkzIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYYRorAxmb3etc0D0S8a2LO/3VSv3fFWRmhgWGp6Y35/9crG/O5d03bwBjvmqBOnfF/bpvP5HjX9BUHfXaRajgDAgCYIIAAACaSHkBPPvmkQqFQwjJjxoxkbwYAMMINyXtA1113nd59992/bySdt5oAAImGJBnS09NVVFQ0FN8aADBKDMl7QIcPH1ZJSYmmTp2qe+65R0ePHr3gur29vero6EhYAACjX9IDqLy8XFu2bNHOnTu1adMmNTU16aabblJnZ+eg69fU1CgajcaX0tLSZLcEAEhBSQ+gyspKffe739Xs2bO1ZMkS/eY3v1FbW5tef/31Qdevrq5We3t7fDl27FiyWwIApKAhvzogNzdX1157rY4cOTLo45FIRJGI/wfYAAAj25B/DujUqVNqbGxUcXHxUG8KADCCJD2AHn74YdXV1emTTz7R73//e912220Kh8O66667kr0pAMAIlvRfwX366ae66667dPLkSU2cOFE33nij9uzZo4kTJyZ7UwCAESzpAfTqq68m+1siReVE+rxrMhXzrkkL+ddkhYINrOx3/v8lpkZOeNdcndXiXdPcf5V3TXeA4a+SlJXmv/96YxneNR2xLO+azLH+xx1SE7PgAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmBjyP0iH0SsUct41Xc5/YGX7mTHeNUrv9q9RsCGmOeHT3jU//vifvWu+N7XWu+bj7iLvGknKDbD/ugYyA23LVyQSbNAsUg9nQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAE0zDRmA9Z/wPn4EAr3mae6PeNUFNzGr1rsnQgHdN9DtHvGtmNB33rtnTebV3jSS1BZhA3tkf8a4JMn08FuN182jBMwkAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEw0gR2Kke/+GTY0N93jUx5/866ZPTE7xrJOmfxnziXRMbptdxnw3keNeUZX8WaFt7/1bmXXOi27+/zJD/INeengzvGqQmzoAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYYBgpAuvu8h9GmhGKDUEn5xtwoUB1BeFO75r/c/J/BdhSr3fF7s4Z3jX/Ev3Qu0aS/q15lndNzxn/Hyc5aT3eNQOn+bE1WnAGBAAwQQABAEx4B9Du3bt1yy23qKSkRKFQSNu3b0943DmnJ554QsXFxcrOzlZFRYUOHz6crH4BAKOEdwB1dXVpzpw52rhx46CPb9iwQc8//7xefPFF7d27V2PHjtWSJUvU0+P/u14AwOjl/W5eZWWlKisrB33MOafnnntOjz32mG699VZJ0ksvvaTCwkJt375dd9555+V1CwAYNZL6HlBTU5NaWlpUUVERvy8ajaq8vFz19fWD1vT29qqjoyNhAQCMfkkNoJaWFklSYWFhwv2FhYXxx85VU1OjaDQaX0pLS5PZEgAgRZlfBVddXa329vb4cuzYMeuWAADDIKkBVFRUJElqbW1NuL+1tTX+2LkikYjGjx+fsAAARr+kBlBZWZmKioq0a9eu+H0dHR3au3ev5s+fn8xNAQBGOO+r4E6dOqUjR47Ebzc1NenAgQPKy8vT5MmTtXbtWv3kJz/RNddco7KyMj3++OMqKSnRsmXLktk3AGCE8w6gffv26eabb47fXrdunSRpxYoV2rJlix555BF1dXVp9erVamtr04033qidO3cqKysreV0DAEY87wBauHChnHMXfDwUCunpp5/W008/fVmNIfXFuvyHQmZoeIaRBjUzs9+7ZufHX/WumaaPvGvqT5R519yf94F3jST1x/x/O5+Vfsa/JjTgXRPqDnvXIDWZXwUHALgyEUAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBM+I8zBv5HWpf/VOKxacMzDftMLNjE5HFp/n82JKMhO9C2fB1rzvOuKbwuM9C2unv96/LGdnvXjAkwDTt8mtfNowXPJADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMMI0Vg4Z6Qd01WyL/m9ECGd01+5JR3TVDR/xyeAavZjRHvmvBi//0dVFrIeddkBGgvdMa/BqmJMyAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmGEaKwMK9/pMk+5z/wMpImv/0ySCDMYMa39QzLNsZ2zx8/6YxkT7vmrHp/jVh74pgQ3CRmjgDAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIJhpAgs/fTwbCcm/+GT0YDN9bsB75r0j//Lu8Z/K9L4T/yHfQY1Jedv3jV9Mf/Rohkh/+c2PHy7AUOMMyAAgAkCCABgwjuAdu/erVtuuUUlJSUKhULavn17wuMrV65UKBRKWJYuXZqsfgEAo4R3AHV1dWnOnDnauHHjBddZunSpjh8/Hl9eeeWVy2oSADD6eF+EUFlZqcrKyouuE4lEVFRUFLgpAMDoNyTvAdXW1qqgoEDTp0/XAw88oJMnT15w3d7eXnV0dCQsAIDRL+kBtHTpUr300kvatWuXfvazn6murk6VlZUaGBj8wtOamhpFo9H4UlpamuyWAAApKOmfA7rzzjvjX8+aNUuzZ8/WtGnTVFtbq0WLFp23fnV1tdatWxe/3dHRQQgBwBVgyC/Dnjp1qvLz83XkyJFBH49EIho/fnzCAgAY/YY8gD799FOdPHlSxcXFQ70pAMAI4v0ruFOnTiWczTQ1NenAgQPKy8tTXl6ennrqKS1fvlxFRUVqbGzUI488oquvvlpLlixJauMAgJHNO4D27dunm2++OX77i/dvVqxYoU2bNungwYP69a9/rba2NpWUlGjx4sX68Y9/rEgkkryuAQAjnncALVy4UM65Cz7+29/+9rIawsgRDjDvMyfNf2Bl74D/tTLjwj3eNUENfPbZsGwn65MLf5wh2fIyu71r/to3Zgg6OV8oNiybwTBgFhwAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwETS/yQ3rhzppy88Ff1CMuQ/DTuIsPx7S3UDx5qHbVsTMk951wSZhp0R8n8NHDrjXYIUxRkQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwwjRWCZnf4DP8OhkHfNV3OOe9ekhWLeNZKUERqeYalBuP6+YdvWVeld3jXTx7V612SF/H8EZXSNvkGzVyrOgAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJhgGCkCS+/1H/iZFuA1TzR82rsmktbvXSNJ/W4gUF2q6o4F2w85aT3eNeH04RkSyjDS0YMzIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYYRoqUlxE6410zNq030LZi8h+wmsr6FWxwZ5D91+/CgbblK6ObYaSjBWdAAAATBBAAwIRXANXU1Oj6669XTk6OCgoKtGzZMjU0NCSs09PTo6qqKk2YMEHjxo3T8uXL1dramtSmAQAjn1cA1dXVqaqqSnv27NE777yj/v5+LV68WF1dXfF1HnroIb311lt64403VFdXp+bmZt1+++1JbxwAMLJ5XYSwc+fOhNtbtmxRQUGB9u/frwULFqi9vV2//OUvtXXrVn3729+WJG3evFlf+cpXtGfPHn3zm99MXucAgBHtst4Dam9vlyTl5eVJkvbv36/+/n5VVFTE15kxY4YmT56s+vr6Qb9Hb2+vOjo6EhYAwOgXOIBisZjWrl2rG264QTNnzpQktbS0KDMzU7m5uQnrFhYWqqWlZdDvU1NTo2g0Gl9KS0uDtgQAGEECB1BVVZUOHTqkV1999bIaqK6uVnt7e3w5duzYZX0/AMDIEOiDqGvWrNHbb7+t3bt3a9KkSfH7i4qK1NfXp7a2toSzoNbWVhUVFQ36vSKRiCKRSJA2AAAjmNcZkHNOa9as0bZt2/Tee++prKws4fG5c+cqIyNDu3btit/X0NCgo0ePav78+cnpGAAwKnidAVVVVWnr1q3asWOHcnJy4u/rRKNRZWdnKxqN6r777tO6deuUl5en8ePH68EHH9T8+fO5Ag4AkMArgDZt2iRJWrhwYcL9mzdv1sqVKyVJP//5z5WWlqbly5ert7dXS5Ys0S9+8YukNAsAGD28Asi5Sw8BzMrK0saNG7Vx48bATWFkSO8ensGdAwGulQkywFSSelywulTVGQs2uDMt5P/cZoQG/LcT4LnNbB9dz9GVjFlwAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATgf4iKiBJ4e7UnUocVrAp0H8d8J/onMo+G8gOVBd0/w2HzJOnvWuGZ247fHEGBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwATDSJHyYs7/dVJGKNig1H/vKwhUl6paBqKB6rLS+rxrwrGsQNvylfZZm3cNw0hTE2dAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATDCMFIGFBpx3TX1veAg6SZ7/HGXDSBt6igPV/VP2J941aQFGfv7f7hzvGtfV5V2D1MQZEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMMI0VgsSz/waIT0k571xRmtHnXTE7/m3eNJLWcyQ1Ul6r+41SwYaSVOf/uXdPtIt41ueFu7xql82NrtOAMCABgggACAJjwCqCamhpdf/31ysnJUUFBgZYtW6aGhoaEdRYuXKhQKJSw3H///UltGgAw8nkFUF1dnaqqqrRnzx6988476u/v1+LFi9V1zh+IWrVqlY4fPx5fNmzYkNSmAQAjn9e7eTt37ky4vWXLFhUUFGj//v1asGBB/P4xY8aoqKgoOR0CAEaly3oPqL29XZKUl5eXcP/LL7+s/Px8zZw5U9XV1eruvvCVLr29vero6EhYAACjX+DrGWOxmNauXasbbrhBM2fOjN9/9913a8qUKSopKdHBgwf16KOPqqGhQW+++eag36empkZPPfVU0DYAACNU4ACqqqrSoUOH9MEHHyTcv3r16vjXs2bNUnFxsRYtWqTGxkZNmzbtvO9TXV2tdevWxW93dHSotLQ0aFsAgBEiUACtWbNGb7/9tnbv3q1JkyZddN3y8nJJ0pEjRwYNoEgkokjE/wNsAICRzSuAnHN68MEHtW3bNtXW1qqsrOySNQcOHJAkFRcH+0Q2AGB08gqgqqoqbd26VTt27FBOTo5aWlokSdFoVNnZ2WpsbNTWrVv1ne98RxMmTNDBgwf10EMPacGCBZo9e/aQ/AMAACOTVwBt2rRJ0tkPm/7/Nm/erJUrVyozM1PvvvuunnvuOXV1dam0tFTLly/XY489lrSGAQCjg/ev4C6mtLRUdXV1l9UQAODKwFhZBBaKXfwFyWCuy8z2rvnVyfMvXrmU5qyrvGsk6eWmed41efo40LaGQ3a4P1Ddr07e6F2TERrwrvnhxA8uvdK5LvFCGCMHw0gBACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYCLlLjbgeZh0dHYpGo1qoW5UeyrBuBwDg6YzrV612qL29XePHj7/gepwBAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMBEunUD5/piNN0Z9UspNaUOAPCPOKN+SX//eX4hKRdAnZ2dkqQP9BvjTgAAl6Ozs1PRaPSCj6fcNOxYLKbm5mbl5OQoFAolPNbR0aHS0lIdO3bsohNWRzv2w1nsh7PYD2exH85Khf3gnFNnZ6dKSkqUlnbhd3pS7gwoLS1NkyZNuug648ePv6IPsC+wH85iP5zFfjiL/XCW9X642JnPF7gIAQBgggACAJgYUQEUiUS0fv16RSIR61ZMsR/OYj+cxX44i/1w1kjaDyl3EQIA4Mowos6AAACjBwEEADBBAAEATBBAAAATIyaANm7cqC9/+cvKyspSeXm5/vCHP1i3NOyefPJJhUKhhGXGjBnWbQ253bt365ZbblFJSYlCoZC2b9+e8LhzTk888YSKi4uVnZ2tiooKHT582KbZIXSp/bBy5crzjo+lS5faNDtEampqdP311ysnJ0cFBQVatmyZGhoaEtbp6elRVVWVJkyYoHHjxmn58uVqbW016nho/CP7YeHChecdD/fff79Rx4MbEQH02muvad26dVq/fr0+/PBDzZkzR0uWLNGJEyesWxt21113nY4fPx5fPvjgA+uWhlxXV5fmzJmjjRs3Dvr4hg0b9Pzzz+vFF1/U3r17NXbsWC1ZskQ9PT3D3OnQutR+kKSlS5cmHB+vvPLKMHY49Orq6lRVVaU9e/bonXfeUX9/vxYvXqyurq74Og899JDeeustvfHGG6qrq1Nzc7Nuv/12w66T7x/ZD5K0atWqhONhw4YNRh1fgBsB5s2b56qqquK3BwYGXElJiaupqTHsavitX7/ezZkzx7oNU5Lctm3b4rdjsZgrKipyzzzzTPy+trY2F4lE3CuvvGLQ4fA4dz8459yKFSvcrbfeatKPlRMnTjhJrq6uzjl39rnPyMhwb7zxRnydP/3pT06Sq6+vt2pzyJ27H5xz7lvf+pb7/ve/b9fUPyDlz4D6+vq0f/9+VVRUxO9LS0tTRUWF6uvrDTuzcfjwYZWUlGjq1Km65557dPToUeuWTDU1NamlpSXh+IhGoyovL78ij4/a2loVFBRo+vTpeuCBB3Ty5EnrloZUe3u7JCkvL0+StH//fvX39yccDzNmzNDkyZNH9fFw7n74wssvv6z8/HzNnDlT1dXV6u7utmjvglJuGOm5Pv/8cw0MDKiwsDDh/sLCQv35z3826spGeXm5tmzZounTp+v48eN66qmndNNNN+nQoUPKycmxbs9ES0uLJA16fHzx2JVi6dKluv3221VWVqbGxkb96Ec/UmVlperr6xUOh63bS7pYLKa1a9fqhhtu0MyZMyWdPR4yMzOVm5ubsO5oPh4G2w+SdPfdd2vKlCkqKSnRwYMH9eijj6qhoUFvvvmmYbeJUj6A8HeVlZXxr2fPnq3y8nJNmTJFr7/+uu677z7DzpAK7rzzzvjXs2bN0uzZszVt2jTV1tZq0aJFhp0NjaqqKh06dOiKeB/0Yi60H1avXh3/etasWSouLtaiRYvU2NioadOmDXebg0r5X8Hl5+crHA6fdxVLa2urioqKjLpKDbm5ubr22mt15MgR61bMfHEMcHycb+rUqcrPzx+Vx8eaNWv09ttv6/3330/48y1FRUXq6+tTW1tbwvqj9Xi40H4YTHl5uSSl1PGQ8gGUmZmpuXPnateuXfH7YrGYdu3apfnz5xt2Zu/UqVNqbGxUcXGxdStmysrKVFRUlHB8dHR0aO/evVf88fHpp5/q5MmTo+r4cM5pzZo12rZtm9577z2VlZUlPD537lxlZGQkHA8NDQ06evToqDoeLrUfBnPgwAFJSq3jwfoqiH/Eq6++6iKRiNuyZYv74x//6FavXu1yc3NdS0uLdWvD6gc/+IGrra11TU1N7ne/+52rqKhw+fn57sSJE9atDanOzk730UcfuY8++shJcs8++6z76KOP3F/+8hfnnHM//elPXW5urtuxY4c7ePCgu/XWW11ZWZk7ffq0cefJdbH90NnZ6R5++GFXX1/vmpqa3Lvvvuu+/vWvu2uuucb19PRYt540DzzwgItGo662ttYdP348vnR3d8fXuf/++93kyZPde++95/bt2+fmz5/v5s+fb9h18l1qPxw5csQ9/fTTbt++fa6pqcnt2LHDTZ061S1YsMC480QjIoCcc+6FF15wkydPdpmZmW7evHluz5491i0NuzvuuMMVFxe7zMxM96Uvfcndcccd7siRI9ZtDbn333/fSTpvWbFihXPu7KXYjz/+uCssLHSRSMQtWrTINTQ02DY9BC62H7q7u93ixYvdxIkTXUZGhpsyZYpbtWrVqHuRNti/X5LbvHlzfJ3Tp0+7733ve+6qq65yY8aMcbfddps7fvy4XdND4FL74ejRo27BggUuLy/PRSIRd/XVV7sf/vCHrr293bbxc/DnGAAAJlL+PSAAwOhEAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADAxH8DR+VYJGWV2nkAAAAASUVORK5CYII=", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.figure()\n", "plt.imshow(img)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 62, "id": "8e1c07cc-b2bc-4902-a9a6-4ac7f02c5fe4", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[ 2.44647 3.8623989 0.14587203 3.2146688 1.0799949 -2.5363288\n", " 0.86715794 -3.8287208 -2.02238 -2.9016623 ]\n", "predicted label: Trouser\n" ] } ], "source": [ "print(predictions)\n", "print(\"predicted label:\", classes[np.argmax(predictions)])" ] }, { "cell_type": "code", "execution_count": 63, "id": "3d47a8ec", "metadata": {}, "outputs": [], "source": [ "# This will clear the engine cache (containing previously compiled TensorRT engines) and resets the CUDA Context.\n", "torch._dynamo.reset()" ] }, { "cell_type": "markdown", "id": "281c7889", "metadata": {}, "source": [ "## Using Triton Inference Server\n", "In this section, we demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL. \n", "We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server. \n", "\n", "The process looks like this:\n", "- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.\n", "- Define a Triton inference function, which contains a client that binds to the local server on a given node and sends inference requests.\n", "- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.\n", "- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.\n", "\n", "\"drawing\"" ] }, { "cell_type": "code", "execution_count": 64, "id": "53ca290a-ccc3-4923-a292-944921bab36d", "metadata": {}, "outputs": [], "source": [ "from functools import partial" ] }, { "cell_type": "markdown", "id": "d8abea75", "metadata": {}, "source": [ "Import the helper class from server_utils.py:" ] }, { "cell_type": "code", "execution_count": 65, "id": "e616b207", "metadata": {}, "outputs": [], "source": [ "sc.addPyFile(\"server_utils.py\")\n", "\n", "from server_utils import TritonServerManager" ] }, { "cell_type": "markdown", "id": "606934ac", "metadata": {}, "source": [ "Define the Triton Server function:" ] }, { "cell_type": "code", "execution_count": 66, "id": "8fa92fe4-2e04-4d82-a357-bfdfca38bd8c", "metadata": {}, "outputs": [], "source": [ "def triton_server(ports, model_path):\n", " import time\n", " import signal\n", " import numpy as np\n", " import torch\n", " from torch import nn\n", " import torch_tensorrt as trt\n", " from pytriton.decorators import batch\n", " from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor\n", " from pytriton.triton import Triton, TritonConfig\n", " from pyspark import TaskContext\n", "\n", " print(f\"SERVER: Initializing model on worker {TaskContext.get().partitionId()}.\")\n", " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", " \n", " exp_program = torch.export.load(model_path)\n", " example_inputs = (torch.randn((50, 784), dtype=torch.float).to(\"cuda\"),)\n", " trt_gm = trt.dynamo.compile(exp_program,\n", " tuple(example_inputs),\n", " enabled_precisions={torch.float},\n", " workspace_size=1<<30)\n", "\n", " print(\"SERVER: Compiled model.\")\n", "\n", " @batch\n", " def _infer_fn(**inputs):\n", " images = inputs[\"images\"]\n", " if len(images) != 1:\n", " images = np.squeeze(images)\n", " stream = torch.cuda.Stream()\n", " with torch.no_grad(), torch.cuda.stream(stream):\n", " torch_inputs = torch.from_numpy(images).to(device)\n", " outputs = trt_gm(torch_inputs)\n", " return {\n", " \"labels\": outputs.cpu().numpy(),\n", " }\n", " \n", " workspace_path = f\"/tmp/triton_{time.strftime('%m_%d_%M_%S')}\"\n", " triton_conf = TritonConfig(http_port=ports[0], grpc_port=ports[1], metrics_port=ports[2])\n", " with Triton(config=triton_conf, workspace=workspace_path) as triton:\n", " triton.bind(\n", " model_name=\"ImageClassifier\",\n", " infer_func=_infer_fn,\n", " inputs=[\n", " Tensor(name=\"images\", dtype=np.float32, shape=(-1,)),\n", " ],\n", " outputs=[\n", " Tensor(name=\"labels\", dtype=np.float32, shape=(-1,)),\n", " ],\n", " config=ModelConfig(\n", " max_batch_size=64,\n", " batcher=DynamicBatcher(max_queue_delay_microseconds=5000), # 5ms\n", " ),\n", " strict=True,\n", " )\n", "\n", " def _stop_triton(signum, frame):\n", " # The server manager sends SIGTERM to stop the server; this function ensures graceful cleanup.\n", " print(\"SERVER: Received SIGTERM. Stopping Triton server.\")\n", " triton.stop()\n", "\n", " signal.signal(signal.SIGTERM, _stop_triton)\n", "\n", " print(\"SERVER: Serving inference\")\n", " triton.serve()" ] }, { "cell_type": "markdown", "id": "8fea6e5e", "metadata": {}, "source": [ "#### Start Triton servers " ] }, { "cell_type": "markdown", "id": "f837300c", "metadata": {}, "source": [ "The `TritonServerManager` will handle the lifecycle of Triton server instances across the Spark cluster:\n", "- Find available ports for HTTP/gRPC/metrics\n", "- Deploy a server on each node via stage-level scheduling\n", "- Gracefully shutdown servers across nodes" ] }, { "cell_type": "code", "execution_count": 68, "id": "f72c53d6", "metadata": {}, "outputs": [], "source": [ "model_name = \"ImageClassifier\"\n", "server_manager = TritonServerManager(model_name=model_name, model_path=model_path)" ] }, { "cell_type": "code", "execution_count": null, "id": "65d3f7be", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-02-07 11:03:44,809 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\n", "2025-02-07 11:03:44,810 - INFO - Starting 1 servers.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "data": { "text/plain": [ "{'cb4ae00-lcedt': (2020631, [7000, 7001, 7002])}" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Returns {'hostname', (server_pid, [http_port, grpc_port, metrics_port])}\n", "server_manager.start_servers(triton_server)" ] }, { "cell_type": "markdown", "id": "90ed191b", "metadata": {}, "source": [ "#### Define client function" ] }, { "cell_type": "markdown", "id": "86c1545a", "metadata": {}, "source": [ "Get the hostname -> url mapping from the server manager:" ] }, { "cell_type": "code", "execution_count": null, "id": "c4c2833f", "metadata": {}, "outputs": [], "source": [ "host_to_http_url = server_manager.host_to_http_url # or server_manager.host_to_grpc_url" ] }, { "cell_type": "markdown", "id": "c6771c93", "metadata": {}, "source": [ "Define the Triton inference function, which returns a predict function for batch inference through the server:" ] }, { "cell_type": "code", "execution_count": 71, "id": "cec9a48c", "metadata": {}, "outputs": [], "source": [ "def triton_fn(model_name, host_to_url):\n", " import socket\n", " from pytriton.client import ModelClient\n", "\n", " url = host_to_url[socket.gethostname()]\n", " print(f\"Connecting to Triton model {model_name} at {url}.\")\n", "\n", " def infer_batch(inputs):\n", " with ModelClient(url, model_name, inference_timeout_s=240) as client:\n", " result_data = client.infer_batch(inputs)\n", " return result_data[\"labels\"]\n", " \n", " return infer_batch" ] }, { "cell_type": "code", "execution_count": 73, "id": "0262fd4a-9845-44b9-8c75-1c105e7deeca", "metadata": {}, "outputs": [], "source": [ "mnist = predict_batch_udf(partial(triton_fn, model_name=model_name, host_to_url=host_to_http_url),\n", " input_tensor_shapes=[[784]],\n", " return_type=ArrayType(FloatType()),\n", " batch_size=50)" ] }, { "cell_type": "markdown", "id": "30a4362d-7514-4b84-b238-f704a97e1e72", "metadata": {}, "source": [ "#### Run inference" ] }, { "cell_type": "code", "execution_count": 72, "id": "ab94d4d1-dac6-4474-9eb0-59478aa98f7d", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "StructType([StructField('data', ArrayType(FloatType(), True), True)])" ] }, "execution_count": 72, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = spark.read.parquet(data_path_1)\n", "df.schema" ] }, { "cell_type": "code", "execution_count": 74, "id": "fc5f6baa-052e-4b89-94b6-4821cf01952a", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 157 ms, sys: 47.6 ms, total: 205 ms\n", "Wall time: 2.49 s\n" ] } ], "source": [ "%%time\n", "preds = df.withColumn(\"preds\", mnist(struct(df.columns))).collect()" ] }, { "cell_type": "code", "execution_count": 75, "id": "a85dea35-e41d-482d-8a8f-52d3c108f038", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 183 ms, sys: 60.3 ms, total: 243 ms\n", "Wall time: 1.49 s\n" ] } ], "source": [ "%%time\n", "preds = df.withColumn(\"preds\", mnist(*df.columns)).collect()" ] }, { "cell_type": "code", "execution_count": 76, "id": "bc3f0dbe-c52b-41d6-8097-8cebaa5ee5a8", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 383 ms, sys: 43.9 ms, total: 427 ms\n", "Wall time: 1.6 s\n" ] } ], "source": [ "%%time\n", "preds = df.withColumn(\"preds\", mnist(*[col(c) for c in df.columns])).collect()" ] }, { "cell_type": "code", "execution_count": 77, "id": "99fb5e8d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Predicted label: Bag\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAJJ5JREFUeJzt3Xt0lfWd7/HPTkg2t2THEHKTQAMotHLplErKqBRLDpDOuEA5HW9zDrg6MNLgqlKrJz1Watuz0uIa66lDca2zWqir4oVzREbHYgUljAp0QBjGXlLAKGEgoWCTDQlJdrJ/5w/GzERB+P5M8kvC+7XWXovs/Xx4fnnyJJ882TvfRJxzTgAA9LKU0AsAAFyaKCAAQBAUEAAgCAoIABAEBQQACIICAgAEQQEBAIKggAAAQQwKvYAPSyaTOnr0qDIyMhSJREIvBwBg5JzTqVOnVFhYqJSU81/n9LkCOnr0qIqKikIvAwDwCdXW1mrUqFHnfbzPFVBGRoYk6Vp9WYOUFng16Ha9dVXLhCkgmHYl9Lpe6vx6fj49VkCrV6/Www8/rLq6Ok2dOlWPPfaYpk+ffsHcBz92G6Q0DYpQQANOr/1YlQICgvn3T78LPY3SIy9CeOaZZ7RixQqtXLlSb731lqZOnaq5c+fq+PHjPbE7AEA/1CMF9Mgjj2jJkiW644479JnPfEaPP/64hg4dqp/97Gc9sTsAQD/U7QXU1tamPXv2qLS09D92kpKi0tJS7dix4yPbt7a2Kh6Pd7kBAAa+bi+gEydOqKOjQ3l5eV3uz8vLU11d3Ue2r6ysVCwW67zxCjgAuDQE/0XUiooKNTY2dt5qa2tDLwkA0Au6/VVwOTk5Sk1NVX19fZf76+vrlZ+f/5Hto9GootFody8DANDHdfsVUHp6uqZNm6atW7d23pdMJrV161bNmDGju3cHAOineuT3gFasWKFFixbp85//vKZPn65HH31UTU1NuuOOO3pidwCAfqhHCujmm2/WH//4Rz344IOqq6vTZz/7WW3evPkjL0wAAFy6Is71rZkl8XhcsVhMszSfSQgYsFLHF5szx/7O/lxp7vzfmzPAJ9XuEtqmTWpsbFRmZuZ5twv+KjgAwKWJAgIABEEBAQCCoIAAAEFQQACAICggAEAQFBAAIAgKCAAQBAUEAAiCAgIABEEBAQCCoIAAAEH0yDRsoL+q+YH9b1bdP3+jOVPVcP4BjeczPu2MOZO/356RpF9smG3OFH3vTa994dLFFRAAIAgKCAAQBAUEAAiCAgIABEEBAQCCoIAAAEFQQACAICggAEAQFBAAIAgKCAAQBAUEAAiCAgIABEEBAQCCYBp2HxYZZP/wuI4O+46cs2c8RdLSzRmXaDNnBhWPMWckacttD5szX/zV3ebMlX+z25ypNyek3Tdf75GSHvzuU+bM2u/5HXOzSMSe6cVzHBePKyAAQBAUEAAgCAoIABAEBQQACIICAgAEQQEBAIKggAAAQVBAAIAgKCAAQBAUEAAgCAoIABAEBQQACCLiXN+a0hePxxWLxTRL8zUokhZ6Od3HY4BiJDXVnPEaRuqrb506Xfxh7TSv3NiiP5ozg0oPe+2rLzvzcrE5c8Pl+82ZLZMyzBn0fe0uoW3apMbGRmVmZp53O66AAABBUEAAgCAoIABAEBQQACAICggAEAQFBAAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACCIQaEXcMnwGNzp2tvt+/EYetqXh4pKUsrUT5szW770v732VfrLFebMlfIYRppiHzQbSbF/bL3OIUlpP8w2Z25e9y/mzIbF3zRnLlu3w5xB38QVEAAgCAoIABBEtxfQd77zHUUikS63iRMndvduAAD9XI88B3TVVVdpy5Yt/7GTQTzVBADoqkeaYdCgQcrPz++J/xoAMED0yHNABw4cUGFhocaOHavbb79dhw+f/1VCra2tisfjXW4AgIGv2wuopKRE69at0+bNm7VmzRrV1NTouuuu06lTp865fWVlpWKxWOetqKiou5cEAOiDur2AysrK9JWvfEVTpkzR3Llz9dJLL6mhoUHPPvvsObevqKhQY2Nj5622tra7lwQA6IN6/NUBWVlZuvLKK3Xw4MFzPh6NRhWNRnt6GQCAPqbHfw/o9OnTOnTokAoKCnp6VwCAfqTbC+jee+9VVVWV3n33Xb355pu68cYblZqaqltvvbW7dwUA6Me6/UdwR44c0a233qqTJ09q5MiRuvbaa7Vz506NHDmyu3cFAOjHur2Ann766e7+L2HhM1jUYzCmJCnZYY7Eb/2COTNm+R/MmcdPXmfOSFLhq700ncolzZFIdKh9N57DSOu+YH9e9p1EpjnzzEMPmzMr//bL5kz9DH69oy9iFhwAIAgKCAAQBAUEAAiCAgIABEEBAQCCoIAAAEFQQACAICggAEAQFBAAIAgKCAAQBAUEAAiCAgIABNHjf5Cu10Qi9ozP4M7e5DMk1GPIpc9QUV9Xf2OPOVPdmGfONLenmzOSNPzZnV45q0iq5wDYXpJ2yp55s+kKc+bRP33KnLlr1BZz5oHblpgzkpS53uN86K2vRT778d1XD+EKCAAQBAUEAAiCAgIABEEBAQCCoIAAAEFQQACAICggAEAQFBAAIAgKCAAQBAUEAAiCAgIABEEBAQCCoIAAAEEMnGnYfZ3H5Fqficku0XuTrYdtH2nOtDv7mOXUFPuE73c3jTVnJKlAdV45K9fh8XFqS3T/Qs4j77E3zZlvVVSbM9cevcqcWXlgvjlz8//cbM5I0ssvFJkzyVMeo8R9+E617kN/OYArIABAEBQQACAICggAEAQFBAAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACAICggAEAQFBAAIYuAMI/UYlhdJS/fbVaLNI2Rfn9d+PBy978+9ct/MfdacefLfvmDOJGUfnljwiH2YZq/qw+eDr181p5kz/33MTnNm1d455kxqkd8wzYxf2r9GNF7rtaveE/G47nA9M+SYKyAAQBAUEAAgCAoIABAEBQQACIICAgAEQQEBAIKggAAAQVBAAIAgKCAAQBAUEAAgCAoIABAEBQQACCLinMdUxB4Uj8cVi8U0S/M1KGIfbjiQHF9uHxLammXfz6+WrrKHJK1r+Lw5Myr9fXPme7+8yZyJ/cE+wFSSrv+bXebMxt981pwZ9G9RcyalzeN9ivh9evvsq3VE0pwZOeGEOROLtpgzZ9r9vpasHP8P5syyDUvNmeL/scOc6cvaXULbtEmNjY3KzMw873ZcAQEAgqCAAABBmAto+/btuuGGG1RYWKhIJKLnn3++y+POOT344IMqKCjQkCFDVFpaqgMHDnTXegEAA4S5gJqamjR16lStXr36nI+vWrVKP/7xj/X4449r165dGjZsmObOnauWFvvPbQEAA5f5L6KWlZWprKzsnI855/Too4/qgQce0Pz58yVJTzzxhPLy8vT888/rlltu+WSrBQAMGN36HFBNTY3q6upUWlraeV8sFlNJSYl27Dj3qzxaW1sVj8e73AAAA1+3FlBdXZ0kKS8vr8v9eXl5nY99WGVlpWKxWOetqKioO5cEAOijgr8KrqKiQo2NjZ232tra0EsCAPSCbi2g/Px8SVJ9fX2X++vr6zsf+7BoNKrMzMwuNwDAwNetBVRcXKz8/Hxt3bq18754PK5du3ZpxowZ3bkrAEA/Z34V3OnTp3Xw4MHOt2tqarRv3z5lZ2dr9OjRuvvuu/X9739fV1xxhYqLi/Xtb39bhYWFWrBgQXeuGwDQz5kLaPfu3br++us7316xYoUkadGiRVq3bp3uu+8+NTU1aenSpWpoaNC1116rzZs3a/Dgwd23agBAv3dJDyN9Z5XfjwX/7safmzMr/vmvzJn09HZz5gdTnzNnNjdMMWckKUX2U+d3jXkX3uhD3tt7uTnTcVnCnJGkwbXp5kzmO/bjkNJuz3Sk2weEJs3fYp7lUu2ZZJp9fZGk/Ticntlszny26Ig5I0lHT8fMmWvy3jFn3nrf/urfI+9nmTOSNPor/+qVs2AYKQCgT6OAAABBUEAAgCAoIABAEBQQACAICggAEAQFBAAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACAIz1m5A8Nvb/97r9zCg39hzuT8o/3PUZxeeMqc+dnR68yZxja/P5VxR9Eb5syJtmHmTM3gpDmjDvtkZklqu8y+r8RX/mTOjIo1mjMjo6fNmWiqfaK6JGUNsk+cTniM0G7qiJozMzOr7ftJ2vcjSW8OGm/OpMp+Dg0b1GbOvDR9jTkjSbfefq85E3typ9e+LoQrIABAEBQQACAICggAEAQFBAAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACAICggAEAQFBAAIYsAMI225Ybo5kxbZ57evb+WbMzn/6z1z5uXxz5kzD5+wH4ehKfZBiJK08vUF5kxK3H7KuSyPgZoe80vP7ithzsQPXGbOHGwYYc7U2meeKrXV2UOeOqL2AbDOY2bs6+mfM2duX/yKfUeSrsv6gznzucGHzZmX064yZ/7in+80ZyTpbyp+Zc68/GSm174uhCsgAEAQFBAAIAgKCAAQBAUEAAiCAgIABEEBAQCCoIAAAEFQQACAICggAEAQFBAAIAgKCAAQBAUEAAhiwAwjrf98770rWT+sNWf+MudfzJmfNtgHFP5V1j+bM9+t/UtzRpIu/2WqOZMY6jF9UmnmRCTpN4TTpdjX15Fu308yzb4+n7W1Zvkcb0kesYjHzFif/Qz/N/uk2cf/6Xr7jiT9Yf4ac+Y3bfZ3alFsvznzj5mTzRlJ+uvYv5ozv/qzpabtIx2t0r9suuB2XAEBAIKggAAAQVBAAIAgKCAAQBAUEAAgCAoIABAEBQQACIICAgAEQQEBAIKggAAAQVBAAIAgKCAAQBADZhhp2une29e9l282Zzb8abo5k5seN2e+eei/mjPNf3+5OSNJZ3Lt37/4DMdMbTNHlEz1G8KZ0ksDNX0Gdzr77FevtUlS+xC/nJXPcTg12n7eFVTZ9yNJX5s205wpHNxgzlSfzjNnFl6+15yRpBEp9g/umcuHmbZvT6RKFzF/mSsgAEAQFBAAIAhzAW3fvl033HCDCgsLFYlE9Pzzz3d5fPHixYpEIl1u8+bN6671AgAGCHMBNTU1aerUqVq9evV5t5k3b56OHTvWeXvqqac+0SIBAAOP+UUIZWVlKisr+9htotGo8vPzvRcFABj4euQ5oG3btik3N1cTJkzQsmXLdPLkyfNu29raqng83uUGABj4ur2A5s2bpyeeeEJbt27VD3/4Q1VVVamsrEwdHR3n3L6yslKxWKzzVlRU1N1LAgD0Qd3+e0C33HJL578nT56sKVOmaNy4cdq2bZtmz579ke0rKiq0YsWKzrfj8TglBACXgB5/GfbYsWOVk5OjgwcPnvPxaDSqzMzMLjcAwMDX4wV05MgRnTx5UgUFBT29KwBAP2L+Edzp06e7XM3U1NRo3759ys7OVnZ2th566CEtXLhQ+fn5OnTokO677z6NHz9ec+fO7daFAwD6N3MB7d69W9dff33n2x88f7No0SKtWbNG+/fv189//nM1NDSosLBQc+bM0fe+9z1Fo9HuWzUAoN8zF9CsWbPknDvv4y+//PInWpCvQc29t68DbfbfcXo43z448P+dtj8f1viz/2LOJGN+EysTw+y5yLlfDPmxOtLtGd8hnB9zap9/V700WNRrQKjncRh0xp5J8Rga63M++By7jqjfgfj1+qnmzIYVD5sz/5Q+zpy5esi75owkxZNJc2ZY9QnT9u0drRe1HbPgAABBUEAAgCAoIABAEBQQACAICggAEAQFBAAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACAICggAEES3/0nuUIacsE949TUu7bg5c6TdPl74vpfuNGcyRti/p0gMM0ckSelxe8Z5fMvTMdiekcdUa0lqH+qxK4/pzD7r85m67Ssx3L7ApMdXk9Q2+5TqlIsbtNxF/FN+07AzDtuPw7cOzzdn/u+4LebMtjMeJ6ukwtRT5kzHgXds27vERW3HFRAAIAgKCAAQBAUEAAiCAgIABEEBAQCCoIAAAEFQQACAICggAEAQFBAAIAgKCAAQBAUEAAiCAgIABDFghpFmHjptziRch9e+xqa1mDMzdywzZ/LfsA9C/NNEc8R7yGVLjj3TEbW/T9H3PQZJen5r5TMstX2o/X1qz7CfeynDL27A43+WbPWZlCpFj6aZM2mn7R8nr2PnMZw2tdVvGOmZXHvunfVXmDO/u/8fzBlpuEdGuixliDmTOmG8aXvX0SoduPB2XAEBAIKggAAAQVBAAIAgKCAAQBAUEAAgCAoIABAEBQQACIICAgAEQQEBAIKggAAAQVBAAIAgKCAAQBADZhhpSnObOZMW8RvU+K9tmebM8F/ZBwf+aaLHAMWkPdIxxD4QUpLasuw7i56wH/P0U/b1pS84bs5I0snGYeZMR5v90yj9cNScyd5u/34x4veh1Zls+7nXXOgxWNRjGKl85or6zSJVIsO+viEp9o/T/J13mjO/nPETc0aS2mU/9yIJ28TiSPLitucKCAAQBAUEAAiCAgIABEEBAQCCoIAAAEFQQACAICggAEAQFBAAIAgKCAAQBAUEAAiCAgIABEEBAQCCGDDDSNXeYY40Js947epvd37NnEkbZZ+G2Jpjf59SWu37iST8JjWm5zebM2N/FDdn2t87Ys7UR0vMGUkav/WkPXS01p7JHWGOvHdTrjnTPMY2RPIDkaH24b7ujMdw36TH+dpuz3Sk+k1ldYPsueZR9kzGTvuw4rrpQ80ZSRqXZr/uaH/nXdv2LnFR23EFBAAIggICAARhKqDKykpdffXVysjIUG5urhYsWKDq6uou27S0tKi8vFwjRozQ8OHDtXDhQtXX13frogEA/Z+pgKqqqlReXq6dO3fqlVdeUSKR0Jw5c9TU1NS5zT333KMXXnhBGzZsUFVVlY4ePaqbbrqp2xcOAOjfTC9C2Lx5c5e3161bp9zcXO3Zs0czZ85UY2OjfvrTn2r9+vX60pe+JElau3atPv3pT2vnzp36whe+0H0rBwD0a5/oOaDGxkZJUnZ2tiRpz549SiQSKi0t7dxm4sSJGj16tHbs2HHO/6O1tVXxeLzLDQAw8HkXUDKZ1N13361rrrlGkyZNkiTV1dUpPT1dWVlZXbbNy8tTXV3dOf+fyspKxWKxzltRUZHvkgAA/Yh3AZWXl+vtt9/W008//YkWUFFRocbGxs5bba3H71QAAPodr19EXb58uV588UVt375do0aN6rw/Pz9fbW1tamho6HIVVF9fr/z8/HP+X9FoVNFo1GcZAIB+zHQF5JzT8uXLtXHjRr366qsqLi7u8vi0adOUlpamrVu3dt5XXV2tw4cPa8aMGd2zYgDAgGC6AiovL9f69eu1adMmZWRkdD6vE4vFNGTIEMViMX31q1/VihUrlJ2drczMTN11112aMWMGr4ADAHRhKqA1a9ZIkmbNmtXl/rVr12rx4sWSpB/96EdKSUnRwoUL1draqrlz5+onP/lJtywWADBwmArIuQsP2Rs8eLBWr16t1atXey/KS6r99RQvNo268Ebn4JIeGY95n6nNHkMDh3sMMI34vRYl2Wp/CrH5ypHmTHrNe+bM5c8eMmckqf4vxpozJ6/JMGeyRpw2Z1pP24fnRt5PN2ckSQ1p9n35zLT1OfV8Mn6zSBVJ2HfmhtsHwDYX2vfz11v+1pyRpJq//D/mTOqIbNP2LtkmvX/h7ZgFBwAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACAICggAEAQFBAAIggICAARBAQEAgqCAAABBUEAAgCC8/iJqX9TxuwPmzLCUVq99PXHNT82Zv04sNWcizanmTGosYc4kO/wmJrsW++kTX95ozqTcdaU509xqn+YsSc6dsofeH2KONL6bZc5E7IPOpTTPMdD2U89r4rRL9Qh5TKOP+Iyjl+QG2w96aoP986JjmP2dSq/vvS/fLX9WfOGN/pP29hbptQtvxxUQACAICggAEAQFBAAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACAICggAEAQFBAAIggICAAQxYIaR+shKafbKpXpMXfzRdU+bM4WD/mTO/OLkn5szL74xzZyRJHXYBzy+fyTLnElttn+f5HrxW6uIx0BNl2YfPun8ZsZ6ibR7DO/0mffpMyvVYz8uxXMoq8c5nox67ssoJeE3YLU52WbOtGXZqqI9cXHbcwUEAAiCAgIABEEBAQCCoIAAAEFQQACAICggAEAQFBAAIAgKCAAQBAUEAAiCAgIABEEBAQCCoIAAAEFc0sNIt5y6yis3Zehhc+Zyj8GiE9LazZlPDT5pzvy3Wf9kzkjS+t9cbQ8dHmKOpNgPgxIx+7BPSYp4DJ+MJO2ZlDN+gyStXO/sRpIU8ZjB6VLsC/QZNOuztrP78nmnfM4h+24GtdgzklTT3mHODD6RMG3f3n5x23MFBAAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACAICggAEAQFBAAIggICAARBAQEAgqCAAABBXNrDSI9O8MrFRjebM+PS/mjO/CJ+pTmTFrEPGizL2G/OSNKUz9eaMyNKTpszean2zMb4n5kzkvTSUfuA2vak/fu45rY0cybpsZ9omm2I5AeGeAzC9TkOg1LsUzh95oqebol6pKTYEPvEz8LhjeZMW0eqOZP0mcoq6Y0z48yZYzMGm7bvaJV0ETOOuQICAARBAQEAgjAVUGVlpa6++mplZGQoNzdXCxYsUHV1dZdtZs2apUgk0uV25513duuiAQD9n6mAqqqqVF5erp07d+qVV15RIpHQnDlz1NTU1GW7JUuW6NixY523VatWdeuiAQD9n+lFCJs3b+7y9rp165Sbm6s9e/Zo5syZnfcPHTpU+fn53bNCAMCA9ImeA2psPPtqj+zs7C73P/nkk8rJydGkSZNUUVGh5ubzv2qstbVV8Xi8yw0AMPB5vww7mUzq7rvv1jXXXKNJkyZ13n/bbbdpzJgxKiws1P79+3X//ferurpazz333Dn/n8rKSj300EO+ywAA9FPeBVReXq63335br7/+epf7ly5d2vnvyZMnq6CgQLNnz9ahQ4c0btxHX39eUVGhFStWdL4dj8dVVFTkuywAQD/hVUDLly/Xiy++qO3bt2vUqFEfu21JSYkk6eDBg+csoGg0qmjU75fEAAD9l6mAnHO66667tHHjRm3btk3FxcUXzOzbt0+SVFBQ4LVAAMDAZCqg8vJyrV+/Xps2bVJGRobq6uokSbFYTEOGDNGhQ4e0fv16ffnLX9aIESO0f/9+3XPPPZo5c6amTJnSI+8AAKB/MhXQmjVrJJ39ZdP/bO3atVq8eLHS09O1ZcsWPfroo2pqalJRUZEWLlyoBx54oNsWDAAYGMw/gvs4RUVFqqqq+kQLAgBcGi7padhDPCcFL8v6jTlzMBExZ8qz7NOm/dgn8Z5l/52t5mSbOTM0Zag58+mc6gtvdA7LLttrzlyWal+fjzda7JOjG5J+a8tPtX9sB3tMYu+Q/fOixdnP15EpreaMJBWnDTdn9rTaz/HxafZjt6Mly5yRpPV/LDFnRlW+adq+3SV04CK2YxgpACAICggAEAQFBAAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACAICggAEAQFBAAIggICAARxaQ8j/brfX2Itm/B1cybaYB98mkzrne8POqJ++zlyvT0XybMPhcx4c4g5U/jLo+aMJLlU+/vUcdkwcyb1VIs5o2PHzRGXaLfvR1JkqH2IaWS4x+DTC0zYP6d2++BOX81X2f+Qps/n7fA9h82Z9mN15sxZ9kGzPYUrIABAEBQQACAICggAEAQFBAAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACAICggAEESfmwXn/n02VLsSkseYKNO+OuxzySSpPWGf45Xa7jELLtJLs+BS/PaTbPGYBddsP+YdbRFzpj3p97F1Ht+TdbSn2vfjc+65NnvEec6CS9q/NESS9uPgNQsu2Xuz4Nrb7Z/rSY9zqD1p/9i2O/vXlN7SrrNrcxf4+EbchbboZUeOHFFRUVHoZQAAPqHa2lqNGjXqvI/3uQJKJpM6evSoMjIyFIl0/c43Ho+rqKhItbW1yszMDLTC8DgOZ3EczuI4nMVxOKsvHAfnnE6dOqXCwkKlfMxPWPrcj+BSUlI+tjElKTMz85I+wT7AcTiL43AWx+EsjsNZoY9DLBa74Da8CAEAEAQFBAAIol8VUDQa1cqVKxWN+v0l04GC43AWx+EsjsNZHIez+tNx6HMvQgAAXBr61RUQAGDgoIAAAEFQQACAICggAEAQ/aaAVq9erU996lMaPHiwSkpK9Otf/zr0knrdd77zHUUikS63iRMnhl5Wj9u+fbtuuOEGFRYWKhKJ6Pnnn+/yuHNODz74oAoKCjRkyBCVlpbqwIEDYRbbgy50HBYvXvyR82PevHlhFttDKisrdfXVVysjI0O5ublasGCBqquru2zT0tKi8vJyjRgxQsOHD9fChQtVX18faMU942KOw6xZsz5yPtx5552BVnxu/aKAnnnmGa1YsUIrV67UW2+9palTp2ru3Lk6fvx46KX1uquuukrHjh3rvL3++uuhl9TjmpqaNHXqVK1evfqcj69atUo//vGP9fjjj2vXrl0aNmyY5s6dq5YW+yDJvuxCx0GS5s2b1+X8eOqpp3pxhT2vqqpK5eXl2rlzp1555RUlEgnNmTNHTU1Nndvcc889euGFF7RhwwZVVVXp6NGjuummmwKuuvtdzHGQpCVLlnQ5H1atWhVoxefh+oHp06e78vLyzrc7OjpcYWGhq6ysDLiq3rdy5Uo3derU0MsISpLbuHFj59vJZNLl5+e7hx9+uPO+hoYGF41G3VNPPRVghb3jw8fBOecWLVrk5s+fH2Q9oRw/ftxJclVVVc65sx/7tLQ0t2HDhs5tfve73zlJbseOHaGW2eM+fBycc+6LX/yi+/rXvx5uURehz18BtbW1ac+ePSotLe28LyUlRaWlpdqxY0fAlYVx4MABFRYWauzYsbr99tt1+PDh0EsKqqamRnV1dV3Oj1gsppKSkkvy/Ni2bZtyc3M1YcIELVu2TCdPngy9pB7V2NgoScrOzpYk7dmzR4lEosv5MHHiRI0ePXpAnw8fPg4fePLJJ5WTk6NJkyapoqJCzc3NIZZ3Xn1uGOmHnThxQh0dHcrLy+tyf15enn7/+98HWlUYJSUlWrdunSZMmKBjx47poYce0nXXXae3335bGRkZoZcXRF1dnSSd8/z44LFLxbx583TTTTepuLhYhw4d0re+9S2VlZVpx44dSk31+Fs9fVwymdTdd9+ta665RpMmTZJ09nxIT09XVlZWl20H8vlwruMgSbfddpvGjBmjwsJC7d+/X/fff7+qq6v13HPPBVxtV32+gPAfysrKOv89ZcoUlZSUaMyYMXr22Wf11a9+NeDK0Bfccsstnf+ePHmypkyZonHjxmnbtm2aPXt2wJX1jPLycr399tuXxPOgH+d8x2Hp0qWd/548ebIKCgo0e/ZsHTp0SOPGjevtZZ5Tn/8RXE5OjlJTUz/yKpb6+nrl5+cHWlXfkJWVpSuvvFIHDx4MvZRgPjgHOD8+auzYscrJyRmQ58fy5cv14osv6rXXXuvy51vy8/PV1tamhoaGLtsP1PPhfMfhXEpKSiSpT50Pfb6A0tPTNW3aNG3durXzvmQyqa1bt2rGjBkBVxbe6dOndejQIRUUFIReSjDFxcXKz8/vcn7E43Ht2rXrkj8/jhw5opMnTw6o88M5p+XLl2vjxo169dVXVVxc3OXxadOmKS0trcv5UF1drcOHDw+o8+FCx+Fc9u3bJ0l963wI/SqIi/H000+7aDTq1q1b537729+6pUuXuqysLFdXVxd6ab3qG9/4htu2bZurqalxb7zxhistLXU5OTnu+PHjoZfWo06dOuX27t3r9u7d6yS5Rx55xO3du9e99957zjnnfvCDH7isrCy3adMmt3//fjd//nxXXFzszpw5E3jl3evjjsOpU6fcvffe63bs2OFqamrcli1b3Oc+9zl3xRVXuJaWltBL7zbLli1zsVjMbdu2zR07dqzz1tzc3LnNnXfe6UaPHu1effVVt3v3bjdjxgw3Y8aMgKvufhc6DgcPHnTf/e533e7du11NTY3btGmTGzt2rJs5c2bglXfVLwrIOecee+wxN3r0aJeenu6mT5/udu7cGXpJve7mm292BQUFLj093V1++eXu5ptvdgcPHgy9rB732muvOUkfuS1atMg5d/al2N/+9rddXl6ei0ajbvbs2a66ujrsonvAxx2H5uZmN2fOHDdy5EiXlpbmxowZ45YsWTLgvkk71/svya1du7ZzmzNnzrivfe1r7rLLLnNDhw51N954ozt27Fi4RfeACx2Hw4cPu5kzZ7rs7GwXjUbd+PHj3Te/+U3X2NgYduEfwp9jAAAE0eefAwIADEwUEAAgCAoIABAEBQQACIICAgAEQQEBAIKggAAAQVBAAIAgKCAAQBAUEAAgCAoIABAEBQQACOL/AyBNQnoqGwl/AAAAAElFTkSuQmCC", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Sample prediction\n", "sample = preds[0]\n", "predictions = sample.preds\n", "img = sample.data\n", "\n", "img = np.array(img).reshape(28,28)\n", "plt.figure()\n", "plt.imshow(img)\n", "\n", "print(\"Predicted label:\", classes[np.argmax(predictions)])" ] }, { "cell_type": "markdown", "id": "7a26690a-9dc4-4c36-9904-568d73e2be3c", "metadata": { "tags": [] }, "source": [ "#### Stop Triton Server on each executor" ] }, { "cell_type": "code", "execution_count": null, "id": "e02838ba", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-02-04 14:00:18,330 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\n", "2025-02-04 14:00:28,520 - INFO - Sucessfully stopped 1 servers. \n" ] }, { "data": { "text/plain": [ "[True]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "server_manager.stop_servers()" ] }, { "cell_type": "code", "execution_count": 79, "id": "a0608fff-7cfb-489e-96c9-8e1d92e57562", "metadata": {}, "outputs": [], "source": [ "if not on_databricks: # on databricks, spark.stop() puts the cluster in a bad state\n", " spark.stop()" ] }, { "cell_type": "code", "execution_count": null, "id": "08de2664-3d60-487b-90da-6d0f3b8b9203", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "spark-dl-torch", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.11" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: examples/ML+DL-Examples/Spark-DL/dl_inference/requirements.txt ================================================ # Copyright (c) 2025, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. numpy pandas matplotlib portalocker pyarrow h5py pydot scikit-learn jupyterlab pyspark>=3.4.0 huggingface datasets transformers ipywidgets nvidia-pytriton ================================================ FILE: examples/ML+DL-Examples/Spark-DL/dl_inference/server_utils.py ================================================ # # Copyright (c) 2025, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import inspect import logging import os import socket import subprocess import sys import time from multiprocessing import Process from typing import Any, Callable, Dict, List, Optional, Set, Tuple import psutil import requests from pyspark import RDD from pyspark.sql import SparkSession logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) logger = logging.getLogger("ServerManager") # ----------------------------------------------------------------------------- # Helper Functions # ----------------------------------------------------------------------------- def _find_ports(num_ports: int, start_port: int = 7000) -> List[int]: """Find available ports on executor for server services.""" ports = [] conns = {conn.laddr.port for conn in psutil.net_connections(kind="inet")} i = start_port while len(ports) < num_ports: if i not in conns: ports.append(i) i += 1 return ports def _get_valid_vllm_parameters_task() -> Set[str]: """Task to get valid vLLM parameters on executor.""" from vllm.entrypoints.openai.cli_args import create_parser_for_docs parser = create_parser_for_docs() valid_args = set() for action in parser._actions: if action.dest not in [ "help", "host", "port", "served-model-name", "model", ]: valid_args.add(action.dest) return valid_args def _start_triton_server_task( triton_server_fn: Callable, model_name: str, wait_retries: int, wait_timeout: int, model_path: Optional[str] = None, ) -> List[tuple]: """Task to start Triton server process on a Spark executor.""" from pyspark import BarrierTaskContext from pytriton.client import ModelClient def _prepare_pytriton_env(): """Expose PyTriton to correct libpython3.11.so and Triton bundled libraries.""" ld_library_paths = [] # Add nvidia_pytriton.libs to LD_LIBRARY_PATH for path in sys.path: if os.path.isdir(path) and "site-packages" in path: libs_path = os.path.join(path, "nvidia_pytriton.libs") if os.path.isdir(libs_path): ld_library_paths.append(libs_path) break # Add ${CONDA_PREFIX}/lib to LD_LIBRARY_PATH for conda environments if os.path.exists(os.path.join(sys.prefix, "conda-meta")): conda_lib = os.path.join(sys.prefix, "lib") if os.path.isdir(conda_lib): ld_library_paths.append(conda_lib) if "LD_LIBRARY_PATH" in os.environ: ld_library_paths.append(os.environ["LD_LIBRARY_PATH"]) os.environ["LD_LIBRARY_PATH"] = ":".join(ld_library_paths) return None # Setup server function arguments tc = BarrierTaskContext.get() ports = _find_ports(num_ports=3) sig = inspect.signature(triton_server_fn) params = sig.parameters if model_path is not None: assert ( len(params) == 2 ), "Server function must accept (ports, model_path) when model_path is provided" args = (ports, model_path) else: assert len(params) == 1, "Server function must accept (ports) argument" args = (ports,) # Prepare and start server process _prepare_pytriton_env() hostname = socket.gethostname() process = Process(target=triton_server_fn, args=args) process.start() client = ModelClient(f"http://localhost:{ports[0]}", model_name) # Wait for server to start for _ in range(wait_retries): try: client.wait_for_model(wait_timeout) tc.barrier() client.close() return [(hostname, (process.pid, ports))] except Exception: if not process.is_alive(): # If process terminated due to an error, stop waiting break pass client.close() if process.is_alive(): # Terminate if timeout is exceeded to avoid dangling server processes process.terminate() raise TimeoutError( "Failure: Triton server startup failed or timed out. Check the executor logs for more info." ) def _start_vllm_server_task( model_name: str, model_path: str, wait_retries: int, wait_timeout: int, **kwargs, ) -> List[tuple]: """Task to start vLLM server process on a Spark executor.""" from pyspark import BarrierTaskContext tc = BarrierTaskContext.get() port = _find_ports(num_ports=1)[0] hostname = socket.gethostname() # Build command for vLLM server cmd = [ sys.executable, "-m", "vllm.entrypoints.openai.api_server", "--model", model_path, "--served-model-name", model_name, "--port", str(port), ] # Add additional args from kwargs for key, value in kwargs.items(): if isinstance(value, bool) and value: cmd.append(f"--{key}") elif not isinstance(value, bool): cmd.append(f"--{key}") cmd.append(str(value)) logger.info(f"Starting vLLM server with command: {' '.join(cmd)}") # vLLM does CUDA init at import time. Forking will try to re-initialize CUDA if vLLM was imported before and throw an error. os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" # Start server process process = subprocess.Popen(cmd, stdout=sys.stdout, stderr=sys.stderr, text=True) # Wait for server to start health_url = f"http://localhost:{port}/health" for _ in range(wait_retries): try: time.sleep(wait_timeout) response = requests.get(health_url) if response.status_code == 200: tc.barrier() return [(hostname, (process.pid, [port]))] except Exception: if process.poll() is not None: # If process terminated due to an error, stop waiting break pass if process.poll() is None: # Terminate if timeout is exceeded to avoid dangling server processes process.terminate() raise TimeoutError( "Failure: vLLM server startup failed or timed out. Check the executor logs for more info." ) def _stop_server_task( server_pids_ports: Dict[str, Tuple[int, List[int]]], wait_retries: int, wait_timeout: int, ) -> List[bool]: """Task to stop a server process on a Spark executor.""" hostname = socket.gethostname() pid, _ = server_pids_ports.get(hostname, (None, None)) assert pid is not None, f"No server PID found for host {hostname}" try: process = psutil.Process(pid) process.terminate() process.wait(timeout=wait_timeout * wait_retries) return [True] except psutil.NoSuchProcess: return [True] except psutil.TimeoutExpired: try: process.kill() return [True] except: return [False] # ----------------------------------------------------------------------------- # ServerManager Classes # ----------------------------------------------------------------------------- class ServerManager: """ Base class for server management across a Spark cluster. Attributes: spark: Active SparkSession num_executors: Number of servers to manage (= # of executors) model_name: Name of the served model model_path: Optional path to model files server_pids_ports: Dictionary of hostname to (server process ID, ports) """ DEFAULT_WAIT_RETRIES = 24 DEFAULT_WAIT_TIMEOUT = 5 def __init__(self, model_name: str, model_path: Optional[str] = None): """ Initialize the server manager. Args: model_name: Name of the model to serve model_path: Optional path to model file for server function to load from disk """ self.spark = SparkSession.getActiveSession() self.num_executors = self._get_num_executors() self.model_name = model_name self.model_path = model_path self._server_pids_ports: Dict[str, Tuple[int, List[int]]] = {} def _get_num_executors(self) -> int: """Get the number of executors in the cluster.""" return ( len( [ executor.host() for executor in self.spark._jsc.sc() .statusTracker() .getExecutorInfos() ] ) - 1 ) @property def host_to_http_url(self) -> Dict[str, str]: """Map hostname to client HTTP URL for server on that host.""" if not self._server_pids_ports: logger.warning("No urls available. Start servers first.") return None return { host: f"http://localhost:{ports[0]}" for host, (_, ports) in self._server_pids_ports.items() } def _get_node_rdd(self) -> RDD: """Create and configure RDD with stage-level scheduling for 1 task per executor.""" sc = self.spark.sparkContext node_rdd = sc.parallelize(list(range(self.num_executors)), self.num_executors) return self._use_stage_level_scheduling(node_rdd) def _use_stage_level_scheduling(self, rdd: RDD) -> RDD: """ Use stage-level scheduling to ensure each server instance maps to 1 executor. Adapted from https://github.com/NVIDIA/spark-rapids-ml/blob/main/python/src/spark_rapids_ml/core.py """ from pyspark.resource.profile import ResourceProfileBuilder from pyspark.resource.requests import TaskResourceRequests executor_cores = self.spark.conf.get("spark.executor.cores") assert executor_cores is not None, "spark.executor.cores is not set" executor_gpus = self.spark.conf.get("spark.executor.resource.gpu.amount") assert ( executor_gpus is not None ), "spark.executor.resource.gpu.amount is not set" spark_plugins = self.spark.conf.get("spark.plugins", " ") assert spark_plugins is not None spark_rapids_sql_enabled = self.spark.conf.get( "spark.rapids.sql.enabled", "true" ) assert spark_rapids_sql_enabled is not None task_cores = ( int(executor_cores) if "com.nvidia.spark.SQLPlugin" in spark_plugins and "true" == spark_rapids_sql_enabled.lower() else (int(executor_cores) // 2) + 1 ) task_gpus = float(executor_gpus) treqs = TaskResourceRequests().cpus(task_cores).resource("gpu", task_gpus) rp = ResourceProfileBuilder().require(treqs).build logger.info( f"Requesting stage-level resources: (cores={task_cores}, gpu={task_gpus})" ) return rdd.withResources(rp) def start_servers( self, start_server_fn: Callable, wait_retries: int = DEFAULT_WAIT_RETRIES, wait_timeout: int = DEFAULT_WAIT_TIMEOUT, **kwargs, ) -> Dict[str, Tuple[int, List[int]]]: """ Start servers across the cluster. Args: start_server_fn: Function used to start the server process wait_retries: Number of retries for waiting for server startup wait_timeout: Timeout in seconds for each retry **kwargs: Additional server-specific arguments Returns: Dictionary of hostname -> (server PID, [ports]) """ node_rdd = self._get_node_rdd() model_name = self.model_name model_path = self.model_path server_type = self.__class__.__name__.replace("ServerManager", "") logger.info(f"Starting {self.num_executors} {server_type} servers.") start_args = { "model_name": model_name, "wait_retries": wait_retries, "wait_timeout": wait_timeout, } if model_path is not None: start_args["model_path"] = model_path start_args.update(kwargs) self._server_pids_ports = ( node_rdd.barrier() .mapPartitions(lambda _: start_server_fn(**start_args)) .collectAsMap() ) return self._server_pids_ports def stop_servers( self, wait_retries: int = DEFAULT_WAIT_RETRIES, wait_timeout: int = DEFAULT_WAIT_TIMEOUT, ) -> List[bool]: """ Stop all servers across the cluster. Returns: List of booleans indicating success/failure of stopping each server """ if not self._server_pids_ports: logger.warning("No servers to stop.") return [] node_rdd = self._get_node_rdd() server_pids_ports = self._server_pids_ports server_type = self.__class__.__name__.replace("ServerManager", "") stop_success = ( node_rdd.barrier() .mapPartitions( lambda _: _stop_server_task( server_pids_ports=server_pids_ports, wait_retries=wait_retries, wait_timeout=wait_timeout, ) ) .collect() ) if all(stop_success): self._server_pids_ports.clear() logger.info( f"Successfully stopped {self.num_executors} {server_type} servers." ) else: logger.warning( f"{server_type} server termination failed or timed out. Check executor logs." ) return stop_success class TritonServerManager(ServerManager): """ Handle lifecycle of Triton server instances across Spark cluster. Example usage: >>> server_manager = TritonServerManager(model_name="my_model", model_path="/path/to/my_model") >>> # Define triton_server(ports, model_path) that contains PyTriton server logic >>> server_pids_ports = server_manager.start_servers(triton_server) >>> print(f"Servers started with PIDs/Ports: {server_pids_ports}") >>> host_to_http_url = server_manager.host_to_http_url >>> host_to_grpc_url = server_manager.host_to_grpc_url >>> # Define triton_fn() and predict_batch_udf(triton_fn) and run inference... >>> success = server_manager.stop_servers() >>> print(f"Server shutdown success: {success}") """ def __init__(self, model_name: str, model_path: Optional[str] = None): super().__init__(model_name, model_path) @property def host_to_grpc_url(self) -> Dict[str, str]: """Map hostname to client gRPC URL for Triton server on that host.""" if not self._server_pids_ports: logger.warning("No urls available. Start servers first.") return None return { host: f"grpc://localhost:{ports[1]}" for host, (_, ports) in self._server_pids_ports.items() } def start_servers( self, triton_server_fn: Callable, wait_retries: int = ServerManager.DEFAULT_WAIT_RETRIES, wait_timeout: int = ServerManager.DEFAULT_WAIT_TIMEOUT, ) -> Dict[str, Tuple[int, List[int]]]: """ Start Triton servers across the cluster. Args: triton_server_fn: PyTriton server function defining the model and inference logic wait_retries: Number of retries for waiting for server startup wait_timeout: Timeout in seconds for each retry Returns: Dictionary of hostname -> (server PID, [ports]) """ return super().start_servers( start_server_fn=_start_triton_server_task, wait_retries=wait_retries, wait_timeout=wait_timeout, triton_server_fn=triton_server_fn, ) class VLLMServerManager(ServerManager): """ Handle lifecycle of vLLM server instances across Spark cluster. Example usage: >>> server_manager = VLLMServerManager(model_name="my_llm", model_path="/path/to/my_llm") >>> server_manager.start_servers( >>> tensor_parallel_size=1, >>> max_num_seqs=1024, >>> gpu_memory_utilization=0.85, >>> ) >>> print(f"Servers started with PIDs/Ports: {server_pids_ports}") >>> host_to_http_url = server_manager.host_to_http_url >>> # Define vllm_fn() and predict_batch_udf(vllm_fn) and run inference... >>> success = server_manager.stop_servers() >>> print(f"Server shutdown success: {success}") """ def __init__(self, model_name: str, model_path: str = None): super().__init__(model_name, model_path) self.vllm_valid_parameters = self._get_valid_vllm_parameters() def _get_valid_vllm_parameters(self) -> List[str]: """Get valid vLLM parameters on executor.""" rdd = self.spark.sparkContext.parallelize(list(range(1)), 1) return rdd.mapPartitions(lambda _: _get_valid_vllm_parameters_task()).collect() def _validate_vllm_kwargs(self, kwargs: Dict[str, Any]): """Validate vLLM parameters.""" for key in kwargs: if key not in self.vllm_valid_parameters: if key == "host" or key == "port": raise ValueError( f"Invalid vLLM parameter: {key}. Host and port are set by server manager." ) elif key == "served-model-name": raise ValueError( f"Invalid vLLM parameter: {key}. Served model name is set via model_name." ) elif key == "model": raise ValueError( f"Invalid vLLM parameter: {key}. Model path is set via model_path." ) else: raise ValueError(f"Invalid vLLM parameter: {key}") def start_servers( self, wait_retries: int = ServerManager.DEFAULT_WAIT_RETRIES, wait_timeout: int = ServerManager.DEFAULT_WAIT_TIMEOUT, **kwargs, ) -> Dict[str, Tuple[int, List[int]]]: """ Start vLLM OpenAI-compatible servers across the cluster. Args: wait_retries: Number of retries for waiting for server startup wait_timeout: Timeout in seconds for each retry **kwargs: Additional arguments to pass to vLLM server command line e.g. tensor_parallel_size, max_num_seqs, gpu_memory_utilization, etc. See https://docs.vllm.ai/en/stable/serving/openai_compatible_server.html#vllm-serve Returns: Dictionary of hostname -> (server PID, [port]) """ self._validate_vllm_kwargs(kwargs) return super().start_servers( start_server_fn=_start_vllm_server_task, wait_retries=wait_retries, wait_timeout=wait_timeout, **kwargs, ) ================================================ FILE: examples/ML+DL-Examples/Spark-DL/dl_inference/tensorflow/image_classification_tf.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "id": "52d55e3f", "metadata": {}, "source": [ "\n", "\n", "# Pyspark TensorFlow Inference\n", "\n", "## Image classification\n", "This notebook demonstrates training and distributed inference for image classification on MNIST. \n", "Based on: https://www.tensorflow.org/tutorials/keras/save_and_load" ] }, { "cell_type": "markdown", "id": "5233632d", "metadata": {}, "source": [ "Note that cuFFT/cuDNN/cuBLAS registration errors are expected (as of `tf=2.17.0`) and will not affect behavior, as noted in [this issue.](https://github.com/tensorflow/tensorflow/issues/62075) " ] }, { "cell_type": "code", "execution_count": 1, "id": "c8b28f02", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-02-04 13:58:23.275397: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", "2025-02-04 13:58:23.282713: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "2025-02-04 13:58:23.290717: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "2025-02-04 13:58:23.293187: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", "2025-02-04 13:58:23.299616: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", "To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", "2025-02-04 13:58:23.677341: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n" ] } ], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import subprocess\n", "import shutil\n", "import os\n", "\n", "import tensorflow as tf\n", "from tensorflow import keras" ] }, { "cell_type": "code", "execution_count": 2, "id": "e2e67086", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.17.0\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", "I0000 00:00:1738706304.084788 3671509 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1738706304.107153 3671509 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1738706304.109954 3671509 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n" ] } ], "source": [ "print(tf.version.VERSION)\n", "\n", "# Enable GPU memory growth\n", "gpus = tf.config.experimental.list_physical_devices('GPU')\n", "if gpus:\n", " try:\n", " for gpu in gpus:\n", " tf.config.experimental.set_memory_growth(gpu, True)\n", " except RuntimeError as e:\n", " print(e)" ] }, { "cell_type": "markdown", "id": "7e0c7ad6", "metadata": {}, "source": [ "### Load and preprocess dataset\n", "\n", "Load MNIST and create a train/test split." ] }, { "cell_type": "code", "execution_count": 3, "id": "5b007f7c", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "((60000, 28, 28), (10000, 28, 28))" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()\n", "train_images.shape, test_images.shape" ] }, { "cell_type": "code", "execution_count": 4, "id": "7b7cedd1", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "((1000, 784), (1000, 784))" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_labels = train_labels[:1000]\n", "test_labels = test_labels[:1000]\n", "\n", "train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0\n", "test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0\n", "\n", "train_images.shape, test_images.shape" ] }, { "cell_type": "markdown", "id": "867a4403", "metadata": {}, "source": [ "### Define a model" ] }, { "cell_type": "code", "execution_count": 5, "id": "746d94db", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/rishic/anaconda3/envs/spark-dl-tf/lib/python3.11/site-packages/keras/src/layers/core/dense.py:87: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead.\n", " super().__init__(activity_regularizer=activity_regularizer, **kwargs)\n", "I0000 00:00:1738706304.278396 3671509 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "I0000 00:00:1738706304.281131 3671509 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1738706304.283741 3671509 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1738706304.403175 3671509 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1738706304.404296 3671509 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1738706304.405232 3671509 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "2025-02-04 13:58:24.406153: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 40769 MB memory: -> device: 0, name: NVIDIA RTX A6000, pci bus id: 0000:01:00.0, compute capability: 8.6\n" ] }, { "data": { "text/html": [ "
Model: \"sequential\"\n",
       "
\n" ], "text/plain": [ "\u001b[1mModel: \"sequential\"\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
       "┃ Layer (type)                     Output Shape                  Param # ┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
       "│ dense (Dense)                   │ (None, 512)            │       401,920 │\n",
       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
       "│ dropout (Dropout)               │ (None, 512)            │             0 │\n",
       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
       "│ dense_1 (Dense)                 │ (None, 10)             │         5,130 │\n",
       "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
       "
\n" ], "text/plain": [ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", "│ dense (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m401,920\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ dropout (\u001b[38;5;33mDropout\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ dense_1 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m10\u001b[0m) │ \u001b[38;5;34m5,130\u001b[0m │\n", "└─────────────────────────────────┴────────────────────────┴───────────────┘\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Total params: 407,050 (1.55 MB)\n",
       "
\n" ], "text/plain": [ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m407,050\u001b[0m (1.55 MB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Trainable params: 407,050 (1.55 MB)\n",
       "
\n" ], "text/plain": [ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m407,050\u001b[0m (1.55 MB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Non-trainable params: 0 (0.00 B)\n",
       "
\n" ], "text/plain": [ "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Define a simple sequential model\n", "def create_model():\n", " model = tf.keras.Sequential([\n", " keras.layers.Dense(512, activation='relu', input_shape=(784,)),\n", " keras.layers.Dropout(0.2),\n", " keras.layers.Dense(10)\n", " ])\n", "\n", " model.compile(optimizer='adam',\n", " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", " metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])\n", "\n", " return model\n", "\n", "# Create a basic model instance\n", "model = create_model()\n", "\n", "# Display the model's architecture\n", "model.summary()" ] }, { "cell_type": "markdown", "id": "605d082a", "metadata": {}, "source": [ "### Save checkpoints during training" ] }, { "cell_type": "code", "execution_count": 6, "id": "dde1a855", "metadata": {}, "outputs": [], "source": [ "os.mkdir(\"models\") if not os.path.exists(\"models\") else None" ] }, { "cell_type": "code", "execution_count": 7, "id": "244746be", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", "I0000 00:00:1738706304.982690 3671754 service.cc:146] XLA service 0x7f1464019260 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:\n", "I0000 00:00:1738706304.982718 3671754 service.cc:154] StreamExecutor device (0): NVIDIA RTX A6000, Compute Capability 8.6\n", "2025-02-04 13:58:24.999594: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.\n", "2025-02-04 13:58:25.043847: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:531] Loaded cuDNN version 8907\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m 1/32\u001b[0m \u001b[37m━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m26s\u001b[0m 868ms/step - loss: 2.4638 - sparse_categorical_accuracy: 0.0625" ] }, { "name": "stderr", "output_type": "stream", "text": [ "I0000 00:00:1738706305.619913 3671754 device_compiler.h:188] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 17ms/step - loss: 1.6323 - sparse_categorical_accuracy: 0.4913 " ] }, { "name": "stderr", "output_type": "stream", "text": [ "2025-02-04 13:58:26.791107: I external/local_xla/xla/stream_executor/cuda/cuda_asm_compiler.cc:393] ptxas warning : Registers are spilled to local memory in function 'gemm_fusion_dot_33', 4 bytes spill stores, 4 bytes spill loads\n", "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 1: val_sparse_categorical_accuracy improved from -inf to 0.76100, saving model to models/training_1/checkpoint.model.keras\n", "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 48ms/step - loss: 1.6179 - sparse_categorical_accuracy: 0.4965 - val_loss: 0.7533 - val_sparse_categorical_accuracy: 0.7610\n", "Epoch 2/10\n", "\u001b[1m 1/32\u001b[0m \u001b[37m━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - loss: 0.3965 - sparse_categorical_accuracy: 0.9062\n", "Epoch 2: val_sparse_categorical_accuracy improved from 0.76100 to 0.80400, saving model to models/training_1/checkpoint.model.keras\n", "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - loss: 0.4549 - sparse_categorical_accuracy: 0.8773 - val_loss: 0.6002 - val_sparse_categorical_accuracy: 0.8040\n", "Epoch 3/10\n", "\u001b[1m 1/32\u001b[0m \u001b[37m━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - loss: 0.4427 - sparse_categorical_accuracy: 0.8438\n", "Epoch 3: val_sparse_categorical_accuracy improved from 0.80400 to 0.85100, saving model to models/training_1/checkpoint.model.keras\n", "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - loss: 0.2924 - sparse_categorical_accuracy: 0.9289 - val_loss: 0.4876 - val_sparse_categorical_accuracy: 0.8510\n", "Epoch 4/10\n", "\u001b[1m 1/32\u001b[0m \u001b[37m━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - loss: 0.3644 - sparse_categorical_accuracy: 0.9375\n", "Epoch 4: val_sparse_categorical_accuracy did not improve from 0.85100\n", "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - loss: 0.2790 - sparse_categorical_accuracy: 0.9275 - val_loss: 0.4981 - val_sparse_categorical_accuracy: 0.8430\n", "Epoch 5/10\n", "\u001b[1m 1/32\u001b[0m \u001b[37m━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m0s\u001b[0m 12ms/step - loss: 0.2368 - sparse_categorical_accuracy: 0.9375\n", "Epoch 5: val_sparse_categorical_accuracy did not improve from 0.85100\n", "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - loss: 0.1794 - sparse_categorical_accuracy: 0.9645 - val_loss: 0.4893 - val_sparse_categorical_accuracy: 0.8450\n", "Epoch 6/10\n", "\u001b[1m 1/32\u001b[0m \u001b[37m━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m0s\u001b[0m 12ms/step - loss: 0.0830 - sparse_categorical_accuracy: 1.0000\n", "Epoch 6: val_sparse_categorical_accuracy improved from 0.85100 to 0.85400, saving model to models/training_1/checkpoint.model.keras\n", "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - loss: 0.1430 - sparse_categorical_accuracy: 0.9739 - val_loss: 0.4338 - val_sparse_categorical_accuracy: 0.8540\n", "Epoch 7/10\n", "\u001b[1m 1/32\u001b[0m \u001b[37m━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - loss: 0.1518 - sparse_categorical_accuracy: 1.0000\n", "Epoch 7: val_sparse_categorical_accuracy improved from 0.85400 to 0.86200, saving model to models/training_1/checkpoint.model.keras\n", "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - loss: 0.0876 - sparse_categorical_accuracy: 0.9909 - val_loss: 0.4194 - val_sparse_categorical_accuracy: 0.8620\n", "Epoch 8/10\n", "\u001b[1m 1/32\u001b[0m \u001b[37m━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m0s\u001b[0m 12ms/step - loss: 0.0209 - sparse_categorical_accuracy: 1.0000\n", "Epoch 8: val_sparse_categorical_accuracy improved from 0.86200 to 0.86800, saving model to models/training_1/checkpoint.model.keras\n", "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - loss: 0.0669 - sparse_categorical_accuracy: 0.9938 - val_loss: 0.4038 - val_sparse_categorical_accuracy: 0.8680\n", "Epoch 9/10\n", "\u001b[1m 1/32\u001b[0m \u001b[37m━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - loss: 0.0211 - sparse_categorical_accuracy: 1.0000\n", "Epoch 9: val_sparse_categorical_accuracy improved from 0.86800 to 0.86900, saving model to models/training_1/checkpoint.model.keras\n", "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - loss: 0.0429 - sparse_categorical_accuracy: 0.9998 - val_loss: 0.4062 - val_sparse_categorical_accuracy: 0.8690\n", "Epoch 10/10\n", "\u001b[1m 1/32\u001b[0m \u001b[37m━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - loss: 0.0283 - sparse_categorical_accuracy: 1.0000\n", "Epoch 10: val_sparse_categorical_accuracy did not improve from 0.86900\n", "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - loss: 0.0387 - sparse_categorical_accuracy: 0.9992 - val_loss: 0.4069 - val_sparse_categorical_accuracy: 0.8680\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "checkpoint_path = \"models/training_1/checkpoint.model.keras\"\n", "checkpoint_dir = os.path.dirname(checkpoint_path)\n", "\n", "# Create a callback that saves the model's weights\n", "cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,\n", " monitor='val_sparse_categorical_accuracy',\n", " mode='max',\n", " save_best_only=True,\n", " verbose=1)\n", "\n", "# Train the model with the new callback\n", "model.fit(train_images, \n", " train_labels, \n", " epochs=10,\n", " validation_data=(test_images, test_labels),\n", " callbacks=[cp_callback]) # Pass callback to training\n", "\n", "# This may generate warnings related to saving the state of the optimizer.\n", "# These warnings (and similar warnings throughout this notebook)\n", "# are in place to discourage outdated usage, and can be ignored." ] }, { "cell_type": "code", "execution_count": 8, "id": "310eae08", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['checkpoint.model.keras']" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "os.listdir(checkpoint_dir)" ] }, { "cell_type": "code", "execution_count": null, "id": "50eeb6e5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: models/mnist_model/assets\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: models/mnist_model/assets\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Saved artifact at 'models/mnist_model'. The following endpoints are available:\n", "\n", "* Endpoint 'serve'\n", " args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 784), dtype=tf.float32, name='keras_tensor')\n", "Output Type:\n", " TensorSpec(shape=(None, 10), dtype=tf.float32, name=None)\n", "Captures:\n", " 139734758151120: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 139734413261904: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 139739081696528: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 139734413262096: TensorSpec(shape=(), dtype=tf.resource, name=None)\n" ] } ], "source": [ "# Export model in saved_model format\n", "model.export(\"models/mnist_model\")" ] }, { "cell_type": "code", "execution_count": 10, "id": "6d3bba9e", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/rishic/anaconda3/envs/spark-dl-tf/lib/python3.11/site-packages/keras/src/layers/core/dense.py:87: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead.\n", " super().__init__(activity_regularizer=activity_regularizer, **kwargs)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "32/32 - 0s - 10ms/step - loss: 2.3876 - sparse_categorical_accuracy: 0.0840\n", "Untrained model, accuracy: 8.40%\n" ] } ], "source": [ "# Create a basic model instance\n", "model = create_model()\n", "\n", "# Evaluate the model\n", "loss, acc = model.evaluate(test_images, test_labels, verbose=2)\n", "print(\"Untrained model, accuracy: {:5.2f}%\".format(100 * acc))" ] }, { "cell_type": "code", "execution_count": null, "id": "22ad1708", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "32/32 - 0s - 704us/step - loss: 0.4062 - sparse_categorical_accuracy: 0.8690\n", "Restored model, accuracy: 86.90%\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/home/rishic/anaconda3/envs/spark-dl-tf/lib/python3.11/site-packages/keras/src/saving/saving_lib.py:713: UserWarning: Skipping variable loading for optimizer 'adam', because it has 2 variables whereas the saved optimizer has 10 variables. \n", " saveable.load_own_variables(weights_store.get(inner_path))\n" ] } ], "source": [ "# Load the weights from the checkpoint\n", "model.load_weights(checkpoint_path)\n", "\n", "# Re-evaluate the model\n", "loss, acc = model.evaluate(test_images, test_labels, verbose=2)\n", "print(\"Restored model, accuracy: {:5.2f}%\".format(100 * acc))" ] }, { "cell_type": "markdown", "id": "1c097d63", "metadata": {}, "source": [ "### Checkpoint callback options" ] }, { "cell_type": "code", "execution_count": 12, "id": "cb336e89", "metadata": {}, "outputs": [], "source": [ "os.mkdir(\"models/training_2\") if not os.path.exists(\"models/training_2\") else None" ] }, { "cell_type": "code", "execution_count": 13, "id": "750b6deb", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 5: saving model to models/training_2/cp-0005.weights.h5\n", "\n", "Epoch 10: saving model to models/training_2/cp-0010.weights.h5\n", "\n", "Epoch 15: saving model to models/training_2/cp-0015.weights.h5\n", "\n", "Epoch 20: saving model to models/training_2/cp-0020.weights.h5\n", "\n", "Epoch 25: saving model to models/training_2/cp-0025.weights.h5\n", "\n", "Epoch 30: saving model to models/training_2/cp-0030.weights.h5\n", "\n", "Epoch 35: saving model to models/training_2/cp-0035.weights.h5\n", "\n", "Epoch 40: saving model to models/training_2/cp-0040.weights.h5\n", "\n", "Epoch 45: saving model to models/training_2/cp-0045.weights.h5\n", "\n", "Epoch 50: saving model to models/training_2/cp-0050.weights.h5\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Include the epoch in the file name (uses `str.format`)\n", "checkpoint_path = \"models/training_2/cp-{epoch:04d}.weights.h5\"\n", "checkpoint_dir = os.path.dirname(checkpoint_path)\n", "\n", "batch_size = 32\n", "\n", "# Calculate the number of batches per epoch\n", "import math\n", "n_batches = len(train_images) / batch_size\n", "n_batches = math.ceil(n_batches) # round up the number of batches to the nearest whole integer\n", "\n", "# Create a callback that saves the model's weights every 5 epochs\n", "cp_callback = tf.keras.callbacks.ModelCheckpoint(\n", " filepath=checkpoint_path, \n", " verbose=1, \n", " save_weights_only=True,\n", " save_freq=5*n_batches)\n", "\n", "# Create a new model instance\n", "model = create_model()\n", "\n", "# Save the weights using the `checkpoint_path` format\n", "model.save_weights(checkpoint_path.format(epoch=0))\n", "\n", "# Train the model with the new callback\n", "model.fit(train_images, \n", " train_labels,\n", " epochs=50, \n", " batch_size=batch_size, \n", " callbacks=[cp_callback],\n", " validation_data=(test_images, test_labels),\n", " verbose=0)" ] }, { "cell_type": "code", "execution_count": 14, "id": "1c43fd3d", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['cp-0000.weights.h5',\n", " 'cp-0015.weights.h5',\n", " 'cp-0010.weights.h5',\n", " 'cp-0035.weights.h5',\n", " 'cp-0020.weights.h5',\n", " 'cp-0040.weights.h5',\n", " 'cp-0050.weights.h5',\n", " 'cp-0005.weights.h5',\n", " 'cp-0045.weights.h5',\n", " 'cp-0025.weights.h5',\n", " 'cp-0030.weights.h5']" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "os.listdir(checkpoint_dir)" ] }, { "cell_type": "code", "execution_count": 15, "id": "0d7ae715", "metadata": {}, "outputs": [], "source": [ "latest = \"models/training_2/cp-0030.weights.h5\"" ] }, { "cell_type": "code", "execution_count": 16, "id": "d345c6f7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "32/32 - 0s - 11ms/step - loss: 0.4827 - sparse_categorical_accuracy: 0.8740\n", "Restored model, accuracy: 87.40%\n" ] } ], "source": [ "# Create a new model instance\n", "model = create_model()\n", "\n", "# Load the previously saved weights\n", "model.load_weights(latest)\n", "\n", "# Re-evaluate the model from the latest checkpoint\n", "loss, acc = model.evaluate(test_images, test_labels, verbose=2)\n", "print(\"Restored model, accuracy: {:5.2f}%\".format(100 * acc))" ] }, { "cell_type": "markdown", "id": "a86f4700", "metadata": {}, "source": [ "## PySpark" ] }, { "cell_type": "code", "execution_count": 17, "id": "7fcf07bb", "metadata": {}, "outputs": [], "source": [ "from pyspark.ml.functions import predict_batch_udf\n", "from pyspark.sql.functions import struct, col, array, pandas_udf\n", "from pyspark.sql.types import *\n", "from pyspark.sql import SparkSession\n", "from pyspark import SparkConf\n", "import pandas as pd\n", "import json" ] }, { "cell_type": "markdown", "id": "50f02919", "metadata": {}, "source": [ "Check the cluster environment to handle any platform-specific Spark configurations." ] }, { "cell_type": "code", "execution_count": 18, "id": "4c81d510", "metadata": {}, "outputs": [], "source": [ "on_databricks = os.environ.get(\"DATABRICKS_RUNTIME_VERSION\", False)\n", "on_dataproc = os.environ.get(\"DATAPROC_IMAGE_VERSION\", False)\n", "on_standalone = not (on_databricks or on_dataproc)" ] }, { "cell_type": "markdown", "id": "c58f4df7", "metadata": {}, "source": [ "#### Create Spark Session\n", "\n", "For local standalone clusters, we'll connect to the cluster and create the Spark Session. \n", "For CSP environments, Spark will either be preconfigured (Databricks) or we'll need to create the Spark Session (Dataproc)." ] }, { "cell_type": "code", "execution_count": 19, "id": "2c022c24", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "25/02/04 13:58:33 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)\n", "25/02/04 13:58:33 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\n", "Setting default log level to \"WARN\".\n", "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n", "25/02/04 13:58:33 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n" ] } ], "source": [ "conf = SparkConf()\n", "\n", "if 'spark' not in globals():\n", " if on_standalone:\n", " import socket\n", " \n", " conda_env = os.environ.get(\"CONDA_PREFIX\")\n", " hostname = socket.gethostname()\n", " conf.setMaster(f\"spark://{hostname}:7077\")\n", " conf.set(\"spark.pyspark.python\", f\"{conda_env}/bin/python\")\n", " conf.set(\"spark.pyspark.driver.python\", f\"{conda_env}/bin/python\")\n", " elif on_dataproc:\n", " conf.set(\"spark.executorEnv.TF_GPU_ALLOCATOR\", \"cuda_malloc_async\")\n", "\n", " conf.set(\"spark.executor.cores\", \"8\")\n", " conf.set(\"spark.task.resource.gpu.amount\", \"0.125\")\n", " conf.set(\"spark.executor.resource.gpu.amount\", \"1\")\n", " conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"true\")\n", " conf.set(\"spark.python.worker.reuse\", \"true\")\n", "\n", "conf.set(\"spark.sql.execution.arrow.maxRecordsPerBatch\", \"1000\")\n", "spark = SparkSession.builder.appName(\"spark-dl-examples\").config(conf=conf).getOrCreate()\n", "sc = spark.sparkContext" ] }, { "cell_type": "markdown", "id": "c81d0b1b", "metadata": {}, "source": [ "### Create Spark Dataframe" ] }, { "cell_type": "code", "execution_count": 20, "id": "49ff5203", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(1000, 784)" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# numpy array to pandas DataFrame\n", "test_pdf = pd.DataFrame(test_images)\n", "test_pdf.shape" ] }, { "cell_type": "code", "execution_count": 21, "id": "182ee0c7", "metadata": {}, "outputs": [], "source": [ "df = spark.createDataFrame(test_pdf).repartition(8)" ] }, { "cell_type": "markdown", "id": "d4e1c7ec-64fa-43c4-9bcf-0868a401d1f2", "metadata": {}, "source": [ "### Save as Parquet (784 columns of float)" ] }, { "cell_type": "code", "execution_count": 22, "id": "0061c39a-0871-429e-a4ff-751d26bf4b04", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "25/02/04 13:58:35 WARN package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.\n", "[Stage 0:> (0 + 8) / 8]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 3.05 ms, sys: 1.22 ms, total: 4.26 ms\n", "Wall time: 1.93 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "%%time\n", "data_path_784 = \"spark-dl-datasets/mnist_784\"\n", "if on_databricks:\n", " dbutils.fs.mkdirs(\"/FileStore/spark-dl-datasets\")\n", " data_path_784 = \"dbfs:/FileStore/\" + data_path_784\n", "\n", "df.write.mode(\"overwrite\").parquet(data_path_784)" ] }, { "cell_type": "markdown", "id": "18315afb-3fa2-4953-9297-52c04dd70c32", "metadata": {}, "source": [ "### Save as Parquet (1 column of 784 float)" ] }, { "cell_type": "code", "execution_count": 23, "id": "302c73ec", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(1000, 1)" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_pdf['data'] = test_pdf.values.tolist()\n", "pdf = test_pdf[['data']]\n", "pdf.shape" ] }, { "cell_type": "code", "execution_count": 24, "id": "5495901b", "metadata": {}, "outputs": [], "source": [ "df = spark.createDataFrame(pdf)" ] }, { "cell_type": "code", "execution_count": 25, "id": "5fa7faa8-c6bd-41b0-b5f7-fb121f0332e6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 875 μs, sys: 187 μs, total: 1.06 ms\n", "Wall time: 196 ms\n" ] } ], "source": [ "%%time\n", "data_path_1 = \"spark-dl-datasets/mnist_1\"\n", "if on_databricks:\n", " dbutils.fs.mkdirs(\"/FileStore/spark-dl-datasets\")\n", " data_path_1 = \"dbfs:/FileStore/\" + data_path_1\n", "\n", "df.write.mode(\"overwrite\").parquet(data_path_1)" ] }, { "cell_type": "markdown", "id": "b366aaeb", "metadata": {}, "source": [ "## Inference using Spark DL API\n", "\n", "Distributed inference using the PySpark [predict_batch_udf](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.functions.predict_batch_udf.html#pyspark.ml.functions.predict_batch_udf):\n", "\n", "- predict_batch_fn uses Tensorflow APIs to load the model and return a predict function which operates on numpy arrays \n", "- predict_batch_udf will convert the Spark DataFrame columns into numpy input batches for the predict function" ] }, { "cell_type": "markdown", "id": "4238fb28-d002-4b4d-9aa1-8af1fbd5d569", "metadata": {}, "source": [ "### 1 column of 784 float" ] }, { "cell_type": "code", "execution_count": 26, "id": "b9cf62f8-96b2-4716-80bd-bb93d5f939bd", "metadata": {}, "outputs": [], "source": [ "model_path = \"{}/models/training_1/checkpoint.model.keras\".format(os.getcwd())\n", "\n", "# For cloud environments, copy the model to the distributed file system.\n", "if on_databricks:\n", " dbutils.fs.mkdirs(\"/FileStore/spark-dl-models\")\n", " dbfs_model_path = \"/dbfs/FileStore/spark-dl-models/checkpoint.model.keras\"\n", " shutil.copy(model_path, dbfs_model_path)\n", " model_path = dbfs_model_path\n", "elif on_dataproc:\n", " # GCS is mounted at /mnt/gcs by the init script\n", " models_dir = \"/mnt/gcs/spark-dl/models\"\n", " os.mkdir(models_dir) if not os.path.exists(models_dir) else None\n", " gcs_model_path = models_dir + \"/checkpoint.model.keras\"\n", " shutil.copy(model_path, gcs_model_path)\n", " model_path = gcs_model_path" ] }, { "cell_type": "code", "execution_count": 27, "id": "b81fa297-d9d0-4600-880d-dbdcdf8bccc6", "metadata": {}, "outputs": [], "source": [ "def predict_batch_fn():\n", " import tensorflow as tf\n", "\n", " # Enable GPU memory growth to avoid CUDA OOM\n", " gpus = tf.config.experimental.list_physical_devices('GPU')\n", " if gpus:\n", " try:\n", " for gpu in gpus:\n", " tf.config.experimental.set_memory_growth(gpu, True)\n", " except RuntimeError as e:\n", " print(e)\n", "\n", " model = tf.keras.models.load_model(model_path)\n", " def predict(inputs: np.ndarray) -> np.ndarray:\n", " return model.predict(inputs)\n", " \n", " return predict" ] }, { "cell_type": "code", "execution_count": 28, "id": "72a689bd-dd82-492e-8740-1738a215325f", "metadata": {}, "outputs": [], "source": [ "mnist = predict_batch_udf(predict_batch_fn,\n", " return_type=ArrayType(FloatType()),\n", " batch_size=128,\n", " input_tensor_shapes=[[784]])" ] }, { "cell_type": "code", "execution_count": 29, "id": "60a70150-26b1-4145-9e7d-6e17389216b7", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = spark.read.parquet(data_path_1)\n", "len(df.columns)" ] }, { "cell_type": "code", "execution_count": 30, "id": "e027f0d2-0f65-47b7-a562-2f0965faceec", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+--------------------+\n", "| data|\n", "+--------------------+\n", "|[0.0, 0.0, 0.0, 0...|\n", "|[0.0, 0.0, 0.0, 0...|\n", "|[0.0, 0.0, 0.0, 0...|\n", "|[0.0, 0.0, 0.0, 0...|\n", "|[0.0, 0.0, 0.0, 0...|\n", "+--------------------+\n", "only showing top 5 rows\n", "\n" ] } ], "source": [ "df.show(5)" ] }, { "cell_type": "code", "execution_count": 31, "id": "f0c3fb2e-469e-47bc-b948-8f6b0d7f6513", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 6:===================================================> (7 + 1) / 8]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 24.1 ms, sys: 11 ms, total: 35.2 ms\n", "Wall time: 5.52 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "%%time\n", "# first pass caches model/fn\n", "preds = df.withColumn(\"preds\", mnist(struct(df.columns))).collect()" ] }, { "cell_type": "code", "execution_count": 32, "id": "cdfa229a-f4a9-4c11-a410-de4a21c02c82", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 21.1 ms, sys: 14.7 ms, total: 35.8 ms\n", "Wall time: 277 ms\n" ] } ], "source": [ "%%time\n", "preds = df.withColumn(\"preds\", mnist(*df.columns)).collect()" ] }, { "cell_type": "code", "execution_count": 33, "id": "5586ce49-6f93-4343-9b66-0dbb64972179", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 37.1 ms, sys: 8.46 ms, total: 45.6 ms\n", "Wall time: 216 ms\n" ] } ], "source": [ "%%time\n", "preds = df.withColumn(\"preds\", mnist(*[col(c) for c in df.columns])).collect()" ] }, { "cell_type": "markdown", "id": "004f1599-3c62-499e-9fd8-ed5cb0c90de4", "metadata": { "tags": [] }, "source": [ "#### Check predictions" ] }, { "cell_type": "code", "execution_count": 34, "id": "4f947dc0-6b18-4605-810b-e83250a161db", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
datapreds
0[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...[-4.6654954, -2.4895542, -0.5886033, 13.380537...
1[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...[-2.273215, -7.5127845, 1.1983701, -3.540661, ...
2[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...[-2.28909, 0.8308607, 0.31311005, 1.1683632, -...
3[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...[-1.0551968, -6.5028114, 12.420729, 0.45280308...
4[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...[-3.7887802, 3.9983602, -1.5343361, -0.3698440...
5[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...[-4.499274, -1.7618222, 1.1183227, 3.946932, -...
6[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...[-2.7540536, 4.8684144, 0.25152916, -0.4730078...
7[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...[-1.8887109, 0.02717152, -6.0508857, 0.0875094...
8[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...[0.9541265, -2.113048, -1.7508972, -5.4303794,...
9[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...[-1.612412, -0.7655784, -4.473859, 2.0609212, ...
\n", "
" ], "text/plain": [ " data \\\n", "0 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", "1 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", "2 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", "3 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", "4 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", "5 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", "6 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", "7 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", "8 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", "9 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", "\n", " preds \n", "0 [-4.6654954, -2.4895542, -0.5886033, 13.380537... \n", "1 [-2.273215, -7.5127845, 1.1983701, -3.540661, ... \n", "2 [-2.28909, 0.8308607, 0.31311005, 1.1683632, -... \n", "3 [-1.0551968, -6.5028114, 12.420729, 0.45280308... \n", "4 [-3.7887802, 3.9983602, -1.5343361, -0.3698440... \n", "5 [-4.499274, -1.7618222, 1.1183227, 3.946932, -... \n", "6 [-2.7540536, 4.8684144, 0.25152916, -0.4730078... \n", "7 [-1.8887109, 0.02717152, -6.0508857, 0.0875094... \n", "8 [0.9541265, -2.113048, -1.7508972, -5.4303794,... \n", "9 [-1.612412, -0.7655784, -4.473859, 2.0609212, ... " ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ "preds = df.withColumn(\"preds\", mnist(*df.columns)).limit(10).toPandas()\n", "preds" ] }, { "cell_type": "code", "execution_count": 35, "id": "de4964e0-d1f8-4753-afa1-a8f95ca3f151", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([-4.6654954, -2.4895542, -0.5886033, 13.380537 , -6.652599 ,\n", " 2.8400383, -7.9901567, -0.7500452, -2.4487166, -4.349809 ],\n", " dtype=float32)" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sample = preds.iloc[0]\n", "sample.preds" ] }, { "cell_type": "code", "execution_count": 36, "id": "44e9a874-e301-4b72-8df7-bf1c5133c287", "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": 37, "id": "c60e5af4-fc1e-4575-a717-f304664235be", "metadata": {}, "outputs": [], "source": [ "prediction = np.argmax(sample.preds)\n", "img = np.array(sample.data).reshape(28,28)" ] }, { "cell_type": "code", "execution_count": null, "id": "eb45ecc9-d376-40c4-ad7b-2bd08ca5aaf6", "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGzCAYAAABpdMNsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAkfElEQVR4nO3dfXQUdZ7v8U/nqSEkaR7yLAFCFHRAcAYly/AgSiQEZUCYGUG9F7gziJiggI6KR0Udzsksrg7qIHjcHVhHEGWOyMoiDg9JGBRwwTCIM2QhJ0g4kIBcSYcAIaR/9w+uvbQkQDUdfkl4v86pc+iq37fqm6Lgk+qqrnYZY4wAALjKwmw3AAC4NhFAAAArCCAAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKwggAAAVhBAgAPdunXTpEmT/K8LCwvlcrlUWFgYsm24XC698MILIVsf0FwRQGgxlixZIpfL5Z/atGmjHj16KC8vT5WVlbbbc2TNmjUtJmTefvtt3X777UpKSpLb7VZ6eromT56s/fv3224NLVyE7QYAp1566SWlp6fr9OnT2rx5sxYuXKg1a9Zo9+7dio6Ovqq9DBkyRKdOnVJUVJSjujVr1mjBggUNhtCpU6cUEdF8/mkWFxcrPT1dP/vZz9ShQweVlZXp7bff1urVq/W3v/1NqamptltEC9V8jnLgMuXk5OjWW2+VJP36179Wp06d9Oqrr2rVqlWaMGFCgzU1NTVq165dyHsJCwtTmzZtQrrOUK/vSr355psXzBszZoxuvfVWvfPOO3r66actdIXWgLfg0OLdeeedkqSysjJJ0qRJkxQTE6PS0lKNHDlSsbGxeuCBByRJPp9P8+fPV69evdSmTRslJSVp6tSp+u677wLWaYzR3Llz1blzZ0VHR+uOO+7Q119/fcG2G7sGtG3bNo0cOVIdOnRQu3bt1KdPH7322mv+/hYsWCBJAW8pfq+ha0DFxcXKyclRXFycYmJiNGzYMG3dujVgzPdvUX722WeaNWuWEhIS1K5dO9177706evRowNiqqirt2bNHVVVVl7OLL9CtWzdJ0vHjx4OqByTOgNAKlJaWSpI6derkn3f27FllZ2dr0KBB+pd/+Rf/W3NTp07VkiVLNHnyZD366KMqKyvTH/7wBxUXF+uzzz5TZGSkJOn555/X3LlzNXLkSI0cOVJffvmlhg8frjNnzlyyn3Xr1umee+5RSkqKHnvsMSUnJ+sf//iHVq9erccee0xTp07VoUOHtG7dOv3pT3+65Pq+/vprDR48WHFxcXryyScVGRmpt956S0OHDlVRUZEyMzMDxk+fPl0dOnTQnDlztH//fs2fP195eXl6//33/WNWrlypyZMna/HixQE3VVzMsWPHVF9frwMHDuill16SJA0bNuyyaoEGGaCFWLx4sZFk1q9fb44ePWrKy8vN8uXLTadOnUzbtm3NwYMHjTHGTJw40UgyTz/9dED9X//6VyPJLF26NGD+2rVrA+YfOXLEREVFmbvvvtv4fD7/uGeeecZIMhMnTvTPKygoMJJMQUGBMcaYs2fPmvT0dNO1a1fz3XffBWzn/HXl5uaaxv75STJz5szxvx4zZoyJiooypaWl/nmHDh0ysbGxZsiQIRfsn6ysrIBtzZw504SHh5vjx49fMHbx4sUN9tAQt9ttJBlJplOnTub111+/7FqgIbwFhxYnKytLCQkJSktL0/jx4xUTE6OVK1fquuuuCxg3bdq0gNcrVqyQx+PRXXfdpW+//dY/9evXTzExMSooKJAkrV+/XmfOnNH06dMD3hqbMWPGJXsrLi5WWVmZZsyYofbt2wcsO39dl6u+vl5/+ctfNGbMGHXv3t0/PyUlRffff782b94sr9cbUPPQQw8FbGvw4MGqr6/XN9984583adIkGWMu++xHkj755BOtWbNGr7zyirp06aKamhrHPw9wPt6CQ4uzYMEC9ejRQxEREUpKSlLPnj0VFhb4u1RERIQ6d+4cMG/v3r2qqqpSYmJig+s9cuSIJPn/o77hhhsClickJKhDhw4X7e37twN79+59+T/QRRw9elQnT55Uz549L1h20003yefzqby8XL169fLP79KlS8C473v+4XUup+644w5J524CGT16tHr37q2YmBjl5eVd0Xpx7SKA0OL079/ffxdcY9xu9wWh5PP5lJiYqKVLlzZYk5CQELIebQoPD29wvjEmZNvIyMjQj3/8Yy1dupQAQtAIIFwzMjIytH79eg0cOFBt27ZtdFzXrl0lnTtjOv9tr6NHj17yLCIjI0OStHv3bmVlZTU67nLfjktISFB0dLRKSkouWLZnzx6FhYUpLS3tstYVaqdOnVJtba2VbaN14BoQrhm//OUvVV9fr9/+9rcXLDt79qz/luKsrCxFRkbqjTfeCDhrmD9//iW38ZOf/ETp6emaP3/+Bbcon7+u7z+TdKnbmMPDwzV8+HCtWrUq4MkDlZWVWrZsmQYNGqS4uLhL9vVDl3sb9tmzZxsM3S+++EJfffXVJc9EgYvhDAjXjNtvv11Tp05Vfn6+du7cqeHDhysyMlJ79+7VihUr9Nprr+nnP/+5EhIS9MQTTyg/P1/33HOPRo4cqeLiYn3yySeKj4+/6DbCwsK0cOFCjRo1SrfccosmT56slJQU7dmzR19//bU+/fRTSVK/fv0kSY8++qiys7MVHh6u8ePHN7jOuXPnat26dRo0aJAeeeQRRURE6K233lJtba3mzZsX1L643NuwT5w4obS0NN13333q1auX2rVrp6+++kqLFy+Wx+PRc889F9T2AYkAwjVm0aJF6tevn9566y0988wzioiIULdu3fTggw9q4MCB/nFz585VmzZttGjRIhUUFCgzM1N/+ctfdPfdd19yG9nZ2SooKNCLL76oV155RT6fTxkZGZoyZYp/zNixYzV9+nQtX75c7777rowxjQZQr1699Ne//lWzZ89Wfn6+fD6fMjMz9e67717wGaBQi46O1q9//WsVFBToz3/+s06dOqXU1FRNmDBBzz77rP8DqUAwXCaUVyYBALhMXAMCAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMCKZvc5IJ/Pp0OHDik2NjaopwcDAOwyxqi6ulqpqakXPJPxfM0ugA4dOmTt2VYAgNApLy+/4Kn052t2ARQbGytJGqSRilCk5W4AAE6dVZ02a43///PGNFkALViwQC+//LIqKirUt29fvfHGG+rfv/8l675/2y1CkYpwEUAA0OL8/+frXOoySpPchPD+++9r1qxZmjNnjr788kv17dtX2dnZ/i/8AgCgSQLo1Vdf1ZQpUzR58mT96Ec/0qJFixQdHa0//vGPTbE5AEALFPIAOnPmjHbs2BHwZVxhYWHKysrSli1bLhhfW1srr9cbMAEAWr+QB9C3336r+vp6JSUlBcxPSkpSRUXFBePz8/Pl8Xj8E3fAAcC1wfoHUWfPnq2qqir/VF5ebrslAMBVEPK74OLj4xUeHq7KysqA+ZWVlUpOTr5gvNvtltvtDnUbAIBmLuRnQFFRUerXr582bNjgn+fz+bRhwwYNGDAg1JsDALRQTfI5oFmzZmnixIm69dZb1b9/f82fP181NTWaPHlyU2wOANACNUkA3XfffTp69Kief/55VVRU6JZbbtHatWsvuDEBAHDtchljjO0mzuf1euXxeDRUo3kSAgC0QGdNnQq1SlVVVYqLi2t0nPW74AAA1yYCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWBFhuwHYZ37aN7i6cOe/v0RWeh3XlP7vRMc1vu6nHNdI0p7b/+i4JtzlfD9MPTjAcU3hp7c4run2nzWOayRJW3cFVwc4wBkQAMAKAggAYEXIA+iFF16Qy+UKmG688cZQbwYA0MI1yTWgXr16af369f+zkQguNQEAAjVJMkRERCg5ObkpVg0AaCWa5BrQ3r17lZqaqu7du+uBBx7QgQMHGh1bW1srr9cbMAEAWr+QB1BmZqaWLFmitWvXauHChSorK9PgwYNVXV3d4Pj8/Hx5PB7/lJaWFuqWAADNUMgDKCcnR7/4xS/Up08fZWdna82aNTp+/Lg++OCDBsfPnj1bVVVV/qm8vDzULQEAmqEmvzugffv26tGjh/bt29fgcrfbLbfb3dRtAACamSb/HNCJEydUWlqqlJSUpt4UAKAFCXkAPfHEEyoqKtL+/fv1+eef695771V4eLgmTJgQ6k0BAFqwkL8Fd/DgQU2YMEHHjh1TQkKCBg0apK1btyohISHUmwIAtGAuY4yx3cT5vF6vPB6Phmq0IlyRttuxqubnmY5rKm91flK7dsLLjmskqUtEW8c1/2v/XY5r/tRtneManFN8xhdU3eOP5zmuif5wW1DbQutz1tSpUKtUVVWluLi4RsfxLDgAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKwggAAAVhBAAAArCCAAgBUEEADACgIIAGAFAQQAsIKHkV4lR/J+6rim8OlXHNdEu6Ic1zR339afclzTxhXc71Z1cv7PYcaBexzX/DLxvxzX3B1d5bgmWPvqah3XPDH4l45rzpYfdFyD5o+HkQIAmjUCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsiLDdwLXCF+68pjU+2frlYz9yXLNhxiDHNfVtg/vd6rvrnT+B/br/POy45s2EcY5r7v7zHx3XBGvsf011XNPt+P7QN4JWjTMgAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCh5FeJan/+jfHNR88kui4Jjv6gOOanDlPOK6RpLoYl+Oa6/7joOOaiP07nNc4rjgnOYia+iBqKu/5aRBVV8+uny5xXDMmiAes+qqrHdeg9eAMCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCs4GGkV4mvpsZxzTs90xzXvJ0z1nFNfEGx4xpJ8p0+7bjmbFBbunrCExIc13x3V4bjmscf/sBxDdDacAYEALCCAAIAWOE4gDZt2qRRo0YpNTVVLpdLH330UcByY4yef/55paSkqG3btsrKytLevXtD1S8AoJVwHEA1NTXq27evFixY0ODyefPm6fXXX9eiRYu0bds2tWvXTtnZ2TodxPUCAEDr5fgmhJycHOXk5DS4zBij+fPn69lnn9Xo0aMlSe+8846SkpL00Ucfafz48VfWLQCg1QjpNaCysjJVVFQoKyvLP8/j8SgzM1NbtmxpsKa2tlZerzdgAgC0fiENoIqKCklSUlJSwPykpCT/sh/Kz8+Xx+PxT2lpzm89BgC0PNbvgps9e7aqqqr8U3l5ue2WAABXQUgDKDk5WZJUWVkZML+ystK/7Ifcbrfi4uICJgBA6xfSAEpPT1dycrI2bNjgn+f1erVt2zYNGDAglJsCALRwju+CO3HihPbt2+d/XVZWpp07d6pjx47q0qWLZsyYoblz5+qGG25Qenq6nnvuOaWmpmrMmDGh7BsA0MI5DqDt27frjjvu8L+eNWuWJGnixIlasmSJnnzySdXU1Oihhx7S8ePHNWjQIK1du1Zt2rQJXdcAgBbPZYwxtps4n9frlcfj0VCNVoQr0nY7aKHC23uCqnt8x2bHNUPanAlqW1eDT76g6n73bV/HNduGpTiuqf/2mOMaNH9nTZ0KtUpVVVUXva5v/S44AMC1iQACAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwggACAFhBAAEArCCAAABWEEAAACscfx0D0BKUPdorqLohbTaGuBO7VtXEB1X3ed+oIKp4sjWc4QwIAGAFAQQAsIIAAgBYQQABAKwggAAAVhBAAAArCCAAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKzgYaRAK3ZX28NB1c2d9YDjmrqYoDblWGLxWcc1bT7+ogk6wZXiDAgAYAUBBACwggACAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwggACAFhBAAEArOBhpGiV0jacDKpuxyTnNf3cQW3qqogJC665HY+/EeJOQmfOkR87rtnxMb9rN0f8rQAArCCAAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwggACAFhBAAEArCCAAABWEEAAACsIIACAFTyMFK2S67OdQdW9NOp+xzW1STGOa6of9zqu+eyW5Y5rWqNnE7Y7rrnzwUeD2pbn3a1B1eHycAYEALCCAAIAWOE4gDZt2qRRo0YpNTVVLpdLH330UcDySZMmyeVyBUwjRowIVb8AgFbCcQDV1NSob9++WrBgQaNjRowYocOHD/un995774qaBAC0Po5vQsjJyVFOTs5Fx7jdbiUnJwfdFACg9WuSa0CFhYVKTExUz549NW3aNB07dqzRsbW1tfJ6vQETAKD1C3kAjRgxQu+88442bNigf/7nf1ZRUZFycnJUX1/f4Pj8/Hx5PB7/lJaWFuqWAADNUMg/BzR+/Hj/n2+++Wb16dNHGRkZKiws1LBhwy4YP3v2bM2aNcv/2uv1EkIAcA1o8tuwu3fvrvj4eO3bt6/B5W63W3FxcQETAKD1a/IAOnjwoI4dO6aUlJSm3hQAoAVx/BbciRMnAs5mysrKtHPnTnXs2FEdO3bUiy++qHHjxik5OVmlpaV68skndf311ys7OzukjQMAWjbHAbR9+3bdcccd/tffX7+ZOHGiFi5cqF27dunf//3fdfz4caWmpmr48OH67W9/K7fbHbquAQAtnuMAGjp0qIwxjS7/9NNPr6ghwKb6r0sc10R87Xw7HQpcjmtGRf3Ucc3+P/VwXCNJn2QudFzTOaJtUNtyKtIV7rjmdMfgrjZ4gqrC5eJZcAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALAi5F/JDeAyXOSJ8o2W1NY6run6y68c10jSnW/NdFzz3/csCmpbuHZxBgQAsIIAAgBYQQABAKwggAAAVhBAAAArCCAAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKwggAAAVvAwUqAVc0VGBVfXtj7EnYTOrjPOe0vcXtMEneBKcQYEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFbwMFKgFSt545ag6v572MLQNhJCMx6f7rgm+vNtTdAJrhRnQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQ8jRasUFhsbXF17T4g7adiRu9Ic19w1/TPHNf+RuMBxzTlX53fTD04kOq6J2/KN45qzjitwNXAGBACwggACAFjhKIDy8/N12223KTY2VomJiRozZoxKSkoCxpw+fVq5ubnq1KmTYmJiNG7cOFVWVoa0aQBAy+cogIqKipSbm6utW7dq3bp1qqur0/Dhw1VTU+MfM3PmTH388cdasWKFioqKdOjQIY0dOzbkjQMAWjZHNyGsXbs24PWSJUuUmJioHTt2aMiQIaqqqtK//du/admyZbrzzjslSYsXL9ZNN92krVu36p/+6Z9C1zkAoEW7omtAVVVVkqSOHTtKknbs2KG6ujplZWX5x9x4443q0qWLtmzZ0uA6amtr5fV6AyYAQOsXdAD5fD7NmDFDAwcOVO/evSVJFRUVioqKUvv27QPGJiUlqaKiosH15Ofny+Px+Ke0NOe3pwIAWp6gAyg3N1e7d+/W8uXLr6iB2bNnq6qqyj+Vl5df0foAAC1DUB9EzcvL0+rVq7Vp0yZ17tzZPz85OVlnzpzR8ePHA86CKisrlZyc3OC63G633G53MG0AAFowR2dAxhjl5eVp5cqV2rhxo9LT0wOW9+vXT5GRkdqwYYN/XklJiQ4cOKABAwaEpmMAQKvg6AwoNzdXy5Yt06pVqxQbG+u/ruPxeNS2bVt5PB796le/0qxZs9SxY0fFxcVp+vTpGjBgAHfAAQACOAqghQsXSpKGDh0aMH/x4sWaNGmSJOn3v/+9wsLCNG7cONXW1io7O1tvvvlmSJoFALQeLmOMsd3E+bxerzwej4ZqtCJckbbbuSaE9b0pqLo9uTGOa5LT/q/jmiMlCY5rJt9Z6LhGkp7q9HVQdQhOn88nOa7p8ouvQt8IQuqsqVOhVqmqqkpxcXGNjuNZcAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALAiqG9ERfPl6tfLcU3b3x8Jalv/nfFuUHWO9bk6m2nuak2d45pIV3hQ26qsr3VcM+dQjuOazq8F1x9aB86AAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwggACAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMAKHkbaytR0jXFc8373fw1ya1FB1jU9n3xB1c08NNhxzW8S1zuuyf4813FNbGG045rqbo5LJEnps7cEUVXtuCJMO4PYDloLzoAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoeRtrKRH+4zXHNLw4/HNS2jv64neMaXxDPL61z/nxVvf1//uC8SFLpbacd10z78VTHNek7dzmukTGOS+KdbwW4ajgDAgBYQQABAKwggAAAVhBAAAArCCAAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKwggAAAVhBAAAArXMYE8YTDJuT1euXxeDRUoxXhirTdDgDAobOmToVapaqqKsXFxTU6jjMgAIAVBBAAwApHAZSfn6/bbrtNsbGxSkxM1JgxY1RSUhIwZujQoXK5XAHTww8H930zAIDWy1EAFRUVKTc3V1u3btW6detUV1en4cOHq6amJmDclClTdPjwYf80b968kDYNAGj5HH0j6tq1awNeL1myRImJidqxY4eGDBninx8dHa3k5OTQdAgAaJWu6BpQVVWVJKljx44B85cuXar4+Hj17t1bs2fP1smTJxtdR21trbxeb8AEAGj9HJ0Bnc/n82nGjBkaOHCgevfu7Z9///33q2vXrkpNTdWuXbv01FNPqaSkRB9++GGD68nPz9eLL74YbBsAgBYq6M8BTZs2TZ988ok2b96szp07Nzpu48aNGjZsmPbt26eMjIwLltfW1qq2ttb/2uv1Ki0tjc8BAUALdbmfAwrqDCgvL0+rV6/Wpk2bLho+kpSZmSlJjQaQ2+2W2+0Opg0AQAvmKICMMZo+fbpWrlypwsJCpaenX7Jm586dkqSUlJSgGgQAtE6OAig3N1fLli3TqlWrFBsbq4qKCkmSx+NR27ZtVVpaqmXLlmnkyJHq1KmTdu3apZkzZ2rIkCHq06dPk/wAAICWydE1IJfL1eD8xYsXa9KkSSovL9eDDz6o3bt3q6amRmlpabr33nv17LPPXvR9wPPxLDgAaNma5BrQpbIqLS1NRUVFTlYJALhG8Sw4AIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVEbYb+CFjjCTprOokY7kZAIBjZ1Un6X/+P29Mswug6upqSdJmrbHcCQDgSlRXV8vj8TS63GUuFVFXmc/n06FDhxQbGyuXyxWwzOv1Ki0tTeXl5YqLi7PUoX3sh3PYD+ewH85hP5zTHPaDMUbV1dVKTU1VWFjjV3qa3RlQWFiYOnfufNExcXFx1/QB9j32wznsh3PYD+ewH86xvR8udubzPW5CAABYQQABAKxoUQHkdrs1Z84cud1u261YxX44h/1wDvvhHPbDOS1pPzS7mxAAANeGFnUGBABoPQggAIAVBBAAwAoCCABgBQEEALCixQTQggUL1K1bN7Vp00aZmZn64osvbLd01b3wwgtyuVwB04033mi7rSa3adMmjRo1SqmpqXK5XProo48Clhtj9PzzzyslJUVt27ZVVlaW9u7da6fZJnSp/TBp0qQLjo8RI0bYabaJ5Ofn67bbblNsbKwSExM1ZswYlZSUBIw5ffq0cnNz1alTJ8XExGjcuHGqrKy01HHTuJz9MHTo0AuOh4cffthSxw1rEQH0/vvva9asWZozZ46+/PJL9e3bV9nZ2Tpy5Ijt1q66Xr166fDhw/5p8+bNtltqcjU1Nerbt68WLFjQ4PJ58+bp9ddf16JFi7Rt2za1a9dO2dnZOn369FXutGldaj9I0ogRIwKOj/fee+8qdtj0ioqKlJubq61bt2rdunWqq6vT8OHDVVNT4x8zc+ZMffzxx1qxYoWKiop06NAhjR071mLXoXc5+0GSpkyZEnA8zJs3z1LHjTAtQP/+/U1ubq7/dX19vUlNTTX5+fkWu7r65syZY/r27Wu7DaskmZUrV/pf+3w+k5ycbF5++WX/vOPHjxu3223ee+89Cx1eHT/cD8YYM3HiRDN69Ggr/dhy5MgRI8kUFRUZY8793UdGRpoVK1b4x/zjH/8wksyWLVtstdnkfrgfjDHm9ttvN4899pi9pi5Dsz8DOnPmjHbs2KGsrCz/vLCwMGVlZWnLli0WO7Nj7969Sk1NVffu3fXAAw/owIEDtluyqqysTBUVFQHHh8fjUWZm5jV5fBQWFioxMVE9e/bUtGnTdOzYMdstNamqqipJUseOHSVJO3bsUF1dXcDxcOONN6pLly6t+nj44X743tKlSxUfH6/evXtr9uzZOnnypI32GtXsnob9Q99++63q6+uVlJQUMD8pKUl79uyx1JUdmZmZWrJkiXr27KnDhw/rxRdf1ODBg7V7927Fxsbabs+KiooKSWrw+Ph+2bVixIgRGjt2rNLT01VaWqpnnnlGOTk52rJli8LDw223F3I+n08zZszQwIED1bt3b0nnjoeoqCi1b98+YGxrPh4a2g+SdP/996tr165KTU3Vrl279NRTT6mkpEQffvihxW4DNfsAwv/Iycnx/7lPnz7KzMxU165d9cEHH+hXv/qVxc7QHIwfP97/55tvvll9+vRRRkaGCgsLNWzYMIudNY3c3Fzt3r37mrgOejGN7YeHHnrI/+ebb75ZKSkpGjZsmEpLS5WRkXG122xQs38LLj4+XuHh4RfcxVJZWank5GRLXTUP7du3V48ePbRv3z7brVjz/THA8XGh7t27Kz4+vlUeH3l5eVq9erUKCgoCvj8sOTlZZ86c0fHjxwPGt9bjobH90JDMzExJalbHQ7MPoKioKPXr108bNmzwz/P5fNqwYYMGDBhgsTP7Tpw4odLSUqWkpNhuxZr09HQlJycHHB9er1fbtm275o+PgwcP6tixY63q+DDGKC8vTytXrtTGjRuVnp4esLxfv36KjIwMOB5KSkp04MCBVnU8XGo/NGTnzp2S1LyOB9t3QVyO5cuXG7fbbZYsWWL+/ve/m4ceesi0b9/eVFRU2G7tqnr88cdNYWGhKSsrM5999pnJysoy8fHx5siRI7Zba1LV1dWmuLjYFBcXG0nm1VdfNcXFxeabb74xxhjzu9/9zrRv396sWrXK7Nq1y4wePdqkp6ebU6dOWe48tC62H6qrq80TTzxhtmzZYsrKysz69evNT37yE3PDDTeY06dP2249ZKZNm2Y8Ho8pLCw0hw8f9k8nT570j3n44YdNly5dzMaNG8327dvNgAEDzIABAyx2HXqX2g/79u0zL730ktm+fbspKyszq1atMt27dzdDhgyx3HmgFhFAxhjzxhtvmC5dupioqCjTv39/s3XrVtstXXX33XefSUlJMVFRUea6664z9913n9m3b5/ttppcQUGBkXTBNHHiRGPMuVuxn3vuOZOUlGTcbrcZNmyYKSkpsdt0E7jYfjh58qQZPny4SUhIMJGRkaZr165mypQpre6XtIZ+fklm8eLF/jGnTp0yjzzyiOnQoYOJjo429957rzl8+LC9ppvApfbDgQMHzJAhQ0zHjh2N2+02119/vfnNb35jqqqq7Db+A3wfEADAimZ/DQgA0DoRQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAV/w/hgVLrpVGHsAAAAABJRU5ErkJggg==", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.figure()\n", "plt.title(\"Prediction: {}\".format(prediction))\n", "plt.imshow(img)\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "39167347-0b99-4972-998c-e1230bf1d4d5", "metadata": {}, "source": [ "### 784 columns of float" ] }, { "cell_type": "code", "execution_count": 39, "id": "6bea332e-f6de-494f-a0db-795d9fe3e134", "metadata": {}, "outputs": [], "source": [ "def predict_batch_fn():\n", " import tensorflow as tf\n", " # Enable GPU memory growth\n", " gpus = tf.config.experimental.list_physical_devices('GPU')\n", " if gpus:\n", " try:\n", " for gpu in gpus:\n", " tf.config.experimental.set_memory_growth(gpu, True)\n", " except RuntimeError as e:\n", " print(e)\n", " \n", " model = tf.keras.models.load_model(model_path)\n", " def predict(inputs: np.ndarray) -> np.ndarray:\n", " return model.predict(inputs)\n", " \n", " return predict" ] }, { "cell_type": "code", "execution_count": 40, "id": "731d234c-549f-4df3-8a2b-312e63195396", "metadata": {}, "outputs": [], "source": [ "mnist = predict_batch_udf(predict_batch_fn,\n", " return_type=ArrayType(FloatType()),\n", " batch_size=128,\n", " input_tensor_shapes=[[784]])" ] }, { "cell_type": "code", "execution_count": null, "id": "a40fe207-6246-4b0e-abde-823979878d97", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "784" ] }, "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = spark.read.parquet(data_path_784)\n", "len(df.columns)" ] }, { "cell_type": "code", "execution_count": 42, "id": "10904f12-03e7-4518-8f12-2aa11989ddf5", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 12:==============> (2 + 6) / 8]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 52.5 ms, sys: 22 ms, total: 74.5 ms\n", "Wall time: 5.72 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "%%time\n", "preds = df.withColumn(\"preds\", mnist(struct(*df.columns))).collect()" ] }, { "cell_type": "code", "execution_count": 43, "id": "671128df-f0f4-4f54-b35c-d63a78c7f89a", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 13:===========================================> (6 + 2) / 8]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 49.4 ms, sys: 31.9 ms, total: 81.2 ms\n", "Wall time: 1.34 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "%%time\n", "preds = df.withColumn(\"preds\", mnist(array(*df.columns))).collect()" ] }, { "cell_type": "code", "execution_count": 44, "id": "ce35deaf-7d49-4f34-9bf9-b4e6fc5761f4", "metadata": {}, "outputs": [], "source": [ "# should raise ValueError\n", "# preds = df.withColumn(\"preds\", mnist(*df.columns)).collect()" ] }, { "cell_type": "markdown", "id": "01709833-484b-451f-9aa8-37be5b7baf14", "metadata": {}, "source": [ "### Check prediction" ] }, { "cell_type": "code", "execution_count": 45, "id": "f9119632-b284-45d7-a262-c262e034c15c", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
0123456789...775776777778779780781782783preds
00.00.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.0[-6.9618006, 1.2047814, -0.09570807, 0.0462105...
10.00.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.0[-5.2882323, 5.902014, -2.0389183, -1.2460864,...
20.00.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.0[-5.822013, -2.3333628, -2.4322102, -8.040086,...
30.00.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.0[-0.57203317, -1.2920653, -2.7234774, 0.914070...
40.00.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.0[-3.689301, 5.0702505, -0.23930073, -0.7988689...
50.00.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.0[8.268821, -2.070008, 1.722378, -1.8471404, -8...
60.00.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.0[5.59269, -3.1613479, 0.4734843, -0.7772096, -...
70.00.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.0[1.9852623, -5.166985, 0.86473066, -6.491789, ...
80.00.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.0[-2.800528, -4.2984514, 10.887824, -3.1346364,...
90.00.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.0[-3.7827752, -4.51145, -5.354035, 9.399383, -6...
\n", "

10 rows × 785 columns

\n", "
" ], "text/plain": [ " 0 1 2 3 4 5 6 7 8 9 ... 775 776 777 778 \\\n", "0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 \n", "1 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 \n", "2 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 \n", "3 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 \n", "4 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 \n", "5 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 \n", "6 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 \n", "7 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 \n", "8 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 \n", "9 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 \n", "\n", " 779 780 781 782 783 preds \n", "0 0.0 0.0 0.0 0.0 0.0 [-6.9618006, 1.2047814, -0.09570807, 0.0462105... \n", "1 0.0 0.0 0.0 0.0 0.0 [-5.2882323, 5.902014, -2.0389183, -1.2460864,... \n", "2 0.0 0.0 0.0 0.0 0.0 [-5.822013, -2.3333628, -2.4322102, -8.040086,... \n", "3 0.0 0.0 0.0 0.0 0.0 [-0.57203317, -1.2920653, -2.7234774, 0.914070... \n", "4 0.0 0.0 0.0 0.0 0.0 [-3.689301, 5.0702505, -0.23930073, -0.7988689... \n", "5 0.0 0.0 0.0 0.0 0.0 [8.268821, -2.070008, 1.722378, -1.8471404, -8... \n", "6 0.0 0.0 0.0 0.0 0.0 [5.59269, -3.1613479, 0.4734843, -0.7772096, -... \n", "7 0.0 0.0 0.0 0.0 0.0 [1.9852623, -5.166985, 0.86473066, -6.491789, ... \n", "8 0.0 0.0 0.0 0.0 0.0 [-2.800528, -4.2984514, 10.887824, -3.1346364,... \n", "9 0.0 0.0 0.0 0.0 0.0 [-3.7827752, -4.51145, -5.354035, 9.399383, -6... \n", "\n", "[10 rows x 785 columns]" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "preds = df.withColumn(\"preds\", mnist(struct(df.columns))).limit(10).toPandas()\n", "preds" ] }, { "cell_type": "code", "execution_count": 46, "id": "7c067c62-03a6-461e-a1ff-4653276fbea1", "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": 47, "id": "a7084ad0-c021-4296-bad0-7a238971f53b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([-6.9618006 , 1.2047814 , -0.09570807, 0.04621054, -5.8169513 ,\n", " -4.148872 , -5.17938 , 6.382909 , -0.11228667, 0.6022302 ],\n", " dtype=float32)" ] }, "execution_count": 47, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sample = preds.iloc[0]\n", "sample.preds" ] }, { "cell_type": "code", "execution_count": 48, "id": "8167c832-93ef-4f50-873b-07b67c19ef53", "metadata": {}, "outputs": [], "source": [ "prediction = np.argmax(sample.preds)\n", "img = sample.drop('preds').to_numpy(dtype=float)\n", "img = np.array(img).reshape(28,28)" ] }, { "cell_type": "code", "execution_count": null, "id": "297811e1-aecb-4afd-9a6a-30c49e8881cc", "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGzCAYAAABpdMNsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAiTklEQVR4nO3dfXBV9b3v8c/O0+YpCYQ8S8CAAhYET1FyuCCipAlBHVF6KmrvBY4FpQHFHGsPTgVRZtJDTzmoTcE59xTaUxAP0wK3lKKAJBQKdEAYBqu5kMYCAwnImAQChIf9u39w2cdNArg2O3zz8H7NrJnstX7ftb5ZLPiw9lp7bZ9zzgkAgFssyroBAED7RAABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEe3H777Zo0aVLwdWlpqXw+n0pLSyO2DZ/Pp9dffz1i6wNaKgIIrcbSpUvl8/mCU4cOHdS3b19Nnz5d1dXV1u15sm7dulYTMl/d51dP3/rWt6zbQysWY90A4NUbb7yh7OxsnTt3Tlu3btWiRYu0bt067d+/X506dbqlvYwcOVJnz55VXFycp7p169appKSkyRA6e/asYmJazl/N//zP/2w0b9euXXrrrbeUl5dn0BHaipZzlANfU0FBge69915J0ve+9z11795dCxYs0Jo1a/TUU081WVNfX6/OnTtHvJeoqCh16NAhouuM9Ppu1ne/+91G86689Xit/Q18HbwFh1bvoYcekiRVVlZKkiZNmqQuXbqooqJCY8eOVXx8vJ555hlJUiAQ0MKFCzVgwAB16NBBaWlpeu655/Tll1+GrNM5p3nz5qlHjx7q1KmTHnzwQX3yySeNtn2ta0A7d+7U2LFj1a1bN3Xu3FmDBg3SW2+9FeyvpKREUujbW1c0dQ1oz549KigoUEJCgrp06aLRo0drx44dIWOuvEW5bds2FRUVKSUlRZ07d9bjjz+uEydOhIytra3VZ599ptra2q+zi0M0NDToN7/5jR544AH16NHDcz1wBWdAaPUqKiokSd27dw/Ou3jxovLz8zVixAj967/+a/Ctueeee05Lly7V5MmT9cILL6iyslI/+9nPtGfPHm3btk2xsbGSpNmzZ2vevHkaO3asxo4dq48//lh5eXk6f/78DfvZsGGDHnnkEWVkZOjFF19Uenq6Pv30U61du1YvvviinnvuOR09elQbNmxo8u2tq33yySe6//77lZCQoFdeeUWxsbF69913NWrUKJWVlSknJydk/IwZM9StWzfNmTNHn3/+uRYuXKjp06fr/fffD45ZtWqVJk+erCVLloTcVPF1rFu3TjU1NcFQB8LmgFZiyZIlTpLbuHGjO3HihDt8+LBbsWKF6969u+vYsaM7cuSIc865iRMnOknun//5n0Pq//jHPzpJbtmyZSHz169fHzL/+PHjLi4uzj388MMuEAgEx7366qtOkps4cWJw3ubNm50kt3nzZueccxcvXnTZ2dmuV69e7ssvvwzZzlfXVVhY6K7110+SmzNnTvD1uHHjXFxcnKuoqAjOO3r0qIuPj3cjR45stH9yc3NDtvXSSy+56OhoV1NT02jskiVLmuzhesaPH+/8fn+j3w/wirfg0Ork5uYqJSVFWVlZmjBhgrp06aJVq1bptttuCxk3bdq0kNcrV65UYmKivvWtb+mLL74ITkOGDFGXLl20efNmSdLGjRt1/vx5zZgxI+StsZkzZ96wtz179qiyslIzZ85U165dQ5Z9dV1f16VLl/Thhx9q3Lhx6t27d3B+RkaGnn76aW3dulV1dXUhNVOnTg3Z1v33369Lly7pb3/7W3DepEmT5JzzfPZTV1en3//+9xo7dmyj3w/wirfg0OqUlJSob9++iomJUVpamvr166eoqND/S8XExDS6PnHgwAHV1tYqNTW1yfUeP35ckoL/UN95550hy1NSUtStW7fr9nbl7cCBAwd+/V/oOk6cOKEzZ86oX79+jZbdddddCgQCOnz4sAYMGBCc37Nnz5BxV3q++jpXOH7zm9/o3LlzvP2GiCCA0OoMHTo0eBfctfj9/kahFAgElJqaqmXLljVZk5KSErEeLUVHRzc53zl30+tetmyZEhMT9cgjj9z0ugACCO1Gnz59tHHjRg0fPlwdO3a85rhevXpJunzG9NW3vU6cOHHDs4g+ffpIkvbv36/c3Nxrjvu6b8elpKSoU6dOKi8vb7Tss88+U1RUlLKysr7Wum7WsWPHtHnzZk2aNEl+v/+WbBNtG9eA0G585zvf0aVLl/Tmm282Wnbx4kXV1NRIunyNKTY2Vu+8807IWcPChQtvuI1vfvObys7O1sKFC4Pru+Kr67rymaSrx1wtOjpaeXl5WrNmjT7//PPg/Orqai1fvlwjRoxQQkLCDfu6Wji3Ya9YsUKBQIC33xAxnAGh3XjggQf03HPPqbi4WHv37lVeXp5iY2N14MABrVy5Um+99Za+/e1vKyUlRS+//LKKi4v1yCOPaOzYsdqzZ4/+8Ic/KDk5+brbiIqK0qJFi/Too4/qnnvu0eTJk5WRkaHPPvtMn3zyiT744ANJ0pAhQyRJL7zwgvLz8xUdHa0JEyY0uc558+Zpw4YNGjFihL7//e8rJiZG7777rhoaGjR//vyw9kU4t2EvW7ZMmZmZGjVqVFjbBK5GAKFdWbx4sYYMGaJ3331Xr776qmJiYnT77bfru9/9roYPHx4cN2/ePHXo0EGLFy/W5s2blZOTow8//FAPP/zwDbeRn5+vzZs3a+7cufrpT3+qQCCgPn36aMqUKcExTzzxhGbMmKEVK1bo17/+tZxz1wygAQMG6I9//KNmzZql4uJiBQIB5eTk6Ne//nWjzwA1l/Lycu3evVtFRUWNrq0B4fK5SFyZBADAI/4rAwAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMtLjPAQUCAR09elTx8fFhPT0YAGDLOadTp04pMzPzup8ba3EBdPTo0Vv2bCsAQPM5fPjwdb81t8UFUHx8vCRphMYqRrHG3QAAvLqoC9qqdcF/z6+l2QKopKREP/nJT1RVVaXBgwfrnXfe0dChQ29Yd+VttxjFKsZHAAFAq/P/n69zo8sozXITwvvvv6+ioiLNmTNHH3/8sQYPHqz8/PzgF34BANAsAbRgwQJNmTJFkydP1je+8Q0tXrxYnTp10i9+8Yvm2BwAoBWKeACdP39eu3fvDvkyrqioKOXm5mr79u2Nxjc0NKiuri5kAgC0fREPoC+++EKXLl1SWlpayPy0tDRVVVU1Gl9cXKzExMTgxB1wANA+mH8QddasWaqtrQ1Ohw8ftm4JAHALRPwuuOTkZEVHR6u6ujpkfnV1tdLT0xuN9/v9fL88ALRDET8DiouL05AhQ7Rp06bgvEAgoE2bNmnYsGGR3hwAoJVqls8BFRUVaeLEibr33ns1dOhQLVy4UPX19Zo8eXJzbA4A0Ao1SwA9+eSTOnHihGbPnq2qqirdc889Wr9+faMbEwAA7ZfPOeesm/iquro6JSYmapQe40kIANAKXXQXVKo1qq2tVUJCwjXHmd8FBwBonwggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGAi4gH0+uuvy+fzhUz9+/eP9GYAAK1cTHOsdMCAAdq4ceN/bySmWTYDAGjFmiUZYmJilJ6e3hyrBgC0Ec1yDejAgQPKzMxU79699cwzz+jQoUPXHNvQ0KC6urqQCQDQ9kU8gHJycrR06VKtX79eixYtUmVlpe6//36dOnWqyfHFxcVKTEwMTllZWZFuCQDQAvmcc645N1BTU6NevXppwYIFevbZZxstb2hoUENDQ/B1XV2dsrKyNEqPKcYX25ytAQCawUV3QaVao9raWiUkJFxzXLPfHdC1a1f17dtXBw8ebHK53++X3+9v7jYAAC1Ms38O6PTp06qoqFBGRkZzbwoA0IpEPIBefvlllZWV6fPPP9ef/vQnPf7444qOjtZTTz0V6U0BAFqxiL8Fd+TIET311FM6efKkUlJSNGLECO3YsUMpKSmR3hQAoBWLeACtWLEi0qtEOxc9oJ/nmpqB3cLa1qkJ3j8G8D9uq/Rcs+1Ib881w3v81XPN1lV/57lGknq+tddzTeDMmbC2hfaLZ8EBAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAw0exfSAd8VfQd2Z5rpq7+veeahzvVeq6RpCj5PNcEFMaXCt+21XtNGKKmbwurrl9SoeeaPj/YHta20H5xBgQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMMHTsHFLueovPNcU/eEZzzUPj/+55xpJ+jJw1nPNfRtf8FwTdyTOc83+f/yZ55pw/fzx/+255q238z3XXDx8xHMN2g7OgAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJjgYaS4pQKnTnmu6f/mXz3X3HPb//JcI0kd1yd4run779s918Rk9/Jco3/0XhKu1OjTnmtcpw7N0AnaMs6AAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmOBhpGjxLp044bmmx3jvNbdSQ6/unmui5GuGTq6xLZ+7ZdtC+8UZEADABAEEADDhOYC2bNmiRx99VJmZmfL5fFq9enXIcuecZs+erYyMDHXs2FG5ubk6cOBApPoFALQRngOovr5egwcPVklJSZPL58+fr7fffluLFy/Wzp071blzZ+Xn5+vcuXM33SwAoO3wfBNCQUGBCgoKmlzmnNPChQv1ox/9SI899pgk6Ve/+pXS0tK0evVqTZgw4ea6BQC0GRG9BlRZWamqqirl5uYG5yUmJionJ0fbtzf9tcUNDQ2qq6sLmQAAbV9EA6iqqkqSlJaWFjI/LS0tuOxqxcXFSkxMDE5ZWVmRbAkA0EKZ3wU3a9Ys1dbWBqfDhw9btwQAuAUiGkDp6emSpOrq6pD51dXVwWVX8/v9SkhICJkAAG1fRAMoOztb6enp2rRpU3BeXV2ddu7cqWHDhkVyUwCAVs7zXXCnT5/WwYMHg68rKyu1d+9eJSUlqWfPnpo5c6bmzZunO++8U9nZ2XrttdeUmZmpcePGRbJvAEAr5zmAdu3apQcffDD4uqioSJI0ceJELV26VK+88orq6+s1depU1dTUaMSIEVq/fr06dOgQua4BAK2e5wAaNWqUnLv2gwp9Pp/eeOMNvfHGGzfVGNCWHc71e64JyPsDQsN9gGlS1EXPNYEu3n8ntG/md8EBANonAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJz0/DBnDzfH1PW7dwXfOPP3jjQVdxuz9phk7QlnEGBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQPIwVu0ul/yPFc839yFoSxpQ5h1ITngz/c67nmdm1vhk7QlnEGBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQPIwVu0tHcgOeaPjEdm6GTyMncdtG6BbQDnAEBAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwwcNIga+I7p7kueahwZ96rgnIea4JR9/fPx9e3YcfR7gToDHOgAAAJgggAIAJzwG0ZcsWPfroo8rMzJTP59Pq1atDlk+aNEk+ny9kGjNmTKT6BQC0EZ4DqL6+XoMHD1ZJSck1x4wZM0bHjh0LTu+9995NNQkAaHs834RQUFCggoKC647x+/1KT08PuykAQNvXLNeASktLlZqaqn79+mnatGk6efLkNcc2NDSorq4uZAIAtH0RD6AxY8boV7/6lTZt2qR/+Zd/UVlZmQoKCnTp0qUmxxcXFysxMTE4ZWVlRbolAEALFPHPAU2YMCH48913361BgwapT58+Ki0t1ejRoxuNnzVrloqKioKv6+rqCCEAaAea/Tbs3r17Kzk5WQcPHmxyud/vV0JCQsgEAGj7mj2Ajhw5opMnTyojI6O5NwUAaEU8vwV3+vTpkLOZyspK7d27V0lJSUpKStLcuXM1fvx4paenq6KiQq+88oruuOMO5efnR7RxAEDr5jmAdu3apQcffDD4+sr1m4kTJ2rRokXat2+ffvnLX6qmpkaZmZnKy8vTm2++Kb/fH7muAQCtnucAGjVqlJy79oMUP/jgg5tqCLBUOaO/55o1We80QyeR8Y3Zh8Kquxho+q5VIJJ4FhwAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwETEv5IbaM1Gjt1j3cI13VX6Pc81fapa7u8DcAYEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABA8jBb7i57dtC6PK57ni/14457mm32tfeq656LkCuHU4AwIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCh5GiTTr9DzlhVn7suSIg57nmO3u+57km869/8VwDtGScAQEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADDBw0jR4kV3TfRc8z/nrm2GTiIn/adx1i0A5jgDAgCYIIAAACY8BVBxcbHuu+8+xcfHKzU1VePGjVN5eXnImHPnzqmwsFDdu3dXly5dNH78eFVXV0e0aQBA6+cpgMrKylRYWKgdO3Zow4YNunDhgvLy8lRfXx8c89JLL+l3v/udVq5cqbKyMh09elRPPPFExBsHALRunm5CWL9+fcjrpUuXKjU1Vbt379bIkSNVW1ur//iP/9Dy5cv10EMPSZKWLFmiu+66Szt27NDf//3fR65zAECrdlPXgGprayVJSUlJkqTdu3frwoULys3NDY7p37+/evbsqe3btze5joaGBtXV1YVMAIC2L+wACgQCmjlzpoYPH66BAwdKkqqqqhQXF6euXbuGjE1LS1NVVVWT6ykuLlZiYmJwysrKCrclAEArEnYAFRYWav/+/VqxYsVNNTBr1izV1tYGp8OHD9/U+gAArUNYH0SdPn261q5dqy1btqhHjx7B+enp6Tp//rxqampCzoKqq6uVnp7e5Lr8fr/8fn84bQAAWjFPZ0DOOU2fPl2rVq3SRx99pOzs7JDlQ4YMUWxsrDZt2hScV15erkOHDmnYsGGR6RgA0CZ4OgMqLCzU8uXLtWbNGsXHxwev6yQmJqpjx45KTEzUs88+q6KiIiUlJSkhIUEzZszQsGHDuAMOABDCUwAtWrRIkjRq1KiQ+UuWLNGkSZMkSf/2b/+mqKgojR8/Xg0NDcrPz9fPf/7ziDQLAGg7PAWQc+6GYzp06KCSkhKVlJSE3RTwVb5uXT3XPJt4KNythVkHwCueBQcAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMBHWN6ICLV1UmE+1jvaF8X8yFwhrW0B7xxkQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEzyMFC1e5TO3ea4JyIW3sTAeLJr36TjPNbE7/+K5JszfCGixOAMCAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABggoeRosVL3n/Rc83imt5hbevb8Z94rhmZctBzzZ8uxHmuAdoazoAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCY4GGkaPE6rv6z55r1+/4urG0t+EG+55r4g97/GmXoT55rgLaGMyAAgAkCCABgwlMAFRcX67777lN8fLxSU1M1btw4lZeXh4wZNWqUfD5fyPT8889HtGkAQOvnKYDKyspUWFioHTt2aMOGDbpw4YLy8vJUX18fMm7KlCk6duxYcJo/f35EmwYAtH6erp6uX78+5PXSpUuVmpqq3bt3a+TIkcH5nTp1Unp6emQ6BAC0STd1Dai2tlaSlJSUFDJ/2bJlSk5O1sCBAzVr1iydOXPmmutoaGhQXV1dyAQAaPvCvg07EAho5syZGj58uAYOHBic//TTT6tXr17KzMzUvn379MMf/lDl5eX67W9/2+R6iouLNXfu3HDbAAC0UmEHUGFhofbv36+tW7eGzJ86dWrw57vvvlsZGRkaPXq0Kioq1KdPn0brmTVrloqKioKv6+rqlJWVFW5bAIBWIqwAmj59utauXastW7aoR48e1x2bk5MjSTp48GCTAeT3++X3+8NpAwDQinkKIOecZsyYoVWrVqm0tFTZ2dk3rNm7d68kKSMjI6wGAQBtk6cAKiws1PLly7VmzRrFx8erqqpKkpSYmKiOHTuqoqJCy5cv19ixY9W9e3ft27dPL730kkaOHKlBgwY1yy8AAGidPAXQokWLJF3+sOlXLVmyRJMmTVJcXJw2btyohQsXqr6+XllZWRo/frx+9KMfRaxhAEDb4PktuOvJyspSWVnZTTUEAGgfeBo22qSLf/08rLq+08KrA+AdDyMFAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgIsa6gas55yRJF3VBcsbNAAA8u6gLkv773/NraXEBdOrUKUnSVq0z7gQAcDNOnTqlxMTEay73uRtF1C0WCAR09OhRxcfHy+fzhSyrq6tTVlaWDh8+rISEBKMO7bEfLmM/XMZ+uIz9cFlL2A/OOZ06dUqZmZmKirr2lZ4WdwYUFRWlHj16XHdMQkJCuz7ArmA/XMZ+uIz9cBn74TLr/XC9M58ruAkBAGCCAAIAmGhVAeT3+zVnzhz5/X7rVkyxHy5jP1zGfriM/XBZa9oPLe4mBABA+9CqzoAAAG0HAQQAMEEAAQBMEEAAABMEEADARKsJoJKSEt1+++3q0KGDcnJy9Oc//9m6pVvu9ddfl8/nC5n69+9v3Vaz27Jlix599FFlZmbK5/Np9erVIcudc5o9e7YyMjLUsWNH5ebm6sCBAzbNNqMb7YdJkyY1Oj7GjBlj02wzKS4u1n333af4+HilpqZq3LhxKi8vDxlz7tw5FRYWqnv37urSpYvGjx+v6upqo46bx9fZD6NGjWp0PDz//PNGHTetVQTQ+++/r6KiIs2ZM0cff/yxBg8erPz8fB0/fty6tVtuwIABOnbsWHDaunWrdUvNrr6+XoMHD1ZJSUmTy+fPn6+3335bixcv1s6dO9W5c2fl5+fr3Llzt7jT5nWj/SBJY8aMCTk+3nvvvVvYYfMrKytTYWGhduzYoQ0bNujChQvKy8tTfX19cMxLL72k3/3ud1q5cqXKysp09OhRPfHEE4ZdR97X2Q+SNGXKlJDjYf78+UYdX4NrBYYOHeoKCwuDry9duuQyMzNdcXGxYVe33pw5c9zgwYOt2zAlya1atSr4OhAIuPT0dPeTn/wkOK+mpsb5/X733nvvGXR4a1y9H5xzbuLEie6xxx4z6cfK8ePHnSRXVlbmnLv8Zx8bG+tWrlwZHPPpp586SW779u1WbTa7q/eDc8498MAD7sUXX7Rr6mto8WdA58+f1+7du5WbmxucFxUVpdzcXG3fvt2wMxsHDhxQZmamevfurWeeeUaHDh2ybslUZWWlqqqqQo6PxMRE5eTktMvjo7S0VKmpqerXr5+mTZumkydPWrfUrGprayVJSUlJkqTdu3frwoULIcdD//791bNnzzZ9PFy9H65YtmyZkpOTNXDgQM2aNUtnzpyxaO+aWtzTsK/2xRdf6NKlS0pLSwuZn5aWps8++8yoKxs5OTlaunSp+vXrp2PHjmnu3Lm6//77tX//fsXHx1u3Z6KqqkqSmjw+rixrL8aMGaMnnnhC2dnZqqio0KuvvqqCggJt375d0dHR1u1FXCAQ0MyZMzV8+HANHDhQ0uXjIS4uTl27dg0Z25aPh6b2gyQ9/fTT6tWrlzIzM7Vv3z798Ic/VHl5uX77298adhuqxQcQ/ltBQUHw50GDBiknJ0e9evXSf/3Xf+nZZ5817AwtwYQJE4I/33333Ro0aJD69Omj0tJSjR492rCz5lFYWKj9+/e3i+ug13Ot/TB16tTgz3fffbcyMjI0evRoVVRUqE+fPre6zSa1+LfgkpOTFR0d3egulurqaqWnpxt11TJ07dpVffv21cGDB61bMXPlGOD4aKx3795KTk5uk8fH9OnTtXbtWm3evDnk+8PS09N1/vx51dTUhIxvq8fDtfZDU3JyciSpRR0PLT6A4uLiNGTIEG3atCk4LxAIaNOmTRo2bJhhZ/ZOnz6tiooKZWRkWLdiJjs7W+np6SHHR11dnXbu3Nnuj48jR47o5MmTber4cM5p+vTpWrVqlT766CNlZ2eHLB8yZIhiY2NDjofy8nIdOnSoTR0PN9oPTdm7d68ktazjwfouiK9jxYoVzu/3u6VLl7q//OUvburUqa5r166uqqrKurVb6p/+6Z9caWmpq6ysdNu2bXO5ubkuOTnZHT9+3Lq1ZnXq1Cm3Z88et2fPHifJLViwwO3Zs8f97W9/c8459+Mf/9h17drVrVmzxu3bt8899thjLjs72509e9a488i63n44deqUe/nll9327dtdZWWl27hxo/vmN7/p7rzzTnfu3Dnr1iNm2rRpLjEx0ZWWlrpjx44FpzNnzgTHPP/8865nz57uo48+crt27XLDhg1zw4YNM+w68m60Hw4ePOjeeOMNt2vXLldZWenWrFnjevfu7UaOHGnceahWEUDOOffOO++4nj17uri4ODd06FC3Y8cO65ZuuSeffNJlZGS4uLg4d9ttt7knn3zSHTx40LqtZrd582YnqdE0ceJE59zlW7Ffe+01l5aW5vx+vxs9erQrLy+3bboZXG8/nDlzxuXl5bmUlBQXGxvrevXq5aZMmdLm/pPW1O8vyS1ZsiQ45uzZs+773/++69atm+vUqZN7/PHH3bFjx+yabgY32g+HDh1yI0eOdElJSc7v97s77rjD/eAHP3C1tbW2jV+F7wMCAJho8deAAABtEwEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBM/D/AaY3Zb7z6aAAAAABJRU5ErkJggg==", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.figure()\n", "plt.title(\"Prediction: {}\".format(prediction))\n", "plt.imshow(img)\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "d3dc87a7", "metadata": {}, "source": [ "## Using Triton Inference Server\n", "In this section, we demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL. \n", "We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server. \n", "\n", "The process looks like this:\n", "- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.\n", "- Define a Triton inference function, which contains a client that binds to the local server on a given node and sends inference requests.\n", "- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.\n", "- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.\n", "\n", "\"drawing\"" ] }, { "cell_type": "code", "execution_count": 50, "id": "cfc841c3", "metadata": {}, "outputs": [], "source": [ "from functools import partial" ] }, { "cell_type": "markdown", "id": "d1e63867", "metadata": {}, "source": [ "Import the helper class from server_utils.py:" ] }, { "cell_type": "code", "execution_count": 51, "id": "d7af3599", "metadata": {}, "outputs": [], "source": [ "sc.addPyFile(\"server_utils.py\")\n", "\n", "from server_utils import TritonServerManager" ] }, { "cell_type": "markdown", "id": "32cbe1cb", "metadata": {}, "source": [ "Define the Triton Server function:" ] }, { "cell_type": "code", "execution_count": 52, "id": "c3539d1b", "metadata": {}, "outputs": [], "source": [ "def triton_server(ports, model_path):\n", " import time\n", " import signal\n", " import numpy as np\n", " from pytriton.decorators import batch\n", " from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor\n", " from pytriton.triton import Triton, TritonConfig\n", " from pyspark import TaskContext\n", " import tensorflow as tf\n", " from tensorflow import keras\n", "\n", " print(f\"SERVER: Initializing model on worker {TaskContext.get().partitionId()}.\")\n", "\n", " # Enable GPU memory growth\n", " gpus = tf.config.experimental.list_physical_devices('GPU')\n", " if gpus:\n", " try:\n", " for gpu in gpus:\n", " tf.config.experimental.set_memory_growth(gpu, True)\n", " except RuntimeError as e:\n", " print(e)\n", "\n", " model = keras.models.load_model(model_path)\n", "\n", " @batch\n", " def _infer_fn(**inputs):\n", " images = np.squeeze(inputs[\"images\"])\n", " print(f\"SERVER: Received batch of size {len(images)}.\")\n", " return {\n", " \"labels\": model.predict(images)\n", " }\n", "\n", " workspace_path = f\"/tmp/triton_{time.strftime('%m_%d_%M_%S')}\"\n", " triton_conf = TritonConfig(http_port=ports[0], grpc_port=ports[1], metrics_port=ports[2])\n", " with Triton(config=triton_conf, workspace=workspace_path) as triton:\n", " triton.bind(\n", " model_name=\"ImageClassifier\",\n", " infer_func=_infer_fn,\n", " inputs=[\n", " Tensor(name=\"images\", dtype=np.float64, shape=(-1,)),\n", " ],\n", " outputs=[\n", " Tensor(name=\"labels\", dtype=np.float32, shape=(-1,)),\n", " ],\n", " config=ModelConfig(\n", " max_batch_size=128,\n", " batcher=DynamicBatcher(max_queue_delay_microseconds=5000), # 5ms\n", " ),\n", " strict=True,\n", " )\n", "\n", " def _stop_triton(signum, frame):\n", " # The server manager sends SIGTERM to stop the server; this function ensures graceful cleanup.\n", " print(\"SERVER: Received SIGTERM. Stopping Triton server.\")\n", " triton.stop()\n", "\n", " signal.signal(signal.SIGTERM, _stop_triton)\n", "\n", " print(\"SERVER: Serving inference\")\n", " triton.serve()" ] }, { "cell_type": "markdown", "id": "ce4c7701", "metadata": {}, "source": [ "#### Start Triton servers" ] }, { "cell_type": "markdown", "id": "2695d9ab", "metadata": {}, "source": [ "The `TritonServerManager` will handle the lifecycle of Triton server instances across the Spark cluster:\n", "- Find available ports for HTTP/gRPC/metrics\n", "- Deploy a server on each node via stage-level scheduling\n", "- Gracefully shutdown servers across nodes" ] }, { "cell_type": "code", "execution_count": null, "id": "4deae3b1", "metadata": {}, "outputs": [], "source": [ "model_name = \"ImageClassifier\"\n", "server_manager = TritonServerManager(model_name=model_name, model_path=model_path)" ] }, { "cell_type": "code", "execution_count": null, "id": "e56c84f4", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-02-07 11:03:44,809 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\n", "2025-02-07 11:03:44,810 - INFO - Starting 1 servers.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "data": { "text/plain": [ "{'cb4ae00-lcedt': (2020631, [7000, 7001, 7002])}" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Returns {'hostname', (server_pid, [http_port, grpc_port, metrics_port])}\n", "server_manager.start_servers(triton_server)" ] }, { "cell_type": "markdown", "id": "77847814", "metadata": {}, "source": [ "#### Define client function" ] }, { "cell_type": "markdown", "id": "e278fde0", "metadata": {}, "source": [ "Get the hostname -> url mapping from the server manager:" ] }, { "cell_type": "code", "execution_count": null, "id": "68a9606e", "metadata": {}, "outputs": [], "source": [ "host_to_http_url = server_manager.host_to_http_url # or server_manager.host_to_grpc_url" ] }, { "cell_type": "markdown", "id": "4d70bd6f", "metadata": {}, "source": [ "Define the Triton inference function, which returns a predict function for batch inference through the server:" ] }, { "cell_type": "code", "execution_count": 57, "id": "92ba2e26", "metadata": {}, "outputs": [], "source": [ "def triton_fn(model_name, host_to_url):\n", " import socket\n", " from pytriton.client import ModelClient\n", "\n", " url = host_to_url[socket.gethostname()]\n", " print(f\"Connecting to Triton model {model_name} at {url}.\")\n", "\n", " def infer_batch(inputs):\n", " with ModelClient(url, model_name, inference_timeout_s=240) as client:\n", " result_data = client.infer_batch(inputs)\n", " return result_data[\"labels\"]\n", " \n", " return infer_batch" ] }, { "cell_type": "code", "execution_count": 59, "id": "6658d2a1-ef7b-4ca1-9fb6-f2ac9050f3e5", "metadata": {}, "outputs": [], "source": [ "predict = predict_batch_udf(partial(triton_fn, model_name=model_name, host_to_url=host_to_http_url),\n", " input_tensor_shapes=[[784]],\n", " return_type=ArrayType(FloatType()),\n", " batch_size=128)" ] }, { "cell_type": "markdown", "id": "3842c263", "metadata": {}, "source": [ "#### Run inference" ] }, { "cell_type": "code", "execution_count": 58, "id": "43b93753-1d52-4060-9986-f24c30a67528", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "StructType([StructField('data', ArrayType(DoubleType(), True), True)])" ] }, "execution_count": 58, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = spark.read.parquet(data_path_1)\n", "df.schema" ] }, { "cell_type": "code", "execution_count": 60, "id": "8397aa14-82fd-4351-a477-dc8e8b321fa2", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 19:> (0 + 8) / 8]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 19.8 ms, sys: 2.89 ms, total: 22.7 ms\n", "Wall time: 1.67 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "%%time\n", "preds = df.withColumn(\"preds\", predict(struct(\"data\"))).collect()" ] }, { "cell_type": "code", "execution_count": 61, "id": "82698bd9-377a-4415-8971-835487f876cc", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 19.8 ms, sys: 5.99 ms, total: 25.7 ms\n", "Wall time: 399 ms\n" ] } ], "source": [ "%%time\n", "preds = df.withColumn(\"preds\", predict(\"data\")).collect()" ] }, { "cell_type": "code", "execution_count": 62, "id": "419ad7bd-fa28-49d3-b98d-db9fba5aeaef", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 21:====================================> (5 + 3) / 8]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 9.07 ms, sys: 1.34 ms, total: 10.4 ms\n", "Wall time: 888 ms\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
datapreds
0[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...[-4.6654444, -2.4893682, -0.5888205, 13.380681...
1[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...[-2.2732146, -7.5127845, 1.1983705, -3.540661,...
2[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...[-2.2890894, 0.8308606, 0.31311002, 1.1683631,...
3[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...[-1.055197, -6.502811, 12.420727, 0.4528031, -...
4[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...[-3.7887795, 3.9983597, -1.5343359, -0.3698441...
5[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...[-4.4992743, -1.7618219, 1.1183226, 3.9469318,...
6[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...[-2.754053, 4.868414, 0.2515293, -0.47300792, ...
7[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...[-1.888711, 0.02717158, -6.050885, 0.08750934,...
8[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...[0.9541264, -2.113048, -1.7508973, -5.4303784,...
9[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...[-1.612412, -0.7655782, -4.473859, 2.0609212, ...
\n", "
" ], "text/plain": [ " data \\\n", "0 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", "1 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", "2 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", "3 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", "4 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", "5 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", "6 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", "7 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", "8 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", "9 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", "\n", " preds \n", "0 [-4.6654444, -2.4893682, -0.5888205, 13.380681... \n", "1 [-2.2732146, -7.5127845, 1.1983705, -3.540661,... \n", "2 [-2.2890894, 0.8308606, 0.31311002, 1.1683631,... \n", "3 [-1.055197, -6.502811, 12.420727, 0.4528031, -... \n", "4 [-3.7887795, 3.9983597, -1.5343359, -0.3698441... \n", "5 [-4.4992743, -1.7618219, 1.1183226, 3.9469318,... \n", "6 [-2.754053, 4.868414, 0.2515293, -0.47300792, ... \n", "7 [-1.888711, 0.02717158, -6.050885, 0.08750934,... \n", "8 [0.9541264, -2.113048, -1.7508973, -5.4303784,... \n", "9 [-1.612412, -0.7655782, -4.473859, 2.0609212, ... " ] }, "execution_count": 62, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%time\n", "preds = df.withColumn(\"preds\", predict(col(\"data\"))).limit(10).toPandas()\n", "preds" ] }, { "cell_type": "code", "execution_count": 63, "id": "79d90a26", "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": 64, "id": "4ca495f5", "metadata": {}, "outputs": [], "source": [ "sample = preds.iloc[0]\n", "sample.preds\n", "\n", "prediction = np.argmax(sample.preds)\n", "img = np.array(sample.data).reshape(28,28)" ] }, { "cell_type": "code", "execution_count": 65, "id": "a5d10903", "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGzCAYAAABpdMNsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAkfElEQVR4nO3dfXQUdZ7v8U/nqSEkaR7yLAFCFHRAcAYly/AgSiQEZUCYGUG9F7gziJiggI6KR0Udzsksrg7qIHjcHVhHEGWOyMoiDg9JGBRwwTCIM2QhJ0g4kIBcSYcAIaR/9w+uvbQkQDUdfkl4v86pc+iq37fqm6Lgk+qqrnYZY4wAALjKwmw3AAC4NhFAAAArCCAAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKwggAAAVhBAgAPdunXTpEmT/K8LCwvlcrlUWFgYsm24XC698MILIVsf0FwRQGgxlixZIpfL5Z/atGmjHj16KC8vT5WVlbbbc2TNmjUtJmTefvtt3X777UpKSpLb7VZ6eromT56s/fv3224NLVyE7QYAp1566SWlp6fr9OnT2rx5sxYuXKg1a9Zo9+7dio6Ovqq9DBkyRKdOnVJUVJSjujVr1mjBggUNhtCpU6cUEdF8/mkWFxcrPT1dP/vZz9ShQweVlZXp7bff1urVq/W3v/1NqamptltEC9V8jnLgMuXk5OjWW2+VJP36179Wp06d9Oqrr2rVqlWaMGFCgzU1NTVq165dyHsJCwtTmzZtQrrOUK/vSr355psXzBszZoxuvfVWvfPOO3r66actdIXWgLfg0OLdeeedkqSysjJJ0qRJkxQTE6PS0lKNHDlSsbGxeuCBByRJPp9P8+fPV69evdSmTRslJSVp6tSp+u677wLWaYzR3Llz1blzZ0VHR+uOO+7Q119/fcG2G7sGtG3bNo0cOVIdOnRQu3bt1KdPH7322mv+/hYsWCBJAW8pfq+ha0DFxcXKyclRXFycYmJiNGzYMG3dujVgzPdvUX722WeaNWuWEhIS1K5dO9177706evRowNiqqirt2bNHVVVVl7OLL9CtWzdJ0vHjx4OqByTOgNAKlJaWSpI6derkn3f27FllZ2dr0KBB+pd/+Rf/W3NTp07VkiVLNHnyZD366KMqKyvTH/7wBxUXF+uzzz5TZGSkJOn555/X3LlzNXLkSI0cOVJffvmlhg8frjNnzlyyn3Xr1umee+5RSkqKHnvsMSUnJ+sf//iHVq9erccee0xTp07VoUOHtG7dOv3pT3+65Pq+/vprDR48WHFxcXryyScVGRmpt956S0OHDlVRUZEyMzMDxk+fPl0dOnTQnDlztH//fs2fP195eXl6//33/WNWrlypyZMna/HixQE3VVzMsWPHVF9frwMHDuill16SJA0bNuyyaoEGGaCFWLx4sZFk1q9fb44ePWrKy8vN8uXLTadOnUzbtm3NwYMHjTHGTJw40UgyTz/9dED9X//6VyPJLF26NGD+2rVrA+YfOXLEREVFmbvvvtv4fD7/uGeeecZIMhMnTvTPKygoMJJMQUGBMcaYs2fPmvT0dNO1a1fz3XffBWzn/HXl5uaaxv75STJz5szxvx4zZoyJiooypaWl/nmHDh0ysbGxZsiQIRfsn6ysrIBtzZw504SHh5vjx49fMHbx4sUN9tAQt9ttJBlJplOnTub111+/7FqgIbwFhxYnKytLCQkJSktL0/jx4xUTE6OVK1fquuuuCxg3bdq0gNcrVqyQx+PRXXfdpW+//dY/9evXTzExMSooKJAkrV+/XmfOnNH06dMD3hqbMWPGJXsrLi5WWVmZZsyYofbt2wcsO39dl6u+vl5/+ctfNGbMGHXv3t0/PyUlRffff782b94sr9cbUPPQQw8FbGvw4MGqr6/XN9984583adIkGWMu++xHkj755BOtWbNGr7zyirp06aKamhrHPw9wPt6CQ4uzYMEC9ejRQxEREUpKSlLPnj0VFhb4u1RERIQ6d+4cMG/v3r2qqqpSYmJig+s9cuSIJPn/o77hhhsClickJKhDhw4X7e37twN79+59+T/QRRw9elQnT55Uz549L1h20003yefzqby8XL169fLP79KlS8C473v+4XUup+644w5J524CGT16tHr37q2YmBjl5eVd0Xpx7SKA0OL079/ffxdcY9xu9wWh5PP5lJiYqKVLlzZYk5CQELIebQoPD29wvjEmZNvIyMjQj3/8Yy1dupQAQtAIIFwzMjIytH79eg0cOFBt27ZtdFzXrl0lnTtjOv9tr6NHj17yLCIjI0OStHv3bmVlZTU67nLfjktISFB0dLRKSkouWLZnzx6FhYUpLS3tstYVaqdOnVJtba2VbaN14BoQrhm//OUvVV9fr9/+9rcXLDt79qz/luKsrCxFRkbqjTfeCDhrmD9//iW38ZOf/ETp6emaP3/+Bbcon7+u7z+TdKnbmMPDwzV8+HCtWrUq4MkDlZWVWrZsmQYNGqS4uLhL9vVDl3sb9tmzZxsM3S+++EJfffXVJc9EgYvhDAjXjNtvv11Tp05Vfn6+du7cqeHDhysyMlJ79+7VihUr9Nprr+nnP/+5EhIS9MQTTyg/P1/33HOPRo4cqeLiYn3yySeKj4+/6DbCwsK0cOFCjRo1SrfccosmT56slJQU7dmzR19//bU+/fRTSVK/fv0kSY8++qiys7MVHh6u8ePHN7jOuXPnat26dRo0aJAeeeQRRURE6K233lJtba3mzZsX1L643NuwT5w4obS0NN13333q1auX2rVrp6+++kqLFy+Wx+PRc889F9T2AYkAwjVm0aJF6tevn9566y0988wzioiIULdu3fTggw9q4MCB/nFz585VmzZttGjRIhUUFCgzM1N/+ctfdPfdd19yG9nZ2SooKNCLL76oV155RT6fTxkZGZoyZYp/zNixYzV9+nQtX75c7777rowxjQZQr1699Ne//lWzZ89Wfn6+fD6fMjMz9e67717wGaBQi46O1q9//WsVFBToz3/+s06dOqXU1FRNmDBBzz77rP8DqUAwXCaUVyYBALhMXAMCAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMCKZvc5IJ/Pp0OHDik2NjaopwcDAOwyxqi6ulqpqakXPJPxfM0ugA4dOmTt2VYAgNApLy+/4Kn052t2ARQbGytJGqSRilCk5W4AAE6dVZ02a43///PGNFkALViwQC+//LIqKirUt29fvfHGG+rfv/8l675/2y1CkYpwEUAA0OL8/+frXOoySpPchPD+++9r1qxZmjNnjr788kv17dtX2dnZ/i/8AgCgSQLo1Vdf1ZQpUzR58mT96Ec/0qJFixQdHa0//vGPTbE5AEALFPIAOnPmjHbs2BHwZVxhYWHKysrSli1bLhhfW1srr9cbMAEAWr+QB9C3336r+vp6JSUlBcxPSkpSRUXFBePz8/Pl8Xj8E3fAAcC1wfoHUWfPnq2qqir/VF5ebrslAMBVEPK74OLj4xUeHq7KysqA+ZWVlUpOTr5gvNvtltvtDnUbAIBmLuRnQFFRUerXr582bNjgn+fz+bRhwwYNGDAg1JsDALRQTfI5oFmzZmnixIm69dZb1b9/f82fP181NTWaPHlyU2wOANACNUkA3XfffTp69Kief/55VVRU6JZbbtHatWsvuDEBAHDtchljjO0mzuf1euXxeDRUo3kSAgC0QGdNnQq1SlVVVYqLi2t0nPW74AAA1yYCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWBFhuwHYZ37aN7i6cOe/v0RWeh3XlP7vRMc1vu6nHNdI0p7b/+i4JtzlfD9MPTjAcU3hp7c4run2nzWOayRJW3cFVwc4wBkQAMAKAggAYEXIA+iFF16Qy+UKmG688cZQbwYA0MI1yTWgXr16af369f+zkQguNQEAAjVJMkRERCg5ObkpVg0AaCWa5BrQ3r17lZqaqu7du+uBBx7QgQMHGh1bW1srr9cbMAEAWr+QB1BmZqaWLFmitWvXauHChSorK9PgwYNVXV3d4Pj8/Hx5PB7/lJaWFuqWAADNUMgDKCcnR7/4xS/Up08fZWdna82aNTp+/Lg++OCDBsfPnj1bVVVV/qm8vDzULQEAmqEmvzugffv26tGjh/bt29fgcrfbLbfb3dRtAACamSb/HNCJEydUWlqqlJSUpt4UAKAFCXkAPfHEEyoqKtL+/fv1+eef695771V4eLgmTJgQ6k0BAFqwkL8Fd/DgQU2YMEHHjh1TQkKCBg0apK1btyohISHUmwIAtGAuY4yx3cT5vF6vPB6Phmq0IlyRttuxqubnmY5rKm91flK7dsLLjmskqUtEW8c1/2v/XY5r/tRtneManFN8xhdU3eOP5zmuif5wW1DbQutz1tSpUKtUVVWluLi4RsfxLDgAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKwggAAAVhBAAAArCCAAgBUEEADACgIIAGAFAQQAsIKHkV4lR/J+6rim8OlXHNdEu6Ic1zR339afclzTxhXc71Z1cv7PYcaBexzX/DLxvxzX3B1d5bgmWPvqah3XPDH4l45rzpYfdFyD5o+HkQIAmjUCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsiLDdwLXCF+68pjU+2frlYz9yXLNhxiDHNfVtg/vd6rvrnT+B/br/POy45s2EcY5r7v7zHx3XBGvsf011XNPt+P7QN4JWjTMgAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCh5FeJan/+jfHNR88kui4Jjv6gOOanDlPOK6RpLoYl+Oa6/7joOOaiP07nNc4rjgnOYia+iBqKu/5aRBVV8+uny5xXDMmiAes+qqrHdeg9eAMCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCs4GGkV4mvpsZxzTs90xzXvJ0z1nFNfEGx4xpJ8p0+7bjmbFBbunrCExIc13x3V4bjmscf/sBxDdDacAYEALCCAAIAWOE4gDZt2qRRo0YpNTVVLpdLH330UcByY4yef/55paSkqG3btsrKytLevXtD1S8AoJVwHEA1NTXq27evFixY0ODyefPm6fXXX9eiRYu0bds2tWvXTtnZ2TodxPUCAEDr5fgmhJycHOXk5DS4zBij+fPn69lnn9Xo0aMlSe+8846SkpL00Ucfafz48VfWLQCg1QjpNaCysjJVVFQoKyvLP8/j8SgzM1NbtmxpsKa2tlZerzdgAgC0fiENoIqKCklSUlJSwPykpCT/sh/Kz8+Xx+PxT2lpzm89BgC0PNbvgps9e7aqqqr8U3l5ue2WAABXQUgDKDk5WZJUWVkZML+ystK/7Ifcbrfi4uICJgBA6xfSAEpPT1dycrI2bNjgn+f1erVt2zYNGDAglJsCALRwju+CO3HihPbt2+d/XVZWpp07d6pjx47q0qWLZsyYoblz5+qGG25Qenq6nnvuOaWmpmrMmDGh7BsA0MI5DqDt27frjjvu8L+eNWuWJGnixIlasmSJnnzySdXU1Oihhx7S8ePHNWjQIK1du1Zt2rQJXdcAgBbPZYwxtps4n9frlcfj0VCNVoQr0nY7aKHC23uCqnt8x2bHNUPanAlqW1eDT76g6n73bV/HNduGpTiuqf/2mOMaNH9nTZ0KtUpVVVUXva5v/S44AMC1iQACAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwggACAFhBAAEArCCAAABWEEAAACscfx0D0BKUPdorqLohbTaGuBO7VtXEB1X3ed+oIKp4sjWc4QwIAGAFAQQAsIIAAgBYQQABAKwggAAAVhBAAAArCCAAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKzgYaRAK3ZX28NB1c2d9YDjmrqYoDblWGLxWcc1bT7+ogk6wZXiDAgAYAUBBACwggACAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwggACAFhBAAEArOBhpGiV0jacDKpuxyTnNf3cQW3qqogJC665HY+/EeJOQmfOkR87rtnxMb9rN0f8rQAArCCAAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwggACAFhBAAEArCCAAABWEEAAACsIIACAFTyMFK2S67OdQdW9NOp+xzW1STGOa6of9zqu+eyW5Y5rWqNnE7Y7rrnzwUeD2pbn3a1B1eHycAYEALCCAAIAWOE4gDZt2qRRo0YpNTVVLpdLH330UcDySZMmyeVyBUwjRowIVb8AgFbCcQDV1NSob9++WrBgQaNjRowYocOHD/un995774qaBAC0Po5vQsjJyVFOTs5Fx7jdbiUnJwfdFACg9WuSa0CFhYVKTExUz549NW3aNB07dqzRsbW1tfJ6vQETAKD1C3kAjRgxQu+88442bNigf/7nf1ZRUZFycnJUX1/f4Pj8/Hx5PB7/lJaWFuqWAADNUMg/BzR+/Hj/n2+++Wb16dNHGRkZKiws1LBhwy4YP3v2bM2aNcv/2uv1EkIAcA1o8tuwu3fvrvj4eO3bt6/B5W63W3FxcQETAKD1a/IAOnjwoI4dO6aUlJSm3hQAoAVx/BbciRMnAs5mysrKtHPnTnXs2FEdO3bUiy++qHHjxik5OVmlpaV68skndf311ys7OzukjQMAWjbHAbR9+3bdcccd/tffX7+ZOHGiFi5cqF27dunf//3fdfz4caWmpmr48OH67W9/K7fbHbquAQAtnuMAGjp0qIwxjS7/9NNPr6ghwKb6r0sc10R87Xw7HQpcjmtGRf3Ucc3+P/VwXCNJn2QudFzTOaJtUNtyKtIV7rjmdMfgrjZ4gqrC5eJZcAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALAi5F/JDeAyXOSJ8o2W1NY6run6y68c10jSnW/NdFzz3/csCmpbuHZxBgQAsIIAAgBYQQABAKwggAAAVhBAAAArCCAAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKwggAAAVvAwUqAVc0VGBVfXtj7EnYTOrjPOe0vcXtMEneBKcQYEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFbwMFKgFSt545ag6v572MLQNhJCMx6f7rgm+vNtTdAJrhRnQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQ8jRasUFhsbXF17T4g7adiRu9Ic19w1/TPHNf+RuMBxzTlX53fTD04kOq6J2/KN45qzjitwNXAGBACwggACAFjhKIDy8/N12223KTY2VomJiRozZoxKSkoCxpw+fVq5ubnq1KmTYmJiNG7cOFVWVoa0aQBAy+cogIqKipSbm6utW7dq3bp1qqur0/Dhw1VTU+MfM3PmTH388cdasWKFioqKdOjQIY0dOzbkjQMAWjZHNyGsXbs24PWSJUuUmJioHTt2aMiQIaqqqtK//du/admyZbrzzjslSYsXL9ZNN92krVu36p/+6Z9C1zkAoEW7omtAVVVVkqSOHTtKknbs2KG6ujplZWX5x9x4443q0qWLtmzZ0uA6amtr5fV6AyYAQOsXdAD5fD7NmDFDAwcOVO/evSVJFRUVioqKUvv27QPGJiUlqaKiosH15Ofny+Px+Ke0NOe3pwIAWp6gAyg3N1e7d+/W8uXLr6iB2bNnq6qqyj+Vl5df0foAAC1DUB9EzcvL0+rVq7Vp0yZ17tzZPz85OVlnzpzR8ePHA86CKisrlZyc3OC63G633G53MG0AAFowR2dAxhjl5eVp5cqV2rhxo9LT0wOW9+vXT5GRkdqwYYN/XklJiQ4cOKABAwaEpmMAQKvg6AwoNzdXy5Yt06pVqxQbG+u/ruPxeNS2bVt5PB796le/0qxZs9SxY0fFxcVp+vTpGjBgAHfAAQACOAqghQsXSpKGDh0aMH/x4sWaNGmSJOn3v/+9wsLCNG7cONXW1io7O1tvvvlmSJoFALQeLmOMsd3E+bxerzwej4ZqtCJckbbbuSaE9b0pqLo9uTGOa5LT/q/jmiMlCY5rJt9Z6LhGkp7q9HVQdQhOn88nOa7p8ouvQt8IQuqsqVOhVqmqqkpxcXGNjuNZcAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALAiqG9ERfPl6tfLcU3b3x8Jalv/nfFuUHWO9bk6m2nuak2d45pIV3hQ26qsr3VcM+dQjuOazq8F1x9aB86AAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwggACAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMAKHkbaytR0jXFc8373fw1ya1FB1jU9n3xB1c08NNhxzW8S1zuuyf4813FNbGG045rqbo5LJEnps7cEUVXtuCJMO4PYDloLzoAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoeRtrKRH+4zXHNLw4/HNS2jv64neMaXxDPL61z/nxVvf1//uC8SFLpbacd10z78VTHNek7dzmukTGOS+KdbwW4ajgDAgBYQQABAKwggAAAVhBAAAArCCAAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKwggAAAVhBAAAArXMYE8YTDJuT1euXxeDRUoxXhirTdDgDAobOmToVapaqqKsXFxTU6jjMgAIAVBBAAwApHAZSfn6/bbrtNsbGxSkxM1JgxY1RSUhIwZujQoXK5XAHTww8H930zAIDWy1EAFRUVKTc3V1u3btW6detUV1en4cOHq6amJmDclClTdPjwYf80b968kDYNAGj5HH0j6tq1awNeL1myRImJidqxY4eGDBninx8dHa3k5OTQdAgAaJWu6BpQVVWVJKljx44B85cuXar4+Hj17t1bs2fP1smTJxtdR21trbxeb8AEAGj9HJ0Bnc/n82nGjBkaOHCgevfu7Z9///33q2vXrkpNTdWuXbv01FNPqaSkRB9++GGD68nPz9eLL74YbBsAgBYq6M8BTZs2TZ988ok2b96szp07Nzpu48aNGjZsmPbt26eMjIwLltfW1qq2ttb/2uv1Ki0tjc8BAUALdbmfAwrqDCgvL0+rV6/Wpk2bLho+kpSZmSlJjQaQ2+2W2+0Opg0AQAvmKICMMZo+fbpWrlypwsJCpaenX7Jm586dkqSUlJSgGgQAtE6OAig3N1fLli3TqlWrFBsbq4qKCkmSx+NR27ZtVVpaqmXLlmnkyJHq1KmTdu3apZkzZ2rIkCHq06dPk/wAAICWydE1IJfL1eD8xYsXa9KkSSovL9eDDz6o3bt3q6amRmlpabr33nv17LPPXvR9wPPxLDgAaNma5BrQpbIqLS1NRUVFTlYJALhG8Sw4AIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVEbYb+CFjjCTprOokY7kZAIBjZ1Un6X/+P29Mswug6upqSdJmrbHcCQDgSlRXV8vj8TS63GUuFVFXmc/n06FDhxQbGyuXyxWwzOv1Ki0tTeXl5YqLi7PUoX3sh3PYD+ewH85hP5zTHPaDMUbV1dVKTU1VWFjjV3qa3RlQWFiYOnfufNExcXFx1/QB9j32wznsh3PYD+ewH86xvR8udubzPW5CAABYQQABAKxoUQHkdrs1Z84cud1u261YxX44h/1wDvvhHPbDOS1pPzS7mxAAANeGFnUGBABoPQggAIAVBBAAwAoCCABgBQEEALCixQTQggUL1K1bN7Vp00aZmZn64osvbLd01b3wwgtyuVwB04033mi7rSa3adMmjRo1SqmpqXK5XProo48Clhtj9PzzzyslJUVt27ZVVlaW9u7da6fZJnSp/TBp0qQLjo8RI0bYabaJ5Ofn67bbblNsbKwSExM1ZswYlZSUBIw5ffq0cnNz1alTJ8XExGjcuHGqrKy01HHTuJz9MHTo0AuOh4cffthSxw1rEQH0/vvva9asWZozZ46+/PJL9e3bV9nZ2Tpy5Ijt1q66Xr166fDhw/5p8+bNtltqcjU1Nerbt68WLFjQ4PJ58+bp9ddf16JFi7Rt2za1a9dO2dnZOn369FXutGldaj9I0ogRIwKOj/fee+8qdtj0ioqKlJubq61bt2rdunWqq6vT8OHDVVNT4x8zc+ZMffzxx1qxYoWKiop06NAhjR071mLXoXc5+0GSpkyZEnA8zJs3z1LHjTAtQP/+/U1ubq7/dX19vUlNTTX5+fkWu7r65syZY/r27Wu7DaskmZUrV/pf+3w+k5ycbF5++WX/vOPHjxu3223ee+89Cx1eHT/cD8YYM3HiRDN69Ggr/dhy5MgRI8kUFRUZY8793UdGRpoVK1b4x/zjH/8wksyWLVtstdnkfrgfjDHm9ttvN4899pi9pi5Dsz8DOnPmjHbs2KGsrCz/vLCwMGVlZWnLli0WO7Nj7969Sk1NVffu3fXAAw/owIEDtluyqqysTBUVFQHHh8fjUWZm5jV5fBQWFioxMVE9e/bUtGnTdOzYMdstNamqqipJUseOHSVJO3bsUF1dXcDxcOONN6pLly6t+nj44X743tKlSxUfH6/evXtr9uzZOnnypI32GtXsnob9Q99++63q6+uVlJQUMD8pKUl79uyx1JUdmZmZWrJkiXr27KnDhw/rxRdf1ODBg7V7927Fxsbabs+KiooKSWrw+Ph+2bVixIgRGjt2rNLT01VaWqpnnnlGOTk52rJli8LDw223F3I+n08zZszQwIED1bt3b0nnjoeoqCi1b98+YGxrPh4a2g+SdP/996tr165KTU3Vrl279NRTT6mkpEQffvihxW4DNfsAwv/Iycnx/7lPnz7KzMxU165d9cEHH+hXv/qVxc7QHIwfP97/55tvvll9+vRRRkaGCgsLNWzYMIudNY3c3Fzt3r37mrgOejGN7YeHHnrI/+ebb75ZKSkpGjZsmEpLS5WRkXG122xQs38LLj4+XuHh4RfcxVJZWank5GRLXTUP7du3V48ePbRv3z7brVjz/THA8XGh7t27Kz4+vlUeH3l5eVq9erUKCgoCvj8sOTlZZ86c0fHjxwPGt9bjobH90JDMzExJalbHQ7MPoKioKPXr108bNmzwz/P5fNqwYYMGDBhgsTP7Tpw4odLSUqWkpNhuxZr09HQlJycHHB9er1fbtm275o+PgwcP6tixY63q+DDGKC8vTytXrtTGjRuVnp4esLxfv36KjIwMOB5KSkp04MCBVnU8XGo/NGTnzp2S1LyOB9t3QVyO5cuXG7fbbZYsWWL+/ve/m4ceesi0b9/eVFRU2G7tqnr88cdNYWGhKSsrM5999pnJysoy8fHx5siRI7Zba1LV1dWmuLjYFBcXG0nm1VdfNcXFxeabb74xxhjzu9/9zrRv396sWrXK7Nq1y4wePdqkp6ebU6dOWe48tC62H6qrq80TTzxhtmzZYsrKysz69evNT37yE3PDDTeY06dP2249ZKZNm2Y8Ho8pLCw0hw8f9k8nT570j3n44YdNly5dzMaNG8327dvNgAEDzIABAyx2HXqX2g/79u0zL730ktm+fbspKyszq1atMt27dzdDhgyx3HmgFhFAxhjzxhtvmC5dupioqCjTv39/s3XrVtstXXX33XefSUlJMVFRUea6664z9913n9m3b5/ttppcQUGBkXTBNHHiRGPMuVuxn3vuOZOUlGTcbrcZNmyYKSkpsdt0E7jYfjh58qQZPny4SUhIMJGRkaZr165mypQpre6XtIZ+fklm8eLF/jGnTp0yjzzyiOnQoYOJjo429957rzl8+LC9ppvApfbDgQMHzJAhQ0zHjh2N2+02119/vfnNb35jqqqq7Db+A3wfEADAimZ/DQgA0DoRQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAV/w/hgVLrpVGHsAAAAABJRU5ErkJggg==", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.figure()\n", "plt.title(\"Prediction: {}\".format(prediction))\n", "plt.imshow(img)\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "6377f41a-5654-410b-8bad-d392e9dce7b8", "metadata": { "tags": [] }, "source": [ "#### Stop Triton Server on each executor" ] }, { "cell_type": "code", "execution_count": null, "id": "d06de00e", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-02-04 14:00:18,330 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\n", "2025-02-04 14:00:28,520 - INFO - Sucessfully stopped 1 servers. \n" ] }, { "data": { "text/plain": [ "[True]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "server_manager.stop_servers()" ] }, { "cell_type": "code", "execution_count": 67, "id": "f612dc0b-538f-4ecf-81f7-ef6b58c493ab", "metadata": {}, "outputs": [], "source": [ "if not on_databricks: # on databricks, spark.stop() puts the cluster in a bad state\n", " spark.stop()" ] }, { "cell_type": "code", "execution_count": null, "id": "490fc849-e47a-48d7-accc-429ff1cced6b", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "spark-dl-tf", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: examples/ML+DL-Examples/Spark-DL/dl_inference/tensorflow/keras_preprocessing_tf.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "id": "7fcc021a", "metadata": {}, "source": [ "\n", "\n", "# Pyspark TensorFlow Inference\n", "\n", "### Classification using Keras Preprocessing Layers\n", "\n", "In this notebook, we demonstrate distributed inference using Keras preprocessing layers to classify structured data. \n", "From: https://www.tensorflow.org/tutorials/structured_data/preprocessing_layers" ] }, { "cell_type": "markdown", "id": "35203476", "metadata": {}, "source": [ "Note that cuFFT/cuDNN/cuBLAS registration errors are expected (as of `tf=2.17.0`) and will not affect behavior, as noted in [this issue.](https://github.com/tensorflow/tensorflow/issues/62075) " ] }, { "cell_type": "code", "execution_count": 1, "id": "01162f42-0637-4dfe-8d7d-b577e4ffd017", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-02-04 13:59:29.670948: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", "2025-02-04 13:59:29.679838: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "2025-02-04 13:59:29.689914: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "2025-02-04 13:59:29.692851: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", "2025-02-04 13:59:29.700499: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", "To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", "2025-02-04 13:59:30.139239: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n" ] } ], "source": [ "import os\n", "import shutil\n", "import numpy as np\n", "import pandas as pd\n", "import tensorflow as tf\n", "\n", "from tensorflow.keras import layers" ] }, { "cell_type": "code", "execution_count": 2, "id": "0d586fb8", "metadata": {}, "outputs": [], "source": [ "os.mkdir('models') if not os.path.exists('models') else None" ] }, { "cell_type": "code", "execution_count": null, "id": "9fa3e1b7-58cd-45f9-9fee-85f25a31c3c6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.17.0\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", "I0000 00:00:1738706370.524690 3686377 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1738706370.550329 3686377 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1738706370.553239 3686377 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n" ] } ], "source": [ "print(tf.__version__)\n", "\n", "# Enable GPU memory growth\n", "gpus = tf.config.experimental.list_physical_devices('GPU')\n", "if gpus:\n", " try:\n", " for gpu in gpus:\n", " tf.config.experimental.set_memory_growth(gpu, True)\n", " except RuntimeError as e:\n", " print(e)" ] }, { "cell_type": "markdown", "id": "b2402b9a", "metadata": {}, "source": [ "#### Download dataset\n", "\n", "Download the PetFinder dataset from Kaggle, which where each row describes a pet and the goal is to predict adoption speed." ] }, { "cell_type": "code", "execution_count": 4, "id": "9326b072-a53c-40c4-a6cb-bd4d3d644d03", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Downloading data from http://storage.googleapis.com/download.tensorflow.org/data/petfinder-mini.zip\n", "\u001b[1m1668792/1668792\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 0us/step\n" ] } ], "source": [ "import pathlib\n", "import os\n", "dataset_url = 'http://storage.googleapis.com/download.tensorflow.org/data/petfinder-mini.zip'\n", "\n", "data_dir = tf.keras.utils.get_file('petfinder_mini.zip', dataset_url, extract=True, cache_dir='.')\n", "data_dir = pathlib.Path(data_dir)\n", "try:\n", " # pet-finder-mini might be under a parent a directory petfinder_mini_extracted. Check if this is the case:\n", " dataset = os.path.join(os.path.dirname(data_dir), 'petfinder_mini_extracted/petfinder-mini/petfinder-mini.csv')\n", " dataframe = pd.read_csv(dataset)\n", "except:\n", " dataset = os.path.join(os.path.dirname(data_dir), 'petfinder-mini/petfinder-mini.csv')\n", " dataframe = pd.read_csv(dataset)" ] }, { "cell_type": "code", "execution_count": 5, "id": "e98480ef-d13d-44c0-a227-e9a22f9bf2b0", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
TypeAgeBreed1GenderColor1Color2MaturitySizeFurLengthVaccinatedSterilizedHealthFeeDescriptionPhotoAmtAdoptionSpeed
0Cat3TabbyMaleBlackWhiteSmallShortNoNoHealthy100Nibble is a 3+ month old ball of cuteness. He ...12
1Cat1Domestic Medium HairMaleBlackBrownMediumMediumNot SureNot SureHealthy0I just found it alone yesterday near my apartm...20
2Dog1Mixed BreedMaleBrownWhiteMediumMediumYesNoHealthy0Their pregnant mother was dumped by her irresp...73
3Dog4Mixed BreedFemaleBlackBrownMediumShortYesNoHealthy150Good guard dog, very alert, active, obedience ...82
4Dog1Mixed BreedMaleBlackNo ColorMediumShortNoNoHealthy0This handsome yet cute boy is up for adoption....32
\n", "
" ], "text/plain": [ " Type Age Breed1 Gender Color1 Color2 MaturitySize \\\n", "0 Cat 3 Tabby Male Black White Small \n", "1 Cat 1 Domestic Medium Hair Male Black Brown Medium \n", "2 Dog 1 Mixed Breed Male Brown White Medium \n", "3 Dog 4 Mixed Breed Female Black Brown Medium \n", "4 Dog 1 Mixed Breed Male Black No Color Medium \n", "\n", " FurLength Vaccinated Sterilized Health Fee \\\n", "0 Short No No Healthy 100 \n", "1 Medium Not Sure Not Sure Healthy 0 \n", "2 Medium Yes No Healthy 0 \n", "3 Short Yes No Healthy 150 \n", "4 Short No No Healthy 0 \n", "\n", " Description PhotoAmt AdoptionSpeed \n", "0 Nibble is a 3+ month old ball of cuteness. He ... 1 2 \n", "1 I just found it alone yesterday near my apartm... 2 0 \n", "2 Their pregnant mother was dumped by her irresp... 7 3 \n", "3 Good guard dog, very alert, active, obedience ... 8 2 \n", "4 This handsome yet cute boy is up for adoption.... 3 2 " ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataframe.head()" ] }, { "cell_type": "markdown", "id": "27d844f1", "metadata": {}, "source": [ "### Prepare dataset" ] }, { "cell_type": "code", "execution_count": 6, "id": "e8efce25-a835-4cbd-b8a2-1418ba2c1d31", "metadata": {}, "outputs": [], "source": [ "# In the original dataset, `'AdoptionSpeed'` of `4` indicates\n", "# a pet was not adopted.\n", "dataframe['target'] = np.where(dataframe['AdoptionSpeed']==4, 0, 1)\n", "\n", "# Drop unused features.\n", "dataframe = dataframe.drop(columns=['AdoptionSpeed', 'Description'])" ] }, { "cell_type": "code", "execution_count": null, "id": "00d403cf-9ae7-4780-9fac-13d920d8b395", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/rishic/anaconda3/envs/spark-dl-tf/lib/python3.11/site-packages/numpy/core/fromnumeric.py:59: FutureWarning: 'DataFrame.swapaxes' is deprecated and will be removed in a future version. Please use 'DataFrame.transpose' instead.\n", " return bound(*args, **kwds)\n" ] } ], "source": [ "train, val, test = np.split(dataframe.sample(frac=1), [int(0.8*len(dataframe)), int(0.9*len(dataframe))])" ] }, { "cell_type": "code", "execution_count": 8, "id": "4206a56e-5403-42a9-805e-e037044e7995", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "9229 training examples\n", "1154 validation examples\n", "1154 test examples\n" ] } ], "source": [ "print(len(train), 'training examples')\n", "print(len(val), 'validation examples')\n", "print(len(test), 'test examples')" ] }, { "cell_type": "markdown", "id": "a7fa64f8", "metadata": {}, "source": [ "Create an input pipeline which converts each dataset into a tf.data.Dataset with shuffling and batching." ] }, { "cell_type": "code", "execution_count": 9, "id": "499ade5f-ac8a-47ca-a021-071239dfe97d", "metadata": {}, "outputs": [], "source": [ "def df_to_dataset(dataframe, shuffle=True, batch_size=32):\n", " df = dataframe.copy()\n", " labels = df.pop('target')\n", " df = {key: value.to_numpy()[:,tf.newaxis] for key, value in dataframe.items()}\n", " ds = tf.data.Dataset.from_tensor_slices((dict(df), labels))\n", " if shuffle:\n", " ds = ds.shuffle(buffer_size=len(dataframe))\n", " ds = ds.batch(batch_size)\n", " ds = ds.prefetch(batch_size)\n", " return ds" ] }, { "cell_type": "markdown", "id": "96065bed", "metadata": {}, "source": [ "Check the format of the data returned by the pipeline:" ] }, { "cell_type": "code", "execution_count": null, "id": "b9ec57c9-080e-4626-9e03-acf309cf3736", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "I0000 00:00:1738706370.981571 3686377 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1738706370.984478 3686377 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1738706370.987280 3686377 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1738706371.105121 3686377 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1738706371.106231 3686377 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1738706371.107182 3686377 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "2025-02-04 13:59:31.108098: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 40337 MB memory: -> device: 0, name: NVIDIA RTX A6000, pci bus id: 0000:01:00.0, compute capability: 8.6\n" ] } ], "source": [ "batch_size = 5\n", "train_ds = df_to_dataset(train, batch_size=batch_size)" ] }, { "cell_type": "markdown", "id": "bdc8571c", "metadata": {}, "source": [ "(Note that OUT_OF_RANGE errors are safe to ignore: https://github.com/tensorflow/tensorflow/issues/62963)." ] }, { "cell_type": "code", "execution_count": null, "id": "dfcbf268-4508-4eb8-abe1-acf1dbb97bd5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Every feature: ['Type', 'Age', 'Breed1', 'Gender', 'Color1', 'Color2', 'MaturitySize', 'FurLength', 'Vaccinated', 'Sterilized', 'Health', 'Fee', 'PhotoAmt', 'target']\n", "A batch of ages: tf.Tensor(\n", "[[ 4]\n", " [60]\n", " [24]\n", " [ 1]\n", " [ 2]], shape=(5, 1), dtype=int64)\n", "A batch of targets: tf.Tensor([1 1 1 1 1], shape=(5,), dtype=int64)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2025-02-04 13:59:31.170523: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" ] } ], "source": [ "[(train_features, label_batch)] = train_ds.take(1)\n", "print('Every feature:', list(train_features.keys()))\n", "print('A batch of ages:', train_features['Age'])\n", "print('A batch of targets:', label_batch )" ] }, { "cell_type": "markdown", "id": "d5a2d10c", "metadata": {}, "source": [ "### Apply Keras preprocessing layers\n", "\n", "We'll define a normalization layer for numeric features, and a category encoding for categorical features." ] }, { "cell_type": "code", "execution_count": 12, "id": "6c09dc4b-3a2a-44f5-b41c-821ec30b87b1", "metadata": {}, "outputs": [], "source": [ "def get_normalization_layer(name, dataset):\n", " # Create a Normalization layer for the feature.\n", " normalizer = layers.Normalization(axis=None)\n", "\n", " # Prepare a Dataset that only yields the feature.\n", " feature_ds = dataset.map(lambda x, y: x[name])\n", "\n", " # Learn the statistics of the data.\n", " normalizer.adapt(feature_ds)\n", "\n", " return normalizer" ] }, { "cell_type": "code", "execution_count": 13, "id": "59bb91dc-360a-4a89-a9ea-bebc1ddbf1b7", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-02-04 13:59:32.726183: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "photo_count_col = train_features['PhotoAmt']\n", "layer = get_normalization_layer('PhotoAmt', train_ds)\n", "layer(photo_count_col)" ] }, { "cell_type": "code", "execution_count": 14, "id": "4623b612-e924-472b-9ef4-c7f14f9f53c5", "metadata": {}, "outputs": [], "source": [ "def get_category_encoding_layer(name, dataset, dtype, max_tokens=None):\n", " # Create a layer that turns strings into integer indices.\n", " if dtype == 'string':\n", " index = layers.StringLookup(max_tokens=max_tokens)\n", " # Otherwise, create a layer that turns integer values into integer indices.\n", " else:\n", " index = layers.IntegerLookup(max_tokens=max_tokens)\n", "\n", " # Prepare a `tf.data.Dataset` that only yields the feature.\n", " feature_ds = dataset.map(lambda x, y: x[name])\n", "\n", " # Learn the set of possible values and assign them a fixed integer index.\n", " index.adapt(feature_ds)\n", "\n", " # Encode the integer indices.\n", " encoder = layers.CategoryEncoding(num_tokens=index.vocabulary_size())\n", "\n", " # Apply multi-hot encoding to the indices. The lambda function captures the\n", " # layer, so you can use them, or include them in the Keras Functional model later.\n", " return lambda feature: encoder(index(feature))" ] }, { "cell_type": "code", "execution_count": 15, "id": "0a40e9ee-20a5-4a42-8543-c267f99af55e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_type_col = train_features['Type']\n", "test_type_layer = get_category_encoding_layer(name='Type',\n", " dataset=train_ds,\n", " dtype='string')\n", "test_type_layer(test_type_col)" ] }, { "cell_type": "code", "execution_count": 16, "id": "ff63a5cc-71f4-428e-9299-a8018edc7648", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-02-04 13:59:34.294276: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_age_col = train_features['Age']\n", "test_age_layer = get_category_encoding_layer(name='Age',\n", " dataset=train_ds,\n", " dtype='int64',\n", " max_tokens=5)\n", "test_age_layer(test_age_col)" ] }, { "cell_type": "markdown", "id": "afefbcf2", "metadata": {}, "source": [ "### Preprocess selected features\n", "\n", "Apply the preprocessing helper class defined earlier. Add all the feature inputs to a list.\n" ] }, { "cell_type": "code", "execution_count": 17, "id": "2b040b0e-d8ca-4cf0-917c-dd9a272e1f0a", "metadata": {}, "outputs": [], "source": [ "batch_size = 256\n", "train_ds = df_to_dataset(train, batch_size=batch_size)\n", "val_ds = df_to_dataset(val, shuffle=False, batch_size=batch_size)\n", "test_ds = df_to_dataset(test, shuffle=False, batch_size=batch_size)" ] }, { "cell_type": "code", "execution_count": 18, "id": "19df498e-4dd1-467a-8741-e1f5e15932a5", "metadata": {}, "outputs": [], "source": [ "all_inputs = {}\n", "encoded_features = []\n", "\n", "# Numerical features.\n", "for header in ['PhotoAmt', 'Fee']:\n", " numeric_col = tf.keras.Input(shape=(1,), name=header)\n", " normalization_layer = get_normalization_layer(header, train_ds)\n", " encoded_numeric_col = normalization_layer(numeric_col)\n", " all_inputs[header] = numeric_col\n", " encoded_features.append(encoded_numeric_col)" ] }, { "cell_type": "code", "execution_count": 19, "id": "1d12579f-34fb-40b0-a16a-3e13cfea8178", "metadata": {}, "outputs": [], "source": [ "age_col = tf.keras.Input(shape=(1,), name='Age', dtype='int64')\n", "\n", "encoding_layer = get_category_encoding_layer(name='Age',\n", " dataset=train_ds,\n", " dtype='int64',\n", " max_tokens=5)\n", "encoded_age_col = encoding_layer(age_col)\n", "all_inputs['Age'] = age_col\n", "encoded_features.append(encoded_age_col)" ] }, { "cell_type": "code", "execution_count": 20, "id": "bff286eb-7ad7-4d3a-8fa4-c729692d1425", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-02-04 13:59:34.588989: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n", "2025-02-04 13:59:35.029267: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" ] } ], "source": [ "categorical_cols = ['Type', 'Color1', 'Color2', 'Gender', 'MaturitySize',\n", " 'FurLength', 'Vaccinated', 'Sterilized', 'Health', 'Breed1']\n", "\n", "for header in categorical_cols:\n", " categorical_col = tf.keras.Input(shape=(1,), name=header, dtype='string')\n", " encoding_layer = get_category_encoding_layer(name=header,\n", " dataset=train_ds,\n", " dtype='string',\n", " max_tokens=5)\n", " encoded_categorical_col = encoding_layer(categorical_col)\n", " all_inputs[header] = categorical_col\n", " encoded_features.append(encoded_categorical_col)" ] }, { "cell_type": "markdown", "id": "e0dfac0d", "metadata": {}, "source": [ "### Create, compile, and train model" ] }, { "cell_type": "code", "execution_count": 21, "id": "79247436-32d8-4738-a656-3f288c77001c", "metadata": {}, "outputs": [], "source": [ "all_features = tf.keras.layers.concatenate(encoded_features)\n", "x = tf.keras.layers.Dense(32, activation=\"relu\")(all_features)\n", "x = tf.keras.layers.Dropout(0.5)(x)\n", "output = tf.keras.layers.Dense(1)(x)\n", "\n", "model = tf.keras.Model(all_inputs, output)" ] }, { "cell_type": "code", "execution_count": 22, "id": "dbc85d3e-6d1e-4167-9516-b1182e880542", "metadata": {}, "outputs": [], "source": [ "model.compile(optimizer='adam',\n", " loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),\n", " metrics=[\"accuracy\"],\n", " run_eagerly=True)" ] }, { "cell_type": "code", "execution_count": 23, "id": "bc9836c8-3c1a-41ad-8833-a946bafcfb00", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/10\n", "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 16ms/step - accuracy: 0.3658 - loss: 0.7746 - val_accuracy: 0.6854 - val_loss: 0.5841\n", "Epoch 2/10\n", "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 16ms/step - accuracy: 0.6270 - loss: 0.6023 - val_accuracy: 0.7383 - val_loss: 0.5593\n", "Epoch 3/10\n", "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 16ms/step - accuracy: 0.6650 - loss: 0.5781 - val_accuracy: 0.7392 - val_loss: 0.5442\n", "Epoch 4/10\n", "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 17ms/step - accuracy: 0.6609 - loss: 0.5744 - val_accuracy: 0.7418 - val_loss: 0.5329\n", "Epoch 5/10\n", "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 15ms/step - accuracy: 0.6845 - loss: 0.5555 - val_accuracy: 0.7444 - val_loss: 0.5261\n", "Epoch 6/10\n", "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 15ms/step - accuracy: 0.6910 - loss: 0.5465 - val_accuracy: 0.7513 - val_loss: 0.5198\n", "Epoch 7/10\n", "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 15ms/step - accuracy: 0.7018 - loss: 0.5475 - val_accuracy: 0.7556 - val_loss: 0.5145\n", "Epoch 8/10\n", "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 15ms/step - accuracy: 0.7026 - loss: 0.5410 - val_accuracy: 0.7496 - val_loss: 0.5099\n", "Epoch 9/10\n", "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 15ms/step - accuracy: 0.7145 - loss: 0.5315 - val_accuracy: 0.7530 - val_loss: 0.5066\n", "Epoch 10/10\n", "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 15ms/step - accuracy: 0.7099 - loss: 0.5316 - val_accuracy: 0.7539 - val_loss: 0.5038\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.fit(train_ds, epochs=10, validation_data=val_ds)" ] }, { "cell_type": "code", "execution_count": null, "id": "fbccebaa-fc24-4a58-a032-222cef8fdf08", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m5/5\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 7ms/step - accuracy: 0.7416 - loss: 0.5196 \n", "Accuracy 0.7443674206733704\n" ] } ], "source": [ "loss, accuracy = model.evaluate(test_ds)\n", "print(\"Accuracy\", accuracy)" ] }, { "cell_type": "markdown", "id": "7534616c-8561-4869-b6e9-7254ebdb2c3f", "metadata": {}, "source": [ "### Save and reload model\n", "\n", "Demonstrate saving the trained model and reloading it for inference." ] }, { "cell_type": "code", "execution_count": 25, "id": "6bf0d024", "metadata": {}, "outputs": [], "source": [ "model.save('models/my_pet_classifier.keras')" ] }, { "cell_type": "code", "execution_count": 26, "id": "d1a7be62", "metadata": {}, "outputs": [], "source": [ "reloaded_model = tf.keras.models.load_model('models/my_pet_classifier.keras')" ] }, { "cell_type": "code", "execution_count": null, "id": "f3d2a2d5-fd4d-4320-bacc-fd4571cec709", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 27ms/step\n", "This particular pet had a 83.2 percent probability of getting adopted.\n" ] } ], "source": [ "sample = {\n", " 'Type': 'Cat',\n", " 'Age': 3,\n", " 'Breed1': 'Tabby',\n", " 'Gender': 'Male',\n", " 'Color1': 'Black',\n", " 'Color2': 'White',\n", " 'MaturitySize': 'Small',\n", " 'FurLength': 'Short',\n", " 'Vaccinated': 'No',\n", " 'Sterilized': 'No',\n", " 'Health': 'Healthy',\n", " 'Fee': 100,\n", " 'PhotoAmt': 2,\n", "}\n", "\n", "input_dict = {name: tf.convert_to_tensor([value]) for name, value in sample.items()}\n", "predictions = reloaded_model.predict(input_dict)\n", "prob = tf.nn.sigmoid(predictions[0])\n", "\n", "print(\n", " \"This particular pet had a %.1f percent probability \"\n", " \"of getting adopted.\" % (100 * prob)\n", ")" ] }, { "cell_type": "markdown", "id": "f7bbfe69-93ed-4452-8985-c6685e0726c3", "metadata": {}, "source": [ "## PySpark" ] }, { "cell_type": "code", "execution_count": 28, "id": "fc8a0536", "metadata": {}, "outputs": [], "source": [ "from pyspark.sql.functions import col, struct, pandas_udf\n", "from pyspark.ml.functions import predict_batch_udf\n", "from pyspark.sql.types import *\n", "from pyspark.sql import SparkSession\n", "from pyspark import SparkConf\n", "import json\n", "import pandas as pd" ] }, { "cell_type": "markdown", "id": "bb5aa875", "metadata": {}, "source": [ "Check the cluster environment to handle any platform-specific Spark configurations." ] }, { "cell_type": "code", "execution_count": 29, "id": "7701420e", "metadata": {}, "outputs": [], "source": [ "on_databricks = os.environ.get(\"DATABRICKS_RUNTIME_VERSION\", False)\n", "on_dataproc = os.environ.get(\"DATAPROC_IMAGE_VERSION\", False)\n", "on_standalone = not (on_databricks or on_dataproc)" ] }, { "cell_type": "markdown", "id": "5e231dbd", "metadata": {}, "source": [ "#### Create Spark Session\n", "\n", "For local standalone clusters, we'll connect to the cluster and create the Spark Session. \n", "For CSP environments, Spark will either be preconfigured (Databricks) or we'll need to create the Spark Session (Dataproc)." ] }, { "cell_type": "code", "execution_count": 30, "id": "60dff1da", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "25/02/04 13:59:42 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)\n", "25/02/04 13:59:42 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\n", "Setting default log level to \"WARN\".\n", "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n", "25/02/04 13:59:43 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n" ] } ], "source": [ "conf = SparkConf()\n", "\n", "if 'spark' not in globals():\n", " if on_standalone:\n", " import socket\n", " \n", " conda_env = os.environ.get(\"CONDA_PREFIX\")\n", " hostname = socket.gethostname()\n", " conf.setMaster(f\"spark://{hostname}:7077\")\n", " conf.set(\"spark.pyspark.python\", f\"{conda_env}/bin/python\")\n", " conf.set(\"spark.pyspark.driver.python\", f\"{conda_env}/bin/python\")\n", " elif on_dataproc:\n", " conf.set(\"spark.executorEnv.TF_GPU_ALLOCATOR\", \"cuda_malloc_async\")\n", "\n", " conf.set(\"spark.executor.cores\", \"8\")\n", " conf.set(\"spark.task.resource.gpu.amount\", \"0.125\")\n", " conf.set(\"spark.executor.resource.gpu.amount\", \"1\")\n", " conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"true\")\n", " conf.set(\"spark.python.worker.reuse\", \"true\")\n", "\n", "conf.set(\"spark.sql.execution.arrow.maxRecordsPerBatch\", \"1000\")\n", "spark = SparkSession.builder.appName(\"spark-dl-examples\").config(conf=conf).getOrCreate()\n", "sc = spark.sparkContext" ] }, { "cell_type": "markdown", "id": "fa2333d1", "metadata": {}, "source": [ "#### Create PySpark DataFrame" ] }, { "cell_type": "code", "execution_count": 31, "id": "3c64fd7b-3d1e-40f8-ab64-b5c13f8bbe77", "metadata": {}, "outputs": [], "source": [ "df = spark.createDataFrame(dataframe).repartition(8)" ] }, { "cell_type": "code", "execution_count": 32, "id": "1be8215b-5068-41b4-849c-1c3ea7bb108a", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "data_path = \"spark-dl-datasets/petfinder-mini\"\n", "if on_databricks:\n", " dbutils.fs.mkdirs(\"/FileStore/spark-dl-datasets\")\n", " data_path = \"dbfs:/FileStore/\" + data_path\n", "\n", "df.write.mode(\"overwrite\").parquet(data_path)" ] }, { "cell_type": "markdown", "id": "7cec4e0e", "metadata": {}, "source": [ "#### Load and preprocess DataFrame" ] }, { "cell_type": "code", "execution_count": 33, "id": "0892f845", "metadata": {}, "outputs": [], "source": [ "df = spark.read.parquet(data_path).cache()" ] }, { "cell_type": "code", "execution_count": 34, "id": "952645dd", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['Type', 'Age', 'Breed1', 'Gender', 'Color1', 'Color2', 'MaturitySize', 'FurLength', 'Vaccinated', 'Sterilized', 'Health', 'Fee', 'PhotoAmt', 'target']\n" ] } ], "source": [ "columns = df.columns\n", "print(columns)" ] }, { "cell_type": "code", "execution_count": 35, "id": "b9c24c0d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['Type', 'Age', 'Breed1', 'Gender', 'Color1', 'Color2', 'MaturitySize', 'FurLength', 'Vaccinated', 'Sterilized', 'Health', 'Fee', 'PhotoAmt']\n" ] } ], "source": [ "# remove label column\n", "columns.remove(\"target\")\n", "print(columns)" ] }, { "cell_type": "code", "execution_count": 36, "id": "d4dbde99-cf65-4c15-a163-754a0201a48d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+----+---+--------------------+------+------+--------+------------+---------+----------+----------+-------+---+--------+------+\n", "|Type|Age| Breed1|Gender|Color1| Color2|MaturitySize|FurLength|Vaccinated|Sterilized| Health|Fee|PhotoAmt|target|\n", "+----+---+--------------------+------+------+--------+------------+---------+----------+----------+-------+---+--------+------+\n", "| Dog| 3| Mixed Breed| Male| Black|No Color| Small| Medium| Not Sure| Not Sure|Healthy| 0| 2| 0|\n", "| Dog| 9| Mixed Breed| Male| Gray|No Color| Medium| Short| Not Sure| No|Healthy| 0| 4| 1|\n", "| Cat| 4| Domestic Short Hair| Male| Black| Gray| Medium| Short| Not Sure| Not Sure|Healthy| 0| 4| 1|\n", "| Cat| 6| Domestic Short Hair| Male|Yellow| White| Medium| Short| No| No|Healthy| 0| 3| 1|\n", "| Cat| 6|Domestic Medium Hair| Male| Gray|No Color| Small| Medium| No| No|Healthy| 0| 4| 1|\n", "+----+---+--------------------+------+------+--------+------------+---------+----------+----------+-------+---+--------+------+\n", "only showing top 5 rows\n", "\n" ] } ], "source": [ "df.show(5)" ] }, { "cell_type": "markdown", "id": "824d7f97", "metadata": {}, "source": [ "## Inference using Spark DL API\n", "\n", "Distributed inference using the PySpark [predict_batch_udf](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.functions.predict_batch_udf.html#pyspark.ml.functions.predict_batch_udf):\n", "\n", "- predict_batch_fn uses Tensorflow APIs to load the model and return a predict function which operates on numpy arrays \n", "- predict_batch_udf will convert the Spark DataFrame columns into numpy input batches for the predict function" ] }, { "cell_type": "code", "execution_count": 37, "id": "d62eb95a-54c6-44d2-9279-38fb65e0e160", "metadata": {}, "outputs": [], "source": [ "# get absolute path to model\n", "model_path = \"{}/models/my_pet_classifier.keras\".format(os.getcwd())\n", "\n", "# For cloud environments, copy the model to the distributed file system.\n", "if on_databricks:\n", " dbutils.fs.mkdirs(\"/FileStore/spark-dl-models\")\n", " dbfs_model_path = \"/dbfs/FileStore/spark-dl-models/my_pet_classifier.keras\"\n", " shutil.copy(model_path, dbfs_model_path)\n", " model_path = dbfs_model_path\n", "elif on_dataproc:\n", " # GCS is mounted at /mnt/gcs by the init script\n", " models_dir = \"/mnt/gcs/spark-dl/models\"\n", " os.mkdir(models_dir) if not os.path.exists(models_dir) else None\n", " gcs_model_path = models_dir + \"/my_pet_classifier.keras\"\n", " shutil.copy(model_path, gcs_model_path)\n", " model_path = gcs_model_path" ] }, { "cell_type": "code", "execution_count": 38, "id": "45665acf-50c8-445b-a985-b3dabd734709", "metadata": {}, "outputs": [], "source": [ "def predict_batch_fn():\n", " import tensorflow as tf\n", " import pandas as pd\n", " \n", " # Enable GPU memory growth to avoid CUDA OOM\n", " gpus = tf.config.experimental.list_physical_devices('GPU')\n", " if gpus:\n", " try:\n", " for gpu in gpus:\n", " tf.config.experimental.set_memory_growth(gpu, True)\n", " except RuntimeError as e:\n", " print(e)\n", "\n", " model = tf.keras.models.load_model(model_path)\n", "\n", " def predict(t, a, b, g, c1, c2, m, f, v, s, h, fee, p):\n", " inputs = {\n", " \"Type\": t,\n", " \"Age\": a,\n", " \"Breed1\": b,\n", " \"Gender\": g,\n", " \"Color1\": c1,\n", " \"Color2\": c2,\n", " \"MaturitySize\": m,\n", " \"FurLength\": f,\n", " \"Vaccinated\": v,\n", " \"Sterilized\": s,\n", " \"Health\": h,\n", " \"Fee\": fee,\n", " \"PhotoAmt\": p\n", " }\n", " # return model.predict(inputs)\n", " return pd.Series(np.squeeze(model.predict(inputs)))\n", "\n", " return predict" ] }, { "cell_type": "code", "execution_count": 39, "id": "815e3b5f-7914-4235-85fa-50153dcd3d30", "metadata": {}, "outputs": [], "source": [ "# need to pass the list of columns into the model_udf\n", "classify = predict_batch_udf(predict_batch_fn,\n", " return_type=FloatType(),\n", " batch_size=100)" ] }, { "cell_type": "code", "execution_count": 40, "id": "da03a0c6-2d39-425e-a9fa-57c139cca1ed", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "25/02/04 13:59:47 WARN package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[Stage 5:> (0 + 8) / 8]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 19.8 ms, sys: 9.3 ms, total: 29.1 ms\n", "Wall time: 4.99 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "%%time\n", "preds = df.withColumn(\"preds\", classify(struct(*columns)))\n", "results = preds.collect()" ] }, { "cell_type": "code", "execution_count": 41, "id": "03990c76-7198-49a7-bb5d-6870be915fb3", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 6:> (0 + 8) / 8]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 86.9 ms, sys: 13.7 ms, total: 101 ms\n", "Wall time: 1.56 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "%%time\n", "preds = df.withColumn(\"preds\", classify(*columns))\n", "results = preds.collect()" ] }, { "cell_type": "code", "execution_count": 42, "id": "edb93cf3-c248-40c9-b8dc-acc8f51786a9", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 7:> (0 + 8) / 8]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 16.4 ms, sys: 4.46 ms, total: 20.9 ms\n", "Wall time: 1.52 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "%%time\n", "preds = df.withColumn(\"preds\", classify(*[col(c) for c in columns]))\n", "results = preds.collect()" ] }, { "cell_type": "code", "execution_count": 43, "id": "a91f19cb-f7f1-4669-aff1-be594bea5378", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+----+---+--------------------+------+------+--------+------------+---------+----------+----------+------------+---+--------+------+-----------+\n", "|Type|Age| Breed1|Gender|Color1| Color2|MaturitySize|FurLength|Vaccinated|Sterilized| Health|Fee|PhotoAmt|target| preds|\n", "+----+---+--------------------+------+------+--------+------------+---------+----------+----------+------------+---+--------+------+-----------+\n", "| Dog| 3| Mixed Breed| Male| Black|No Color| Small| Medium| Not Sure| Not Sure| Healthy| 0| 2| 0| 0.4963937|\n", "| Dog| 9| Mixed Breed| Male| Gray|No Color| Medium| Short| Not Sure| No| Healthy| 0| 4| 1| 0.6780287|\n", "| Cat| 4| Domestic Short Hair| Male| Black| Gray| Medium| Short| Not Sure| Not Sure| Healthy| 0| 4| 1| 0.58800673|\n", "| Cat| 6| Domestic Short Hair| Male|Yellow| White| Medium| Short| No| No| Healthy| 0| 3| 1| 0.7378843|\n", "| Cat| 6|Domestic Medium Hair| Male| Gray|No Color| Small| Medium| No| No| Healthy| 0| 4| 1| 1.2695599|\n", "| Cat| 5|Domestic Medium Hair|Female| Gray|No Color| Medium| Medium| Yes| Not Sure| Healthy| 0| 1| 0|0.060457088|\n", "| Dog| 24| Beagle|Female| Black| Golden| Medium| Short| Not Sure| Not Sure|Minor Injury| 0| 1| 1| 0.28160828|\n", "| Cat| 29| Tabby| Male| Brown| Golden| Medium| Short| No| No| Healthy| 0| 1| 0| 0.6928505|\n", "| Dog| 9| Mixed Breed|Female| Black| Brown| Medium| Short| Yes| Yes| Healthy| 0| 2| 0|-0.10125986|\n", "| Dog| 2| Mixed Breed|Female| Cream| White| Medium| Short| No| No| Healthy| 0| 1| 0| 1.3703903|\n", "| Dog| 2| Mixed Breed| Male| Brown| White| Medium| Short| Yes| No| Healthy| 0| 1| 1| 1.3243997|\n", "| Dog| 60| Golden Retriever| Male| Brown| Yellow| Medium| Medium| Yes| Yes| Healthy| 0| 5| 1| 0.9026731|\n", "| Cat| 9| Siamese| Male| White|No Color| Medium| Short| Yes| No| Healthy| 0| 2| 1| 0.8207382|\n", "| Dog| 19| Doberman Pinscher|Female| Black| Brown| Large| Short| Yes| Yes| Healthy|500| 2| 1| 0.85343015|\n", "| Cat| 11| Domestic Short Hair| Male| Cream|No Color| Medium| Short| Yes| Yes| Healthy|100| 6| 0| 0.53920615|\n", "| Dog| 18| Mixed Breed|Female| Brown| White| Small| Short| Yes| No| Healthy| 0| 5| 0| 0.718272|\n", "| Dog| 4| Mixed Breed|Female| Brown| White| Medium| Medium| Not Sure| Not Sure| Healthy| 0| 3| 0| 0.16185221|\n", "| Dog| 96| Golden Retriever| Male|Golden|No Color| Large| Long| Yes| Yes| Healthy| 0| 2| 1| 0.8156965|\n", "| Dog| 54| Golden Retriever| Male|Golden|No Color| Large| Medium| Yes| No| Healthy|350| 20| 1| 3.5315154|\n", "| Cat| 5|Domestic Medium Hair|Female| Brown| White| Medium| Medium| No| No| Healthy| 0| 5| 1| 1.1725564|\n", "+----+---+--------------------+------+------+--------+------------+---------+----------+----------+------------+---+--------+------+-----------+\n", "only showing top 20 rows\n", "\n" ] } ], "source": [ "preds.show()" ] }, { "cell_type": "markdown", "id": "0c3e0390", "metadata": {}, "source": [ "## Using Triton Inference Server\n", "In this section, we demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL. \n", "We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server. \n", "\n", "The process looks like this:\n", "- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.\n", "- Define a Triton inference function, which contains a client that binds to the local server on a given node and sends inference requests.\n", "- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.\n", "- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.\n", "\n", "\"drawing\"" ] }, { "cell_type": "code", "execution_count": 44, "id": "2605d134-ef75-4d94-9b16-2c6d85f29bef", "metadata": {}, "outputs": [], "source": [ "from functools import partial" ] }, { "cell_type": "markdown", "id": "ea407357", "metadata": {}, "source": [ "Import the helper class from server_utils.py:" ] }, { "cell_type": "code", "execution_count": 45, "id": "7e1e716f", "metadata": {}, "outputs": [], "source": [ "sc.addPyFile(\"server_utils.py\")\n", "\n", "from server_utils import TritonServerManager" ] }, { "cell_type": "markdown", "id": "fcd28e7d", "metadata": {}, "source": [ "Define the Triton Server function:" ] }, { "cell_type": "code", "execution_count": 46, "id": "4666e618-8038-4dc5-9be7-793aedbf4500", "metadata": {}, "outputs": [], "source": [ "def triton_server(ports, model_path):\n", " import time\n", " import signal\n", " import numpy as np\n", " import tensorflow as tf\n", " from pytriton.decorators import batch\n", " from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor\n", " from pytriton.triton import Triton, TritonConfig\n", " from pyspark import TaskContext\n", "\n", " print(f\"SERVER: Initializing model on worker {TaskContext.get().partitionId()}.\")\n", " # Enable GPU memory growth\n", " gpus = tf.config.experimental.list_physical_devices('GPU')\n", " if gpus:\n", " try:\n", " for gpu in gpus:\n", " tf.config.experimental.set_memory_growth(gpu, True)\n", " except RuntimeError as e:\n", " print(e)\n", "\n", " model = tf.keras.models.load_model(model_path)\n", "\n", " def decode(input_tensor):\n", " return tf.convert_to_tensor(np.vectorize(lambda x: x.decode('utf-8'))(input_tensor))\n", "\n", " def identity(input_tensor):\n", " return tf.convert_to_tensor(input_tensor)\n", "\n", " input_transforms = {\n", " \"Type\": decode,\n", " \"Age\": identity,\n", " \"Breed1\": decode,\n", " \"Gender\": decode,\n", " \"Color1\": decode,\n", " \"Color2\": decode,\n", " \"MaturitySize\": decode,\n", " \"FurLength\": decode,\n", " \"Vaccinated\": decode,\n", " \"Sterilized\": decode,\n", " \"Health\": decode,\n", " \"Fee\": identity,\n", " \"PhotoAmt\": identity\n", " }\n", "\n", " @batch\n", " def _infer_fn(**inputs):\n", " decoded_inputs = {k: input_transforms[k](v) for k, v in inputs.items()}\n", " print(f\"SERVER: Received batch of size {len(decoded_inputs['Type'])}.\")\n", " return {\n", " \"preds\": model.predict(decoded_inputs)\n", " }\n", "\n", " workspace_path = f\"/tmp/triton_{time.strftime('%m_%d_%M_%S')}\"\n", " triton_conf = TritonConfig(http_port=ports[0], grpc_port=ports[1], metrics_port=ports[2])\n", " with Triton(config=triton_conf, workspace=workspace_path) as triton:\n", " triton.bind(\n", " model_name=\"PetClassifier\",\n", " infer_func=_infer_fn,\n", " inputs=[\n", " Tensor(name=\"Type\", dtype=np.bytes_, shape=(-1,)),\n", " Tensor(name=\"Age\", dtype=np.int64, shape=(-1,)),\n", " Tensor(name=\"Breed1\", dtype=np.bytes_, shape=(-1,)),\n", " Tensor(name=\"Gender\", dtype=np.bytes_, shape=(-1,)),\n", " Tensor(name=\"Color1\", dtype=np.bytes_, shape=(-1,)),\n", " Tensor(name=\"Color2\", dtype=np.bytes_, shape=(-1,)),\n", " Tensor(name=\"MaturitySize\", dtype=np.bytes_, shape=(-1,)),\n", " Tensor(name=\"FurLength\", dtype=np.bytes_, shape=(-1,)),\n", " Tensor(name=\"Vaccinated\", dtype=np.bytes_, shape=(-1,)),\n", " Tensor(name=\"Sterilized\", dtype=np.bytes_, shape=(-1,)),\n", " Tensor(name=\"Health\", dtype=np.bytes_, shape=(-1,)),\n", " Tensor(name=\"Fee\", dtype=np.int64, shape=(-1,)),\n", " Tensor(name=\"PhotoAmt\", dtype=np.int64, shape=(-1,)),\n", " ],\n", " outputs=[\n", " Tensor(name=\"preds\", dtype=np.float32, shape=(-1,)),\n", " ],\n", " config=ModelConfig(\n", " max_batch_size=128,\n", " batcher=DynamicBatcher(max_queue_delay_microseconds=5000), # 5ms\n", " ),\n", " strict=True,\n", " )\n", "\n", " def _stop_triton(signum, frame):\n", " # The server manager sends SIGTERM to stop the server; this function ensures graceful cleanup.\n", " print(\"SERVER: Received SIGTERM. Stopping Triton server.\")\n", " triton.stop()\n", "\n", " signal.signal(signal.SIGTERM, _stop_triton)\n", "\n", " print(\"SERVER: Serving inference\")\n", " triton.serve()" ] }, { "cell_type": "markdown", "id": "617525a5", "metadata": {}, "source": [ "#### Start Triton servers" ] }, { "cell_type": "markdown", "id": "fc93a43a", "metadata": {}, "source": [ "The `TritonServerManager` will handle the lifecycle of Triton server instances across the Spark cluster:\n", "- Find available ports for HTTP/gRPC/metrics\n", "- Deploy a server on each node via stage-level scheduling\n", "- Gracefully shutdown servers across nodes" ] }, { "cell_type": "code", "execution_count": null, "id": "c9b98208", "metadata": {}, "outputs": [], "source": [ "model_name = \"PetClassifier\"\n", "server_manager = TritonServerManager(model_name=model_name, model_path=model_path)" ] }, { "cell_type": "code", "execution_count": null, "id": "228401f7", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-02-07 11:03:44,809 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\n", "2025-02-07 11:03:44,810 - INFO - Starting 1 servers.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "data": { "text/plain": [ "{'cb4ae00-lcedt': (2020631, [7000, 7001, 7002])}" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Returns {'hostname', (server_pid, [http_port, grpc_port, metrics_port])}\n", "server_manager.start_servers(triton_server)" ] }, { "cell_type": "markdown", "id": "cb560288", "metadata": {}, "source": [ "#### Define client function" ] }, { "cell_type": "markdown", "id": "5d28b1ca", "metadata": {}, "source": [ "Get the hostname -> url mapping from the server manager:" ] }, { "cell_type": "code", "execution_count": null, "id": "d1234a02", "metadata": {}, "outputs": [], "source": [ "host_to_http_url = server_manager.host_to_http_url # or server_manager.host_to_grpc_url" ] }, { "cell_type": "markdown", "id": "3c9ef706", "metadata": {}, "source": [ "Define the Triton inference function, which returns a predict function for batch inference through the server:" ] }, { "cell_type": "code", "execution_count": 51, "id": "e50b5fc8", "metadata": {}, "outputs": [], "source": [ "def triton_fn(model_name, host_to_url):\n", " import socket\n", " import numpy as np\n", " from pytriton.client import ModelClient\n", "\n", " url = host_to_url[socket.gethostname()]\n", " print(f\"CLIENT: Connecting to {model_name} at {url}\")\n", "\n", " def infer_batch(t, a, b, g, c1, c2, m, f, v, s, h, fee, p):\n", " \n", " def encode(value):\n", " return np.vectorize(lambda x: x.encode(\"utf-8\"))(value).astype(np.bytes_)\n", "\n", " with ModelClient(url, model_name, inference_timeout_s=240) as client:\n", " encoded_inputs = {\n", " \"Type\": encode(t), \n", " \"Age\": a, \n", " \"Breed1\": encode(b), \n", " \"Gender\": encode(g),\n", " \"Color1\": encode(c1),\n", " \"Color2\": encode(c2),\n", " \"MaturitySize\": encode(m),\n", " \"FurLength\": encode(f),\n", " \"Vaccinated\": encode(v),\n", " \"Sterilized\": encode(s),\n", " \"Health\": encode(h),\n", " \"Fee\": fee,\n", " \"PhotoAmt\": p\n", " }\n", " result_data = client.infer_batch(**encoded_inputs)\n", " return result_data[\"preds\"]\n", " \n", " return infer_batch" ] }, { "cell_type": "code", "execution_count": 54, "id": "2ffb020e-dc93-456b-bee6-405611eee1e1", "metadata": {}, "outputs": [], "source": [ "# need to pass the list of columns into the model_udf\n", "classify = predict_batch_udf(partial(triton_fn, model_name=model_name, host_to_url=host_to_http_url),\n", " input_tensor_shapes=[[1]] * len(columns),\n", " return_type=FloatType(),\n", " batch_size=64)" ] }, { "cell_type": "markdown", "id": "2edd887f", "metadata": {}, "source": [ "#### Load and preprocess DataFrame" ] }, { "cell_type": "code", "execution_count": 52, "id": "fe8dc3e6-f1b1-4a24-85f4-0a5ecabef4c5", "metadata": {}, "outputs": [], "source": [ "df = spark.read.parquet(data_path)" ] }, { "cell_type": "code", "execution_count": 53, "id": "4cfb3f34-a215-4781-91bf-2bec85e15633", "metadata": {}, "outputs": [], "source": [ "columns = df.columns\n", "# remove label column\n", "columns.remove(\"target\")" ] }, { "cell_type": "markdown", "id": "b75e6f20-f06c-4f4c-ada1-c562e078ed4b", "metadata": {}, "source": [ "#### Run inference" ] }, { "cell_type": "code", "execution_count": 55, "id": "e6ff0356-becd-421f-aebb-272497d5ad6a", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 12:> (0 + 8) / 8]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 17.3 ms, sys: 7.75 ms, total: 25.1 ms\n", "Wall time: 6.35 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "%%time\n", "preds = df.withColumn(\"preds\", classify(struct(*columns)))\n", "results = preds.collect()" ] }, { "cell_type": "code", "execution_count": 56, "id": "ce18ee7c-5958-4986-b200-6d986fcc6243", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 15.2 ms, sys: 4.2 ms, total: 19.4 ms\n", "Wall time: 5.86 s\n" ] } ], "source": [ "%%time\n", "preds = df.withColumn(\"preds\", classify(*columns))\n", "results = preds.collect()" ] }, { "cell_type": "code", "execution_count": 57, "id": "0888ce40-b2c4-4aed-8ccb-6a8bcd00abc8", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 14:===========================================> (6 + 2) / 8]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 93.4 ms, sys: 3.4 ms, total: 96.8 ms\n", "Wall time: 5.87 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "%%time\n", "preds = df.withColumn(\"preds\", classify(*[col(c) for c in columns]))\n", "results = preds.collect()" ] }, { "cell_type": "code", "execution_count": 58, "id": "d45812b5-f584-41a4-a821-2b59e065671c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+----+---+--------------------+------+------+--------+------------+---------+----------+----------+------------+---+--------+------+-----------+\n", "|Type|Age| Breed1|Gender|Color1| Color2|MaturitySize|FurLength|Vaccinated|Sterilized| Health|Fee|PhotoAmt|target| preds|\n", "+----+---+--------------------+------+------+--------+------------+---------+----------+----------+------------+---+--------+------+-----------+\n", "| Dog| 3| Mixed Breed| Male| Black|No Color| Small| Medium| Not Sure| Not Sure| Healthy| 0| 2| 0| 0.4963937|\n", "| Dog| 9| Mixed Breed| Male| Gray|No Color| Medium| Short| Not Sure| No| Healthy| 0| 4| 1| 0.6780287|\n", "| Cat| 4| Domestic Short Hair| Male| Black| Gray| Medium| Short| Not Sure| Not Sure| Healthy| 0| 4| 1| 0.58800673|\n", "| Cat| 6| Domestic Short Hair| Male|Yellow| White| Medium| Short| No| No| Healthy| 0| 3| 1| 0.7378843|\n", "| Cat| 6|Domestic Medium Hair| Male| Gray|No Color| Small| Medium| No| No| Healthy| 0| 4| 1| 1.2695599|\n", "| Cat| 5|Domestic Medium Hair|Female| Gray|No Color| Medium| Medium| Yes| Not Sure| Healthy| 0| 1| 0|0.060457088|\n", "| Dog| 24| Beagle|Female| Black| Golden| Medium| Short| Not Sure| Not Sure|Minor Injury| 0| 1| 1| 0.28160828|\n", "| Cat| 29| Tabby| Male| Brown| Golden| Medium| Short| No| No| Healthy| 0| 1| 0| 0.6928505|\n", "| Dog| 9| Mixed Breed|Female| Black| Brown| Medium| Short| Yes| Yes| Healthy| 0| 2| 0|-0.10125986|\n", "| Dog| 2| Mixed Breed|Female| Cream| White| Medium| Short| No| No| Healthy| 0| 1| 0| 1.3703903|\n", "| Dog| 2| Mixed Breed| Male| Brown| White| Medium| Short| Yes| No| Healthy| 0| 1| 1| 1.3243997|\n", "| Dog| 60| Golden Retriever| Male| Brown| Yellow| Medium| Medium| Yes| Yes| Healthy| 0| 5| 1| 0.9026731|\n", "| Cat| 9| Siamese| Male| White|No Color| Medium| Short| Yes| No| Healthy| 0| 2| 1| 0.8207382|\n", "| Dog| 19| Doberman Pinscher|Female| Black| Brown| Large| Short| Yes| Yes| Healthy|500| 2| 1| 0.85343015|\n", "| Cat| 11| Domestic Short Hair| Male| Cream|No Color| Medium| Short| Yes| Yes| Healthy|100| 6| 0| 0.53920615|\n", "| Dog| 18| Mixed Breed|Female| Brown| White| Small| Short| Yes| No| Healthy| 0| 5| 0| 0.718272|\n", "| Dog| 4| Mixed Breed|Female| Brown| White| Medium| Medium| Not Sure| Not Sure| Healthy| 0| 3| 0| 0.16185221|\n", "| Dog| 96| Golden Retriever| Male|Golden|No Color| Large| Long| Yes| Yes| Healthy| 0| 2| 1| 0.8156965|\n", "| Dog| 54| Golden Retriever| Male|Golden|No Color| Large| Medium| Yes| No| Healthy|350| 20| 1| 3.5315154|\n", "| Cat| 5|Domestic Medium Hair|Female| Brown| White| Medium| Medium| No| No| Healthy| 0| 5| 1| 1.1725564|\n", "+----+---+--------------------+------+------+--------+------------+---------+----------+----------+------------+---+--------+------+-----------+\n", "only showing top 20 rows\n", "\n" ] } ], "source": [ "preds.show()" ] }, { "cell_type": "markdown", "id": "63135aa0-b44c-4dda-8050-8cad320afe88", "metadata": { "tags": [] }, "source": [ "#### Stop Triton Server on each executor" ] }, { "cell_type": "code", "execution_count": 59, "id": "6914f44f-677f-4db3-be09-783df8d11b8a", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-02-04 14:00:18,330 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\n", "2025-02-04 14:00:28,520 - INFO - Sucessfully stopped 1 servers. \n" ] }, { "data": { "text/plain": [ "[True]" ] }, "execution_count": 59, "metadata": {}, "output_type": "execute_result" } ], "source": [ "server_manager.stop_servers()" ] }, { "cell_type": "code", "execution_count": 60, "id": "f8c6ee43-8891-4446-986e-1447c5d48bac", "metadata": {}, "outputs": [], "source": [ "if not on_databricks: # on databricks, spark.stop() puts the cluster in a bad state\n", " spark.stop()" ] }, { "cell_type": "code", "execution_count": null, "id": "e611126e-d8c3-40ac-bf16-b911f6d7b39f", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "spark-dl-tf", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: examples/ML+DL-Examples/Spark-DL/dl_inference/tensorflow/keras_resnet50_tf.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "id": "8e6810cc-5982-4293-bfbd-c91ef0aca204", "metadata": {}, "source": [ "\n", "\n", "# PySpark Tensorflow Inference\n", "\n", "### Flower Recognition with Keras Resnet50\n", "\n", "In this notebook, we demonstrate distribute inference with Resnet50 on the Databricks flower photos dataset. \n", "From: https://docs.databricks.com/_static/notebooks/deep-learning/keras-metadata.html" ] }, { "cell_type": "markdown", "id": "858e3a8d", "metadata": {}, "source": [ "Note that cuFFT/cuDNN/cuBLAS registration errors are expected (as of `tf=2.17.0`) and will not affect behavior, as noted in [this issue.](https://github.com/tensorflow/tensorflow/issues/62075) " ] }, { "cell_type": "code", "execution_count": 1, "id": "cf329ac8-0763-44bc-b0f6-b634b7dc480e", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-02-04 14:00:35.457924: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", "2025-02-04 14:00:35.465639: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "2025-02-04 14:00:35.473515: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "2025-02-04 14:00:35.475792: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", "2025-02-04 14:00:35.482106: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", "To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", "2025-02-04 14:00:35.843263: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n" ] } ], "source": [ "import os\n", "import shutil\n", "import subprocess\n", "import time\n", "import json\n", "import pandas as pd\n", "from PIL import Image\n", "import numpy as np\n", "import uuid\n", " \n", "import tensorflow as tf\n", "from tensorflow.keras.applications.resnet50 import ResNet50" ] }, { "cell_type": "code", "execution_count": 2, "id": "532d562d", "metadata": {}, "outputs": [], "source": [ "os.mkdir('models') if not os.path.exists('models') else None" ] }, { "cell_type": "code", "execution_count": 3, "id": "75175140", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.17.0\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", "I0000 00:00:1738706436.174805 3714100 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1738706436.197467 3714100 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1738706436.200398 3714100 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n" ] } ], "source": [ "print(tf.__version__)\n", "\n", "# Enable GPU memory growth\n", "gpus = tf.config.experimental.list_physical_devices('GPU')\n", "if gpus:\n", " try:\n", " for gpu in gpus:\n", " tf.config.experimental.set_memory_growth(gpu, True)\n", " except RuntimeError as e:\n", " print(e)" ] }, { "cell_type": "markdown", "id": "02fe61b8", "metadata": {}, "source": [ "## PySpark" ] }, { "cell_type": "code", "execution_count": 4, "id": "b474339c", "metadata": {}, "outputs": [], "source": [ "from pyspark.sql.functions import col, struct, pandas_udf, PandasUDFType\n", "from pyspark.ml.functions import predict_batch_udf\n", "from pyspark.sql.types import *\n", "from pyspark.sql import SparkSession\n", "from pyspark import SparkConf\n", "from typing import Iterator, Tuple" ] }, { "cell_type": "markdown", "id": "e182cacb", "metadata": {}, "source": [ "Check the cluster environment to handle any platform-specific Spark configurations." ] }, { "cell_type": "code", "execution_count": 5, "id": "564b1d33", "metadata": {}, "outputs": [], "source": [ "on_databricks = os.environ.get(\"DATABRICKS_RUNTIME_VERSION\", False)\n", "on_dataproc = os.environ.get(\"DATAPROC_IMAGE_VERSION\", False)\n", "on_standalone = not (on_databricks or on_dataproc)" ] }, { "cell_type": "markdown", "id": "016cdd0b", "metadata": {}, "source": [ "#### Create Spark Session\n", "\n", "For local standalone clusters, we'll connect to the cluster and create the Spark Session. \n", "For CSP environments, Spark will either be preconfigured (Databricks) or we'll need to create the Spark Session (Dataproc)." ] }, { "cell_type": "code", "execution_count": 6, "id": "44d72768", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "25/02/04 14:00:36 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)\n", "25/02/04 14:00:36 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\n", "Setting default log level to \"WARN\".\n", "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n", "25/02/04 14:00:37 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n" ] } ], "source": [ "conf = SparkConf()\n", "\n", "if 'spark' not in globals():\n", " if on_standalone:\n", " import socket\n", " \n", " conda_env = os.environ.get(\"CONDA_PREFIX\")\n", " hostname = socket.gethostname()\n", " conf.setMaster(f\"spark://{hostname}:7077\")\n", " conf.set(\"spark.pyspark.python\", f\"{conda_env}/bin/python\")\n", " conf.set(\"spark.pyspark.driver.python\", f\"{conda_env}/bin/python\")\n", " elif on_dataproc:\n", " conf.set(\"spark.executorEnv.TF_GPU_ALLOCATOR\", \"cuda_malloc_async\")\n", "\n", " conf.set(\"spark.executor.cores\", \"8\")\n", " conf.set(\"spark.task.resource.gpu.amount\", \"0.125\")\n", " conf.set(\"spark.executor.resource.gpu.amount\", \"1\")\n", " conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"true\")\n", " conf.set(\"spark.python.worker.reuse\", \"true\")\n", " conf.set(\"spark.driver.memory\", \"8g\")\n", " conf.set(\"spark.executor.memory\", \"8g\")\n", "\n", "conf.set(\"spark.sql.execution.arrow.maxRecordsPerBatch\", \"512\")\n", "conf.set(\"spark.sql.parquet.columnarReaderBatchSize\", \"1024\")\n", "spark = SparkSession.builder.appName(\"spark-dl-examples\").config(conf=conf).getOrCreate()\n", "sc = spark.sparkContext" ] }, { "cell_type": "markdown", "id": "61c406fa", "metadata": {}, "source": [ "Define the input and output directories." ] }, { "cell_type": "code", "execution_count": 7, "id": "c566dc17", "metadata": {}, "outputs": [], "source": [ "os.mkdir(\"spark-dl-datasets\") if not os.path.exists(\"spark-dl-datasets\") else None\n", "data_path = \"spark-dl-datasets/flowers_{uuid}.parquet\".format(uuid=str(uuid.uuid1()))\n", "local_file_path = f\"{os.getcwd()}/{data_path}\"\n", "output_file_path = \"predictions/predictions\"" ] }, { "cell_type": "markdown", "id": "968d08a7-66b9-444f-b362-d8df692aef1c", "metadata": {}, "source": [ "### Prepare trained model and data for inference" ] }, { "cell_type": "markdown", "id": "da083168-137f-492c-8769-d8f1e2111756", "metadata": {}, "source": [ "Load the ResNet-50 Model and broadcast the weights." ] }, { "cell_type": "code", "execution_count": 8, "id": "2ddc715a-cdbc-4c49-93e9-58c9d88511da", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "I0000 00:00:1738706437.771948 3714100 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1738706437.774792 3714100 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1738706437.777387 3714100 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1738706437.894244 3714100 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1738706437.895287 3714100 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1738706437.896207 3714100 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "2025-02-04 14:00:37.897142: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 40337 MB memory: -> device: 0, name: NVIDIA RTX A6000, pci bus id: 0000:01:00.0, compute capability: 8.6\n" ] } ], "source": [ "model = ResNet50()\n", "bc_model_weights = sc.broadcast(model.get_weights())" ] }, { "cell_type": "markdown", "id": "77dddfa3-e8df-4e8e-8251-64457f1ebf80", "metadata": {}, "source": [ "Load the data and save the datasets to one Parquet file." ] }, { "cell_type": "code", "execution_count": 9, "id": "c0738bec-97d4-4946-8c49-5e6d07ff1afc", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Image count: 3670\n" ] } ], "source": [ "import pathlib\n", "dataset_url = \"https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz\"\n", "data_dir = tf.keras.utils.get_file(origin=dataset_url,\n", " fname='flower_photos',\n", " untar=True)\n", "data_dir = pathlib.Path(data_dir)\n", "image_count = len(list(data_dir.glob('*/*.jpg')))\n", "print(f\"Image count: {image_count}\")" ] }, { "cell_type": "code", "execution_count": 10, "id": "d54f470a-d308-4426-8ed0-33f95155bb4f", "metadata": {}, "outputs": [], "source": [ "import os\n", "files = [os.path.join(dp, f) for dp, dn, filenames in os.walk(data_dir) for f in filenames if os.path.splitext(f)[1] == '.jpg']\n", "files = files[:2048]" ] }, { "cell_type": "code", "execution_count": 11, "id": "64f94ee0-f1ea-47f6-a77e-be8da5d1b87a", "metadata": {}, "outputs": [], "source": [ "image_data = []\n", "for file in files:\n", " img = Image.open(file)\n", " img = img.resize([224, 224])\n", " data = np.asarray(img, dtype=\"float32\").reshape([224*224*3])\n", "\n", " image_data.append({\"data\": data})" ] }, { "cell_type": "code", "execution_count": 12, "id": "b4ae1a98", "metadata": {}, "outputs": [], "source": [ "pd.DataFrame(image_data, columns=['data']).to_parquet(data_path)\n", "\n", "if on_databricks:\n", " dbutils.fs.mkdirs(\"/FileStore/spark-dl-datasets\")\n", " shutil.copy(local_file_path, \"/dbfs/FileStore/{}\".format(data_path))\n", " data_path = \"/dbfs/FileStore/{}\".format(data_path)\n", "elif on_dataproc:\n", " data_dir = \"/mnt/gcs/spark-dl/spark-dl-datasets\"\n", " os.mkdir(data_dir) if not os.path.exists(data_dir) else None\n", " shutil.copy(local_file_path, \"/mnt/gcs/spark-dl/\" + data_path)\n", " data_path = \"file:///mnt/gcs/spark-dl/\" + data_path" ] }, { "cell_type": "markdown", "id": "f2414b0f-58f2-4e4a-9d09-8ea95b38d413", "metadata": {}, "source": [ "### Save Model\n" ] }, { "cell_type": "code", "execution_count": 13, "id": "670328e3-7274-4d78-b315-487750166a3f", "metadata": {}, "outputs": [], "source": [ "model_path = 'models/resnet50_model.keras'\n", "model.save(model_path)\n", "\n", "# For cloud environments, copy the model to the distributed file system.\n", "if on_databricks:\n", " dbutils.fs.mkdirs(\"/FileStore/spark-dl-models\")\n", " dbfs_model_path = \"/dbfs/FileStore/spark-dl-models/resnet50_model.keras\"\n", " shutil.copy(model_path, dbfs_model_path)\n", " model_path = dbfs_model_path\n", "elif on_dataproc:\n", " # GCS is mounted at /mnt/gcs by the init script\n", " models_dir = \"/mnt/gcs/spark-dl/models\"\n", " os.mkdir(models_dir) if not os.path.exists(models_dir) else None\n", " gcs_model_path = models_dir + \"/resnet50_model.keras\"\n", " shutil.copy(model_path, gcs_model_path)\n", " model_path = gcs_model_path" ] }, { "cell_type": "markdown", "id": "b827ad56-1af0-41b7-be68-94bd203a2a70", "metadata": {}, "source": [ "### Load the data into Spark DataFrames" ] }, { "cell_type": "code", "execution_count": 14, "id": "8ddc22d0-b88a-4906-bd47-bf247e34feeb", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2048\n" ] } ], "source": [ "df = spark.read.parquet(data_path)\n", "print(df.count())" ] }, { "cell_type": "markdown", "id": "865929b0-b016-4de4-996d-7f16176cf49c", "metadata": { "tags": [] }, "source": [ "### Model inference via Pandas UDF" ] }, { "cell_type": "markdown", "id": "b1f5a747", "metadata": {}, "source": [ "Define the function to parse the input data." ] }, { "cell_type": "code", "execution_count": 15, "id": "a67b3128-13c1-44f1-a0c0-7cf7a836fee3", "metadata": {}, "outputs": [], "source": [ "def parse_image(image_data):\n", " image = tf.image.convert_image_dtype(\n", " image_data, dtype=tf.float32) * (2. / 255) - 1\n", " image = tf.reshape(image, [224, 224, 3])\n", " return image" ] }, { "cell_type": "markdown", "id": "024e4ba2", "metadata": {}, "source": [ "Define the function for model inference." ] }, { "cell_type": "code", "execution_count": 16, "id": "7b33185f-6d1e-4ca9-9757-fdc3d736496b", "metadata": {}, "outputs": [], "source": [ "@pandas_udf(ArrayType(FloatType()))\n", "def pandas_predict_udf(iter: Iterator[Tuple[pd.Series]]) -> Iterator[pd.Series]:\n", "\n", " # Enable GPU memory growth to avoid CUDA OOM\n", " gpus = tf.config.experimental.list_physical_devices('GPU')\n", " if gpus:\n", " try:\n", " for gpu in gpus:\n", " tf.config.experimental.set_memory_growth(gpu, True)\n", " except RuntimeError as e:\n", " print(e)\n", "\n", " batch_size = 64\n", " model = ResNet50(weights=None)\n", " model.set_weights(bc_model_weights.value)\n", " for image_batch in iter:\n", " images = np.vstack(image_batch)\n", " dataset = tf.data.Dataset.from_tensor_slices(images)\n", " dataset = dataset.map(parse_image, num_parallel_calls=8).prefetch(\n", " 5000).batch(batch_size)\n", " preds = model.predict(dataset)\n", " yield pd.Series(list(preds))" ] }, { "cell_type": "markdown", "id": "08190547", "metadata": {}, "source": [ "Run model inference and save the results to Parquet." ] }, { "cell_type": "code", "execution_count": 17, "id": "ad8c05da-db38-45ef-81d0-1f862f575ced", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 49.7 ms, sys: 17.6 ms, total: 67.3 ms\n", "Wall time: 15.1 s\n" ] } ], "source": [ "%%time\n", "predictions_1 = df.select(pandas_predict_udf(col(\"data\")).alias(\"prediction\"))\n", "results = predictions_1.collect()" ] }, { "cell_type": "code", "execution_count": 18, "id": "08cb2a10", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 6:============================================> (3 + 1) / 4]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "+----------------------------------------------------------------------------------------------------+\n", "| prediction|\n", "+----------------------------------------------------------------------------------------------------+\n", "|[1.2964063E-4, 2.4653607E-4, 6.7508765E-5, 1.2236452E-4, 5.7346635E-5, 3.9642912E-4, 7.033199E-6,...|\n", "|[4.4486973E-5, 3.5260408E-4, 4.684452E-5, 8.12069E-5, 3.179397E-5, 1.9187202E-4, 7.887208E-6, 1.3...|\n", "|[1.059436E-4, 2.2737762E-4, 3.0225037E-5, 6.550149E-5, 2.3658315E-5, 3.7172026E-4, 3.353684E-6, 2...|\n", "|[2.0393689E-5, 2.2818097E-4, 7.841931E-5, 6.991323E-5, 4.704759E-5, 9.822018E-5, 5.5858673E-6, 2....|\n", "|[1.13108545E-4, 2.3128217E-4, 5.283139E-5, 1.0866656E-4, 4.0229144E-5, 3.7223354E-4, 5.5677583E-6...|\n", "|[9.1271184E-5, 2.0681013E-4, 4.5193243E-5, 7.6812066E-5, 3.2361808E-5, 3.399333E-4, 3.8415465E-6,...|\n", "|[1.0792112E-4, 3.7743401E-4, 7.618583E-5, 1.24259E-4, 4.7426664E-5, 3.3307416E-4, 1.0592865E-5, 9...|\n", "|[2.2220212E-5, 2.7357432E-4, 3.8200575E-5, 6.235621E-5, 1.7954999E-5, 1.7249273E-4, 6.021971E-6, ...|\n", "|[1.1044029E-4, 2.8961376E-4, 4.2384647E-5, 1.0728626E-4, 3.0468744E-5, 4.796082E-4, 6.4537376E-6,...|\n", "|[9.68494E-5, 2.0567125E-4, 7.450887E-5, 1.13256065E-4, 4.609738E-5, 2.8675792E-4, 5.603957E-6, 5....|\n", "|[7.420906E-5, 3.2883475E-4, 1.3444667E-4, 1.7758778E-4, 8.4717096E-5, 2.2534849E-4, 1.3623082E-5,...|\n", "|[8.755989E-5, 2.7312606E-4, 3.59614E-5, 7.7967066E-5, 2.3571063E-5, 3.6875304E-4, 3.5629025E-6, 3...|\n", "|[9.7425895E-5, 2.7611412E-4, 5.74094E-5, 1.1035101E-4, 3.8303257E-5, 3.4981826E-4, 6.167147E-6, 4...|\n", "|[6.92996E-5, 2.5326438E-4, 5.063317E-5, 1.1494952E-4, 3.0212495E-5, 2.7857954E-4, 5.0324948E-6, 5...|\n", "|[4.2184765E-5, 2.4904116E-4, 1.237565E-4, 1.4271903E-4, 7.3208634E-5, 1.6054673E-4, 7.938735E-6, ...|\n", "|[2.719573E-5, 3.8372327E-4, 1.291892E-4, 1.5711001E-4, 7.3108524E-5, 8.553368E-5, 1.2617156E-5, 1...|\n", "|[3.0565643E-5, 3.55542E-4, 1.5949155E-4, 2.1368133E-4, 8.043127E-5, 1.02662845E-4, 1.3859853E-5, ...|\n", "|[3.311506E-5, 2.8069926E-4, 1.7956384E-4, 2.0205336E-4, 1.3665091E-4, 1.0115404E-4, 3.409792E-5, ...|\n", "|[4.573667E-5, 2.888326E-4, 2.3792271E-4, 2.460216E-4, 1.2164583E-4, 1.3814335E-4, 1.6352218E-5, 2...|\n", "|[1.2279079E-4, 2.8073761E-4, 6.365874E-5, 1.0251792E-4, 4.3527238E-5, 3.914249E-4, 8.236801E-6, 6...|\n", "+----------------------------------------------------------------------------------------------------+\n", "only showing top 20 rows\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "predictions_1.show(truncate=100)" ] }, { "cell_type": "code", "execution_count": 19, "id": "40799f8e-443e-40ca-919b-391f901cb3f4", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "predictions_1.write.mode(\"overwrite\").parquet(output_file_path + \"_1\")" ] }, { "cell_type": "markdown", "id": "e7a69aa9", "metadata": {}, "source": [ "## Inference using Spark DL API\n", "\n", "Distributed inference using the PySpark [predict_batch_udf](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.functions.predict_batch_udf.html#pyspark.ml.functions.predict_batch_udf):\n", "\n", "- predict_batch_fn uses Tensorflow APIs to load the model and return a predict function which operates on numpy arrays \n", "- predict_batch_udf will convert the Spark DataFrame columns into numpy input batches for the predict function" ] }, { "cell_type": "code", "execution_count": 20, "id": "dda88b46-6300-4bf7-bc10-7403f4fbbf92", "metadata": {}, "outputs": [], "source": [ "def predict_batch_fn():\n", " import tensorflow as tf\n", " from tensorflow.keras.applications.resnet50 import ResNet50\n", "\n", " # Enable GPU memory growth\n", " gpus = tf.config.experimental.list_physical_devices('GPU')\n", " if gpus:\n", " try:\n", " for gpu in gpus:\n", " tf.config.experimental.set_memory_growth(gpu, True)\n", " except RuntimeError as e:\n", " print(e)\n", "\n", " model = ResNet50()\n", " def predict(inputs):\n", " inputs = inputs * (2. / 255) - 1\n", " return model.predict(inputs)\n", " return predict" ] }, { "cell_type": "code", "execution_count": 21, "id": "cff0e851-563d-40b6-9d05-509c22b3b7f9", "metadata": {}, "outputs": [], "source": [ "classify = predict_batch_udf(predict_batch_fn,\n", " input_tensor_shapes=[[224, 224, 3]],\n", " return_type=ArrayType(FloatType()),\n", " batch_size=50)" ] }, { "cell_type": "code", "execution_count": 22, "id": "aa7c156f-e2b3-4837-9427-ccf3a5720412", "metadata": {}, "outputs": [], "source": [ "df = spark.read.parquet(data_path)" ] }, { "cell_type": "code", "execution_count": 23, "id": "80bc50ad-eaf5-4fce-a354-5e17d65e2da5", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 61.7 ms, sys: 23.1 ms, total: 84.8 ms\n", "Wall time: 16.7 s\n" ] } ], "source": [ "%%time\n", "# first pass caches model/fn\n", "predictions_2 = df.select(classify(struct(\"data\")).alias(\"prediction\"))\n", "results = predictions_2.collect()" ] }, { "cell_type": "code", "execution_count": 24, "id": "41cace80-7a4b-4929-8e63-9c83f9745e02", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 141 ms, sys: 22.2 ms, total: 163 ms\n", "Wall time: 16 s\n" ] } ], "source": [ "%%time\n", "predictions_2 = df.select(classify(\"data\").alias(\"prediction\"))\n", "results = predictions_2.collect()" ] }, { "cell_type": "code", "execution_count": 25, "id": "56a2ec8a-de09-4d7c-9666-1b3c76f10657", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 52.8 ms, sys: 14.3 ms, total: 67.1 ms\n", "Wall time: 15.5 s\n" ] } ], "source": [ "%%time\n", "predictions_2 = df.select(classify(col(\"data\")).alias(\"prediction\"))\n", "results = predictions_2.collect()" ] }, { "cell_type": "code", "execution_count": 26, "id": "2dcf3791", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 13:===========================================> (3 + 1) / 4]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "+----------------------------------------------------------------------------------------------------+\n", "| prediction|\n", "+----------------------------------------------------------------------------------------------------+\n", "|[1.293178E-4, 2.4644283E-4, 6.760039E-5, 1.2260793E-4, 5.7431564E-5, 3.9597694E-4, 7.0522524E-6, ...|\n", "|[4.4487308E-5, 3.5378174E-4, 4.6667028E-5, 8.102564E-5, 3.168566E-5, 1.9189132E-4, 7.903805E-6, 1...|\n", "|[1.0566196E-4, 2.2684377E-4, 3.00564E-5, 6.5251304E-5, 2.3520754E-5, 3.7116173E-4, 3.331476E-6, 2...|\n", "|[2.0337258E-5, 2.2749524E-4, 7.8351426E-5, 6.991163E-5, 4.7081656E-5, 9.8092445E-5, 5.564894E-6, ...|\n", "|[1.12979564E-4, 2.3172122E-4, 5.2946547E-5, 1.0876398E-4, 4.0259067E-5, 3.7143996E-4, 5.5940513E-...|\n", "|[9.093228E-5, 2.0639994E-4, 4.5151268E-5, 7.666316E-5, 3.2264295E-5, 3.387436E-4, 3.832487E-6, 4....|\n", "|[1.0783461E-4, 3.7850672E-4, 7.660902E-5, 1.2446321E-4, 4.7591406E-5, 3.3328883E-4, 1.067249E-5, ...|\n", "|[2.2258617E-5, 2.7345872E-4, 3.814439E-5, 6.229726E-5, 1.79387E-5, 1.7259057E-4, 6.0371217E-6, 1....|\n", "|[1.1067773E-4, 2.8997674E-4, 4.2570035E-5, 1.0747747E-4, 3.0524247E-5, 4.7921995E-4, 6.489833E-6,...|\n", "|[9.676251E-5, 2.0588847E-4, 7.467098E-5, 1.1326933E-4, 4.6123736E-5, 2.8609246E-4, 5.627118E-6, 5...|\n", "|[7.4104944E-5, 3.290917E-4, 1.3448784E-4, 1.7742367E-4, 8.463227E-5, 2.2462371E-4, 1.3614881E-5, ...|\n", "|[8.7211796E-5, 2.7337394E-4, 3.5953894E-5, 7.7924225E-5, 2.3554327E-5, 3.67775E-4, 3.5652213E-6, ...|\n", "|[9.7237185E-5, 2.762026E-4, 5.7450008E-5, 1.1019135E-4, 3.831896E-5, 3.4878452E-4, 6.1574788E-6, ...|\n", "|[6.938849E-5, 2.5376282E-4, 5.0565883E-5, 1.14880335E-4, 3.0061366E-5, 2.7866007E-4, 5.024482E-6,...|\n", "|[4.2096388E-5, 2.4889092E-4, 1.2363133E-4, 1.4304162E-4, 7.337785E-5, 1.6042824E-4, 7.959722E-6, ...|\n", "|[2.730248E-5, 3.851789E-4, 1.293143E-4, 1.5753493E-4, 7.302161E-5, 8.547956E-5, 1.26348905E-5, 1....|\n", "|[3.0354899E-5, 3.5562844E-4, 1.6008675E-4, 2.1440513E-4, 8.062159E-5, 1.02023136E-4, 1.3876455E-5...|\n", "|[3.3083066E-5, 2.8158593E-4, 1.7979987E-4, 2.0232225E-4, 1.3704685E-4, 1.0091762E-4, 3.4243407E-5...|\n", "|[4.5485373E-5, 2.878148E-4, 2.3707838E-4, 2.4493985E-4, 1.21028905E-4, 1.3738636E-4, 1.6280053E-5...|\n", "|[1.22468E-4, 2.809503E-4, 6.3342835E-5, 1.021957E-4, 4.3373006E-5, 3.905496E-4, 8.212427E-6, 6.20...|\n", "+----------------------------------------------------------------------------------------------------+\n", "only showing top 20 rows\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "predictions_2.show(truncate=100)" ] }, { "cell_type": "code", "execution_count": 27, "id": "fc511eae", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "predictions_2.write.mode(\"overwrite\").parquet(output_file_path + \"_2\")" ] }, { "cell_type": "markdown", "id": "878ca7fb", "metadata": {}, "source": [ "## Using Triton Inference Server\n", "In this section, we demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL. \n", "We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server. \n", "\n", "The process looks like this:\n", "- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.\n", "- Define a Triton inference function, which contains a client that binds to the local server on a given node and sends inference requests.\n", "- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.\n", "- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.\n", "\n", "\"drawing\"" ] }, { "cell_type": "code", "execution_count": 28, "id": "2605d134-ef75-4d94-9b16-2c6d85f29bef", "metadata": {}, "outputs": [], "source": [ "from functools import partial" ] }, { "cell_type": "markdown", "id": "cdded12d", "metadata": {}, "source": [ "Import the helper class from server_utils.py:" ] }, { "cell_type": "code", "execution_count": 29, "id": "a2475d41", "metadata": {}, "outputs": [], "source": [ "sc.addPyFile(\"server_utils.py\")\n", "\n", "from server_utils import TritonServerManager" ] }, { "cell_type": "markdown", "id": "1f6701dc", "metadata": {}, "source": [ "Define the Triton Server function:" ] }, { "cell_type": "code", "execution_count": 30, "id": "8c8c0744-0558-4dac-bbfe-8bdde4b2af2d", "metadata": {}, "outputs": [], "source": [ "def triton_server(ports):\n", " import time\n", " import signal\n", " import numpy as np\n", " import tensorflow as tf\n", " from tensorflow.keras.applications import ResNet50\n", " from pytriton.decorators import batch\n", " from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor\n", " from pytriton.triton import Triton, TritonConfig\n", " from pyspark import TaskContext\n", "\n", " print(f\"SERVER: Initializing ResNet on worker {TaskContext.get().partitionId()}.\")\n", "\n", " # Enable GPU memory growth\n", " gpus = tf.config.experimental.list_physical_devices('GPU')\n", " if gpus:\n", " try:\n", " for gpu in gpus:\n", " tf.config.experimental.set_memory_growth(gpu, True)\n", " except RuntimeError as e:\n", " print(e)\n", " \n", " model = ResNet50()\n", " normalization_layer = tf.keras.layers.Rescaling(scale=2./255, offset=-1)\n", "\n", " @batch\n", " def _infer_fn(**inputs):\n", " images = inputs[\"images\"]\n", " normalized_images = normalization_layer(images)\n", " return {\n", " \"preds\": model.predict(normalized_images),\n", " }\n", "\n", " workspace_path = f\"/tmp/triton_{time.strftime('%m_%d_%M_%S')}\"\n", " triton_conf = TritonConfig(http_port=ports[0], grpc_port=ports[1], metrics_port=ports[2])\n", " with Triton(config=triton_conf, workspace=workspace_path) as triton:\n", " triton.bind(\n", " model_name=\"ResNet50\",\n", " infer_func=_infer_fn,\n", " inputs=[\n", " Tensor(name=\"images\", dtype=np.float32, shape=(224, 224, 3)),\n", " ],\n", " outputs=[\n", " Tensor(name=\"preds\", dtype=np.float32, shape=(-1,)),\n", " ],\n", " config=ModelConfig(\n", " max_batch_size=100,\n", " batcher=DynamicBatcher(max_queue_delay_microseconds=5000), # 5ms\n", " ),\n", " strict=True,\n", " )\n", "\n", " def _stop_triton(signum, frame):\n", " # The server manager sends SIGTERM to stop the server; this function ensures graceful cleanup.\n", " print(\"SERVER: Received SIGTERM. Stopping Triton server.\")\n", " triton.stop()\n", "\n", " signal.signal(signal.SIGTERM, _stop_triton)\n", "\n", " print(\"SERVER: Serving inference\")\n", " triton.serve()" ] }, { "cell_type": "markdown", "id": "d74f7037", "metadata": {}, "source": [ "#### Start Triton servers" ] }, { "cell_type": "markdown", "id": "4bf99bde", "metadata": {}, "source": [ "The `TritonServerManager` will handle the lifecycle of Triton server instances across the Spark cluster:\n", "- Find available ports for HTTP/gRPC/metrics\n", "- Deploy a server on each node via stage-level scheduling\n", "- Gracefully shutdown servers across nodes" ] }, { "cell_type": "code", "execution_count": null, "id": "2309a55c", "metadata": {}, "outputs": [], "source": [ "model_name = \"ResNet50\"\n", "server_manager = TritonServerManager(model_name=model_name)" ] }, { "cell_type": "code", "execution_count": null, "id": "205fa1e8", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-02-07 11:03:44,809 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\n", "2025-02-07 11:03:44,810 - INFO - Starting 1 servers.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "data": { "text/plain": [ "{'cb4ae00-lcedt': (2020631, [7000, 7001, 7002])}" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Returns {'hostname', (server_pid, [http_port, grpc_port, metrics_port])}\n", "server_manager.start_servers(triton_server)" ] }, { "cell_type": "markdown", "id": "e49ebdbe", "metadata": {}, "source": [ "#### Define client function" ] }, { "cell_type": "markdown", "id": "55c42174", "metadata": {}, "source": [ "Get the hostname -> url mapping from the server manager:" ] }, { "cell_type": "code", "execution_count": null, "id": "9e4ff20e", "metadata": {}, "outputs": [], "source": [ "host_to_http_url = server_manager.host_to_http_url # or server_manager.host_to_grpc_url" ] }, { "cell_type": "markdown", "id": "481dbd42", "metadata": {}, "source": [ "Define the Triton inference function, which returns a predict function for batch inference through the server:" ] }, { "cell_type": "code", "execution_count": 35, "id": "a5ab49bb", "metadata": {}, "outputs": [], "source": [ "def triton_fn(model_name, host_to_url):\n", " import socket\n", " import numpy as np\n", " from pytriton.client import ModelClient\n", "\n", " url = host_to_url[socket.gethostname()]\n", " print(f\"CLIENT: Connecting to {model_name} at {url}\")\n", "\n", " def infer_batch(inputs):\n", " with ModelClient(url, model_name, inference_timeout_s=240) as client:\n", " result_data = client.infer_batch(inputs)\n", " return result_data[\"preds\"]\n", " \n", " return infer_batch" ] }, { "cell_type": "code", "execution_count": 37, "id": "9fabcaeb-5a44-42bb-8097-5dbc2d0cee3e", "metadata": {}, "outputs": [], "source": [ "classify = predict_batch_udf(partial(triton_fn, model_name=model_name, host_to_url=host_to_http_url),\n", " input_tensor_shapes=[[224, 224, 3]],\n", " return_type=ArrayType(FloatType()),\n", " batch_size=50)" ] }, { "cell_type": "markdown", "id": "fcd2328e", "metadata": {}, "source": [ "#### Load DataFrame" ] }, { "cell_type": "code", "execution_count": 36, "id": "bbfc9009", "metadata": {}, "outputs": [], "source": [ "df = spark.read.parquet(data_path)" ] }, { "cell_type": "markdown", "id": "8c07365c-0a14-49b3-9bd8-cfb35f48b089", "metadata": {}, "source": [ "#### Run inference" ] }, { "cell_type": "code", "execution_count": 38, "id": "e595473d-1a5d-46a6-a6ba-89d2ea903de9", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 60.9 ms, sys: 21.3 ms, total: 82.2 ms\n", "Wall time: 18.4 s\n" ] } ], "source": [ "%%time\n", "# first pass caches model/fn\n", "predictions_3 = df.select(classify(struct(\"data\")).alias(\"prediction\"))\n", "results = predictions_3.collect()" ] }, { "cell_type": "code", "execution_count": 39, "id": "5f66d468-e0b1-4589-8606-b3848063a823", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 46.3 ms, sys: 16.1 ms, total: 62.4 ms\n", "Wall time: 12.3 s\n" ] } ], "source": [ "%%time\n", "predictions_3 = df.select(classify(\"data\").alias(\"prediction\"))\n", "results = predictions_3.collect()" ] }, { "cell_type": "code", "execution_count": 40, "id": "632c4c3a-fa52-4c3d-b71e-7526286e353a", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 57.5 ms, sys: 16.4 ms, total: 73.9 ms\n", "Wall time: 12.4 s\n" ] } ], "source": [ "%%time\n", "predictions_3 = df.select(classify(col(\"data\")).alias(\"prediction\"))\n", "results = predictions_3.collect()" ] }, { "cell_type": "code", "execution_count": 41, "id": "49870e39", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 22:===========================================> (3 + 1) / 4]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "+----------------------------------------------------------------------------------------------------+\n", "| prediction|\n", "+----------------------------------------------------------------------------------------------------+\n", "|[1.293178E-4, 2.4644283E-4, 6.760039E-5, 1.2260793E-4, 5.7431564E-5, 3.9597694E-4, 7.0522524E-6, ...|\n", "|[4.4487308E-5, 3.5378174E-4, 4.6667028E-5, 8.102564E-5, 3.168566E-5, 1.9189132E-4, 7.903805E-6, 1...|\n", "|[1.0566196E-4, 2.2684377E-4, 3.00564E-5, 6.5251304E-5, 2.3520754E-5, 3.7116173E-4, 3.331476E-6, 2...|\n", "|[2.0337258E-5, 2.2749524E-4, 7.8351426E-5, 6.991163E-5, 4.7081656E-5, 9.8092445E-5, 5.564894E-6, ...|\n", "|[1.12979564E-4, 2.3172122E-4, 5.2946547E-5, 1.0876398E-4, 4.0259067E-5, 3.7143996E-4, 5.5940513E-...|\n", "|[9.093228E-5, 2.0639994E-4, 4.5151268E-5, 7.666316E-5, 3.2264295E-5, 3.387436E-4, 3.832487E-6, 4....|\n", "|[1.0783461E-4, 3.7850672E-4, 7.660902E-5, 1.2446321E-4, 4.7591406E-5, 3.3328883E-4, 1.067249E-5, ...|\n", "|[2.2258617E-5, 2.7345872E-4, 3.814439E-5, 6.229726E-5, 1.79387E-5, 1.7259057E-4, 6.0371217E-6, 1....|\n", "|[1.1067773E-4, 2.8997674E-4, 4.2570035E-5, 1.0747747E-4, 3.0524247E-5, 4.7921995E-4, 6.489833E-6,...|\n", "|[9.676251E-5, 2.0588847E-4, 7.467098E-5, 1.1326933E-4, 4.6123736E-5, 2.8609246E-4, 5.627118E-6, 5...|\n", "|[7.4104944E-5, 3.290917E-4, 1.3448784E-4, 1.7742367E-4, 8.463227E-5, 2.2462371E-4, 1.3614881E-5, ...|\n", "|[8.7211796E-5, 2.7337394E-4, 3.5953894E-5, 7.7924225E-5, 2.3554327E-5, 3.67775E-4, 3.5652213E-6, ...|\n", "|[9.7237185E-5, 2.762026E-4, 5.7450008E-5, 1.1019135E-4, 3.831896E-5, 3.4878452E-4, 6.1574788E-6, ...|\n", "|[6.938849E-5, 2.5376282E-4, 5.0565883E-5, 1.14880335E-4, 3.0061366E-5, 2.7866007E-4, 5.024482E-6,...|\n", "|[4.2096388E-5, 2.4889092E-4, 1.2363133E-4, 1.4304162E-4, 7.337785E-5, 1.6042824E-4, 7.959722E-6, ...|\n", "|[2.730248E-5, 3.851789E-4, 1.293143E-4, 1.5753493E-4, 7.302161E-5, 8.547956E-5, 1.26348905E-5, 1....|\n", "|[3.0354899E-5, 3.5562844E-4, 1.6008675E-4, 2.1440513E-4, 8.062159E-5, 1.02023136E-4, 1.3876455E-5...|\n", "|[3.3083066E-5, 2.8158593E-4, 1.7979987E-4, 2.0232225E-4, 1.3704685E-4, 1.0091762E-4, 3.4243407E-5...|\n", "|[4.5485373E-5, 2.878148E-4, 2.3707838E-4, 2.4493985E-4, 1.21028905E-4, 1.3738636E-4, 1.6280053E-5...|\n", "|[1.22468E-4, 2.809503E-4, 6.3342835E-5, 1.021957E-4, 4.3373006E-5, 3.905496E-4, 8.212427E-6, 6.20...|\n", "+----------------------------------------------------------------------------------------------------+\n", "only showing top 20 rows\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "predictions_3.show(truncate=100)" ] }, { "cell_type": "code", "execution_count": 42, "id": "86cd59f9", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "predictions_3.write.mode(\"overwrite\").parquet(output_file_path + \"_3\")" ] }, { "cell_type": "markdown", "id": "4dc06b7e-f750-40b5-9208-a035db11d937", "metadata": { "tags": [] }, "source": [ "#### Stop Triton Server on each executor" ] }, { "cell_type": "code", "execution_count": 43, "id": "bbfcaa51-3b9f-43ff-a4a8-4b46766115b8", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-02-04 14:03:34,747 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\n", "2025-02-04 14:03:39,935 - INFO - Sucessfully stopped 1 servers. \n" ] }, { "data": { "text/plain": [ "[True]" ] }, "execution_count": 43, "metadata": {}, "output_type": "execute_result" } ], "source": [ "server_manager.stop_servers()" ] }, { "cell_type": "code", "execution_count": 44, "id": "0d88639b-d934-4eb4-ae2f-cc13b9b10456", "metadata": {}, "outputs": [], "source": [ "if not on_databricks: # on databricks, spark.stop() puts the cluster in a bad state\n", " spark.stop()" ] }, { "cell_type": "code", "execution_count": null, "id": "df8cc28a-34d7-479c-be7e-9a380d39e25e", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "spark-dl-tf", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: examples/ML+DL-Examples/Spark-DL/dl_inference/tensorflow/text_classification_tf.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "id": "2cd2accf-5877-4136-a243-7a33a13ce2b4", "metadata": {}, "source": [ "\n", "\n", "# Pyspark TensorFlow Inference\n", "\n", "### Text Classification\n", "In this notebook, we demonstrate training a model to perform sentiment analysis, and using the trained model for distributed inference. \n", "Based on: https://www.tensorflow.org/tutorials/keras/text_classification" ] }, { "cell_type": "markdown", "id": "bc72d0ed", "metadata": {}, "source": [ "Note that cuFFT/cuDNN/cuBLAS registration errors are expected (as of `tf=2.17.0`) and will not affect behavior, as noted in [this issue.](https://github.com/tensorflow/tensorflow/issues/62075) " ] }, { "cell_type": "code", "execution_count": 1, "id": "76f0f5df-502f-444e-b2ee-1122e1dea870", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-02-04 14:05:12.899608: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", "2025-02-04 14:05:12.907256: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "2025-02-04 14:05:12.915374: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "2025-02-04 14:05:12.917743: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", "2025-02-04 14:05:12.924372: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", "To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", "2025-02-04 14:05:13.295411: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n" ] } ], "source": [ "import os\n", "import re\n", "import shutil\n", "import string\n", "import matplotlib.pyplot as plt\n", "\n", "import tensorflow as tf\n", "from tensorflow.keras import layers, losses" ] }, { "cell_type": "code", "execution_count": 2, "id": "a364ad5f-b269-45b5-ab8b-d8f34fb642b7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.17.0\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", "I0000 00:00:1738706713.692042 3744395 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1738706713.716276 3744395 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1738706713.719037 3744395 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n" ] } ], "source": [ "print(tf.__version__)\n", "\n", "# Enable GPU memory growth\n", "gpus = tf.config.experimental.list_physical_devices('GPU')\n", "if gpus:\n", " try:\n", " for gpu in gpus:\n", " tf.config.experimental.set_memory_growth(gpu, True)\n", " except RuntimeError as e:\n", " print(e)" ] }, { "cell_type": "markdown", "id": "b64bb471", "metadata": {}, "source": [ "### Download and explore the dataset" ] }, { "cell_type": "code", "execution_count": 3, "id": "d229c1b6-3967-46b5-9ea8-68f4b42dd211", "metadata": {}, "outputs": [], "source": [ "from datasets import load_dataset\n", "dataset = load_dataset(\"imdb\")" ] }, { "cell_type": "code", "execution_count": 4, "id": "88f9a92e", "metadata": {}, "outputs": [], "source": [ "# Create directories for our data\n", "base_dir = \"spark-dl-datasets/imdb\"\n", "if os.environ.get(\"DATABRICKS_RUNTIME_VERSION\", False):\n", " # For databricks, use the driver disk rather than Workspace (much faster)\n", " base_dir = \"/local_disk0/\" + base_dir\n", "\n", "train_dir = base_dir + \"/train\"\n", "test_dir = base_dir + \"/test\"" ] }, { "cell_type": "code", "execution_count": 5, "id": "3f984d5a", "metadata": {}, "outputs": [], "source": [ "# Create directories for positive (1) and negative (0) reviews\n", "for split in [\"train\", \"test\"]:\n", " split_dir = os.path.join(base_dir, split)\n", " pos_dir = split_dir + \"/pos\"\n", " neg_dir = split_dir + \"/neg\"\n", "\n", " os.makedirs(pos_dir, exist_ok=True)\n", " os.makedirs(neg_dir, exist_ok=True)" ] }, { "cell_type": "code", "execution_count": 6, "id": "6cd2328a", "metadata": {}, "outputs": [], "source": [ "def write_reviews_to_files(dataset_split, split_name):\n", " for idx, example in enumerate(dataset_split):\n", " label_dir = \"pos\" if example[\"label\"] == 1 else \"neg\"\n", " dir_path = os.path.join(base_dir, split_name, label_dir)\n", "\n", " file_path = dir_path + f\"/review_{idx}.txt\"\n", " with open(file_path, \"w\", encoding=\"utf-8\") as f:\n", " f.write(example[\"text\"])\n", "\n", "# Write train and test sets\n", "write_reviews_to_files(dataset[\"train\"], \"train\")\n", "write_reviews_to_files(dataset[\"test\"], \"test\")" ] }, { "cell_type": "markdown", "id": "b02fde64", "metadata": {}, "source": [ "There are 25,000 examples in the training folder, of which we will use 80% (or 20,000) for training, and 5,000 for validation." ] }, { "cell_type": "code", "execution_count": 7, "id": "5c357f22", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Found 25000 files belonging to 2 classes.\n", "Using 20000 files for training.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "I0000 00:00:1738706719.326625 3744395 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1738706719.329542 3744395 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1738706719.332409 3744395 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1738706719.451656 3744395 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1738706719.452700 3744395 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1738706719.453630 3744395 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "2025-02-04 14:05:19.454569: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 40337 MB memory: -> device: 0, name: NVIDIA RTX A6000, pci bus id: 0000:01:00.0, compute capability: 8.6\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Found 25000 files belonging to 2 classes.\n", "Using 5000 files for validation.\n", "Found 25000 files belonging to 2 classes.\n" ] } ], "source": [ "batch_size = 32\n", "seed = 42\n", "\n", "raw_train_ds = tf.keras.utils.text_dataset_from_directory(\n", " str(train_dir),\n", " batch_size=batch_size,\n", " validation_split=0.2,\n", " subset=\"training\",\n", " seed=seed,\n", ")\n", "\n", "raw_val_ds = tf.keras.utils.text_dataset_from_directory(\n", " str(train_dir),\n", " batch_size=batch_size,\n", " validation_split=0.2,\n", " subset=\"validation\",\n", " seed=seed,\n", ")\n", "\n", "raw_test_ds = tf.keras.utils.text_dataset_from_directory(\n", " str(test_dir),\n", " batch_size=batch_size\n", ")" ] }, { "cell_type": "markdown", "id": "02994994", "metadata": {}, "source": [ "We can take a look at a sample of the dataset (note that OUT_OF_RANGE errors are safe to ignore: https://github.com/tensorflow/tensorflow/issues/62963):" ] }, { "cell_type": "code", "execution_count": 8, "id": "1d528a95", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Review b'I was really, really disappointed with this movie. it started really well, and built up some great atmosphere and suspense, but when it finally got round to revealing the \"monster\"...it turned out to be just some psycho with skin problems......again. Whoop-de-do. Yet another nutjob movie...like we don\\'t already have enough of them.

To be fair, the \"creep\" is genuinely unsettling to look at, and the way he moves and the strange sounds he makes are pretty creepy, but I\\'m sick of renting film like this only to discover that the monster is human, albeit a twisted, demented, freakish one. When I saw all the tell-tale rats early on I was hoping for some kind of freaky rat-monster hybrid thing...it was such a let down when the Creep was revealed.

On top of this, some of the stuff in this movie makes no sense. (Spoiler)

Why the hell does the Creep kill the security Guard? Whats the point, apart from sticking a great honking sign up that says \"HI I\\'m A PSYCHO AND I LIVE DOWN HERE!\"? Its stupid, and only seems to happen to prevent Franka Potente\\'s character from getting help.

what the hells he been eating down there? I got the impression he was effectively walled in, and only the unexpected opening into that tunnel section let him loose...so has he been munching rats all that time, and if so why do they hang around him so much? Why is he so damn hard to kill? He\\'s thin, malnourished and not exactly at peak performance...but seems to keep going despite injuries that are equivalent to those that .cripple the non-psycho characters in the film.

The DVD commentary says we are intended to empathise with Creep, but I just find him loathsome. Its an effective enough movie, but it wasted so many opportunities that it makes me sick.'\n", "Label 0\n", "Review b\"This has the absolute worst performance from Robert Duval who sounds just like William Buckley throughout the entire film. His hammy melodramatic acting takes away from any dramatic interest. I'm not sure if this was deliberate scene stealing or inadvertent but it's the only thing I can recall from a truly forgettable film. This picture should be shown in every amateur acting class of an example of what not to do. Thank God, Duvall went on to bigger and better things and stopped trying to effect a cultured accent. He is a good character actor but that's about it. Klaus is so much better. His performance is muted and noteworthy.\"\n", "Label 0\n", "Review b'A long time ago, in a galaxy far, far away.....There was a boy who was only two years old when the original \"Star Wars\" film was released. He doesn\\'t remember first seeing the movie, but he also doesn\\'t remember life before it. He does remember the first \"Star Wars\" themed gift he got...a shoebox full of action figures from the original set. He was too young to fully appreciate how special that gift would be. But years later, he would get what to this day goes down as one of the best gifts he\\'s ever received: another box full of action figures, ten of the final twelve he needed to complete his collection. It\\'s now legendary in this boy\\'s family how the last action figure he needed, Anakin Skywalker, stopped being produced and carried in stores, and how this boy went for about ten years (until he got into college) trying to track one down and finally bought it from someone on his dorm floor for a bag of beer nuggets (don\\'t ask...it\\'s a Northern Illinois University thing).

I can\\'t review \"Star Wars\" as a movie. It represents absolutely everything good, fun and magical about my childhood. There\\'s no separating it in my mind from Christmases, birthdays, summers and winters growing up. In the winter, my friends and I would build snow forts and pretend we were on Hoth (I was always Han Solo). My friends\\' dad built them a kick-ass tree house, and that served as the Ewok village. They also had a huge pine tree whose bottom branches were high enough to create a sort of cave underneath it, and this made a great spot to pretend we were in Yoda\\'s home. I am unabashedly dorky when it comes to \"Star Wars\" and I think people either just understand that or they don\\'t. I don\\'t get the appeal of \"Lord of the Rings\" or \"Star Trek\" but I understand the rabid flocks of fans that follow them because I am a rabid fan of George Lucas\\'s films.

I feel no need to defend my opinion of these movies as some of the greatest of all time. Every time I put them in the DVD player, I feel like I\\'m eight years old again, when life was simple and the biggest problem I had was figuring out how I was going to track down a figure of Anakin Skywalker.

Grade (for the entire trilogy): A+'\n", "Label 1\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2025-02-04 14:05:20.533703: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" ] } ], "source": [ "for text_batch, label_batch in raw_train_ds.take(1):\n", " for i in range(3):\n", " print(\"Review\", text_batch.numpy()[i])\n", " print(\"Label\", label_batch.numpy()[i])" ] }, { "cell_type": "markdown", "id": "4bca98b1", "metadata": {}, "source": [ "Notice the reviews contain raw text (with punctuation and occasional HTML tags like \\
\\). We will show how to handle these in the following section." ] }, { "cell_type": "code", "execution_count": 9, "id": "f8921ed2", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Label 0 corresponds to neg\n", "Label 1 corresponds to pos\n" ] } ], "source": [ "print(\"Label 0 corresponds to\", raw_train_ds.class_names[0])\n", "print(\"Label 1 corresponds to\", raw_train_ds.class_names[1])" ] }, { "cell_type": "markdown", "id": "f6cf0e47", "metadata": {}, "source": [ "### Prepare the dataset for training\n", "\n", "Next, we will standardize, tokenize, and vectorize the data using the tf.keras.layers.TextVectorization layer. \n", "We will write a custom standardization function to remove the HTML." ] }, { "cell_type": "code", "execution_count": 10, "id": "cb141709-fcc1-4cee-bc98-9c89aaba8648", "metadata": {}, "outputs": [], "source": [ "def custom_standardization(input_data):\n", " lowercase = tf.strings.lower(input_data)\n", " stripped_html = tf.strings.regex_replace(lowercase, \"
\", \" \")\n", " return tf.strings.regex_replace(\n", " stripped_html, \"[%s]\" % re.escape(string.punctuation), \"\"\n", " )" ] }, { "cell_type": "markdown", "id": "b35e36a2", "metadata": {}, "source": [ "Next, we will create a TextVectorization layer to standardize, tokenize, and vectorize our data." ] }, { "cell_type": "code", "execution_count": 11, "id": "d4e80ea9-536a-4ebc-8b35-1eca73dbba7d", "metadata": {}, "outputs": [], "source": [ "max_features = 10000\n", "sequence_length = 250\n", "\n", "vectorize_layer = layers.TextVectorization(\n", " standardize=custom_standardization,\n", " max_tokens=max_features,\n", " output_mode=\"int\",\n", " output_sequence_length=sequence_length,\n", ")" ] }, { "cell_type": "markdown", "id": "879fbc3f", "metadata": {}, "source": [ "Next, we will call adapt to fit the state of the preprocessing layer to the dataset. This will cause the model to build an index of strings to integers." ] }, { "cell_type": "code", "execution_count": 12, "id": "ad1e5d81-7dae-4b08-b520-ca45501b9510", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-02-04 14:05:22.003236: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" ] } ], "source": [ "# Make a text-only dataset (without labels), then call adapt\n", "train_text = raw_train_ds.map(lambda x, y: x)\n", "vectorize_layer.adapt(train_text)" ] }, { "cell_type": "markdown", "id": "ad1e5d81-7dae-4b08-b520-ca45501b9510", "metadata": {}, "source": [ "Let's create a function to see the result of using this layer to preprocess some data." ] }, { "cell_type": "code", "execution_count": 13, "id": "80f243f5-edd3-4e1c-bddc-abc1cc6673ef", "metadata": {}, "outputs": [], "source": [ "def vectorize_text(text, label):\n", " text = tf.expand_dims(text, -1)\n", " return vectorize_layer(text), label" ] }, { "cell_type": "code", "execution_count": 14, "id": "8f37e95c-515c-4edb-a1ee-fc47be5df4b9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Review tf.Tensor(b\"To describe this film as garbage is unfair. At least rooting through garbage can be an absorbing hobby. This flick was neither absorbing nor entertaining.

Kevin Bacon can act superbly given the chance, so no doubt had an IRS bill to settle when he agreed to this dire screenplay. The mad scientist story of 'Hollow Man' has been told before, been told better, and been told without resorting to so many ludicrously expensive special effects.

Most of those special effects seem to be built around the transparent anatomical dolls of men, women and dogs you could buy in the early seventies. In the UK they were marketed as 'The Transparent Man (/Woman/Dog)' which is maybe where they got the title for this film.

Clever special effects, dire script, non-existent plot.

\", shape=(), dtype=string)\n", "Label neg\n", "Vectorized review (, )\n" ] } ], "source": [ "# retrieve a batch (of 32 reviews and labels) from the dataset\n", "text_batch, label_batch = next(iter(raw_train_ds))\n", "first_review, first_label = text_batch[0], label_batch[0]\n", "print(\"Review\", first_review)\n", "print(\"Label\", raw_train_ds.class_names[first_label])\n", "print(\"Vectorized review\", vectorize_text(first_review, first_label))" ] }, { "cell_type": "markdown", "id": "680f53bb", "metadata": {}, "source": [ "We can lookup the token (string) that each integer corresponds to by calling .get_vocabulary() on the layer." ] }, { "cell_type": "code", "execution_count": 15, "id": "60c9208a-39ac-4e6c-a603-61038cdf3d10", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1287 ---> nowhere\n", " 313 ---> house\n", "Vocabulary size: 10000\n" ] } ], "source": [ "print(\"1287 ---> \",vectorize_layer.get_vocabulary()[1287])\n", "print(\" 313 ---> \",vectorize_layer.get_vocabulary()[313])\n", "print('Vocabulary size: {}'.format(len(vectorize_layer.get_vocabulary())))" ] }, { "cell_type": "code", "execution_count": 16, "id": "3cf90d4b-8dae-44b2-b32b-80cb0092c430", "metadata": {}, "outputs": [], "source": [ "train_ds = raw_train_ds.map(vectorize_text)\n", "val_ds = raw_val_ds.map(vectorize_text)\n", "test_ds = raw_test_ds.map(vectorize_text)" ] }, { "cell_type": "markdown", "id": "b3db3f77", "metadata": {}, "source": [ "### Configure the dataset for performance\n", "\n", "These are two important methods you should use when loading data to make sure that I/O does not become blocking.\n", "\n", "`.cache()` keeps data in memory after it's loaded off disk. This will ensure the dataset does not become a bottleneck while training your model. If your dataset is too large to fit into memory, you can also use this method to create a performant on-disk cache, which is more efficient to read than many small files.\n", "\n", "`.prefetch()` overlaps data preprocessing and model execution while training." ] }, { "cell_type": "code", "execution_count": 17, "id": "115a5aba-8a00-458f-be25-0aae9f55de22", "metadata": {}, "outputs": [], "source": [ "AUTOTUNE = tf.data.AUTOTUNE\n", "\n", "train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)\n", "val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)\n", "test_ds = test_ds.cache().prefetch(buffer_size=AUTOTUNE)" ] }, { "cell_type": "markdown", "id": "0d6d6692", "metadata": {}, "source": [ "### Create the model" ] }, { "cell_type": "code", "execution_count": 18, "id": "d64f4495-102d-4244-9b42-1ba9976a366e", "metadata": {}, "outputs": [], "source": [ "embedding_dim = 16" ] }, { "cell_type": "code", "execution_count": 19, "id": "3dc95d22-935f-4091-b0ee-da95174eb9a0", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Model: \"sequential\"\n",
       "
\n" ], "text/plain": [ "\u001b[1mModel: \"sequential\"\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
       "┃ Layer (type)                     Output Shape                  Param # ┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
       "│ embedding (Embedding)           │ ?                      │   0 (unbuilt) │\n",
       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
       "│ dropout (Dropout)               │ ?                      │   0 (unbuilt) │\n",
       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
       "│ global_average_pooling1d        │ ?                      │   0 (unbuilt) │\n",
       "│ (GlobalAveragePooling1D)        │                        │               │\n",
       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
       "│ dropout_1 (Dropout)             │ ?                      │   0 (unbuilt) │\n",
       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
       "│ dense (Dense)                   │ ?                      │   0 (unbuilt) │\n",
       "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
       "
\n" ], "text/plain": [ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", "│ embedding (\u001b[38;5;33mEmbedding\u001b[0m) │ ? │ \u001b[38;5;34m0\u001b[0m (unbuilt) │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ dropout (\u001b[38;5;33mDropout\u001b[0m) │ ? │ \u001b[38;5;34m0\u001b[0m (unbuilt) │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ global_average_pooling1d │ ? │ \u001b[38;5;34m0\u001b[0m (unbuilt) │\n", "│ (\u001b[38;5;33mGlobalAveragePooling1D\u001b[0m) │ │ │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ dropout_1 (\u001b[38;5;33mDropout\u001b[0m) │ ? │ \u001b[38;5;34m0\u001b[0m (unbuilt) │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ dense (\u001b[38;5;33mDense\u001b[0m) │ ? │ \u001b[38;5;34m0\u001b[0m (unbuilt) │\n", "└─────────────────────────────────┴────────────────────────┴───────────────┘\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Total params: 0 (0.00 B)\n",
       "
\n" ], "text/plain": [ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Trainable params: 0 (0.00 B)\n",
       "
\n" ], "text/plain": [ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Non-trainable params: 0 (0.00 B)\n",
       "
\n" ], "text/plain": [ "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "model = tf.keras.Sequential([\n", " layers.Embedding(max_features, embedding_dim),\n", " layers.Dropout(0.2),\n", " layers.GlobalAveragePooling1D(),\n", " layers.Dropout(0.2),\n", " layers.Dense(1, activation='sigmoid')])\n", "\n", "model.summary()" ] }, { "cell_type": "code", "execution_count": 20, "id": "d9059b93-7666-46db-bf15-517c4c205df9", "metadata": {}, "outputs": [], "source": [ "model.compile(loss=losses.BinaryCrossentropy(),\n", " optimizer='adam',\n", " metrics=[tf.metrics.BinaryAccuracy(threshold=0.5)])" ] }, { "cell_type": "markdown", "id": "f8b66d33", "metadata": {}, "source": [ "#### Train model" ] }, { "cell_type": "code", "execution_count": 21, "id": "b1d5959f-1bd8-48da-9815-8239599519b2", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", "I0000 00:00:1738706722.621647 3744883 service.cc:146] XLA service 0x334cd320 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:\n", "I0000 00:00:1738706722.621667 3744883 service.cc:154] StreamExecutor device (0): NVIDIA RTX A6000, Compute Capability 8.6\n", "2025-02-04 14:05:22.635317: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.\n", "2025-02-04 14:05:22.689182: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:531] Loaded cuDNN version 8907\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m262/625\u001b[0m \u001b[32m━━━━━━━━\u001b[0m\u001b[37m━━━━━━━━━━━━\u001b[0m \u001b[1m0s\u001b[0m 578us/step - binary_accuracy: 0.5299 - loss: 0.6904" ] }, { "name": "stderr", "output_type": "stream", "text": [ "I0000 00:00:1738706723.175401 3744883 device_compiler.h:188] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m625/625\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 2ms/step - binary_accuracy: 0.5692 - loss: 0.6832 - val_binary_accuracy: 0.7020 - val_loss: 0.6195\n", "Epoch 2/10\n", "\u001b[1m625/625\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 455us/step - binary_accuracy: 0.7588 - loss: 0.5825 - val_binary_accuracy: 0.7954 - val_loss: 0.5009\n", "Epoch 3/10\n", "\u001b[1m625/625\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 536us/step - binary_accuracy: 0.8293 - loss: 0.4681 - val_binary_accuracy: 0.8352 - val_loss: 0.4253\n", "Epoch 4/10\n", "\u001b[1m625/625\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 516us/step - binary_accuracy: 0.8523 - loss: 0.3967 - val_binary_accuracy: 0.8516 - val_loss: 0.3802\n", "Epoch 5/10\n", "\u001b[1m625/625\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 448us/step - binary_accuracy: 0.8692 - loss: 0.3524 - val_binary_accuracy: 0.8592 - val_loss: 0.3522\n", "Epoch 6/10\n", "\u001b[1m625/625\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 530us/step - binary_accuracy: 0.8810 - loss: 0.3199 - val_binary_accuracy: 0.8658 - val_loss: 0.3324\n", "Epoch 7/10\n", "\u001b[1m625/625\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 489us/step - binary_accuracy: 0.8919 - loss: 0.2945 - val_binary_accuracy: 0.8666 - val_loss: 0.3188\n", "Epoch 8/10\n", "\u001b[1m625/625\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 509us/step - binary_accuracy: 0.8975 - loss: 0.2744 - val_binary_accuracy: 0.8720 - val_loss: 0.3085\n", "Epoch 9/10\n", "\u001b[1m625/625\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 389us/step - binary_accuracy: 0.9042 - loss: 0.2565 - val_binary_accuracy: 0.8756 - val_loss: 0.3017\n", "Epoch 10/10\n", "\u001b[1m625/625\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 410us/step - binary_accuracy: 0.9121 - loss: 0.2409 - val_binary_accuracy: 0.8750 - val_loss: 0.2972\n" ] } ], "source": [ "epochs = 10\n", "history = model.fit(\n", " train_ds,\n", " validation_data=val_ds,\n", " epochs=epochs)" ] }, { "cell_type": "markdown", "id": "4c8d8f2a", "metadata": {}, "source": [ "#### Evaluate the model" ] }, { "cell_type": "code", "execution_count": 22, "id": "656afe07-354f-4ff2-8e3e-d02bad6c5958", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m782/782\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 573us/step - binary_accuracy: 0.8719 - loss: 0.3147\n", "Loss: 0.3172186613082886\n", "Accuracy: 0.8701599836349487\n" ] } ], "source": [ "loss, accuracy = model.evaluate(test_ds)\n", "\n", "print(\"Loss: \", loss)\n", "print(\"Accuracy: \", accuracy)" ] }, { "cell_type": "markdown", "id": "b2a307ce", "metadata": {}, "source": [ "Create a plot of accuracy and loss over time:" ] }, { "cell_type": "code", "execution_count": 23, "id": "a01d0f13-d0b8-4d78-9ddc-ede5ed402446", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "dict_keys(['binary_accuracy', 'loss', 'val_binary_accuracy', 'val_loss'])" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "history_dict = history.history\n", "history_dict.keys()" ] }, { "cell_type": "code", "execution_count": 24, "id": "1f7484c3-3cdf-46d5-b95d-80316f0e6240", "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjcAAAHHCAYAAABDUnkqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABUfUlEQVR4nO3deZyNdf/H8deZGbNhBoNZzDBI9n2LuS2VQmVJNEoMdetO1qQbt52izRYhLbRKaSyVXZRQ3EmppNzZwqAwYx/OXL8/rt8cc8wYs5yZa+ac9/PxOA/nXOc65/ocM3Xevtf3e31shmEYiIiIiLgJL6sLEBEREXElhRsRERFxKwo3IiIi4lYUbkRERMStKNyIiIiIW1G4EREREbeicCMiIiJuReFGRERE3IrCjYiIiLgVhRsRC/Tu3Zvo6OgcvXb8+PHYbDbXFlTAHDhwAJvNxsKFC/P1uJs2bcJms7Fp0ybHtqz+rPKq5ujoaHr37u3S98yKhQsXYrPZOHDgQL4fWyS3FG5E0rDZbFm6pf3yE8mtrVu3Mn78eM6cOWN1KSJuwcfqAkQKknfffdfp8TvvvMO6devSba9evXqujvP666+TkpKSo9eOHj2aESNG5Or4knW5+Vll1datW5kwYQK9e/emRIkSTs/t3bsXLy/9O1QkOxRuRNJ45JFHnB5/8803rFu3Lt326124cIHAwMAsH6dIkSI5qg/Ax8cHHx/9p5tfcvOzcgU/Pz9Ljy9SGOmfAyLZ1Lp1a2rVqsV3331Hy5YtCQwM5D//+Q8Ay5cv59577yUiIgI/Pz8qV67MpEmTsNvtTu9x/TyO1PkaL7/8MvPnz6dy5cr4+fnRuHFjduzY4fTajObc2Gw2BgwYwLJly6hVqxZ+fn7UrFmT1atXp6t/06ZNNGrUCH9/fypXrsxrr72W5Xk8mzdvplu3bpQvXx4/Pz+ioqJ46qmnuHjxYrrPV6xYMY4cOULnzp0pVqwYZcqUYdiwYen+Ls6cOUPv3r0JDg6mRIkSxMXFZen0zH//+19sNhtvv/12uufWrFmDzWbjs88+A+DgwYM8+eSTVK1alYCAAEJCQujWrVuW5pNkNOcmqzX/+OOP9O7dm0qVKuHv709YWBiPPvoof//9t2Of8ePH88wzzwBQsWJFx6nP1NoymnPzxx9/0K1bN0qVKkVgYCC33XYbn3/+udM+qfOHPvroI5577jkiIyPx9/fnzjvvZN++fTf93DcyZ84catasiZ+fHxEREfTv3z/dZ//999954IEHCAsLw9/fn8jISLp3705iYqJjn3Xr1vGPf/yDEiVKUKxYMapWrer470gkt/TPP5Ec+Pvvv2nfvj3du3fnkUceITQ0FDAnYRYrVoyhQ4dSrFgxvvjiC8aOHUtSUhIvvfTSTd/3gw8+4OzZs/zrX//CZrPx4osv0qVLF/7444+bjiB8/fXXxMfH8+STT1K8eHFeeeUVHnjgAQ4dOkRISAgA33//Pe3atSM8PJwJEyZgt9uZOHEiZcqUydLn/vjjj7lw4QL9+vUjJCSE7du3M2vWLP78808+/vhjp33tdjtt27aladOmvPzyy6xfv56pU6dSuXJl+vXrB4BhGHTq1Imvv/6aJ554gurVq7N06VLi4uJuWkujRo2oVKkSH330Ubr9Fy9eTMmSJWnbti0AO3bsYOvWrXTv3p3IyEgOHDjA3Llzad26Nb/88ku2Rt2yU/O6dev4448/6NOnD2FhYfz888/Mnz+fn3/+mW+++QabzUaXLl347bffWLRoEdOnT6d06dIAN/yZHD9+nObNm3PhwgUGDRpESEgIb7/9Nh07dmTJkiXcf//9Tvs///zzeHl5MWzYMBITE3nxxRfp0aMH3377bZY/c6rx48czYcIE2rRpQ79+/di7dy9z585lx44dbNmyhSJFipCcnEzbtm25fPkyAwcOJCwsjCNHjvDZZ59x5swZgoOD+fnnn7nvvvuoU6cOEydOxM/Pj3379rFly5Zs1ySSIUNEbqh///7G9f+ZtGrVygCMefPmpdv/woUL6bb961//MgIDA41Lly45tsXFxRkVKlRwPN6/f78BGCEhIcapU6cc25cvX24AxqeffurYNm7cuHQ1AYavr6+xb98+x7YffvjBAIxZs2Y5tnXo0MEIDAw0jhw54tj2+++/Gz4+PuneMyMZfb4pU6YYNpvNOHjwoNPnA4yJEyc67Vu/fn2jYcOGjsfLli0zAOPFF190bLt69arRokULAzAWLFiQaT0jR440ihQp4vR3dvnyZaNEiRLGo48+mmnd27ZtMwDjnXfecWzbuHGjARgbN250+ixpf1bZqTmj4y5atMgAjK+++sqx7aWXXjIAY//+/en2r1ChghEXF+d4PGTIEAMwNm/e7Nh29uxZo2LFikZ0dLRht9udPkv16tWNy5cvO/adOXOmARi7d+9Od6y0FixY4FTTiRMnDF9fX+Puu+92HMMwDGP27NkGYLz11luGYRjG999/bwDGxx9/fMP3nj59ugEYJ0+ezLQGkZzSaSmRHPDz86NPnz7ptgcEBDjunz17lr/++osWLVpw4cIFfv3115u+b2xsLCVLlnQ8btGiBWCehriZNm3aULlyZcfjOnXqEBQU5Hit3W5n/fr1dO7cmYiICMd+t9xyC+3bt7/p+4Pz5zt//jx//fUXzZs3xzAMvv/++3T7P/HEE06PW7Ro4fRZVq5ciY+Pj2MkB8Db25uBAwdmqZ7Y2FiuXLlCfHy8Y9vatWs5c+YMsbGxGdZ95coV/v77b2655RZKlCjBzp07s3SsnNSc9riXLl3ir7/+4rbbbgPI9nHTHr9Jkyb84x//cGwrVqwYjz/+OAcOHOCXX35x2r9Pnz74+vo6Hmfndyqt9evXk5yczJAhQ5wmOPft25egoCDHabHg4GDAPDV44cKFDN8rddL08uXL83yytngmhRuRHChXrpzTF0aqn3/+mfvvv5/g4GCCgoIoU6aMYzJy2vkGN1K+fHmnx6lB5/Tp09l+berrU1974sQJLl68yC233JJuv4y2ZeTQoUP07t2bUqVKOebRtGrVCkj/+fz9/dOdWklbD5hzYcLDwylWrJjTflWrVs1SPXXr1qVatWosXrzYsW3x4sWULl2aO+64w7Ht4sWLjB07lqioKPz8/ChdujRlypThzJkzWfq5pJWdmk+dOsXgwYMJDQ0lICCAMmXKULFiRSBrvw83On5Gx0pdwXfw4EGn7bn5nbr+uJD+c/r6+lKpUiXH8xUrVmTo0KG88cYblC5dmrZt2/Lqq686fd7Y2FhiYmL45z//SWhoKN27d+ejjz5S0BGX0ZwbkRxI+y/yVGfOnKFVq1YEBQUxceJEKleujL+/Pzt37mT48OFZ+h+3t7d3htsNw8jT12aF3W7nrrvu4tSpUwwfPpxq1apRtGhRjhw5Qu/evdN9vhvV42qxsbE899xz/PXXXxQvXpwVK1bw0EMPOa0oGzhwIAsWLGDIkCE0a9aM4OBgbDYb3bt3z9Mv1AcffJCtW7fyzDPPUK9ePYoVK0ZKSgrt2rXLty/yvP69yMjUqVPp3bs3y5cvZ+3atQwaNIgpU6bwzTffEBkZSUBAAF999RUbN27k888/Z/Xq1SxevJg77riDtWvX5tvvjrgvhRsRF9m0aRN///038fHxtGzZ0rF9//79FlZ1TdmyZfH3989wpUxWVs/s3r2b3377jbfffptevXo5tq9bty7HNVWoUIENGzZw7tw5p5GQvXv3Zvk9YmNjmTBhAp988gmhoaEkJSXRvXt3p32WLFlCXFwcU6dOdWy7dOlSji6al9WaT58+zYYNG5gwYQJjx451bP/999/TvWd2rjhdoUKFDP9+Uk97VqhQIcvvlR2p77t3714qVark2J6cnMz+/ftp06aN0/61a9emdu3ajB49mq1btxITE8O8efN49tlnAfDy8uLOO+/kzjvvZNq0aUyePJlRo0axcePGdO8lkl06LSXiIqn/2kz7L+Lk5GTmzJljVUlOvL29adOmDcuWLePo0aOO7fv27WPVqlVZej04fz7DMJg5c2aOa7rnnnu4evUqc+fOdWyz2+3MmjUry+9RvXp1ateuzeLFi1m8eDHh4eFO4TK19utHKmbNmpVuWbora87o7wtgxowZ6d6zaNGiAFkKW/fccw/bt29n27Ztjm3nz59n/vz5REdHU6NGjax+lGxp06YNvr6+vPLKK06f6c033yQxMZF7770XgKSkJK5ever02tq1a+Pl5cXly5cB83Td9erVqwfg2EckNzRyI+IizZs3p2TJksTFxTFo0CBsNhvvvvtung7/Z9f48eNZu3YtMTEx9OvXD7vdzuzZs6lVqxa7du3K9LXVqlWjcuXKDBs2jCNHjhAUFMQnn3yS7bkbaXXo0IGYmBhGjBjBgQMHqFGjBvHx8dmejxIbG8vYsWPx9/fnscceS3dF3/vuu493332X4OBgatSowbZt21i/fr1jiXxe1BwUFETLli158cUXuXLlCuXKlWPt2rUZjuQ1bNgQgFGjRtG9e3eKFClChw4dHKEnrREjRrBo0SLat2/PoEGDKFWqFG+//Tb79+/nk08+ybOrGZcpU4aRI0cyYcIE2rVrR8eOHdm7dy9z5syhcePGjrllX3zxBQMGDKBbt27ceuutXL16lXfffRdvb28eeOABACZOnMhXX33FvffeS4UKFThx4gRz5swhMjLSaaK0SE4p3Ii4SEhICJ999hlPP/00o0ePpmTJkjzyyCPceeedjuutWK1hw4asWrWKYcOGMWbMGKKiopg4cSJ79uy56WquIkWK8OmnnzrmT/j7+3P//fczYMAA6tatm6N6vLy8WLFiBUOGDOG9997DZrPRsWNHpk6dSv369bP8PrGxsYwePZoLFy44rZJKNXPmTLy9vXn//fe5dOkSMTExrF+/Pkc/l+zU/MEHHzBw4EBeffVVDMPg7rvvZtWqVU6r1QAaN27MpEmTmDdvHqtXryYlJYX9+/dnGG5CQ0PZunUrw4cPZ9asWVy6dIk6derw6aefOkZP8sr48eMpU6YMs2fP5qmnnqJUqVI8/vjjTJ482XEdprp169K2bVs+/fRTjhw5QmBgIHXr1mXVqlWOlWIdO3bkwIEDvPXWW/z111+ULl2aVq1aMWHCBMdqK5HcsBkF6Z+VImKJzp078/PPP2c4H0REpLDRnBsRD3N9q4Tff/+dlStX0rp1a2sKEhFxMY3ciHiY8PBwR7+jgwcPMnfuXC5fvsz3339PlSpVrC5PRCTXNOdGxMO0a9eORYsWkZCQgJ+fH82aNWPy5MkKNiLiNjRyIyIiIm5Fc25ERETErSjciIiIiFvxuDk3KSkpHD16lOLFi2frkuciIiJiHcMwOHv2LBERETe9WKXHhZujR48SFRVldRkiIiKSA4cPHyYyMjLTfTwu3BQvXhww/3KCgoIsrkZERESyIikpiaioKMf3eGY8LtyknooKCgpSuBERESlksjKlRBOKRURExK0o3IiIiIhbUbgRERERt+Jxc25ERMS17HY7V65csboMcQO+vr43XeadFQo3IiKSI4ZhkJCQwJkzZ6wuRdyEl5cXFStWxNfXN1fvo3AjIiI5khpsypYtS2BgoC6MKrmSepHdY8eOUb58+Vz9PinciIhIttntdkewCQkJsboccRNlypTh6NGjXL16lSJFiuT4fTShWEREsi11jk1gYKDFlYg7ST0dZbfbc/U+CjciIpJjOhUlruSq3yedlnIRux02b4ZjxyA8HFq0AG9vq6sSERHxPBq5cYH4eIiOhttvh4cfNv+Mjja3i4iI+4uOjmbGjBlZ3n/Tpk3YbLY8X2m2cOFCSpQokafHKIgUbnIpPh66doU//3TefuSIuV0BR0Qkc3Y7bNoEixaZf+ZyukWmbDZbprfx48fn6H137NjB448/nuX9mzdvzrFjxwgODs7R8SRzOi2VC3Y7DB4MhpH+OcMAmw2GDIFOnXSKSkQkI/Hx5v9H0/4DMTISZs6ELl1cf7xjx4457i9evJixY8eyd+9ex7ZixYo57huGgd1ux8fn5l+VZcqUyVYdvr6+hIWFZes1knUaucmFzZvTj9ikZRhw+LC5n4iIOLNi5DssLMxxCw4OxmazOR7/+uuvFC9enFWrVtGwYUP8/Pz4+uuv+d///kenTp0IDQ2lWLFiNG7cmPXr1zu97/WnpWw2G2+88Qb3338/gYGBVKlShRUrVjiev/60VOrpozVr1lC9enWKFStGu3btnMLY1atXGTRoECVKlCAkJIThw4cTFxdH586ds/V3MHfuXCpXroyvry9Vq1bl3XffdTxnGAbjx4+nfPny+Pn5ERERwaBBgxzPz5kzhypVquDv709oaChdu3bN1rHzi8JNLqT5nXPJfiIinuJmI99gjnzn5SmqGxkxYgTPP/88e/bsoU6dOpw7d4577rmHDRs28P3339OuXTs6dOjAoUOHMn2fCRMm8OCDD/Ljjz9yzz330KNHD06dOnXD/S9cuMDLL7/Mu+++y1dffcWhQ4cYNmyY4/kXXniB999/nwULFrBlyxaSkpJYtmxZtj7b0qVLGTx4ME8//TQ//fQT//rXv+jTpw8bN24E4JNPPmH69Om89tpr/P777yxbtozatWsD8N///pdBgwYxceJE9u7dy+rVq2nZsmW2jp9vDA+TmJhoAEZiYmKu32vjRsMw/zPM/LZxY64PJSJSoFy8eNH45ZdfjIsXL+bo9QXh/58LFiwwgoOD09S00QCMZcuW3fS1NWvWNGbNmuV4XKFCBWP69OmOx4AxevRox+Nz584ZgLFq1SqnY50+fdpRC2Ds27fP8ZpXX33VCA0NdTwODQ01XnrpJcfjq1evGuXLlzc6deqU5c/YvHlzo2/fvk77dOvWzbjnnnsMwzCMqVOnGrfeequRnJyc7r0++eQTIygoyEhKSrrh8XIrs9+r7Hx/a+QmF1q0MM8N32hZvs0GUVHmfiIick1BHvlu1KiR0+Nz584xbNgwqlevTokSJShWrBh79uy56chNnTp1HPeLFi1KUFAQJ06cuOH+gYGBVK5c2fE4PDzcsX9iYiLHjx+nSZMmjue9vb1p2LBhtj7bnj17iImJcdoWExPDnj17AOjWrRsXL16kUqVK9O3bl6VLl3L16lUA7rrrLipUqEClSpXo2bMn77//PhcuXMjW8fOLwk0ueHubk94gfcBJfTxjhiYTi4hcLzzctfu5UtGiRZ0eDxs2jKVLlzJ58mQ2b97Mrl27qF27NsnJyZm+z/XtA2w2GykpKdna38jovF0eioqKYu/evcyZM4eAgACefPJJWrZsyZUrVyhevDg7d+5k0aJFhIeHM3bsWOrWrVsgG6cq3ORSly6wZAmUK+e8PTLS3J4Xs/1FRAq7wjTyvWXLFnr37s39999P7dq1CQsL48CBA/laQ3BwMKGhoezYscOxzW63s3Pnzmy9T/Xq1dmyZYvTti1btlCjRg3H44CAADp06MArr7zCpk2b2LZtG7t37wbAx8eHNm3a8OKLL/Ljjz9y4MABvvjii1x8sryhpeAu0KWLudxbVygWEcma1JHvrl3NIJN2gKKgjXxXqVKF+Ph4OnTogM1mY8yYMZmOwOSVgQMHMmXKFG655RaqVavGrFmzOH36dLZaFjzzzDM8+OCD1K9fnzZt2vDpp58SHx/vWP21cOFC7HY7TZs2JTAwkPfee4+AgAAqVKjAZ599xh9//EHLli0pWbIkK1euJCUlhapVq+bVR84xhRsX8faG1q2trkJEpPBIHfnO6Do3M2YUnJHvadOm8eijj9K8eXNKly7N8OHDSUpKyvc6hg8fTkJCAr169cLb25vHH3+ctm3b4p2NBNi5c2dmzpzJyy+/zODBg6lYsSILFiyg9f9/gZUoUYLnn3+eoUOHYrfbqV27Np9++ikhISGUKFGC+Ph4xo8fz6VLl6hSpQqLFi2iZs2aefSJc85m5PcJPYslJSURHBxMYmIiQUFBVpcjIlIoXbp0if3791OxYkX8/f1z9V7qzZczKSkpVK9enQcffJBJkyZZXY5LZPZ7lZ3vb43ciIiIpTTynTUHDx5k7dq1tGrVisuXLzN79mz279/Pww8/bHVpBY4mFIuIiBQCXl5eLFy4kMaNGxMTE8Pu3btZv3491atXt7q0AkcjNyIiIoVAVFRUupVOkjGN3IiIiIhbUbgRERERt6JwIyIiIm5F4UZERETcisKNiIiIuBWFGxEREXErCjciIiLZ1Lp1a4YMGeJ4HB0dzYwZMzJ9jc1mY9myZbk+tqveJzPjx4+nXr16eXqMvKRwIyIiHqNDhw60a9cuw+c2b96MzWbjxx9/zPb77tixg8cffzy35Tm5UcA4duwY7du3d+mx3I3CjYiIeIzHHnuMdevW8WfaTp3/b8GCBTRq1Ig6depk+33LlClDYGCgK0q8qbCwMPz8/PLlWIWVwo2IiHiM++67jzJlyrBw4UKn7efOnePjjz/mscce4++//+ahhx6iXLlyBAYGUrt2bRYtWpTp+15/Wur333+nZcuW+Pv7U6NGDdatW5fuNcOHD+fWW28lMDCQSpUqMWbMGK5cuQLAwoULmTBhAj/88AM2mw2bzeao+frTUrt37+aOO+4gICCAkJAQHn/8cc6dO+d4vnfv3nTu3JmXX36Z8PBwQkJC6N+/v+NYWZGSksLEiROJjIzEz8+PevXqsXr1asfzycnJDBgwgPDwcPz9/alQoQJTpkwBwDAMxo8fT/ny5fHz8yMiIoJBgwZl+dg5ofYLIiLiEoYBFy5Yc+zAQLDZbr6fj48PvXr1YuHChYwaNQrb/7/o448/xm6389BDD3Hu3DkaNmzI8OHDCQoK4vPPP6dnz55UrlyZJk2a3PQYKSkpdOnShdDQUL799lsSExOd5uekKl68OAsXLiQiIoLdu3fTt29fihcvzr///W9iY2P56aefWL16NevXrwcgODg43XucP3+etm3b0qxZM3bs2MGJEyf45z//yYABA5wC3MaNGwkPD2fjxo3s27eP2NhY6tWrR9++fW/+lwbMnDmTqVOn8tprr1G/fn3eeustOnbsyM8//0yVKlV45ZVXWLFiBR999BHly5fn8OHDHD58GIBPPvmE6dOn8+GHH1KzZk0SEhL44YcfsnTcHDM8TGJiogEYiYmJVpciIlJoXbx40fjll1+MixcvOradO2cYZsTJ/9u5c1mvfc+ePQZgbNy40bGtRYsWxiOPPHLD19x7773G008/7XjcqlUrY/DgwY7HFSpUMKZPn24YhmGsWbPG8PHxMY4cOeJ4ftWqVQZgLF269IbHeOmll4yGDRs6Ho8bN86oW7duuv3Svs/8+fONkiVLGufS/AV8/vnnhpeXl5GQkGAYhmHExcUZFSpUMK5everYp1u3bkZsbOwNa7n+2BEREcZzzz3ntE/jxo2NJ5980jAMwxg4cKBxxx13GCkpKenea+rUqcatt95qJCcn3/B4qTL6vUqVne9vnZYSERGPUq1aNZo3b85bb70FwL59+9i8eTOPPfYYAHa7nUmTJlG7dm1KlSpFsWLFWLNmDYcOHcrS++/Zs4eoqCgiIiIc25o1a5Zuv8WLFxMTE0NYWBjFihVj9OjRWT5G2mPVrVuXokWLOrbFxMSQkpLC3r17Hdtq1qyJt7e343F4eDgnTpzI0jGSkpI4evQoMTExTttjYmLYs2cPYJ762rVrF1WrVmXQoEGsXbvWsV+3bt24ePEilSpVom/fvixdupSrV69m63Nml8KNiIi4RGAgnDtnzS27c3kfe+wxPvnkE86ePcuCBQuoXLkyrVq1AuCll15i5syZDB8+nI0bN7Jr1y7atm1LcnKyy/6utm3bRo8ePbjnnnv47LPP+P777xk1apRLj5FWkSJFnB7bbDZSUlJc9v4NGjRg//79TJo0iYsXL/Lggw/StWtXwOxmvnfvXubMmUNAQABPPvkkLVu2zNacn+zSnBsREXEJmw3SDCAUaA8++CCDBw/mgw8+4J133qFfv36O+TdbtmyhU6dOPPLII4A5h+a3336jRo0aWXrv6tWrc/jwYY4dO0Z4eDgA33zzjdM+W7dupUKFCowaNcqx7eDBg077+Pr6Yrfbb3qshQsXcv78ecfozZYtW/Dy8qJq1apZqvdmgoKCiIiIYMuWLY4AmHqctHOQgoKCiI2NJTY2lq5du9KuXTtOnTpFqVKlCAgIoEOHDnTo0IH+/ftTrVo1du/eTYMGDVxS4/UUbkRExOMUK1aM2NhYRo4cSVJSEr1793Y8V6VKFZYsWcLWrVspWbIk06ZN4/jx41kON23atOHWW28lLi6Ol156iaSkJKcQk3qMQ4cO8eGHH9K4cWM+//xzli5d6rRPdHQ0+/fvZ9euXURGRlK8ePF0S8B79OjBuHHjiIuLY/z48Zw8eZKBAwfSs2dPQkNDc/aXk4FnnnmGcePGUblyZerVq8eCBQvYtWsX77//PgDTpk0jPDyc+vXr4+Xlxccff0xYWBglSpRg4cKF2O12mjZtSmBgIO+99x4BAQFUqFDBZfVdT6elRETEIz322GOcPn2atm3bOs2PGT16NA0aNKBt27a0bt2asLAwOnfunOX39fLyYunSpVy8eJEmTZrwz3/+k+eee85pn44dO/LUU08xYMAA6tWrx9atWxkzZozTPg888ADt2rXj9ttvp0yZMhkuRw8MDGTNmjWcOnWKxo0b07VrV+68805mz56dvb+Mmxg0aBBDhw7l6aefpnbt2qxevZoVK1ZQpUoVwFz59eKLL9KoUSMaN27MgQMHWLlyJV5eXpQoUYLXX3+dmJgY6tSpw/r16/n0008JCQlxaY1p2QzDMPLs3QugpKQkgoODSUxMJCgoyOpyREQKpUuXLrF//34qVqyIv7+/1eWIm8js9yo7398auRERERG3onAjIiIibkXhRkRERNyKwo2IiIi4FYUbERHJMQ9bkyJ5zFW/Two3IiKSbalXvL1gVadMcUupV2hO2yoiJ3QRPxc6fhySkuD/l/2LiLgtb29vSpQo4ehPFBgY6LjCr0hOpKSkcPLkSQIDA/HxyV08UbhxkRUr4OGHoWFD2LTJvAy5iIg7CwsLA8hyA0aRm/Hy8qJ8+fK5DsoKNy5Svz5cvQpffQWrVsE991hdkYhI3rLZbISHh1O2bNk8bYIonsPX1xcvr9zPmFG4cZGoKBg0CF56CUaMgLZtIZenDEVECgVvb+9cz5EQcSVNKHahESOgRAnYvRs++MDqakRERDyTwo0LlSoFI0ea90ePhkuXrK1HRETEEyncuNjAgVCuHBw6BHPnWl2NiIiI51G4cbGAAJgwwbz/7LOQmGhtPSIiIp5G4SYPxMVB9epw6hS8+KLV1YiIiHgWhZs84OMDU6aY96dPh2PHrK1HRETEkyjc5JGOHaF5c7h48dppKhEREcl7Cjd5xGaDF14w77/xBuzda209IiIinkLhJg/94x/QoQPY7TBqlNXViIiIeAbLw82rr75KdHQ0/v7+NG3alO3bt2e6/5kzZ+jfvz/h4eH4+flx6623snLlynyqNvsmTwYvL/jkE/j2W6urERERcX+WhpvFixczdOhQxo0bx86dO6lbty5t27a9YRO25ORk7rrrLg4cOMCSJUvYu3cvr7/+OuXKlcvnyrOuVi1z9RTA8OFgGNbWIyIi4u5shmHd123Tpk1p3Lgxs2fPBsx251FRUQwcOJARI0ak23/evHm89NJL/PrrrxQpUiRHx0xKSiI4OJjExESCgoJyVX9WHT4MVarA5cuwciW0b58vhxUREXEb2fn+tmzkJjk5me+++442bdpcK8bLizZt2rBt27YMX7NixQqaNWtG//79CQ0NpVatWkyePBm73X7D41y+fJmkpCSnW35LbaoJ5uhNJuWKiIhILlkWbv766y/sdjuhoaFO20NDQ0lISMjwNX/88QdLlizBbrezcuVKxowZw9SpU3n22WdveJwpU6YQHBzsuEVFRbn0c2SVmmqKiIjkD8snFGdHSkoKZcuWZf78+TRs2JDY2FhGjRrFvHnzbviakSNHkpiY6LgdPnw4Hyu+plQpM+CAmmqKiIjkJcvCTenSpfH29ub48eNO248fP05YWFiGrwkPD+fWW2/F29vbsa169eokJCSQnJyc4Wv8/PwICgpyulll0CA11RQREclrloUbX19fGjZsyIYNGxzbUlJS2LBhA82aNcvwNTExMezbt4+UlBTHtt9++43w8HB8fX3zvObcSttU87nn1FRTREQkL1h6Wmro0KG8/vrrvP322+zZs4d+/fpx/vx5+vTpA0CvXr0YOXKkY/9+/fpx6tQpBg8ezG+//cbnn3/O5MmT6d+/v1UfIdtSm2r+/Te89JLV1YiIiLgfHysPHhsby8mTJxk7diwJCQnUq1eP1atXOyYZHzp0CC+va/krKiqKNWvW8NRTT1GnTh3KlSvH4MGDGT58uFUfIdtSm2p27gzTpkH//hAebnVVIiIi7sPS69xYwYrr3FzPMMzWDFu3wr/+BZnMhxYREREKyXVuPJnNBs8/b95XU00RERHXUrixSIsW15pqjh5tdTUiIiLuQ+HGQqlNNZcsUVNNERERV1G4sZCaaoqIiLiewo3FJkwAPz/48ktYvdrqakRERAo/hRuLRUXBwIHmfTXVFBERyT2FmwJg5EgIDlZTTREREVdQuCkASpUyAw7AmDFw+bK19YiIiBRmCjcFRGpTzYMH1VRTREQkNxRuCoi0TTWffVZNNUVERHJK4aYAUVNNERGR3FO4KUB8fMwL+4HZVPPYMWvrERERKYwUbgqYTp2gWTO4ePHaaSoRERHJOoWbAsZmgxdeMO+/8Qb89pu19YiIiBQ2CjcFUNqmmqNGWV2NiIhI4aJwU0CpqaaIiEjOKNwUULVqQa9e5n011RQREck6hZsCTE01RUREsk/hpgArX/5aU80RIyAlxdp6RERECgOFmwIutanmjz+qqaaIiEhWKNwUcGmbao4eraaaIiIiN6NwUwioqaaIiEjWKdwUAgEBMH68eV9NNUVERDKncFNI9O4N1aqpqaaIiMjNKNwUEj4+MGWKeX/6dDXVFBERuRGFm0IktanmhQswcWLeHMNuh02bYNEi80+7PW+OIyIiklcUbgqRtE01X3/d9U014+MhOhpuvx0eftj8Mzra3C4iIlJYKNwUMi1awH33ub6pZnw8dO0Kf/7pvP3IEXO7Ao6IiBQWCjeF0JQp5ijOkiWwfXvu389uh8GDM+5flbptyBCdohIRkcJB4aYQqlUL4uLM+65oqrl5c/oRm7QMAw4fNvcTEREp6BRuCqnUppqbNsGaNbl7r6yuvNIKLRERKQwUbgqptE01hw/PXVPN8HDX7iciImIlhZtCzFVNNVu0gMhIcx5PRmw2iIoy9xMRESnoFG4KsVKlYMQI835ummp6e8PMmeb96wNO6uMZM8z9RERECjqFm0Ju0CCIiMh9U80uXczVV+XKOW+PjDS3d+mSuzpFRETyi80wcrvWpnBJSkoiODiYxMREgoKCrC7HJd54A/r2hZAQ+N//zFNVOWW3m6uijh0z59i0aKERGxERsV52vr81cuMG0jbVfPnl3L2Xtze0bg0PPWT+qWAjIiKFjcKNG0jbVHPaNC3ZFhERz6Zw4ybyo6mmiIhIYaBw4ybyuqmmiIhIYaFw40byqqmmiIhIYaJw42Zc3VRTRESksFG4cTOubqopIiJS2CjcuCFXNtUUEREpbBRu3FD58jBggHk/t001RUREChuFGzflqqaaIiIihY3CjZsKCbnWVHPMmJw31RQRESlsFG7cWGpTzQMHYN48q6sRERHJHwo3biww0JxcDDBpEiQmWluPiIhIflC4cXOubKopIiJSGCjcuDkfH5g82byvppoiIuIJFG48QOfOcNttaqopIiKeQeHGA6ippoiIeBKFGw/RsuW1ppqjR1tdjYiISN5RuPEgqU01P/5YTTVFRMR9Kdx4kFq1oFcv876aaoqIiLtSuPEwEyeqqaaIiLg3hRsPk7ap5ogRaqopIiLuR+HGA6U21fzhB1i0yOpqREREXEvhxgOlbao5erSaaoqIiHtRuPFQaqopIiLuSuHGQwUGwvjx5n011RQREXeicOPB+vSBqlXVVFNERNyLwo0H8/ExL+wHZlPNhARr6xEREXEFhRsPp6aaIiLibhRuPFzapprz56uppoiIFH4KN6KmmiIi4lYKRLh59dVXiY6Oxt/fn6ZNm7I9k66OCxcuxGazOd38/f3zsVr3NHnytaaaO3ZYXY2IiEjOWR5uFi9ezNChQxk3bhw7d+6kbt26tG3blhMnTtzwNUFBQRw7dsxxO3jwYD5W7J5q11ZTTRERcQ+Wh5tp06bRt29f+vTpQ40aNZg3bx6BgYG89dZbN3yNzWYjLCzMcQsNDc3Hit1XalPNjRth7VqrqxEREckZS8NNcnIy3333HW3atHFs8/Lyok2bNmzbtu2Grzt37hwVKlQgKiqKTp068fPPP+dHuW4vbVPNp582V1CJiIgUNpaGm7/++gu73Z5u5CU0NJSEG1x0pWrVqrz11lssX76c9957j5SUFJo3b86ff/6Z4f6XL18mKSnJ6SY3NnIklCkDP/8MPXuqa7iIiBQ+lp+Wyq5mzZrRq1cv6tWrR6tWrYiPj6dMmTK89tprGe4/ZcoUgoODHbeoqKh8rrhwCQmB+Hjw9TX/1OopEREpbCwNN6VLl8bb25vjx487bT9+/DhhYWFZeo8iRYpQv3599u3bl+HzI0eOJDEx0XE7fPhwrut2d//4B7zxhnl/yhR4+21r6xEREckOS8ONr68vDRs2ZMOGDY5tKSkpbNiwgWbNmmXpPex2O7t37yY8PDzD5/38/AgKCnK6yc317AmjRpn3+/aFzZutrUdERCSrLD8tNXToUF5//XXefvtt9uzZQ79+/Th//jx9+vQBoFevXowcOdKx/8SJE1m7di1//PEHO3fu5JFHHuHgwYP885//tOojuK2JE6FrV7hyBe6/H/73P6srEhERuTkfqwuIjY3l5MmTjB07loSEBOrVq8fq1asdk4wPHTqEl9e1DHb69Gn69u1LQkICJUuWpGHDhmzdupUaNWpY9RHclpeXeUrq4EHzwn733QfbtkGJElZXJiIicmM2w/Csy7UlJSURHBxMYmKiTlFl0bFj0KQJ/PkntGkDK1dCkSJWVyUiIp4kO9/flp+WkoIvPBw+/RSKFoX162HgQF3BWERECi6FG8mSevVg0SKz/9Rrr8Err1hdkYiISMYUbiTLOnSAl1827w8dCp9/bm09IiIiGVG4kWx56ilzaXhKCnTvDj/+aHVFIiIizhRuJFtsNnj1VbjjDjh3zhzNuUGnDBEREUso3Ei2FSkCS5bArbfCoUPQuTNcvGh1VSIiIiaFG8mRkiXhs8+gVCn49lvo00crqEREpGBQuJEcq1LFbK5ZpAgsXgzjx1tdkYiIiMKN5FKrVubScDDbNbz/vrX1iIiIKNxIrvXpA//+t3n/0Udh61Zr6xEREc+mcCMuMWWKObE4Odn888ABiwsSERGPpXAjLuHlBe+9B/Xrw8mTZpPNpCSrqxIREU+kcCMuU7So2YMqIgJ+/hliY+HqVaurEhERT6NwIy5VrhysWAEBAbB6tdmmQUREJD8p3IjLNWxonqICmDXLvKKxiIhIflG4kTzRpYs5yRhg8GBYs8baekRExHMo3EieGT4cevcGux0efBB++cXqikRExBMo3EiesdnMC/y1bGmunLrvPnMllYiISF5SuJE85esLn3wClSvD/v3mNXAuXbK6KhERcWc5CjeHDx/mzz//dDzevn07Q4YMYf78+S4rTNxH6dJmk83gYPPqxX37qsmmiIjknRyFm4cffpiNGzcCkJCQwF133cX27dsZNWoUEydOdGmB4h6qVYMlS8Db21xJNXmy1RWJiIi7ylG4+emnn2jSpAkAH330EbVq1WLr1q28//77LFy40JX1iRtp0wbmzDHvjx4NH31kbT0iIuKechRurly5gp+fHwDr16+nY8eOAFSrVo1jx465rjpxO48/Dk89Zd6Pi4Pt262tR0RE3E+Owk3NmjWZN28emzdvZt26dbRr1w6Ao0ePEhIS4tICxf289BLce685sbhjRzh0yOqKRETEneQo3Lzwwgu89tprtG7dmoceeoi6desCsGLFCsfpKpEb8faGRYugTh04fhw6dICzZ62uSkRE3IXNMHK2bsVut5OUlETJkiUd2w4cOEBgYCBly5Z1WYGulpSURHBwMImJiQQFBVldjkc7dAiaNLkWcJYuNYOPiIjI9bLz/Z2jkZuLFy9y+fJlR7A5ePAgM2bMYO/evQU62EjBUr48LF8O/v5mN/F//9vqikRExB3kKNx06tSJd955B4AzZ87QtGlTpk6dSufOnZk7d65LCxT31rQpvP22eX/aNNClkkREJLdyFG527txJixYtAFiyZAmhoaEcPHiQd955h1deecWlBYr7e/BBSL08Uv/+sGGDtfWIiEjhlqNwc+HCBYoXLw7A2rVr6dKlC15eXtx2220cPHjQpQWKZxg9Gnr0gKtXoWtX2Ls3d+9nt8OmTebE5U2bzMciIuIZchRubrnlFpYtW8bhw4dZs2YNd999NwAnTpzQJF3JEZsN3ngDmjeHM2fMJpt//52z94qPh+houP12ePhh88/oaHO7iIi4vxyFm7FjxzJs2DCio6Np0qQJzZo1A8xRnPr167u0QPEc/v7miqnoaNi3Dx54AJKTs/ce8fHmyE+a1mcAHDliblfAERFxfzleCp6QkMCxY8eoW7cuXl5mRtq+fTtBQUFUq1bNpUW6kpaCF3w//wzNmpnXvunTB9580xzZuRm73QxG1webVDYbREaa3cm15FxEpHDJ86XgAGFhYdSvX5+jR486OoQ3adKkQAcbKRxq1jT7Tnl5wYIF5hWNs2Lz5hsHGzA7kR8+bO4nIiLuK0fhJiUlhYkTJxIcHEyFChWoUKECJUqUYNKkSaSkpLi6RvFA7drBzJnm/REjzNNVN5PVtmZqfyYi4t58cvKiUaNG8eabb/L8888TExMDwNdff8348eO5dOkSzz33nEuLFM80YIC5amr2bHjkEXPEpUGDG+8fHp61983qfiIiUjjlaM5NREQE8+bNc3QDT7V8+XKefPJJjhw54rICXU1zbgqXq1fNlVNr1kBEhNlFvFy5jPdNnXNz5Ih5Cup6mnMjIlJ45fmcm1OnTmU4t6ZatWqcOnUqJ28pkiEfH1i8GGrUgKNHzS7i589nvK+397VTWddPQE59PGOGgo2IiLvLUbipW7cus2fPTrd99uzZ1KlTJ9dFiaQVHAyffQZlysDOndCzJ9xoaleXLrBkSfrRnchIc3uXLnlfr4iIWCtHp6W+/PJL7r33XsqXL++4xs22bds4fPgwK1eudLRmKIh0Wqrw2rrVvCBfcrI5yXjKlBvva7ebc3SOHTPn2LRooREbEZHCLM9PS7Vq1YrffvuN+++/nzNnznDmzBm6dOnCzz//zLvvvpujokVupnlzeOst8/7zz5vLxG/E2xtat4aHHjL/VLAREfEcOb6IX0Z++OEHGjRogL0AN/LRyE3hN3YsTJoERYrAunXQqpXVFYmISF7Ll4v4iVhl/Hizk/iVK+Ycmn37rK5IREQKEoUbKXS8vGDhQmjSBE6dMpeKnz5tdVUiIlJQKNxIoRQQAMuXQ1SUeaG/bt3MkRwREZFsXaG4y03W0Z45cyY3tYhkS1iYuUQ8JgY2bDCvaDxvXtaabIqIiPvKVrgJDg6+6fO9evXKVUEi2VGnDixaZF7cb/58qFYNnnrK6qpERMRKLl0tVRhotZR7mj4dhg41R22WL4cOHayuSEREXEmrpcTjDBkC//qX2VPqoYfghx+srkhERKyicCNuwWaDWbPgzjvN3lMdOkBCgtVViYiIFRRuxG0UKQIffwxVq8Lhw9CpE1y8aHVVIiKS3xRuxK2ULGmuoCpVCrZvhx494MIFq6sSEZH8pHAjbueWW2DpUnMkZ+lSaNjQ7CYuIiKeQeFG3FLLlrB6NUREwK+/QtOmZrPNAtz2TEREXEThRtzWHXfAjz/CAw/A1aswcqQ54fjQIasrExGRvKRwI24tJMScZPzWW1CsGHz5pXnhvw8+sLoyERHJKwo34vZsNujTB3btgttug8REc6Jxjx6gjiEiIu5H4UY8RuXKsHkzjB8P3t7m6E3duvDVV1ZXJiIirqRwIx7FxwfGjYOvv4ZKlcz5N61bm/NxkpOtrk5ERFxB4UY80m23maepHn3UbNnw/PPQrJm5skpERAo3hRvxWMWLw5tvwiefmBf927kTGjSAefPMwCMiIoWTwo14vC5dYPduuOsus11Dv37QsSOcOGF1ZSIikhMKNyKYF/tbvRqmTwc/P7OFQ+3a8PnnVlcmIiLZpXAj8v+8vGDIENixwww2J07AfffBk0+qP5WISGGicCNyndq1zaabTz1lPp47V/2pREQKE4UbkQz4+8O0abB27bX+VLfdBi+8oP5UIiIFncKNSCbuuutaf6orV2DECPWnEhEp6BRuRG7iRv2pFi2yujIREclIgQg3r776KtHR0fj7+9O0aVO2b9+epdd9+OGH2Gw2OnfunLcFisfLqD/Vww/DI4+oP5WISEFjebhZvHgxQ4cOZdy4cezcuZO6devStm1bTtzkIiMHDhxg2LBhtGjRIp8qFUnfn+r999WfSkSkoLE83EybNo2+ffvSp08fatSowbx58wgMDOStt9664Wvsdjs9evRgwoQJVKpUKR+rFVF/KhGRgs7ScJOcnMx3331HmzZtHNu8vLxo06YN27Ztu+HrJk6cSNmyZXnsscdueozLly+TlJTkdBNxhYz6UzVvrv5UIiJWszTc/PXXX9jtdkJDQ522h4aGkpCQkOFrvv76a958801ef/31LB1jypQpBAcHO25RUVG5rlsk1fX9qb77Tv2pRESsZvlpqew4e/YsPXv25PXXX6d06dJZes3IkSNJTEx03A4fPpzHVYon6tLFXDLepo36U4mIWM3HyoOXLl0ab29vjh8/7rT9+PHjhIWFpdv/f//7HwcOHKBDhw6ObSkpKQD4+Piwd+9eKleu7PQaPz8//Pz88qB6EWflysGaNfDKK+b1cFL7U731Ftx7r9XViYh4DktHbnx9fWnYsCEbNmxwbEtJSWHDhg00a9Ys3f7VqlVj9+7d7Nq1y3Hr2LEjt99+O7t27dIpJ7Fc2v5UtWpd60/Vv7/6U4mI5BdLR24Ahg4dSlxcHI0aNaJJkybMmDGD8+fP06dPHwB69epFuXLlmDJlCv7+/tSqVcvp9SVKlABIt13ESrVrmwHnP/8xO43PmQNffGEuHW/QwOrqRETcm+XhJjY2lpMnTzJ27FgSEhKoV68eq1evdkwyPnToEF5ehWpqkAhwrT9V+/bQu/e1/lSTJsGwYeZ1ckRExPVshuFZazqSkpIIDg4mMTGRoKAgq8sRD/H33/D44xAfbz5u1QreeQfKl7e2LhGRwiI7398aEhHJByEhsGRJ9vtT2e2waZO536ZN6kguIpIVCjci+SS7/ani4yE6Gm6/3dzv9tvNx6mjPyIikjGFG5F8lpX+VPHx0LUr/Pmn82uPHDG3K+CIiNyYwo2IBVL7U23e7Nyf6j//MS8COHhwxlc4Tt02ZIhOUYmI3IjCjYiFmjVz7k81ZQrUq5d+xCYtw4DDh81gJCIi6SnciFjs+v5Uv/2WtdcdO5a3dYmIFFYKNyIFRGp/qoYNs7Z/eHje1iMiUlgp3IgUIOXKwbZtEBx8431sNoiKghYt8q8uEZHCROFGpIApUsS8Hs6NGAbMmKErHIuI3IjCjUgB1KWLOQenXLn0zwUHw+7dZlNOERFJT+0XRAowu91cFfX77/Df/8LKlddWUvn6Qo8e5rLxunWtrVNEJK9l5/tb4UakELlyxbyA3/Tp8O2317bffjs89RTcey+oz6yIuCP1lhJxU0WKQGwsfPONOfE4Ntace7NxI3TsCFWrwuzZcO6c1ZWKiFhH4UakkLrtNvjwQ/jjD/j3v6FECdi3DwYOhMhIeOYZOHjQ6ipFRPKfwo1IIVe+PLzwgjkX59VX4dZbzaacL79stnbo1g22bs24nYOIiDtSuBFxE0WLwpNPwp498Pnn0KYNpKTAkiUQEwNNm8IHH5jzdkRE3JnCjYib8fKCe+6BdevMJeOPPQZ+frBjh7m6Kjra7GH1999WVyoikjcUbkTcWK1a8MYbZqPNSZMgLAyOHjW7j0dFwRNPmCM9IiLuROFGxAOUKQOjR8OBA/DOO1C/Ply8CK+9BjVqQPv2sGaN5uWIiHtQuBHxIH5+0LMnfPcdfPkl3H+/2atq9Wpo1w5q1jQDz4ULVlcqIpJzCjciHshmg5YtzQsC7tsHQ4ZA8eLmKaonnjBPWf3nP3DkiNWViohkn8KNiIerVMm84vGff5p/VqwIp06Zk46jo81JyDt2WF2liEjWKdyICABBQeYIzu+/myM6LVvC1avm8vEmTeAf/zCXlV+9anWlIiKZU7gRESfe3uZcnC+/NOfm9Oxptn3YssW8IOAtt8DUqXDmjNWViohkTOFGRG6oQQNzddXBgzBmDJQubd4fNsxs8TBokDlnR0SkIFG4EZGbCg+HiRPh0CHzujm1asH58zBrltnuoWNHs3mnlpKLSEGgcCMiWRYQYF7x+McfzSsg33uvGWg+/RTuuMO8fs7ChXDpktWViognU7gRkWyz2czeVZ99Br/+ava0CgyEH36APn2gQgUYPx6OH7e6UhHxRAo3IpIrVaua3cj//NPsTh4ZCSdOwIQJZsfyPn3M0CMikl8UbkTEJUqWhH//G/74AxYvhttug+Rk8zRVvXrmaaulS3XKSkTynsKNiLhUkSLw4IOwbZt5697dXF6+cSN06WKuuOraFd57D06ftrpaEXFHNsPwrPUNSUlJBAcHk5iYSFBQkNXliHiEw4fNU1fvvefc0sHbG1q3hk6dzFv58paVKCIFXHa+vxVuRCTfGIZ5YcBly2D5cvjpJ+fnGzQwQ07nzlC7tjlxWUQEFG4ypXAjkv/sdti8GY4dM6+Z06KFOWqzb58ZcpYvh6+/dr5OTsWK14JOTAz4+FhWvogUAAo3mVC4Eclf8fEweLC5mipVZCTMnGnOwUl14oS5tHz5cli71nnicUgIdOhghp277zaXnYuIZ1G4yYTCjUj+iY83Jw9f/3+Z1NNNS5Y4B5xU58+bAWfZMjPwnDp17bmAADPgdOoE990HZcrkWfkiUoAo3GRC4UYkf9jtEB3tPGKTls1mjuDs32+eorqRq1fNU1ap83QOHLj2nJeX2a089fRVpUquq19EChaFm0wo3Ijkj02b4Pbbb77fxo3miqmsMAyz9UNq0Pn+e+fna9e+FnQaNNCEZBF3kp3vb13nRkTyxLFjrt0PzLBSty6MGwc7d5qjODNnmiHK2xt274Znn4VGjcwWEAMHwvr1cOVKjj6CiBRSCjcikifCw127X0YqVIBBg+CLL8wJye+8Y87hCQw0r60zezbcdReULQuPPAIffwxnz+b8eCJSOOi0lIjkidQ5N0eOpJ9QDFmfc5MTFy/Chg3m6asVK+DkyWvP+fqaTT87dYKOHSEszLXHFpG8oTk3mVC4Eck/qaulwDng3Gy1lCvZ7fDNN2bQWbbMvLZO2jpuu82co9Opk9kEVEQKJoWbTCjciOSvjK5zExUFM2bkfbC5nmHAnj3Xgs6OHc7PV6tmBp3OnaFxY3M1logUDAo3mVC4Ecl/N7pCsdWOHDFPWy1bZq7aSjvxODzcPG3VqZPZ0dzPz7IyRQSFm0wp3IhIRhITYdUqM+isXOk88bh4cWjf3hzRad8eSpSwqEgRD6ZwkwmFGxG5mcuXzev0pF5PJ+1ydR8faNXK7HfVpIl501WSRfKewk0mFG5EJDtSUuC//702T2fPnvT7VKx4Leg0bQr166v/lYirKdxkQuFGRHLjt9/MCwNu327eMgo73t7m1ZKbNr0WeqpXLxjzjEQKK4WbTCjciIgrJSaaIzvbt8O335q3hIT0+xUrZl45OXV0p0kTKFdOLSJEskrhJhMKNyKSlwzDXIX17bfXRnd27DA7nV8vPNx5dKdxY9D/lkQypnCTCYUbEclvdrt5+ip1dGf7drMPlt3uvJ/NZl5rJ+3oTu3a5lWVRTydwk0mFG5EpCC4cMFs/pk6uvPtt2Yj0Ov5+ZkdztNOWK5USaezxPMo3GRC4UZECqoTJ8xTWGlPaZ0+nX6/UqWcR3caN9ZydHF/CjeZULgRkcLCMMxeWKlBZ/t2+P578zo816tUKf1y9ICA/K9ZJK8o3GRC4UZECrPkZPjxR+f5O7/+mn4/H5/0y9GrVdNydCm8FG4yoXAjIu7mzJlry9FTQ09Gy9GLF3dejt64sZajS+GhcJMJhRsRyamC2gD0eoZhdmFPO7rz3/9mvBw9ONi8wGCNGs5/VqigruhSsCjcZELhRkRyIj4eBg82Q0OqyEiYORO6dLGurqyy2+GXX5xHd376Kf1y9FQBAeZprOtDT+XKUKRI/tYuAgo3mVK4EZHsio+Hrl3NEZG0Uk/nLFlSOALO9S5fht9/N0PPnj3X/ty715zbk5EiRaBKlfSjPbfeqgnMkrcUbjKhcCMi2WG3Q3S084hNWjabOYKzf3/BPEWVE1evmp/n+tCzZ0/Gp7bA/HuoVCl96KlWTVddFtdQuMmEwo2IZMemTXD77Tffb+NGaN06r6uxVkqKGfKuDz2//JLx9XhSRUZmPK+ndOn8q10Kv+x8f/vkU00iIoXSsWOu3a8w8/KC8uXNW7t217YbhnkBwoxGeo4dMwPRn3/CunXO71emTMahJyJCK7gkdxRuREQyER7u2v3ckc0GoaHm7fpRrtOnzevwXB98DhyAkyfN21dfOb8mKCjj0BMdrRVckjU6LSUikonUOTdHjqSfUAzuOecmP5w/b05c/uUX5+Dzv/9lvoKratX0weeWW7SCyxNozk0mFG5EJLtSV0uBc8Ap7KulCqLUFVypp7XSruDKqO0EmKEyIsK8IGFmt8DA/P0s4loKN5lQuBGRnMjoOjdRUTBjhoJNfrDbM17B9csvN17Bdb0SJZzDTmRk+gBUurROfRVUCjeZULgRkZwqLFco9iSGAUePmqHzyJEb37IagIoUydookL9/3n4uSa/QhZtXX32Vl156iYSEBOrWrcusWbNo0qRJhvvGx8czefJk9u3bx5UrV6hSpQpPP/00PXv2zNKxFG5ERDyLYUBiYubh58gRc8VXVr8RQ0JuHoBCQrTqy5UK1VLwxYsXM3ToUObNm0fTpk2ZMWMGbdu2Ze/evZQtWzbd/qVKlWLUqFFUq1YNX19fPvvsM/r06UPZsmVp27atBZ9AREQKMpvNPCVVogTUrHnj/a5cMUflbhR+UkeHLl2Cv/82bz/+eOP38/NLPwp0/amwiAjw9XX1JxbLR26aNm1K48aNmT17NgApKSlERUUxcOBARowYkaX3aNCgAffeey+TJk266b4auRERkZwyDHN5+81GgU6ezPp7lilzLeyUKQOlSmV+CwryzBGhQjNyk5yczHfffcfIkSMd27y8vGjTpg3btm276esNw+CLL75g7969vPDCC3lZqoiICDbbtZBRu/aN97t8OeNRoLRzg44eNfdLvd7Prl1Zq8Hb++YBKO0tJMT8MzjYcyZLWxpu/vrrL+x2O6GhoU7bQ0ND+fXXX2/4usTERMqVK8fly5fx9vZmzpw53HXXXRnue/nyZS6nWT+YlJTkmuJFRERuwM/PvD5SdPSN9zEM89RW2rDz999w6lT6W+r2ixfNie2pgSg7bDYoWTJrQSjtrUQJ8LF8Ekv2FLJyTcWLF2fXrl2cO3eODRs2MHToUCpVqkTrDBq7TJkyhQkTJuR/kSIiIpmw2cyl56VLQ926WXvNxYvmabGMAlDaEHT97dw5M0ylPs6u4OCbh6C0t9KlzVNsVrF0zk1ycjKBgYEsWbKEzp07O7bHxcVx5swZli9fnqX3+ec//8nhw4dZs2ZNuucyGrmJiorSnBsREfEYycnXQtGNAlBGt8TEnB2vfn3YudO1n6HQzLnx9fWlYcOGbNiwwRFuUlJS2LBhAwMGDMjy+6SkpDgFmLT8/Pzw8/NzRbkiIiKFkq/vtf5f2XHlCpw5c/MQdH1gsrrju+WnpYYOHUpcXByNGjWiSZMmzJgxg/Pnz9OnTx8AevXqRbly5ZgyZQpgnmZq1KgRlStX5vLly6xcuZJ3332XuXPnWvkxREQKDV2MULKqSBHz9FJ2TzFZfQU9y8NNbGwsJ0+eZOzYsSQkJFCvXj1Wr17tmGR86NAhvNJM7z5//jxPPvkkf/75JwEBAVSrVo333nuP2NhYqz6CiEihkVEbichImDlTbSTEdaxeqm75dW7ym65zIyKeKrUB6PX/11cDUCkMsvP97SEr3kVEPJvdbo7YZPTP2dRtQ4aY+4kUdgo3IiIeYPNm51NR1zMMOHzY3E+ksFO4ERHxAMeOuXY/kYJM4UZExAOEh7t2P5GCTOFGRMQDtGhhroq60SoWmw2iosz9RAo7hRsREQ/g7W0u94b0ASf18YwZut6NuAeFGxERD9Gli7ncu1w55+2RkVoGLu7F8ov4iYhI/unSBTp10hWKxb0p3IiIeBhvb2jd2uoqRPKOTkuJiIiIW1G4EREREbeicCMiIiJuRXNuRESk0LLbNTla0lO4ERGRQik+3mwGmrZnVmSkeT0fLWv3bDotJSIihU58PHTtmr4Z6JEj5vb4eGvqkoJB4UZERAoVu90csTGM9M+lbhsyxNxPPJPCjYiIFCqbN6cfsUnLMODwYXM/8UwKNyIiUqgcO+ba/cT9KNyIiEihEh7u2v3E/SjciIhIodKihbkq6vru5qlsNoiKMvcTz6RwIyIihYq3t7ncG9IHnNTHM2boejeeTOFGREQKnS5dYMkSKFfOeXtkpLld17nxbLqIn4iIFEpdukCnTrpCsaSncCMiIoWWtze0bm11FVLQ6LSUiIiIuBWFGxEREXErOi0lIiJiMXU3dy2FGxEREQupu7nr6bSUiIiIRdTdPG8o3IiIiFhA3c3zjsKNiIiIBdTdPO8o3IiIiFhA3c3zjsKNiIiIBdTdPO8o3IiIiFhA3c3zjsKNiIiIBdTdPO8o3IiIiFhE3c3zhi7iJyIiYiF1N3c9hRsRERGLuUt384LSRkLhRkRERHKtILWR0JwbERERyZWC1kZC4UZERERyrCC2kVC4ERERkRwriG0kFG5EREQkxwpiGwmFGxEREcmxgthGQuFGREREcqwgtpFQuBEREZEcK4htJBRuREREJFcKWhsJXcRPREREcq0gtZFQuBERERGXKChtJHRaSkRERNyKwo2IiIi4FYUbERERcSsKNyIiIuJWFG5ERETErSjciIiIiFtRuBERERG3onAjIiIibkXhRkRERNyKx12h2DAMAJKSkiyuRERERLIq9Xs79Xs8Mx4Xbs6ePQtAVFSUxZWIiIhIdp09e5bg4OBM97EZWYlAbiQlJYWjR49SvHhxbNf3ZhfATMdRUVEcPnyYoKAgq8vxePp5FCz6eRQ8+pkULHn18zAMg7NnzxIREYGXV+azajxu5MbLy4vIyEiryygUgoKC9D+KAkQ/j4JFP4+CRz+TgiUvfh43G7FJpQnFIiIi4lYUbkRERMStKNxIOn5+fowbNw4/Pz+rSxH08yho9PMoePQzKVgKws/D4yYUi4iIiHvTyI2IiIi4FYUbERERcSsKNyIiIuJWFG5ERETErSjciMOUKVNo3LgxxYsXp2zZsnTu3Jm9e/daXZYAzz//PDabjSFDhlhdikc7cuQIjzzyCCEhIQQEBFC7dm3++9//Wl2WR7Lb7YwZM4aKFSsSEBBA5cqVmTRpUpb6DknuffXVV3To0IGIiAhsNhvLli1zet4wDMaOHUt4eDgBAQG0adOG33//Pd/qU7gRhy+//JL+/fvzzTffsG7dOq5cucLdd9/N+fPnrS7No+3YsYPXXnuNOnXqWF2KRzt9+jQxMTEUKVKEVatW8csvvzB16lRKlixpdWke6YUXXmDu3LnMnj2bPXv28MILL/Diiy8ya9Ysq0vzCOfPn6du3bq8+uqrGT7/4osv8sorrzBv3jy+/fZbihYtStu2bbl06VK+1Kel4HJDJ0+epGzZsnz55Ze0bNnS6nI80rlz52jQoAFz5szh2WefpV69esyYMcPqsjzSiBEj2LJlC5s3b7a6FAHuu+8+QkNDefPNNx3bHnjgAQICAnjvvfcsrMzz2Gw2li5dSufOnQFz1CYiIoKnn36aYcOGAZCYmEhoaCgLFy6ke/fueV6TRm7khhITEwEoVaqUxZV4rv79+3PvvffSpk0bq0vxeCtWrKBRo0Z069aNsmXLUr9+fV5//XWry/JYzZs3Z8OGDfz2228A/PDDD3z99de0b9/e4spk//79JCQkOP1/Kzg4mKZNm7Jt27Z8qcHjGmdK1qSkpDBkyBBiYmKoVauW1eV4pA8//JCdO3eyY8cOq0sR4I8//mDu3LkMHTqU//znP+zYsYNBgwbh6+tLXFyc1eV5nBEjRpCUlES1atXw9vbGbrfz3HPP0aNHD6tL83gJCQkAhIaGOm0PDQ11PJfXFG4kQ/379+enn37i66+/troUj3T48GEGDx7MunXr8Pf3t7ocwQz8jRo1YvLkyQDUr1+fn376iXnz5incWOCjjz7i/fff54MPPqBmzZrs2rWLIUOGEBERoZ+H6LSUpDdgwAA+++wzNm7cSGRkpNXleKTvvvuOEydO0KBBA3x8fPDx8eHLL7/klVdewcfHB7vdbnWJHic8PJwaNWo4batevTqHDh2yqCLP9swzzzBixAi6d+9O7dq16dmzJ0899RRTpkyxujSPFxYWBsDx48edth8/ftzxXF5TuBEHwzAYMGAAS5cu5YsvvqBixYpWl+Sx7rzzTnbv3s2uXbsct0aNGtGjRw927dqFt7e31SV6nJiYmHSXRvjtt9+oUKGCRRV5tgsXLuDl5fwV5u3tTUpKikUVSaqKFSsSFhbGhg0bHNuSkpL49ttvadasWb7UoNNS4tC/f38++OADli9fTvHixR3nRoODgwkICLC4Os9SvHjxdHOdihYtSkhIiOZAWeSpp56iefPmTJ48mQcffJDt27czf/585s+fb3VpHqlDhw4899xzlC9fnpo1a/L9998zbdo0Hn30UatL8wjnzp1j3759jsf79+9n165dlCpVivLlyzNkyBCeffZZqlSpQsWKFRkzZgwRERGOFVV5zhD5f0CGtwULFlhdmhiG0apVK2Pw4MFWl+HRPv30U6NWrVqGn5+fUa1aNWP+/PlWl+SxkpKSjMGDBxvly5c3/P39jUqVKhmjRo0yLl++bHVpHmHjxo0Zfl/ExcUZhmEYKSkpxpgxY4zQ0FDDz8/PuPPOO429e/fmW326zo2IiIi4Fc25EREREbeicCMiIiJuReFGRERE3IrCjYiIiLgVhRsRERFxKwo3IiIi4lYUbkRERMStKNyIiEey2WwsW7bM6jJEJA8o3IhIvuvduzc2my3drV27dlaXJiJuQL2lRMQS7dq1Y8GCBU7b/Pz8LKpGRNyJRm5ExBJ+fn6EhYU53UqWLAmYp4zmzp1L+/btCQgIoFKlSixZssTp9bt37+aOO+4gICCAkJAQHn/8cc6dO+e0z1tvvUXNmjXx8/MjPDycAQMGOD3/119/cf/99xMYGEiVKlVYsWKF47nTp0/To0cPypQpQ0BAAFWqVEkXxkSkYFK4EZECacyYMTzwwAP88MMP9OjRg+7du7Nnzx4Azp8/T9u2bSlZsiQ7duzg448/Zv369U7hZe7cufTv35/HH3+c3bt3s2LFCm655RanY0yYMIEHH3yQH3/8kXvuuYcePXpw6tQpx/F/+eUXVq1axZ49e5g7dy6lS5fOv78AEcm5fGvRKSLy/+Li4gxvb2+jaNGiTrfnnnvOMAyzQ/0TTzzh9JqmTZsa/fr1MwzDMObPn2+ULFnSOHfunOP5zz//3PDy8jISEhIMwzCMiIgIY9SoUTesATBGjx7teHzu3DkDMFatWmUYhmF06NDB6NOnj2s+sIjkK825ERFL3H777cydO9dpW6lSpRz3mzVr5vRcs2bN2LVrFwB79uyhbt26FC1a1PF8TEwMKSkp7N27F5vNxtGjR7nzzjszraFOnTqO+0WLFiUoKIgTJ04A0K9fPx544AF27tzJ3XffTefOnWnevHmOPquI5C+FGxGxRNGiRdOdJnKVgICALO1XpEgRp8c2m42UlBQA2rdvz8GDB1m5ciXr1q3jzjvvpH///rz88ssur1dEXEtzbkSkQPrmm2/SPa5evToA1atX54cffuD8+fOO57ds2YKXlxdVq1alePHiREdHs2HDhlzVUKZMGeLi4njvvfeYMWMG8+fPz9X7iUj+0MiNiFji8uXLJCQkOG3z8fFxTNr9+OOPadSoEf/4xz94//332b59O2+++SYAPXr0YNy4ccTFxTF+/HhOnjzJwIED6dmzJ6GhoQCMHz+eJ554grJly9K+fXvOnj3Lli1bGDhwYJbqGzt2LA0bNqRmzZpcvnyZzz77zBGuRKRgU7gREUusXr2a8PBwp21Vq1bl119/BcyVTB9++CFPPvkk4eHhLFq0iBo1agAQGBjImjVrGDx4MI0bNyYwMJAHHniAadOmOd4rLi6OS5cuMX36dIYNG0bp0qXp2rVrluvz9fVl5MiRHDhwgICAAFq0aMGHH37ogk8uInnNZhiGYXURIiJp2Ww2li5dSufOna0uRUQKIc25EREREbeicCMiIiJuRXNuRKTA0dlyEckNjdyIiIiIW1G4EREREbeicCMiIiJuReFGRERE3IrCjYiIiLgVhRsRERFxKwo3IiIi4lYUbkRERMStKNyIiIiIW/k/BPLpH1Zbfp0AAAAASUVORK5CYII=", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "acc = history_dict['binary_accuracy']\n", "val_acc = history_dict['val_binary_accuracy']\n", "loss = history_dict['loss']\n", "val_loss = history_dict['val_loss']\n", "\n", "epochs = range(1, len(acc) + 1)\n", "\n", "# \"bo\" is for \"blue dot\"\n", "plt.plot(epochs, loss, 'bo', label='Training loss')\n", "# b is for \"solid blue line\"\n", "plt.plot(epochs, val_loss, 'b', label='Validation loss')\n", "plt.title('Training and validation loss')\n", "plt.xlabel('Epochs')\n", "plt.ylabel('Loss')\n", "plt.legend()\n", "\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 25, "id": "af51178e-fe0b-40ca-9260-2190fb52d960", "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAkAAAAHHCAYAAABXx+fLAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABcgUlEQVR4nO3dd1hT598G8DuEvVWQJQKideJWioraSovaUnGideCottaNtmoVR63SWqu4qrV11Um1aP1Vq1WqdY+690RRFBQHCCpKOO8f5000EjCBwEnI/bmuXCZPTk6+J9Dm5pxnyARBEEBERERkQsykLoCIiIiopDEAERERkclhACIiIiKTwwBEREREJocBiIiIiEwOAxARERGZHAYgIiIiMjkMQERERGRyGICIiIjI5DAAEelB79694evrW6jXTpo0CTKZTL8FGZjr169DJpNh2bJlJfq+u3btgkwmw65du1Rt2v6siqtmX19f9O7dW6/7JCLdMQBRqSaTybS6vfoFSVRU+/fvx6RJk/Do0SOpSyGifJhLXQBRcVqxYoXa419//RXbt2/P0169evUivc/PP/+M3NzcQr12/PjxGDNmTJHen7RXlJ+Vtvbv34/Jkyejd+/ecHZ2Vnvu4sWLMDPj355EUmMAolKtR48eao8PHjyI7du352l/3ZMnT2Bra6v1+1hYWBSqPgAwNzeHuTn/UywpRflZ6YOVlZWk728ssrKyYGdnJ3UZVIrxzxAyeS1btkStWrVw9OhRNG/eHLa2tvjqq68AAH/88Qc++OADeHp6wsrKCv7+/pgyZQoUCoXaPl7vV6LsPzJjxgwsWrQI/v7+sLKyQqNGjXDkyBG112rqAySTyTB48GBs3LgRtWrVgpWVFWrWrImtW7fmqX/Xrl1o2LAhrK2t4e/vj59++knrfkV79uxB586dUbFiRVhZWcHb2xsjRozA06dP8xyfvb09kpOTER4eDnt7e7i6umLUqFF5PotHjx6hd+/ecHJygrOzMyIjI7W6FPTff/9BJpNh+fLleZ7btm0bZDIZ/vzzTwDAjRs38Pnnn6Nq1aqwsbFBuXLl0LlzZ1y/fv2N76OpD5C2NZ86dQq9e/dGpUqVYG1tDXd3d/Tt2xf3799XbTNp0iR88cUXAAA/Pz/VZVZlbZr6AF27dg2dO3dG2bJlYWtri7fffhubN29W20bZn+m3337D1KlTUaFCBVhbW6NVq1a4cuXKG49bl8/s0aNHGDFiBHx9fWFlZYUKFSqgV69eSEtLU23z7NkzTJo0CW+99Rasra3h4eGBDh064OrVq2r1vn55WVPfKuXv19WrV9G2bVs4ODige/fuALT/HQWACxcuoEuXLnB1dYWNjQ2qVq2KcePGAQB27twJmUyGDRs25Hnd6tWrIZPJcODAgTd+jlR68M9OIgD3799HmzZt0LVrV/To0QNubm4AgGXLlsHe3h5RUVGwt7fHP//8gwkTJiAjIwPff//9G/e7evVqPH78GJ9++ilkMhmmT5+ODh064Nq1a288E7F3717Ex8fj888/h4ODA+bMmYOOHTsiKSkJ5cqVAwAcP34crVu3hoeHByZPngyFQoGvv/4arq6uWh33unXr8OTJEwwcOBDlypXD4cOHMXfuXNy6dQvr1q1T21ahUCA0NBSBgYGYMWMGduzYgR9++AH+/v4YOHAgAEAQBLRr1w579+7FZ599hurVq2PDhg2IjIx8Yy0NGzZEpUqV8Ntvv+XZPi4uDmXKlEFoaCgA4MiRI9i/fz+6du2KChUq4Pr161iwYAFatmyJc+fO6XT2Tpeat2/fjmvXrqFPnz5wd3fH2bNnsWjRIpw9exYHDx6ETCZDhw4dcOnSJaxZswazZs2Ci4sLAOT7M0lNTUWTJk3w5MkTDB06FOXKlcPy5cvx0UcfYf369Wjfvr3a9t9++y3MzMwwatQopKenY/r06ejevTsOHTpU4HFq+5llZmYiODgY58+fR9++fVG/fn2kpaVh06ZNuHXrFlxcXKBQKPDhhx8iISEBXbt2xbBhw/D48WNs374dZ86cgb+/v9afv1JOTg5CQ0PRrFkzzJgxQ1WPtr+jp06dQnBwMCwsLDBgwAD4+vri6tWr+N///oepU6eiZcuW8Pb2xqpVq/J8pqtWrYK/vz+CgoJ0rpuMmEBkQgYNGiS8/mvfokULAYCwcOHCPNs/efIkT9unn34q2NraCs+ePVO1RUZGCj4+PqrHiYmJAgChXLlywoMHD1Ttf/zxhwBA+N///qdqmzhxYp6aAAiWlpbClStXVG0nT54UAAhz585VtYWFhQm2trZCcnKyqu3y5cuCubl5nn1qoun4YmJiBJlMJty4cUPt+AAIX3/9tdq29erVExo0aKB6vHHjRgGAMH36dFVbTk6OEBwcLAAQli5dWmA9Y8eOFSwsLNQ+s+zsbMHZ2Vno27dvgXUfOHBAACD8+uuvqradO3cKAISdO3eqHcurPytdatb0vmvWrBEACLt371a1ff/99wIAITExMc/2Pj4+QmRkpOrx8OHDBQDCnj17VG2PHz8W/Pz8BF9fX0GhUKgdS/Xq1YXs7GzVtrNnzxYACKdPn87zXq/S9jObMGGCAECIj4/Ps31ubq4gCIKwZMkSAYAwc+bMfLfR9NkLwsv/Nl79XJW/X2PGjNGqbk2/o82bNxccHBzU2l6tRxDE3y8rKyvh0aNHqra7d+8K5ubmwsSJE/O8D5VuvARGBLFfRp8+ffK029jYqO4/fvwYaWlpCA4OxpMnT3DhwoU37jciIgJlypRRPQ4ODgYgXvJ4k5CQELW/pGvXrg1HR0fVaxUKBXbs2IHw8HB4enqqtqtcuTLatGnzxv0D6seXlZWFtLQ0NGnSBIIg4Pjx43m2/+yzz9QeBwcHqx3Lli1bYG5urjojBAByuRxDhgzRqp6IiAi8ePEC8fHxqra///4bjx49QkREhMa6X7x4gfv376Ny5cpwdnbGsWPHtHqvwtT86vs+e/YMaWlpePvttwFA5/d99f0bN26MZs2aqdrs7e0xYMAAXL9+HefOnVPbvk+fPrC0tFQ91vZ3StvP7Pfff0edOnXynCUBoLqs+vvvv8PFxUXjZ1SUKR1e/Rloqju/39F79+5h9+7d6Nu3LypWrJhvPb169UJ2djbWr1+vaouLi0NOTs4b+wVS6cMARATAy8tL7UtF6ezZs2jfvj2cnJzg6OgIV1dX1f8o09PT37jf1/9nrAxDDx8+1Pm1ytcrX3v37l08ffoUlStXzrOdpjZNkpKS0Lt3b5QtW1bVr6dFixYA8h6ftbV1nss4r9YDiP1MPDw8YG9vr7Zd1apVtaqnTp06qFatGuLi4lRtcXFxcHFxwbvvvqtqe/r0KSZMmABvb29YWVnBxcUFrq6uePTokVY/l1fpUvODBw8wbNgwuLm5wcbGBq6urvDz8wOg3e9Dfu+v6b2UIxNv3Lih1l7Y3yltP7OrV6+iVq1aBe7r6tWrqFq1ql4775ubm6NChQp52rX5HVWGvzfVXa1aNTRq1AirVq1Sta1atQpvv/221v/NUOnBPkBEUP8rU+nRo0do0aIFHB0d8fXXX8Pf3x/W1tY4duwYRo8erdVQarlcrrFdEIRifa02FAoF3nvvPTx48ACjR49GtWrVYGdnh+TkZPTu3TvP8eVXj75FRERg6tSpSEtLg4ODAzZt2oRu3bqpfdkOGTIES5cuxfDhwxEUFAQnJyfIZDJ07dq1WIe4d+nSBfv378cXX3yBunXrwt7eHrm5uWjdunWxD61XKuzvRUl/ZvmdCXq907ySlZVVnukBdP0d1UavXr0wbNgw3Lp1C9nZ2Th48CDmzZun837I+DEAEeVj165duH//PuLj49G8eXNVe2JiooRVvVS+fHlYW1trHAGkzaig06dP49KlS1i+fDl69eqlat++fXuha/Lx8UFCQgIyMzPVzqhcvHhR631ERERg8uTJ+P333+Hm5oaMjAx07dpVbZv169cjMjISP/zwg6rt2bNnhZp4UNuaHz58iISEBEyePBkTJkxQtV++fDnPPnW5DOTj46Px81FeYvXx8dF6XwXR9jPz9/fHmTNnCtyXv78/Dh06hBcvXuTbmV95Zur1/b9+Rqsg2v6OVqpUCQDeWDcAdO3aFVFRUVizZg2ePn0KCwsLtcurZDp4CYwoH8q/tF/9y/r58+f48ccfpSpJjVwuR0hICDZu3Ijbt2+r2q9cuYK//vpLq9cD6scnCAJmz55d6Jratm2LnJwcLFiwQNWmUCgwd+5crfdRvXp1BAQEIC4uDnFxcfDw8FALoMraXz/jMXfu3HzPLuijZk2fFwDExsbm2ady/hptAlnbtm1x+PBhtSHYWVlZWLRoEXx9fVGjRg1tD6VA2n5mHTt2xMmTJzUOF1e+vmPHjkhLS9N45kS5jY+PD+RyOXbv3q32vC7//Wj7O+rq6ormzZtjyZIlSEpK0liPkouLC9q0aYOVK1di1apVaN26tWqkHpkWngEiykeTJk1QpkwZREZGYujQoZDJZFixYoXeLkHpw6RJk/D333+jadOmGDhwIBQKBebNm4datWrhxIkTBb62WrVq8Pf3x6hRo5CcnAxHR0f8/vvvWvVPyk9YWBiaNm2KMWPG4Pr166hRowbi4+N17h8TERGBCRMmwNraGv369ctzaeTDDz/EihUr4OTkhBo1auDAgQPYsWOHanqA4qjZ0dERzZs3x/Tp0/HixQt4eXnh77//1nhGsEGDBgCAcePGoWvXrrCwsEBYWJjGif3GjBmDNWvWoE2bNhg6dCjKli2L5cuXIzExEb///rveZo3W9jP74osvsH79enTu3Bl9+/ZFgwYN8ODBA2zatAkLFy5EnTp10KtXL/z666+IiorC4cOHERwcjKysLOzYsQOff/452rVrBycnJ3Tu3Blz586FTCaDv78//vzzT9y9e1frmnX5HZ0zZw6aNWuG+vXrY8CAAfDz88P169exefPmPP8t9OrVC506dQIATJkyRfcPk0qHEh93RiSh/IbB16xZU+P2+/btE95++23BxsZG8PT0FL788kth27ZtbxxarRzq+/333+fZJwC1Ibf5DYMfNGhQnte+PoRaEAQhISFBqFevnmBpaSn4+/sLv/zyizBy5EjB2to6n0/hpXPnzgkhISGCvb294OLiIvTv31813P71Ycp2dnZ5Xq+p9vv37ws9e/YUHB0dBScnJ6Fnz57C8ePHtRoGr3T58mUBgABA2Lt3b57nHz58KPTp00dwcXER7O3thdDQUOHChQt5Ph9thsHrUvOtW7eE9u3bC87OzoKTk5PQuXNn4fbt23l+poIgCFOmTBG8vLwEMzMztSHxmn6GV69eFTp16iQ4OzsL1tbWQuPGjYU///xTbRvlsaxbt06tXdOwck20/cyUn8fgwYMFLy8vwdLSUqhQoYIQGRkppKWlqbZ58uSJMG7cOMHPz0+wsLAQ3N3dhU6dOglXr15VbXPv3j2hY8eOgq2trVCmTBnh008/Fc6cOaP175cgaP87KgiCcObMGdXPx9raWqhataoQHR2dZ5/Z2dlCmTJlBCcnJ+Hp06cFfm5UeskEwYD+nCUivQgPD8fZs2c19k8hMnU5OTnw9PREWFgYFi9eLHU5JBH2ASIycq8vCXD58mVs2bIFLVu2lKYgIgO3ceNG3Lt3T61jNZkengEiMnIeHh6q9alu3LiBBQsWIDs7G8ePH0eVKlWkLo/IYBw6dAinTp3ClClT4OLiUujJK6l0YCdoIiPXunVrrFmzBikpKbCyskJQUBCmTZvG8EP0mgULFmDlypWoW7eu2mKsZJp4BoiIiIhMDvsAERERkclhACIiIiKTwz5AGuTm5uL27dtwcHAo0srGREREVHIEQcDjx4/h6en5xklEGYA0uH37Nry9vaUug4iIiArh5s2bqFChQoHbMABp4ODgAED8AB0dHSWuhoiIiLSRkZEBb29v1fd4QRiANFBe9nJ0dGQAIiIiMjLadF9hJ2giIiIyOQxAREREZHIYgIiIiMjkMAARERGRyWEAIiIiIpPDAEREREQmhwGIiIiITA4DEBEREZkcBiAiIiIyOZwJmoiIiEqEQgHs2QPcuQN4eADBwYBcLk0tDEBERERU7OLjgWHDgFu3XrZVqADMng106FDy9fASGBERERWr+HigUyf18AMAyclie3x8ydfEAERERETFRqEQz/wIQt7nlG3Dh4vblSQGICIiIio2e/bkPfPzKkEAbt4UtytJDEBERERUbO7c0e92+sIARERERMXGw0O/2+kLR4EREREZOEMaPq6r4GBxtFdysuZ+QDKZ+HxwcMnWxTNAREREBiw+HvD1Bd55B/j4Y/FfX19pRk4VhlwuDnUHxLDzKuXj2NiSD3QMQERERAbKEIePF0aHDsD69YCXl3p7hQpiuxTzAMkEQdMJKdOWkZEBJycnpKenw9HRUepyiIjIBCkU4pme/EZQKS8dJSYaz+Ww4r6Up8v3N/sAERERGSBdho+3bFliZRWJXG44tfISGBERkQEy1OHjpQUDEBERkQEy1OHjpYXkAWj+/Pnw9fWFtbU1AgMDcfjw4Xy3ffHiBb7++mv4+/vD2toaderUwdatW4u0TyIiIkOkHD7++sgpJZkM8PYu+eHjpYWkASguLg5RUVGYOHEijh07hjp16iA0NBR3797VuP348ePx008/Ye7cuTh37hw+++wztG/fHsePHy/0PomIqPRSKIBdu4A1a8R/S3q9qaIw1OHjpYYgocaNGwuDBg1SPVYoFIKnp6cQExOjcXsPDw9h3rx5am0dOnQQunfvXuh9apKeni4AENLT07V+DRERGZbffxeEChUEQewuLN4qVBDbjYmm4/D2Nr7jKAm6fH9Ldgbo+fPnOHr0KEJCQlRtZmZmCAkJwYEDBzS+Jjs7G9bW1mptNjY22Lt3b6H3SUREpU9pmT8HEOfIuX4d2LkTWL1a/DcxUZq5c0oTyYbBp6WlQaFQwM3NTa3dzc0NFy5c0Pia0NBQzJw5E82bN4e/vz8SEhIQHx8Pxf+f0yzMPgExWGVnZ6seZ2RkFPawiIhIYgoFMGyY5mUXBEG8fDR8ONCunfFcPjKk4eOlheSdoHUxe/ZsVKlSBdWqVYOlpSUGDx6MPn36wMysaIcRExMDJycn1c3b21tPFRMRUUnTZf4cMl2SBSAXFxfI5XKkpqaqtaempsLd3V3ja1xdXbFx40ZkZWXhxo0buHDhAuzt7VGpUqVC7xMAxo4di/T0dNXt5s2bRTw6IiKSCufPIW1IFoAsLS3RoEEDJCQkqNpyc3ORkJCAoKCgAl9rbW0NLy8v5OTk4Pfff0e7du2KtE8rKys4Ojqq3YiIyDhx/hzShqRLYURFRSEyMhINGzZE48aNERsbi6ysLPTp0wcA0KtXL3h5eSEmJgYAcOjQISQnJ6Nu3bpITk7GpEmTkJubiy+//FLrfRIRUemmnD8nOVlzPyDlGlqcP8e0SRqAIiIicO/ePUyYMAEpKSmoW7cutm7dqurEnJSUpNa/59mzZxg/fjyuXbsGe3t7tG3bFitWrICzs7PW+yQiojcr7kUri5Ny/pxOncSw82oI4vw5pMTV4DXgavBEZMri48VRVK92JK5QQQwVxjT0WtNxeHuL4ceYjoO0p8v3NwOQBgxARGSqlPPnvP7NoDxzsn69cYUHYz6TRbpjACoiBiAiMkUKBeDrm/8QcmXfmcREhggyTLp8fxvVPEBERFR8OH8OmRIGICIiAsD5c8i0MAAREREAzp9DpoUBiIiIALycP0fZ4fl1Mpk4iorz51BpwABEREQAXs6fA+QNQZw/h0obBiAiIlLp0EEc6u7lpd5eoYLxDYEnKoikM0ETEZHh6dABaNeO8+dQ6cYAREREecjlQMuWUldBVHx4CYyIiIhMDs8AERHpEZdeIDIODEBERHpSWhYRJTIFvARGRKQHykVEX19KIjlZbI+Pl6YuItKMAYiIqIgUCvHMj6alpZVtw4eL2xGRYWAAIiIqIi4iSmR8GICIiIqIi4gSGR8GICKiIuIiokTGhwGIiKiIuIgokfFhACIiKiIuIkpkfBiAiIj0gIuIEhkXToRIRKQnXESU9CUnB0hLA+7efXm7dw94+BAwMwPMzcWbXK7+75vaivr8621mZvlf+jV0DEBERHrERURJk9xcMby8GmZeDTevtz14IHXF2itsEOvdGxg8WMK6pXtrIiIi4yQIwOPH2oWZu3fFszm6ToRpZga4uADly4s3V1egTBnxvRUK8SxRTs7L+5ratLmvzbaaJvlUUm6Tna3b8YWG6ra9vjEAEZFB4CKiJLWnT7ULM8rHun7hA2KAUYYZZbB5NeC8+rhMGcP5byA3V/xvVJ8hq1IlaY+JAYiIJMdFRKm4PH0K3LghzsRdUJi5exfIzNR9//b22oUZV1fxbI6lpf6PsSSYmYk3CwupK9EfBiAikpRyEdHXT7ErFxHlCCoqyIsXQFIScP06kJgo3l69n5Ki2/4sLbULM8p/bW2L46ioJMgEoaAre6YpIyMDTk5OSE9Ph6Ojo9TlEJVaCgXg65v/OloymXgmKDHRcC4FUMlSKMQwrAw1rwedW7fEyzMFcXAAKlYE3N3zDzPKm4OD8Y5qIt2+v3kGiIgko8siohxZVToJApCaqvnszfXr4tmdFy8K3oe1NeDnJ4bpV/9V3i9blqGG8mIAIiLJcBHR0k8QxCHdmsKN8t9nzwreh4WFeAbn9XCjfOzmxoBDumMAIiLJcBHR0iEjI/8+ONevi8PFC2JmJl7qfP3MjfK+pycvgZL+MQARkWSUi4gmJ2ueZ0TZB4iLiErryRNxJFV+AUebSfvc3TWHG19fcaFYYx0dRcaLAYiIJKNcRLRTJzHsvBqCSuMiogqF2J/l1VtOTt42Kdtff+7RI7GPzpuUK6c53Pj5AT4+gI1NMX+4RDpiACIiSSkXEdU0D1BsrOEPgX/8GLh4EbhwQf2Wmpo3VBjzmFsHh7x9b1697+AgdYVEumEAIiLJGfoiooIghrPXQ87Fi+Llu6IwNxc7+Spvrz8uqfb8nrO3FwNOmTLsaEylCwMQERkEQ1hE9Nkz4PJlzUEnKyv/17m7A9Wqqd+8vMR+LQUFEXNzhgoiqTAAEZFJEQRx+YPXQ86FC2KH3vwuU5mbA1WqiOGmatWXQadqVcDZuSSPgIj0gQGIiEqlFy+Aa9c09895+DD/1zk7A9Wr5z2j4+dXutZBIjJ1DEBEZNQePdIccq5cETseayKTiYHm9bM51aqJSyPwshRR6ccAREQGLzdXXBJBU9+cgha7tLXNeyanWjWgcmUOyyYydQxARGQwnjwBLl3KG3QuXQKePs3/dV5eec/kKDsim5mVXP1EZDwYgIioxAmC2OH41Cn12+XL+XdCtrR82Qn51dtbbwFvWPSZiCgPBiAiI6dQGO78OYC4TtTp0+pB5/Tp/NeHKltWcydkX19xJBYRkT7wfydERiw+XvMMyrNnl/wMygqF2PH49bM6169r3t7SEqhRA6hdW/3m5laiZRORiWIAIjJS8fHiGlqvXzJKThbb168vvhCUlpb3rM6ZM+JEgppUqJA36Lz1FoeVE5F0ZIJgzKvTFI+MjAw4OTkhPT0djuxcQAZIoRAvCb165udVylXUExOLdjns+XNxpNXrZ3Vu39a8va0tUKuWetAJCBAvaxERFTddvr95BojICO3Zk3/4AcSzQjdvittps7yEIIjDyV8POufPixMKalKpUt6zOpUqGVb/IyKi/DAAERmhO3cKv93Tp8DZs+odkk+dEi9raeLk9PJMjjLo1KrF1b+JyLgxABEZIQ8P7baTyYBNm/IONc/NzbutmZk4l87rZ3W8vTkzMhGVPpJPETZ//nz4+vrC2toagYGBOHz4cIHbx8bGomrVqrCxsYG3tzdGjBiBZ6/0vJw0aRJkMpnarVq1asV9GEQlKjhY7ONTUDCRyYBu3YB27YDoaGDdOrE/T24u4OICtGoFjBgBLF0KHD0KZGYC584Ba9cCX30FfPghULEiww8RlU6SngGKi4tDVFQUFi5ciMDAQMTGxiI0NBQXL15E+fLl82y/evVqjBkzBkuWLEGTJk1w6dIl9O7dGzKZDDNnzlRtV7NmTezYsUP12JyTh1ApI5cDs2YBnTvnv40giKOs8htqzmBDRKZM0mQwc+ZM9O/fH3369AEALFy4EJs3b8aSJUswZsyYPNvv378fTZs2xccffwwA8PX1Rbdu3XDo0CG17czNzeHu7l78B0AkgeRkYNkyYMkSzc87OAD9+om3qlU51JyISBPJAtDz589x9OhRjB07VtVmZmaGkJAQHDhwQONrmjRpgpUrV+Lw4cNo3Lgxrl27hi1btqBnz55q212+fBmenp6wtrZGUFAQYmJiULFixXxryc7ORnZ2tupxRkZGEY+OSL+ePwf+/BNYvBjYuvVlHx4HB6BrV6B+fXE5CE9Pw5sJmojIEEkWgNLS0qBQKOD22rSvbm5uuHDhgsbXfPzxx0hLS0OzZs0gCAJycnLw2Wef4auvvlJtExgYiGXLlqFq1aq4c+cOJk+ejODgYJw5cwYO+QxbiYmJweTJk/V3cER6cuGCGHqWLwfu3XvZHhwsnuHp1Amws5OuPiIiYyV5J2hd7Nq1C9OmTcOPP/6IY8eOIT4+Hps3b8aUKVNU27Rp0wadO3dG7dq1ERoaii1btuDRo0f47bff8t3v2LFjkZ6errrdvHmzJA6HSKPMTPHyVtOm4ppYM2aI4cfNDRg9WuzIvHs3EBnJ8ENEVFiSnQFycXGBXC5HamqqWntqamq+/Xeio6PRs2dPfPLJJwCAgIAAZGVlYcCAARg3bhzMzPLmOWdnZ7z11lu4cuVKvrVYWVnBysqqCEdDVDSCABw6JJ7tWbtWDEGAeCmrbVvxbE/btuzPQ0SkL5KdAbK0tESDBg2QkJCgasvNzUVCQgKCgoI0vubJkyd5Qo78/zs75LeiR2ZmJq5evQoPbSdOISpB9+4BM2eKEwsGBQG//CKGn8qVgZgYIClJnMenXTuGHyIifZJ0FFhUVBQiIyPRsGFDNG7cGLGxscjKylKNCuvVqxe8vLwQExMDAAgLC8PMmTNRr149BAYG4sqVK4iOjkZYWJgqCI0aNQphYWHw8fHB7du3MXHiRMjlcnTr1k2y4yR6lUIBbN8unu3544+XS03Y2Ih9evr1A5o35zB1IqLiJGkAioiIwL179zBhwgSkpKSgbt262Lp1q6pjdFJSktoZn/Hjx0Mmk2H8+PFITk6Gq6srwsLCMHXqVNU2t27dQrdu3XD//n24urqiWbNmOHjwIFxdXUv8+Ihedf262Ldn2TJxnS6lBg2ATz4RJy10cpKqOiIi08LV4DXgavCkL8+eARs3imd7EhLEvj4AUKYM0KOHeLanTh1JSyQiKjW4GjyRxE6eFEPPypXAw4cv21u1Es/2hIcD1taSlUdEZPIYgIj0JD0dWLNGDD7//feyvUIFoE8f8ebnJ119RET0EgMQUREIgjgnz+LFwPr1wNOnYruFhThyq18/4L33ODMzEZGhYQAiKoQ7d8TZmZcsAS5fftleo4YYenr2BNjvnojIcDEAEWkpJwfYskWcq2fLFnE4OwDY2wMREWLfnsBADl8nIjIGDEBEb3DpknimZ/lyICXlZXuTJuLZni5dxBBERETGgwGISIOsLLFPz+LFwJ49L9tdXcU1uPr2FdfpIiIi48QARPT/BEEcvbV4sTiaKyNDbDczA1q3Fs/2fPghYGkpbZ1ERFR0DEBkshQK8ezOpUvAqVPiaK7Tp18+7+cnhp7ISHEoOxERlR4MQGSS4uOBIUOA27fV2y0sgM6dxeDTsqV49oeIiEofBiAyOfHxQMeOmp978UJ87t13S7YmIiIqWfz7lkyKQgEMGpT/8zIZMHz4yyHuRERUOjEAkUlZvVp9KPvrBEFcqf3VkV9ERFT6MACRyTh1Chg8WLtt79wp3lqIiEhaDEBkEg4eBFq0eDm0/U08PIq3HiIikhYDEJV6O3YAISHAo0dAUBDg6Zn/chUyGeDtDQQHl2iJRERUwhiAqFT74w/ggw/EmZ3ffx/Yvh2YO1d87vUQpHwcG8vV24mISjsGICq1Vq4Uh7Q/fw506ABs2gTY2Yn3168HvLzUt69QQWzv0EGaeomIqORwHiAqlX788eVw98hIcQV381d+2zt0ANq1E0d73bkj9vkJDuaZHyIiU8EARKVOTAzw1Vfi/SFDxEtammZ0lsvF2Z6JiMj08BIYlRqCAIwZ8zL8REcDs2dzOQsiIsqLZ4CoVMjNFS95LVwoPp4xAxg5UtqaiIjIcDEAkdF78QLo3Vuc5VkmA376CejfX+qqiIjIkDEAkVF79gyIiBBHeJmbAytWAF27Sl0VEREZOgYgMlqPH4sjuXbuBKytxSHsH3wgdVVERGQMGIDIKD14ALRtCxw6BNjbA3/+KS51QUREpA0GIDI6KSnirM6nTwNlywJbtwKNGkldFRERGRMGIDIqN26I63pduSJOXvj330CtWlJXRURExoYBiIzGhQvAe+8Bt24Bvr7iIqf+/lJXRURExohTxJFROH4caN5cDD/VqwN79zL8EBFR4TEAkcHbtw945x3g3j2gQQNg9+68C5kSERHpggGIDNrff4sdntPTxcVKExIAFxepqyIiImPHAEQG6/ffgQ8/BJ48AVq3Fkd7OTlJXRUREZUGDEBkkJYvB7p0EZe56NwZ+OMPwNZW6qqIiKi0YAAigzN3rri2V24u0K8fsGYNYGkpdVVERFSaMACRwRAE4JtvgKFDxccjRgA//wzI5dLWRUREpQ8DEBkEQQC++AKIjhYfT5oE/PCDuLo7ERGRvnEiRJKcQgEMHCie7QGAWbOA4cMlLYmIiEo5BiCS1PPnQK9eQFwcYGYmhqC+faWuioiISjsGIJLM06dAp07Ali2AhQWwapU44ouIiKi4MQCRJDIygI8+Av79F7CxAeLjxbl+iIiISgIDEJW4+/fFsPPff4CDA7B5szjLMxERUUlhAKISdfu2uKL7uXPikhZbt4rrexEREZUkBiAqMdeuASEhQGIi4OkJ7NghruxORERU0jgPEJWIc+fEy1yJiUClSsDevQw/REQkHQYgKnZHjwLNm4uXv2rWBPbsAfz8pK6KiIhMGQMQFavdu4F33hE7PjdqJI768vSUuioiIjJ1DEBUbLZsAUJDgcePgRYtgIQEoFw5qasiIiIygAA0f/58+Pr6wtraGoGBgTh8+HCB28fGxqJq1aqwsbGBt7c3RowYgWfPnhVpn6R/v/0GtGsHPHsGfPAB8Ndf4pB3IiIiQyBpAIqLi0NUVBQmTpyIY8eOoU6dOggNDcXdu3c1br969WqMGTMGEydOxPnz57F48WLExcXhq6++KvQ+Sf8WLwa6dQNycoCuXYENG8TJDomIiAyFTBAEQao3DwwMRKNGjTBv3jwAQG5uLry9vTFkyBCMGTMmz/aDBw/G+fPnkZCQoGobOXIkDh06hL179xZqn5pkZGTAyckJ6enpcHR0LOphmpRZs4CoKPF+//7AggWAXC5tTUREZBp0+f6W7AzQ8+fPcfToUYSEhLwsxswMISEhOHDggMbXNGnSBEePHlVd0rp27Rq2bNmCtm3bFnqfAJCdnY2MjAy1G+lGEICJE1+Gn1GjgJ9+YvghIiLDJNlEiGlpaVAoFHBzc1Nrd3Nzw4ULFzS+5uOPP0ZaWhqaNWsGQRCQk5ODzz77THUJrDD7BICYmBhMnjy5iEdkunJzxeAze7b4+JtvgK++AmQyaesiIiLKj+SdoHWxa9cuTJs2DT/++COOHTuG+Ph4bN68GVOmTCnSfseOHYv09HTV7ebNm3qquPRTKIBPPnkZfubMAcaNY/ghIiLDJtkZIBcXF8jlcqSmpqq1p6amwt3dXeNroqOj0bNnT3zyyScAgICAAGRlZWHAgAEYN25cofYJAFZWVrCysiriEZme7GygRw9g/XrAzAxYsgSIjJS6KiIiojeT7AyQpaUlGjRooNahOTc3FwkJCQgKCtL4midPnsDMTL1k+f93MhEEoVD7pMJ58kQc5r5+PWBpCaxbx/BDRETGQ9LFUKOiohAZGYmGDRuicePGiI2NRVZWFvr06QMA6NWrF7y8vBATEwMACAsLw8yZM1GvXj0EBgbiypUriI6ORlhYmCoIvWmfVHRZWUDr1uJ6Xra2wMaN4grvRERExkLSABQREYF79+5hwoQJSElJQd26dbF161ZVJ+akpCS1Mz7jx4+HTCbD+PHjkZycDFdXV4SFhWHq1Kla75OKbtYsMfw4OQGbNwNNm0pdERERkW4knQfIUHEeoPw9ewb4+AB37wIrVwLdu0tdERERkahY5wHy9fXF119/jaSkpEIXSMZr5Uox/Hh7A126SF0NERFR4egcgIYPH474+HhUqlQJ7733HtauXYvs7OziqI0MTG4u8MMP4v3hwwELC0nLISIiKrRCBaATJ07g8OHDqF69OoYMGQIPDw8MHjwYx44dK44ayUBs2QJcuAA4Oopz/xARERmrQg+Dr1+/PubMmYPbt29j4sSJ+OWXX9CoUSPUrVsXS5YsAbsWlT4zZoj/fvqpGIKIiIiMVaFHgb148QIbNmzA0qVLsX37drz99tvo168fbt26ha+++go7duzA6tWr9VkrSejIEeDffwFzc2DoUKmrISIiKhqdA9CxY8ewdOlSrFmzBmZmZujVqxdmzZqFatWqqbZp3749GjVqpNdCSVrKvj/dugEVKkhbCxERUVHpHIAaNWqE9957DwsWLEB4eDgsNPSE9fPzQ9euXfVSIEnv+nVxpmcAGDlS0lKIiIj0QucAdO3aNfj4+BS4jZ2dHZYuXVroosiwxMaKI8Deew+oU0fqaoiIiIpO507Qd+/exaFDh/K0Hzp0CP/9959eiiLD8fAh8Msv4v1Ro6SthYiISF90DkCDBg3CzZs387QnJydj0KBBeimKDMdPP4lrfwUEcL0vIiIqPXQOQOfOnUP9+vXztNerVw/nzp3TS1FkGLKzgTlzxPujRgEymbT1EBER6YvOAcjKygqpqal52u/cuQNzc0nXViU9W7MGuHMH8PQE2KediIhKE50D0Pvvv4+xY8ciPT1d1fbo0SN89dVXeI/XSEoNQXg58eGwYYClpbT1EBER6ZPOp2xmzJiB5s2bw8fHB/Xq1QMAnDhxAm5ublixYoXeCyRpbNsGnD0L2NsDAwZIXQ0REZF+6RyAvLy8cOrUKaxatQonT56EjY0N+vTpg27dummcE4iMk/LsT//+gLOzpKUQERHpnUzgol15ZGRkwMnJCenp6XA0wUWvjh8H6tcH5HLg6lXgDdM+ERERGQRdvr8L3Wv53LlzSEpKwvPnz9XaP/roo8LukgyEctmLLl00hx+FAtizR+wg7eEBBAeLYYmIiMhYFGom6Pbt2+P06dOQyWSqVd9l/z9GWqFQ6LdCKlE3bwJr14r3NS17ER8vdoq+detlW4UKwOzZQIcOJVMjERFRUek8CmzYsGHw8/PD3bt3YWtri7Nnz2L37t1o2LAhdu3aVQwlUkmaPVs8w/POO0CDBurPxccDnTqphx8ASE4W2+PjS65OIiKiotA5AB04cABff/01XFxcYGZmBjMzMzRr1gwxMTEYOnRocdRIJSQ9HVi0SLz/+rIXCoV45kdTjzFl2/Dh4nZERESGTucApFAo4ODgAABwcXHB7du3AQA+Pj64ePGifqujEvXzz8Djx0CNGkDr1urP7dmT98zPqwRBvHy2Z0/x1khERKQPOvcBqlWrFk6ePAk/Pz8EBgZi+vTpsLS0xKJFi1CpUqXiqJFKwPPn4qrvgNj3x+y1aHznjnb70XY7IiIiKekcgMaPH4+srCwAwNdff40PP/wQwcHBKFeuHOLi4vReIJWM334T+/K4uQHdu+d93sNDu/1oux0REZGU9DIP0IMHD1CmTBnVSDBjZ2rzAAkCUK8ecPIkMHUq8NVXebdRKABfXzEkafqNkcnE0WCJiRwST0RE0tDl+1unPkAvXryAubk5zpw5o9ZetmzZUhN+TFFCghh+bG2Bzz7TvI1cLo4QA/KuCq98HBvL8ENERMZBpwBkYWGBihUrcq6fUka57EW/fkDZsvlv16EDsH494OWl3l6hgtjOeYCIiMhY6HwJbPHixYiPj8eKFStQtqBvSyNmSpfATp0C6tQROz1fvgxo04+dM0ETEZEhKtalMObNm4crV67A09MTPj4+sLOzU3v+2LFjuu6SJDRzpvhvx47ahR9ADDstWxZbSURERMVO5wAUHh5eDGWQFJKTgdWrxfualr0gIiIqrXQOQBMnTiyOOkgCc+cCL16Il7ACA6WuhoiIqOToPBM0lQ6PHwMLF4r3X1/2goiIqLTT+QyQmZlZgUPeOULMOCxeLK799dZbwIcfSl0NERFRydI5AG3YsEHt8YsXL3D8+HEsX74ckydP1lthVHxycoBZs8T7mpa9ICIiKu30MhM0AKxevRpxcXH4448/9LE7SZX2YfBr1wLdugGursCNG4CNjdQVERERFV2xzQRdkLfffhsJCQn62h0VE0EAvv9evD94MMMPERGZJr0EoKdPn2LOnDnwen2KYDI4//4LHDsGWFsDn38udTVERETS0LkP0OuLngqCgMePH8PW1hYrV67Ua3Gkf8plL/r0AVxcpK2FiIhIKjoHoFmzZqkFIDMzM7i6uiIwMBBlypTRa3GkX+fOAZs3i4uXjhghdTVERETS0TkA9e7duxjKoJKgXPYiPByoUkXSUoiIiCSlcx+gpUuXYt26dXna161bh+XLl+ulKNK/lBRgxQrxPic+JCIiU6dzAIqJiYGLhs4j5cuXx7Rp0/RSFOnfvHnA8+dAUBDQpInU1RAREUlL5wCUlJQEPz+/PO0+Pj5ISkrSS1GkX1lZwI8/ivd59oeIiKgQAah8+fI4depUnvaTJ0+iXLlyeimK9GvpUuDhQ8DfH2jXTupqiIiIpKdzAOrWrRuGDh2KnTt3QqFQQKFQ4J9//sGwYcPQtWvX4qiRikCheNn5OSoKkMulrYeIiMgQ6DwKbMqUKbh+/TpatWoFc3Px5bm5uejVqxf7ABmgDRuAxESgXDmAA/iIiIhEOgcgS0tLxMXF4ZtvvsGJEydgY2ODgIAA+Pj4FEd9VASvLnvx+eeAra209RARERkKnQOQUpUqVVCFk8kYtH37gMOHASsrYNAgqashIiIyHDr3AerYsSO+++67PO3Tp09H586d9VIU6Ydy2YtevQA3N2lrISIiMiQ6B6Ddu3ejbdu2edrbtGmD3bt3F6qI+fPnw9fXF9bW1ggMDMThw4fz3bZly5aQyWR5bh988IFqm969e+d5vnXr1oWqzVhdvAhs2iTej4qSthYiIiJDo/MlsMzMTFhaWuZpt7CwQEZGhs4FxMXFISoqCgsXLkRgYCBiY2MRGhqKixcvonz58nm2j4+Px/Pnz1WP79+/jzp16uQ5+9S6dWssXbpU9djKykrn2ozZrFliH6CwMKBaNamrISIiMiw6nwEKCAhAXFxcnva1a9eiRo0aOhcwc+ZM9O/fH3369EGNGjWwcOFC2NraYsmSJRq3L1u2LNzd3VW37du3w9bWNk8AsrKyUtvOlBZqvXsXUK5KwokPiYiI8tL5DFB0dDQ6dOiAq1ev4t133wUAJCQkYPXq1Vi/fr1O+3r+/DmOHj2KsWPHqtrMzMwQEhKCAwcOaLWPxYsXo2vXrrCzs1Nr37VrF8qXL48yZcrg3XffxTfffJPvRI3Z2dnIzs5WPS7MmSxD8uOPwLNnQKNGQHCw1NUQEREZHp3PAIWFhWHjxo24cuUKPv/8c4wcORLJycn4559/ULlyZZ32lZaWBoVCAbfXeui6ubkhJSXlja8/fPgwzpw5g08++UStvXXr1vj111+RkJCA7777Dv/++y/atGkDhUKhcT8xMTFwcnJS3by9vXU6DkPy5Akwf754f9QoQCaTth4iIiJDVKhh8B988IGq03FGRgbWrFmDUaNG4ejRo/mGjOKwePFiBAQEoHHjxmrtr85IHRAQgNq1a8Pf3x+7du1Cq1at8uxn7NixiHqlp3BGRobRhqBffwXS0gBfX6BDB6mrISIiMkw6nwFS2r17NyIjI+Hp6YkffvgB7777Lg4ePKjTPlxcXCCXy5GamqrWnpqaCnd39wJfm5WVhbVr16Jfv35vfJ9KlSrBxcUFV65c0fi8lZUVHB0d1W7G6NVlL0aMAMwLPcsTERFR6aZTAEpJScG3336LKlWqoHPnznB0dER2djY2btyIb7/9Fo0aNdLpzS0tLdGgQQMkJCSo2nJzc5GQkICgoKACX7tu3TpkZ2ejR48eb3yfW7du4f79+/Dw8NCpPmPzv/8Bly8DZcoAfftKXQ0REZHh0joAhYWFoWrVqjh16hRiY2Nx+/ZtzJ07t8gFREVF4eeff8by5ctx/vx5DBw4EFlZWejTpw8AoFevXmqdpJUWL16M8PDwPB2bMzMz8cUXX+DgwYO4fv06EhIS0K5dO1SuXBmhoaFFrteQKSc+HDgQsLeXthYiIiJDpvVFkr/++gtDhw7FwIED9boERkREBO7du4cJEyYgJSUFdevWxdatW1Udo5OSkmBmpp7TLl68iL179+Lvv//Osz+5XI5Tp05h+fLlePToETw9PfH+++9jypQppXouoAMHxKUvLC2BwYOlroaIiMiwyQRBELTZ8ODBg1i8eDHi4uJQvXp19OzZE127doWHhwdOnjxZqDmADFVGRgacnJyQnp5uNP2BOnUCfv9dvPS1eLHU1RAREZU8Xb6/tb4E9vbbb+Pnn3/GnTt38Omnn2Lt2rXw9PREbm4utm/fjsePHxe5cCqcq1eB+HjxPpe9ICIiejOdR4HZ2dmhb9++2Lt3L06fPo2RI0fi22+/Rfny5fHRRx8VR430BsplL9q2BWrWlLoaIiIiw1foYfAAULVqVUyfPh23bt3CmjVr9FUT6eD+fUC5agiXvSAiItJOkQKQklwuR3h4ODYplx+nErNgAfD0KVC/PtCypdTVEBERGQe9BCCSxrNngHImAi57QUREpD0GICO2cqW48nvFiuIoMCIiItIOA5CRys0FfvhBvD98OGBhIWk5RERERoUByEht2QJcuAA4OQGffCJ1NURERMaFAchIKZe9+PRTwMFB2lqIiIiMDQOQETpyBPj3X3G196FDpa6GiIjI+DAAGSFl35+PPwa8vKSthYiIyBgxABmZ69eBdevE+yNHSloKERGR0WIAMjKxseIIsPffB2rXlroaIiIi48QAZEQePgR++UW8z2UviIiICo8ByIj89BOQlSWe+QkJkboaIiIi48UAZCSys4E5c8T7XPaCiIioaBiAjMSaNcCdO+Kor4gIqashIiIybgxARkAQXk58OGwYYGkpbT1ERETGjgHICGzbBpw9K874PGCA1NUQEREZPwYgI6A8+9O/v7j2FxERERUNA5CBO34cSEgA5HLx8hcREREVHQOQgVMuexERAVSsKG0tREREpQUDkAG7eRNYu1a8z2UviIiI9IcByIDNng0oFMC77wL160tdDRERUenBAGSg0tOBRYvE+1z2goiISL8YgAzUzz8Djx8DNWoArVtLXQ0REVHpwgBkgJ4/F1d9B7jsBRERUXFgADJAv/0GJCcD7u7Axx9LXQ0REVHpwwBkYF5d9mLoUMDKStp6iIiISiMGIAOTkACcPAnY2QGffip1NURERKUTA5CBUZ796dcPKFtW2lqIiIhKKwYgA3LqlLjwqZkZMHy41NUQERGVXgxABkS57EWnToCfn7S1EBERlWYMQAbi1i1g9WrxPic+JCIiKl4MQAZi7lwgJwdo3hxo1EjqaoiIiEo3BiADkJEBLFwo3ufZHyIiouLHAGQAFi8WQ1DVqsAHH0hdDRERUenHACSxFy9eLnsxcqQ4AoyIiIiKF79uJbZ+PZCUBJQvD/TsKXU1REREpoEBSEKvLnsxeDBgbS1tPURERKaCAUhCu3YBx44BNjbAwIFSV0NERGQ6GIAkpDz706cP4OIibS1ERESmhAFIImfPAlu2ADIZMGKE1NUQERGZFgYgicycKf7bvj1QubK0tRAREZkaBiAJ3LkDrFwp3ufEh0RERCWPAUgC8+YBz58DTZoAQUFSV0NERGR6GIBKWGYmsGCBeJ9nf4iIiKTBAFTCli4FHj4U+/189JHU1RAREZkmgwhA8+fPh6+vL6ytrREYGIjDhw/nu23Lli0hk8ny3D54ZREtQRAwYcIEeHh4wMbGBiEhIbh8+XJJHEqBcnKAWbPE+1FRgFwubT1ERESmSvIAFBcXh6ioKEycOBHHjh1DnTp1EBoairt372rcPj4+Hnfu3FHdzpw5A7lcjs6dO6u2mT59OubMmYOFCxfi0KFDsLOzQ2hoKJ49e1ZSh6XRhg1AYiJQrhwQGSlpKURERCZN8gA0c+ZM9O/fH3369EGNGjWwcOFC2NraYsmSJRq3L1u2LNzd3VW37du3w9bWVhWABEFAbGwsxo8fj3bt2qF27dr49ddfcfv2bWzcuLEEjyyv48fFeX8GDQJsbSUthYiIyKRJGoCeP3+Oo0ePIiQkRNVmZmaGkJAQHDhwQKt9LF68GF27doWdnR0AIDExESkpKWr7dHJyQmBgYL77zM7ORkZGhtqtOEybBpw/DwwdWiy7JyIiIi1JGoDS0tKgUCjg5uam1u7m5oaUlJQ3vv7w4cM4c+YMPvnkE1Wb8nW67DMmJgZOTk6qm7e3t66HorWqVcVLYERERCQdyS+BFcXixYsREBCAxo0bF2k/Y8eORXp6uup28+ZNPVVIREREhkjSAOTi4gK5XI7U1FS19tTUVLi7uxf42qysLKxduxb9+vVTa1e+Tpd9WllZwdHRUe1GREREpZekAcjS0hINGjRAQkKCqi03NxcJCQkIesMUyevWrUN2djZ69Oih1u7n5wd3d3e1fWZkZODQoUNv3CcRERGZBnOpC4iKikJkZCQaNmyIxo0bIzY2FllZWejTpw8AoFevXvDy8kJMTIza6xYvXozw8HCUe61DjUwmw/Dhw/HNN9+gSpUq8PPzQ3R0NDw9PREeHl5Sh0VEREQGTPIAFBERgXv37mHChAlISUlB3bp1sXXrVlUn5qSkJJiZqZ+ounjxIvbu3Yu///5b4z6//PJLZGVlYcCAAXj06BGaNWuGrVu3wtrautiPh4iIiAyfTBAEQeoiDE1GRgacnJyQnp7O/kBERERGQpfvb6MeBUZERERUGAxAREREZHIYgIiIiMjkMAARERGRyWEAIiIiIpPDAEREREQmhwGIiIiITA4DEBEREZkcBiAiIiIyOQxAREREZHIYgIiIiMjkMAARERGRyWEAIiIiIpPDAEREREQmhwGIiIiITA4DEBEREZkcBiAiIiIyOQxAREREZHIYgIiIiMjkMAARERGRyWEAIiIiIpPDAEREREQmhwGIiIiITA4DEBEREZkcBiAiIiIyOQxAREREZHIYgIiIiMjkMAARERGRyWEAIiIiIpPDAEREREQmhwGIiIiITA4DEBEREZkcBiAiIiIyOQxAREREZHIYgIiIiMjkMAARERGRyWEAIiIiIpPDAEREREQmhwGIiIiITA4DEBEREZkcBiAiIiIyOQxAREREZHLMpS6AiIhMi0KhwIsXL6Qug4yQhYUF5HK5XvbFAERERCVCEASkpKTg0aNHUpdCRszZ2Rnu7u6QyWRF2g8DEBERlQhl+ClfvjxsbW2L/AVGpkUQBDx58gR3794FAHh4eBRpfwxARERU7BQKhSr8lCtXTupyyEjZ2NgAAO7evYvy5csX6XIYO0ETEVGxU/b5sbW1lbgSMnbK36Gi9iNjACIiohLDy15UVPr6HZI8AM2fPx++vr6wtrZGYGAgDh8+XOD2jx49wqBBg+Dh4QErKyu89dZb2LJli+r5SZMmQSaTqd2qVatW3IdBRESkFV9fX8TGxmq9/a5duyCTydh5XM8k7QMUFxeHqKgoLFy4EIGBgYiNjUVoaCguXryI8uXL59n++fPneO+991C+fHmsX78eXl5euHHjBpydndW2q1mzJnbs2KF6bG7Ork5ERKWBQgHs2QPcuQN4eADBwYCeRkXn8aYzDRMnTsSkSZN03u+RI0dgZ2en9fZNmjTBnTt34OTkpPN7Uf4kTQYzZ85E//790adPHwDAwoULsXnzZixZsgRjxozJs/2SJUvw4MED7N+/HxYWFgDEJP06c3NzuLu7F2vtRERUsuLjgWHDgFu3XrZVqADMng106KD/97tz547qflxcHCZMmICLFy+q2uzt7VX3BUGAQqHQ6g9uV1dXneqwtLTkd1oxkOwS2PPnz3H06FGEhIS8LMbMDCEhIThw4IDG12zatAlBQUEYNGgQ3NzcUKtWLUybNg0KhUJtu8uXL8PT0xOVKlVC9+7dkZSUVKzHQkRExSs+HujUST38AEBystgeH6//93R3d1fdnJycIJPJVI8vXLgABwcH/PXXX2jQoAGsrKywd+9eXL16Fe3atYObmxvs7e3RqFEjtSsSQN5LYDKZDL/88gvat28PW1tbVKlSBZs2bVI9//olsGXLlsHZ2Rnbtm1D9erVYW9vj9atW6sFtpycHAwdOhTOzs4oV64cRo8ejcjISISHh+d7vPfv30e3bt3g5eUFW1tbBAQEYM2aNWrb5ObmYvr06ahcuTKsrKxQsWJFTJ06VfX8rVu30K1bN5QtWxZ2dnZo2LAhDh06VIhPv/hJFoDS0tKgUCjg5uam1u7m5oaUlBSNr7l27RrWr18PhUKBLVu2IDo6Gj/88AO++eYb1TaBgYFYtmwZtm7digULFiAxMRHBwcF4/PhxvrVkZ2cjIyND7UZERIZBoRDP/AhC3ueUbcOHi9uVtDFjxuDbb7/F+fPnUbt2bWRmZqJt27ZISEjA8ePH0bp1a4SFhb3xD/HJkyejS5cuOHXqFNq2bYvu3bvjwYMH+W7/5MkTzJgxAytWrMDu3buRlJSEUaNGqZ7/7rvvsGrVKixduhT79u1DRkYGNm7cWGANz549Q4MGDbB582acOXMGAwYMQM+ePdX65o4dOxbffvstoqOjce7cOaxevVr1PZ6ZmYkWLVogOTkZmzZtwsmTJ/Hll18iNzdXi09SAoJEkpOTBQDC/v371dq/+OILoXHjxhpfU6VKFcHb21vIyclRtf3www+Cu7t7vu/z8OFDwdHRUfjll1/y3WbixIkCgDy39PR0HY+KiIg0efr0qXDu3Dnh6dOnOr92505BEKNOwbedO/VetsrSpUsFJyenV2raKQAQNm7c+MbX1qxZU5g7d67qsY+PjzBr1izVYwDC+PHjVY8zMzMFAMJff/2l9l4PHz5U1QJAuHLliuo18+fPF9zc3FSP3dzchO+//171OCcnR6hYsaLQrl07bQ9ZEARB+OCDD4SRI0cKgiAIGRkZgpWVlfDzzz9r3Pann34SHBwchPv37+v0Hroq6HcpPT1d6+9vyc4Aubi4QC6XIzU1Va09NTU132udHh4eeOutt9QmPqpevTpSUlLw/Plzja9xdnbGW2+9hStXruRby9ixY5Genq663bx5sxBHRERExeGVKzt62U6fGjZsqPY4MzMTo0aNQvXq1eHs7Ax7e3ucP3/+jWeAateurbpvZ2cHR0dH1YzHmtja2sLf31/12MPDQ7V9eno6UlNT0bhxY9XzcrkcDRo0KLAGhUKBKVOmICAgAGXLloW9vT22bdumqv38+fPIzs5Gq1atNL7+xIkTqFevHsqWLVvg+xgKyQKQpaUlGjRogISEBFVbbm4uEhISEBQUpPE1TZs2xZUrV9ROp126dAkeHh6wtLTU+JrMzExcvXq1wCmzrays4OjoqHYjIiLDoO2KB0VcGaFQXh/NNWrUKGzYsAHTpk3Dnj17cOLECQQEBOT7R7qScmCPkkwmK/DSkabtBU3XCHXw/fffY/bs2Rg9ejR27tyJEydOIDQ0VFW7chbm/LzpeUMj6TxAUVFR+Pnnn7F8+XKcP38eAwcORFZWlmpUWK9evTB27FjV9gMHDsSDBw8wbNgwXLp0CZs3b8a0adMwaNAg1TajRo3Cv//+i+vXr2P//v1o37495HI5unXrVuLHR0RERRccLI72ym9UukwGeHuL20lt37596N27N9q3b4+AgAC4u7vj+vXrJVqDk5MT3NzccOTIEVWbQqHAsWPHCnzdvn370K5dO/To0QN16tRBpUqVcOnSJdXzVapUgY2NjdqJi1fVrl0bJ06cKLDvkiGRdBh8REQE7t27hwkTJiAlJQV169bF1q1bVR2qkpKSYGb2MqN5e3tj27ZtGDFiBGrXrg0vLy8MGzYMo0ePVm2j7IF+//59uLq6olmzZjh48KDOww6JiMgwyOXiUPdOncSw8+qJDmUoio0tvvmAdFGlShXEx8cjLCwMMpkM0dHRknQCHjJkCGJiYlC5cmVUq1YNc+fOxcOHDwuc26hKlSpYv3499u/fjzJlymDmzJlITU1FjRo1AADW1tYYPXo0vvzyS1haWqJp06a4d+8ezp49i379+qFbt26YNm0awsPDERMTAw8PDxw/fhyenp75XtmRkuQzBA4ePBiDBw/W+NyuXbvytAUFBeHgwYP57m/t2rX6Ko2IiAxEhw7A+vWa5wGKjS2eeYAKY+bMmejbty+aNGkCFxcXjB49WpKRxaNHj0ZKSgp69eoFuVyOAQMGIDQ0tMDFQ8ePH49r164hNDQUtra2GDBgAMLDw5Genq7aJjo6Gubm5pgwYQJu374NDw8PfPbZZwDEri1///03Ro4cibZt2yInJwc1atTA/Pnzi/14C0MmFPWiYSmUkZEBJycnpKensz8QEZEePHv2DImJifDz84O1tXWh91OSM0GXJrm5uahevTq6dOmCKVOmSF1OkRT0u6TL97fkZ4CIiIi0JZcDLVtKXYXhu3HjBv7++2+0aNEC2dnZmDdvHhITE/Hxxx9LXZrBkHwxVCIiItIvMzMzLFu2DI0aNULTpk1x+vRp7NixA9WrV5e6NIPBM0BERESljLe3N/bt2yd1GQaNZ4CIiIjI5DAAERERkclhACIiIiKTwwBEREREJocBiIiIiEwOAxARERGZHAYgIiKiYtSyZUsMHz5c9djX1xexsbEFvkYmk2Hjxo1Ffm997ac0YgAiIiLSICwsDK1bt9b43J49eyCTyXDq1Cmd93vkyBEMGDCgqOWpmTRpEurWrZun/c6dO2jTpo1e36u0YAAiIiLSoF+/fti+fTtuvbr66v9bunQpGjZsiNq1a+u8X1dXV9ja2uqjxDdyd3eHlZVVibyXsWEAIiIi0uDDDz+Eq6srli1bptaemZmJdevWoV+/frh//z66desGLy8v2NraIiAgAGvWrClwv69fArt8+TKaN28Oa2tr1KhRA9u3b8/zmtGjR+Ott96Cra0tKlWqhOjoaLx48QIAsGzZMkyePBknT56ETCaDTCZT1fz6JbDTp0/j3XffhY2NDcqVK4cBAwYgMzNT9Xzv3r0RHh6OGTNmwMPDA+XKlcOgQYNU76XJ1atX0a5dO7i5ucHe3h6NGjXCjh071LbJzs7G6NGj4e3tDSsrK1SuXBmLFy9WPX/27Fl8+OGHcHR0hIODA4KDg3H16tUCP8ei4lIYREQkCUEAnjwp+fe1tQVksjdvZ25ujl69emHZsmUYN24cZP//onXr1kGhUKBbt27IzMxEgwYNMHr0aDg6OmLz5s3o2bMn/P390bhx4ze+R25uLjp06AA3NzccOnQI6enpav2FlBwcHLBs2TJ4enri9OnT6N+/PxwcHPDll18iIiICZ86cwdatW1XBw8nJKc8+srKyEBoaiqCgIBw5cgR3797FJ598gsGDB6uFvJ07d8LDwwM7d+7ElStXEBERgbp166J///4ajyEzMxNt27bF1KlTYWVlhV9//RVhYWG4ePEiKlasCADo1asXDhw4gDlz5qBOnTpITExEWloaACA5ORnNmzdHy5Yt8c8//8DR0RH79u1DTk7OGz+/IhEoj/T0dAGAkJ6ertf95uQIws6dgrB6tfhvTo5ed09EZLCePn0qnDt3Tnj69KmqLTNTEMQYVLK3zEzt6z5//rwAQNi5c6eqLTg4WOjRo0e+r/nggw+EkSNHqh63aNFCGDZsmOqxj4+PMGvWLEEQBGHbtm2Cubm5kJycrHr+r7/+EgAIGzZsyPc9vv/+e6FBgwaqxxMnThTq1KmTZ7tX97No0SKhTJkyQuYrH8DmzZsFMzMzISUlRRAEQYiMjBR8fHyEnFe+oDp37ixERETkW4smNWvWFObOnSsIgiBcvHhRACBs375d47Zjx44V/Pz8hOfPn2u1b02/S0q6fH/zElgJiY8HfH2Bd94BPv5Y/NfXV2wnIiLDVK1aNTRp0gRLliwBAFy5cgV79uxBv379AAAKhQJTpkxBQEAAypYtC3t7e2zbtg1JSUla7f/8+fPw9vaGp6enqi0oKCjPdnFxcWjatCnc3d1hb2+P8ePHa/0er75XnTp1YGdnp2pr2rQpcnNzcfHiRVVbzZo1IZfLVY89PDxw9+7dfPebmZmJUaNGoXr16nB2doa9vT3Onz+vqu/EiROQy+Vo0aKFxtefOHECwcHBsLCw0Ol4ioqXwEpAfDzQqZP4t8erkpPF9vXrgQ4dpKmNiEgqtrbAK91PSvR9ddGvXz8MGTIE8+fPx9KlS+Hv76/6Mv/+++8xe/ZsxMbGIiAgAHZ2dhg+fDieP3+ut3oPHDiA7t27Y/LkyQgNDYWTkxPWrl2LH374QW/v8arXg4hMJkNubm6+248aNQrbt2/HjBkzULlyZdjY2KBTp06qz8DGxqbA93vT88WFAaiYKRTAsGF5ww8gtslkwPDhQLt2wCuBm4io1JPJgFdORhisLl26YNiwYVi9ejV+/fVXDBw4UNUfaN++fWjXrh169OgBQOzTc+nSJdSoUUOrfVevXh03b97EnTt34OHhAQA4ePCg2jb79++Hj48Pxo0bp2q7ceOG2jaWlpZQKBRvfK9ly5YhKytLdRZo3759MDMzQ9WqVbWqV5N9+/ahd+/eaN++PQDxjND169dVzwcEBCA3Nxf//vsvQkJC8ry+du3aWL58OV68eFGiZ4F4CayY7dkDaBhBqSIIwM2b4nZERGR47O3tERERgbFjx+LOnTvo3bu36rkqVapg+/bt2L9/P86fP49PP/0UqampWu87JCQEb731FiIjI3Hy5Ens2bNHLego3yMpKQlr167F1atXMWfOHGzYsEFtG19fXyQmJuLEiRNIS0tDdnZ2nvfq3r07rK2tERkZiTNnzmDnzp0YMmQIevbsCTc3N90+lNfqi4+Px4kTJ3Dy5El8/PHHameMfH19ERkZib59+2Ljxo1ITEzErl278NtvvwEABg8ejIyMDHTt2hX//fcfLl++jBUrVqhdlisODEDF7M4d/W5HREQlr1+/fnj48CFCQ0PV+uuMHz8e9evXR2hoKFq2bAl3d3eEh4drvV8zMzNs2LABT58+RePGjfHJJ59g6tSpatt89NFHGDFiBAYPHoy6deti//79iI6OVtumY8eOaN26Nd555x24urpqHIpva2uLbdu24cGDB2jUqBE6deqEVq1aYd68ebp9GK+ZOXMmypQpgyZNmiAsLAyhoaGoX7++2jYLFixAp06d8Pnnn6NatWro378/srKyAADlypXDP//8g8zMTLRo0QINGjTAzz//XOxng2SCoOnijGnLyMiAk5MT0tPT4ejoWKR97doldnh+k507gZYti/RWREQG69mzZ0hMTISfnx+sra2lLoeMWEG/S7p8f/MMUDELDgYqVMh/zgmZDPD2FrcjIiKiksEAVMzkcmD2bPH+6yFI+Tg2lh2giYiIShIDUAno0EEc6u7lpd5eoQKHwBMREUmBw+BLSIcO4lD3PXvEDs8eHuJlL575ISIiKnkMQCVILmdHZyIiIkPAS2BERFRiOPCYikpfv0MMQEREVOyUc7o8kWL5dypVlL9DRZ0niJfAiIio2Mnlcjg7O6sW1bS1tVUtJ0GkDUEQ8OTJE9y9exfOzs5qC7YWBgMQERGVCHd3dwAocGVxojdxdnZW/S4VBQMQERGVCJlMBg8PD5QvXx4vXryQuhwyQhYWFkU+86PEAERERCVKLpfr7UuMqLDYCZqIiIhMDgMQERERmRwGICIiIjI57AOkgXKSpYyMDIkrISIiIm0pv7e1mSyRAUiDx48fAwC8vb0lroSIiIh09fjxYzg5ORW4jUzgvOR55Obm4vbt23BwcOBEXfnIyMiAt7c3bt68CUdHR6nLMXn8eRgW/jwMC38ehqU4fx6CIODx48fw9PSEmVnBvXx4BkgDMzMzVKhQQeoyjIKjoyP/h2JA+PMwLPx5GBb+PAxLcf083nTmR4mdoImIiMjkMAARERGRyWEAokKxsrLCxIkTYWVlJXUpBP48DA1/HoaFPw/DYig/D3aCJiIiIpPDM0BERERkchiAiIiIyOQwABEREZHJYQAiIiIik8MARFqLiYlBo0aN4ODggPLlyyM8PBwXL16Uuiz6f99++y1kMhmGDx8udSkmLTk5GT169EC5cuVgY2ODgIAA/Pfff1KXZZIUCgWio6Ph5+cHGxsb+Pv7Y8qUKVqtE0VFt3v3boSFhcHT0xMymQwbN25Ue14QBEyYMAEeHh6wsbFBSEgILl++XGL1MQCR1v79918MGjQIBw8exPbt2/HixQu8//77yMrKkro0k3fkyBH89NNPqF27ttSlmLSHDx+iadOmsLCwwF9//YVz587hhx9+QJkyZaQuzSR99913WLBgAebNm4fz58/ju+++w/Tp0zF37lypSzMJWVlZqFOnDubPn6/x+enTp2POnDlYuHAhDh06BDs7O4SGhuLZs2clUh+HwVOh3bt3D+XLl8e///6L5s2bS12OycrMzET9+vXx448/4ptvvkHdunURGxsrdVkmacyYMdi3bx/27NkjdSkE4MMPP4SbmxsWL16sauvYsSNsbGywcuVKCSszPTKZDBs2bEB4eDgA8eyPp6cnRo4ciVGjRgEA0tPT4ebmhmXLlqFr167FXhPPAFGhpaenAwDKli0rcSWmbdCgQfjggw8QEhIidSkmb9OmTWjYsCE6d+6M8uXLo169evj555+lLstkNWnSBAkJCbh06RIA4OTJk9i7dy/atGkjcWWUmJiIlJQUtf9vOTk5ITAwEAcOHCiRGrgYKhVKbm4uhg8fjqZNm6JWrVpSl2Oy1q5di2PHjuHIkSNSl0IArl27hgULFiAqKgpfffUVjhw5gqFDh8LS0hKRkZFSl2dyxowZg4yMDFSrVg1yuRwKhQJTp05F9+7dpS7N5KWkpAAA3Nzc1Nrd3NxUzxU3BiAqlEGDBuHMmTPYu3ev1KWYrJs3b2LYsGHYvn07rK2tpS6HIP5h0LBhQ0ybNg0AUK9ePZw5cwYLFy5kAJLAb7/9hlWrVmH16tWoWbMmTpw4geHDh8PT05M/D+IlMNLd4MGD8eeff2Lnzp2oUKGC1OWYrKNHj+Lu3buoX78+zM3NYW5ujn///Rdz5syBubk5FAqF1CWaHA8PD9SoUUOtrXr16khKSpKoItP2xRdfYMyYMejatSsCAgLQs2dPjBgxAjExMVKXZvLc3d0BAKmpqWrtqampqueKGwMQaU0QBAwePBgbNmzAP//8Az8/P6lLMmmtWrXC6dOnceLECdWtYcOG6N69O06cOAG5XC51iSanadOmeaaGuHTpEnx8fCSqyLQ9efIEZmbqX3NyuRy5ubkSVURKfn5+cHd3R0JCgqotIyMDhw4dQlBQUInUwEtgpLVBgwZh9erV+OOPP+Dg4KC6Tuvk5AQbGxuJqzM9Dg4Oefpf2dnZoVy5cuyXJZERI0agSZMmmDZtGrp06YLDhw9j0aJFWLRokdSlmaSwsDBMnToVFStWRM2aNXH8+HHMnDkTffv2lbo0k5CZmYkrV66oHicmJuLEiRMoW7YsKlasiOHDh+Obb75BlSpV4Ofnh+joaHh6eqpGihU7gUhLADTeli5dKnVp9P9atGghDBs2TOoyTNr//vc/oVatWoKVlZVQrVo1YdGiRVKXZLIyMjKEYcOGCRUrVhSsra2FSpUqCePGjROys7OlLs0k7Ny5U+N3RmRkpCAIgpCbmytER0cLbm5ugpWVldCqVSvh4sWLJVYf5wEiIiIik8M+QERERGRyGICIiIjI5DAAERERkclhACIiIiKTwwBEREREJocBiIiIiEwOAxARERGZHAYgIqJ8yGQybNy4UeoyiKgYMAARkUHq3bs3ZDJZnlvr1q2lLo2ISgGuBUZEBqt169ZYunSpWpuVlZVE1RBRacIzQERksKysrODu7q52K1OmDADx8tSCBQvQpk0b2NjYoFKlSli/fr3a60+fPo13330XNjY2KFeuHAYMGIDMzEy1bZYsWYKaNWvCysoKHh4eGDx4sNrzaWlpaN++PWxtbVGlShVs2rRJ9dzDhw/RvXt3uLq6wsbGBlWqVMkT2IjIMDEAEZHRio6ORseOHXHy5El0794dXbt2xfnz5wEAWVlZCA0NRZkyZXDkyBGsW7cOO3bsUAs4CxYswKBBgzBgwACcPn0amzZtQuXKldXeY/LkyejSpQtOnTqFtm3bonv37njw4IHq/c+dO4e//voL58+fx4IFC+Di4lJyHwARFV6JLbtKRKSDyMhIQS6XC3Z2dmq3qVOnCoIgCACEzz77TO01gYGBwsCBAwVBEIRFixYJZcqUETIzM1XPb968WTAzMxNSUlIEQRAET09PYdy4cfnWAEAYP3686nFmZqYAQPjrr78EQRCEsLAwoU+fPvo5YCIqUewDREQG65133sGCBQvU2sqWLau6HxQUpPZcUFAQTpw4AQA4f/486tSpAzs7O9XzTZs2RW5uLi5evAiZTIbbt2+jVatWBdZQu3Zt1X07Ozs4Ojri7t27AICBAweiY8eOOHbsGN5//32Eh4ejSZMmhTpWIipZDEBEZLDs7OzyXJLSFxsbG622s7CwUHssk8mQm5sLAGjTpg1u3LiBLVu2YPv27WjVqhUGDRqEGTNm6L1eItIv9gEiIqN18ODBPI+rV68OAKhevTpOnjyJrKws1fP79u2DmZkZqlatCgcHB/j6+iIhIaFINbi6uiIyMhIrV65EbGwsFi1aVKT9EVHJ4BkgIjJY2dnZSElJUWszNzdXdTRet24dGjZsiGbNmmHVqlU4fPgwFi9eDADo3r07Jk6ciMjISEyaNAn37t3DkCFD0LNnT7i5uQEAJk2ahM8++wzly5dHmzZt8PjxY+zbtw9DhgzRqr4JEyagQYMGqFmzJrKzs/Hnn3+qAhgRGTYGICIyWFu3boWHh4daW9WqVXHhwgUA4gittWvX4vPPP4eHhwfWrFmDGjVqAABsbW2xbds2DBs2DI0aNYKtrS06duyImTNnqvYVGRmJZ8+eYdasWRg1ahRcXFzQqVMnreuztLTE2LFjcf36ddjY2CA4OBhr167Vw5ETUXGTCYIgSF0EEZGuZDIZNmzYgPDwcKlLISIjxD5AREREZHIYgIiIiMjksA8QERklXr0noqLgGSAiIiIyOQxAREREZHIYgIiIiMjkMAARERGRyWEAIiIiIpPDAEREREQmhwGIiIiITA4DEBEREZkcBiAiIiIyOf8Hw4fJoTNf+rYAAAAASUVORK5CYII=", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.plot(epochs, acc, 'bo', label='Training acc')\n", "plt.plot(epochs, val_acc, 'b', label='Validation acc')\n", "plt.title('Training and validation accuracy')\n", "plt.xlabel('Epochs')\n", "plt.ylabel('Accuracy')\n", "plt.legend(loc='lower right')\n", "\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "7865d6f2", "metadata": {}, "source": [ "### Export the model\n", "\n", "We can export the model including the TextVectorization layer inside the model to conduct inference on raw text." ] }, { "cell_type": "code", "execution_count": 26, "id": "93b0a42c-437e-41bb-99e7-d58cb8036a3a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m782/782\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 2ms/step - accuracy: 0.4935 - binary_accuracy: 0.0000e+00 - loss: 0.0000e+00\n", "{'accuracy': 0.5, 'binary_accuracy': 0.0, 'loss': 0.0}\n" ] } ], "source": [ "export_model = tf.keras.Sequential([\n", " vectorize_layer,\n", " model,\n", " layers.Activation('sigmoid')\n", "])\n", "\n", "export_model.compile(\n", " loss=losses.BinaryCrossentropy(from_logits=False), optimizer=\"adam\", metrics=['accuracy']\n", ")\n", "\n", "# Test it with `raw_test_ds`, which yields raw strings\n", "metrics = export_model.evaluate(raw_test_ds, return_dict=True)\n", "print(metrics)" ] }, { "cell_type": "markdown", "id": "d0795584", "metadata": {}, "source": [ "Conduct inference on new data:" ] }, { "cell_type": "code", "execution_count": 27, "id": "8939539b-a600-48b1-a55e-3f1087f4a855", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 61ms/step\n" ] }, { "data": { "text/plain": [ "array([[0.67346764],\n", " [0.634105 ],\n", " [0.61044645]], dtype=float32)" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "examples = tf.constant([\n", " \"The movie was great!\",\n", " \"The movie was okay.\",\n", " \"The movie was terrible...\"\n", "])\n", "\n", "export_model.predict(examples)" ] }, { "cell_type": "markdown", "id": "f6b40a59-8d3b-44ec-a4f7-92c5742a0c1c", "metadata": {}, "source": [ "### Save Model" ] }, { "cell_type": "code", "execution_count": 28, "id": "3e520822", "metadata": {}, "outputs": [], "source": [ "os.mkdir('models') if not os.path.exists('models') else None" ] }, { "cell_type": "code", "execution_count": 29, "id": "7f22cc32-2708-4808-8e76-99024da87a21", "metadata": {}, "outputs": [], "source": [ "export_model.save('models/text_model.keras')" ] }, { "cell_type": "markdown", "id": "e0461f74-fdd0-4f30-9f44-0be7ad00d9b0", "metadata": {}, "source": [ "### Load model" ] }, { "cell_type": "code", "execution_count": 30, "id": "c9cf2c7f-5e86-4ff8-984e-dd0ed7a3ece9", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Model: \"sequential_1\"\n",
       "
\n" ], "text/plain": [ "\u001b[1mModel: \"sequential_1\"\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
       "┃ Layer (type)                     Output Shape                  Param # ┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
       "│ text_vectorization              │ (None, 250)            │             0 │\n",
       "│ (TextVectorization)             │                        │               │\n",
       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
       "│ sequential (Sequential)         │ (None, 1)              │       160,017 │\n",
       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
       "│ activation (Activation)         │ (None, 1)              │             0 │\n",
       "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
       "
\n" ], "text/plain": [ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", "│ text_vectorization │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m250\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", "│ (\u001b[38;5;33mTextVectorization\u001b[0m) │ │ │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ sequential (\u001b[38;5;33mSequential\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m160,017\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ activation (\u001b[38;5;33mActivation\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", "└─────────────────────────────────┴────────────────────────┴───────────────┘\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Total params: 160,017 (625.07 KB)\n",
       "
\n" ], "text/plain": [ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m160,017\u001b[0m (625.07 KB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Trainable params: 160,017 (625.07 KB)\n",
       "
\n" ], "text/plain": [ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m160,017\u001b[0m (625.07 KB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Non-trainable params: 0 (0.00 B)\n",
       "
\n" ], "text/plain": [ "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# register callables as custom objects before loading\n", "custom_objects = {\"vectorize_layer\": vectorize_layer, \"custom_standardization\": custom_standardization}\n", "with tf.keras.utils.custom_object_scope(custom_objects):\n", " new_model = tf.keras.models.load_model('models/text_model.keras', compile=False)\n", "\n", "new_model.summary()" ] }, { "cell_type": "markdown", "id": "242a4f7e-fa45-4d21-b103-fe3718bc0f10", "metadata": {}, "source": [ "### Predict" ] }, { "cell_type": "code", "execution_count": 31, "id": "531680b2-42ef-4205-9a38-6995aee9f340", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 59ms/step\n" ] }, { "data": { "text/plain": [ "array([[0.67346764],\n", " [0.634105 ],\n", " [0.61044645]], dtype=float32)" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "new_model.predict(examples)" ] }, { "cell_type": "markdown", "id": "a82ae387-1587-4175-b4b2-66586e4668f7", "metadata": {}, "source": [ "## PySpark" ] }, { "cell_type": "code", "execution_count": 32, "id": "d6d515c2-ce53-4af5-a936-ae91fdecea99", "metadata": {}, "outputs": [], "source": [ "from pyspark.ml.functions import predict_batch_udf\n", "from pyspark.sql.functions import struct, col, array, pandas_udf\n", "from pyspark.sql.types import ArrayType, FloatType, DoubleType\n", "from pyspark.sql import SparkSession\n", "from pyspark import SparkConf\n", "import pandas as pd\n", "import json" ] }, { "cell_type": "markdown", "id": "39c35256", "metadata": {}, "source": [ "Check the cluster environment to handle any platform-specific Spark configurations." ] }, { "cell_type": "code", "execution_count": 33, "id": "31de0c5f", "metadata": {}, "outputs": [], "source": [ "on_databricks = os.environ.get(\"DATABRICKS_RUNTIME_VERSION\", False)\n", "on_dataproc = os.environ.get(\"DATAPROC_IMAGE_VERSION\", False)\n", "on_standalone = not (on_databricks or on_dataproc)" ] }, { "cell_type": "markdown", "id": "55ad7f00", "metadata": {}, "source": [ "#### Create Spark Session\n", "\n", "For local standalone clusters, we'll connect to the cluster and create the Spark Session. \n", "For CSP environments, Spark will either be preconfigured (Databricks) or we'll need to create the Spark Session (Dataproc)." ] }, { "cell_type": "code", "execution_count": 34, "id": "6b653c43", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "25/02/04 14:05:31 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)\n", "25/02/04 14:05:31 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\n", "Setting default log level to \"WARN\".\n", "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n", "25/02/04 14:05:31 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n" ] } ], "source": [ "conf = SparkConf()\n", "\n", "if 'spark' not in globals():\n", " if on_standalone:\n", " import socket\n", " \n", " conda_env = os.environ.get(\"CONDA_PREFIX\")\n", " hostname = socket.gethostname()\n", " conf.setMaster(f\"spark://{hostname}:7077\")\n", " conf.set(\"spark.pyspark.python\", f\"{conda_env}/bin/python\")\n", " conf.set(\"spark.pyspark.driver.python\", f\"{conda_env}/bin/python\")\n", " elif on_dataproc:\n", " conf.set(\"spark.executorEnv.TF_GPU_ALLOCATOR\", \"cuda_malloc_async\")\n", "\n", " conf.set(\"spark.executor.cores\", \"8\")\n", " conf.set(\"spark.task.resource.gpu.amount\", \"0.125\")\n", " conf.set(\"spark.executor.resource.gpu.amount\", \"1\")\n", " conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"true\")\n", " conf.set(\"spark.python.worker.reuse\", \"true\")\n", "\n", "conf.set(\"spark.sql.execution.arrow.maxRecordsPerBatch\", \"1000\")\n", "spark = SparkSession.builder.appName(\"spark-dl-examples\").config(conf=conf).getOrCreate()\n", "sc = spark.sparkContext" ] }, { "cell_type": "markdown", "id": "53b39d27", "metadata": {}, "source": [ "Load the IMDB dataset. We'll perform inference on the first sentence of each sample." ] }, { "cell_type": "code", "execution_count": 35, "id": "ef3309eb", "metadata": {}, "outputs": [], "source": [ "from datasets import load_dataset\n", "\n", "dataset = load_dataset(\"imdb\", split=\"test\")\n", "dataset = dataset.to_pandas().drop(columns=\"label\")" ] }, { "cell_type": "markdown", "id": "3a7672d1", "metadata": {}, "source": [ "#### Create PySpark DataFrame" ] }, { "cell_type": "code", "execution_count": 36, "id": "bb05466f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "StructType([StructField('text', StringType(), True)])" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = spark.createDataFrame(dataset).repartition(8)\n", "df.schema" ] }, { "cell_type": "code", "execution_count": 37, "id": "3f0a594b", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "25/02/04 14:05:36 WARN TaskSetManager: Stage 0 contains a task of very large size (4021 KiB). The maximum recommended task size is 1000 KiB.\n", " \r" ] }, { "data": { "text/plain": [ "[Row(text=\"Anyone remember the first CKY, CKY2K etc..? Back when it was about making crazy cool stuff, rather than watching Bam Margera act like a douchebag, spoiled 5 year old, super/rock-star wannabe.

The show used to be awesome, however, Bam's fame and wealth has led him to believe, that we now enjoy him acting childish and idiotic, more than actual cool stuff, that used to be in ex. CKY2K.

The acts are so repetitive, there's like nothing new, except annoying stupidity and rehearsed comments... The only things we see is Bam Margera, so busy showing us how much he doesn't care, how much money he got or whatsoever.

I really got nothing much left to say except, give us back CKY2K, cause Bam suck..

I enjoy watching Steve-o, Knoxville etc. a thousand times more.\")]" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.take(1)" ] }, { "cell_type": "code", "execution_count": 38, "id": "9d9db063", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "25/02/04 14:05:37 WARN TaskSetManager: Stage 3 contains a task of very large size (4021 KiB). The maximum recommended task size is 1000 KiB.\n" ] } ], "source": [ "data_path = \"spark-dl-datasets/imdb_test\"\n", "if on_databricks:\n", " data_path = \"dbfs:/FileStore/\" + data_path\n", "\n", "df.write.mode(\"overwrite\").parquet(data_path)" ] }, { "cell_type": "markdown", "id": "2f78a16a", "metadata": {}, "source": [ "#### Load and Preprocess PySpark DataFrame\n", "\n", "Define our preprocess function. We'll take the first sentence of each sample as our input for sentiment analysis." ] }, { "cell_type": "code", "execution_count": 39, "id": "1c081557", "metadata": {}, "outputs": [], "source": [ "@pandas_udf(\"string\")\n", "def preprocess(text: pd.Series) -> pd.Series:\n", " return pd.Series([s.split(\".\")[0] for s in text])" ] }, { "cell_type": "code", "execution_count": 40, "id": "60af570a", "metadata": {}, "outputs": [], "source": [ "# Limit to N rows, since this can be slow\n", "df = spark.read.parquet(data_path).limit(512).repartition(8)" ] }, { "cell_type": "code", "execution_count": 41, "id": "a690f6df", "metadata": {}, "outputs": [], "source": [ "input_df = df.select(preprocess(col(\"text\")).alias(\"lines\")).cache()" ] }, { "cell_type": "markdown", "id": "01166d97", "metadata": {}, "source": [ "## Inference using Spark DL API\n", "\n", "Distributed inference using the PySpark [predict_batch_udf](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.functions.predict_batch_udf.html#pyspark.ml.functions.predict_batch_udf):\n", "\n", "- predict_batch_fn uses Tensorflow APIs to load the model and return a predict function which operates on numpy arrays \n", "- predict_batch_udf will convert the Spark DataFrame columns into numpy input batches for the predict function" ] }, { "cell_type": "code", "execution_count": 42, "id": "7b7a8395-e2ae-4c3c-bf57-763dfde600ad", "metadata": {}, "outputs": [], "source": [ "text_model_path = \"{}/models/text_model.keras\".format(os.getcwd())\n", "\n", "# For cloud environments, copy the model to the distributed file system.\n", "if on_databricks:\n", " dbutils.fs.mkdirs(\"/FileStore/spark-dl-models\")\n", " dbfs_model_path = \"/dbfs/FileStore/spark-dl-models/text_model.keras\"\n", " shutil.copy(text_model_path, dbfs_model_path)\n", " text_model_path = dbfs_model_path\n", "elif on_dataproc:\n", " # GCS is mounted at /mnt/gcs by the init script\n", " models_dir = \"/mnt/gcs/spark-dl/models\"\n", " os.mkdir(models_dir) if not os.path.exists(models_dir) else None\n", " gcs_model_path = models_dir + \"/text_model.keras\"\n", " shutil.copy(text_model_path, gcs_model_path)\n", " text_model_path = gcs_model_path" ] }, { "cell_type": "code", "execution_count": 43, "id": "8c0524cf-3a75-4fb8-8025-f0654acce13e", "metadata": {}, "outputs": [], "source": [ "def predict_batch_fn():\n", " # since this function runs on the executor, any required imports should be added inside the function.\n", " import re\n", " import string\n", " import tensorflow as tf\n", " from tensorflow.keras import layers\n", "\n", " # Enable GPU memory growth to avoid CUDA OOM\n", " gpus = tf.config.experimental.list_physical_devices('GPU')\n", " if gpus:\n", " try:\n", " for gpu in gpus:\n", " tf.config.experimental.set_memory_growth(gpu, True)\n", " except RuntimeError as e:\n", " print(e)\n", "\n", " def custom_standardization(input_data):\n", " lowercase = tf.strings.lower(input_data)\n", " stripped_html = tf.strings.regex_replace(lowercase, \"
\", \" \")\n", " return tf.strings.regex_replace(\n", " stripped_html, \"[%s]\" % re.escape(string.punctuation), \"\"\n", " )\n", "\n", " max_features = 10000\n", " sequence_length = 250\n", "\n", " vectorize_layer = layers.TextVectorization(\n", " standardize=custom_standardization,\n", " max_tokens=max_features,\n", " output_mode=\"int\",\n", " output_sequence_length=sequence_length,\n", " )\n", "\n", " custom_objects = {\"vectorize_layer\": vectorize_layer,\n", " \"custom_standardization\": custom_standardization}\n", " with tf.keras.utils.custom_object_scope(custom_objects):\n", " model = tf.keras.models.load_model(text_model_path)\n", "\n", " def predict(inputs):\n", " return model.predict(inputs)\n", "\n", " return predict" ] }, { "cell_type": "code", "execution_count": 44, "id": "0d603644-d938-4c87-aa8a-2512251638d5", "metadata": {}, "outputs": [], "source": [ "classify = predict_batch_udf(predict_batch_fn,\n", " return_type=FloatType(),\n", " batch_size=256)" ] }, { "cell_type": "code", "execution_count": 45, "id": "0b480622-8dc1-4879-933e-c43112768630", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 9:> (0 + 8) / 8]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 6.81 ms, sys: 3.75 ms, total: 10.6 ms\n", "Wall time: 4.62 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "%%time\n", "predictions = input_df.withColumn(\"preds\", classify(struct(\"lines\")))\n", "results = predictions.collect()" ] }, { "cell_type": "code", "execution_count": 46, "id": "31b0a262-387e-4a5e-a60e-b9b8ee456199", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 4.58 ms, sys: 0 ns, total: 4.58 ms\n", "Wall time: 142 ms\n" ] } ], "source": [ "%%time\n", "predictions = input_df.withColumn(\"preds\", classify(\"lines\"))\n", "results = predictions.collect()" ] }, { "cell_type": "code", "execution_count": 47, "id": "7ef9e431-59f5-4b29-9f79-ae16a9cfb0b9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 903 μs, sys: 4.09 ms, total: 5 ms\n", "Wall time: 222 ms\n" ] } ], "source": [ "%%time\n", "predictions = input_df.withColumn(\"preds\", classify(col(\"lines\")))\n", "results = predictions.collect()" ] }, { "cell_type": "code", "execution_count": 48, "id": "9a325ee2-3268-414a-bb75-a5fcf794f512", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+--------------------------------------------------------------------------------+----------+\n", "| lines| preds|\n", "+--------------------------------------------------------------------------------+----------+\n", "|The only reason I'm even giving this movie a 4 is because it was made in to a...| 0.571606|\n", "|Awkward disaster mishmash has a team of scavengers coming across the overturn...| 0.6264358|\n", "|Here is a fantastic concept for a film - a series of meteors crash into a sma...| 0.6764294|\n", "| I walked out of the cinema having suffered this film after 30 mins| 0.6258814|\n", "|A wildly uneven film where the major problem is the uneasy mix of comedy and ...|0.63658905|\n", "|Leonard Rossiter and Frances de la Tour carry this film, not without a strugg...| 0.633625|\n", "| A good cast|0.65998995|\n", "|Yet again, I appear to be the only person on planet Earth who is capable of c...| 0.6435825|\n", "|As a serious horror fan, I get that certain marketing ploys are used to sell ...| 0.6453945|\n", "|Upon writing this review I have difficulty trying to think of what to write a...|0.61587423|\n", "| Simply awful| 0.594154|\n", "|I am a fan of Ed Harris' work and I really had high expectations about this film| 0.6366444|\n", "| Well|0.65976477|\n", "| This is a new approach to comedy| 0.6555772|\n", "| It's been mentioned by others the inane dialogue in this series and I agree| 0.6534178|\n", "|One of the most boring movies I've ever had to sit through, it's completely f...| 0.5919746|\n", "|This movie was playing on Lifetime Movie Network last month and I decided to ...| 0.6527056|\n", "| 1983's \"Frightmare\" is an odd little film|0.64622015|\n", "| 'Felony' is a B-movie|0.64882356|\n", "| This movie defines the word \"confused\"|0.63689107|\n", "+--------------------------------------------------------------------------------+----------+\n", "only showing top 20 rows\n", "\n" ] } ], "source": [ "predictions.show(truncate=80)" ] }, { "cell_type": "markdown", "id": "ad9b07e6", "metadata": {}, "source": [ "## Using Triton Inference Server\n", "In this section, we demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL. \n", "We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server. \n", "\n", "The process looks like this:\n", "- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.\n", "- Define a Triton inference function, which contains a client that binds to the local server on a given node and sends inference requests.\n", "- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.\n", "- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.\n", "\n", "\"drawing\"" ] }, { "cell_type": "markdown", "id": "889a1623", "metadata": {}, "source": [ "First we'll cleanup the vocabulary layer of the model to remove non-ASCII characters. This ensures the inputs can be properly serialized and sent to Triton." ] }, { "cell_type": "code", "execution_count": 49, "id": "f4f14c8f", "metadata": {}, "outputs": [], "source": [ "import unicodedata\n", "\n", "def normalize_vocabulary(vocab):\n", " # Normalize each word in the vocabulary to remove non-ASCII characters\n", " normalized_vocab = [\n", " unicodedata.normalize('NFKD', word).encode('ascii', 'ignore').decode('utf-8')\n", " for word in vocab\n", " ]\n", " normalized_vocab = filter(lambda x: x != '', normalized_vocab)\n", " normalized_vocab = list(set(normalized_vocab)) \n", "\n", "\n", " return normalized_vocab\n", "\n", "vocab = vectorize_layer.get_vocabulary()\n", "normalized_vocab = normalize_vocabulary(vocab)\n", "\n", "# Reassign the cleaned vocabulary to the TextVectorization layer\n", "vectorize_layer.set_vocabulary(normalized_vocab)" ] }, { "cell_type": "code", "execution_count": 50, "id": "9614a192", "metadata": {}, "outputs": [], "source": [ "# Save the model with the cleaned vocabulary\n", "triton_model_path = '{}/models/text_model_cleaned.keras'.format(os.getcwd())\n", "export_model.save(triton_model_path)\n", "\n", "# For cloud environments, copy the model to the distributed file system.\n", "if on_databricks:\n", " dbutils.fs.mkdirs(\"/FileStore/spark-dl-models\")\n", " dbfs_model_path = \"/dbfs/FileStore/spark-dl-models/text_model_cleaned.keras\"\n", " shutil.copy(triton_model_path, dbfs_model_path)\n", " triton_model_path = dbfs_model_path\n", "elif on_dataproc:\n", " # GCS is mounted at /mnt/gcs by the init script\n", " models_dir = \"/mnt/gcs/spark-dl/models\"\n", " os.mkdir(models_dir) if not os.path.exists(models_dir) else None\n", " gcs_model_path = models_dir + \"/text_model_cleaned.keras\"\n", " shutil.copy(triton_model_path, gcs_model_path)\n", " triton_model_path = gcs_model_path" ] }, { "cell_type": "code", "execution_count": 51, "id": "32d0142a", "metadata": {}, "outputs": [], "source": [ "from functools import partial" ] }, { "cell_type": "markdown", "id": "edddffb9", "metadata": {}, "source": [ "Import the helper class from server_utils.py:" ] }, { "cell_type": "code", "execution_count": 52, "id": "444bad3f", "metadata": {}, "outputs": [], "source": [ "sc.addPyFile(\"server_utils.py\")\n", "\n", "from server_utils import TritonServerManager" ] }, { "cell_type": "markdown", "id": "f0923a56", "metadata": {}, "source": [ "Define the Triton Server function:" ] }, { "cell_type": "code", "execution_count": 53, "id": "a4d37d33", "metadata": {}, "outputs": [], "source": [ "def triton_server(ports, model_path):\n", " import time\n", " import signal\n", " import numpy as np\n", " import tensorflow as tf\n", " from pytriton.decorators import batch\n", " from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor\n", " from pytriton.triton import Triton, TritonConfig\n", " from pyspark import TaskContext\n", " from tensorflow.keras import layers \n", "\n", " \n", " print(f\"SERVER: Initializing model on worker {TaskContext.get().partitionId()}.\")\n", " # Enable GPU memory growth\n", " gpus = tf.config.experimental.list_physical_devices('GPU')\n", " if gpus:\n", " try:\n", " for gpu in gpus:\n", " tf.config.experimental.set_memory_growth(gpu, True)\n", " except RuntimeError as e:\n", " print(e)\n", "\n", " def custom_standardization(input_data):\n", " lowercase = tf.strings.lower(input_data)\n", " stripped_html = tf.strings.regex_replace(lowercase, \"
\", \" \")\n", " return tf.strings.regex_replace(\n", " stripped_html, \"[%s]\" % re.escape(string.punctuation), \"\"\n", " )\n", "\n", " max_features = 10000\n", " sequence_length = 250\n", "\n", " vectorize_layer = layers.TextVectorization(\n", " standardize=custom_standardization,\n", " max_tokens=max_features,\n", " output_mode=\"int\",\n", " output_sequence_length=sequence_length,\n", " )\n", "\n", " custom_objects = {\"vectorize_layer\": vectorize_layer,\n", " \"custom_standardization\": custom_standardization}\n", "\n", " with tf.keras.utils.custom_object_scope(custom_objects):\n", " model = tf.keras.models.load_model(model_path)\n", "\n", " @batch\n", " def _infer_fn(**inputs):\n", " sentences = inputs[\"text\"]\n", " print(f\"SERVER: Received batch of size {len(sentences)}.\")\n", " decoded_sentences = tf.convert_to_tensor(np.vectorize(lambda x: x.decode('utf-8'))(sentences))\n", " return {\n", " \"preds\": model.predict(decoded_sentences)\n", " }\n", " \n", " workspace_path = f\"/tmp/triton_{time.strftime('%m_%d_%M_%S')}\"\n", " triton_conf = TritonConfig(http_port=ports[0], grpc_port=ports[1], metrics_port=ports[2])\n", " with Triton(config=triton_conf, workspace=workspace_path) as triton:\n", " triton.bind(\n", " model_name=\"TextModel\",\n", " infer_func=_infer_fn,\n", " inputs=[\n", " Tensor(name=\"text\", dtype=np.bytes_, shape=(-1,)),\n", " ],\n", " outputs=[\n", " Tensor(name=\"preds\", dtype=np.float32, shape=(-1,)),\n", " ],\n", " config=ModelConfig(\n", " max_batch_size=128,\n", " batcher=DynamicBatcher(max_queue_delay_microseconds=5000), # 5ms\n", " ),\n", " strict=True,\n", " )\n", "\n", " def _stop_triton(signum, frame):\n", " # The server manager sends SIGTERM to stop the server; this function ensures graceful cleanup.\n", " print(\"SERVER: Received SIGTERM. Stopping Triton server.\")\n", " triton.stop()\n", "\n", " signal.signal(signal.SIGTERM, _stop_triton)\n", "\n", " print(\"SERVER: Serving inference\")\n", " triton.serve()" ] }, { "cell_type": "markdown", "id": "d340e231", "metadata": {}, "source": [ "#### Start Triton servers" ] }, { "cell_type": "markdown", "id": "fcdb7c5a", "metadata": {}, "source": [ "The `TritonServerManager` will handle the lifecycle of Triton server instances across the Spark cluster:\n", "- Find available ports for HTTP/gRPC/metrics\n", "- Deploy a server on each node via stage-level scheduling\n", "- Gracefully shutdown servers across nodes" ] }, { "cell_type": "code", "execution_count": 55, "id": "4d5dc419", "metadata": {}, "outputs": [], "source": [ "model_name = \"TextModel\"\n", "server_manager = TritonServerManager(model_name=model_name, model_path=triton_model_path)" ] }, { "cell_type": "code", "execution_count": null, "id": "20198644", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-02-07 11:03:44,809 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\n", "2025-02-07 11:03:44,810 - INFO - Starting 1 servers.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "data": { "text/plain": [ "{'cb4ae00-lcedt': (2020631, [7000, 7001, 7002])}" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Returns {'hostname', (server_pid, [http_port, grpc_port, metrics_port])}\n", "server_manager.start_servers(triton_server)" ] }, { "cell_type": "markdown", "id": "e1477f4b", "metadata": {}, "source": [ "#### Define client function" ] }, { "cell_type": "markdown", "id": "798c2815", "metadata": {}, "source": [ "Get the hostname -> url mapping from the server manager:" ] }, { "cell_type": "code", "execution_count": null, "id": "813d42cf", "metadata": {}, "outputs": [], "source": [ "host_to_http_url = server_manager.host_to_http_url # or server_manager.host_to_grpc_url" ] }, { "cell_type": "markdown", "id": "f16617e3", "metadata": {}, "source": [ "Define the Triton inference function, which returns a predict function for batch inference through the server:" ] }, { "cell_type": "code", "execution_count": 58, "id": "0ad47438", "metadata": {}, "outputs": [], "source": [ "def triton_fn(model_name, host_to_url):\n", " import socket\n", " import numpy as np\n", " from pytriton.client import ModelClient\n", "\n", " url = host_to_url[socket.gethostname()]\n", " print(f\"CLIENT: Connecting to {model_name} at {url}\")\n", "\n", " def infer_batch(inputs):\n", " with ModelClient(url, model_name, inference_timeout_s=240) as client:\n", " encoded_inputs = np.vectorize(lambda x: x.encode(\"utf-8\"))(inputs).astype(np.bytes_)\n", " encoded_inputs = np.expand_dims(encoded_inputs, axis=1)\n", " result_data = client.infer_batch(encoded_inputs)\n", " \n", " return result_data[\"preds\"]\n", " \n", " return infer_batch" ] }, { "cell_type": "code", "execution_count": 61, "id": "8e06d33f-5cef-4a48-afc3-5d468f8ec2b4", "metadata": {}, "outputs": [], "source": [ "classify = predict_batch_udf(partial(triton_fn, model_name=model_name, host_to_url=host_to_http_url),\n", " return_type=FloatType(),\n", " batch_size=64)" ] }, { "cell_type": "markdown", "id": "91974885", "metadata": {}, "source": [ "#### Load and preprocess DataFrame" ] }, { "cell_type": "code", "execution_count": 59, "id": "41106a02-236e-4cb3-ac51-76aa64b663c2", "metadata": {}, "outputs": [], "source": [ "df = spark.read.parquet(data_path).limit(512).repartition(8)" ] }, { "cell_type": "code", "execution_count": 60, "id": "e851870b", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "25/02/04 14:05:48 WARN CacheManager: Asked to cache already cached data.\n" ] } ], "source": [ "input_df = df.select(preprocess(col(\"text\")).alias(\"lines\")).cache()" ] }, { "cell_type": "code", "execution_count": 62, "id": "d89e74ad-e551-4bfa-ad08-98725878630a", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 24:==============> (2 + 6) / 8]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 2.92 ms, sys: 4.06 ms, total: 6.97 ms\n", "Wall time: 1.03 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "%%time\n", "predictions = input_df.withColumn(\"preds\", classify(struct(\"lines\")))\n", "results = predictions.collect()" ] }, { "cell_type": "code", "execution_count": 63, "id": "b4fa7fc9-341c-49a6-9af2-e316f2355d67", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 1.39 ms, sys: 2.15 ms, total: 3.53 ms\n", "Wall time: 237 ms\n" ] } ], "source": [ "%%time\n", "predictions = input_df.withColumn(\"preds\", classify(\"lines\"))\n", "results = predictions.collect()" ] }, { "cell_type": "code", "execution_count": 64, "id": "564f999b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 862 μs, sys: 2.77 ms, total: 3.63 ms\n", "Wall time: 225 ms\n" ] } ], "source": [ "%%time\n", "predictions = input_df.withColumn(\"preds\", classify(col(\"lines\")))\n", "results = predictions.collect()" ] }, { "cell_type": "code", "execution_count": 65, "id": "9222e8a9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+--------------------------------------------------------------------------------+----------+\n", "| lines| preds|\n", "+--------------------------------------------------------------------------------+----------+\n", "|The only reason I'm even giving this movie a 4 is because it was made in to a...|0.67212176|\n", "|Awkward disaster mishmash has a team of scavengers coming across the overturn...|0.63807774|\n", "|Here is a fantastic concept for a film - a series of meteors crash into a sma...|0.65471745|\n", "| I walked out of the cinema having suffered this film after 30 mins| 0.6527998|\n", "|A wildly uneven film where the major problem is the uneasy mix of comedy and ...| 0.6405446|\n", "|Leonard Rossiter and Frances de la Tour carry this film, not without a strugg...|0.63534474|\n", "| A good cast|0.64761806|\n", "|Yet again, I appear to be the only person on planet Earth who is capable of c...|0.66956663|\n", "|As a serious horror fan, I get that certain marketing ploys are used to sell ...|0.62346375|\n", "|Upon writing this review I have difficulty trying to think of what to write a...| 0.681598|\n", "| Simply awful| 0.6537583|\n", "|I am a fan of Ed Harris' work and I really had high expectations about this film| 0.6382922|\n", "| Well|0.65424603|\n", "| This is a new approach to comedy| 0.6628315|\n", "| It's been mentioned by others the inane dialogue in this series and I agree|0.63345987|\n", "|One of the most boring movies I've ever had to sit through, it's completely f...| 0.6459369|\n", "|This movie was playing on Lifetime Movie Network last month and I decided to ...|0.65335083|\n", "| 1983's \"Frightmare\" is an odd little film|0.65602964|\n", "| 'Felony' is a B-movie| 0.6583404|\n", "| This movie defines the word \"confused\"| 0.6217103|\n", "+--------------------------------------------------------------------------------+----------+\n", "only showing top 20 rows\n", "\n" ] } ], "source": [ "predictions.show(truncate=80)" ] }, { "cell_type": "markdown", "id": "d45e8981-ca44-429b-9b37-e04035c3a86b", "metadata": { "tags": [] }, "source": [ "#### Stop Triton Server on each executor" ] }, { "cell_type": "code", "execution_count": 66, "id": "a71ac9b6-47a2-4306-bc40-9ce7b4e968ec", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-02-04 14:05:50,166 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\n", "2025-02-04 14:06:00,351 - INFO - Sucessfully stopped 1 servers. \n" ] }, { "data": { "text/plain": [ "[True]" ] }, "execution_count": 66, "metadata": {}, "output_type": "execute_result" } ], "source": [ "server_manager.stop_servers()" ] }, { "cell_type": "code", "execution_count": 67, "id": "54a90574-7cbb-487b-b7a8-dcda0e6e301f", "metadata": {}, "outputs": [], "source": [ "if not on_databricks: # on databricks, spark.stop() puts the cluster in a bad state\n", " spark.stop()" ] }, { "cell_type": "code", "execution_count": null, "id": "88e3bfea-a825-46eb-b8c2-921a932c0089", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "spark-dl-tf", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: examples/ML+DL-Examples/Spark-DL/dl_inference/tf_requirements.txt ================================================ # Copyright (c) 2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -r requirements.txt tensorflow[and-cuda] tf-keras ================================================ FILE: examples/ML+DL-Examples/Spark-DL/dl_inference/torch_requirements.txt ================================================ # Copyright (c) 2025, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -r requirements.txt torch<=2.5.1 torchvision torch-tensorrt tensorrt --extra-index-url https://download.pytorch.org/whl/cu121 sentence_transformers sentencepiece nvidia-modelopt[all] --extra-index-url https://pypi.nvidia.com ================================================ FILE: examples/ML+DL-Examples/Spark-DL/dl_inference/vllm/qwen-2.5-14b-tensor-parallel_vllm.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "\n", "# PySpark LLM Inference: Qwen-2.5-14b Data Structuring\n", "\n", "In this notebook, we demonstrate distributed batch inference with [Qwen-2.5](https://huggingface.co/Qwen/Qwen2.5-14B-Instruct), using open weights on Huggingface.\n", "\n", "The Qwen-2.5-14b-instruct is an instruction-fine-tuned version of the Qwen-2.5-14b base model. We'll show how to use the model to prepare unstructured text data into a structured schema for downstream tasks.\n", "\n", "**Note:** This example demonstrates **tensor parallelism**, which requires multiple GPUs per node. For standalone users, make sure to use a Spark worker with 2 GPUs. If you follow the Databricks or Dataproc instructions, make sure to include the `tp` argument to the cluster startup scripts." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Check the cluster environment to handle any platform-specific configurations." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "tags": [] }, "outputs": [], "source": [ "import os\n", "\n", "on_databricks = os.environ.get(\"DATABRICKS_RUNTIME_VERSION\", False)\n", "on_dataproc = os.environ.get(\"DATAPROC_IMAGE_VERSION\", False)\n", "on_standalone = not (on_databricks or on_dataproc)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "tags": [] }, "outputs": [], "source": [ "# For cloud environments, load the model to the distributed file system.\n", "if on_databricks:\n", " models_dir = \"/dbfs/FileStore/spark-dl-models\"\n", " dbutils.fs.mkdirs(\"/FileStore/spark-dl-models\")\n", " model_path = f\"{models_dir}/qwen2.5-14b\"\n", "elif on_dataproc:\n", " models_dir = \"/mnt/gcs/spark-dl-models\"\n", " os.mkdir(models_dir) if not os.path.exists(models_dir) else None\n", " model_path = f\"{models_dir}/qwen2.5-14b\"\n", "else:\n", " model_path = os.path.abspath(\"qwen2.5-14b\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Download the model from huggingface hub." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "tags": [] }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f75ef1f2071f413da5ae502589293c62", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Fetching 18 files: 0%| | 0/18 [00:00system\\n\"),\n", " lit(system_prompt),\n", " lit(\"<|im_end|>\\n<|im_start|>user\\n\"),\n", " lit(\"Analyze this review: \"),\n", " col(\"value\"),\n", " lit(\"<|im_end|>\\n<|im_start|>assistant\\n\")\n", " ).alias(\"prompt\")\n", ")" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "<|im_start|>system\n", "You are a specialized review analysis AI that categorizes product reviews into precise sentiment categories.\n", "IMPORTANT: Your response must contain ONLY valid JSON and nothing else - no explanations, no additional text.\n", "For each review, analyze and output EXACTLY this JSON structure:\n", "{\n", " \"primary_sentiment\": [EXACTLY ONE OF: \"positive\", \"negative\", \"neutral\", \"mixed\"],\n", " \"sentiment_score\": [integer between 1-10, where 1 is extremely negative and 10 is extremely positive],\n", " \"purchase_intention\": [EXACTLY ONE OF: \"will repurchase\", \"might repurchase\", \"will not repurchase\", \"recommends alternatives\", \"uncertain\"]\n", "}\n", "\n", "Do not include any text before or after the JSON. The response should start with '{' and end with '}' with no trailing characters, comments, or explanations.\n", "<|im_end|>\n", "<|im_start|>user\n", "Analyze this review: Installing the game was a struggle (because of games for windows live bugs).Some championship races and cars can only be \"unlocked\" by buying them as an addon to the game. I paid nearly 30 dollars when the game was new. I don't like the idea that I have to keep paying to keep playing.I noticed no improvement in the physics or graphics compared to Dirt 2.I tossed it in the garbage and vowed never to buy another codemasters game. I'm really tired of arcade style rally/racing games anyway.I'll continue to get my fix from Richard Burns Rally, and you should to. :)http://www.amazon.com/Richard-Burns-Rally-PC/dp/B000C97156/ref=sr_1_1?ie=UTF8&qid;=1341886844&sr;=8-1&keywords;=richard+burns+rallyThank you for reading my review! If you enjoyed it, be sure to rate it as helpful.<|im_end|>\n", "<|im_start|>assistant\n", "\n" ] } ], "source": [ "print(df.take(1)[0].prompt)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "data_path = \"spark-dl-datasets/amazon_video_game_reviews\"\n", "if on_databricks:\n", " dbutils.fs.mkdirs(\"/FileStore/spark-dl-datasets\")\n", " data_path = \"dbfs:/FileStore/\" + data_path\n", "\n", "df.write.mode(\"overwrite\").parquet(data_path)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Using vLLM Server\n", "In this section, we demonstrate integration with [vLLM Serving](https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html), an open-source server with an OpenAI-compatible completions endpoint for LLMs. \n", "\n", "The process looks like this:\n", "- Distribute a server startup task across the Spark cluster, instructing each node to launch a vLLM server process.\n", "- Define a vLLM inference function, which sends inference request to the local server on a given node.\n", "- Wrap the vLLM inference function in a predict_batch_udf to launch parallel inference requests using Spark.\n", "- Finally, distribute a shutdown signal to terminate the vLLM server processes on each node.\n", "\n", "\"drawing\"" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "tags": [] }, "outputs": [], "source": [ "from functools import partial" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Import the helper class from server_utils.py:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "sc.addPyFile(\"server_utils.py\")\n", "\n", "from server_utils import VLLMServerManager" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "There are currently some hanging issues with vLLM's `torch.compile` on Databricks, which we are working to resolve. For now we will enforce eager mode on Databricks, which disables compilation at some performance cost." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "enforce_eager = True if on_databricks else False" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Start vLLM servers" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `VLLMServerManager` will handle the lifecycle of vLLM server instances across the Spark cluster:\n", "- Find available ports for HTTP\n", "- Deploy a server on each node via stage-level scheduling\n", "- Gracefully shutdown servers across nodes" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "model_name = \"qwen-2.5-14b\"\n", "server_manager = VLLMServerManager(model_name=model_name, model_path=model_path)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can pass any of the supported [vLLM serve CLI arguments](https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#vllm-serve) as key-word arguments when starting the servers. Note that this can take some time, as it includes loading the model from disk, Torch compilation, and capturing CUDA graphs.\n", "\n", "Here, we set `tensor_parallel_size` to the number of GPUs per node:" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-03-20 01:04:42,978 - INFO - Requesting stage-level resources: (cores=13, gpu=2.0)\n", "2025-03-20 01:04:42,979 - INFO - Starting 2 VLLM servers.\n", " \r" ] }, { "data": { "text/plain": [ "{'spark-dl-inference-vllm-tp-w-0': (35438, [7000]),\n", " 'spark-dl-inference-vllm-tp-w-1': (35288, [7000])}" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tensor_parallel_size = int(spark.conf.get(\"spark.executor.resource.gpu.amount\"))\n", "server_manager.start_servers(tensor_parallel_size=tensor_parallel_size,\n", " gpu_memory_utilization=0.95,\n", " max_model_len=6600,\n", " task=\"generate\",\n", " enforce_eager=enforce_eager,\n", " wait_retries=100)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Define client function" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Get the hostname -> url mapping from the server manager:" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "tags": [] }, "outputs": [], "source": [ "host_to_http_url = server_manager.host_to_http_url" ] }, { "cell_type": "code", "execution_count": 51, "metadata": { "tags": [] }, "outputs": [], "source": [ "def vllm_fn(model_name, host_to_url):\n", " import socket\n", " import json\n", " import requests\n", "\n", " url = host_to_url[socket.gethostname()]\n", " \n", " def predict(inputs):\n", " print(inputs)\n", " response = requests.post(\n", " \"http://localhost:7000/v1/completions\",\n", " json={\n", " \"model\": model_name,\n", " \"prompt\": inputs.tolist(),\n", " \"max_tokens\": 50,\n", " \"temperature\": 0.7,\n", " \"top_p\": 0.8,\n", " \"repetition_penalty\": 1.05,\n", " }\n", " )\n", " result_dicts = [json.loads(o[\"text\"]) for o in response.json()[\"choices\"]]\n", " return result_dicts\n", " \n", " return predict" ] }, { "cell_type": "code", "execution_count": 55, "metadata": { "tags": [] }, "outputs": [], "source": [ "generate = predict_batch_udf(partial(vllm_fn, model_name=model_name, host_to_url=host_to_http_url),\n", " return_type=StructType([\n", " StructField(\"primary_sentiment\", StringType()),\n", " StructField(\"sentiment_score\", IntegerType()),\n", " StructField(\"purchase_intention\", StringType())\n", " ]),\n", " batch_size=32)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Load DataFrame" ] }, { "cell_type": "code", "execution_count": 56, "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "df = spark.read.parquet(data_path).repartition(16)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Run Inference" ] }, { "cell_type": "code", "execution_count": 57, "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 31:=================================================> (14 + 2) / 16]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 29.6 ms, sys: 6.89 ms, total: 36.5 ms\n", "Wall time: 33 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "%%time\n", "# first pass caches model/fn\n", "preds = df.withColumn(\"outputs\", generate(col(\"prompt\"))).select(\"prompt\", \"outputs.*\")\n", "results = preds.collect()" ] }, { "cell_type": "code", "execution_count": 58, "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 34:=================================================> (14 + 2) / 16]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 25.6 ms, sys: 6.73 ms, total: 32.3 ms\n", "Wall time: 32 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "%%time\n", "preds = df.withColumn(\"outputs\", generate(col(\"prompt\"))).select(\"prompt\", \"outputs.*\")\n", "results = preds.collect()" ] }, { "cell_type": "code", "execution_count": 59, "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 37:> (0 + 1) / 1]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "+--------------------------------------------------+-----------------+---------------+-------------------+\n", "| prompt|primary_sentiment|sentiment_score| purchase_intention|\n", "+--------------------------------------------------+-----------------+---------------+-------------------+\n", "|<|im_start|>system\\nYou are a specialized revie...| positive| 9| will repurchase|\n", "|<|im_start|>system\\nYou are a specialized revie...| positive| 9| will repurchase|\n", "|<|im_start|>system\\nYou are a specialized revie...| positive| 8| will repurchase|\n", "|<|im_start|>system\\nYou are a specialized revie...| negative| 4|will not repurchase|\n", "|<|im_start|>system\\nYou are a specialized revie...| mixed| 6| might repurchase|\n", "+--------------------------------------------------+-----------------+---------------+-------------------+\n", "only showing top 5 rows\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "preds.show(5, truncate=50)" ] }, { "cell_type": "code", "execution_count": 60, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Review: <|im_start|>system\n", "You are a specialized review analysis AI that categorizes product reviews into precise sentiment categories.\n", "IMPORTANT: Your response must contain ONLY valid JSON and nothing else - no explanations, no additional text.\n", "For each review, analyze and output EXACTLY this JSON structure:\n", "{\n", " \"primary_sentiment\": [EXACTLY ONE OF: \"positive\", \"negative\", \"neutral\", \"mixed\"],\n", " \"sentiment_score\": [integer between 1-10, where 1 is extremely negative and 10 is extremely positive],\n", " \"purchase_intention\": [EXACTLY ONE OF: \"will repurchase\", \"might repurchase\", \"will not repurchase\", \"recommends alternatives\", \"uncertain\"]\n", "}\n", "\n", "Do not include any text before or after the JSON. The response should start with '{' and end with '}' with no trailing characters, comments, or explanations.\n", "<|im_end|>\n", "<|im_start|>user\n", "Analyze this review: I have never played anything like this since. Everything from Sly Racoon, to Ratchet and Clank, owe it to this.Wicked witch Gruntilda takes Banjo's sister to hey layer, miles away in a realistic 3D cartoon world.Banjo is a bear with Kazooie a bird in his backpack that can help him jump and fly and basically you learn to do lots of things with it. You solve puzzles via action and collect tolkens across lovely maps. Mumbo Jumbo transforms Banjo into some other creatures along the way. You can fly. It was amazing. A full adventure all the way to end. We played it for months and I have NEVER played anything like it again. The makers of Donkey Kong released it at the best time. It is now up to the future generations to make adventure concepts better than this one. This is one of the best N64 games ever.<|im_end|>\n", "<|im_start|>assistant\n", "\n", "Sentiment: positive, Score: 9, Status: will repurchase\n" ] } ], "source": [ "sample = results[0]\n", "print(\"Review:\", sample[\"prompt\"])\n", "print(f\"Sentiment: {sample['primary_sentiment']}, Score: {sample['sentiment_score']}, Status: {sample['purchase_intention']}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Shut down server on each executor" ] }, { "cell_type": "code", "execution_count": 61, "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-03-20 01:19:32,218 - INFO - Requesting stage-level resources: (cores=13, gpu=2.0)\n", "2025-03-20 01:19:33,872 - INFO - Successfully stopped 2 VLLM servers. \n" ] }, { "data": { "text/plain": [ "[True, True]" ] }, "execution_count": 61, "metadata": {}, "output_type": "execute_result" } ], "source": [ "server_manager.stop_servers()" ] }, { "cell_type": "code", "execution_count": 62, "metadata": { "tags": [] }, "outputs": [], "source": [ "if not on_databricks: # on databricks, spark.stop() puts the cluster in a bad state\n", " spark.stop()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "spark-dl-vllm", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.11" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: examples/ML+DL-Examples/Spark-DL/dl_inference/vllm/qwen-2.5-7b_vllm.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "\n", "# PySpark LLM Inference: Qwen-2.5 Text Summarization\n", "\n", "In this notebook, we demonstrate distributed batch inference with [Qwen-2.5](https://huggingface.co/Qwen/Qwen2.5-7B-Instruct), using open weights on Huggingface.\n", "\n", "The Qwen-2.5-7b-instruct is an instruction-fine-tuned version of the Qwen-2.5-7b base model. We'll show how to use the model to perform text summarization.\n", "\n", "**Note:** Running this model on GPU with 16-bit precision requires **~16GB** of GPU RAM. Make sure your instances have sufficient GPU capacity." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# Manually enable Huggingface tokenizer parallelism to avoid disabling with PySpark parallelism.\n", "# See (https://github.com/huggingface/transformers/issues/5486) for more info. \n", "import os\n", "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"true\"\n", "\n", "# vLLM does CUDA init at import time. Forking will try to re-initialize CUDA if vLLM was imported before and throw an error.\n", "os.environ[\"VLLM_WORKER_MULTIPROC_METHOD\"] = \"spawn\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Check the cluster environment to handle any platform-specific configurations." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "on_databricks = os.environ.get(\"DATABRICKS_RUNTIME_VERSION\", False)\n", "on_dataproc = os.environ.get(\"DATAPROC_IMAGE_VERSION\", False)\n", "on_standalone = not (on_databricks or on_dataproc)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# For cloud environments, load the model to the distributed file system.\n", "if on_databricks:\n", " models_dir = \"/dbfs/FileStore/spark-dl-models\"\n", " dbutils.fs.mkdirs(\"/FileStore/spark-dl-models\")\n", " model_path = f\"{models_dir}/qwen-2.5-7b\"\n", "elif on_dataproc:\n", " models_dir = \"/mnt/gcs/spark-dl-models\"\n", " os.mkdir(models_dir) if not os.path.exists(models_dir) else None\n", " model_path = f\"{models_dir}/qwen-2.5-7b\"\n", "else:\n", " model_path = os.path.abspath(\"qwen-2.5-7b\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Download the model from huggingface hub." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "eeb0daf2bd7948bebd94ce2a9a5a01b8", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Fetching 14 files: 0%| | 0/14 [00:00system\\n\"),\n", " lit(system_prompt),\n", " lit(\"<|im_end|>\\n<|im_start|>user\\n\"),\n", " col(\"value\"),\n", " lit(\"<|im_end|>\\n<|im_start|>assistant\\n\")\n", " ).alias(\"prompt\")\n", ")" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "<|im_start|>system\n", "You are a knowledgeable AI assistant. Your job is to create a 1 sentence summary \n", "of a research abstract that captures the main objective, methodology, and key findings, using clear \n", "language while preserving technical accuracy and quantitative results.<|im_end|>\n", "<|im_start|>user\n", " The problem of statistical learning is to construct a predictor of a random\n", "variable $Y$ as a function of a related random variable $X$ on the basis of an\n", "i.i.d. training sample from the joint distribution of $(X,Y)$. Allowable\n", "predictors are drawn from some specified class, and the goal is to approach\n", "asymptotically the performance (expected loss) of the best predictor in the\n", "class. We consider the setting in which one has perfect observation of the\n", "$X$-part of the sample, while the $Y$-part has to be communicated at some\n", "finite bit rate. The encoding of the $Y$-values is allowed to depend on the\n", "$X$-values. Under suitable regularity conditions on the admissible predictors,\n", "the underlying family of probability distributions and the loss function, we\n", "give an information-theoretic characterization of achievable predictor\n", "performance in terms of conditional distortion-rate functions. The ideas are\n", "illustrated on the example of nonparametric regression in Gaussian noise.\n", "<|im_end|>\n", "<|im_start|>assistant\n", "\n" ] } ], "source": [ "print(df.take(1)[0].prompt)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "data_path = \"spark-dl-datasets/arxiv_abstracts\"\n", "if on_databricks:\n", " dbutils.fs.mkdirs(\"/FileStore/spark-dl-datasets\")\n", " data_path = \"dbfs:/FileStore/\" + data_path\n", "\n", "df.write.mode(\"overwrite\").parquet(data_path)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Using vLLM Server\n", "In this section, we demonstrate integration with [vLLM Serving](https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html), an open-source server with an OpenAI-compatible completions endpoint for LLMs. \n", "\n", "The process looks like this:\n", "- Distribute a server startup task across the Spark cluster, instructing each node to launch a vLLM server process.\n", "- Define a vLLM inference function, which sends inference request to the local server on a given node.\n", "- Wrap the vLLM inference function in a predict_batch_udf to launch parallel inference requests using Spark.\n", "- Finally, distribute a shutdown signal to terminate the vLLM server processes on each node.\n", "\n", "\"drawing\"" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "from functools import partial" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Import the helper class from server_utils.py:" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "sc.addPyFile(\"server_utils.py\")\n", "\n", "from server_utils import VLLMServerManager" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Start vLLM servers" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `VLLMServerManager` will handle the lifecycle of vLLM server instances across the Spark cluster:\n", "- Find available ports for HTTP\n", "- Deploy a server on each node via stage-level scheduling\n", "- Gracefully shutdown servers across nodes" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "model_name = \"qwen-2.5-7b\"\n", "server_manager = VLLMServerManager(model_name=model_name, model_path=model_path)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can pass any of the supported [vLLM serve CLI arguments](https://docs.vllm.ai/en/stable/serving/openai_compatible_server.html#vllm-serve) as key-word arguments when starting the servers. Note that this can take some time, as it includes loading the model from disk, Torch compilation, and capturing CUDA graphs." ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[2025-03-24 11:37:57] INFO server_utils.py:359: Requesting stage-level resources: (cores=5, gpu=1.0)\n", "[2025-03-24 11:37:57] INFO server_utils.py:390: Starting 1 VLLM servers.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "data": { "text/plain": [ "{'cb4ae00-lcedt': (4022579, [7000])}" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "server_manager.start_servers(gpu_memory_utilization=0.95,\n", " max_model_len=6600,\n", " task=\"generate\",\n", " enforce_eager=enforce_eager,\n", " wait_retries=60)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Define client function" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Get the hostname -> url mapping from the server manager:" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "host_to_http_url = server_manager.host_to_http_url" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Define the vLLM inference function, which returns a predict function for batch inference through the server:" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "def vllm_fn(model_name, host_to_url):\n", " import socket\n", " import numpy as np\n", " import requests\n", "\n", " url = host_to_url[socket.gethostname()]\n", " \n", " def predict(inputs):\n", " response = requests.post(\n", " f\"{url}/v1/completions\",\n", " json={\n", " \"model\": model_name,\n", " \"prompt\": inputs.tolist(),\n", " \"max_tokens\": 128,\n", " \"temperature\": 0.7,\n", " \"top_p\": 0.8,\n", " \"repetition_penalty\": 1.05,\n", " }\n", " )\n", " return np.array([r[\"text\"] for r in response.json()[\"choices\"]])\n", " \n", " return predict" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "generate = predict_batch_udf(partial(vllm_fn, model_name=model_name, host_to_url=host_to_http_url),\n", " return_type=StringType(),\n", " batch_size=32)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Load DataFrame\n", "\n", "We'll parallelize over a small set of prompts for demonstration." ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "df = spark.read.parquet(data_path).limit(256).repartition(8)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Run Inference" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 11:==================================================> (7 + 1) / 8]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 7.53 ms, sys: 2.19 ms, total: 9.72 ms\n", "Wall time: 13.9 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "%%time\n", "# first pass caches model/fn and does JIT compilation\n", "preds = df.withColumn(\"outputs\", generate(col(\"prompt\")))\n", "results = preds.collect()" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 17:===========================================> (6 + 2) / 8]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 10.7 ms, sys: 3.65 ms, total: 14.3 ms\n", "Wall time: 6.26 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "%%time\n", "preds = df.withColumn(\"outputs\", generate(col(\"prompt\")))\n", "results = preds.collect()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Sample output:" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Q: <|im_start|>system\n", "You are a knowledgeable AI assistant. Your job is to create a 1 sentence summary \n", "of a research abstract that captures the main objective, methodology, and key findings, using clear \n", "language while preserving technical accuracy and quantitative results.<|im_end|>\n", "<|im_start|>user\n", " Images can be segmented by first using a classifier to predict an affinity\n", "graph that reflects the degree to which image pixels must be grouped together\n", "and then partitioning the graph to yield a segmentation. Machine learning has\n", "been applied to the affinity classifier to produce affinity graphs that are\n", "good in the sense of minimizing edge misclassification rates. However, this\n", "error measure is only indirectly related to the quality of segmentations\n", "produced by ultimately partitioning the affinity graph. We present the first\n", "machine learning algorithm for training a classifier to produce affinity graphs\n", "that are good in the sense of producing segmentations that directly minimize\n", "the Rand index, a well known segmentation performance measure. The Rand index\n", "measures segmentation performance by quantifying the classification of the\n", "connectivity of image pixel pairs after segmentation. By using the simple graph\n", "partitioning algorithm of finding the connected components of the thresholded\n", "affinity graph, we are able to train an affinity classifier to directly\n", "minimize the Rand index of segmentations resulting from the graph partitioning.\n", "Our learning algorithm corresponds to the learning of maximin affinities\n", "between image pixel pairs, which are predictive of the pixel-pair connectivity.\n", "<|im_end|>\n", "<|im_start|>assistant\n", " \n", "\n", "A: The research presents a machine learning algorithm that trains an affinity classifier to directly minimize the Rand index of image segmentations by producing affinity graphs optimized for pixel-pair connectivity, using a simple graph partitioning method. \n", "\n" ] } ], "source": [ "print(f\"Q: {results[0].prompt} \\n\")\n", "print(f\"A: {results[0].outputs} \\n\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Shut down server on each executor" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[2025-03-24 11:38:49] INFO server_utils.py:359: Requesting stage-level resources: (cores=5, gpu=1.0)\n", "[2025-03-24 11:38:50] INFO server_utils.py:447: Successfully stopped 1 VLLM servers.\n" ] }, { "data": { "text/plain": [ "[True]" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "server_manager.stop_servers()" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [], "source": [ "if not on_databricks: # on databricks, spark.stop() puts the cluster in a bad state\n", " spark.stop()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "spark-dl-vllm", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.11" } }, "nbformat": 4, "nbformat_minor": 2 } ================================================ FILE: examples/ML+DL-Examples/Spark-DL/dl_inference/vllm_requirements.txt ================================================ # Copyright (c) 2025, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. pyspark>=3.4.0 datasets vllm ipywidgets jupyterlab ================================================ FILE: examples/ML+DL-Examples/Spark-Rapids-ML/pca/README.md ================================================ # Spark-Rapids-ML PCA example This is an example of the GPU accelerated PCA algorithm from the [Spark-Rapids-ML](https://github.com/NVIDIA/spark-rapids-ml) library, which provides PySpark ML compatible algorithms powered by RAPIDS cuML. The notebook uses PCA to reduce a random dataset with 2048 feature dimensions to 3 dimensions. We train both the GPU and CPU algorithms for comparison. ## Build Please refer to the Spark-Rapids-ML [README](https://github.com/NVIDIA/spark-rapids-ml/blob/HEAD/python) to setup the RAPIDS conda environment and install Spark-Rapids-ML dependencies. ## Download RAPIDS Jar from Maven Central Download the [Spark-Rapids plugin](https://nvidia.github.io/spark-rapids/docs/download.html#download-rapids-accelerator-for-apache-spark-v24081). For Spark-RAPIDS-ML version 26.02.0, download the RAPIDS jar from Maven Central: [rapids-4-spark_2.12-26.02.0.jar](https://repo1.maven.org/maven2/com/nvidia/rapids-4-spark_2.12/26.02.0/rapids-4-spark_2.12-26.02.0.jar). ## Running the Notebooks Once you have built your environment, please follow these instructions to run the notebooks. Make sure `jupyterlab` is installed in the environment. **Note**: for demonstration purposes, these examples just use a local Spark Standalone cluster with a single executor, but you should be able to run them on any distributed Spark cluster. ``` # setup environment variables export SPARK_HOME=/path/to/spark export RAPIDS_JAR=/path/to/rapids.jar # launches the standalone cluster and jupyter with pyspark ./start-spark-rapids.sh # BROWSE to localhost:8888 to view/run notebooks # stop spark standalone cluster ${SPARK_HOME}/sbin/stop-worker.sh; ${SPARK_HOME}/sbin/stop-master.sh ``` ================================================ FILE: examples/ML+DL-Examples/Spark-Rapids-ML/pca/notebooks/pca.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Principal Component Analysis (PCA)\n", "\n", "In this notebook, we will demonstrate the end-to-end workflow of Spark RAPIDS accelerated PCA." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import time\n", "import os" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "No active Spark session found, initializing manually.\n", "File already exists. Skipping download.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "24/10/04 18:04:27 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)\n", "24/10/04 18:04:27 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\n", "24/10/04 18:04:27 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n", "Setting default log level to \"WARN\".\n", "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n", "24/10/04 18:04:27 WARN RapidsPluginUtils: RAPIDS Accelerator 25.02.1 using cudf 25.02.1, private revision 9fac64da220ddd6bf5626bd7bd1dd74c08603eac\n", "24/10/04 18:04:27 WARN RapidsPluginUtils: RAPIDS Accelerator is enabled, to disable GPU support set `spark.rapids.sql.enabled` to false.\n", "24/10/04 18:04:31 WARN GpuDeviceManager: RMM pool is disabled since spark.rapids.memory.gpu.pooling.enabled is set to false; however, this configuration is deprecated and the behavior may change in a future release.\n" ] } ], "source": [ "from pyspark.sql import SparkSession\n", "from pyspark import SparkConf\n", "\n", "def get_rapids_jar():\n", " import os\n", " import requests\n", "\n", " SPARK_RAPIDS_VERSION = \"26.02.0\"\n", " rapids_jar = f\"rapids-4-spark_2.12-{SPARK_RAPIDS_VERSION}.jar\"\n", " if not os.path.exists(rapids_jar):\n", " print(\"Downloading spark rapids jar\")\n", " url = f\"https://repo1.maven.org/maven2/com/nvidia/rapids-4-spark_2.12/{SPARK_RAPIDS_VERSION}/{rapids_jar}\"\n", " response = requests.get(url)\n", " if response.status_code == 200:\n", " with open(rapids_jar, \"wb\") as f:\n", " f.write(response.content)\n", " print(f\"File '{rapids_jar}' downloaded and saved successfully.\")\n", " else:\n", " print(f\"Failed to download the file. Status code: {response.status_code}\")\n", " else:\n", " print(\"File already exists. Skipping download.\")\n", " return rapids_jar\n", "\n", "def initialize_spark(rapids_jar: str):\n", " '''\n", " If no active Spark session is found, initialize and configure a new one. \n", " '''\n", " import socket\n", " hostname = socket.gethostname()\n", "\n", " conf = SparkConf()\n", " conf.setMaster(f\"spark://{hostname}:7077\") # Assuming master is on host and default port. \n", " conf.set(\"spark.task.maxFailures\", \"1\")\n", " conf.set(\"spark.driver.memory\", \"10g\")\n", " conf.set(\"spark.executor.memory\", \"8g\")\n", " conf.set(\"spark.rpc.message.maxSize\", \"1024\")\n", " conf.set(\"spark.sql.pyspark.jvmStacktrace.enabled\", \"true\")\n", " conf.set(\"spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled\", \"false\")\n", " conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"true\")\n", " conf.set(\"spark.python.worker.reuse\", \"true\")\n", " conf.set(\"spark.rapids.ml.uvm.enabled\", \"true\")\n", " conf.set(\"spark.jars\", rapids_jar)\n", " conf.set(\"spark.executorEnv.PYTHONPATH\", rapids_jar)\n", " conf.set(\"spark.rapids.memory.gpu.minAllocFraction\", \"0.0001\")\n", " conf.set(\"spark.plugins\", \"com.nvidia.spark.SQLPlugin\")\n", " conf.set(\"spark.locality.wait\", \"0s\")\n", " conf.set(\"spark.sql.cache.serializer\", \"com.nvidia.spark.ParquetCachedBatchSerializer\")\n", " conf.set(\"spark.rapids.memory.gpu.pooling.enabled\", \"false\")\n", " conf.set(\"spark.sql.execution.sortBeforeRepartition\", \"false\")\n", " conf.set(\"spark.rapids.sql.format.parquet.reader.type\", \"MULTITHREADED\")\n", " conf.set(\"spark.rapids.sql.format.parquet.multiThreadedRead.maxNumFilesParallel\", \"20\")\n", " conf.set(\"spark.rapids.sql.multiThreadedRead.numThreads\", \"20\")\n", " conf.set(\"spark.rapids.sql.python.gpu.enabled\", \"true\")\n", " conf.set(\"spark.rapids.memory.pinnedPool.size\", \"2G\")\n", " conf.set(\"spark.python.daemon.module\", \"rapids.daemon\")\n", " conf.set(\"spark.rapids.sql.batchSizeBytes\", \"512m\")\n", " conf.set(\"spark.sql.adaptive.enabled\", \"false\")\n", " conf.set(\"spark.sql.files.maxPartitionBytes\", \"512m\")\n", " conf.set(\"spark.rapids.sql.concurrentGpuTasks\", \"1\")\n", " conf.set(\"spark.sql.execution.arrow.maxRecordsPerBatch\", \"20000\")\n", " conf.set(\"spark.rapids.sql.explain\", \"NONE\")\n", " \n", " spark = SparkSession.builder.appName(\"spark-rapids-ml-pca\").config(conf=conf).getOrCreate()\n", " return spark\n", "\n", "# Check if Spark session is already active, if not, initialize it\n", "if 'spark' not in globals():\n", " print(\"No active Spark session found, initializing manually.\")\n", " rapids_jar = os.environ.get('RAPIDS_JAR')\n", " if rapids_jar is None:\n", " rapids_jar = get_rapids_jar()\n", " spark = initialize_spark(rapids_jar)\n", "else:\n", " print(\"Using existing Spark session.\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Generate synthetic dataset\n", "\n", "Here we generate a 100,000 x 2048 random dataset." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "24/10/04 18:04:45 WARN TaskSetManager: Stage 0 contains a task of very large size (160085 KiB). The maximum recommended task size is 1000 KiB.\n", " \r" ] } ], "source": [ "rows = 100000\n", "dim = 2048\n", "dtype = 'float32'\n", "np.random.seed(42)\n", "\n", "data = np.random.rand(rows, dim).astype(dtype)\n", "pd_data = pd.DataFrame({\"features\": list(data)})\n", "prepare_df = spark.createDataFrame(pd_data)\n", "prepare_df.write.mode(\"overwrite\").parquet(\"PCA_data.parquet\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Spark-RAPIDS-ML accepts ArrayType input\n", "\n", "Note that in the original Spark-ML PCA, we must `Vectorize` the input column:\n", "\n", "```python\n", "from pyspark.ml.linalg import Vectors\n", "data = [(Vectors.sparse(5, [(1, 1.0), (3, 7.0)]),),\n", " (Vectors.dense([2.0, 0.0, 3.0, 4.0, 5.0]),),\n", " (Vectors.dense([4.0, 0.0, 0.0, 6.0, 7.0]),)]\n", "df = spark.createDataFrame(data,[\"features\"])\n", "df.show()\n", "```\n", "\n", "...whereas the Spark-RAPIDS-ML version does not require extra Vectorization, and can accept an ArrayType column as the input column:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "root\n", " |-- features: array (nullable = true)\n", " | |-- element: float (containsNull = true)\n", "\n" ] } ], "source": [ "data_df = spark.read.parquet(\"PCA_data.parquet\")\n", "data_df.printSchema()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Using Spark-RAPIDS-ML PCA (GPU)\n", "\n", "Compared to the Spark-ML PCA training API:\n", "\n", "```python\n", "from pyspark.ml.feature import PCA\n", "pca = PCA(k=3, inputCol=\"features\")\n", "pca.setOutputCol(\"pca_features\")\n", "```\n", "\n", "We use a customized class which requires **no code change** from the user to enjoy GPU acceleration:\n", "\n", "```python\n", "from spark_rapids_ml.feature import PCA\n", "pca = PCA(k=3, inputCol=\"features\")\n", "pca.setOutputCol(\"pca_features\")\n", "```" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "PCA_570681141389" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from spark_rapids_ml.feature import PCA\n", "\n", "gpu_pca = PCA(k=2, inputCol=\"features\")\n", "gpu_pca.setOutputCol(\"pca_features\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The PCA estimator object can be persisted and reloaded." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "estimator_path = \"/tmp/pca_estimator\"\n", "gpu_pca.write().overwrite().save(estimator_path)\n", "gpu_pca_loaded = PCA.load(estimator_path)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Fit" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "24/10/04 18:04:58 WARN MultiFileReaderThreadPool: Configuring the file reader thread pool with a max of 32 threads instead of spark.rapids.sql.multiThreadedRead.numThreads = 20\n", "2024-10-04 18:04:58,487 - spark_rapids_ml.feature.PCA - INFO - CUDA managed memory enabled.\n", "2024-10-04 18:04:58,570 - spark_rapids_ml.feature.PCA - INFO - Training spark-rapids-ml with 1 worker(s) ...\n", "INFO: Process 2762394 found CUDA visible device(s): 0\n", "2024-10-04 18:05:01,613 - spark_rapids_ml.feature.PCA - INFO - Loading data into python worker memory\n", "2024-10-04 18:05:02,551 - spark_rapids_ml.feature.PCA - INFO - Initializing cuml context\n", "2024-10-04 18:05:03,795 - spark_rapids_ml.feature.PCA - INFO - Invoking cuml fit\n", "2024-10-04 18:05:05,326 - spark_rapids_ml.feature.PCA - INFO - Cuml fit complete\n", "2024-10-04 18:05:06,858 - spark_rapids_ml.feature.PCA - INFO - Finished training\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "GPU PCA fit took: 8.90433144569397 sec\n" ] } ], "source": [ "start_time = time.time()\n", "gpu_pca_model = gpu_pca_loaded.fit(data_df)\n", "gpu_fit_time = time.time() - start_time\n", "print(f\"GPU PCA fit took: {gpu_fit_time} sec\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Transform" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+---------------------------+\n", "|pca_features |\n", "+---------------------------+\n", "|[0.062363233, 0.4037608] |\n", "|[0.49734917, 0.703541] |\n", "|[0.0035427138, 0.29358602] |\n", "|[-0.06798951, 0.37400067] |\n", "|[0.10075127, 0.34651726] |\n", "|[-0.22320557, 0.6660976] |\n", "|[0.49608234, 0.6761328] |\n", "|[0.25515205, 0.20352581] |\n", "|[-0.5102935, 0.319284] |\n", "|[-0.5109488, 0.2756377] |\n", "|[0.411546, -0.17954555] |\n", "|[0.21616393, -0.46268395] |\n", "|[-0.0924304, 0.65660465] |\n", "|[0.12355948, 0.9478601] |\n", "|[0.49234354, 0.63746333] |\n", "|[-0.86077166, 0.0037032962]|\n", "|[-0.013956882, 0.663955] |\n", "|[-0.30510652, 0.02372247] |\n", "|[-0.05999008, 0.28261736] |\n", "|[0.36605445, 0.9674797] |\n", "+---------------------------+\n", "only showing top 20 rows\n", "\n", "GPU PCA transform took: 0.43911027908325195 sec\n" ] } ], "source": [ "start_time = time.time()\n", "embeddings = gpu_pca_model.transform(data_df).select(\"pca_features\").show(truncate=False)\n", "gpu_transform_time = time.time() - start_time\n", "print(f\"GPU PCA transform took: {gpu_transform_time} sec\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Using Spark-ML PCA (CPU)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "PCA_58add243f20d" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from pyspark.ml.feature import PCA\n", "\n", "cpu_pca = PCA(k=2, inputCol=\"features\")\n", "cpu_pca.setOutputCol(\"pca_features\")" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "root\n", " |-- features: vector (nullable = true)\n", "\n" ] } ], "source": [ "from pyspark.ml.functions import array_to_vector\n", "\n", "vector_df = data_df.select(array_to_vector(\"features\").alias(\"features\"))\n", "vector_df.printSchema()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Fit" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "24/10/04 17:07:07 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.lapack.JNILAPACK\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU PCA fit took: 63.37388610839844 sec\n" ] } ], "source": [ "start_time = time.time()\n", "cpu_pca_model = cpu_pca.fit(vector_df)\n", "pca_fit_time = time.time() - start_time\n", "print(f\"CPU PCA fit took: {pca_fit_time} sec\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Transform" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+-------------------------------------------+\n", "|pca_features |\n", "+-------------------------------------------+\n", "|[0.24926765828229927,0.3425432972889563] |\n", "|[-0.5175207040808384,0.48893065865444574] |\n", "|[-0.2505049373829902,0.381272141155778] |\n", "|[-0.39046980420292005,0.4870705091697811] |\n", "|[-0.4024088726395023,0.707133448810984] |\n", "|[-0.3061227832285992,0.5363554872099332] |\n", "|[-0.6065136982526093,0.5205197626985932] |\n", "|[-0.21870566838630084,0.6516598402789231] |\n", "|[0.1910036552854184,0.6336513389989592] |\n", "|[0.6139537641786907,0.6055187085018856] |\n", "|[-0.026502904776425647,-0.0366087508156753]|\n", "|[-0.2989311781309336,-0.05136110567458389] |\n", "|[-0.5474468086054212,-0.18779964958125014] |\n", "|[-0.6644746232216499,0.10351178251944647] |\n", "|[-0.12685301272617464,0.47394431583661295] |\n", "|[-0.4355221246718862,-0.00346289187881239] |\n", "|[0.6222719258951077,0.5488293416698503] |\n", "|[0.04966907735703511,0.7138677407505005] |\n", "|[0.6260486995906139,0.3553228450428632] |\n", "|[0.16396683091519929,0.7382693234881972] |\n", "+-------------------------------------------+\n", "only showing top 20 rows\n", "\n", "CPU PCA transform took: 0.19607114791870117 sec\n" ] } ], "source": [ "start_time = time.time()\n", "embeddings = cpu_pca_model.transform(vector_df).select(\"pca_features\").show(truncate=False)\n", "pca_transform_time = time.time() - start_time\n", "print(f\"CPU PCA transform took: {pca_transform_time} sec\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Summary" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU runtime: (64.02s + 0.20s)\n", "GPU runtime: (8.76s + 0.42s)\n", "End-to-end speedup: CPU / GPU = 7.00x\n" ] } ], "source": [ "speedup = (pca_fit_time + pca_transform_time) / (gpu_fit_time + gpu_transform_time)\n", "print(f\"CPU runtime: ({pca_fit_time:.2f}s + {pca_transform_time:.2f}s)\")\n", "print(f\"GPU runtime: ({gpu_fit_time:.2f}s + {gpu_transform_time:.2f}s)\")\n", "print(f\"End-to-end speedup: CPU / GPU = {speedup:.2f}x\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "rapids-25.02", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.10" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: examples/ML+DL-Examples/Spark-Rapids-ML/pca/start-spark-rapids.sh ================================================ #!/bin/bash # # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Check if SPARK_HOME is set if [ -z "$SPARK_HOME" ]; then echo "Please set the SPARK_HOME environment variable before running this script." exit 1 fi # Check if RAPIDS_JAR is set if [ -z "$RAPIDS_JAR" ]; then echo "Please set the RAPIDS_JAR environment variable before running this script." exit 1 fi # Configuration MASTER_HOSTNAME=$(hostname) MASTER=spark://${MASTER_HOSTNAME}:7077 CORES_PER_WORKER=8 MEMORY_PER_WORKER=16G # Environment variables export SPARK_HOME=${SPARK_HOME} export MASTER=${MASTER} export SPARK_WORKER_INSTANCES=1 export CORES_PER_WORKER=${CORES_PER_WORKER} export PYSPARK_DRIVER_PYTHON=jupyter export PYSPARK_DRIVER_PYTHON_OPTS='lab' # Start standalone cluster echo "Starting Spark standalone cluster..." ${SPARK_HOME}/sbin/start-master.sh ${SPARK_HOME}/sbin/start-worker.sh -c ${CORES_PER_WORKER} -m ${MEMORY_PER_WORKER} ${MASTER} # Start Jupyter with PySpark echo "Launching PySpark with Jupyter..." ${SPARK_HOME}/bin/pyspark --master ${MASTER} \ --driver-memory 10G \ --executor-memory 8G \ --conf spark.task.maxFailures=1 \ --conf spark.rpc.message.maxSize=1024 \ --conf spark.sql.pyspark.jvmStacktrace.enabled=true \ --conf spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled=false \ --conf spark.sql.execution.arrow.pyspark.enabled=true \ --conf spark.python.worker.reuse=true \ --conf spark.rapids.ml.uvm.enabled=true \ --conf spark.jars=${RAPIDS_JAR} \ --conf spark.executorEnv.PYTHONPATH=${RAPIDS_JAR} \ --conf spark.rapids.memory.gpu.minAllocFraction=0.0001 \ --conf spark.plugins=com.nvidia.spark.SQLPlugin \ --conf spark.locality.wait=0s \ --conf spark.sql.cache.serializer=com.nvidia.spark.ParquetCachedBatchSerializer \ --conf spark.rapids.memory.gpu.pooling.enabled=false \ --conf spark.sql.execution.sortBeforeRepartition=false \ --conf spark.rapids.sql.format.parquet.reader.type=MULTITHREADED \ --conf spark.rapids.sql.format.parquet.multiThreadedRead.maxNumFilesParallel=20 \ --conf spark.rapids.sql.multiThreadedRead.numThreads=20 \ --conf spark.rapids.sql.python.gpu.enabled=true \ --conf spark.rapids.memory.pinnedPool.size=2G \ --conf spark.python.daemon.module=rapids.daemon \ --conf spark.rapids.sql.batchSizeBytes=512m \ --conf spark.sql.adaptive.enabled=false \ --conf spark.sql.files.maxPartitionBytes=512m \ --conf spark.rapids.sql.concurrentGpuTasks=1 \ --conf spark.sql.execution.arrow.maxRecordsPerBatch=20000 \ --conf spark.rapids.sql.explain=NONE ================================================ FILE: examples/SQL+DF-Examples/customer-churn/README.md ================================================ # Customer Churn This demo is derived from [data-science-blueprints](https://github.com/NVIDIA/data-science-blueprints) repository. The repository shows a realistic ETL workflow based on synthetic normalized data. It consists of two pieces: 1. _an augmentation notebook_, which synthesizes normalized (long-form) data from a wide-form input file, optionally augmenting it by duplicating records, and 2. _an ETL notebook_, which performs joins and aggregations in order to generate wide-form data from the synthetic long-form data. To learn more about the customer churn use case, you can read our [ebook](https://www.nvidia.com/en-us/ai-data-science/resources/churn-prediction-blueprint/). ================================================ FILE: examples/SQL+DF-Examples/customer-churn/notebooks/python/README.md ================================================ # telco-churn-augmentation This demo shows a realistic ETL workflow based on synthetic normalized data. It consists of two pieces: 1. _an [augmentation notebook](augment.ipynb)_, which synthesizes normalized (long-form) data from a wide-form input file, optionally augmenting it by duplicating records, and 2. _an [ETL notebook](etl.ipynb)_, which performs joins and aggregations in order to generate wide-form data from the synthetic long-form data. From a performance evaluation perspective, the latter is the interesting workload; the former is just a data generator for the latter. ================================================ FILE: examples/SQL+DF-Examples/customer-churn/notebooks/python/augment.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Customer churn augment\n", "\n", "This notebook is derived from [customer churn augment notebook](https://github.com/NVIDIA/data-science-blueprints/blob/main/churn/augment.ipynb), please refer to this [git repo](https://github.com/NVIDIA/data-science-blueprints/tree/main/churn) for more detail information.\n", "\n", " " ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "tags": [ "parameters" ] }, "outputs": [], "source": [ "# notebook parameters\n", "\n", "import os\n", "\n", "spark_master = os.getenv(\"SPARK_MASTER_URL\", \"spark://ip:port\")\n", "app_name = \"augment\"\n", "dataRoot = os.getenv(\"DATA_ROOT\", \"data\")\n", "input_file = os.path.join(dataRoot, \"WA_Fn-UseC_-Telco-Customer-Churn-.csv\")\n", "output_mode = \"overwrite\"\n", "output_kind = \"parquet\"\n", "driver_memory = '12g'\n", "executor_memory = '8g'\n", "\n", "dup_times = 100\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import churn.augment\n", "\n", "churn.augment.register_options(\n", " spark_master = spark_master,\n", " app_name = app_name,\n", " input_file = input_file,\n", " output_mode = output_mode,\n", " output_kind = output_kind,\n", " driver_memory = driver_memory,\n", " executor_memory = executor_memory,\n", " dup_times = dup_times,\n", " use_decimal = True\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Sanity-checking\n", "\n", "We're going to make sure we're running with a compatible JVM first — if we run on macOS, we might get one that doesn't work with Scala." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "from os import getenv" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'/data/usr/lib/jvm/java-8-openjdk-amd64'" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "getenv(\"JAVA_HOME\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Spark setup" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "import pyspark" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "

SparkSession - hive

\n", " \n", "
\n", "

SparkContext

\n", "\n", "

Spark UI

\n", "\n", "
\n", "
Version
\n", "
v3.2.0
\n", "
Master
\n", "
spark://yuanli-System-Product-Name:7077
\n", "
AppName
\n", "
PySparkShell
\n", "
\n", "
\n", " \n", "
\n", " " ], "text/plain": [ "" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "session = pyspark.sql.SparkSession.builder \\\n", " .master(spark_master) \\\n", " .appName(app_name) \\\n", " .config(\"spark.driver.memory\", driver_memory) \\\n", " .config(\"spark.executor.memory\", executor_memory) \\\n", " .getOrCreate()\n", "session" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Schema definition\n", "\n", "Most of the fields are strings representing booleans or categoricals, but a few (`tenure`, `MonthlyCharges`, and `TotalCharges`) are numeric." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "read 7043 records from source dataset (7032 non-null records)\n" ] } ], "source": [ "from churn.augment import load_supplied_data\n", "\n", "df = load_supplied_data(session, input_file)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Splitting the data frame\n", "\n", "The training data schema looks like this:\n", "\n", "- customerID\n", "- gender\n", "- SeniorCitizen\n", "- Partner\n", "- Dependents\n", "- tenure\n", "- PhoneService\n", "- MultipleLines\n", "- InternetService\n", "- OnlineSecurity\n", "- OnlineBackup\n", "- DeviceProtection\n", "- TechSupport\n", "- StreamingTV\n", "- StreamingMovies\n", "- Contract\n", "- PaperlessBilling\n", "- PaymentMethod\n", "- MonthlyCharges\n", "- TotalCharges\n", "- Churn\n", "\n", "We want to divide the data frame into several frames that we can join together in an ETL job.\n", "\n", "Those frames will look like this:\n", "\n", "- **Customer metadata**\n", " - customerID\n", " - gender\n", " - date of birth (we'll derive age and senior citizen status from this)\n", " - Partner\n", " - Dependents\n", " - (nominal) MonthlyCharges\n", "- **Billing events**\n", " - customerID\n", " - date (we'll derive tenure from the number/duration of billing events)\n", " - kind (one of \"AccountCreation\", \"Charge\", or \"AccountTermination\")\n", " - value (either a positive nonzero amount or 0.00; we'll derive TotalCharges from the sum of amounts and Churn from the existence of an AccountTermination event)\n", "- **Customer phone features**\n", " - customerID\n", " - feature (one of \"PhoneService\" or \"MultipleLines\")\n", "- **Customer internet features**\n", " - customerID\n", " - feature (one of \"InternetService\", \"OnlineSecurity\", \"OnlineBackup\", \"DeviceProtection\", \"TechSupport\", \"StreamingTV\", \"StreamingMovies\")\n", " - value (one of \"Fiber\", \"DSL\", \"Yes\", \"No\")\n", "- **Customer account features**\n", " - customerID\n", " - feature (one of \"Contract\", \"PaperlessBilling\", \"PaymentMethod\")\n", " - value (one of \"Month-to-month\", \"One year\", \"Two year\", \"No\", \"Yes\", \"Credit card (automatic)\", \"Mailed check\", \"Bank transfer (automatic)\", \"Electronic check\")" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "root\n", " |-- customerID: string (nullable = true)\n", " |-- gender: string (nullable = true)\n", " |-- SeniorCitizen: string (nullable = true)\n", " |-- Partner: string (nullable = true)\n", " |-- Dependents: string (nullable = true)\n", " |-- tenure: double (nullable = true)\n", " |-- PhoneService: string (nullable = true)\n", " |-- MultipleLines: string (nullable = true)\n", " |-- InternetService: string (nullable = true)\n", " |-- OnlineSecurity: string (nullable = true)\n", " |-- OnlineBackup: string (nullable = true)\n", " |-- DeviceProtection: string (nullable = true)\n", " |-- TechSupport: string (nullable = true)\n", " |-- StreamingTV: string (nullable = true)\n", " |-- StreamingMovies: string (nullable = true)\n", " |-- Contract: string (nullable = true)\n", " |-- PaperlessBilling: string (nullable = true)\n", " |-- PaymentMethod: string (nullable = true)\n", " |-- MonthlyCharges: double (nullable = true)\n", " |-- TotalCharges: double (nullable = true)\n", " |-- Churn: string (nullable = true)\n", "\n" ] } ], "source": [ "df.printSchema()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll start by generating a series of monthly charges, then a series of account creation events, and finally a series of churn events. `billingEvents` is the data frame containing all of these events: account activation, account termination, and individual payment events." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/yuanli/work/spark-3.2.0-bin-hadoop3.2/python/pyspark/sql/functions.py:1353: FutureWarning: Deprecated in 3.2, use shiftright instead.\n", " warnings.warn(\"Deprecated in 3.2, use shiftright instead.\", FutureWarning)\n" ] } ], "source": [ "from churn.augment import billing_events\n", "billingEvents = billing_events(df)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Our next step is to generate customer metadata, which includes the following fields:\n", "\n", " - gender\n", " - date of birth (we'll derive age and senior citizen status from this)\n", " - Partner\n", " - Dependents\n", " \n", "We'll calculate date of birth by using the hash of the customer ID as a pseudorandom number and then assuming that ages are uniformly distributed between 18-65 and exponentially distributed over 65." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-04-05 09:36:31,848 WARN conf.HiveConf: HiveConf of name hive.stats.jdbc.timeout does not exist\n", "2022-04-05 09:36:31,849 WARN conf.HiveConf: HiveConf of name hive.stats.retries.wait does not exist\n", "2022-04-05 09:36:33,683 WARN metastore.ObjectStore: Version information not found in metastore. hive.metastore.schema.verification is not enabled so recording the schema version 2.3.0\n", "2022-04-05 09:36:33,683 WARN metastore.ObjectStore: setMetaStoreSchemaVersion called but recording version is disabled: version = 2.3.0, comment = Set by MetaStore yuanli@127.0.1.1\n", "2022-04-05 09:36:33,811 WARN metastore.ObjectStore: Failed to get database global_temp, returning NoSuchObjectException\n", "2022-04-05 09:36:33,892 WARN rapids.GpuOverrides: \n", "! cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.execution.LocalTableScanExec\n", " @Expression name#326 could run on GPU\n", " @Expression database#327 could run on GPU\n", " @Expression description#328 could run on GPU\n", " @Expression tableType#329 could run on GPU\n", " @Expression isTemporary#330 could run on GPU\n", "\n", "2022-04-05 09:36:33,960 WARN rapids.GpuOverrides: \n", " ! cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.execution.RDDScanExec\n", " @Expression u_value#337 could run on GPU\n", "\n", " \r" ] } ], "source": [ "from churn.augment import customer_meta\n", "customerMeta = customer_meta(df)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we can generate customer phone features, which include:\n", "\n", " - customerID\n", " - feature (one of \"PhoneService\" or \"MultipleLines\")\n", " - value (always \"Yes\"; there are no records for \"No\" or \"No Phone Service\")" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "from churn.augment import phone_features\n", "customerPhoneFeatures = phone_features(df)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Customer internet features include:\n", " - customerID\n", " - feature (one of \"InternetService\", \"OnlineSecurity\", \"OnlineBackup\", \"DeviceProtection\", \"TechSupport\", \"StreamingTV\", \"StreamingMovies\")\n", " - value (one of \"Fiber\", \"DSL\", \"Yes\" -- no records for \"No\" or \"No internet service\")" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "from churn.augment import internet_features\n", "customerInternetFeatures = internet_features(df)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Customer account features include:\n", "\n", " - customerID\n", " - feature (one of \"Contract\", \"PaperlessBilling\", \"PaymentMethod\")\n", " - value (one of \"Month-to-month\", \"One year\", \"Two year\", \"Yes\", \"Credit card (automatic)\", \"Mailed check\", \"Bank transfer (automatic)\", \"Electronic check\")" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "from churn.augment import account_features\n", "customerAccountFeatures = account_features(df)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Write outputs" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-04-05 09:36:36,792 WARN rapids.GpuOverrides: \n", "! cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.execution.LocalTableScanExec\n", " @Expression name#798 could run on GPU\n", " @Expression database#799 could run on GPU\n", " @Expression description#800 could run on GPU\n", " @Expression tableType#801 could run on GPU\n", " @Expression isTemporary#802 could run on GPU\n", "\n", "2022-04-05 09:36:37,142 WARN rapids.GpuOverrides: \n", "!Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression format_string(%s-%s, customerID#0, u_value#337) AS customerID#816 could run on GPU\n", " ! format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\n", " @Expression %s-%s could run on GPU\n", " @Expression customerID#0 could run on GPU\n", " @Expression u_value#337 could run on GPU\n", " @Expression kind#133 could run on GPU\n", " @Expression value#136 could run on GPU\n", " @Expression date#156 could run on GPU\n", " @Expression month#315 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression customerID#0 could run on GPU\n", " @Expression Charge AS kind#133 could run on GPU\n", " @Expression Charge could run on GPU\n", " @Expression value#136 could run on GPU\n", " @Expression add_months(2022-04-05, cast(-(cast(_we0#149 as bigint) + last_month#135L) as int)) AS date#156 could run on GPU\n", " ! add_months(2022-04-05, cast(-(cast(_we0#149 as bigint) + last_month#135L) as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.AddMonths\n", " @Expression 2022-04-05 could run on GPU\n", " @Expression cast(-(cast(_we0#149 as bigint) + last_month#135L) as int) could run on GPU\n", " @Expression -(cast(_we0#149 as bigint) + last_month#135L) could run on GPU\n", " @Expression (cast(_we0#149 as bigint) + last_month#135L) could run on GPU\n", " @Expression cast(_we0#149 as bigint) could run on GPU\n", " @Expression _we0#149 could run on GPU\n", " @Expression last_month#135L could run on GPU\n", " !Exec cannot run on GPU because Columnar exchange without columnar children is inefficient\n", " @Partitioning could run on GPU\n", " @Expression customerID#0 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression customerID#0 could run on GPU\n", " @Expression value#136 could run on GPU\n", " @Expression CASE WHEN (Churn#20 = Yes) THEN -((((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) % 6)) ELSE 0 END AS last_month#135L could run on GPU\n", " @Expression CASE WHEN (Churn#20 = Yes) THEN -((((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) % 6)) ELSE 0 END could run on GPU\n", " @Expression (Churn#20 = Yes) could run on GPU\n", " @Expression Churn#20 could run on GPU\n", " @Expression Yes could run on GPU\n", " @Expression -((((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) % 6)) could run on GPU\n", " @Expression ((((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) % 6)) could run on GPU\n", " @Expression (((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10)) could run on GPU\n", " @Expression ((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) could run on GPU\n", " @Expression (((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) could run on GPU\n", " @Expression ((abs(xxhash64(customerID#0, 42), false) & 255) % 36) could run on GPU\n", " @Expression (abs(xxhash64(customerID#0, 42), false) & 255) could run on GPU\n", " @Expression abs(xxhash64(customerID#0, 42), false) could run on GPU\n", " ! xxhash64(customerID#0, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#0 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 36 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#0, 42), false), 8) could run on GPU\n", " @Expression abs(xxhash64(customerID#0, 42), false) could run on GPU\n", " ! xxhash64(customerID#0, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#0 could run on GPU\n", " @Expression 8 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 24 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#0, 42), false), 16) could run on GPU\n", " @Expression abs(xxhash64(customerID#0, 42), false) could run on GPU\n", " ! xxhash64(customerID#0, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#0 could run on GPU\n", " @Expression 16 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 14 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#0, 42), false), 24) could run on GPU\n", " @Expression abs(xxhash64(customerID#0, 42), false) could run on GPU\n", " ! xxhash64(customerID#0, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#0 could run on GPU\n", " @Expression 24 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 10 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) % 6) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#0, 42), false), 32) could run on GPU\n", " @Expression abs(xxhash64(customerID#0, 42), false) could run on GPU\n", " ! xxhash64(customerID#0, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#0 could run on GPU\n", " @Expression 32 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 6 could run on GPU\n", " @Expression 0 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression explode(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int))) could run on GPU\n", " ! array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.ArrayRepeat\n", " @Expression cast((TotalCharges#19 / tenure#5) as decimal(8,2)) could run on GPU\n", " @Expression (TotalCharges#19 / tenure#5) could run on GPU\n", " @Expression TotalCharges#19 could run on GPU\n", " @Expression tenure#5 could run on GPU\n", " !Expression cast(tenure#5 as int) cannot run on GPU because Casting from floating point types to integral types on the GPU supports a slightly different range of values when using Spark 3.1.0 or later. Refer to the CAST documentation for more details.. To enable this operation on the GPU, set spark.rapids.sql.castFloatToIntegralTypes.enabled to true.\n", " @Expression tenure#5 could run on GPU\n", " @Expression customerID#0 could run on GPU\n", " @Expression Churn#20 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression ((atleastnnonnulls(21, customerID#0, gender#1, SeniorCitizen#2, Partner#3, Dependents#4, tenure#5, PhoneService#6, MultipleLines#7, InternetService#8, OnlineSecurity#9, OnlineBackup#10, DeviceProtection#11, TechSupport#12, StreamingTV#13, StreamingMovies#14, Contract#15, PaperlessBilling#16, PaymentMethod#17, MonthlyCharges#18, TotalCharges#19, Churn#20) AND (size(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)), true) > 0)) AND isnotnull(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)))) could run on GPU\n", " @Expression (atleastnnonnulls(21, customerID#0, gender#1, SeniorCitizen#2, Partner#3, Dependents#4, tenure#5, PhoneService#6, MultipleLines#7, InternetService#8, OnlineSecurity#9, OnlineBackup#10, DeviceProtection#11, TechSupport#12, StreamingTV#13, StreamingMovies#14, Contract#15, PaperlessBilling#16, PaymentMethod#17, MonthlyCharges#18, TotalCharges#19, Churn#20) AND (size(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)), true) > 0)) could run on GPU\n", " @Expression atleastnnonnulls(21, customerID#0, gender#1, SeniorCitizen#2, Partner#3, Dependents#4, tenure#5, PhoneService#6, MultipleLines#7, InternetService#8, OnlineSecurity#9, OnlineBackup#10, DeviceProtection#11, TechSupport#12, StreamingTV#13, StreamingMovies#14, Contract#15, PaperlessBilling#16, PaymentMethod#17, MonthlyCharges#18, TotalCharges#19, Churn#20) could run on GPU\n", " @Expression customerID#0 could run on GPU\n", " @Expression gender#1 could run on GPU\n", " @Expression SeniorCitizen#2 could run on GPU\n", " @Expression Partner#3 could run on GPU\n", " @Expression Dependents#4 could run on GPU\n", " @Expression tenure#5 could run on GPU\n", " @Expression PhoneService#6 could run on GPU\n", " @Expression MultipleLines#7 could run on GPU\n", " @Expression InternetService#8 could run on GPU\n", " @Expression OnlineSecurity#9 could run on GPU\n", " @Expression OnlineBackup#10 could run on GPU\n", " @Expression DeviceProtection#11 could run on GPU\n", " @Expression TechSupport#12 could run on GPU\n", " @Expression StreamingTV#13 could run on GPU\n", " @Expression StreamingMovies#14 could run on GPU\n", " @Expression Contract#15 could run on GPU\n", " @Expression PaperlessBilling#16 could run on GPU\n", " @Expression PaymentMethod#17 could run on GPU\n", " @Expression MonthlyCharges#18 could run on GPU\n", " @Expression TotalCharges#19 could run on GPU\n", " @Expression Churn#20 could run on GPU\n", " @Expression (size(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)), true) > 0) could run on GPU\n", " @Expression size(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)), true) could run on GPU\n", " ! array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.ArrayRepeat\n", " @Expression cast((TotalCharges#19 / tenure#5) as decimal(8,2)) could run on GPU\n", " @Expression (TotalCharges#19 / tenure#5) could run on GPU\n", " @Expression TotalCharges#19 could run on GPU\n", " @Expression tenure#5 could run on GPU\n", " !Expression cast(tenure#5 as int) cannot run on GPU because Casting from floating point types to integral types on the GPU supports a slightly different range of values when using Spark 3.1.0 or later. Refer to the CAST documentation for more details.. To enable this operation on the GPU, set spark.rapids.sql.castFloatToIntegralTypes.enabled to true.\n", " @Expression tenure#5 could run on GPU\n", " @Expression 0 could run on GPU\n", " @Expression isnotnull(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int))) could run on GPU\n", " ! array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.ArrayRepeat\n", " @Expression cast((TotalCharges#19 / tenure#5) as decimal(8,2)) could run on GPU\n", " @Expression (TotalCharges#19 / tenure#5) could run on GPU\n", " @Expression TotalCharges#19 could run on GPU\n", " @Expression tenure#5 could run on GPU\n", " !Expression cast(tenure#5 as int) cannot run on GPU because Casting from floating point types to integral types on the GPU supports a slightly different range of values when using Spark 3.1.0 or later. Refer to the CAST documentation for more details.. To enable this operation on the GPU, set spark.rapids.sql.castFloatToIntegralTypes.enabled to true.\n", " @Expression tenure#5 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression customerID#265 could run on GPU\n", " @Expression AccountCreation AS kind#191 could run on GPU\n", " @Expression AccountCreation could run on GPU\n", " @Expression 0.00 AS value#192 could run on GPU\n", " @Expression 0.00 could run on GPU\n", " @Expression add_months(2022-04-05, cast(((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) as int)) AS date#200 could run on GPU\n", " ! add_months(2022-04-05, cast(((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.AddMonths\n", " @Expression 2022-04-05 could run on GPU\n", " !Expression cast(((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) as int) cannot run on GPU because Casting from floating point types to integral types on the GPU supports a slightly different range of values when using Spark 3.1.0 or later. Refer to the CAST documentation for more details.. To enable this operation on the GPU, set spark.rapids.sql.castFloatToIntegralTypes.enabled to true.\n", " @Expression ((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) could run on GPU\n", " @Expression (-tenure#270 - 1.0) could run on GPU\n", " @Expression -tenure#270 could run on GPU\n", " @Expression tenure#270 could run on GPU\n", " @Expression 1.0 could run on GPU\n", " @Expression CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END could run on GPU\n", " @Expression (Churn#285 = Yes) could run on GPU\n", " @Expression Churn#285 could run on GPU\n", " @Expression Yes could run on GPU\n", " @Expression cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) could run on GPU\n", " @Expression -((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) could run on GPU\n", " @Expression ((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) could run on GPU\n", " @Expression (((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) could run on GPU\n", " @Expression ((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) could run on GPU\n", " @Expression (((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) could run on GPU\n", " @Expression ((abs(xxhash64(customerID#265, 42), false) & 255) % 36) could run on GPU\n", " @Expression (abs(xxhash64(customerID#265, 42), false) & 255) could run on GPU\n", " @Expression abs(xxhash64(customerID#265, 42), false) could run on GPU\n", " ! xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#265 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 36 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#265, 42), false), 8) could run on GPU\n", " @Expression abs(xxhash64(customerID#265, 42), false) could run on GPU\n", " ! xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#265 could run on GPU\n", " @Expression 8 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 24 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#265, 42), false), 16) could run on GPU\n", " @Expression abs(xxhash64(customerID#265, 42), false) could run on GPU\n", " ! xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#265 could run on GPU\n", " @Expression 16 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 14 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#265, 42), false), 24) could run on GPU\n", " @Expression abs(xxhash64(customerID#265, 42), false) could run on GPU\n", " ! xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#265 could run on GPU\n", " @Expression 24 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 10 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#265, 42), false), 32) could run on GPU\n", " @Expression abs(xxhash64(customerID#265, 42), false) could run on GPU\n", " ! xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#265 could run on GPU\n", " @Expression 32 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 6 could run on GPU\n", " @Expression 0.0 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression customerID#290 could run on GPU\n", " @Expression AccountTermination AS kind#258 could run on GPU\n", " @Expression AccountTermination could run on GPU\n", " @Expression 0.00 AS value#259 could run on GPU\n", " @Expression 0.00 could run on GPU\n", " @Expression CASE WHEN (Churn#310 = Yes) THEN add_months(2022-04-05, cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int)) ELSE 2022-04-05 END AS date#260 could run on GPU\n", " @Expression CASE WHEN (Churn#310 = Yes) THEN add_months(2022-04-05, cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int)) ELSE 2022-04-05 END could run on GPU\n", " @Expression (Churn#310 = Yes) could run on GPU\n", " @Expression Churn#310 could run on GPU\n", " @Expression Yes could run on GPU\n", " ! add_months(2022-04-05, cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.AddMonths\n", " @Expression 2022-04-05 could run on GPU\n", " @Expression cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int) could run on GPU\n", " @Expression -((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) could run on GPU\n", " @Expression ((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) could run on GPU\n", " @Expression (((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) could run on GPU\n", " @Expression ((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) could run on GPU\n", " @Expression (((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) could run on GPU\n", " @Expression ((abs(xxhash64(customerID#290, 42), false) & 255) % 36) could run on GPU\n", " @Expression (abs(xxhash64(customerID#290, 42), false) & 255) could run on GPU\n", " @Expression abs(xxhash64(customerID#290, 42), false) could run on GPU\n", " ! xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#290 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 36 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#290, 42), false), 8) could run on GPU\n", " @Expression abs(xxhash64(customerID#290, 42), false) could run on GPU\n", " ! xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#290 could run on GPU\n", " @Expression 8 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 24 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#290, 42), false), 16) could run on GPU\n", " @Expression abs(xxhash64(customerID#290, 42), false) could run on GPU\n", " ! xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#290 could run on GPU\n", " @Expression 16 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 14 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#290, 42), false), 24) could run on GPU\n", " @Expression abs(xxhash64(customerID#290, 42), false) could run on GPU\n", " ! xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#290 could run on GPU\n", " @Expression 24 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 10 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#290, 42), false), 32) could run on GPU\n", " @Expression abs(xxhash64(customerID#290, 42), false) could run on GPU\n", " ! xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#290 could run on GPU\n", " @Expression 32 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 6 could run on GPU\n", " @Expression 2022-04-05 could run on GPU\n", "\n", "2022-04-05 09:36:37,176 WARN rapids.GpuOverrides: \n", "!Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression format_string(%s-%s, customerID#0, u_value#337) AS customerID#816 could run on GPU\n", " ! format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\n", " @Expression %s-%s could run on GPU\n", " @Expression customerID#0 could run on GPU\n", " @Expression u_value#337 could run on GPU\n", " @Expression kind#133 could run on GPU\n", " @Expression value#136 could run on GPU\n", " @Expression date#156 could run on GPU\n", " @Expression month#315 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression customerID#0 could run on GPU\n", " @Expression Charge AS kind#133 could run on GPU\n", " @Expression Charge could run on GPU\n", " @Expression value#136 could run on GPU\n", " @Expression add_months(2022-04-05, cast(-(cast(_we0#149 as bigint) + last_month#135L) as int)) AS date#156 could run on GPU\n", " ! add_months(2022-04-05, cast(-(cast(_we0#149 as bigint) + last_month#135L) as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.AddMonths\n", " @Expression 2022-04-05 could run on GPU\n", " @Expression cast(-(cast(_we0#149 as bigint) + last_month#135L) as int) could run on GPU\n", " @Expression -(cast(_we0#149 as bigint) + last_month#135L) could run on GPU\n", " @Expression (cast(_we0#149 as bigint) + last_month#135L) could run on GPU\n", " @Expression cast(_we0#149 as bigint) could run on GPU\n", " @Expression _we0#149 could run on GPU\n", " @Expression last_month#135L could run on GPU\n", " !Exec cannot run on GPU because Columnar exchange without columnar children is inefficient\n", " @Partitioning could run on GPU\n", " @Expression customerID#0 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression customerID#0 could run on GPU\n", " @Expression value#136 could run on GPU\n", " @Expression CASE WHEN (Churn#20 = Yes) THEN -((((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) % 6)) ELSE 0 END AS last_month#135L could run on GPU\n", " @Expression CASE WHEN (Churn#20 = Yes) THEN -((((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) % 6)) ELSE 0 END could run on GPU\n", " @Expression (Churn#20 = Yes) could run on GPU\n", " @Expression Churn#20 could run on GPU\n", " @Expression Yes could run on GPU\n", " @Expression -((((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) % 6)) could run on GPU\n", " @Expression ((((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) % 6)) could run on GPU\n", " @Expression (((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10)) could run on GPU\n", " @Expression ((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) could run on GPU\n", " @Expression (((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) could run on GPU\n", " @Expression ((abs(xxhash64(customerID#0, 42), false) & 255) % 36) could run on GPU\n", " @Expression (abs(xxhash64(customerID#0, 42), false) & 255) could run on GPU\n", " @Expression abs(xxhash64(customerID#0, 42), false) could run on GPU\n", " ! xxhash64(customerID#0, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#0 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 36 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#0, 42), false), 8) could run on GPU\n", " @Expression abs(xxhash64(customerID#0, 42), false) could run on GPU\n", " ! xxhash64(customerID#0, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#0 could run on GPU\n", " @Expression 8 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 24 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#0, 42), false), 16) could run on GPU\n", " @Expression abs(xxhash64(customerID#0, 42), false) could run on GPU\n", " ! xxhash64(customerID#0, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#0 could run on GPU\n", " @Expression 16 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 14 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#0, 42), false), 24) could run on GPU\n", " @Expression abs(xxhash64(customerID#0, 42), false) could run on GPU\n", " ! xxhash64(customerID#0, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#0 could run on GPU\n", " @Expression 24 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 10 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) % 6) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#0, 42), false), 32) could run on GPU\n", " @Expression abs(xxhash64(customerID#0, 42), false) could run on GPU\n", " ! xxhash64(customerID#0, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#0 could run on GPU\n", " @Expression 32 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 6 could run on GPU\n", " @Expression 0 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression explode(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int))) could run on GPU\n", " ! array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.ArrayRepeat\n", " @Expression cast((TotalCharges#19 / tenure#5) as decimal(8,2)) could run on GPU\n", " @Expression (TotalCharges#19 / tenure#5) could run on GPU\n", " @Expression TotalCharges#19 could run on GPU\n", " @Expression tenure#5 could run on GPU\n", " !Expression cast(tenure#5 as int) cannot run on GPU because Casting from floating point types to integral types on the GPU supports a slightly different range of values when using Spark 3.1.0 or later. Refer to the CAST documentation for more details.. To enable this operation on the GPU, set spark.rapids.sql.castFloatToIntegralTypes.enabled to true.\n", " @Expression tenure#5 could run on GPU\n", " @Expression customerID#0 could run on GPU\n", " @Expression Churn#20 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression ((atleastnnonnulls(21, customerID#0, gender#1, SeniorCitizen#2, Partner#3, Dependents#4, tenure#5, PhoneService#6, MultipleLines#7, InternetService#8, OnlineSecurity#9, OnlineBackup#10, DeviceProtection#11, TechSupport#12, StreamingTV#13, StreamingMovies#14, Contract#15, PaperlessBilling#16, PaymentMethod#17, MonthlyCharges#18, TotalCharges#19, Churn#20) AND (size(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)), true) > 0)) AND isnotnull(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)))) could run on GPU\n", " @Expression (atleastnnonnulls(21, customerID#0, gender#1, SeniorCitizen#2, Partner#3, Dependents#4, tenure#5, PhoneService#6, MultipleLines#7, InternetService#8, OnlineSecurity#9, OnlineBackup#10, DeviceProtection#11, TechSupport#12, StreamingTV#13, StreamingMovies#14, Contract#15, PaperlessBilling#16, PaymentMethod#17, MonthlyCharges#18, TotalCharges#19, Churn#20) AND (size(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)), true) > 0)) could run on GPU\n", " @Expression atleastnnonnulls(21, customerID#0, gender#1, SeniorCitizen#2, Partner#3, Dependents#4, tenure#5, PhoneService#6, MultipleLines#7, InternetService#8, OnlineSecurity#9, OnlineBackup#10, DeviceProtection#11, TechSupport#12, StreamingTV#13, StreamingMovies#14, Contract#15, PaperlessBilling#16, PaymentMethod#17, MonthlyCharges#18, TotalCharges#19, Churn#20) could run on GPU\n", " @Expression customerID#0 could run on GPU\n", " @Expression gender#1 could run on GPU\n", " @Expression SeniorCitizen#2 could run on GPU\n", " @Expression Partner#3 could run on GPU\n", " @Expression Dependents#4 could run on GPU\n", " @Expression tenure#5 could run on GPU\n", " @Expression PhoneService#6 could run on GPU\n", " @Expression MultipleLines#7 could run on GPU\n", " @Expression InternetService#8 could run on GPU\n", " @Expression OnlineSecurity#9 could run on GPU\n", " @Expression OnlineBackup#10 could run on GPU\n", " @Expression DeviceProtection#11 could run on GPU\n", " @Expression TechSupport#12 could run on GPU\n", " @Expression StreamingTV#13 could run on GPU\n", " @Expression StreamingMovies#14 could run on GPU\n", " @Expression Contract#15 could run on GPU\n", " @Expression PaperlessBilling#16 could run on GPU\n", " @Expression PaymentMethod#17 could run on GPU\n", " @Expression MonthlyCharges#18 could run on GPU\n", " @Expression TotalCharges#19 could run on GPU\n", " @Expression Churn#20 could run on GPU\n", " @Expression (size(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)), true) > 0) could run on GPU\n", " @Expression size(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)), true) could run on GPU\n", " ! array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.ArrayRepeat\n", " @Expression cast((TotalCharges#19 / tenure#5) as decimal(8,2)) could run on GPU\n", " @Expression (TotalCharges#19 / tenure#5) could run on GPU\n", " @Expression TotalCharges#19 could run on GPU\n", " @Expression tenure#5 could run on GPU\n", " !Expression cast(tenure#5 as int) cannot run on GPU because Casting from floating point types to integral types on the GPU supports a slightly different range of values when using Spark 3.1.0 or later. Refer to the CAST documentation for more details.. To enable this operation on the GPU, set spark.rapids.sql.castFloatToIntegralTypes.enabled to true.\n", " @Expression tenure#5 could run on GPU\n", " @Expression 0 could run on GPU\n", " @Expression isnotnull(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int))) could run on GPU\n", " ! array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.ArrayRepeat\n", " @Expression cast((TotalCharges#19 / tenure#5) as decimal(8,2)) could run on GPU\n", " @Expression (TotalCharges#19 / tenure#5) could run on GPU\n", " @Expression TotalCharges#19 could run on GPU\n", " @Expression tenure#5 could run on GPU\n", " !Expression cast(tenure#5 as int) cannot run on GPU because Casting from floating point types to integral types on the GPU supports a slightly different range of values when using Spark 3.1.0 or later. Refer to the CAST documentation for more details.. To enable this operation on the GPU, set spark.rapids.sql.castFloatToIntegralTypes.enabled to true.\n", " @Expression tenure#5 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression customerID#265 could run on GPU\n", " @Expression AccountCreation AS kind#191 could run on GPU\n", " @Expression AccountCreation could run on GPU\n", " @Expression 0.00 AS value#192 could run on GPU\n", " @Expression 0.00 could run on GPU\n", " @Expression add_months(2022-04-05, cast(((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) as int)) AS date#200 could run on GPU\n", " ! add_months(2022-04-05, cast(((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.AddMonths\n", " @Expression 2022-04-05 could run on GPU\n", " !Expression cast(((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) as int) cannot run on GPU because Casting from floating point types to integral types on the GPU supports a slightly different range of values when using Spark 3.1.0 or later. Refer to the CAST documentation for more details.. To enable this operation on the GPU, set spark.rapids.sql.castFloatToIntegralTypes.enabled to true.\n", " @Expression ((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) could run on GPU\n", " @Expression (-tenure#270 - 1.0) could run on GPU\n", " @Expression -tenure#270 could run on GPU\n", " @Expression tenure#270 could run on GPU\n", " @Expression 1.0 could run on GPU\n", " @Expression CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END could run on GPU\n", " @Expression (Churn#285 = Yes) could run on GPU\n", " @Expression Churn#285 could run on GPU\n", " @Expression Yes could run on GPU\n", " @Expression cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) could run on GPU\n", " @Expression -((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) could run on GPU\n", " @Expression ((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) could run on GPU\n", " @Expression (((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) could run on GPU\n", " @Expression ((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) could run on GPU\n", " @Expression (((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) could run on GPU\n", " @Expression ((abs(xxhash64(customerID#265, 42), false) & 255) % 36) could run on GPU\n", " @Expression (abs(xxhash64(customerID#265, 42), false) & 255) could run on GPU\n", " @Expression abs(xxhash64(customerID#265, 42), false) could run on GPU\n", " ! xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#265 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 36 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#265, 42), false), 8) could run on GPU\n", " @Expression abs(xxhash64(customerID#265, 42), false) could run on GPU\n", " ! xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#265 could run on GPU\n", " @Expression 8 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 24 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#265, 42), false), 16) could run on GPU\n", " @Expression abs(xxhash64(customerID#265, 42), false) could run on GPU\n", " ! xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#265 could run on GPU\n", " @Expression 16 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 14 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#265, 42), false), 24) could run on GPU\n", " @Expression abs(xxhash64(customerID#265, 42), false) could run on GPU\n", " ! xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#265 could run on GPU\n", " @Expression 24 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 10 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#265, 42), false), 32) could run on GPU\n", " @Expression abs(xxhash64(customerID#265, 42), false) could run on GPU\n", " ! xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#265 could run on GPU\n", " @Expression 32 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 6 could run on GPU\n", " @Expression 0.0 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression customerID#290 could run on GPU\n", " @Expression AccountTermination AS kind#258 could run on GPU\n", " @Expression AccountTermination could run on GPU\n", " @Expression 0.00 AS value#259 could run on GPU\n", " @Expression 0.00 could run on GPU\n", " @Expression CASE WHEN (Churn#310 = Yes) THEN add_months(2022-04-05, cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int)) ELSE 2022-04-05 END AS date#260 could run on GPU\n", " @Expression CASE WHEN (Churn#310 = Yes) THEN add_months(2022-04-05, cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int)) ELSE 2022-04-05 END could run on GPU\n", " @Expression (Churn#310 = Yes) could run on GPU\n", " @Expression Churn#310 could run on GPU\n", " @Expression Yes could run on GPU\n", " ! add_months(2022-04-05, cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.AddMonths\n", " @Expression 2022-04-05 could run on GPU\n", " @Expression cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int) could run on GPU\n", " @Expression -((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) could run on GPU\n", " @Expression ((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) could run on GPU\n", " @Expression (((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) could run on GPU\n", " @Expression ((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) could run on GPU\n", " @Expression (((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) could run on GPU\n", " @Expression ((abs(xxhash64(customerID#290, 42), false) & 255) % 36) could run on GPU\n", " @Expression (abs(xxhash64(customerID#290, 42), false) & 255) could run on GPU\n", " @Expression abs(xxhash64(customerID#290, 42), false) could run on GPU\n", " ! xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#290 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 36 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#290, 42), false), 8) could run on GPU\n", " @Expression abs(xxhash64(customerID#290, 42), false) could run on GPU\n", " ! xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#290 could run on GPU\n", " @Expression 8 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 24 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#290, 42), false), 16) could run on GPU\n", " @Expression abs(xxhash64(customerID#290, 42), false) could run on GPU\n", " ! xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#290 could run on GPU\n", " @Expression 16 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 14 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#290, 42), false), 24) could run on GPU\n", " @Expression abs(xxhash64(customerID#290, 42), false) could run on GPU\n", " ! xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#290 could run on GPU\n", " @Expression 24 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 10 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#290, 42), false), 32) could run on GPU\n", " @Expression abs(xxhash64(customerID#290, 42), false) could run on GPU\n", " ! xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#290 could run on GPU\n", " @Expression 32 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 6 could run on GPU\n", " @Expression 2022-04-05 could run on GPU\n", "\n", "2022-04-05 09:36:37,199 WARN rapids.GpuOverrides: \n", "!Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression format_string(%s-%s, customerID#0, u_value#337) AS customerID#816 could run on GPU\n", " ! format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\n", " @Expression %s-%s could run on GPU\n", " @Expression customerID#0 could run on GPU\n", " @Expression u_value#337 could run on GPU\n", " @Expression kind#133 could run on GPU\n", " @Expression value#136 could run on GPU\n", " @Expression date#156 could run on GPU\n", " @Expression month#315 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression customerID#0 could run on GPU\n", " @Expression Charge AS kind#133 could run on GPU\n", " @Expression Charge could run on GPU\n", " @Expression value#136 could run on GPU\n", " @Expression add_months(2022-04-05, cast(-(cast(_we0#149 as bigint) + last_month#135L) as int)) AS date#156 could run on GPU\n", " ! add_months(2022-04-05, cast(-(cast(_we0#149 as bigint) + last_month#135L) as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.AddMonths\n", " @Expression 2022-04-05 could run on GPU\n", " @Expression cast(-(cast(_we0#149 as bigint) + last_month#135L) as int) could run on GPU\n", " @Expression -(cast(_we0#149 as bigint) + last_month#135L) could run on GPU\n", " @Expression (cast(_we0#149 as bigint) + last_month#135L) could run on GPU\n", " @Expression cast(_we0#149 as bigint) could run on GPU\n", " @Expression _we0#149 could run on GPU\n", " @Expression last_month#135L could run on GPU\n", " !Exec cannot run on GPU because Columnar exchange without columnar children is inefficient\n", " @Partitioning could run on GPU\n", " @Expression customerID#0 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression customerID#0 could run on GPU\n", " @Expression value#136 could run on GPU\n", " @Expression CASE WHEN (Churn#20 = Yes) THEN -((((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) % 6)) ELSE 0 END AS last_month#135L could run on GPU\n", " @Expression CASE WHEN (Churn#20 = Yes) THEN -((((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) % 6)) ELSE 0 END could run on GPU\n", " @Expression (Churn#20 = Yes) could run on GPU\n", " @Expression Churn#20 could run on GPU\n", " @Expression Yes could run on GPU\n", " @Expression -((((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) % 6)) could run on GPU\n", " @Expression ((((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) % 6)) could run on GPU\n", " @Expression (((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10)) could run on GPU\n", " @Expression ((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) could run on GPU\n", " @Expression (((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) could run on GPU\n", " @Expression ((abs(xxhash64(customerID#0, 42), false) & 255) % 36) could run on GPU\n", " @Expression (abs(xxhash64(customerID#0, 42), false) & 255) could run on GPU\n", " @Expression abs(xxhash64(customerID#0, 42), false) could run on GPU\n", " ! xxhash64(customerID#0, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#0 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 36 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#0, 42), false), 8) could run on GPU\n", " @Expression abs(xxhash64(customerID#0, 42), false) could run on GPU\n", " ! xxhash64(customerID#0, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#0 could run on GPU\n", " @Expression 8 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 24 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#0, 42), false), 16) could run on GPU\n", " @Expression abs(xxhash64(customerID#0, 42), false) could run on GPU\n", " ! xxhash64(customerID#0, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#0 could run on GPU\n", " @Expression 16 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 14 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#0, 42), false), 24) could run on GPU\n", " @Expression abs(xxhash64(customerID#0, 42), false) could run on GPU\n", " ! xxhash64(customerID#0, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#0 could run on GPU\n", " @Expression 24 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 10 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) % 6) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#0, 42), false), 32) could run on GPU\n", " @Expression abs(xxhash64(customerID#0, 42), false) could run on GPU\n", " ! xxhash64(customerID#0, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#0 could run on GPU\n", " @Expression 32 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 6 could run on GPU\n", " @Expression 0 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression explode(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int))) could run on GPU\n", " ! array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.ArrayRepeat\n", " @Expression cast((TotalCharges#19 / tenure#5) as decimal(8,2)) could run on GPU\n", " @Expression (TotalCharges#19 / tenure#5) could run on GPU\n", " @Expression TotalCharges#19 could run on GPU\n", " @Expression tenure#5 could run on GPU\n", " !Expression cast(tenure#5 as int) cannot run on GPU because Casting from floating point types to integral types on the GPU supports a slightly different range of values when using Spark 3.1.0 or later. Refer to the CAST documentation for more details.. To enable this operation on the GPU, set spark.rapids.sql.castFloatToIntegralTypes.enabled to true.\n", " @Expression tenure#5 could run on GPU\n", " @Expression customerID#0 could run on GPU\n", " @Expression Churn#20 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression ((atleastnnonnulls(21, customerID#0, gender#1, SeniorCitizen#2, Partner#3, Dependents#4, tenure#5, PhoneService#6, MultipleLines#7, InternetService#8, OnlineSecurity#9, OnlineBackup#10, DeviceProtection#11, TechSupport#12, StreamingTV#13, StreamingMovies#14, Contract#15, PaperlessBilling#16, PaymentMethod#17, MonthlyCharges#18, TotalCharges#19, Churn#20) AND (size(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)), true) > 0)) AND isnotnull(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)))) could run on GPU\n", " @Expression (atleastnnonnulls(21, customerID#0, gender#1, SeniorCitizen#2, Partner#3, Dependents#4, tenure#5, PhoneService#6, MultipleLines#7, InternetService#8, OnlineSecurity#9, OnlineBackup#10, DeviceProtection#11, TechSupport#12, StreamingTV#13, StreamingMovies#14, Contract#15, PaperlessBilling#16, PaymentMethod#17, MonthlyCharges#18, TotalCharges#19, Churn#20) AND (size(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)), true) > 0)) could run on GPU\n", " @Expression atleastnnonnulls(21, customerID#0, gender#1, SeniorCitizen#2, Partner#3, Dependents#4, tenure#5, PhoneService#6, MultipleLines#7, InternetService#8, OnlineSecurity#9, OnlineBackup#10, DeviceProtection#11, TechSupport#12, StreamingTV#13, StreamingMovies#14, Contract#15, PaperlessBilling#16, PaymentMethod#17, MonthlyCharges#18, TotalCharges#19, Churn#20) could run on GPU\n", " @Expression customerID#0 could run on GPU\n", " @Expression gender#1 could run on GPU\n", " @Expression SeniorCitizen#2 could run on GPU\n", " @Expression Partner#3 could run on GPU\n", " @Expression Dependents#4 could run on GPU\n", " @Expression tenure#5 could run on GPU\n", " @Expression PhoneService#6 could run on GPU\n", " @Expression MultipleLines#7 could run on GPU\n", " @Expression InternetService#8 could run on GPU\n", " @Expression OnlineSecurity#9 could run on GPU\n", " @Expression OnlineBackup#10 could run on GPU\n", " @Expression DeviceProtection#11 could run on GPU\n", " @Expression TechSupport#12 could run on GPU\n", " @Expression StreamingTV#13 could run on GPU\n", " @Expression StreamingMovies#14 could run on GPU\n", " @Expression Contract#15 could run on GPU\n", " @Expression PaperlessBilling#16 could run on GPU\n", " @Expression PaymentMethod#17 could run on GPU\n", " @Expression MonthlyCharges#18 could run on GPU\n", " @Expression TotalCharges#19 could run on GPU\n", " @Expression Churn#20 could run on GPU\n", " @Expression (size(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)), true) > 0) could run on GPU\n", " @Expression size(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)), true) could run on GPU\n", " ! array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.ArrayRepeat\n", " @Expression cast((TotalCharges#19 / tenure#5) as decimal(8,2)) could run on GPU\n", " @Expression (TotalCharges#19 / tenure#5) could run on GPU\n", " @Expression TotalCharges#19 could run on GPU\n", " @Expression tenure#5 could run on GPU\n", " !Expression cast(tenure#5 as int) cannot run on GPU because Casting from floating point types to integral types on the GPU supports a slightly different range of values when using Spark 3.1.0 or later. Refer to the CAST documentation for more details.. To enable this operation on the GPU, set spark.rapids.sql.castFloatToIntegralTypes.enabled to true.\n", " @Expression tenure#5 could run on GPU\n", " @Expression 0 could run on GPU\n", " @Expression isnotnull(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int))) could run on GPU\n", " ! array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.ArrayRepeat\n", " @Expression cast((TotalCharges#19 / tenure#5) as decimal(8,2)) could run on GPU\n", " @Expression (TotalCharges#19 / tenure#5) could run on GPU\n", " @Expression TotalCharges#19 could run on GPU\n", " @Expression tenure#5 could run on GPU\n", " !Expression cast(tenure#5 as int) cannot run on GPU because Casting from floating point types to integral types on the GPU supports a slightly different range of values when using Spark 3.1.0 or later. Refer to the CAST documentation for more details.. To enable this operation on the GPU, set spark.rapids.sql.castFloatToIntegralTypes.enabled to true.\n", " @Expression tenure#5 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression customerID#265 could run on GPU\n", " @Expression AccountCreation AS kind#191 could run on GPU\n", " @Expression AccountCreation could run on GPU\n", " @Expression 0.00 AS value#192 could run on GPU\n", " @Expression 0.00 could run on GPU\n", " @Expression add_months(2022-04-05, cast(((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) as int)) AS date#200 could run on GPU\n", " ! add_months(2022-04-05, cast(((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.AddMonths\n", " @Expression 2022-04-05 could run on GPU\n", " !Expression cast(((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) as int) cannot run on GPU because Casting from floating point types to integral types on the GPU supports a slightly different range of values when using Spark 3.1.0 or later. Refer to the CAST documentation for more details.. To enable this operation on the GPU, set spark.rapids.sql.castFloatToIntegralTypes.enabled to true.\n", " @Expression ((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) could run on GPU\n", " @Expression (-tenure#270 - 1.0) could run on GPU\n", " @Expression -tenure#270 could run on GPU\n", " @Expression tenure#270 could run on GPU\n", " @Expression 1.0 could run on GPU\n", " @Expression CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END could run on GPU\n", " @Expression (Churn#285 = Yes) could run on GPU\n", " @Expression Churn#285 could run on GPU\n", " @Expression Yes could run on GPU\n", " @Expression cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) could run on GPU\n", " @Expression -((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) could run on GPU\n", " @Expression ((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) could run on GPU\n", " @Expression (((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) could run on GPU\n", " @Expression ((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) could run on GPU\n", " @Expression (((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) could run on GPU\n", " @Expression ((abs(xxhash64(customerID#265, 42), false) & 255) % 36) could run on GPU\n", " @Expression (abs(xxhash64(customerID#265, 42), false) & 255) could run on GPU\n", " @Expression abs(xxhash64(customerID#265, 42), false) could run on GPU\n", " ! xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#265 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 36 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#265, 42), false), 8) could run on GPU\n", " @Expression abs(xxhash64(customerID#265, 42), false) could run on GPU\n", " ! xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#265 could run on GPU\n", " @Expression 8 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 24 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#265, 42), false), 16) could run on GPU\n", " @Expression abs(xxhash64(customerID#265, 42), false) could run on GPU\n", " ! xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#265 could run on GPU\n", " @Expression 16 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 14 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#265, 42), false), 24) could run on GPU\n", " @Expression abs(xxhash64(customerID#265, 42), false) could run on GPU\n", " ! xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#265 could run on GPU\n", " @Expression 24 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 10 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#265, 42), false), 32) could run on GPU\n", " @Expression abs(xxhash64(customerID#265, 42), false) could run on GPU\n", " ! xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#265 could run on GPU\n", " @Expression 32 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 6 could run on GPU\n", " @Expression 0.0 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression customerID#290 could run on GPU\n", " @Expression AccountTermination AS kind#258 could run on GPU\n", " @Expression AccountTermination could run on GPU\n", " @Expression 0.00 AS value#259 could run on GPU\n", " @Expression 0.00 could run on GPU\n", " @Expression CASE WHEN (Churn#310 = Yes) THEN add_months(2022-04-05, cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int)) ELSE 2022-04-05 END AS date#260 could run on GPU\n", " @Expression CASE WHEN (Churn#310 = Yes) THEN add_months(2022-04-05, cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int)) ELSE 2022-04-05 END could run on GPU\n", " @Expression (Churn#310 = Yes) could run on GPU\n", " @Expression Churn#310 could run on GPU\n", " @Expression Yes could run on GPU\n", " ! add_months(2022-04-05, cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.AddMonths\n", " @Expression 2022-04-05 could run on GPU\n", " @Expression cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int) could run on GPU\n", " @Expression -((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) could run on GPU\n", " @Expression ((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) could run on GPU\n", " @Expression (((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) could run on GPU\n", " @Expression ((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) could run on GPU\n", " @Expression (((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) could run on GPU\n", " @Expression ((abs(xxhash64(customerID#290, 42), false) & 255) % 36) could run on GPU\n", " @Expression (abs(xxhash64(customerID#290, 42), false) & 255) could run on GPU\n", " @Expression abs(xxhash64(customerID#290, 42), false) could run on GPU\n", " ! xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#290 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 36 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#290, 42), false), 8) could run on GPU\n", " @Expression abs(xxhash64(customerID#290, 42), false) could run on GPU\n", " ! xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#290 could run on GPU\n", " @Expression 8 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 24 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#290, 42), false), 16) could run on GPU\n", " @Expression abs(xxhash64(customerID#290, 42), false) could run on GPU\n", " ! xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#290 could run on GPU\n", " @Expression 16 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 14 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#290, 42), false), 24) could run on GPU\n", " @Expression abs(xxhash64(customerID#290, 42), false) could run on GPU\n", " ! xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#290 could run on GPU\n", " @Expression 24 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 10 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#290, 42), false), 32) could run on GPU\n", " @Expression abs(xxhash64(customerID#290, 42), false) could run on GPU\n", " ! xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#290 could run on GPU\n", " @Expression 32 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 6 could run on GPU\n", " @Expression 2022-04-05 could run on GPU\n", "\n", "2022-04-05 09:36:37,210 WARN rapids.GpuOverrides: \n", "!Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression format_string(%s-%s, customerID#0, u_value#337) AS customerID#816 could run on GPU\n", " ! format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\n", " @Expression %s-%s could run on GPU\n", " @Expression customerID#0 could run on GPU\n", " @Expression u_value#337 could run on GPU\n", " @Expression kind#133 could run on GPU\n", " @Expression value#136 could run on GPU\n", " @Expression date#156 could run on GPU\n", " @Expression month#315 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression customerID#0 could run on GPU\n", " @Expression Charge AS kind#133 could run on GPU\n", " @Expression Charge could run on GPU\n", " @Expression value#136 could run on GPU\n", " @Expression add_months(2022-04-05, cast(-(cast(_we0#149 as bigint) + last_month#135L) as int)) AS date#156 could run on GPU\n", " ! add_months(2022-04-05, cast(-(cast(_we0#149 as bigint) + last_month#135L) as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.AddMonths\n", " @Expression 2022-04-05 could run on GPU\n", " @Expression cast(-(cast(_we0#149 as bigint) + last_month#135L) as int) could run on GPU\n", " @Expression -(cast(_we0#149 as bigint) + last_month#135L) could run on GPU\n", " @Expression (cast(_we0#149 as bigint) + last_month#135L) could run on GPU\n", " @Expression cast(_we0#149 as bigint) could run on GPU\n", " @Expression _we0#149 could run on GPU\n", " @Expression last_month#135L could run on GPU\n", " !Exec cannot run on GPU because Columnar exchange without columnar children is inefficient\n", " @Partitioning could run on GPU\n", " @Expression customerID#0 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression customerID#0 could run on GPU\n", " @Expression value#136 could run on GPU\n", " @Expression CASE WHEN (Churn#20 = Yes) THEN -((((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) % 6)) ELSE 0 END AS last_month#135L could run on GPU\n", " @Expression CASE WHEN (Churn#20 = Yes) THEN -((((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) % 6)) ELSE 0 END could run on GPU\n", " @Expression (Churn#20 = Yes) could run on GPU\n", " @Expression Churn#20 could run on GPU\n", " @Expression Yes could run on GPU\n", " @Expression -((((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) % 6)) could run on GPU\n", " @Expression ((((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) % 6)) could run on GPU\n", " @Expression (((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10)) could run on GPU\n", " @Expression ((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) could run on GPU\n", " @Expression (((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) could run on GPU\n", " @Expression ((abs(xxhash64(customerID#0, 42), false) & 255) % 36) could run on GPU\n", " @Expression (abs(xxhash64(customerID#0, 42), false) & 255) could run on GPU\n", " @Expression abs(xxhash64(customerID#0, 42), false) could run on GPU\n", " ! xxhash64(customerID#0, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#0 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 36 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#0, 42), false), 8) could run on GPU\n", " @Expression abs(xxhash64(customerID#0, 42), false) could run on GPU\n", " ! xxhash64(customerID#0, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#0 could run on GPU\n", " @Expression 8 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 24 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#0, 42), false), 16) could run on GPU\n", " @Expression abs(xxhash64(customerID#0, 42), false) could run on GPU\n", " ! xxhash64(customerID#0, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#0 could run on GPU\n", " @Expression 16 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 14 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#0, 42), false), 24) could run on GPU\n", " @Expression abs(xxhash64(customerID#0, 42), false) could run on GPU\n", " ! xxhash64(customerID#0, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#0 could run on GPU\n", " @Expression 24 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 10 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) % 6) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#0, 42), false), 32) could run on GPU\n", " @Expression abs(xxhash64(customerID#0, 42), false) could run on GPU\n", " ! xxhash64(customerID#0, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#0 could run on GPU\n", " @Expression 32 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 6 could run on GPU\n", " @Expression 0 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression explode(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int))) could run on GPU\n", " ! array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.ArrayRepeat\n", " @Expression cast((TotalCharges#19 / tenure#5) as decimal(8,2)) could run on GPU\n", " @Expression (TotalCharges#19 / tenure#5) could run on GPU\n", " @Expression TotalCharges#19 could run on GPU\n", " @Expression tenure#5 could run on GPU\n", " !Expression cast(tenure#5 as int) cannot run on GPU because Casting from floating point types to integral types on the GPU supports a slightly different range of values when using Spark 3.1.0 or later. Refer to the CAST documentation for more details.. To enable this operation on the GPU, set spark.rapids.sql.castFloatToIntegralTypes.enabled to true.\n", " @Expression tenure#5 could run on GPU\n", " @Expression customerID#0 could run on GPU\n", " @Expression Churn#20 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression ((atleastnnonnulls(21, customerID#0, gender#1, SeniorCitizen#2, Partner#3, Dependents#4, tenure#5, PhoneService#6, MultipleLines#7, InternetService#8, OnlineSecurity#9, OnlineBackup#10, DeviceProtection#11, TechSupport#12, StreamingTV#13, StreamingMovies#14, Contract#15, PaperlessBilling#16, PaymentMethod#17, MonthlyCharges#18, TotalCharges#19, Churn#20) AND (size(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)), true) > 0)) AND isnotnull(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)))) could run on GPU\n", " @Expression (atleastnnonnulls(21, customerID#0, gender#1, SeniorCitizen#2, Partner#3, Dependents#4, tenure#5, PhoneService#6, MultipleLines#7, InternetService#8, OnlineSecurity#9, OnlineBackup#10, DeviceProtection#11, TechSupport#12, StreamingTV#13, StreamingMovies#14, Contract#15, PaperlessBilling#16, PaymentMethod#17, MonthlyCharges#18, TotalCharges#19, Churn#20) AND (size(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)), true) > 0)) could run on GPU\n", " @Expression atleastnnonnulls(21, customerID#0, gender#1, SeniorCitizen#2, Partner#3, Dependents#4, tenure#5, PhoneService#6, MultipleLines#7, InternetService#8, OnlineSecurity#9, OnlineBackup#10, DeviceProtection#11, TechSupport#12, StreamingTV#13, StreamingMovies#14, Contract#15, PaperlessBilling#16, PaymentMethod#17, MonthlyCharges#18, TotalCharges#19, Churn#20) could run on GPU\n", " @Expression customerID#0 could run on GPU\n", " @Expression gender#1 could run on GPU\n", " @Expression SeniorCitizen#2 could run on GPU\n", " @Expression Partner#3 could run on GPU\n", " @Expression Dependents#4 could run on GPU\n", " @Expression tenure#5 could run on GPU\n", " @Expression PhoneService#6 could run on GPU\n", " @Expression MultipleLines#7 could run on GPU\n", " @Expression InternetService#8 could run on GPU\n", " @Expression OnlineSecurity#9 could run on GPU\n", " @Expression OnlineBackup#10 could run on GPU\n", " @Expression DeviceProtection#11 could run on GPU\n", " @Expression TechSupport#12 could run on GPU\n", " @Expression StreamingTV#13 could run on GPU\n", " @Expression StreamingMovies#14 could run on GPU\n", " @Expression Contract#15 could run on GPU\n", " @Expression PaperlessBilling#16 could run on GPU\n", " @Expression PaymentMethod#17 could run on GPU\n", " @Expression MonthlyCharges#18 could run on GPU\n", " @Expression TotalCharges#19 could run on GPU\n", " @Expression Churn#20 could run on GPU\n", " @Expression (size(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)), true) > 0) could run on GPU\n", " @Expression size(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)), true) could run on GPU\n", " ! array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.ArrayRepeat\n", " @Expression cast((TotalCharges#19 / tenure#5) as decimal(8,2)) could run on GPU\n", " @Expression (TotalCharges#19 / tenure#5) could run on GPU\n", " @Expression TotalCharges#19 could run on GPU\n", " @Expression tenure#5 could run on GPU\n", " !Expression cast(tenure#5 as int) cannot run on GPU because Casting from floating point types to integral types on the GPU supports a slightly different range of values when using Spark 3.1.0 or later. Refer to the CAST documentation for more details.. To enable this operation on the GPU, set spark.rapids.sql.castFloatToIntegralTypes.enabled to true.\n", " @Expression tenure#5 could run on GPU\n", " @Expression 0 could run on GPU\n", " @Expression isnotnull(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int))) could run on GPU\n", " ! array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.ArrayRepeat\n", " @Expression cast((TotalCharges#19 / tenure#5) as decimal(8,2)) could run on GPU\n", " @Expression (TotalCharges#19 / tenure#5) could run on GPU\n", " @Expression TotalCharges#19 could run on GPU\n", " @Expression tenure#5 could run on GPU\n", " !Expression cast(tenure#5 as int) cannot run on GPU because Casting from floating point types to integral types on the GPU supports a slightly different range of values when using Spark 3.1.0 or later. Refer to the CAST documentation for more details.. To enable this operation on the GPU, set spark.rapids.sql.castFloatToIntegralTypes.enabled to true.\n", " @Expression tenure#5 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression customerID#265 could run on GPU\n", " @Expression AccountCreation AS kind#191 could run on GPU\n", " @Expression AccountCreation could run on GPU\n", " @Expression 0.00 AS value#192 could run on GPU\n", " @Expression 0.00 could run on GPU\n", " @Expression add_months(2022-04-05, cast(((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) as int)) AS date#200 could run on GPU\n", " ! add_months(2022-04-05, cast(((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.AddMonths\n", " @Expression 2022-04-05 could run on GPU\n", " !Expression cast(((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) as int) cannot run on GPU because Casting from floating point types to integral types on the GPU supports a slightly different range of values when using Spark 3.1.0 or later. Refer to the CAST documentation for more details.. To enable this operation on the GPU, set spark.rapids.sql.castFloatToIntegralTypes.enabled to true.\n", " @Expression ((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) could run on GPU\n", " @Expression (-tenure#270 - 1.0) could run on GPU\n", " @Expression -tenure#270 could run on GPU\n", " @Expression tenure#270 could run on GPU\n", " @Expression 1.0 could run on GPU\n", " @Expression CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END could run on GPU\n", " @Expression (Churn#285 = Yes) could run on GPU\n", " @Expression Churn#285 could run on GPU\n", " @Expression Yes could run on GPU\n", " @Expression cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) could run on GPU\n", " @Expression -((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) could run on GPU\n", " @Expression ((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) could run on GPU\n", " @Expression (((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) could run on GPU\n", " @Expression ((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) could run on GPU\n", " @Expression (((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) could run on GPU\n", " @Expression ((abs(xxhash64(customerID#265, 42), false) & 255) % 36) could run on GPU\n", " @Expression (abs(xxhash64(customerID#265, 42), false) & 255) could run on GPU\n", " @Expression abs(xxhash64(customerID#265, 42), false) could run on GPU\n", " ! xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#265 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 36 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#265, 42), false), 8) could run on GPU\n", " @Expression abs(xxhash64(customerID#265, 42), false) could run on GPU\n", " ! xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#265 could run on GPU\n", " @Expression 8 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 24 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#265, 42), false), 16) could run on GPU\n", " @Expression abs(xxhash64(customerID#265, 42), false) could run on GPU\n", " ! xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#265 could run on GPU\n", " @Expression 16 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 14 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#265, 42), false), 24) could run on GPU\n", " @Expression abs(xxhash64(customerID#265, 42), false) could run on GPU\n", " ! xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#265 could run on GPU\n", " @Expression 24 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 10 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#265, 42), false), 32) could run on GPU\n", " @Expression abs(xxhash64(customerID#265, 42), false) could run on GPU\n", " ! xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#265 could run on GPU\n", " @Expression 32 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 6 could run on GPU\n", " @Expression 0.0 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression customerID#290 could run on GPU\n", " @Expression AccountTermination AS kind#258 could run on GPU\n", " @Expression AccountTermination could run on GPU\n", " @Expression 0.00 AS value#259 could run on GPU\n", " @Expression 0.00 could run on GPU\n", " @Expression CASE WHEN (Churn#310 = Yes) THEN add_months(2022-04-05, cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int)) ELSE 2022-04-05 END AS date#260 could run on GPU\n", " @Expression CASE WHEN (Churn#310 = Yes) THEN add_months(2022-04-05, cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int)) ELSE 2022-04-05 END could run on GPU\n", " @Expression (Churn#310 = Yes) could run on GPU\n", " @Expression Churn#310 could run on GPU\n", " @Expression Yes could run on GPU\n", " ! add_months(2022-04-05, cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.AddMonths\n", " @Expression 2022-04-05 could run on GPU\n", " @Expression cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int) could run on GPU\n", " @Expression -((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) could run on GPU\n", " @Expression ((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) could run on GPU\n", " @Expression (((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) could run on GPU\n", " @Expression ((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) could run on GPU\n", " @Expression (((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) could run on GPU\n", " @Expression ((abs(xxhash64(customerID#290, 42), false) & 255) % 36) could run on GPU\n", " @Expression (abs(xxhash64(customerID#290, 42), false) & 255) could run on GPU\n", " @Expression abs(xxhash64(customerID#290, 42), false) could run on GPU\n", " ! xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#290 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 36 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#290, 42), false), 8) could run on GPU\n", " @Expression abs(xxhash64(customerID#290, 42), false) could run on GPU\n", " ! xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#290 could run on GPU\n", " @Expression 8 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 24 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#290, 42), false), 16) could run on GPU\n", " @Expression abs(xxhash64(customerID#290, 42), false) could run on GPU\n", " ! xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#290 could run on GPU\n", " @Expression 16 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 14 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#290, 42), false), 24) could run on GPU\n", " @Expression abs(xxhash64(customerID#290, 42), false) could run on GPU\n", " ! xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#290 could run on GPU\n", " @Expression 24 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 10 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#290, 42), false), 32) could run on GPU\n", " @Expression abs(xxhash64(customerID#290, 42), false) could run on GPU\n", " ! xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#290 could run on GPU\n", " @Expression 32 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 6 could run on GPU\n", " @Expression 2022-04-05 could run on GPU\n", "\n", "2022-04-05 09:36:37,305 WARN rapids.GpuOverrides: \n", "!Exec cannot run on GPU because Columnar exchange without columnar children is inefficient\n", " @Partitioning could run on GPU\n", " @Expression customerID#0 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression customerID#0 could run on GPU\n", " @Expression value#136 could run on GPU\n", " @Expression CASE WHEN (Churn#20 = Yes) THEN -((((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) % 6)) ELSE 0 END AS last_month#135L could run on GPU\n", " @Expression CASE WHEN (Churn#20 = Yes) THEN -((((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) % 6)) ELSE 0 END could run on GPU\n", " @Expression (Churn#20 = Yes) could run on GPU\n", " @Expression Churn#20 could run on GPU\n", " @Expression Yes could run on GPU\n", " @Expression -((((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) % 6)) could run on GPU\n", " @Expression ((((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) % 6)) could run on GPU\n", " @Expression (((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10)) could run on GPU\n", " @Expression ((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) could run on GPU\n", " @Expression (((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) could run on GPU\n", " @Expression ((abs(xxhash64(customerID#0, 42), false) & 255) % 36) could run on GPU\n", " @Expression (abs(xxhash64(customerID#0, 42), false) & 255) could run on GPU\n", " @Expression abs(xxhash64(customerID#0, 42), false) could run on GPU\n", " ! xxhash64(customerID#0, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#0 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 36 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#0, 42), false), 8) could run on GPU\n", " @Expression abs(xxhash64(customerID#0, 42), false) could run on GPU\n", " ! xxhash64(customerID#0, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#0 could run on GPU\n", " @Expression 8 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 24 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#0, 42), false), 16) could run on GPU\n", " @Expression abs(xxhash64(customerID#0, 42), false) could run on GPU\n", " ! xxhash64(customerID#0, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#0 could run on GPU\n", " @Expression 16 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 14 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#0, 42), false), 24) could run on GPU\n", " @Expression abs(xxhash64(customerID#0, 42), false) could run on GPU\n", " ! xxhash64(customerID#0, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#0 could run on GPU\n", " @Expression 24 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 10 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) % 6) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#0, 42), false), 32) could run on GPU\n", " @Expression abs(xxhash64(customerID#0, 42), false) could run on GPU\n", " ! xxhash64(customerID#0, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#0 could run on GPU\n", " @Expression 32 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 6 could run on GPU\n", " @Expression 0 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression explode(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int))) could run on GPU\n", " ! array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.ArrayRepeat\n", " @Expression cast((TotalCharges#19 / tenure#5) as decimal(8,2)) could run on GPU\n", " @Expression (TotalCharges#19 / tenure#5) could run on GPU\n", " @Expression TotalCharges#19 could run on GPU\n", " @Expression tenure#5 could run on GPU\n", " !Expression cast(tenure#5 as int) cannot run on GPU because Casting from floating point types to integral types on the GPU supports a slightly different range of values when using Spark 3.1.0 or later. Refer to the CAST documentation for more details.. To enable this operation on the GPU, set spark.rapids.sql.castFloatToIntegralTypes.enabled to true.\n", " @Expression tenure#5 could run on GPU\n", " @Expression customerID#0 could run on GPU\n", " @Expression Churn#20 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression ((atleastnnonnulls(21, customerID#0, gender#1, SeniorCitizen#2, Partner#3, Dependents#4, tenure#5, PhoneService#6, MultipleLines#7, InternetService#8, OnlineSecurity#9, OnlineBackup#10, DeviceProtection#11, TechSupport#12, StreamingTV#13, StreamingMovies#14, Contract#15, PaperlessBilling#16, PaymentMethod#17, MonthlyCharges#18, TotalCharges#19, Churn#20) AND (size(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)), true) > 0)) AND isnotnull(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)))) could run on GPU\n", " @Expression (atleastnnonnulls(21, customerID#0, gender#1, SeniorCitizen#2, Partner#3, Dependents#4, tenure#5, PhoneService#6, MultipleLines#7, InternetService#8, OnlineSecurity#9, OnlineBackup#10, DeviceProtection#11, TechSupport#12, StreamingTV#13, StreamingMovies#14, Contract#15, PaperlessBilling#16, PaymentMethod#17, MonthlyCharges#18, TotalCharges#19, Churn#20) AND (size(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)), true) > 0)) could run on GPU\n", " @Expression atleastnnonnulls(21, customerID#0, gender#1, SeniorCitizen#2, Partner#3, Dependents#4, tenure#5, PhoneService#6, MultipleLines#7, InternetService#8, OnlineSecurity#9, OnlineBackup#10, DeviceProtection#11, TechSupport#12, StreamingTV#13, StreamingMovies#14, Contract#15, PaperlessBilling#16, PaymentMethod#17, MonthlyCharges#18, TotalCharges#19, Churn#20) could run on GPU\n", " @Expression customerID#0 could run on GPU\n", " @Expression gender#1 could run on GPU\n", " @Expression SeniorCitizen#2 could run on GPU\n", " @Expression Partner#3 could run on GPU\n", " @Expression Dependents#4 could run on GPU\n", " @Expression tenure#5 could run on GPU\n", " @Expression PhoneService#6 could run on GPU\n", " @Expression MultipleLines#7 could run on GPU\n", " @Expression InternetService#8 could run on GPU\n", " @Expression OnlineSecurity#9 could run on GPU\n", " @Expression OnlineBackup#10 could run on GPU\n", " @Expression DeviceProtection#11 could run on GPU\n", " @Expression TechSupport#12 could run on GPU\n", " @Expression StreamingTV#13 could run on GPU\n", " @Expression StreamingMovies#14 could run on GPU\n", " @Expression Contract#15 could run on GPU\n", " @Expression PaperlessBilling#16 could run on GPU\n", " @Expression PaymentMethod#17 could run on GPU\n", " @Expression MonthlyCharges#18 could run on GPU\n", " @Expression TotalCharges#19 could run on GPU\n", " @Expression Churn#20 could run on GPU\n", " @Expression (size(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)), true) > 0) could run on GPU\n", " @Expression size(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)), true) could run on GPU\n", " ! array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.ArrayRepeat\n", " @Expression cast((TotalCharges#19 / tenure#5) as decimal(8,2)) could run on GPU\n", " @Expression (TotalCharges#19 / tenure#5) could run on GPU\n", " @Expression TotalCharges#19 could run on GPU\n", " @Expression tenure#5 could run on GPU\n", " !Expression cast(tenure#5 as int) cannot run on GPU because Casting from floating point types to integral types on the GPU supports a slightly different range of values when using Spark 3.1.0 or later. Refer to the CAST documentation for more details.. To enable this operation on the GPU, set spark.rapids.sql.castFloatToIntegralTypes.enabled to true.\n", " @Expression tenure#5 could run on GPU\n", " @Expression 0 could run on GPU\n", " @Expression isnotnull(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int))) could run on GPU\n", " ! array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.ArrayRepeat\n", " @Expression cast((TotalCharges#19 / tenure#5) as decimal(8,2)) could run on GPU\n", " @Expression (TotalCharges#19 / tenure#5) could run on GPU\n", " @Expression TotalCharges#19 could run on GPU\n", " @Expression tenure#5 could run on GPU\n", " !Expression cast(tenure#5 as int) cannot run on GPU because Casting from floating point types to integral types on the GPU supports a slightly different range of values when using Spark 3.1.0 or later. Refer to the CAST documentation for more details.. To enable this operation on the GPU, set spark.rapids.sql.castFloatToIntegralTypes.enabled to true.\n", " @Expression tenure#5 could run on GPU\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2022-04-05 09:36:37,476 WARN rapids.GpuOverrides: \n", "!Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression format_string(%s-%s, customerID#0, u_value#337) AS customerID#816 could run on GPU\n", " ! format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\n", " @Expression %s-%s could run on GPU\n", " @Expression customerID#0 could run on GPU\n", " @Expression u_value#337 could run on GPU\n", " @Expression kind#133 could run on GPU\n", " @Expression value#136 could run on GPU\n", " @Expression date#156 could run on GPU\n", " @Expression month#315 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression customerID#0 could run on GPU\n", " @Expression Charge AS kind#133 could run on GPU\n", " @Expression Charge could run on GPU\n", " @Expression value#136 could run on GPU\n", " @Expression add_months(2022-04-05, cast(-(cast(_we0#149 as bigint) + last_month#135L) as int)) AS date#156 could run on GPU\n", " ! add_months(2022-04-05, cast(-(cast(_we0#149 as bigint) + last_month#135L) as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.AddMonths\n", " @Expression 2022-04-05 could run on GPU\n", " @Expression cast(-(cast(_we0#149 as bigint) + last_month#135L) as int) could run on GPU\n", " @Expression -(cast(_we0#149 as bigint) + last_month#135L) could run on GPU\n", " @Expression (cast(_we0#149 as bigint) + last_month#135L) could run on GPU\n", " @Expression cast(_we0#149 as bigint) could run on GPU\n", " @Expression _we0#149 could run on GPU\n", " @Expression last_month#135L could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression customerID#265 could run on GPU\n", " @Expression AccountCreation AS kind#191 could run on GPU\n", " @Expression AccountCreation could run on GPU\n", " @Expression 0.00 AS value#192 could run on GPU\n", " @Expression 0.00 could run on GPU\n", " @Expression add_months(2022-04-05, cast(((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) as int)) AS date#200 could run on GPU\n", " ! add_months(2022-04-05, cast(((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.AddMonths\n", " @Expression 2022-04-05 could run on GPU\n", " !Expression cast(((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) as int) cannot run on GPU because Casting from floating point types to integral types on the GPU supports a slightly different range of values when using Spark 3.1.0 or later. Refer to the CAST documentation for more details.. To enable this operation on the GPU, set spark.rapids.sql.castFloatToIntegralTypes.enabled to true.\n", " @Expression ((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) could run on GPU\n", " @Expression (-tenure#270 - 1.0) could run on GPU\n", " @Expression -tenure#270 could run on GPU\n", " @Expression tenure#270 could run on GPU\n", " @Expression 1.0 could run on GPU\n", " @Expression CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END could run on GPU\n", " @Expression (Churn#285 = Yes) could run on GPU\n", " @Expression Churn#285 could run on GPU\n", " @Expression Yes could run on GPU\n", " @Expression cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) could run on GPU\n", " @Expression -((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) could run on GPU\n", " @Expression ((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) could run on GPU\n", " @Expression (((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) could run on GPU\n", " @Expression ((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) could run on GPU\n", " @Expression (((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) could run on GPU\n", " @Expression ((abs(xxhash64(customerID#265, 42), false) & 255) % 36) could run on GPU\n", " @Expression (abs(xxhash64(customerID#265, 42), false) & 255) could run on GPU\n", " @Expression abs(xxhash64(customerID#265, 42), false) could run on GPU\n", " ! xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#265 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 36 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#265, 42), false), 8) could run on GPU\n", " @Expression abs(xxhash64(customerID#265, 42), false) could run on GPU\n", " ! xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#265 could run on GPU\n", " @Expression 8 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 24 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#265, 42), false), 16) could run on GPU\n", " @Expression abs(xxhash64(customerID#265, 42), false) could run on GPU\n", " ! xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#265 could run on GPU\n", " @Expression 16 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 14 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#265, 42), false), 24) could run on GPU\n", " @Expression abs(xxhash64(customerID#265, 42), false) could run on GPU\n", " ! xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#265 could run on GPU\n", " @Expression 24 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 10 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#265, 42), false), 32) could run on GPU\n", " @Expression abs(xxhash64(customerID#265, 42), false) could run on GPU\n", " ! xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#265 could run on GPU\n", " @Expression 32 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 6 could run on GPU\n", " @Expression 0.0 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression customerID#290 could run on GPU\n", " @Expression AccountTermination AS kind#258 could run on GPU\n", " @Expression AccountTermination could run on GPU\n", " @Expression 0.00 AS value#259 could run on GPU\n", " @Expression 0.00 could run on GPU\n", " @Expression CASE WHEN (Churn#310 = Yes) THEN add_months(2022-04-05, cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int)) ELSE 2022-04-05 END AS date#260 could run on GPU\n", " @Expression CASE WHEN (Churn#310 = Yes) THEN add_months(2022-04-05, cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int)) ELSE 2022-04-05 END could run on GPU\n", " @Expression (Churn#310 = Yes) could run on GPU\n", " @Expression Churn#310 could run on GPU\n", " @Expression Yes could run on GPU\n", " ! add_months(2022-04-05, cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.AddMonths\n", " @Expression 2022-04-05 could run on GPU\n", " @Expression cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int) could run on GPU\n", " @Expression -((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) could run on GPU\n", " @Expression ((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) could run on GPU\n", " @Expression (((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) could run on GPU\n", " @Expression ((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) could run on GPU\n", " @Expression (((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) could run on GPU\n", " @Expression ((abs(xxhash64(customerID#290, 42), false) & 255) % 36) could run on GPU\n", " @Expression (abs(xxhash64(customerID#290, 42), false) & 255) could run on GPU\n", " @Expression abs(xxhash64(customerID#290, 42), false) could run on GPU\n", " ! xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#290 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 36 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#290, 42), false), 8) could run on GPU\n", " @Expression abs(xxhash64(customerID#290, 42), false) could run on GPU\n", " ! xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#290 could run on GPU\n", " @Expression 8 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 24 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#290, 42), false), 16) could run on GPU\n", " @Expression abs(xxhash64(customerID#290, 42), false) could run on GPU\n", " ! xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#290 could run on GPU\n", " @Expression 16 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 14 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#290, 42), false), 24) could run on GPU\n", " @Expression abs(xxhash64(customerID#290, 42), false) could run on GPU\n", " ! xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#290 could run on GPU\n", " @Expression 24 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 10 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#290, 42), false), 32) could run on GPU\n", " @Expression abs(xxhash64(customerID#290, 42), false) could run on GPU\n", " ! xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#290 could run on GPU\n", " @Expression 32 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 6 could run on GPU\n", " @Expression 2022-04-05 could run on GPU\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2022-04-05 09:36:37,897 WARN rapids.GpuOverrides: \n", "!Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression format_string(%s-%s, customerID#0, u_value#337) AS customerID#816 could run on GPU\n", " ! format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\n", " @Expression %s-%s could run on GPU\n", " @Expression customerID#0 could run on GPU\n", " @Expression u_value#337 could run on GPU\n", " @Expression kind#133 could run on GPU\n", " @Expression value#136 could run on GPU\n", " @Expression date#156 could run on GPU\n", " @Expression month#315 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression customerID#0 could run on GPU\n", " @Expression Charge AS kind#133 could run on GPU\n", " @Expression Charge could run on GPU\n", " @Expression value#136 could run on GPU\n", " @Expression add_months(2022-04-05, cast(-(cast(_we0#149 as bigint) + last_month#135L) as int)) AS date#156 could run on GPU\n", " ! add_months(2022-04-05, cast(-(cast(_we0#149 as bigint) + last_month#135L) as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.AddMonths\n", " @Expression 2022-04-05 could run on GPU\n", " @Expression cast(-(cast(_we0#149 as bigint) + last_month#135L) as int) could run on GPU\n", " @Expression -(cast(_we0#149 as bigint) + last_month#135L) could run on GPU\n", " @Expression (cast(_we0#149 as bigint) + last_month#135L) could run on GPU\n", " @Expression cast(_we0#149 as bigint) could run on GPU\n", " @Expression _we0#149 could run on GPU\n", " @Expression last_month#135L could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression customerID#265 could run on GPU\n", " @Expression AccountCreation AS kind#191 could run on GPU\n", " @Expression AccountCreation could run on GPU\n", " @Expression 0.00 AS value#192 could run on GPU\n", " @Expression 0.00 could run on GPU\n", " @Expression add_months(2022-04-05, cast(((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) as int)) AS date#200 could run on GPU\n", " ! add_months(2022-04-05, cast(((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.AddMonths\n", " @Expression 2022-04-05 could run on GPU\n", " !Expression cast(((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) as int) cannot run on GPU because Casting from floating point types to integral types on the GPU supports a slightly different range of values when using Spark 3.1.0 or later. Refer to the CAST documentation for more details.. To enable this operation on the GPU, set spark.rapids.sql.castFloatToIntegralTypes.enabled to true.\n", " @Expression ((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) could run on GPU\n", " @Expression (-tenure#270 - 1.0) could run on GPU\n", " @Expression -tenure#270 could run on GPU\n", " @Expression tenure#270 could run on GPU\n", " @Expression 1.0 could run on GPU\n", " @Expression CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END could run on GPU\n", " @Expression (Churn#285 = Yes) could run on GPU\n", " @Expression Churn#285 could run on GPU\n", " @Expression Yes could run on GPU\n", " @Expression cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) could run on GPU\n", " @Expression -((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) could run on GPU\n", " @Expression ((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) could run on GPU\n", " @Expression (((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) could run on GPU\n", " @Expression ((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) could run on GPU\n", " @Expression (((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) could run on GPU\n", " @Expression ((abs(xxhash64(customerID#265, 42), false) & 255) % 36) could run on GPU\n", " @Expression (abs(xxhash64(customerID#265, 42), false) & 255) could run on GPU\n", " @Expression abs(xxhash64(customerID#265, 42), false) could run on GPU\n", " ! xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#265 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 36 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#265, 42), false), 8) could run on GPU\n", " @Expression abs(xxhash64(customerID#265, 42), false) could run on GPU\n", " ! xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#265 could run on GPU\n", " @Expression 8 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 24 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#265, 42), false), 16) could run on GPU\n", " @Expression abs(xxhash64(customerID#265, 42), false) could run on GPU\n", " ! xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#265 could run on GPU\n", " @Expression 16 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 14 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#265, 42), false), 24) could run on GPU\n", " @Expression abs(xxhash64(customerID#265, 42), false) could run on GPU\n", " ! xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#265 could run on GPU\n", " @Expression 24 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 10 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#265, 42), false), 32) could run on GPU\n", " @Expression abs(xxhash64(customerID#265, 42), false) could run on GPU\n", " ! xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#265 could run on GPU\n", " @Expression 32 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 6 could run on GPU\n", " @Expression 0.0 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression customerID#290 could run on GPU\n", " @Expression AccountTermination AS kind#258 could run on GPU\n", " @Expression AccountTermination could run on GPU\n", " @Expression 0.00 AS value#259 could run on GPU\n", " @Expression 0.00 could run on GPU\n", " @Expression CASE WHEN (Churn#310 = Yes) THEN add_months(2022-04-05, cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int)) ELSE 2022-04-05 END AS date#260 could run on GPU\n", " @Expression CASE WHEN (Churn#310 = Yes) THEN add_months(2022-04-05, cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int)) ELSE 2022-04-05 END could run on GPU\n", " @Expression (Churn#310 = Yes) could run on GPU\n", " @Expression Churn#310 could run on GPU\n", " @Expression Yes could run on GPU\n", " ! add_months(2022-04-05, cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.AddMonths\n", " @Expression 2022-04-05 could run on GPU\n", " @Expression cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int) could run on GPU\n", " @Expression -((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) could run on GPU\n", " @Expression ((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) could run on GPU\n", " @Expression (((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) could run on GPU\n", " @Expression ((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) could run on GPU\n", " @Expression (((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) could run on GPU\n", " @Expression ((abs(xxhash64(customerID#290, 42), false) & 255) % 36) could run on GPU\n", " @Expression (abs(xxhash64(customerID#290, 42), false) & 255) could run on GPU\n", " @Expression abs(xxhash64(customerID#290, 42), false) could run on GPU\n", " ! xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#290 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 36 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#290, 42), false), 8) could run on GPU\n", " @Expression abs(xxhash64(customerID#290, 42), false) could run on GPU\n", " ! xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#290 could run on GPU\n", " @Expression 8 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 24 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#290, 42), false), 16) could run on GPU\n", " @Expression abs(xxhash64(customerID#290, 42), false) could run on GPU\n", " ! xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#290 could run on GPU\n", " @Expression 16 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 14 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#290, 42), false), 24) could run on GPU\n", " @Expression abs(xxhash64(customerID#290, 42), false) could run on GPU\n", " ! xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#290 could run on GPU\n", " @Expression 24 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 10 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#290, 42), false), 32) could run on GPU\n", " @Expression abs(xxhash64(customerID#290, 42), false) could run on GPU\n", " ! xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#290 could run on GPU\n", " @Expression 32 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 6 could run on GPU\n", " @Expression 2022-04-05 could run on GPU\n", "\n", "2022-04-05 09:36:37,903 WARN rapids.GpuOverrides: \n", "!Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression format_string(%s-%s, customerID#0, u_value#337) AS customerID#816 could run on GPU\n", " ! format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\n", " @Expression %s-%s could run on GPU\n", " @Expression customerID#0 could run on GPU\n", " @Expression u_value#337 could run on GPU\n", " @Expression kind#133 could run on GPU\n", " @Expression value#136 could run on GPU\n", " @Expression date#156 could run on GPU\n", " @Expression month#315 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression customerID#0 could run on GPU\n", " @Expression Charge AS kind#133 could run on GPU\n", " @Expression Charge could run on GPU\n", " @Expression value#136 could run on GPU\n", " @Expression add_months(2022-04-05, cast(-(cast(_we0#149 as bigint) + last_month#135L) as int)) AS date#156 could run on GPU\n", " ! add_months(2022-04-05, cast(-(cast(_we0#149 as bigint) + last_month#135L) as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.AddMonths\n", " @Expression 2022-04-05 could run on GPU\n", " @Expression cast(-(cast(_we0#149 as bigint) + last_month#135L) as int) could run on GPU\n", " @Expression -(cast(_we0#149 as bigint) + last_month#135L) could run on GPU\n", " @Expression (cast(_we0#149 as bigint) + last_month#135L) could run on GPU\n", " @Expression cast(_we0#149 as bigint) could run on GPU\n", " @Expression _we0#149 could run on GPU\n", " @Expression last_month#135L could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression customerID#265 could run on GPU\n", " @Expression AccountCreation AS kind#191 could run on GPU\n", " @Expression AccountCreation could run on GPU\n", " @Expression 0.00 AS value#192 could run on GPU\n", " @Expression 0.00 could run on GPU\n", " @Expression add_months(2022-04-05, cast(((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) as int)) AS date#200 could run on GPU\n", " ! add_months(2022-04-05, cast(((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.AddMonths\n", " @Expression 2022-04-05 could run on GPU\n", " !Expression cast(((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) as int) cannot run on GPU because Casting from floating point types to integral types on the GPU supports a slightly different range of values when using Spark 3.1.0 or later. Refer to the CAST documentation for more details.. To enable this operation on the GPU, set spark.rapids.sql.castFloatToIntegralTypes.enabled to true.\n", " @Expression ((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) could run on GPU\n", " @Expression (-tenure#270 - 1.0) could run on GPU\n", " @Expression -tenure#270 could run on GPU\n", " @Expression tenure#270 could run on GPU\n", " @Expression 1.0 could run on GPU\n", " @Expression CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END could run on GPU\n", " @Expression (Churn#285 = Yes) could run on GPU\n", " @Expression Churn#285 could run on GPU\n", " @Expression Yes could run on GPU\n", " @Expression cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) could run on GPU\n", " @Expression -((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) could run on GPU\n", " @Expression ((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) could run on GPU\n", " @Expression (((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) could run on GPU\n", " @Expression ((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) could run on GPU\n", " @Expression (((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) could run on GPU\n", " @Expression ((abs(xxhash64(customerID#265, 42), false) & 255) % 36) could run on GPU\n", " @Expression (abs(xxhash64(customerID#265, 42), false) & 255) could run on GPU\n", " @Expression abs(xxhash64(customerID#265, 42), false) could run on GPU\n", " ! xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#265 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 36 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#265, 42), false), 8) could run on GPU\n", " @Expression abs(xxhash64(customerID#265, 42), false) could run on GPU\n", " ! xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#265 could run on GPU\n", " @Expression 8 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 24 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#265, 42), false), 16) could run on GPU\n", " @Expression abs(xxhash64(customerID#265, 42), false) could run on GPU\n", " ! xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#265 could run on GPU\n", " @Expression 16 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 14 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#265, 42), false), 24) could run on GPU\n", " @Expression abs(xxhash64(customerID#265, 42), false) could run on GPU\n", " ! xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#265 could run on GPU\n", " @Expression 24 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 10 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#265, 42), false), 32) could run on GPU\n", " @Expression abs(xxhash64(customerID#265, 42), false) could run on GPU\n", " ! xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#265 could run on GPU\n", " @Expression 32 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 6 could run on GPU\n", " @Expression 0.0 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression customerID#290 could run on GPU\n", " @Expression AccountTermination AS kind#258 could run on GPU\n", " @Expression AccountTermination could run on GPU\n", " @Expression 0.00 AS value#259 could run on GPU\n", " @Expression 0.00 could run on GPU\n", " @Expression CASE WHEN (Churn#310 = Yes) THEN add_months(2022-04-05, cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int)) ELSE 2022-04-05 END AS date#260 could run on GPU\n", " @Expression CASE WHEN (Churn#310 = Yes) THEN add_months(2022-04-05, cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int)) ELSE 2022-04-05 END could run on GPU\n", " @Expression (Churn#310 = Yes) could run on GPU\n", " @Expression Churn#310 could run on GPU\n", " @Expression Yes could run on GPU\n", " ! add_months(2022-04-05, cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.AddMonths\n", " @Expression 2022-04-05 could run on GPU\n", " @Expression cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int) could run on GPU\n", " @Expression -((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) could run on GPU\n", " @Expression ((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) could run on GPU\n", " @Expression (((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) could run on GPU\n", " @Expression ((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) could run on GPU\n", " @Expression (((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) could run on GPU\n", " @Expression ((abs(xxhash64(customerID#290, 42), false) & 255) % 36) could run on GPU\n", " @Expression (abs(xxhash64(customerID#290, 42), false) & 255) could run on GPU\n", " @Expression abs(xxhash64(customerID#290, 42), false) could run on GPU\n", " ! xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#290 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 36 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#290, 42), false), 8) could run on GPU\n", " @Expression abs(xxhash64(customerID#290, 42), false) could run on GPU\n", " ! xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#290 could run on GPU\n", " @Expression 8 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 24 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#290, 42), false), 16) could run on GPU\n", " @Expression abs(xxhash64(customerID#290, 42), false) could run on GPU\n", " ! xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#290 could run on GPU\n", " @Expression 16 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 14 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#290, 42), false), 24) could run on GPU\n", " @Expression abs(xxhash64(customerID#290, 42), false) could run on GPU\n", " ! xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#290 could run on GPU\n", " @Expression 24 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 10 could run on GPU\n", " @Expression ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6) could run on GPU\n", " @Expression (shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) could run on GPU\n", " @Expression shiftright(abs(xxhash64(customerID#290, 42), false), 32) could run on GPU\n", " @Expression abs(xxhash64(customerID#290, 42), false) could run on GPU\n", " ! xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\n", " @Expression customerID#290 could run on GPU\n", " @Expression 32 could run on GPU\n", " @Expression 255 could run on GPU\n", " @Expression 6 could run on GPU\n", " @Expression 2022-04-05 could run on GPU\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2022-04-05 09:40:21,129 WARN rapids.GpuOverrides: \n", " !Exec cannot run on GPU because Columnar exchange without columnar children is inefficient\n", " @Partitioning could run on GPU\n", " @Expression customerID#395 ASC NULLS FIRST could run on GPU\n", " @Expression customerID#395 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression format_string(%s-%s, customerID#0, u_value#337) AS customerID#395 could run on GPU\n", " ! format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\n", " @Expression %s-%s could run on GPU\n", " @Expression customerID#0 could run on GPU\n", " @Expression u_value#337 could run on GPU\n", " @Expression date_sub(2022-04-05, cast(FLOOR(CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END) as int)) AS dateOfBirth#460 could run on GPU\n", " @Expression date_sub(2022-04-05, cast(FLOOR(CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END) as int)) could run on GPU\n", " @Expression 2022-04-05 could run on GPU\n", " @Expression cast(FLOOR(CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END) as int) could run on GPU\n", " @Expression FLOOR(CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END) could run on GPU\n", " @Expression CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END could run on GPU\n", " @Expression (cast(SeniorCitizen#2 as int) = 0) could run on GPU\n", " @Expression cast(SeniorCitizen#2 as int) could run on GPU\n", " @Expression SeniorCitizen#2 could run on GPU\n", " @Expression 0 could run on GPU\n", " @Expression (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) could run on GPU\n", " @Expression ((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) could run on GPU\n", " @Expression (cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) could run on GPU\n", " @Expression cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) could run on GPU\n", " @Expression (abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) could run on GPU\n", " @Expression abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) could run on GPU\n", " @Expression hash(format_string(%s-%s, customerID#0, u_value#337), 42) could run on GPU\n", " ! format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\n", " @Expression %s-%s could run on GPU\n", " @Expression customerID#0 could run on GPU\n", " @Expression u_value#337 could run on GPU\n", " @Expression 4096 could run on GPU\n", " @Expression 4096.0 could run on GPU\n", " @Expression 16801.5 could run on GPU\n", " @Expression 6574.5 could run on GPU\n", " @Expression (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) could run on GPU\n", " @Expression ((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) could run on GPU\n", " @Expression (-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) could run on GPU\n", " @Expression -LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) could run on GPU\n", " @Expression LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) could run on GPU\n", " @Expression -(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) could run on GPU\n", " @Expression (cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) could run on GPU\n", " @Expression cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) could run on GPU\n", " @Expression (abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) could run on GPU\n", " @Expression abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) could run on GPU\n", " @Expression hash(format_string(%s-%s, customerID#0, u_value#337), 42) could run on GPU\n", " ! format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\n", " @Expression %s-%s could run on GPU\n", " @Expression customerID#0 could run on GPU\n", " @Expression u_value#337 could run on GPU\n", " @Expression 4096 could run on GPU\n", " @Expression 4096.0 could run on GPU\n", " @Expression 6.3 could run on GPU\n", " @Expression 365.25 could run on GPU\n", " @Expression 23741.25 could run on GPU\n", " @Expression gender#1 could run on GPU\n", " @Expression SeniorCitizen#2 could run on GPU\n", " @Expression Partner#3 could run on GPU\n", " @Expression Dependents#4 could run on GPU\n", " @Expression cast(MonthlyCharges#18 as decimal(8,2)) AS MonthlyCharges#441 could run on GPU\n", " @Expression cast(MonthlyCharges#18 as decimal(8,2)) could run on GPU\n", " @Expression MonthlyCharges#18 could run on GPU\n", " @Expression 2022-04-05 09:36:19.001066 AS now#439 could run on GPU\n", " @Expression 2022-04-05 09:36:19.001066 could run on GPU\n", "\n", "2022-04-05 09:40:21,133 WARN rapids.GpuOverrides: \n", " !Exec cannot run on GPU because Columnar exchange without columnar children is inefficient\n", " @Partitioning could run on GPU\n", " @Expression customerID#395 ASC NULLS FIRST could run on GPU\n", " @Expression customerID#395 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression format_string(%s-%s, customerID#0, u_value#337) AS customerID#395 could run on GPU\n", " ! format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\n", " @Expression %s-%s could run on GPU\n", " @Expression customerID#0 could run on GPU\n", " @Expression u_value#337 could run on GPU\n", " @Expression date_sub(2022-04-05, cast(FLOOR(CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END) as int)) AS dateOfBirth#460 could run on GPU\n", " @Expression date_sub(2022-04-05, cast(FLOOR(CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END) as int)) could run on GPU\n", " @Expression 2022-04-05 could run on GPU\n", " @Expression cast(FLOOR(CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END) as int) could run on GPU\n", " @Expression FLOOR(CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END) could run on GPU\n", " @Expression CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END could run on GPU\n", " @Expression (cast(SeniorCitizen#2 as int) = 0) could run on GPU\n", " @Expression cast(SeniorCitizen#2 as int) could run on GPU\n", " @Expression SeniorCitizen#2 could run on GPU\n", " @Expression 0 could run on GPU\n", " @Expression (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) could run on GPU\n", " @Expression ((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) could run on GPU\n", " @Expression (cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) could run on GPU\n", " @Expression cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) could run on GPU\n", " @Expression (abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) could run on GPU\n", " @Expression abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) could run on GPU\n", " @Expression hash(format_string(%s-%s, customerID#0, u_value#337), 42) could run on GPU\n", " ! format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\n", " @Expression %s-%s could run on GPU\n", " @Expression customerID#0 could run on GPU\n", " @Expression u_value#337 could run on GPU\n", " @Expression 4096 could run on GPU\n", " @Expression 4096.0 could run on GPU\n", " @Expression 16801.5 could run on GPU\n", " @Expression 6574.5 could run on GPU\n", " @Expression (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) could run on GPU\n", " @Expression ((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) could run on GPU\n", " @Expression (-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) could run on GPU\n", " @Expression -LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) could run on GPU\n", " @Expression LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) could run on GPU\n", " @Expression -(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) could run on GPU\n", " @Expression (cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) could run on GPU\n", " @Expression cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) could run on GPU\n", " @Expression (abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) could run on GPU\n", " @Expression abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) could run on GPU\n", " @Expression hash(format_string(%s-%s, customerID#0, u_value#337), 42) could run on GPU\n", " ! format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\n", " @Expression %s-%s could run on GPU\n", " @Expression customerID#0 could run on GPU\n", " @Expression u_value#337 could run on GPU\n", " @Expression 4096 could run on GPU\n", " @Expression 4096.0 could run on GPU\n", " @Expression 6.3 could run on GPU\n", " @Expression 365.25 could run on GPU\n", " @Expression 23741.25 could run on GPU\n", " @Expression gender#1 could run on GPU\n", " @Expression SeniorCitizen#2 could run on GPU\n", " @Expression Partner#3 could run on GPU\n", " @Expression Dependents#4 could run on GPU\n", " @Expression cast(MonthlyCharges#18 as decimal(8,2)) AS MonthlyCharges#441 could run on GPU\n", " @Expression cast(MonthlyCharges#18 as decimal(8,2)) could run on GPU\n", " @Expression MonthlyCharges#18 could run on GPU\n", " @Expression 2022-04-05 09:36:19.001066 AS now#439 could run on GPU\n", " @Expression 2022-04-05 09:36:19.001066 could run on GPU\n", "\n", "2022-04-05 09:40:21,138 WARN rapids.GpuOverrides: \n", " !Exec cannot run on GPU because Columnar exchange without columnar children is inefficient\n", " @Partitioning could run on GPU\n", " @Expression customerID#395 ASC NULLS FIRST could run on GPU\n", " @Expression customerID#395 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression format_string(%s-%s, customerID#0, u_value#337) AS customerID#395 could run on GPU\n", " ! format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\n", " @Expression %s-%s could run on GPU\n", " @Expression customerID#0 could run on GPU\n", " @Expression u_value#337 could run on GPU\n", " @Expression date_sub(2022-04-05, cast(FLOOR(CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END) as int)) AS dateOfBirth#460 could run on GPU\n", " @Expression date_sub(2022-04-05, cast(FLOOR(CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END) as int)) could run on GPU\n", " @Expression 2022-04-05 could run on GPU\n", " @Expression cast(FLOOR(CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END) as int) could run on GPU\n", " @Expression FLOOR(CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END) could run on GPU\n", " @Expression CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END could run on GPU\n", " @Expression (cast(SeniorCitizen#2 as int) = 0) could run on GPU\n", " @Expression cast(SeniorCitizen#2 as int) could run on GPU\n", " @Expression SeniorCitizen#2 could run on GPU\n", " @Expression 0 could run on GPU\n", " @Expression (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) could run on GPU\n", " @Expression ((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) could run on GPU\n", " @Expression (cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) could run on GPU\n", " @Expression cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) could run on GPU\n", " @Expression (abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) could run on GPU\n", " @Expression abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) could run on GPU\n", " @Expression hash(format_string(%s-%s, customerID#0, u_value#337), 42) could run on GPU\n", " ! format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\n", " @Expression %s-%s could run on GPU\n", " @Expression customerID#0 could run on GPU\n", " @Expression u_value#337 could run on GPU\n", " @Expression 4096 could run on GPU\n", " @Expression 4096.0 could run on GPU\n", " @Expression 16801.5 could run on GPU\n", " @Expression 6574.5 could run on GPU\n", " @Expression (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) could run on GPU\n", " @Expression ((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) could run on GPU\n", " @Expression (-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) could run on GPU\n", " @Expression -LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) could run on GPU\n", " @Expression LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) could run on GPU\n", " @Expression -(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) could run on GPU\n", " @Expression (cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) could run on GPU\n", " @Expression cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) could run on GPU\n", " @Expression (abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) could run on GPU\n", " @Expression abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) could run on GPU\n", " @Expression hash(format_string(%s-%s, customerID#0, u_value#337), 42) could run on GPU\n", " ! format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\n", " @Expression %s-%s could run on GPU\n", " @Expression customerID#0 could run on GPU\n", " @Expression u_value#337 could run on GPU\n", " @Expression 4096 could run on GPU\n", " @Expression 4096.0 could run on GPU\n", " @Expression 6.3 could run on GPU\n", " @Expression 365.25 could run on GPU\n", " @Expression 23741.25 could run on GPU\n", " @Expression gender#1 could run on GPU\n", " @Expression SeniorCitizen#2 could run on GPU\n", " @Expression Partner#3 could run on GPU\n", " @Expression Dependents#4 could run on GPU\n", " @Expression cast(MonthlyCharges#18 as decimal(8,2)) AS MonthlyCharges#441 could run on GPU\n", " @Expression cast(MonthlyCharges#18 as decimal(8,2)) could run on GPU\n", " @Expression MonthlyCharges#18 could run on GPU\n", " @Expression 2022-04-05 09:36:19.001066 AS now#439 could run on GPU\n", " @Expression 2022-04-05 09:36:19.001066 could run on GPU\n", "\n", "2022-04-05 09:40:21,144 WARN rapids.GpuOverrides: \n", " !Exec cannot run on GPU because Columnar exchange without columnar children is inefficient\n", " @Partitioning could run on GPU\n", " @Expression customerID#395 ASC NULLS FIRST could run on GPU\n", " @Expression customerID#395 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression format_string(%s-%s, customerID#0, u_value#337) AS customerID#395 could run on GPU\n", " ! format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\n", " @Expression %s-%s could run on GPU\n", " @Expression customerID#0 could run on GPU\n", " @Expression u_value#337 could run on GPU\n", " @Expression date_sub(2022-04-05, cast(FLOOR(CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END) as int)) AS dateOfBirth#460 could run on GPU\n", " @Expression date_sub(2022-04-05, cast(FLOOR(CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END) as int)) could run on GPU\n", " @Expression 2022-04-05 could run on GPU\n", " @Expression cast(FLOOR(CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END) as int) could run on GPU\n", " @Expression FLOOR(CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END) could run on GPU\n", " @Expression CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END could run on GPU\n", " @Expression (cast(SeniorCitizen#2 as int) = 0) could run on GPU\n", " @Expression cast(SeniorCitizen#2 as int) could run on GPU\n", " @Expression SeniorCitizen#2 could run on GPU\n", " @Expression 0 could run on GPU\n", " @Expression (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) could run on GPU\n", " @Expression ((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) could run on GPU\n", " @Expression (cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) could run on GPU\n", " @Expression cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) could run on GPU\n", " @Expression (abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) could run on GPU\n", " @Expression abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) could run on GPU\n", " @Expression hash(format_string(%s-%s, customerID#0, u_value#337), 42) could run on GPU\n", " ! format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\n", " @Expression %s-%s could run on GPU\n", " @Expression customerID#0 could run on GPU\n", " @Expression u_value#337 could run on GPU\n", " @Expression 4096 could run on GPU\n", " @Expression 4096.0 could run on GPU\n", " @Expression 16801.5 could run on GPU\n", " @Expression 6574.5 could run on GPU\n", " @Expression (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) could run on GPU\n", " @Expression ((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) could run on GPU\n", " @Expression (-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) could run on GPU\n", " @Expression -LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) could run on GPU\n", " @Expression LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) could run on GPU\n", " @Expression -(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) could run on GPU\n", " @Expression (cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) could run on GPU\n", " @Expression cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) could run on GPU\n", " @Expression (abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) could run on GPU\n", " @Expression abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) could run on GPU\n", " @Expression hash(format_string(%s-%s, customerID#0, u_value#337), 42) could run on GPU\n", " ! format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\n", " @Expression %s-%s could run on GPU\n", " @Expression customerID#0 could run on GPU\n", " @Expression u_value#337 could run on GPU\n", " @Expression 4096 could run on GPU\n", " @Expression 4096.0 could run on GPU\n", " @Expression 6.3 could run on GPU\n", " @Expression 365.25 could run on GPU\n", " @Expression 23741.25 could run on GPU\n", " @Expression gender#1 could run on GPU\n", " @Expression SeniorCitizen#2 could run on GPU\n", " @Expression Partner#3 could run on GPU\n", " @Expression Dependents#4 could run on GPU\n", " @Expression cast(MonthlyCharges#18 as decimal(8,2)) AS MonthlyCharges#441 could run on GPU\n", " @Expression cast(MonthlyCharges#18 as decimal(8,2)) could run on GPU\n", " @Expression MonthlyCharges#18 could run on GPU\n", " @Expression 2022-04-05 09:36:19.001066 AS now#439 could run on GPU\n", " @Expression 2022-04-05 09:36:19.001066 could run on GPU\n", "\n", "2022-04-05 09:40:21,206 WARN rapids.GpuOverrides: \n", " !Exec cannot run on GPU because Columnar exchange without columnar children is inefficient\n", " @Partitioning could run on GPU\n", " @Expression customerID#395 ASC NULLS FIRST could run on GPU\n", " @Expression customerID#395 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression format_string(%s-%s, customerID#0, u_value#337) AS customerID#395 could run on GPU\n", " ! format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\n", " @Expression %s-%s could run on GPU\n", " @Expression customerID#0 could run on GPU\n", " @Expression u_value#337 could run on GPU\n", " @Expression date_sub(2022-04-05, cast(FLOOR(CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END) as int)) AS dateOfBirth#460 could run on GPU\n", " @Expression date_sub(2022-04-05, cast(FLOOR(CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END) as int)) could run on GPU\n", " @Expression 2022-04-05 could run on GPU\n", " @Expression cast(FLOOR(CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END) as int) could run on GPU\n", " @Expression FLOOR(CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END) could run on GPU\n", " @Expression CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END could run on GPU\n", " @Expression (cast(SeniorCitizen#2 as int) = 0) could run on GPU\n", " @Expression cast(SeniorCitizen#2 as int) could run on GPU\n", " @Expression SeniorCitizen#2 could run on GPU\n", " @Expression 0 could run on GPU\n", " @Expression (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) could run on GPU\n", " @Expression ((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) could run on GPU\n", " @Expression (cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) could run on GPU\n", " @Expression cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) could run on GPU\n", " @Expression (abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) could run on GPU\n", " @Expression abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) could run on GPU\n", " @Expression hash(format_string(%s-%s, customerID#0, u_value#337), 42) could run on GPU\n", " ! format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\n", " @Expression %s-%s could run on GPU\n", " @Expression customerID#0 could run on GPU\n", " @Expression u_value#337 could run on GPU\n", " @Expression 4096 could run on GPU\n", " @Expression 4096.0 could run on GPU\n", " @Expression 16801.5 could run on GPU\n", " @Expression 6574.5 could run on GPU\n", " @Expression (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) could run on GPU\n", " @Expression ((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) could run on GPU\n", " @Expression (-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) could run on GPU\n", " @Expression -LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) could run on GPU\n", " @Expression LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) could run on GPU\n", " @Expression -(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) could run on GPU\n", " @Expression (cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) could run on GPU\n", " @Expression cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) could run on GPU\n", " @Expression (abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) could run on GPU\n", " @Expression abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) could run on GPU\n", " @Expression hash(format_string(%s-%s, customerID#0, u_value#337), 42) could run on GPU\n", " ! format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\n", " @Expression %s-%s could run on GPU\n", " @Expression customerID#0 could run on GPU\n", " @Expression u_value#337 could run on GPU\n", " @Expression 4096 could run on GPU\n", " @Expression 4096.0 could run on GPU\n", " @Expression 6.3 could run on GPU\n", " @Expression 365.25 could run on GPU\n", " @Expression 23741.25 could run on GPU\n", " @Expression gender#1 could run on GPU\n", " @Expression SeniorCitizen#2 could run on GPU\n", " @Expression Partner#3 could run on GPU\n", " @Expression Dependents#4 could run on GPU\n", " @Expression cast(MonthlyCharges#18 as decimal(8,2)) AS MonthlyCharges#441 could run on GPU\n", " @Expression cast(MonthlyCharges#18 as decimal(8,2)) could run on GPU\n", " @Expression MonthlyCharges#18 could run on GPU\n", " @Expression 2022-04-05 09:36:19.001066 AS now#439 could run on GPU\n", " @Expression 2022-04-05 09:36:19.001066 could run on GPU\n", "\n", "2022-04-05 09:40:21,209 WARN rapids.GpuOverrides: \n", "!Exec cannot run on GPU because Columnar exchange without columnar children is inefficient\n", " @Partitioning could run on GPU\n", " @Expression customerID#395 ASC NULLS FIRST could run on GPU\n", " @Expression customerID#395 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression format_string(%s-%s, customerID#0, u_value#337) AS customerID#395 could run on GPU\n", " ! format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\n", " @Expression %s-%s could run on GPU\n", " @Expression customerID#0 could run on GPU\n", " @Expression u_value#337 could run on GPU\n", " @Expression date_sub(2022-04-05, cast(FLOOR(CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END) as int)) AS dateOfBirth#460 could run on GPU\n", " @Expression date_sub(2022-04-05, cast(FLOOR(CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END) as int)) could run on GPU\n", " @Expression 2022-04-05 could run on GPU\n", " @Expression cast(FLOOR(CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END) as int) could run on GPU\n", " @Expression FLOOR(CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END) could run on GPU\n", " @Expression CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END could run on GPU\n", " @Expression (cast(SeniorCitizen#2 as int) = 0) could run on GPU\n", " @Expression cast(SeniorCitizen#2 as int) could run on GPU\n", " @Expression SeniorCitizen#2 could run on GPU\n", " @Expression 0 could run on GPU\n", " @Expression (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) could run on GPU\n", " @Expression ((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) could run on GPU\n", " @Expression (cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) could run on GPU\n", " @Expression cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) could run on GPU\n", " @Expression (abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) could run on GPU\n", " @Expression abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) could run on GPU\n", " @Expression hash(format_string(%s-%s, customerID#0, u_value#337), 42) could run on GPU\n", " ! format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\n", " @Expression %s-%s could run on GPU\n", " @Expression customerID#0 could run on GPU\n", " @Expression u_value#337 could run on GPU\n", " @Expression 4096 could run on GPU\n", " @Expression 4096.0 could run on GPU\n", " @Expression 16801.5 could run on GPU\n", " @Expression 6574.5 could run on GPU\n", " @Expression (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) could run on GPU\n", " @Expression ((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) could run on GPU\n", " @Expression (-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) could run on GPU\n", " @Expression -LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) could run on GPU\n", " @Expression LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) could run on GPU\n", " @Expression -(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) could run on GPU\n", " @Expression (cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) could run on GPU\n", " @Expression cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) could run on GPU\n", " @Expression (abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) could run on GPU\n", " @Expression abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) could run on GPU\n", " @Expression hash(format_string(%s-%s, customerID#0, u_value#337), 42) could run on GPU\n", " ! format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\n", " @Expression %s-%s could run on GPU\n", " @Expression customerID#0 could run on GPU\n", " @Expression u_value#337 could run on GPU\n", " @Expression 4096 could run on GPU\n", " @Expression 4096.0 could run on GPU\n", " @Expression 6.3 could run on GPU\n", " @Expression 365.25 could run on GPU\n", " @Expression 23741.25 could run on GPU\n", " @Expression gender#1 could run on GPU\n", " @Expression SeniorCitizen#2 could run on GPU\n", " @Expression Partner#3 could run on GPU\n", " @Expression Dependents#4 could run on GPU\n", " @Expression cast(MonthlyCharges#18 as decimal(8,2)) AS MonthlyCharges#441 could run on GPU\n", " @Expression cast(MonthlyCharges#18 as decimal(8,2)) could run on GPU\n", " @Expression MonthlyCharges#18 could run on GPU\n", " @Expression 2022-04-05 09:36:19.001066 AS now#439 could run on GPU\n", " @Expression 2022-04-05 09:36:19.001066 could run on GPU\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2022-04-05 09:40:23,697 WARN rapids.GpuOverrides: (0 + 1) / 1]\n", " !Exec cannot run on GPU because Unable to replace CustomShuffleReader due to child not being columnar\n", "\n", "2022-04-05 09:40:24,451 WARN rapids.GpuOverrides: \n", "! cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.execution.LocalTableScanExec\n", " @Expression name#894 could run on GPU\n", " @Expression database#895 could run on GPU\n", " @Expression description#896 could run on GPU\n", " @Expression tableType#897 could run on GPU\n", " @Expression isTemporary#898 could run on GPU\n", "\n", "2022-04-05 09:40:24,499 WARN rapids.GpuOverrides: \n", "!Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression format_string(%s-%s, customerID#0, u_value#337) AS customerID#910 could run on GPU\n", " ! format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\n", " @Expression %s-%s could run on GPU\n", " @Expression customerID#0 could run on GPU\n", " @Expression u_value#337 could run on GPU\n", " @Expression feature#479 could run on GPU\n", " @Expression value#480 could run on GPU\n", "\n", "2022-04-05 09:40:24,502 WARN rapids.GpuOverrides: \n", "!Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression format_string(%s-%s, customerID#0, u_value#337) AS customerID#910 could run on GPU\n", " ! format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\n", " @Expression %s-%s could run on GPU\n", " @Expression customerID#0 could run on GPU\n", " @Expression u_value#337 could run on GPU\n", " @Expression feature#479 could run on GPU\n", " @Expression value#480 could run on GPU\n", "\n", "2022-04-05 09:40:24,504 WARN rapids.GpuOverrides: \n", "!Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression format_string(%s-%s, customerID#0, u_value#337) AS customerID#910 could run on GPU\n", " ! format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\n", " @Expression %s-%s could run on GPU\n", " @Expression customerID#0 could run on GPU\n", " @Expression u_value#337 could run on GPU\n", " @Expression feature#479 could run on GPU\n", " @Expression value#480 could run on GPU\n", "\n", "2022-04-05 09:40:24,507 WARN rapids.GpuOverrides: \n", "!Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression format_string(%s-%s, customerID#0, u_value#337) AS customerID#910 could run on GPU\n", " ! format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\n", " @Expression %s-%s could run on GPU\n", " @Expression customerID#0 could run on GPU\n", " @Expression u_value#337 could run on GPU\n", " @Expression feature#479 could run on GPU\n", " @Expression value#480 could run on GPU\n", "\n", "2022-04-05 09:40:24,555 WARN rapids.GpuOverrides: \n", "!Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression format_string(%s-%s, customerID#0, u_value#337) AS customerID#910 could run on GPU\n", " ! format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\n", " @Expression %s-%s could run on GPU\n", " @Expression customerID#0 could run on GPU\n", " @Expression u_value#337 could run on GPU\n", " @Expression feature#479 could run on GPU\n", " @Expression value#480 could run on GPU\n", "\n", "2022-04-05 09:40:24,557 WARN rapids.GpuOverrides: \n", "!Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression format_string(%s-%s, customerID#0, u_value#337) AS customerID#910 could run on GPU\n", " ! format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\n", " @Expression %s-%s could run on GPU\n", " @Expression customerID#0 could run on GPU\n", " @Expression u_value#337 could run on GPU\n", " @Expression feature#479 could run on GPU\n", " @Expression value#480 could run on GPU\n", "\n", "2022-04-05 09:40:25,815 WARN rapids.GpuOverrides: \n", "! cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.execution.LocalTableScanExec\n", " @Expression name#946 could run on GPU\n", " @Expression database#947 could run on GPU\n", " @Expression description#948 could run on GPU\n", " @Expression tableType#949 could run on GPU\n", " @Expression isTemporary#950 could run on GPU\n", "\n", "2022-04-05 09:40:25,888 WARN rapids.GpuOverrides: \n", "!Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression format_string(%s-%s, customerID#0, u_value#337) AS customerID#962 could run on GPU\n", " ! format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\n", " @Expression %s-%s could run on GPU\n", " @Expression customerID#0 could run on GPU\n", " @Expression u_value#337 could run on GPU\n", " @Expression feature#513 could run on GPU\n", " @Expression value#514 could run on GPU\n", "\n", "2022-04-05 09:40:25,894 WARN rapids.GpuOverrides: \n", "!Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression format_string(%s-%s, customerID#0, u_value#337) AS customerID#962 could run on GPU\n", " ! format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\n", " @Expression %s-%s could run on GPU\n", " @Expression customerID#0 could run on GPU\n", " @Expression u_value#337 could run on GPU\n", " @Expression feature#513 could run on GPU\n", " @Expression value#514 could run on GPU\n", "\n", "2022-04-05 09:40:25,901 WARN rapids.GpuOverrides: \n", "!Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression format_string(%s-%s, customerID#0, u_value#337) AS customerID#962 could run on GPU\n", " ! format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\n", " @Expression %s-%s could run on GPU\n", " @Expression customerID#0 could run on GPU\n", " @Expression u_value#337 could run on GPU\n", " @Expression feature#513 could run on GPU\n", " @Expression value#514 could run on GPU\n", "\n", "2022-04-05 09:40:25,907 WARN rapids.GpuOverrides: \n", "!Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression format_string(%s-%s, customerID#0, u_value#337) AS customerID#962 could run on GPU\n", " ! format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\n", " @Expression %s-%s could run on GPU\n", " @Expression customerID#0 could run on GPU\n", " @Expression u_value#337 could run on GPU\n", " @Expression feature#513 could run on GPU\n", " @Expression value#514 could run on GPU\n", "\n", "2022-04-05 09:40:25,962 WARN rapids.GpuOverrides: \n", "!Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression format_string(%s-%s, customerID#0, u_value#337) AS customerID#962 could run on GPU\n", " ! format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\n", " @Expression %s-%s could run on GPU\n", " @Expression customerID#0 could run on GPU\n", " @Expression u_value#337 could run on GPU\n", " @Expression feature#513 could run on GPU\n", " @Expression value#514 could run on GPU\n", "\n", "2022-04-05 09:40:25,967 WARN rapids.GpuOverrides: \n", "!Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression format_string(%s-%s, customerID#0, u_value#337) AS customerID#962 could run on GPU\n", " ! format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\n", " @Expression %s-%s could run on GPU\n", " @Expression customerID#0 could run on GPU\n", " @Expression u_value#337 could run on GPU\n", " @Expression feature#513 could run on GPU\n", " @Expression value#514 could run on GPU\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2022-04-05 09:40:28,911 WARN rapids.GpuOverrides: \n", "! cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.execution.LocalTableScanExec\n", " @Expression name#998 could run on GPU\n", " @Expression database#999 could run on GPU\n", " @Expression description#1000 could run on GPU\n", " @Expression tableType#1001 could run on GPU\n", " @Expression isTemporary#1002 could run on GPU\n", "\n", "2022-04-05 09:40:28,964 WARN rapids.GpuOverrides: \n", "!Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression format_string(%s-%s, customerID#721, u_value#337) AS customerID#1014 could run on GPU\n", " ! format_string(%s-%s, customerID#721, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\n", " @Expression %s-%s could run on GPU\n", " @Expression customerID#721 could run on GPU\n", " @Expression u_value#337 could run on GPU\n", " @Expression feature#722 could run on GPU\n", " @Expression value#723 could run on GPU\n", " ! cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.execution.RDDScanExec\n", " @Expression customerID#721 could run on GPU\n", " @Expression feature#722 could run on GPU\n", " @Expression value#723 could run on GPU\n", "\n", "2022-04-05 09:40:28,967 WARN rapids.GpuOverrides: \n", "!Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression format_string(%s-%s, customerID#721, u_value#337) AS customerID#1014 could run on GPU\n", " ! format_string(%s-%s, customerID#721, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\n", " @Expression %s-%s could run on GPU\n", " @Expression customerID#721 could run on GPU\n", " @Expression u_value#337 could run on GPU\n", " @Expression feature#722 could run on GPU\n", " @Expression value#723 could run on GPU\n", " ! cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.execution.RDDScanExec\n", " @Expression customerID#721 could run on GPU\n", " @Expression feature#722 could run on GPU\n", " @Expression value#723 could run on GPU\n", "\n", "2022-04-05 09:40:28,970 WARN rapids.GpuOverrides: \n", "!Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression format_string(%s-%s, customerID#721, u_value#337) AS customerID#1014 could run on GPU\n", " ! format_string(%s-%s, customerID#721, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\n", " @Expression %s-%s could run on GPU\n", " @Expression customerID#721 could run on GPU\n", " @Expression u_value#337 could run on GPU\n", " @Expression feature#722 could run on GPU\n", " @Expression value#723 could run on GPU\n", " ! cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.execution.RDDScanExec\n", " @Expression customerID#721 could run on GPU\n", " @Expression feature#722 could run on GPU\n", " @Expression value#723 could run on GPU\n", "\n", "2022-04-05 09:40:28,973 WARN rapids.GpuOverrides: \n", "!Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression format_string(%s-%s, customerID#721, u_value#337) AS customerID#1014 could run on GPU\n", " ! format_string(%s-%s, customerID#721, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\n", " @Expression %s-%s could run on GPU\n", " @Expression customerID#721 could run on GPU\n", " @Expression u_value#337 could run on GPU\n", " @Expression feature#722 could run on GPU\n", " @Expression value#723 could run on GPU\n", " ! cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.execution.RDDScanExec\n", " @Expression customerID#721 could run on GPU\n", " @Expression feature#722 could run on GPU\n", " @Expression value#723 could run on GPU\n", "\n", "2022-04-05 09:40:29,023 WARN rapids.GpuOverrides: \n", "!Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression format_string(%s-%s, customerID#721, u_value#337) AS customerID#1014 could run on GPU\n", " ! format_string(%s-%s, customerID#721, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\n", " @Expression %s-%s could run on GPU\n", " @Expression customerID#721 could run on GPU\n", " @Expression u_value#337 could run on GPU\n", " @Expression feature#722 could run on GPU\n", " @Expression value#723 could run on GPU\n", " ! cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.execution.RDDScanExec\n", " @Expression customerID#721 could run on GPU\n", " @Expression feature#722 could run on GPU\n", " @Expression value#723 could run on GPU\n", "\n", "2022-04-05 09:40:29,026 WARN rapids.GpuOverrides: \n", "!Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression format_string(%s-%s, customerID#721, u_value#337) AS customerID#1014 could run on GPU\n", " ! format_string(%s-%s, customerID#721, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\n", " @Expression %s-%s could run on GPU\n", " @Expression customerID#721 could run on GPU\n", " @Expression u_value#337 could run on GPU\n", " @Expression feature#722 could run on GPU\n", " @Expression value#723 could run on GPU\n", " ! cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.execution.RDDScanExec\n", " @Expression customerID#721 could run on GPU\n", " @Expression feature#722 could run on GPU\n", " @Expression value#723 could run on GPU\n", "\n", "[Stage 41:==================================================> (10 + 1) / 11]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 214 ms, sys: 34 ms, total: 248 ms\n", "Wall time: 3min 54s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\r", " \r" ] } ], "source": [ "%%time\n", "\n", "from churn.augment import write_df\n", "\n", "write_df(billingEvents, \"billing_events\", partition_by=\"month\")\n", "write_df(customerMeta, \"customer_meta\", skip_replication=True)\n", "write_df(customerPhoneFeatures, \"customer_phone_features\")\n", "write_df(customerInternetFeatures.orderBy(\"customerID\"), \"customer_internet_features\")\n", "write_df(customerAccountFeatures, \"customer_account_features\")" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "billing_events 703200\n", "customer_meta 703200\n", "customer_phone_features 635200\n", "customer_internet_features 551200\n", "customer_account_features 703200\n" ] } ], "source": [ "for f in [\"billing_events\", \"customer_meta\", \"customer_phone_features\", \"customer_internet_features\", \"customer_account_features\"]:\n", " output_df = session.read.parquet(\"%s.parquet\" % f)\n", " print(f, output_df.select(\"customerID\").distinct().count())" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "import pyspark.sql.functions as F\n", "from functools import reduce\n", "\n", "output_dfs = []\n", "\n", "for f in [\"billing_events\", \"customer_meta\", \"customer_phone_features\", \"customer_internet_features\", \"customer_account_features\"]:\n", " output_dfs.append(\n", " session.read.parquet(\"%s.parquet\" % f).select(\n", " F.lit(f).alias(\"table\"),\n", " \"customerID\"\n", " )\n", " )\n", "\n", "all_customers = reduce(lambda l, r: l.unionAll(r), output_dfs)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-04-05 09:41:25,790 WARN rapids.GpuOverrides: \n", "!Exec cannot run on GPU because the Exec CollectLimitExec has been disabled, and is disabled by default because Collect Limit replacement can be slower on the GPU, if huge number of rows in a batch it could help by limiting the number of rows transferred from GPU to CPU. Set spark.rapids.sql.exec.CollectLimitExec to true if you wish to enable it\n", " @Partitioning could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression table#1189 could run on GPU\n", " @Expression approx_count_distinct(customerID#1179, 0.05, 0, 0) could run on GPU\n", " ! approx_count_distinct(customerID#1179, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\n", " @Expression customerID#1179 could run on GPU\n", " @Expression approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L could run on GPU\n", " @Expression table#1189 could run on GPU\n", " @Expression cast(approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L as string) AS approx_unique_customers#2118 could run on GPU\n", " @Expression cast(approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L as string) could run on GPU\n", " @Expression approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L could run on GPU\n", " !Exec cannot run on GPU because Columnar exchange without columnar children is inefficient\n", " @Partitioning could run on GPU\n", " @Expression table#1189 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression table#1189 could run on GPU\n", " @Expression partial_approx_count_distinct(customerID#1179, 0.05, 0, 0) could run on GPU\n", " ! approx_count_distinct(customerID#1179, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\n", " @Expression customerID#1179 could run on GPU\n", " @Expression MS[0]#1354L could run on GPU\n", " @Expression MS[1]#1355L could run on GPU\n", " @Expression MS[2]#1356L could run on GPU\n", " @Expression MS[3]#1357L could run on GPU\n", " @Expression MS[4]#1358L could run on GPU\n", " @Expression MS[5]#1359L could run on GPU\n", " @Expression MS[6]#1360L could run on GPU\n", " @Expression MS[7]#1361L could run on GPU\n", " @Expression MS[8]#1362L could run on GPU\n", " @Expression MS[9]#1363L could run on GPU\n", " @Expression MS[10]#1364L could run on GPU\n", " @Expression MS[11]#1365L could run on GPU\n", " @Expression MS[12]#1366L could run on GPU\n", " @Expression MS[13]#1367L could run on GPU\n", " @Expression MS[14]#1368L could run on GPU\n", " @Expression MS[15]#1369L could run on GPU\n", " @Expression MS[16]#1370L could run on GPU\n", " @Expression MS[17]#1371L could run on GPU\n", " @Expression MS[18]#1372L could run on GPU\n", " @Expression MS[19]#1373L could run on GPU\n", " @Expression MS[20]#1374L could run on GPU\n", " @Expression MS[21]#1375L could run on GPU\n", " @Expression MS[22]#1376L could run on GPU\n", " @Expression MS[23]#1377L could run on GPU\n", " @Expression MS[24]#1378L could run on GPU\n", " @Expression MS[25]#1379L could run on GPU\n", " @Expression MS[26]#1380L could run on GPU\n", " @Expression MS[27]#1381L could run on GPU\n", " @Expression MS[28]#1382L could run on GPU\n", " @Expression MS[29]#1383L could run on GPU\n", " @Expression MS[30]#1384L could run on GPU\n", " @Expression MS[31]#1385L could run on GPU\n", " @Expression MS[32]#1386L could run on GPU\n", " @Expression MS[33]#1387L could run on GPU\n", " @Expression MS[34]#1388L could run on GPU\n", " @Expression MS[35]#1389L could run on GPU\n", " @Expression MS[36]#1390L could run on GPU\n", " @Expression MS[37]#1391L could run on GPU\n", " @Expression MS[38]#1392L could run on GPU\n", " @Expression MS[39]#1393L could run on GPU\n", " @Expression MS[40]#1394L could run on GPU\n", " @Expression MS[41]#1395L could run on GPU\n", " @Expression MS[42]#1396L could run on GPU\n", " @Expression MS[43]#1397L could run on GPU\n", " @Expression MS[44]#1398L could run on GPU\n", " @Expression MS[45]#1399L could run on GPU\n", " @Expression MS[46]#1400L could run on GPU\n", " @Expression MS[47]#1401L could run on GPU\n", " @Expression MS[48]#1402L could run on GPU\n", " @Expression MS[49]#1403L could run on GPU\n", " @Expression MS[50]#1404L could run on GPU\n", " @Expression MS[51]#1405L could run on GPU\n", " @Expression table#1189 could run on GPU\n", " @Expression MS[0]#1406L could run on GPU\n", " @Expression MS[1]#1407L could run on GPU\n", " @Expression MS[2]#1408L could run on GPU\n", " @Expression MS[3]#1409L could run on GPU\n", " @Expression MS[4]#1410L could run on GPU\n", " @Expression MS[5]#1411L could run on GPU\n", " @Expression MS[6]#1412L could run on GPU\n", " @Expression MS[7]#1413L could run on GPU\n", " @Expression MS[8]#1414L could run on GPU\n", " @Expression MS[9]#1415L could run on GPU\n", " @Expression MS[10]#1416L could run on GPU\n", " @Expression MS[11]#1417L could run on GPU\n", " @Expression MS[12]#1418L could run on GPU\n", " @Expression MS[13]#1419L could run on GPU\n", " @Expression MS[14]#1420L could run on GPU\n", " @Expression MS[15]#1421L could run on GPU\n", " @Expression MS[16]#1422L could run on GPU\n", " @Expression MS[17]#1423L could run on GPU\n", " @Expression MS[18]#1424L could run on GPU\n", " @Expression MS[19]#1425L could run on GPU\n", " @Expression MS[20]#1426L could run on GPU\n", " @Expression MS[21]#1427L could run on GPU\n", " @Expression MS[22]#1428L could run on GPU\n", " @Expression MS[23]#1429L could run on GPU\n", " @Expression MS[24]#1430L could run on GPU\n", " @Expression MS[25]#1431L could run on GPU\n", " @Expression MS[26]#1432L could run on GPU\n", " @Expression MS[27]#1433L could run on GPU\n", " @Expression MS[28]#1434L could run on GPU\n", " @Expression MS[29]#1435L could run on GPU\n", " @Expression MS[30]#1436L could run on GPU\n", " @Expression MS[31]#1437L could run on GPU\n", " @Expression MS[32]#1438L could run on GPU\n", " @Expression MS[33]#1439L could run on GPU\n", " @Expression MS[34]#1440L could run on GPU\n", " @Expression MS[35]#1441L could run on GPU\n", " @Expression MS[36]#1442L could run on GPU\n", " @Expression MS[37]#1443L could run on GPU\n", " @Expression MS[38]#1444L could run on GPU\n", " @Expression MS[39]#1445L could run on GPU\n", " @Expression MS[40]#1446L could run on GPU\n", " @Expression MS[41]#1447L could run on GPU\n", " @Expression MS[42]#1448L could run on GPU\n", " @Expression MS[43]#1449L could run on GPU\n", " @Expression MS[44]#1450L could run on GPU\n", " @Expression MS[45]#1451L could run on GPU\n", " @Expression MS[46]#1452L could run on GPU\n", " @Expression MS[47]#1453L could run on GPU\n", " @Expression MS[48]#1454L could run on GPU\n", " @Expression MS[49]#1455L could run on GPU\n", " @Expression MS[50]#1456L could run on GPU\n", " @Expression MS[51]#1457L could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression 0#2539 could run on GPU\n", " @Expression approx_count_distinct(customerID#1883, 0.05, 0, 0) could run on GPU\n", " ! approx_count_distinct(customerID#1883, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\n", " @Expression customerID#1883 could run on GPU\n", " @Expression approx_count_distinct(customerID#1883, 0.05, 0, 0)#1671L could run on GPU\n", " @Expression all AS table#2537 could run on GPU\n", " @Expression all could run on GPU\n", " @Expression cast(approx_count_distinct(customerID#1883, 0.05, 0, 0)#1671L as string) AS approx_unique_customers#2538 could run on GPU\n", " @Expression cast(approx_count_distinct(customerID#1883, 0.05, 0, 0)#1671L as string) could run on GPU\n", " @Expression approx_count_distinct(customerID#1883, 0.05, 0, 0)#1671L could run on GPU\n", " !Exec cannot run on GPU because Columnar exchange without columnar children is inefficient\n", " @Partitioning could run on GPU\n", " @Expression 0#2539 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression 0 AS 0#2539 could run on GPU\n", " @Expression 0 could run on GPU\n", " @Expression partial_approx_count_distinct(customerID#1883, 0.05, 0, 0) could run on GPU\n", " ! approx_count_distinct(customerID#1883, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\n", " @Expression customerID#1883 could run on GPU\n", " @Expression MS[0]#1905L could run on GPU\n", " @Expression MS[1]#1906L could run on GPU\n", " @Expression MS[2]#1907L could run on GPU\n", " @Expression MS[3]#1908L could run on GPU\n", " @Expression MS[4]#1909L could run on GPU\n", " @Expression MS[5]#1910L could run on GPU\n", " @Expression MS[6]#1911L could run on GPU\n", " @Expression MS[7]#1912L could run on GPU\n", " @Expression MS[8]#1913L could run on GPU\n", " @Expression MS[9]#1914L could run on GPU\n", " @Expression MS[10]#1915L could run on GPU\n", " @Expression MS[11]#1916L could run on GPU\n", " @Expression MS[12]#1917L could run on GPU\n", " @Expression MS[13]#1918L could run on GPU\n", " @Expression MS[14]#1919L could run on GPU\n", " @Expression MS[15]#1920L could run on GPU\n", " @Expression MS[16]#1921L could run on GPU\n", " @Expression MS[17]#1922L could run on GPU\n", " @Expression MS[18]#1923L could run on GPU\n", " @Expression MS[19]#1924L could run on GPU\n", " @Expression MS[20]#1925L could run on GPU\n", " @Expression MS[21]#1926L could run on GPU\n", " @Expression MS[22]#1927L could run on GPU\n", " @Expression MS[23]#1928L could run on GPU\n", " @Expression MS[24]#1929L could run on GPU\n", " @Expression MS[25]#1930L could run on GPU\n", " @Expression MS[26]#1931L could run on GPU\n", " @Expression MS[27]#1932L could run on GPU\n", " @Expression MS[28]#1933L could run on GPU\n", " @Expression MS[29]#1934L could run on GPU\n", " @Expression MS[30]#1935L could run on GPU\n", " @Expression MS[31]#1936L could run on GPU\n", " @Expression MS[32]#1937L could run on GPU\n", " @Expression MS[33]#1938L could run on GPU\n", " @Expression MS[34]#1939L could run on GPU\n", " @Expression MS[35]#1940L could run on GPU\n", " @Expression MS[36]#1941L could run on GPU\n", " @Expression MS[37]#1942L could run on GPU\n", " @Expression MS[38]#1943L could run on GPU\n", " @Expression MS[39]#1944L could run on GPU\n", " @Expression MS[40]#1945L could run on GPU\n", " @Expression MS[41]#1946L could run on GPU\n", " @Expression MS[42]#1947L could run on GPU\n", " @Expression MS[43]#1948L could run on GPU\n", " @Expression MS[44]#1949L could run on GPU\n", " @Expression MS[45]#1950L could run on GPU\n", " @Expression MS[46]#1951L could run on GPU\n", " @Expression MS[47]#1952L could run on GPU\n", " @Expression MS[48]#1953L could run on GPU\n", " @Expression MS[49]#1954L could run on GPU\n", " @Expression MS[50]#1955L could run on GPU\n", " @Expression MS[51]#1956L could run on GPU\n", " @Expression 0#2539 could run on GPU\n", " @Expression MS[0]#1957L could run on GPU\n", " @Expression MS[1]#1958L could run on GPU\n", " @Expression MS[2]#1959L could run on GPU\n", " @Expression MS[3]#1960L could run on GPU\n", " @Expression MS[4]#1961L could run on GPU\n", " @Expression MS[5]#1962L could run on GPU\n", " @Expression MS[6]#1963L could run on GPU\n", " @Expression MS[7]#1964L could run on GPU\n", " @Expression MS[8]#1965L could run on GPU\n", " @Expression MS[9]#1966L could run on GPU\n", " @Expression MS[10]#1967L could run on GPU\n", " @Expression MS[11]#1968L could run on GPU\n", " @Expression MS[12]#1969L could run on GPU\n", " @Expression MS[13]#1970L could run on GPU\n", " @Expression MS[14]#1971L could run on GPU\n", " @Expression MS[15]#1972L could run on GPU\n", " @Expression MS[16]#1973L could run on GPU\n", " @Expression MS[17]#1974L could run on GPU\n", " @Expression MS[18]#1975L could run on GPU\n", " @Expression MS[19]#1976L could run on GPU\n", " @Expression MS[20]#1977L could run on GPU\n", " @Expression MS[21]#1978L could run on GPU\n", " @Expression MS[22]#1979L could run on GPU\n", " @Expression MS[23]#1980L could run on GPU\n", " @Expression MS[24]#1981L could run on GPU\n", " @Expression MS[25]#1982L could run on GPU\n", " @Expression MS[26]#1983L could run on GPU\n", " @Expression MS[27]#1984L could run on GPU\n", " @Expression MS[28]#1985L could run on GPU\n", " @Expression MS[29]#1986L could run on GPU\n", " @Expression MS[30]#1987L could run on GPU\n", " @Expression MS[31]#1988L could run on GPU\n", " @Expression MS[32]#1989L could run on GPU\n", " @Expression MS[33]#1990L could run on GPU\n", " @Expression MS[34]#1991L could run on GPU\n", " @Expression MS[35]#1992L could run on GPU\n", " @Expression MS[36]#1993L could run on GPU\n", " @Expression MS[37]#1994L could run on GPU\n", " @Expression MS[38]#1995L could run on GPU\n", " @Expression MS[39]#1996L could run on GPU\n", " @Expression MS[40]#1997L could run on GPU\n", " @Expression MS[41]#1998L could run on GPU\n", " @Expression MS[42]#1999L could run on GPU\n", " @Expression MS[43]#2000L could run on GPU\n", " @Expression MS[44]#2001L could run on GPU\n", " @Expression MS[45]#2002L could run on GPU\n", " @Expression MS[46]#2003L could run on GPU\n", " @Expression MS[47]#2004L could run on GPU\n", " @Expression MS[48]#2005L could run on GPU\n", " @Expression MS[49]#2006L could run on GPU\n", " @Expression MS[50]#2007L could run on GPU\n", " @Expression MS[51]#2008L could run on GPU\n", "\n", "2022-04-05 09:41:25,794 WARN rapids.GpuOverrides: \n", "!Exec cannot run on GPU because the Exec CollectLimitExec has been disabled, and is disabled by default because Collect Limit replacement can be slower on the GPU, if huge number of rows in a batch it could help by limiting the number of rows transferred from GPU to CPU. Set spark.rapids.sql.exec.CollectLimitExec to true if you wish to enable it\n", " @Partitioning could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression table#1189 could run on GPU\n", " @Expression approx_count_distinct(customerID#1179, 0.05, 0, 0) could run on GPU\n", " ! approx_count_distinct(customerID#1179, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\n", " @Expression customerID#1179 could run on GPU\n", " @Expression approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L could run on GPU\n", " @Expression table#1189 could run on GPU\n", " @Expression cast(approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L as string) AS approx_unique_customers#2118 could run on GPU\n", " @Expression cast(approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L as string) could run on GPU\n", " @Expression approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L could run on GPU\n", " !Exec cannot run on GPU because Columnar exchange without columnar children is inefficient\n", " @Partitioning could run on GPU\n", " @Expression table#1189 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression table#1189 could run on GPU\n", " @Expression partial_approx_count_distinct(customerID#1179, 0.05, 0, 0) could run on GPU\n", " ! approx_count_distinct(customerID#1179, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\n", " @Expression customerID#1179 could run on GPU\n", " @Expression MS[0]#1354L could run on GPU\n", " @Expression MS[1]#1355L could run on GPU\n", " @Expression MS[2]#1356L could run on GPU\n", " @Expression MS[3]#1357L could run on GPU\n", " @Expression MS[4]#1358L could run on GPU\n", " @Expression MS[5]#1359L could run on GPU\n", " @Expression MS[6]#1360L could run on GPU\n", " @Expression MS[7]#1361L could run on GPU\n", " @Expression MS[8]#1362L could run on GPU\n", " @Expression MS[9]#1363L could run on GPU\n", " @Expression MS[10]#1364L could run on GPU\n", " @Expression MS[11]#1365L could run on GPU\n", " @Expression MS[12]#1366L could run on GPU\n", " @Expression MS[13]#1367L could run on GPU\n", " @Expression MS[14]#1368L could run on GPU\n", " @Expression MS[15]#1369L could run on GPU\n", " @Expression MS[16]#1370L could run on GPU\n", " @Expression MS[17]#1371L could run on GPU\n", " @Expression MS[18]#1372L could run on GPU\n", " @Expression MS[19]#1373L could run on GPU\n", " @Expression MS[20]#1374L could run on GPU\n", " @Expression MS[21]#1375L could run on GPU\n", " @Expression MS[22]#1376L could run on GPU\n", " @Expression MS[23]#1377L could run on GPU\n", " @Expression MS[24]#1378L could run on GPU\n", " @Expression MS[25]#1379L could run on GPU\n", " @Expression MS[26]#1380L could run on GPU\n", " @Expression MS[27]#1381L could run on GPU\n", " @Expression MS[28]#1382L could run on GPU\n", " @Expression MS[29]#1383L could run on GPU\n", " @Expression MS[30]#1384L could run on GPU\n", " @Expression MS[31]#1385L could run on GPU\n", " @Expression MS[32]#1386L could run on GPU\n", " @Expression MS[33]#1387L could run on GPU\n", " @Expression MS[34]#1388L could run on GPU\n", " @Expression MS[35]#1389L could run on GPU\n", " @Expression MS[36]#1390L could run on GPU\n", " @Expression MS[37]#1391L could run on GPU\n", " @Expression MS[38]#1392L could run on GPU\n", " @Expression MS[39]#1393L could run on GPU\n", " @Expression MS[40]#1394L could run on GPU\n", " @Expression MS[41]#1395L could run on GPU\n", " @Expression MS[42]#1396L could run on GPU\n", " @Expression MS[43]#1397L could run on GPU\n", " @Expression MS[44]#1398L could run on GPU\n", " @Expression MS[45]#1399L could run on GPU\n", " @Expression MS[46]#1400L could run on GPU\n", " @Expression MS[47]#1401L could run on GPU\n", " @Expression MS[48]#1402L could run on GPU\n", " @Expression MS[49]#1403L could run on GPU\n", " @Expression MS[50]#1404L could run on GPU\n", " @Expression MS[51]#1405L could run on GPU\n", " @Expression table#1189 could run on GPU\n", " @Expression MS[0]#1406L could run on GPU\n", " @Expression MS[1]#1407L could run on GPU\n", " @Expression MS[2]#1408L could run on GPU\n", " @Expression MS[3]#1409L could run on GPU\n", " @Expression MS[4]#1410L could run on GPU\n", " @Expression MS[5]#1411L could run on GPU\n", " @Expression MS[6]#1412L could run on GPU\n", " @Expression MS[7]#1413L could run on GPU\n", " @Expression MS[8]#1414L could run on GPU\n", " @Expression MS[9]#1415L could run on GPU\n", " @Expression MS[10]#1416L could run on GPU\n", " @Expression MS[11]#1417L could run on GPU\n", " @Expression MS[12]#1418L could run on GPU\n", " @Expression MS[13]#1419L could run on GPU\n", " @Expression MS[14]#1420L could run on GPU\n", " @Expression MS[15]#1421L could run on GPU\n", " @Expression MS[16]#1422L could run on GPU\n", " @Expression MS[17]#1423L could run on GPU\n", " @Expression MS[18]#1424L could run on GPU\n", " @Expression MS[19]#1425L could run on GPU\n", " @Expression MS[20]#1426L could run on GPU\n", " @Expression MS[21]#1427L could run on GPU\n", " @Expression MS[22]#1428L could run on GPU\n", " @Expression MS[23]#1429L could run on GPU\n", " @Expression MS[24]#1430L could run on GPU\n", " @Expression MS[25]#1431L could run on GPU\n", " @Expression MS[26]#1432L could run on GPU\n", " @Expression MS[27]#1433L could run on GPU\n", " @Expression MS[28]#1434L could run on GPU\n", " @Expression MS[29]#1435L could run on GPU\n", " @Expression MS[30]#1436L could run on GPU\n", " @Expression MS[31]#1437L could run on GPU\n", " @Expression MS[32]#1438L could run on GPU\n", " @Expression MS[33]#1439L could run on GPU\n", " @Expression MS[34]#1440L could run on GPU\n", " @Expression MS[35]#1441L could run on GPU\n", " @Expression MS[36]#1442L could run on GPU\n", " @Expression MS[37]#1443L could run on GPU\n", " @Expression MS[38]#1444L could run on GPU\n", " @Expression MS[39]#1445L could run on GPU\n", " @Expression MS[40]#1446L could run on GPU\n", " @Expression MS[41]#1447L could run on GPU\n", " @Expression MS[42]#1448L could run on GPU\n", " @Expression MS[43]#1449L could run on GPU\n", " @Expression MS[44]#1450L could run on GPU\n", " @Expression MS[45]#1451L could run on GPU\n", " @Expression MS[46]#1452L could run on GPU\n", " @Expression MS[47]#1453L could run on GPU\n", " @Expression MS[48]#1454L could run on GPU\n", " @Expression MS[49]#1455L could run on GPU\n", " @Expression MS[50]#1456L could run on GPU\n", " @Expression MS[51]#1457L could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression 0#2539 could run on GPU\n", " @Expression approx_count_distinct(customerID#1883, 0.05, 0, 0) could run on GPU\n", " ! approx_count_distinct(customerID#1883, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\n", " @Expression customerID#1883 could run on GPU\n", " @Expression approx_count_distinct(customerID#1883, 0.05, 0, 0)#1671L could run on GPU\n", " @Expression all AS table#2537 could run on GPU\n", " @Expression all could run on GPU\n", " @Expression cast(approx_count_distinct(customerID#1883, 0.05, 0, 0)#1671L as string) AS approx_unique_customers#2538 could run on GPU\n", " @Expression cast(approx_count_distinct(customerID#1883, 0.05, 0, 0)#1671L as string) could run on GPU\n", " @Expression approx_count_distinct(customerID#1883, 0.05, 0, 0)#1671L could run on GPU\n", " !Exec cannot run on GPU because Columnar exchange without columnar children is inefficient\n", " @Partitioning could run on GPU\n", " @Expression 0#2539 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression 0 AS 0#2539 could run on GPU\n", " @Expression 0 could run on GPU\n", " @Expression partial_approx_count_distinct(customerID#1883, 0.05, 0, 0) could run on GPU\n", " ! approx_count_distinct(customerID#1883, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\n", " @Expression customerID#1883 could run on GPU\n", " @Expression MS[0]#1905L could run on GPU\n", " @Expression MS[1]#1906L could run on GPU\n", " @Expression MS[2]#1907L could run on GPU\n", " @Expression MS[3]#1908L could run on GPU\n", " @Expression MS[4]#1909L could run on GPU\n", " @Expression MS[5]#1910L could run on GPU\n", " @Expression MS[6]#1911L could run on GPU\n", " @Expression MS[7]#1912L could run on GPU\n", " @Expression MS[8]#1913L could run on GPU\n", " @Expression MS[9]#1914L could run on GPU\n", " @Expression MS[10]#1915L could run on GPU\n", " @Expression MS[11]#1916L could run on GPU\n", " @Expression MS[12]#1917L could run on GPU\n", " @Expression MS[13]#1918L could run on GPU\n", " @Expression MS[14]#1919L could run on GPU\n", " @Expression MS[15]#1920L could run on GPU\n", " @Expression MS[16]#1921L could run on GPU\n", " @Expression MS[17]#1922L could run on GPU\n", " @Expression MS[18]#1923L could run on GPU\n", " @Expression MS[19]#1924L could run on GPU\n", " @Expression MS[20]#1925L could run on GPU\n", " @Expression MS[21]#1926L could run on GPU\n", " @Expression MS[22]#1927L could run on GPU\n", " @Expression MS[23]#1928L could run on GPU\n", " @Expression MS[24]#1929L could run on GPU\n", " @Expression MS[25]#1930L could run on GPU\n", " @Expression MS[26]#1931L could run on GPU\n", " @Expression MS[27]#1932L could run on GPU\n", " @Expression MS[28]#1933L could run on GPU\n", " @Expression MS[29]#1934L could run on GPU\n", " @Expression MS[30]#1935L could run on GPU\n", " @Expression MS[31]#1936L could run on GPU\n", " @Expression MS[32]#1937L could run on GPU\n", " @Expression MS[33]#1938L could run on GPU\n", " @Expression MS[34]#1939L could run on GPU\n", " @Expression MS[35]#1940L could run on GPU\n", " @Expression MS[36]#1941L could run on GPU\n", " @Expression MS[37]#1942L could run on GPU\n", " @Expression MS[38]#1943L could run on GPU\n", " @Expression MS[39]#1944L could run on GPU\n", " @Expression MS[40]#1945L could run on GPU\n", " @Expression MS[41]#1946L could run on GPU\n", " @Expression MS[42]#1947L could run on GPU\n", " @Expression MS[43]#1948L could run on GPU\n", " @Expression MS[44]#1949L could run on GPU\n", " @Expression MS[45]#1950L could run on GPU\n", " @Expression MS[46]#1951L could run on GPU\n", " @Expression MS[47]#1952L could run on GPU\n", " @Expression MS[48]#1953L could run on GPU\n", " @Expression MS[49]#1954L could run on GPU\n", " @Expression MS[50]#1955L could run on GPU\n", " @Expression MS[51]#1956L could run on GPU\n", " @Expression 0#2539 could run on GPU\n", " @Expression MS[0]#1957L could run on GPU\n", " @Expression MS[1]#1958L could run on GPU\n", " @Expression MS[2]#1959L could run on GPU\n", " @Expression MS[3]#1960L could run on GPU\n", " @Expression MS[4]#1961L could run on GPU\n", " @Expression MS[5]#1962L could run on GPU\n", " @Expression MS[6]#1963L could run on GPU\n", " @Expression MS[7]#1964L could run on GPU\n", " @Expression MS[8]#1965L could run on GPU\n", " @Expression MS[9]#1966L could run on GPU\n", " @Expression MS[10]#1967L could run on GPU\n", " @Expression MS[11]#1968L could run on GPU\n", " @Expression MS[12]#1969L could run on GPU\n", " @Expression MS[13]#1970L could run on GPU\n", " @Expression MS[14]#1971L could run on GPU\n", " @Expression MS[15]#1972L could run on GPU\n", " @Expression MS[16]#1973L could run on GPU\n", " @Expression MS[17]#1974L could run on GPU\n", " @Expression MS[18]#1975L could run on GPU\n", " @Expression MS[19]#1976L could run on GPU\n", " @Expression MS[20]#1977L could run on GPU\n", " @Expression MS[21]#1978L could run on GPU\n", " @Expression MS[22]#1979L could run on GPU\n", " @Expression MS[23]#1980L could run on GPU\n", " @Expression MS[24]#1981L could run on GPU\n", " @Expression MS[25]#1982L could run on GPU\n", " @Expression MS[26]#1983L could run on GPU\n", " @Expression MS[27]#1984L could run on GPU\n", " @Expression MS[28]#1985L could run on GPU\n", " @Expression MS[29]#1986L could run on GPU\n", " @Expression MS[30]#1987L could run on GPU\n", " @Expression MS[31]#1988L could run on GPU\n", " @Expression MS[32]#1989L could run on GPU\n", " @Expression MS[33]#1990L could run on GPU\n", " @Expression MS[34]#1991L could run on GPU\n", " @Expression MS[35]#1992L could run on GPU\n", " @Expression MS[36]#1993L could run on GPU\n", " @Expression MS[37]#1994L could run on GPU\n", " @Expression MS[38]#1995L could run on GPU\n", " @Expression MS[39]#1996L could run on GPU\n", " @Expression MS[40]#1997L could run on GPU\n", " @Expression MS[41]#1998L could run on GPU\n", " @Expression MS[42]#1999L could run on GPU\n", " @Expression MS[43]#2000L could run on GPU\n", " @Expression MS[44]#2001L could run on GPU\n", " @Expression MS[45]#2002L could run on GPU\n", " @Expression MS[46]#2003L could run on GPU\n", " @Expression MS[47]#2004L could run on GPU\n", " @Expression MS[48]#2005L could run on GPU\n", " @Expression MS[49]#2006L could run on GPU\n", " @Expression MS[50]#2007L could run on GPU\n", " @Expression MS[51]#2008L could run on GPU\n", "\n", "2022-04-05 09:41:25,797 WARN rapids.GpuOverrides: \n", "!Exec cannot run on GPU because the Exec CollectLimitExec has been disabled, and is disabled by default because Collect Limit replacement can be slower on the GPU, if huge number of rows in a batch it could help by limiting the number of rows transferred from GPU to CPU. Set spark.rapids.sql.exec.CollectLimitExec to true if you wish to enable it\n", " @Partitioning could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression table#1189 could run on GPU\n", " @Expression approx_count_distinct(customerID#1179, 0.05, 0, 0) could run on GPU\n", " ! approx_count_distinct(customerID#1179, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\n", " @Expression customerID#1179 could run on GPU\n", " @Expression approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L could run on GPU\n", " @Expression table#1189 could run on GPU\n", " @Expression cast(approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L as string) AS approx_unique_customers#2118 could run on GPU\n", " @Expression cast(approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L as string) could run on GPU\n", " @Expression approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L could run on GPU\n", " !Exec cannot run on GPU because Columnar exchange without columnar children is inefficient\n", " @Partitioning could run on GPU\n", " @Expression table#1189 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression table#1189 could run on GPU\n", " @Expression partial_approx_count_distinct(customerID#1179, 0.05, 0, 0) could run on GPU\n", " ! approx_count_distinct(customerID#1179, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\n", " @Expression customerID#1179 could run on GPU\n", " @Expression MS[0]#1354L could run on GPU\n", " @Expression MS[1]#1355L could run on GPU\n", " @Expression MS[2]#1356L could run on GPU\n", " @Expression MS[3]#1357L could run on GPU\n", " @Expression MS[4]#1358L could run on GPU\n", " @Expression MS[5]#1359L could run on GPU\n", " @Expression MS[6]#1360L could run on GPU\n", " @Expression MS[7]#1361L could run on GPU\n", " @Expression MS[8]#1362L could run on GPU\n", " @Expression MS[9]#1363L could run on GPU\n", " @Expression MS[10]#1364L could run on GPU\n", " @Expression MS[11]#1365L could run on GPU\n", " @Expression MS[12]#1366L could run on GPU\n", " @Expression MS[13]#1367L could run on GPU\n", " @Expression MS[14]#1368L could run on GPU\n", " @Expression MS[15]#1369L could run on GPU\n", " @Expression MS[16]#1370L could run on GPU\n", " @Expression MS[17]#1371L could run on GPU\n", " @Expression MS[18]#1372L could run on GPU\n", " @Expression MS[19]#1373L could run on GPU\n", " @Expression MS[20]#1374L could run on GPU\n", " @Expression MS[21]#1375L could run on GPU\n", " @Expression MS[22]#1376L could run on GPU\n", " @Expression MS[23]#1377L could run on GPU\n", " @Expression MS[24]#1378L could run on GPU\n", " @Expression MS[25]#1379L could run on GPU\n", " @Expression MS[26]#1380L could run on GPU\n", " @Expression MS[27]#1381L could run on GPU\n", " @Expression MS[28]#1382L could run on GPU\n", " @Expression MS[29]#1383L could run on GPU\n", " @Expression MS[30]#1384L could run on GPU\n", " @Expression MS[31]#1385L could run on GPU\n", " @Expression MS[32]#1386L could run on GPU\n", " @Expression MS[33]#1387L could run on GPU\n", " @Expression MS[34]#1388L could run on GPU\n", " @Expression MS[35]#1389L could run on GPU\n", " @Expression MS[36]#1390L could run on GPU\n", " @Expression MS[37]#1391L could run on GPU\n", " @Expression MS[38]#1392L could run on GPU\n", " @Expression MS[39]#1393L could run on GPU\n", " @Expression MS[40]#1394L could run on GPU\n", " @Expression MS[41]#1395L could run on GPU\n", " @Expression MS[42]#1396L could run on GPU\n", " @Expression MS[43]#1397L could run on GPU\n", " @Expression MS[44]#1398L could run on GPU\n", " @Expression MS[45]#1399L could run on GPU\n", " @Expression MS[46]#1400L could run on GPU\n", " @Expression MS[47]#1401L could run on GPU\n", " @Expression MS[48]#1402L could run on GPU\n", " @Expression MS[49]#1403L could run on GPU\n", " @Expression MS[50]#1404L could run on GPU\n", " @Expression MS[51]#1405L could run on GPU\n", " @Expression table#1189 could run on GPU\n", " @Expression MS[0]#1406L could run on GPU\n", " @Expression MS[1]#1407L could run on GPU\n", " @Expression MS[2]#1408L could run on GPU\n", " @Expression MS[3]#1409L could run on GPU\n", " @Expression MS[4]#1410L could run on GPU\n", " @Expression MS[5]#1411L could run on GPU\n", " @Expression MS[6]#1412L could run on GPU\n", " @Expression MS[7]#1413L could run on GPU\n", " @Expression MS[8]#1414L could run on GPU\n", " @Expression MS[9]#1415L could run on GPU\n", " @Expression MS[10]#1416L could run on GPU\n", " @Expression MS[11]#1417L could run on GPU\n", " @Expression MS[12]#1418L could run on GPU\n", " @Expression MS[13]#1419L could run on GPU\n", " @Expression MS[14]#1420L could run on GPU\n", " @Expression MS[15]#1421L could run on GPU\n", " @Expression MS[16]#1422L could run on GPU\n", " @Expression MS[17]#1423L could run on GPU\n", " @Expression MS[18]#1424L could run on GPU\n", " @Expression MS[19]#1425L could run on GPU\n", " @Expression MS[20]#1426L could run on GPU\n", " @Expression MS[21]#1427L could run on GPU\n", " @Expression MS[22]#1428L could run on GPU\n", " @Expression MS[23]#1429L could run on GPU\n", " @Expression MS[24]#1430L could run on GPU\n", " @Expression MS[25]#1431L could run on GPU\n", " @Expression MS[26]#1432L could run on GPU\n", " @Expression MS[27]#1433L could run on GPU\n", " @Expression MS[28]#1434L could run on GPU\n", " @Expression MS[29]#1435L could run on GPU\n", " @Expression MS[30]#1436L could run on GPU\n", " @Expression MS[31]#1437L could run on GPU\n", " @Expression MS[32]#1438L could run on GPU\n", " @Expression MS[33]#1439L could run on GPU\n", " @Expression MS[34]#1440L could run on GPU\n", " @Expression MS[35]#1441L could run on GPU\n", " @Expression MS[36]#1442L could run on GPU\n", " @Expression MS[37]#1443L could run on GPU\n", " @Expression MS[38]#1444L could run on GPU\n", " @Expression MS[39]#1445L could run on GPU\n", " @Expression MS[40]#1446L could run on GPU\n", " @Expression MS[41]#1447L could run on GPU\n", " @Expression MS[42]#1448L could run on GPU\n", " @Expression MS[43]#1449L could run on GPU\n", " @Expression MS[44]#1450L could run on GPU\n", " @Expression MS[45]#1451L could run on GPU\n", " @Expression MS[46]#1452L could run on GPU\n", " @Expression MS[47]#1453L could run on GPU\n", " @Expression MS[48]#1454L could run on GPU\n", " @Expression MS[49]#1455L could run on GPU\n", " @Expression MS[50]#1456L could run on GPU\n", " @Expression MS[51]#1457L could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression 0#2539 could run on GPU\n", " @Expression approx_count_distinct(customerID#1883, 0.05, 0, 0) could run on GPU\n", " ! approx_count_distinct(customerID#1883, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\n", " @Expression customerID#1883 could run on GPU\n", " @Expression approx_count_distinct(customerID#1883, 0.05, 0, 0)#1671L could run on GPU\n", " @Expression all AS table#2537 could run on GPU\n", " @Expression all could run on GPU\n", " @Expression cast(approx_count_distinct(customerID#1883, 0.05, 0, 0)#1671L as string) AS approx_unique_customers#2538 could run on GPU\n", " @Expression cast(approx_count_distinct(customerID#1883, 0.05, 0, 0)#1671L as string) could run on GPU\n", " @Expression approx_count_distinct(customerID#1883, 0.05, 0, 0)#1671L could run on GPU\n", " !Exec cannot run on GPU because Columnar exchange without columnar children is inefficient\n", " @Partitioning could run on GPU\n", " @Expression 0#2539 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression 0 AS 0#2539 could run on GPU\n", " @Expression 0 could run on GPU\n", " @Expression partial_approx_count_distinct(customerID#1883, 0.05, 0, 0) could run on GPU\n", " ! approx_count_distinct(customerID#1883, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\n", " @Expression customerID#1883 could run on GPU\n", " @Expression MS[0]#1905L could run on GPU\n", " @Expression MS[1]#1906L could run on GPU\n", " @Expression MS[2]#1907L could run on GPU\n", " @Expression MS[3]#1908L could run on GPU\n", " @Expression MS[4]#1909L could run on GPU\n", " @Expression MS[5]#1910L could run on GPU\n", " @Expression MS[6]#1911L could run on GPU\n", " @Expression MS[7]#1912L could run on GPU\n", " @Expression MS[8]#1913L could run on GPU\n", " @Expression MS[9]#1914L could run on GPU\n", " @Expression MS[10]#1915L could run on GPU\n", " @Expression MS[11]#1916L could run on GPU\n", " @Expression MS[12]#1917L could run on GPU\n", " @Expression MS[13]#1918L could run on GPU\n", " @Expression MS[14]#1919L could run on GPU\n", " @Expression MS[15]#1920L could run on GPU\n", " @Expression MS[16]#1921L could run on GPU\n", " @Expression MS[17]#1922L could run on GPU\n", " @Expression MS[18]#1923L could run on GPU\n", " @Expression MS[19]#1924L could run on GPU\n", " @Expression MS[20]#1925L could run on GPU\n", " @Expression MS[21]#1926L could run on GPU\n", " @Expression MS[22]#1927L could run on GPU\n", " @Expression MS[23]#1928L could run on GPU\n", " @Expression MS[24]#1929L could run on GPU\n", " @Expression MS[25]#1930L could run on GPU\n", " @Expression MS[26]#1931L could run on GPU\n", " @Expression MS[27]#1932L could run on GPU\n", " @Expression MS[28]#1933L could run on GPU\n", " @Expression MS[29]#1934L could run on GPU\n", " @Expression MS[30]#1935L could run on GPU\n", " @Expression MS[31]#1936L could run on GPU\n", " @Expression MS[32]#1937L could run on GPU\n", " @Expression MS[33]#1938L could run on GPU\n", " @Expression MS[34]#1939L could run on GPU\n", " @Expression MS[35]#1940L could run on GPU\n", " @Expression MS[36]#1941L could run on GPU\n", " @Expression MS[37]#1942L could run on GPU\n", " @Expression MS[38]#1943L could run on GPU\n", " @Expression MS[39]#1944L could run on GPU\n", " @Expression MS[40]#1945L could run on GPU\n", " @Expression MS[41]#1946L could run on GPU\n", " @Expression MS[42]#1947L could run on GPU\n", " @Expression MS[43]#1948L could run on GPU\n", " @Expression MS[44]#1949L could run on GPU\n", " @Expression MS[45]#1950L could run on GPU\n", " @Expression MS[46]#1951L could run on GPU\n", " @Expression MS[47]#1952L could run on GPU\n", " @Expression MS[48]#1953L could run on GPU\n", " @Expression MS[49]#1954L could run on GPU\n", " @Expression MS[50]#1955L could run on GPU\n", " @Expression MS[51]#1956L could run on GPU\n", " @Expression 0#2539 could run on GPU\n", " @Expression MS[0]#1957L could run on GPU\n", " @Expression MS[1]#1958L could run on GPU\n", " @Expression MS[2]#1959L could run on GPU\n", " @Expression MS[3]#1960L could run on GPU\n", " @Expression MS[4]#1961L could run on GPU\n", " @Expression MS[5]#1962L could run on GPU\n", " @Expression MS[6]#1963L could run on GPU\n", " @Expression MS[7]#1964L could run on GPU\n", " @Expression MS[8]#1965L could run on GPU\n", " @Expression MS[9]#1966L could run on GPU\n", " @Expression MS[10]#1967L could run on GPU\n", " @Expression MS[11]#1968L could run on GPU\n", " @Expression MS[12]#1969L could run on GPU\n", " @Expression MS[13]#1970L could run on GPU\n", " @Expression MS[14]#1971L could run on GPU\n", " @Expression MS[15]#1972L could run on GPU\n", " @Expression MS[16]#1973L could run on GPU\n", " @Expression MS[17]#1974L could run on GPU\n", " @Expression MS[18]#1975L could run on GPU\n", " @Expression MS[19]#1976L could run on GPU\n", " @Expression MS[20]#1977L could run on GPU\n", " @Expression MS[21]#1978L could run on GPU\n", " @Expression MS[22]#1979L could run on GPU\n", " @Expression MS[23]#1980L could run on GPU\n", " @Expression MS[24]#1981L could run on GPU\n", " @Expression MS[25]#1982L could run on GPU\n", " @Expression MS[26]#1983L could run on GPU\n", " @Expression MS[27]#1984L could run on GPU\n", " @Expression MS[28]#1985L could run on GPU\n", " @Expression MS[29]#1986L could run on GPU\n", " @Expression MS[30]#1987L could run on GPU\n", " @Expression MS[31]#1988L could run on GPU\n", " @Expression MS[32]#1989L could run on GPU\n", " @Expression MS[33]#1990L could run on GPU\n", " @Expression MS[34]#1991L could run on GPU\n", " @Expression MS[35]#1992L could run on GPU\n", " @Expression MS[36]#1993L could run on GPU\n", " @Expression MS[37]#1994L could run on GPU\n", " @Expression MS[38]#1995L could run on GPU\n", " @Expression MS[39]#1996L could run on GPU\n", " @Expression MS[40]#1997L could run on GPU\n", " @Expression MS[41]#1998L could run on GPU\n", " @Expression MS[42]#1999L could run on GPU\n", " @Expression MS[43]#2000L could run on GPU\n", " @Expression MS[44]#2001L could run on GPU\n", " @Expression MS[45]#2002L could run on GPU\n", " @Expression MS[46]#2003L could run on GPU\n", " @Expression MS[47]#2004L could run on GPU\n", " @Expression MS[48]#2005L could run on GPU\n", " @Expression MS[49]#2006L could run on GPU\n", " @Expression MS[50]#2007L could run on GPU\n", " @Expression MS[51]#2008L could run on GPU\n", "\n", "2022-04-05 09:41:25,801 WARN util.package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.\n", "2022-04-05 09:41:25,806 WARN rapids.GpuOverrides: \n", "!Exec cannot run on GPU because Columnar exchange without columnar children is inefficient\n", " @Partitioning could run on GPU\n", " @Expression table#1189 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression table#1189 could run on GPU\n", " @Expression partial_approx_count_distinct(customerID#1179, 0.05, 0, 0) could run on GPU\n", " ! approx_count_distinct(customerID#1179, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\n", " @Expression customerID#1179 could run on GPU\n", " @Expression MS[0]#1354L could run on GPU\n", " @Expression MS[1]#1355L could run on GPU\n", " @Expression MS[2]#1356L could run on GPU\n", " @Expression MS[3]#1357L could run on GPU\n", " @Expression MS[4]#1358L could run on GPU\n", " @Expression MS[5]#1359L could run on GPU\n", " @Expression MS[6]#1360L could run on GPU\n", " @Expression MS[7]#1361L could run on GPU\n", " @Expression MS[8]#1362L could run on GPU\n", " @Expression MS[9]#1363L could run on GPU\n", " @Expression MS[10]#1364L could run on GPU\n", " @Expression MS[11]#1365L could run on GPU\n", " @Expression MS[12]#1366L could run on GPU\n", " @Expression MS[13]#1367L could run on GPU\n", " @Expression MS[14]#1368L could run on GPU\n", " @Expression MS[15]#1369L could run on GPU\n", " @Expression MS[16]#1370L could run on GPU\n", " @Expression MS[17]#1371L could run on GPU\n", " @Expression MS[18]#1372L could run on GPU\n", " @Expression MS[19]#1373L could run on GPU\n", " @Expression MS[20]#1374L could run on GPU\n", " @Expression MS[21]#1375L could run on GPU\n", " @Expression MS[22]#1376L could run on GPU\n", " @Expression MS[23]#1377L could run on GPU\n", " @Expression MS[24]#1378L could run on GPU\n", " @Expression MS[25]#1379L could run on GPU\n", " @Expression MS[26]#1380L could run on GPU\n", " @Expression MS[27]#1381L could run on GPU\n", " @Expression MS[28]#1382L could run on GPU\n", " @Expression MS[29]#1383L could run on GPU\n", " @Expression MS[30]#1384L could run on GPU\n", " @Expression MS[31]#1385L could run on GPU\n", " @Expression MS[32]#1386L could run on GPU\n", " @Expression MS[33]#1387L could run on GPU\n", " @Expression MS[34]#1388L could run on GPU\n", " @Expression MS[35]#1389L could run on GPU\n", " @Expression MS[36]#1390L could run on GPU\n", " @Expression MS[37]#1391L could run on GPU\n", " @Expression MS[38]#1392L could run on GPU\n", " @Expression MS[39]#1393L could run on GPU\n", " @Expression MS[40]#1394L could run on GPU\n", " @Expression MS[41]#1395L could run on GPU\n", " @Expression MS[42]#1396L could run on GPU\n", " @Expression MS[43]#1397L could run on GPU\n", " @Expression MS[44]#1398L could run on GPU\n", " @Expression MS[45]#1399L could run on GPU\n", " @Expression MS[46]#1400L could run on GPU\n", " @Expression MS[47]#1401L could run on GPU\n", " @Expression MS[48]#1402L could run on GPU\n", " @Expression MS[49]#1403L could run on GPU\n", " @Expression MS[50]#1404L could run on GPU\n", " @Expression MS[51]#1405L could run on GPU\n", " @Expression table#1189 could run on GPU\n", " @Expression MS[0]#1406L could run on GPU\n", " @Expression MS[1]#1407L could run on GPU\n", " @Expression MS[2]#1408L could run on GPU\n", " @Expression MS[3]#1409L could run on GPU\n", " @Expression MS[4]#1410L could run on GPU\n", " @Expression MS[5]#1411L could run on GPU\n", " @Expression MS[6]#1412L could run on GPU\n", " @Expression MS[7]#1413L could run on GPU\n", " @Expression MS[8]#1414L could run on GPU\n", " @Expression MS[9]#1415L could run on GPU\n", " @Expression MS[10]#1416L could run on GPU\n", " @Expression MS[11]#1417L could run on GPU\n", " @Expression MS[12]#1418L could run on GPU\n", " @Expression MS[13]#1419L could run on GPU\n", " @Expression MS[14]#1420L could run on GPU\n", " @Expression MS[15]#1421L could run on GPU\n", " @Expression MS[16]#1422L could run on GPU\n", " @Expression MS[17]#1423L could run on GPU\n", " @Expression MS[18]#1424L could run on GPU\n", " @Expression MS[19]#1425L could run on GPU\n", " @Expression MS[20]#1426L could run on GPU\n", " @Expression MS[21]#1427L could run on GPU\n", " @Expression MS[22]#1428L could run on GPU\n", " @Expression MS[23]#1429L could run on GPU\n", " @Expression MS[24]#1430L could run on GPU\n", " @Expression MS[25]#1431L could run on GPU\n", " @Expression MS[26]#1432L could run on GPU\n", " @Expression MS[27]#1433L could run on GPU\n", " @Expression MS[28]#1434L could run on GPU\n", " @Expression MS[29]#1435L could run on GPU\n", " @Expression MS[30]#1436L could run on GPU\n", " @Expression MS[31]#1437L could run on GPU\n", " @Expression MS[32]#1438L could run on GPU\n", " @Expression MS[33]#1439L could run on GPU\n", " @Expression MS[34]#1440L could run on GPU\n", " @Expression MS[35]#1441L could run on GPU\n", " @Expression MS[36]#1442L could run on GPU\n", " @Expression MS[37]#1443L could run on GPU\n", " @Expression MS[38]#1444L could run on GPU\n", " @Expression MS[39]#1445L could run on GPU\n", " @Expression MS[40]#1446L could run on GPU\n", " @Expression MS[41]#1447L could run on GPU\n", " @Expression MS[42]#1448L could run on GPU\n", " @Expression MS[43]#1449L could run on GPU\n", " @Expression MS[44]#1450L could run on GPU\n", " @Expression MS[45]#1451L could run on GPU\n", " @Expression MS[46]#1452L could run on GPU\n", " @Expression MS[47]#1453L could run on GPU\n", " @Expression MS[48]#1454L could run on GPU\n", " @Expression MS[49]#1455L could run on GPU\n", " @Expression MS[50]#1456L could run on GPU\n", " @Expression MS[51]#1457L could run on GPU\n", "\n", "2022-04-05 09:41:25,810 WARN rapids.GpuOverrides: \n", "!Exec cannot run on GPU because Columnar exchange without columnar children is inefficient\n", " @Partitioning could run on GPU\n", " @Expression 0#2539 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression 0 AS 0#2539 could run on GPU\n", " @Expression 0 could run on GPU\n", " @Expression partial_approx_count_distinct(customerID#1883, 0.05, 0, 0) could run on GPU\n", " ! approx_count_distinct(customerID#1883, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\n", " @Expression customerID#1883 could run on GPU\n", " @Expression MS[0]#1905L could run on GPU\n", " @Expression MS[1]#1906L could run on GPU\n", " @Expression MS[2]#1907L could run on GPU\n", " @Expression MS[3]#1908L could run on GPU\n", " @Expression MS[4]#1909L could run on GPU\n", " @Expression MS[5]#1910L could run on GPU\n", " @Expression MS[6]#1911L could run on GPU\n", " @Expression MS[7]#1912L could run on GPU\n", " @Expression MS[8]#1913L could run on GPU\n", " @Expression MS[9]#1914L could run on GPU\n", " @Expression MS[10]#1915L could run on GPU\n", " @Expression MS[11]#1916L could run on GPU\n", " @Expression MS[12]#1917L could run on GPU\n", " @Expression MS[13]#1918L could run on GPU\n", " @Expression MS[14]#1919L could run on GPU\n", " @Expression MS[15]#1920L could run on GPU\n", " @Expression MS[16]#1921L could run on GPU\n", " @Expression MS[17]#1922L could run on GPU\n", " @Expression MS[18]#1923L could run on GPU\n", " @Expression MS[19]#1924L could run on GPU\n", " @Expression MS[20]#1925L could run on GPU\n", " @Expression MS[21]#1926L could run on GPU\n", " @Expression MS[22]#1927L could run on GPU\n", " @Expression MS[23]#1928L could run on GPU\n", " @Expression MS[24]#1929L could run on GPU\n", " @Expression MS[25]#1930L could run on GPU\n", " @Expression MS[26]#1931L could run on GPU\n", " @Expression MS[27]#1932L could run on GPU\n", " @Expression MS[28]#1933L could run on GPU\n", " @Expression MS[29]#1934L could run on GPU\n", " @Expression MS[30]#1935L could run on GPU\n", " @Expression MS[31]#1936L could run on GPU\n", " @Expression MS[32]#1937L could run on GPU\n", " @Expression MS[33]#1938L could run on GPU\n", " @Expression MS[34]#1939L could run on GPU\n", " @Expression MS[35]#1940L could run on GPU\n", " @Expression MS[36]#1941L could run on GPU\n", " @Expression MS[37]#1942L could run on GPU\n", " @Expression MS[38]#1943L could run on GPU\n", " @Expression MS[39]#1944L could run on GPU\n", " @Expression MS[40]#1945L could run on GPU\n", " @Expression MS[41]#1946L could run on GPU\n", " @Expression MS[42]#1947L could run on GPU\n", " @Expression MS[43]#1948L could run on GPU\n", " @Expression MS[44]#1949L could run on GPU\n", " @Expression MS[45]#1950L could run on GPU\n", " @Expression MS[46]#1951L could run on GPU\n", " @Expression MS[47]#1952L could run on GPU\n", " @Expression MS[48]#1953L could run on GPU\n", " @Expression MS[49]#1954L could run on GPU\n", " @Expression MS[50]#1955L could run on GPU\n", " @Expression MS[51]#1956L could run on GPU\n", " @Expression 0#2539 could run on GPU\n", " @Expression MS[0]#1957L could run on GPU\n", " @Expression MS[1]#1958L could run on GPU\n", " @Expression MS[2]#1959L could run on GPU\n", " @Expression MS[3]#1960L could run on GPU\n", " @Expression MS[4]#1961L could run on GPU\n", " @Expression MS[5]#1962L could run on GPU\n", " @Expression MS[6]#1963L could run on GPU\n", " @Expression MS[7]#1964L could run on GPU\n", " @Expression MS[8]#1965L could run on GPU\n", " @Expression MS[9]#1966L could run on GPU\n", " @Expression MS[10]#1967L could run on GPU\n", " @Expression MS[11]#1968L could run on GPU\n", " @Expression MS[12]#1969L could run on GPU\n", " @Expression MS[13]#1970L could run on GPU\n", " @Expression MS[14]#1971L could run on GPU\n", " @Expression MS[15]#1972L could run on GPU\n", " @Expression MS[16]#1973L could run on GPU\n", " @Expression MS[17]#1974L could run on GPU\n", " @Expression MS[18]#1975L could run on GPU\n", " @Expression MS[19]#1976L could run on GPU\n", " @Expression MS[20]#1977L could run on GPU\n", " @Expression MS[21]#1978L could run on GPU\n", " @Expression MS[22]#1979L could run on GPU\n", " @Expression MS[23]#1980L could run on GPU\n", " @Expression MS[24]#1981L could run on GPU\n", " @Expression MS[25]#1982L could run on GPU\n", " @Expression MS[26]#1983L could run on GPU\n", " @Expression MS[27]#1984L could run on GPU\n", " @Expression MS[28]#1985L could run on GPU\n", " @Expression MS[29]#1986L could run on GPU\n", " @Expression MS[30]#1987L could run on GPU\n", " @Expression MS[31]#1988L could run on GPU\n", " @Expression MS[32]#1989L could run on GPU\n", " @Expression MS[33]#1990L could run on GPU\n", " @Expression MS[34]#1991L could run on GPU\n", " @Expression MS[35]#1992L could run on GPU\n", " @Expression MS[36]#1993L could run on GPU\n", " @Expression MS[37]#1994L could run on GPU\n", " @Expression MS[38]#1995L could run on GPU\n", " @Expression MS[39]#1996L could run on GPU\n", " @Expression MS[40]#1997L could run on GPU\n", " @Expression MS[41]#1998L could run on GPU\n", " @Expression MS[42]#1999L could run on GPU\n", " @Expression MS[43]#2000L could run on GPU\n", " @Expression MS[44]#2001L could run on GPU\n", " @Expression MS[45]#2002L could run on GPU\n", " @Expression MS[46]#2003L could run on GPU\n", " @Expression MS[47]#2004L could run on GPU\n", " @Expression MS[48]#2005L could run on GPU\n", " @Expression MS[49]#2006L could run on GPU\n", " @Expression MS[50]#2007L could run on GPU\n", " @Expression MS[51]#2008L could run on GPU\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2022-04-05 09:42:07,736 WARN rapids.GpuOverrides: > (0 + 0) / 815]\n", "!Exec cannot run on GPU because the Exec CollectLimitExec has been disabled, and is disabled by default because Collect Limit replacement can be slower on the GPU, if huge number of rows in a batch it could help by limiting the number of rows transferred from GPU to CPU. Set spark.rapids.sql.exec.CollectLimitExec to true if you wish to enable it\n", " @Partitioning could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression table#1189 could run on GPU\n", " @Expression approx_count_distinct(customerID#1179, 0.05, 0, 0) could run on GPU\n", " ! approx_count_distinct(customerID#1179, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\n", " @Expression customerID#1179 could run on GPU\n", " @Expression approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L could run on GPU\n", " @Expression table#1189 could run on GPU\n", " @Expression cast(approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L as string) AS approx_unique_customers#2118 could run on GPU\n", " @Expression cast(approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L as string) could run on GPU\n", " @Expression approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression 0#2539 could run on GPU\n", " @Expression approx_count_distinct(customerID#1883, 0.05, 0, 0) could run on GPU\n", " ! approx_count_distinct(customerID#1883, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\n", " @Expression customerID#1883 could run on GPU\n", " @Expression approx_count_distinct(customerID#1883, 0.05, 0, 0)#1671L could run on GPU\n", " @Expression all AS table#2537 could run on GPU\n", " @Expression all could run on GPU\n", " @Expression cast(approx_count_distinct(customerID#1883, 0.05, 0, 0)#1671L as string) AS approx_unique_customers#2538 could run on GPU\n", " @Expression cast(approx_count_distinct(customerID#1883, 0.05, 0, 0)#1671L as string) could run on GPU\n", " @Expression approx_count_distinct(customerID#1883, 0.05, 0, 0)#1671L could run on GPU\n", "\n", "2022-04-05 09:42:46,961 WARN rapids.GpuOverrides: =============>(812 + 1) / 815]\n", "!Exec cannot run on GPU because the Exec CollectLimitExec has been disabled, and is disabled by default because Collect Limit replacement can be slower on the GPU, if huge number of rows in a batch it could help by limiting the number of rows transferred from GPU to CPU. Set spark.rapids.sql.exec.CollectLimitExec to true if you wish to enable it\n", " @Partitioning could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression table#1189 could run on GPU\n", " @Expression approx_count_distinct(customerID#1179, 0.05, 0, 0) could run on GPU\n", " ! approx_count_distinct(customerID#1179, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\n", " @Expression customerID#1179 could run on GPU\n", " @Expression approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L could run on GPU\n", " @Expression table#1189 could run on GPU\n", " @Expression cast(approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L as string) AS approx_unique_customers#2118 could run on GPU\n", " @Expression cast(approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L as string) could run on GPU\n", " @Expression approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression 0#2539 could run on GPU\n", " @Expression approx_count_distinct(customerID#1883, 0.05, 0, 0) could run on GPU\n", " ! approx_count_distinct(customerID#1883, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\n", " @Expression customerID#1883 could run on GPU\n", " @Expression approx_count_distinct(customerID#1883, 0.05, 0, 0)#1671L could run on GPU\n", " @Expression all AS table#2537 could run on GPU\n", " @Expression all could run on GPU\n", " @Expression cast(approx_count_distinct(customerID#1883, 0.05, 0, 0)#1671L as string) AS approx_unique_customers#2538 could run on GPU\n", " @Expression cast(approx_count_distinct(customerID#1883, 0.05, 0, 0)#1671L as string) could run on GPU\n", " @Expression approx_count_distinct(customerID#1883, 0.05, 0, 0)#1671L could run on GPU\n", "\n", "2022-04-05 09:42:46,964 WARN rapids.GpuOverrides: \n", "!Exec cannot run on GPU because the Exec CollectLimitExec has been disabled, and is disabled by default because Collect Limit replacement can be slower on the GPU, if huge number of rows in a batch it could help by limiting the number of rows transferred from GPU to CPU. Set spark.rapids.sql.exec.CollectLimitExec to true if you wish to enable it\n", " @Partitioning could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression table#1189 could run on GPU\n", " @Expression approx_count_distinct(customerID#1179, 0.05, 0, 0) could run on GPU\n", " ! approx_count_distinct(customerID#1179, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\n", " @Expression customerID#1179 could run on GPU\n", " @Expression approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L could run on GPU\n", " @Expression table#1189 could run on GPU\n", " @Expression cast(approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L as string) AS approx_unique_customers#2118 could run on GPU\n", " @Expression cast(approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L as string) could run on GPU\n", " @Expression approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L could run on GPU\n", " !Exec cannot run on GPU because Unable to replace CustomShuffleReader due to child not being columnar\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression 0#2539 could run on GPU\n", " @Expression approx_count_distinct(customerID#1883, 0.05, 0, 0) could run on GPU\n", " ! approx_count_distinct(customerID#1883, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\n", " @Expression customerID#1883 could run on GPU\n", " @Expression approx_count_distinct(customerID#1883, 0.05, 0, 0)#1671L could run on GPU\n", " @Expression all AS table#2537 could run on GPU\n", " @Expression all could run on GPU\n", " @Expression cast(approx_count_distinct(customerID#1883, 0.05, 0, 0)#1671L as string) AS approx_unique_customers#2538 could run on GPU\n", " @Expression cast(approx_count_distinct(customerID#1883, 0.05, 0, 0)#1671L as string) could run on GPU\n", " @Expression approx_count_distinct(customerID#1883, 0.05, 0, 0)#1671L could run on GPU\n", " !Exec cannot run on GPU because Unable to replace CustomShuffleReader due to child not being columnar\n", "\n", " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "+--------------------+-----------------------+\n", "| table|approx_unique_customers|\n", "+--------------------+-----------------------+\n", "| billing_events| 699470|\n", "| customer_meta| 699470|\n", "|customer_phone_fe...| 631148|\n", "|customer_internet...| 521053|\n", "|customer_account_...| 699470|\n", "| all| 699470|\n", "+--------------------+-----------------------+\n", "\n" ] } ], "source": [ "\n", "each_table = all_customers.groupBy(\"table\").agg(F.approx_count_distinct(\"customerID\").alias(\"approx_unique_customers\"))\n", "overall = all_customers.groupBy(F.lit(\"all\").alias(\"table\")).agg(F.approx_count_distinct(\"customerID\").alias(\"approx_unique_customers\"))\n", "\n", "each_table.union(overall).show()" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-04-05 09:42:47,133 WARN rapids.GpuOverrides: \n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression table#1189 could run on GPU\n", " @Expression approx_count_distinct(customerID#1179, 0.05, 0, 0) could run on GPU\n", " ! approx_count_distinct(customerID#1179, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\n", " @Expression customerID#1179 could run on GPU\n", " @Expression approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L could run on GPU\n", " @Expression table#1189 could run on GPU\n", " @Expression approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L AS approx_unique_customers#1353L could run on GPU\n", " @Expression approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L could run on GPU\n", " !Exec cannot run on GPU because Columnar exchange without columnar children is inefficient\n", " @Partitioning could run on GPU\n", " @Expression table#1189 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression table#1189 could run on GPU\n", " @Expression partial_approx_count_distinct(customerID#1179, 0.05, 0, 0) could run on GPU\n", " ! approx_count_distinct(customerID#1179, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\n", " @Expression customerID#1179 could run on GPU\n", " @Expression MS[0]#1354L could run on GPU\n", " @Expression MS[1]#1355L could run on GPU\n", " @Expression MS[2]#1356L could run on GPU\n", " @Expression MS[3]#1357L could run on GPU\n", " @Expression MS[4]#1358L could run on GPU\n", " @Expression MS[5]#1359L could run on GPU\n", " @Expression MS[6]#1360L could run on GPU\n", " @Expression MS[7]#1361L could run on GPU\n", " @Expression MS[8]#1362L could run on GPU\n", " @Expression MS[9]#1363L could run on GPU\n", " @Expression MS[10]#1364L could run on GPU\n", " @Expression MS[11]#1365L could run on GPU\n", " @Expression MS[12]#1366L could run on GPU\n", " @Expression MS[13]#1367L could run on GPU\n", " @Expression MS[14]#1368L could run on GPU\n", " @Expression MS[15]#1369L could run on GPU\n", " @Expression MS[16]#1370L could run on GPU\n", " @Expression MS[17]#1371L could run on GPU\n", " @Expression MS[18]#1372L could run on GPU\n", " @Expression MS[19]#1373L could run on GPU\n", " @Expression MS[20]#1374L could run on GPU\n", " @Expression MS[21]#1375L could run on GPU\n", " @Expression MS[22]#1376L could run on GPU\n", " @Expression MS[23]#1377L could run on GPU\n", " @Expression MS[24]#1378L could run on GPU\n", " @Expression MS[25]#1379L could run on GPU\n", " @Expression MS[26]#1380L could run on GPU\n", " @Expression MS[27]#1381L could run on GPU\n", " @Expression MS[28]#1382L could run on GPU\n", " @Expression MS[29]#1383L could run on GPU\n", " @Expression MS[30]#1384L could run on GPU\n", " @Expression MS[31]#1385L could run on GPU\n", " @Expression MS[32]#1386L could run on GPU\n", " @Expression MS[33]#1387L could run on GPU\n", " @Expression MS[34]#1388L could run on GPU\n", " @Expression MS[35]#1389L could run on GPU\n", " @Expression MS[36]#1390L could run on GPU\n", " @Expression MS[37]#1391L could run on GPU\n", " @Expression MS[38]#1392L could run on GPU\n", " @Expression MS[39]#1393L could run on GPU\n", " @Expression MS[40]#1394L could run on GPU\n", " @Expression MS[41]#1395L could run on GPU\n", " @Expression MS[42]#1396L could run on GPU\n", " @Expression MS[43]#1397L could run on GPU\n", " @Expression MS[44]#1398L could run on GPU\n", " @Expression MS[45]#1399L could run on GPU\n", " @Expression MS[46]#1400L could run on GPU\n", " @Expression MS[47]#1401L could run on GPU\n", " @Expression MS[48]#1402L could run on GPU\n", " @Expression MS[49]#1403L could run on GPU\n", " @Expression MS[50]#1404L could run on GPU\n", " @Expression MS[51]#1405L could run on GPU\n", " @Expression table#1189 could run on GPU\n", " @Expression MS[0]#1406L could run on GPU\n", " @Expression MS[1]#1407L could run on GPU\n", " @Expression MS[2]#1408L could run on GPU\n", " @Expression MS[3]#1409L could run on GPU\n", " @Expression MS[4]#1410L could run on GPU\n", " @Expression MS[5]#1411L could run on GPU\n", " @Expression MS[6]#1412L could run on GPU\n", " @Expression MS[7]#1413L could run on GPU\n", " @Expression MS[8]#1414L could run on GPU\n", " @Expression MS[9]#1415L could run on GPU\n", " @Expression MS[10]#1416L could run on GPU\n", " @Expression MS[11]#1417L could run on GPU\n", " @Expression MS[12]#1418L could run on GPU\n", " @Expression MS[13]#1419L could run on GPU\n", " @Expression MS[14]#1420L could run on GPU\n", " @Expression MS[15]#1421L could run on GPU\n", " @Expression MS[16]#1422L could run on GPU\n", " @Expression MS[17]#1423L could run on GPU\n", " @Expression MS[18]#1424L could run on GPU\n", " @Expression MS[19]#1425L could run on GPU\n", " @Expression MS[20]#1426L could run on GPU\n", " @Expression MS[21]#1427L could run on GPU\n", " @Expression MS[22]#1428L could run on GPU\n", " @Expression MS[23]#1429L could run on GPU\n", " @Expression MS[24]#1430L could run on GPU\n", " @Expression MS[25]#1431L could run on GPU\n", " @Expression MS[26]#1432L could run on GPU\n", " @Expression MS[27]#1433L could run on GPU\n", " @Expression MS[28]#1434L could run on GPU\n", " @Expression MS[29]#1435L could run on GPU\n", " @Expression MS[30]#1436L could run on GPU\n", " @Expression MS[31]#1437L could run on GPU\n", " @Expression MS[32]#1438L could run on GPU\n", " @Expression MS[33]#1439L could run on GPU\n", " @Expression MS[34]#1440L could run on GPU\n", " @Expression MS[35]#1441L could run on GPU\n", " @Expression MS[36]#1442L could run on GPU\n", " @Expression MS[37]#1443L could run on GPU\n", " @Expression MS[38]#1444L could run on GPU\n", " @Expression MS[39]#1445L could run on GPU\n", " @Expression MS[40]#1446L could run on GPU\n", " @Expression MS[41]#1447L could run on GPU\n", " @Expression MS[42]#1448L could run on GPU\n", " @Expression MS[43]#1449L could run on GPU\n", " @Expression MS[44]#1450L could run on GPU\n", " @Expression MS[45]#1451L could run on GPU\n", " @Expression MS[46]#1452L could run on GPU\n", " @Expression MS[47]#1453L could run on GPU\n", " @Expression MS[48]#1454L could run on GPU\n", " @Expression MS[49]#1455L could run on GPU\n", " @Expression MS[50]#1456L could run on GPU\n", " @Expression MS[51]#1457L could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression 0#4023 could run on GPU\n", " @Expression approx_count_distinct(customerID#3375, 0.05, 0, 0) could run on GPU\n", " ! approx_count_distinct(customerID#3375, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\n", " @Expression customerID#3375 could run on GPU\n", " @Expression approx_count_distinct(customerID#3375, 0.05, 0, 0)#1671L could run on GPU\n", " @Expression all AS table#1564 could run on GPU\n", " @Expression all could run on GPU\n", " @Expression approx_count_distinct(customerID#3375, 0.05, 0, 0)#1671L AS approx_unique_customers#1672L could run on GPU\n", " @Expression approx_count_distinct(customerID#3375, 0.05, 0, 0)#1671L could run on GPU\n", " !Exec cannot run on GPU because Columnar exchange without columnar children is inefficient\n", " @Partitioning could run on GPU\n", " @Expression 0#4023 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression 0 AS 0#4023 could run on GPU\n", " @Expression 0 could run on GPU\n", " @Expression partial_approx_count_distinct(customerID#3375, 0.05, 0, 0) could run on GPU\n", " ! approx_count_distinct(customerID#3375, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\n", " @Expression customerID#3375 could run on GPU\n", " @Expression MS[0]#3397L could run on GPU\n", " @Expression MS[1]#3398L could run on GPU\n", " @Expression MS[2]#3399L could run on GPU\n", " @Expression MS[3]#3400L could run on GPU\n", " @Expression MS[4]#3401L could run on GPU\n", " @Expression MS[5]#3402L could run on GPU\n", " @Expression MS[6]#3403L could run on GPU\n", " @Expression MS[7]#3404L could run on GPU\n", " @Expression MS[8]#3405L could run on GPU\n", " @Expression MS[9]#3406L could run on GPU\n", " @Expression MS[10]#3407L could run on GPU\n", " @Expression MS[11]#3408L could run on GPU\n", " @Expression MS[12]#3409L could run on GPU\n", " @Expression MS[13]#3410L could run on GPU\n", " @Expression MS[14]#3411L could run on GPU\n", " @Expression MS[15]#3412L could run on GPU\n", " @Expression MS[16]#3413L could run on GPU\n", " @Expression MS[17]#3414L could run on GPU\n", " @Expression MS[18]#3415L could run on GPU\n", " @Expression MS[19]#3416L could run on GPU\n", " @Expression MS[20]#3417L could run on GPU\n", " @Expression MS[21]#3418L could run on GPU\n", " @Expression MS[22]#3419L could run on GPU\n", " @Expression MS[23]#3420L could run on GPU\n", " @Expression MS[24]#3421L could run on GPU\n", " @Expression MS[25]#3422L could run on GPU\n", " @Expression MS[26]#3423L could run on GPU\n", " @Expression MS[27]#3424L could run on GPU\n", " @Expression MS[28]#3425L could run on GPU\n", " @Expression MS[29]#3426L could run on GPU\n", " @Expression MS[30]#3427L could run on GPU\n", " @Expression MS[31]#3428L could run on GPU\n", " @Expression MS[32]#3429L could run on GPU\n", " @Expression MS[33]#3430L could run on GPU\n", " @Expression MS[34]#3431L could run on GPU\n", " @Expression MS[35]#3432L could run on GPU\n", " @Expression MS[36]#3433L could run on GPU\n", " @Expression MS[37]#3434L could run on GPU\n", " @Expression MS[38]#3435L could run on GPU\n", " @Expression MS[39]#3436L could run on GPU\n", " @Expression MS[40]#3437L could run on GPU\n", " @Expression MS[41]#3438L could run on GPU\n", " @Expression MS[42]#3439L could run on GPU\n", " @Expression MS[43]#3440L could run on GPU\n", " @Expression MS[44]#3441L could run on GPU\n", " @Expression MS[45]#3442L could run on GPU\n", " @Expression MS[46]#3443L could run on GPU\n", " @Expression MS[47]#3444L could run on GPU\n", " @Expression MS[48]#3445L could run on GPU\n", " @Expression MS[49]#3446L could run on GPU\n", " @Expression MS[50]#3447L could run on GPU\n", " @Expression MS[51]#3448L could run on GPU\n", " @Expression 0#4023 could run on GPU\n", " @Expression MS[0]#3449L could run on GPU\n", " @Expression MS[1]#3450L could run on GPU\n", " @Expression MS[2]#3451L could run on GPU\n", " @Expression MS[3]#3452L could run on GPU\n", " @Expression MS[4]#3453L could run on GPU\n", " @Expression MS[5]#3454L could run on GPU\n", " @Expression MS[6]#3455L could run on GPU\n", " @Expression MS[7]#3456L could run on GPU\n", " @Expression MS[8]#3457L could run on GPU\n", " @Expression MS[9]#3458L could run on GPU\n", " @Expression MS[10]#3459L could run on GPU\n", " @Expression MS[11]#3460L could run on GPU\n", " @Expression MS[12]#3461L could run on GPU\n", " @Expression MS[13]#3462L could run on GPU\n", " @Expression MS[14]#3463L could run on GPU\n", " @Expression MS[15]#3464L could run on GPU\n", " @Expression MS[16]#3465L could run on GPU\n", " @Expression MS[17]#3466L could run on GPU\n", " @Expression MS[18]#3467L could run on GPU\n", " @Expression MS[19]#3468L could run on GPU\n", " @Expression MS[20]#3469L could run on GPU\n", " @Expression MS[21]#3470L could run on GPU\n", " @Expression MS[22]#3471L could run on GPU\n", " @Expression MS[23]#3472L could run on GPU\n", " @Expression MS[24]#3473L could run on GPU\n", " @Expression MS[25]#3474L could run on GPU\n", " @Expression MS[26]#3475L could run on GPU\n", " @Expression MS[27]#3476L could run on GPU\n", " @Expression MS[28]#3477L could run on GPU\n", " @Expression MS[29]#3478L could run on GPU\n", " @Expression MS[30]#3479L could run on GPU\n", " @Expression MS[31]#3480L could run on GPU\n", " @Expression MS[32]#3481L could run on GPU\n", " @Expression MS[33]#3482L could run on GPU\n", " @Expression MS[34]#3483L could run on GPU\n", " @Expression MS[35]#3484L could run on GPU\n", " @Expression MS[36]#3485L could run on GPU\n", " @Expression MS[37]#3486L could run on GPU\n", " @Expression MS[38]#3487L could run on GPU\n", " @Expression MS[39]#3488L could run on GPU\n", " @Expression MS[40]#3489L could run on GPU\n", " @Expression MS[41]#3490L could run on GPU\n", " @Expression MS[42]#3491L could run on GPU\n", " @Expression MS[43]#3492L could run on GPU\n", " @Expression MS[44]#3493L could run on GPU\n", " @Expression MS[45]#3494L could run on GPU\n", " @Expression MS[46]#3495L could run on GPU\n", " @Expression MS[47]#3496L could run on GPU\n", " @Expression MS[48]#3497L could run on GPU\n", " @Expression MS[49]#3498L could run on GPU\n", " @Expression MS[50]#3499L could run on GPU\n", " @Expression MS[51]#3500L could run on GPU\n", "\n", "2022-04-05 09:42:47,136 WARN rapids.GpuOverrides: \n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression table#1189 could run on GPU\n", " @Expression approx_count_distinct(customerID#1179, 0.05, 0, 0) could run on GPU\n", " ! approx_count_distinct(customerID#1179, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\n", " @Expression customerID#1179 could run on GPU\n", " @Expression approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L could run on GPU\n", " @Expression table#1189 could run on GPU\n", " @Expression approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L AS approx_unique_customers#1353L could run on GPU\n", " @Expression approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L could run on GPU\n", " !Exec cannot run on GPU because Columnar exchange without columnar children is inefficient\n", " @Partitioning could run on GPU\n", " @Expression table#1189 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression table#1189 could run on GPU\n", " @Expression partial_approx_count_distinct(customerID#1179, 0.05, 0, 0) could run on GPU\n", " ! approx_count_distinct(customerID#1179, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\n", " @Expression customerID#1179 could run on GPU\n", " @Expression MS[0]#1354L could run on GPU\n", " @Expression MS[1]#1355L could run on GPU\n", " @Expression MS[2]#1356L could run on GPU\n", " @Expression MS[3]#1357L could run on GPU\n", " @Expression MS[4]#1358L could run on GPU\n", " @Expression MS[5]#1359L could run on GPU\n", " @Expression MS[6]#1360L could run on GPU\n", " @Expression MS[7]#1361L could run on GPU\n", " @Expression MS[8]#1362L could run on GPU\n", " @Expression MS[9]#1363L could run on GPU\n", " @Expression MS[10]#1364L could run on GPU\n", " @Expression MS[11]#1365L could run on GPU\n", " @Expression MS[12]#1366L could run on GPU\n", " @Expression MS[13]#1367L could run on GPU\n", " @Expression MS[14]#1368L could run on GPU\n", " @Expression MS[15]#1369L could run on GPU\n", " @Expression MS[16]#1370L could run on GPU\n", " @Expression MS[17]#1371L could run on GPU\n", " @Expression MS[18]#1372L could run on GPU\n", " @Expression MS[19]#1373L could run on GPU\n", " @Expression MS[20]#1374L could run on GPU\n", " @Expression MS[21]#1375L could run on GPU\n", " @Expression MS[22]#1376L could run on GPU\n", " @Expression MS[23]#1377L could run on GPU\n", " @Expression MS[24]#1378L could run on GPU\n", " @Expression MS[25]#1379L could run on GPU\n", " @Expression MS[26]#1380L could run on GPU\n", " @Expression MS[27]#1381L could run on GPU\n", " @Expression MS[28]#1382L could run on GPU\n", " @Expression MS[29]#1383L could run on GPU\n", " @Expression MS[30]#1384L could run on GPU\n", " @Expression MS[31]#1385L could run on GPU\n", " @Expression MS[32]#1386L could run on GPU\n", " @Expression MS[33]#1387L could run on GPU\n", " @Expression MS[34]#1388L could run on GPU\n", " @Expression MS[35]#1389L could run on GPU\n", " @Expression MS[36]#1390L could run on GPU\n", " @Expression MS[37]#1391L could run on GPU\n", " @Expression MS[38]#1392L could run on GPU\n", " @Expression MS[39]#1393L could run on GPU\n", " @Expression MS[40]#1394L could run on GPU\n", " @Expression MS[41]#1395L could run on GPU\n", " @Expression MS[42]#1396L could run on GPU\n", " @Expression MS[43]#1397L could run on GPU\n", " @Expression MS[44]#1398L could run on GPU\n", " @Expression MS[45]#1399L could run on GPU\n", " @Expression MS[46]#1400L could run on GPU\n", " @Expression MS[47]#1401L could run on GPU\n", " @Expression MS[48]#1402L could run on GPU\n", " @Expression MS[49]#1403L could run on GPU\n", " @Expression MS[50]#1404L could run on GPU\n", " @Expression MS[51]#1405L could run on GPU\n", " @Expression table#1189 could run on GPU\n", " @Expression MS[0]#1406L could run on GPU\n", " @Expression MS[1]#1407L could run on GPU\n", " @Expression MS[2]#1408L could run on GPU\n", " @Expression MS[3]#1409L could run on GPU\n", " @Expression MS[4]#1410L could run on GPU\n", " @Expression MS[5]#1411L could run on GPU\n", " @Expression MS[6]#1412L could run on GPU\n", " @Expression MS[7]#1413L could run on GPU\n", " @Expression MS[8]#1414L could run on GPU\n", " @Expression MS[9]#1415L could run on GPU\n", " @Expression MS[10]#1416L could run on GPU\n", " @Expression MS[11]#1417L could run on GPU\n", " @Expression MS[12]#1418L could run on GPU\n", " @Expression MS[13]#1419L could run on GPU\n", " @Expression MS[14]#1420L could run on GPU\n", " @Expression MS[15]#1421L could run on GPU\n", " @Expression MS[16]#1422L could run on GPU\n", " @Expression MS[17]#1423L could run on GPU\n", " @Expression MS[18]#1424L could run on GPU\n", " @Expression MS[19]#1425L could run on GPU\n", " @Expression MS[20]#1426L could run on GPU\n", " @Expression MS[21]#1427L could run on GPU\n", " @Expression MS[22]#1428L could run on GPU\n", " @Expression MS[23]#1429L could run on GPU\n", " @Expression MS[24]#1430L could run on GPU\n", " @Expression MS[25]#1431L could run on GPU\n", " @Expression MS[26]#1432L could run on GPU\n", " @Expression MS[27]#1433L could run on GPU\n", " @Expression MS[28]#1434L could run on GPU\n", " @Expression MS[29]#1435L could run on GPU\n", " @Expression MS[30]#1436L could run on GPU\n", " @Expression MS[31]#1437L could run on GPU\n", " @Expression MS[32]#1438L could run on GPU\n", " @Expression MS[33]#1439L could run on GPU\n", " @Expression MS[34]#1440L could run on GPU\n", " @Expression MS[35]#1441L could run on GPU\n", " @Expression MS[36]#1442L could run on GPU\n", " @Expression MS[37]#1443L could run on GPU\n", " @Expression MS[38]#1444L could run on GPU\n", " @Expression MS[39]#1445L could run on GPU\n", " @Expression MS[40]#1446L could run on GPU\n", " @Expression MS[41]#1447L could run on GPU\n", " @Expression MS[42]#1448L could run on GPU\n", " @Expression MS[43]#1449L could run on GPU\n", " @Expression MS[44]#1450L could run on GPU\n", " @Expression MS[45]#1451L could run on GPU\n", " @Expression MS[46]#1452L could run on GPU\n", " @Expression MS[47]#1453L could run on GPU\n", " @Expression MS[48]#1454L could run on GPU\n", " @Expression MS[49]#1455L could run on GPU\n", " @Expression MS[50]#1456L could run on GPU\n", " @Expression MS[51]#1457L could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression 0#4023 could run on GPU\n", " @Expression approx_count_distinct(customerID#3375, 0.05, 0, 0) could run on GPU\n", " ! approx_count_distinct(customerID#3375, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\n", " @Expression customerID#3375 could run on GPU\n", " @Expression approx_count_distinct(customerID#3375, 0.05, 0, 0)#1671L could run on GPU\n", " @Expression all AS table#1564 could run on GPU\n", " @Expression all could run on GPU\n", " @Expression approx_count_distinct(customerID#3375, 0.05, 0, 0)#1671L AS approx_unique_customers#1672L could run on GPU\n", " @Expression approx_count_distinct(customerID#3375, 0.05, 0, 0)#1671L could run on GPU\n", " !Exec cannot run on GPU because Columnar exchange without columnar children is inefficient\n", " @Partitioning could run on GPU\n", " @Expression 0#4023 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression 0 AS 0#4023 could run on GPU\n", " @Expression 0 could run on GPU\n", " @Expression partial_approx_count_distinct(customerID#3375, 0.05, 0, 0) could run on GPU\n", " ! approx_count_distinct(customerID#3375, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\n", " @Expression customerID#3375 could run on GPU\n", " @Expression MS[0]#3397L could run on GPU\n", " @Expression MS[1]#3398L could run on GPU\n", " @Expression MS[2]#3399L could run on GPU\n", " @Expression MS[3]#3400L could run on GPU\n", " @Expression MS[4]#3401L could run on GPU\n", " @Expression MS[5]#3402L could run on GPU\n", " @Expression MS[6]#3403L could run on GPU\n", " @Expression MS[7]#3404L could run on GPU\n", " @Expression MS[8]#3405L could run on GPU\n", " @Expression MS[9]#3406L could run on GPU\n", " @Expression MS[10]#3407L could run on GPU\n", " @Expression MS[11]#3408L could run on GPU\n", " @Expression MS[12]#3409L could run on GPU\n", " @Expression MS[13]#3410L could run on GPU\n", " @Expression MS[14]#3411L could run on GPU\n", " @Expression MS[15]#3412L could run on GPU\n", " @Expression MS[16]#3413L could run on GPU\n", " @Expression MS[17]#3414L could run on GPU\n", " @Expression MS[18]#3415L could run on GPU\n", " @Expression MS[19]#3416L could run on GPU\n", " @Expression MS[20]#3417L could run on GPU\n", " @Expression MS[21]#3418L could run on GPU\n", " @Expression MS[22]#3419L could run on GPU\n", " @Expression MS[23]#3420L could run on GPU\n", " @Expression MS[24]#3421L could run on GPU\n", " @Expression MS[25]#3422L could run on GPU\n", " @Expression MS[26]#3423L could run on GPU\n", " @Expression MS[27]#3424L could run on GPU\n", " @Expression MS[28]#3425L could run on GPU\n", " @Expression MS[29]#3426L could run on GPU\n", " @Expression MS[30]#3427L could run on GPU\n", " @Expression MS[31]#3428L could run on GPU\n", " @Expression MS[32]#3429L could run on GPU\n", " @Expression MS[33]#3430L could run on GPU\n", " @Expression MS[34]#3431L could run on GPU\n", " @Expression MS[35]#3432L could run on GPU\n", " @Expression MS[36]#3433L could run on GPU\n", " @Expression MS[37]#3434L could run on GPU\n", " @Expression MS[38]#3435L could run on GPU\n", " @Expression MS[39]#3436L could run on GPU\n", " @Expression MS[40]#3437L could run on GPU\n", " @Expression MS[41]#3438L could run on GPU\n", " @Expression MS[42]#3439L could run on GPU\n", " @Expression MS[43]#3440L could run on GPU\n", " @Expression MS[44]#3441L could run on GPU\n", " @Expression MS[45]#3442L could run on GPU\n", " @Expression MS[46]#3443L could run on GPU\n", " @Expression MS[47]#3444L could run on GPU\n", " @Expression MS[48]#3445L could run on GPU\n", " @Expression MS[49]#3446L could run on GPU\n", " @Expression MS[50]#3447L could run on GPU\n", " @Expression MS[51]#3448L could run on GPU\n", " @Expression 0#4023 could run on GPU\n", " @Expression MS[0]#3449L could run on GPU\n", " @Expression MS[1]#3450L could run on GPU\n", " @Expression MS[2]#3451L could run on GPU\n", " @Expression MS[3]#3452L could run on GPU\n", " @Expression MS[4]#3453L could run on GPU\n", " @Expression MS[5]#3454L could run on GPU\n", " @Expression MS[6]#3455L could run on GPU\n", " @Expression MS[7]#3456L could run on GPU\n", " @Expression MS[8]#3457L could run on GPU\n", " @Expression MS[9]#3458L could run on GPU\n", " @Expression MS[10]#3459L could run on GPU\n", " @Expression MS[11]#3460L could run on GPU\n", " @Expression MS[12]#3461L could run on GPU\n", " @Expression MS[13]#3462L could run on GPU\n", " @Expression MS[14]#3463L could run on GPU\n", " @Expression MS[15]#3464L could run on GPU\n", " @Expression MS[16]#3465L could run on GPU\n", " @Expression MS[17]#3466L could run on GPU\n", " @Expression MS[18]#3467L could run on GPU\n", " @Expression MS[19]#3468L could run on GPU\n", " @Expression MS[20]#3469L could run on GPU\n", " @Expression MS[21]#3470L could run on GPU\n", " @Expression MS[22]#3471L could run on GPU\n", " @Expression MS[23]#3472L could run on GPU\n", " @Expression MS[24]#3473L could run on GPU\n", " @Expression MS[25]#3474L could run on GPU\n", " @Expression MS[26]#3475L could run on GPU\n", " @Expression MS[27]#3476L could run on GPU\n", " @Expression MS[28]#3477L could run on GPU\n", " @Expression MS[29]#3478L could run on GPU\n", " @Expression MS[30]#3479L could run on GPU\n", " @Expression MS[31]#3480L could run on GPU\n", " @Expression MS[32]#3481L could run on GPU\n", " @Expression MS[33]#3482L could run on GPU\n", " @Expression MS[34]#3483L could run on GPU\n", " @Expression MS[35]#3484L could run on GPU\n", " @Expression MS[36]#3485L could run on GPU\n", " @Expression MS[37]#3486L could run on GPU\n", " @Expression MS[38]#3487L could run on GPU\n", " @Expression MS[39]#3488L could run on GPU\n", " @Expression MS[40]#3489L could run on GPU\n", " @Expression MS[41]#3490L could run on GPU\n", " @Expression MS[42]#3491L could run on GPU\n", " @Expression MS[43]#3492L could run on GPU\n", " @Expression MS[44]#3493L could run on GPU\n", " @Expression MS[45]#3494L could run on GPU\n", " @Expression MS[46]#3495L could run on GPU\n", " @Expression MS[47]#3496L could run on GPU\n", " @Expression MS[48]#3497L could run on GPU\n", " @Expression MS[49]#3498L could run on GPU\n", " @Expression MS[50]#3499L could run on GPU\n", " @Expression MS[51]#3500L could run on GPU\n", "\n", "2022-04-05 09:42:47,139 WARN rapids.GpuOverrides: \n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression table#1189 could run on GPU\n", " @Expression approx_count_distinct(customerID#1179, 0.05, 0, 0) could run on GPU\n", " ! approx_count_distinct(customerID#1179, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\n", " @Expression customerID#1179 could run on GPU\n", " @Expression approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L could run on GPU\n", " @Expression table#1189 could run on GPU\n", " @Expression approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L AS approx_unique_customers#1353L could run on GPU\n", " @Expression approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L could run on GPU\n", " !Exec cannot run on GPU because Columnar exchange without columnar children is inefficient\n", " @Partitioning could run on GPU\n", " @Expression table#1189 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression table#1189 could run on GPU\n", " @Expression partial_approx_count_distinct(customerID#1179, 0.05, 0, 0) could run on GPU\n", " ! approx_count_distinct(customerID#1179, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\n", " @Expression customerID#1179 could run on GPU\n", " @Expression MS[0]#1354L could run on GPU\n", " @Expression MS[1]#1355L could run on GPU\n", " @Expression MS[2]#1356L could run on GPU\n", " @Expression MS[3]#1357L could run on GPU\n", " @Expression MS[4]#1358L could run on GPU\n", " @Expression MS[5]#1359L could run on GPU\n", " @Expression MS[6]#1360L could run on GPU\n", " @Expression MS[7]#1361L could run on GPU\n", " @Expression MS[8]#1362L could run on GPU\n", " @Expression MS[9]#1363L could run on GPU\n", " @Expression MS[10]#1364L could run on GPU\n", " @Expression MS[11]#1365L could run on GPU\n", " @Expression MS[12]#1366L could run on GPU\n", " @Expression MS[13]#1367L could run on GPU\n", " @Expression MS[14]#1368L could run on GPU\n", " @Expression MS[15]#1369L could run on GPU\n", " @Expression MS[16]#1370L could run on GPU\n", " @Expression MS[17]#1371L could run on GPU\n", " @Expression MS[18]#1372L could run on GPU\n", " @Expression MS[19]#1373L could run on GPU\n", " @Expression MS[20]#1374L could run on GPU\n", " @Expression MS[21]#1375L could run on GPU\n", " @Expression MS[22]#1376L could run on GPU\n", " @Expression MS[23]#1377L could run on GPU\n", " @Expression MS[24]#1378L could run on GPU\n", " @Expression MS[25]#1379L could run on GPU\n", " @Expression MS[26]#1380L could run on GPU\n", " @Expression MS[27]#1381L could run on GPU\n", " @Expression MS[28]#1382L could run on GPU\n", " @Expression MS[29]#1383L could run on GPU\n", " @Expression MS[30]#1384L could run on GPU\n", " @Expression MS[31]#1385L could run on GPU\n", " @Expression MS[32]#1386L could run on GPU\n", " @Expression MS[33]#1387L could run on GPU\n", " @Expression MS[34]#1388L could run on GPU\n", " @Expression MS[35]#1389L could run on GPU\n", " @Expression MS[36]#1390L could run on GPU\n", " @Expression MS[37]#1391L could run on GPU\n", " @Expression MS[38]#1392L could run on GPU\n", " @Expression MS[39]#1393L could run on GPU\n", " @Expression MS[40]#1394L could run on GPU\n", " @Expression MS[41]#1395L could run on GPU\n", " @Expression MS[42]#1396L could run on GPU\n", " @Expression MS[43]#1397L could run on GPU\n", " @Expression MS[44]#1398L could run on GPU\n", " @Expression MS[45]#1399L could run on GPU\n", " @Expression MS[46]#1400L could run on GPU\n", " @Expression MS[47]#1401L could run on GPU\n", " @Expression MS[48]#1402L could run on GPU\n", " @Expression MS[49]#1403L could run on GPU\n", " @Expression MS[50]#1404L could run on GPU\n", " @Expression MS[51]#1405L could run on GPU\n", " @Expression table#1189 could run on GPU\n", " @Expression MS[0]#1406L could run on GPU\n", " @Expression MS[1]#1407L could run on GPU\n", " @Expression MS[2]#1408L could run on GPU\n", " @Expression MS[3]#1409L could run on GPU\n", " @Expression MS[4]#1410L could run on GPU\n", " @Expression MS[5]#1411L could run on GPU\n", " @Expression MS[6]#1412L could run on GPU\n", " @Expression MS[7]#1413L could run on GPU\n", " @Expression MS[8]#1414L could run on GPU\n", " @Expression MS[9]#1415L could run on GPU\n", " @Expression MS[10]#1416L could run on GPU\n", " @Expression MS[11]#1417L could run on GPU\n", " @Expression MS[12]#1418L could run on GPU\n", " @Expression MS[13]#1419L could run on GPU\n", " @Expression MS[14]#1420L could run on GPU\n", " @Expression MS[15]#1421L could run on GPU\n", " @Expression MS[16]#1422L could run on GPU\n", " @Expression MS[17]#1423L could run on GPU\n", " @Expression MS[18]#1424L could run on GPU\n", " @Expression MS[19]#1425L could run on GPU\n", " @Expression MS[20]#1426L could run on GPU\n", " @Expression MS[21]#1427L could run on GPU\n", " @Expression MS[22]#1428L could run on GPU\n", " @Expression MS[23]#1429L could run on GPU\n", " @Expression MS[24]#1430L could run on GPU\n", " @Expression MS[25]#1431L could run on GPU\n", " @Expression MS[26]#1432L could run on GPU\n", " @Expression MS[27]#1433L could run on GPU\n", " @Expression MS[28]#1434L could run on GPU\n", " @Expression MS[29]#1435L could run on GPU\n", " @Expression MS[30]#1436L could run on GPU\n", " @Expression MS[31]#1437L could run on GPU\n", " @Expression MS[32]#1438L could run on GPU\n", " @Expression MS[33]#1439L could run on GPU\n", " @Expression MS[34]#1440L could run on GPU\n", " @Expression MS[35]#1441L could run on GPU\n", " @Expression MS[36]#1442L could run on GPU\n", " @Expression MS[37]#1443L could run on GPU\n", " @Expression MS[38]#1444L could run on GPU\n", " @Expression MS[39]#1445L could run on GPU\n", " @Expression MS[40]#1446L could run on GPU\n", " @Expression MS[41]#1447L could run on GPU\n", " @Expression MS[42]#1448L could run on GPU\n", " @Expression MS[43]#1449L could run on GPU\n", " @Expression MS[44]#1450L could run on GPU\n", " @Expression MS[45]#1451L could run on GPU\n", " @Expression MS[46]#1452L could run on GPU\n", " @Expression MS[47]#1453L could run on GPU\n", " @Expression MS[48]#1454L could run on GPU\n", " @Expression MS[49]#1455L could run on GPU\n", " @Expression MS[50]#1456L could run on GPU\n", " @Expression MS[51]#1457L could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression 0#4023 could run on GPU\n", " @Expression approx_count_distinct(customerID#3375, 0.05, 0, 0) could run on GPU\n", " ! approx_count_distinct(customerID#3375, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\n", " @Expression customerID#3375 could run on GPU\n", " @Expression approx_count_distinct(customerID#3375, 0.05, 0, 0)#1671L could run on GPU\n", " @Expression all AS table#1564 could run on GPU\n", " @Expression all could run on GPU\n", " @Expression approx_count_distinct(customerID#3375, 0.05, 0, 0)#1671L AS approx_unique_customers#1672L could run on GPU\n", " @Expression approx_count_distinct(customerID#3375, 0.05, 0, 0)#1671L could run on GPU\n", " !Exec cannot run on GPU because Columnar exchange without columnar children is inefficient\n", " @Partitioning could run on GPU\n", " @Expression 0#4023 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression 0 AS 0#4023 could run on GPU\n", " @Expression 0 could run on GPU\n", " @Expression partial_approx_count_distinct(customerID#3375, 0.05, 0, 0) could run on GPU\n", " ! approx_count_distinct(customerID#3375, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\n", " @Expression customerID#3375 could run on GPU\n", " @Expression MS[0]#3397L could run on GPU\n", " @Expression MS[1]#3398L could run on GPU\n", " @Expression MS[2]#3399L could run on GPU\n", " @Expression MS[3]#3400L could run on GPU\n", " @Expression MS[4]#3401L could run on GPU\n", " @Expression MS[5]#3402L could run on GPU\n", " @Expression MS[6]#3403L could run on GPU\n", " @Expression MS[7]#3404L could run on GPU\n", " @Expression MS[8]#3405L could run on GPU\n", " @Expression MS[9]#3406L could run on GPU\n", " @Expression MS[10]#3407L could run on GPU\n", " @Expression MS[11]#3408L could run on GPU\n", " @Expression MS[12]#3409L could run on GPU\n", " @Expression MS[13]#3410L could run on GPU\n", " @Expression MS[14]#3411L could run on GPU\n", " @Expression MS[15]#3412L could run on GPU\n", " @Expression MS[16]#3413L could run on GPU\n", " @Expression MS[17]#3414L could run on GPU\n", " @Expression MS[18]#3415L could run on GPU\n", " @Expression MS[19]#3416L could run on GPU\n", " @Expression MS[20]#3417L could run on GPU\n", " @Expression MS[21]#3418L could run on GPU\n", " @Expression MS[22]#3419L could run on GPU\n", " @Expression MS[23]#3420L could run on GPU\n", " @Expression MS[24]#3421L could run on GPU\n", " @Expression MS[25]#3422L could run on GPU\n", " @Expression MS[26]#3423L could run on GPU\n", " @Expression MS[27]#3424L could run on GPU\n", " @Expression MS[28]#3425L could run on GPU\n", " @Expression MS[29]#3426L could run on GPU\n", " @Expression MS[30]#3427L could run on GPU\n", " @Expression MS[31]#3428L could run on GPU\n", " @Expression MS[32]#3429L could run on GPU\n", " @Expression MS[33]#3430L could run on GPU\n", " @Expression MS[34]#3431L could run on GPU\n", " @Expression MS[35]#3432L could run on GPU\n", " @Expression MS[36]#3433L could run on GPU\n", " @Expression MS[37]#3434L could run on GPU\n", " @Expression MS[38]#3435L could run on GPU\n", " @Expression MS[39]#3436L could run on GPU\n", " @Expression MS[40]#3437L could run on GPU\n", " @Expression MS[41]#3438L could run on GPU\n", " @Expression MS[42]#3439L could run on GPU\n", " @Expression MS[43]#3440L could run on GPU\n", " @Expression MS[44]#3441L could run on GPU\n", " @Expression MS[45]#3442L could run on GPU\n", " @Expression MS[46]#3443L could run on GPU\n", " @Expression MS[47]#3444L could run on GPU\n", " @Expression MS[48]#3445L could run on GPU\n", " @Expression MS[49]#3446L could run on GPU\n", " @Expression MS[50]#3447L could run on GPU\n", " @Expression MS[51]#3448L could run on GPU\n", " @Expression 0#4023 could run on GPU\n", " @Expression MS[0]#3449L could run on GPU\n", " @Expression MS[1]#3450L could run on GPU\n", " @Expression MS[2]#3451L could run on GPU\n", " @Expression MS[3]#3452L could run on GPU\n", " @Expression MS[4]#3453L could run on GPU\n", " @Expression MS[5]#3454L could run on GPU\n", " @Expression MS[6]#3455L could run on GPU\n", " @Expression MS[7]#3456L could run on GPU\n", " @Expression MS[8]#3457L could run on GPU\n", " @Expression MS[9]#3458L could run on GPU\n", " @Expression MS[10]#3459L could run on GPU\n", " @Expression MS[11]#3460L could run on GPU\n", " @Expression MS[12]#3461L could run on GPU\n", " @Expression MS[13]#3462L could run on GPU\n", " @Expression MS[14]#3463L could run on GPU\n", " @Expression MS[15]#3464L could run on GPU\n", " @Expression MS[16]#3465L could run on GPU\n", " @Expression MS[17]#3466L could run on GPU\n", " @Expression MS[18]#3467L could run on GPU\n", " @Expression MS[19]#3468L could run on GPU\n", " @Expression MS[20]#3469L could run on GPU\n", " @Expression MS[21]#3470L could run on GPU\n", " @Expression MS[22]#3471L could run on GPU\n", " @Expression MS[23]#3472L could run on GPU\n", " @Expression MS[24]#3473L could run on GPU\n", " @Expression MS[25]#3474L could run on GPU\n", " @Expression MS[26]#3475L could run on GPU\n", " @Expression MS[27]#3476L could run on GPU\n", " @Expression MS[28]#3477L could run on GPU\n", " @Expression MS[29]#3478L could run on GPU\n", " @Expression MS[30]#3479L could run on GPU\n", " @Expression MS[31]#3480L could run on GPU\n", " @Expression MS[32]#3481L could run on GPU\n", " @Expression MS[33]#3482L could run on GPU\n", " @Expression MS[34]#3483L could run on GPU\n", " @Expression MS[35]#3484L could run on GPU\n", " @Expression MS[36]#3485L could run on GPU\n", " @Expression MS[37]#3486L could run on GPU\n", " @Expression MS[38]#3487L could run on GPU\n", " @Expression MS[39]#3488L could run on GPU\n", " @Expression MS[40]#3489L could run on GPU\n", " @Expression MS[41]#3490L could run on GPU\n", " @Expression MS[42]#3491L could run on GPU\n", " @Expression MS[43]#3492L could run on GPU\n", " @Expression MS[44]#3493L could run on GPU\n", " @Expression MS[45]#3494L could run on GPU\n", " @Expression MS[46]#3495L could run on GPU\n", " @Expression MS[47]#3496L could run on GPU\n", " @Expression MS[48]#3497L could run on GPU\n", " @Expression MS[49]#3498L could run on GPU\n", " @Expression MS[50]#3499L could run on GPU\n", " @Expression MS[51]#3500L could run on GPU\n", "\n", "2022-04-05 09:42:47,147 WARN rapids.GpuOverrides: \n", "!Exec cannot run on GPU because Columnar exchange without columnar children is inefficient\n", " @Partitioning could run on GPU\n", " @Expression table#1189 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression table#1189 could run on GPU\n", " @Expression partial_approx_count_distinct(customerID#1179, 0.05, 0, 0) could run on GPU\n", " ! approx_count_distinct(customerID#1179, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\n", " @Expression customerID#1179 could run on GPU\n", " @Expression MS[0]#1354L could run on GPU\n", " @Expression MS[1]#1355L could run on GPU\n", " @Expression MS[2]#1356L could run on GPU\n", " @Expression MS[3]#1357L could run on GPU\n", " @Expression MS[4]#1358L could run on GPU\n", " @Expression MS[5]#1359L could run on GPU\n", " @Expression MS[6]#1360L could run on GPU\n", " @Expression MS[7]#1361L could run on GPU\n", " @Expression MS[8]#1362L could run on GPU\n", " @Expression MS[9]#1363L could run on GPU\n", " @Expression MS[10]#1364L could run on GPU\n", " @Expression MS[11]#1365L could run on GPU\n", " @Expression MS[12]#1366L could run on GPU\n", " @Expression MS[13]#1367L could run on GPU\n", " @Expression MS[14]#1368L could run on GPU\n", " @Expression MS[15]#1369L could run on GPU\n", " @Expression MS[16]#1370L could run on GPU\n", " @Expression MS[17]#1371L could run on GPU\n", " @Expression MS[18]#1372L could run on GPU\n", " @Expression MS[19]#1373L could run on GPU\n", " @Expression MS[20]#1374L could run on GPU\n", " @Expression MS[21]#1375L could run on GPU\n", " @Expression MS[22]#1376L could run on GPU\n", " @Expression MS[23]#1377L could run on GPU\n", " @Expression MS[24]#1378L could run on GPU\n", " @Expression MS[25]#1379L could run on GPU\n", " @Expression MS[26]#1380L could run on GPU\n", " @Expression MS[27]#1381L could run on GPU\n", " @Expression MS[28]#1382L could run on GPU\n", " @Expression MS[29]#1383L could run on GPU\n", " @Expression MS[30]#1384L could run on GPU\n", " @Expression MS[31]#1385L could run on GPU\n", " @Expression MS[32]#1386L could run on GPU\n", " @Expression MS[33]#1387L could run on GPU\n", " @Expression MS[34]#1388L could run on GPU\n", " @Expression MS[35]#1389L could run on GPU\n", " @Expression MS[36]#1390L could run on GPU\n", " @Expression MS[37]#1391L could run on GPU\n", " @Expression MS[38]#1392L could run on GPU\n", " @Expression MS[39]#1393L could run on GPU\n", " @Expression MS[40]#1394L could run on GPU\n", " @Expression MS[41]#1395L could run on GPU\n", " @Expression MS[42]#1396L could run on GPU\n", " @Expression MS[43]#1397L could run on GPU\n", " @Expression MS[44]#1398L could run on GPU\n", " @Expression MS[45]#1399L could run on GPU\n", " @Expression MS[46]#1400L could run on GPU\n", " @Expression MS[47]#1401L could run on GPU\n", " @Expression MS[48]#1402L could run on GPU\n", " @Expression MS[49]#1403L could run on GPU\n", " @Expression MS[50]#1404L could run on GPU\n", " @Expression MS[51]#1405L could run on GPU\n", " @Expression table#1189 could run on GPU\n", " @Expression MS[0]#1406L could run on GPU\n", " @Expression MS[1]#1407L could run on GPU\n", " @Expression MS[2]#1408L could run on GPU\n", " @Expression MS[3]#1409L could run on GPU\n", " @Expression MS[4]#1410L could run on GPU\n", " @Expression MS[5]#1411L could run on GPU\n", " @Expression MS[6]#1412L could run on GPU\n", " @Expression MS[7]#1413L could run on GPU\n", " @Expression MS[8]#1414L could run on GPU\n", " @Expression MS[9]#1415L could run on GPU\n", " @Expression MS[10]#1416L could run on GPU\n", " @Expression MS[11]#1417L could run on GPU\n", " @Expression MS[12]#1418L could run on GPU\n", " @Expression MS[13]#1419L could run on GPU\n", " @Expression MS[14]#1420L could run on GPU\n", " @Expression MS[15]#1421L could run on GPU\n", " @Expression MS[16]#1422L could run on GPU\n", " @Expression MS[17]#1423L could run on GPU\n", " @Expression MS[18]#1424L could run on GPU\n", " @Expression MS[19]#1425L could run on GPU\n", " @Expression MS[20]#1426L could run on GPU\n", " @Expression MS[21]#1427L could run on GPU\n", " @Expression MS[22]#1428L could run on GPU\n", " @Expression MS[23]#1429L could run on GPU\n", " @Expression MS[24]#1430L could run on GPU\n", " @Expression MS[25]#1431L could run on GPU\n", " @Expression MS[26]#1432L could run on GPU\n", " @Expression MS[27]#1433L could run on GPU\n", " @Expression MS[28]#1434L could run on GPU\n", " @Expression MS[29]#1435L could run on GPU\n", " @Expression MS[30]#1436L could run on GPU\n", " @Expression MS[31]#1437L could run on GPU\n", " @Expression MS[32]#1438L could run on GPU\n", " @Expression MS[33]#1439L could run on GPU\n", " @Expression MS[34]#1440L could run on GPU\n", " @Expression MS[35]#1441L could run on GPU\n", " @Expression MS[36]#1442L could run on GPU\n", " @Expression MS[37]#1443L could run on GPU\n", " @Expression MS[38]#1444L could run on GPU\n", " @Expression MS[39]#1445L could run on GPU\n", " @Expression MS[40]#1446L could run on GPU\n", " @Expression MS[41]#1447L could run on GPU\n", " @Expression MS[42]#1448L could run on GPU\n", " @Expression MS[43]#1449L could run on GPU\n", " @Expression MS[44]#1450L could run on GPU\n", " @Expression MS[45]#1451L could run on GPU\n", " @Expression MS[46]#1452L could run on GPU\n", " @Expression MS[47]#1453L could run on GPU\n", " @Expression MS[48]#1454L could run on GPU\n", " @Expression MS[49]#1455L could run on GPU\n", " @Expression MS[50]#1456L could run on GPU\n", " @Expression MS[51]#1457L could run on GPU\n", "\n", "2022-04-05 09:42:47,151 WARN rapids.GpuOverrides: \n", "!Exec cannot run on GPU because Columnar exchange without columnar children is inefficient\n", " @Partitioning could run on GPU\n", " @Expression 0#4023 could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression 0 AS 0#4023 could run on GPU\n", " @Expression 0 could run on GPU\n", " @Expression partial_approx_count_distinct(customerID#3375, 0.05, 0, 0) could run on GPU\n", " ! approx_count_distinct(customerID#3375, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\n", " @Expression customerID#3375 could run on GPU\n", " @Expression MS[0]#3397L could run on GPU\n", " @Expression MS[1]#3398L could run on GPU\n", " @Expression MS[2]#3399L could run on GPU\n", " @Expression MS[3]#3400L could run on GPU\n", " @Expression MS[4]#3401L could run on GPU\n", " @Expression MS[5]#3402L could run on GPU\n", " @Expression MS[6]#3403L could run on GPU\n", " @Expression MS[7]#3404L could run on GPU\n", " @Expression MS[8]#3405L could run on GPU\n", " @Expression MS[9]#3406L could run on GPU\n", " @Expression MS[10]#3407L could run on GPU\n", " @Expression MS[11]#3408L could run on GPU\n", " @Expression MS[12]#3409L could run on GPU\n", " @Expression MS[13]#3410L could run on GPU\n", " @Expression MS[14]#3411L could run on GPU\n", " @Expression MS[15]#3412L could run on GPU\n", " @Expression MS[16]#3413L could run on GPU\n", " @Expression MS[17]#3414L could run on GPU\n", " @Expression MS[18]#3415L could run on GPU\n", " @Expression MS[19]#3416L could run on GPU\n", " @Expression MS[20]#3417L could run on GPU\n", " @Expression MS[21]#3418L could run on GPU\n", " @Expression MS[22]#3419L could run on GPU\n", " @Expression MS[23]#3420L could run on GPU\n", " @Expression MS[24]#3421L could run on GPU\n", " @Expression MS[25]#3422L could run on GPU\n", " @Expression MS[26]#3423L could run on GPU\n", " @Expression MS[27]#3424L could run on GPU\n", " @Expression MS[28]#3425L could run on GPU\n", " @Expression MS[29]#3426L could run on GPU\n", " @Expression MS[30]#3427L could run on GPU\n", " @Expression MS[31]#3428L could run on GPU\n", " @Expression MS[32]#3429L could run on GPU\n", " @Expression MS[33]#3430L could run on GPU\n", " @Expression MS[34]#3431L could run on GPU\n", " @Expression MS[35]#3432L could run on GPU\n", " @Expression MS[36]#3433L could run on GPU\n", " @Expression MS[37]#3434L could run on GPU\n", " @Expression MS[38]#3435L could run on GPU\n", " @Expression MS[39]#3436L could run on GPU\n", " @Expression MS[40]#3437L could run on GPU\n", " @Expression MS[41]#3438L could run on GPU\n", " @Expression MS[42]#3439L could run on GPU\n", " @Expression MS[43]#3440L could run on GPU\n", " @Expression MS[44]#3441L could run on GPU\n", " @Expression MS[45]#3442L could run on GPU\n", " @Expression MS[46]#3443L could run on GPU\n", " @Expression MS[47]#3444L could run on GPU\n", " @Expression MS[48]#3445L could run on GPU\n", " @Expression MS[49]#3446L could run on GPU\n", " @Expression MS[50]#3447L could run on GPU\n", " @Expression MS[51]#3448L could run on GPU\n", " @Expression 0#4023 could run on GPU\n", " @Expression MS[0]#3449L could run on GPU\n", " @Expression MS[1]#3450L could run on GPU\n", " @Expression MS[2]#3451L could run on GPU\n", " @Expression MS[3]#3452L could run on GPU\n", " @Expression MS[4]#3453L could run on GPU\n", " @Expression MS[5]#3454L could run on GPU\n", " @Expression MS[6]#3455L could run on GPU\n", " @Expression MS[7]#3456L could run on GPU\n", " @Expression MS[8]#3457L could run on GPU\n", " @Expression MS[9]#3458L could run on GPU\n", " @Expression MS[10]#3459L could run on GPU\n", " @Expression MS[11]#3460L could run on GPU\n", " @Expression MS[12]#3461L could run on GPU\n", " @Expression MS[13]#3462L could run on GPU\n", " @Expression MS[14]#3463L could run on GPU\n", " @Expression MS[15]#3464L could run on GPU\n", " @Expression MS[16]#3465L could run on GPU\n", " @Expression MS[17]#3466L could run on GPU\n", " @Expression MS[18]#3467L could run on GPU\n", " @Expression MS[19]#3468L could run on GPU\n", " @Expression MS[20]#3469L could run on GPU\n", " @Expression MS[21]#3470L could run on GPU\n", " @Expression MS[22]#3471L could run on GPU\n", " @Expression MS[23]#3472L could run on GPU\n", " @Expression MS[24]#3473L could run on GPU\n", " @Expression MS[25]#3474L could run on GPU\n", " @Expression MS[26]#3475L could run on GPU\n", " @Expression MS[27]#3476L could run on GPU\n", " @Expression MS[28]#3477L could run on GPU\n", " @Expression MS[29]#3478L could run on GPU\n", " @Expression MS[30]#3479L could run on GPU\n", " @Expression MS[31]#3480L could run on GPU\n", " @Expression MS[32]#3481L could run on GPU\n", " @Expression MS[33]#3482L could run on GPU\n", " @Expression MS[34]#3483L could run on GPU\n", " @Expression MS[35]#3484L could run on GPU\n", " @Expression MS[36]#3485L could run on GPU\n", " @Expression MS[37]#3486L could run on GPU\n", " @Expression MS[38]#3487L could run on GPU\n", " @Expression MS[39]#3488L could run on GPU\n", " @Expression MS[40]#3489L could run on GPU\n", " @Expression MS[41]#3490L could run on GPU\n", " @Expression MS[42]#3491L could run on GPU\n", " @Expression MS[43]#3492L could run on GPU\n", " @Expression MS[44]#3493L could run on GPU\n", " @Expression MS[45]#3494L could run on GPU\n", " @Expression MS[46]#3495L could run on GPU\n", " @Expression MS[47]#3496L could run on GPU\n", " @Expression MS[48]#3497L could run on GPU\n", " @Expression MS[49]#3498L could run on GPU\n", " @Expression MS[50]#3499L could run on GPU\n", " @Expression MS[51]#3500L could run on GPU\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2022-04-05 09:43:28,385 WARN rapids.GpuOverrides: > (0 + 0) / 815]\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression table#1189 could run on GPU\n", " @Expression approx_count_distinct(customerID#1179, 0.05, 0, 0) could run on GPU\n", " ! approx_count_distinct(customerID#1179, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\n", " @Expression customerID#1179 could run on GPU\n", " @Expression approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L could run on GPU\n", " @Expression table#1189 could run on GPU\n", " @Expression approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L AS approx_unique_customers#1353L could run on GPU\n", " @Expression approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression 0#4023 could run on GPU\n", " @Expression approx_count_distinct(customerID#3375, 0.05, 0, 0) could run on GPU\n", " ! approx_count_distinct(customerID#3375, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\n", " @Expression customerID#3375 could run on GPU\n", " @Expression approx_count_distinct(customerID#3375, 0.05, 0, 0)#1671L could run on GPU\n", " @Expression all AS table#1564 could run on GPU\n", " @Expression all could run on GPU\n", " @Expression approx_count_distinct(customerID#3375, 0.05, 0, 0)#1671L AS approx_unique_customers#1672L could run on GPU\n", " @Expression approx_count_distinct(customerID#3375, 0.05, 0, 0)#1671L could run on GPU\n", "\n", "2022-04-05 09:44:07,480 WARN rapids.GpuOverrides: =============>(812 + 1) / 815]\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression table#1189 could run on GPU\n", " @Expression approx_count_distinct(customerID#1179, 0.05, 0, 0) could run on GPU\n", " ! approx_count_distinct(customerID#1179, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\n", " @Expression customerID#1179 could run on GPU\n", " @Expression approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L could run on GPU\n", " @Expression table#1189 could run on GPU\n", " @Expression approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L AS approx_unique_customers#1353L could run on GPU\n", " @Expression approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression 0#4023 could run on GPU\n", " @Expression approx_count_distinct(customerID#3375, 0.05, 0, 0) could run on GPU\n", " ! approx_count_distinct(customerID#3375, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\n", " @Expression customerID#3375 could run on GPU\n", " @Expression approx_count_distinct(customerID#3375, 0.05, 0, 0)#1671L could run on GPU\n", " @Expression all AS table#1564 could run on GPU\n", " @Expression all could run on GPU\n", " @Expression approx_count_distinct(customerID#3375, 0.05, 0, 0)#1671L AS approx_unique_customers#1672L could run on GPU\n", " @Expression approx_count_distinct(customerID#3375, 0.05, 0, 0)#1671L could run on GPU\n", "\n", "2022-04-05 09:44:07,482 WARN rapids.GpuOverrides: \n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression table#1189 could run on GPU\n", " @Expression approx_count_distinct(customerID#1179, 0.05, 0, 0) could run on GPU\n", " ! approx_count_distinct(customerID#1179, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\n", " @Expression customerID#1179 could run on GPU\n", " @Expression approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L could run on GPU\n", " @Expression table#1189 could run on GPU\n", " @Expression approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L AS approx_unique_customers#1353L could run on GPU\n", " @Expression approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L could run on GPU\n", " !Exec cannot run on GPU because Unable to replace CustomShuffleReader due to child not being columnar\n", " !Exec cannot run on GPU because not all expressions can be replaced\n", " @Expression 0#4023 could run on GPU\n", " @Expression approx_count_distinct(customerID#3375, 0.05, 0, 0) could run on GPU\n", " ! approx_count_distinct(customerID#3375, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\n", " @Expression customerID#3375 could run on GPU\n", " @Expression approx_count_distinct(customerID#3375, 0.05, 0, 0)#1671L could run on GPU\n", " @Expression all AS table#1564 could run on GPU\n", " @Expression all could run on GPU\n", " @Expression approx_count_distinct(customerID#3375, 0.05, 0, 0)#1671L AS approx_unique_customers#1672L could run on GPU\n", " @Expression approx_count_distinct(customerID#3375, 0.05, 0, 0)#1671L could run on GPU\n", " !Exec cannot run on GPU because Unable to replace CustomShuffleReader due to child not being columnar\n", "\n", " \r" ] } ], "source": [ "rows = each_table.union(overall).collect()" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'billing_events': 699470,\n", " 'customer_meta': 699470,\n", " 'customer_phone_features': 631148,\n", " 'customer_internet_features': 521053,\n", " 'customer_account_features': 699470,\n", " 'all': 699470}" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dict([(row[0], row[1]) for row in rows])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: examples/SQL+DF-Examples/customer-churn/notebooks/python/churn/augment.py ================================================ # Copyright (c) 2022, NVIDIA Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import datetime import os import pyspark from pyspark.sql.types import StructType, StructField, StringType, DoubleType, DecimalType import pyspark.sql.functions as F from collections import defaultdict options = defaultdict(lambda: None) now = datetime.datetime.now(datetime.timezone.utc) AUGMENT_VERSION = "0.7" AUGMENT_CUSTOMER_TAG = "0007" session = None currencyType = None def get_currency_type(): global options global currencyType if currencyType is not None: return currencyType if "use_decimal" in options and options["use_decimal"]: if "decimal_precision" in options : assert options["decimal_precision"] > 5, "Decimal precision is too small; was %d but should be at least 6" % options["decimal_precision"] currencyType = DecimalType(options["decimal_precision"], 2) else: # "999,999.99 should be enough for anyone" currencyType = DecimalType(8, 2) else: currencyType = DoubleType() return currencyType def _register_session(s): global session session = s def _get_uniques(ct): global session table_names = set([table.name for table in session.catalog.listTables()]) if ("uniques_%d" % ct) in table_names: return session.table("uniques_%d" % ct) else: def str_part(seed=0x5CA1AB1E): "generate the string part of a unique ID" import random r = random.Random(seed) from base64 import b64encode while True: yield "%s-%s" % (b64encode(r.getrandbits(72).to_bytes(9, "big"), b"@_").decode( "utf-8" ), AUGMENT_CUSTOMER_TAG) sp = str_part() uniques = ( session.createDataFrame( schema=StructType([StructField("u_value", StringType())]), data=[dict(u_value=next(sp)) for _ in range(min(int(ct * 1.02), ct + 2))], ) .distinct() .orderBy("u_value") .limit(ct) ).cache() uc = uniques.count() assert (uc == ct), "due to prng collision we had %d instead of %d replicas" % (uc, ct) uniques.createOrReplaceTempView("uniques_%d" % ct) return uniques def register_options(**kwargs): global options for k, v in kwargs.items(): options[k] = v def load_supplied_data(session, input_file): _register_session(session) fields = [ "customerID", "gender", "SeniorCitizen", "Partner", "Dependents", "tenure", "PhoneService", "MultipleLines", "InternetService", "OnlineSecurity", "OnlineBackup", "DeviceProtection", "TechSupport", "StreamingTV", "StreamingMovies", "Contract", "PaperlessBilling", "PaymentMethod", "MonthlyCharges", "TotalCharges", "Churn", ] double_fields = set(["tenure", "MonthlyCharges", "TotalCharges"]) schema = pyspark.sql.types.StructType( [ pyspark.sql.types.StructField( f, DoubleType() if f in double_fields else StringType() ) for f in fields ] ) df = session.read.csv(input_file, header=True, schema=schema) source_count = df.count() df = df.dropna() nn_count = df.count() if source_count == nn_count: print("read %d records from source dataset with no nulls -- is this what you expect?" % source_count) else: print("read %d records from source dataset (%d non-null records)" % (source_count, nn_count)) return df def replicate_df(df, duplicates): if duplicates > 1: uniques = _get_uniques(duplicates) df = ( df.crossJoin(uniques.distinct()) .withColumn("customerID", F.format_string("%s-%s", "customerID", "u_value")) .drop("u_value") ) return df def examine_categoricals(df, columns=None): """ Returns (to driver memory) a list of tuples consisting of every unique value for each column in `columns` or for every categorical column in the source data if no columns are specified """ default_columns = [ "SeniorCitizen", "Partner", "Dependents", "PhoneService", "MultipleLines", "InternetService", "OnlineSecurity", "OnlineBackup", "DeviceProtection", "TechSupport", "StreamingTV", "StreamingMovies", "Contract", "PaperlessBilling", "PaymentMethod", ] columns = columns or default_columns return [(c, [row[0] for row in df.select(c).distinct().rdd.collect()]) for c in columns] def billing_events(df): import datetime MAX_MONTH = 72 def get_last_month(col): h = F.abs(F.xxhash64(col)) h1 = (h.bitwiseAND(0xff)) % (MAX_MONTH // 2) h2 = (F.shiftRight(h, 8).bitwiseAND(0xff)) % (MAX_MONTH // 3) h3 = (F.shiftRight(h, 16).bitwiseAND(0xff)) % (MAX_MONTH // 5) h4 = (F.shiftRight(h, 24).bitwiseAND(0xff)) % (MAX_MONTH // 7) h5 = (F.shiftRight(h, 32).bitwiseAND(0xff)) % (MAX_MONTH // 11) return -(h1 + h2 + h3 + h4 + h5) w = pyspark.sql.Window.orderBy(F.lit("")).partitionBy(df.customerID) charges = ( df.select( df.customerID, F.lit("Charge").alias("kind"), F.explode( F.array_repeat((df.TotalCharges / df.tenure).cast(get_currency_type()), df.tenure.cast("int")) ).alias("value"), F.when(df.Churn == "Yes", get_last_month(df.customerID)).otherwise(0).alias("last_month") ) .withColumn("now", F.lit(now).cast("date")) .withColumn("month_number", -(F.row_number().over(w) + F.col("last_month"))) .withColumn("date", F.expr("add_months(now, month_number)")) .drop("now", "month_number", "last_month") ) serviceStarts = ( df.withColumn("last_month", F.when(df.Churn == "Yes", get_last_month(df.customerID)).otherwise(0)).select( df.customerID, F.lit("AccountCreation").alias("kind"), F.lit(0.0).cast(get_currency_type()).alias("value"), F.lit(now).alias("now"), (-df.tenure - 1 + F.col("last_month")).alias("month_number"), ) .withColumn("date", F.expr("add_months(now, month_number)")) .drop("now", "month_number") ) serviceTerminations = df.withColumn("last_month", F.when(df.Churn == "Yes", get_last_month(df.customerID)).otherwise(0)).where( df.Churn == "Yes" ).withColumn("now", F.lit(now)).select( df.customerID, F.lit("AccountTermination").alias("kind"), F.lit(0.0).cast(get_currency_type()).alias("value"), F.expr("add_months(now, last_month)").alias("date") ) billingEvents = charges.union(serviceStarts).union(serviceTerminations).orderBy("date").withColumn("month", F.substring("date", 0, 7)) return billingEvents def resolve_path(name): output_prefix = options["output_prefix"] or "" output_mode = options["output_mode"] or "overwrite" output_kind = options["output_kind"] or "parquet" name = "%s.%s" % (name, output_kind) if output_prefix != "": name = "%s%s" % (output_prefix, name) return name def write_df(df, name, skip_replication=False, partition_by=None): dup_times = options["dup_times"] or 1 output_prefix = options["output_prefix"] or "" output_mode = options["output_mode"] or "overwrite" output_kind = options["output_kind"] or "parquet" if not skip_replication: df = replicate_df(df, dup_times) write = df.write if partition_by is not None: if type(partition_by) == str: partition_by = [partition_by] write = write.partitionBy(*partition_by) name = "%s.%s" % (name, output_kind) if output_prefix != "": name = "%s%s" % (output_prefix, name) kwargs = {} if output_kind == "csv": kwargs["header"] = True getattr(write.mode(output_mode), output_kind)(name, **kwargs) def customer_meta(df): SENIOR_CUTOFF = 65 ADULT_CUTOFF = 18 DAYS_IN_YEAR = 365.25 EXPONENTIAL_DIST_SCALE = 6.3 augmented_original = replicate_df(df, options["dup_times"] or 1) customerMetaRaw = augmented_original.select( "customerID", F.lit(now).alias("now"), (F.abs(F.hash(augmented_original.customerID)) % 4096 / 4096).alias("choice"), "SeniorCitizen", "gender", "Partner", "Dependents", F.col("MonthlyCharges").cast(get_currency_type()).alias("MonthlyCharges"), ) customerMetaRaw = customerMetaRaw.withColumn( "ageInDays", F.floor( F.when( customerMetaRaw.SeniorCitizen == 0, ( customerMetaRaw.choice * ((SENIOR_CUTOFF - ADULT_CUTOFF - 1) * DAYS_IN_YEAR) ) + (ADULT_CUTOFF * DAYS_IN_YEAR), ).otherwise( (SENIOR_CUTOFF * DAYS_IN_YEAR) + ( DAYS_IN_YEAR * (-F.log1p(-customerMetaRaw.choice) * EXPONENTIAL_DIST_SCALE) ) ) ).cast("int"), ) customerMetaRaw = customerMetaRaw.withColumn( "dateOfBirth", F.expr("date_sub(now, ageInDays)") ) return customerMetaRaw.select( "customerID", "dateOfBirth", "gender", "SeniorCitizen", "Partner", "Dependents", "MonthlyCharges", "now", ).orderBy("customerID") def phone_features(df): phoneService = df.select( "customerID", F.lit("PhoneService").alias("feature"), F.lit("Yes").alias("value") ).where(df.PhoneService == "Yes") multipleLines = df.select( "customerID", F.lit("MultipleLines").alias("feature"), F.lit("Yes").alias("value") ).where(df.MultipleLines == "Yes") return phoneService.union(multipleLines).orderBy("customerID") def internet_features(df): internet_service = df.select( "customerID", F.lit("InternetService").alias("feature"), df.InternetService.alias("value"), ).where(df.InternetService != "No") customerInternetFeatures = internet_service for feature in [ "InternetService", "OnlineSecurity", "OnlineBackup", "DeviceProtection", "TechSupport", "StreamingTV", "StreamingMovies", ]: tmpdf = df.select( "customerID", F.lit(feature).alias("feature"), df[feature].alias("value"), ).where(df[feature] == "Yes") customerInternetFeatures = customerInternetFeatures.union(tmpdf) return customerInternetFeatures def account_features(df): session = df.sql_ctx.sparkSession accountSchema = pyspark.sql.types.StructType( [ pyspark.sql.types.StructField(f, StringType()) for f in ["customerID", "feature", "value"] ] ) customerAccountFeatures = session.createDataFrame(schema=accountSchema, data=[]) for feature in ["Contract", "PaperlessBilling", "PaymentMethod"]: tmpdf = df.select( "customerID", F.lit(feature).alias("feature"), df[feature].alias("value"), ).where(df[feature] != "No") customerAccountFeatures = customerAccountFeatures.union(tmpdf) return customerAccountFeatures def debug_augmentation(df): return ( df.select("customerID") .distinct() .select( "customerID", F.substring("customerID", 0, 10).alias("originalID"), F.element_at(F.split("customerID", "-", -1), 3).alias("suffix"), ) ) ================================================ FILE: examples/SQL+DF-Examples/customer-churn/notebooks/python/churn/eda.py ================================================ # Copyright (c) 2022, NVIDIA Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from pyspark.sql import types as T from pyspark.sql import functions as F eda_options = { 'use_array_ops' : False } def isnumeric(data_type): numeric_types = [T.ByteType, T.ShortType, T.IntegerType, T.LongType, T.FloatType, T.DoubleType, T.DecimalType] return any([isinstance(data_type, t) for t in numeric_types]) def percent_true(df, cols): denominator = df.count() return {col : df.where(F.col(col) == True).count() / denominator for col in cols} def cardinalities(df, cols): from functools import reduce counts = df.agg( F.struct(*[F.countDistinct(F.col(c)).alias(c) for c in cols] + [F.count(F.col(cols[0])).alias('total')]).alias("results") ).select("results").collect()[0][0].asDict() counts.update({'total' : df.count()}) return counts def likely_unique(counts): total = counts["total"] return [k for (k, v) in counts.items() if k != "total" and abs(total - v) < total * 0.15] def likely_categoricals(counts): total = counts["total"] return [k for (k, v) in counts.items() if v < total * 0.15 or v < 128] def unique_values(df, cols): if eda_options['use_array_ops']: return unique_values_array(df, cols) else: return unique_values_driver(df, cols) def unique_values_array(df, cols): from functools import reduce counts = df.groupBy( F.lit(True).alias("drop_me") ).agg( *[F.array_sort(F.collect_set(F.col(c))).alias(c) for c in cols] ).drop("drop_me").cache() result = reduce(lambda l, r: l.unionAll(r), [counts.select(F.lit(c).alias("field"), F.col(c).alias("unique_vals")) for c in counts.columns]).collect() return dict([(r[0],r[1]) for r in result]) def unique_values_driver(df, cols): return { col : [v[0] for v in df.select(F.col(col).alias('value')).distinct().orderBy(F.col('value')).collect()] for col in cols} def approx_ecdf(df, cols): from functools import reduce quantiles = [0.0, 0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99, 1.0] qs = df.approxQuantile(cols, quantiles, 0.01) result = dict(zip(cols, qs)) return {c: dict(zip(quantiles, vs)) for (c, vs) in result.items()} def gen_summary(df, output_prefix=""): summary = {} string_cols = [] boolean_cols = [] numeric_cols = [] other_cols = [] for field in df.schema.fields: if isinstance(field.dataType, T.StringType): string_cols.append(field.name) elif isinstance(field.dataType, T.BooleanType): boolean_cols.append(field.name) elif isnumeric(field.dataType): numeric_cols.append(field.name) else: other_cols.append(field.name) counts = cardinalities(df, string_cols) uniques = likely_unique(counts) categoricals = unique_values(df, likely_categoricals(counts)) for span in [2,3,4,6,12]: thecube = df.cube("Churn", F.ceil(df.tenure / span).alias("%d_month_spans" % span), "gender", "Partner", "SeniorCitizen", "Contract", "PaperlessBilling", "PaymentMethod", F.ceil(F.log2(F.col("MonthlyCharges"))*10).alias("log_charges")).count() therollup = df.rollup("Churn", F.ceil(df.tenure / span).alias("%d_month_spans" % span), "SeniorCitizen", "Contract", "PaperlessBilling", "PaymentMethod", F.ceil(F.log2(F.col("MonthlyCharges"))*10).alias("log_charges")).agg(F.sum(F.col("TotalCharges")).alias("sum_charges")) thecube.write.mode("overwrite").parquet("%scube-%d.parquet" % (output_prefix, span)) therollup.write.mode("overwrite").parquet("%srollup-%d.parquet" % (output_prefix, span)) encoding_struct = { "categorical" : categoricals, "numeric" : numeric_cols + boolean_cols, "unique": uniques } summary["schema"] = df.schema.jsonValue() summary["ecdfs"] = approx_ecdf(df, numeric_cols) summary["true_percentage"] = percent_true(df, boolean_cols) summary["encoding"] = encoding_struct summary["distinct_customers"] = df.select(df.customerID).distinct().count() return summary def losses_by_month(be): customer_lifetime_values = be.groupBy("customerID").sum("value").alias("value") return be.where(be.kind == "AccountTermination").join(customer_lifetime_values, "customerID").groupBy("month").sum("value").alias("value").sort("month").toPandas().to_json() def output_reports(df, be=None, report_prefix=""): import json summary = gen_summary(df, report_prefix) if be is not None: summary["losses_by_month"] = losses_by_month(be) with open("%ssummary.json" % report_prefix, "w") as sf: json.dump(summary, sf) with open("%sencodings.json" % report_prefix, "w") as ef: json.dump(summary["encoding"], ef) ================================================ FILE: examples/SQL+DF-Examples/customer-churn/notebooks/python/churn/etl.py ================================================ #!/usr/bin/env python # coding: utf-8 # Copyright (c) 2022, NVIDIA Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import pyspark import pyspark.sql import pyspark.sql.functions as F from collections import defaultdict options = defaultdict(lambda: None) session = None ETL_VERSION = '0.7' def register_options(**kwargs): global options for k, v in kwargs.items(): options[k] = v def _register_session(s): global session session = s def _register_views(lvars, *names): for n in names: if n in lvars: lvars[n].createOrReplaceTempView(n) def withsession(df_arg=0): def decorate(fn): def wrapped(*args, **kwargs): _register_session(args[df_arg].sql_ctx.sparkSession) fn(*args, **kwargs) return wrapped return decorate def read_df(session, fn): kwargs = {} _register_session(session) input_kind = options["input_kind"] if input_kind == "csv": kwargs["header"] = True return getattr(session.read, input_kind)("%s.%s" % (fn, input_kind), **kwargs) def find_customers(billing_events_df): customers = billing_events_df.select("customerID").distinct() if 'cache_customers' in options: customers.cache() customers.createOrReplaceTempView("customers") return customers def customers(): global session return session.table("customers") def join_billing_data(billing_events_df): _register_session(billing_events_df.sql_ctx.sparkSession) billing_events = billing_events_df.withColumn("value", billing_events_df.value) customers = find_customers(billing_events) counts_and_charges = billing_events.groupBy("customerID", "kind").agg( F.count(billing_events.value).alias("event_counts"), F.sum(billing_events.value).alias("total_charges"), ) counts_and_charges.createOrReplaceTempView("counts_and_charges") terminations = billing_events.where(F.col("kind") == "AccountTermination").select( F.col("customerID").alias("Churn") ) churned = customers.join( terminations, customers.customerID == terminations.Churn, how="leftouter" ).select( "customerID", F.when(F.col("Churn").isNull(), F.lit(False)).otherwise(F.lit(True)).alias("Churn") ) customer_charges = customers.join( counts_and_charges.where(F.col("kind") == "Charge"), "customerID", how="leftouter" ).select( "customerID", F.col("event_counts").alias("tenure"), F.col("total_charges").alias("TotalCharges"), ).fillna({'tenure': 0, 'TotalCharges': 0.0}) _register_views(locals(), "counts_and_charges", "terminations", "churned", "customer_charges") # counts_and_charges.createOrReplaceTempView("counts_and_charges") # terminations.createOrReplaceTempView("terminations") # churned.createOrReplaceTempView("churned") # customer_charges.createOrReplaceTempView("customer_charges") customer_billing = churned.join(customer_charges, "customerID") _register_views(locals(), "counts_and_charges", "terminations", "churned", "customer_charges", "customer_billing") return customer_billing def join_phone_features(phone_features_df): phone_features = phone_features_df phone_service = phone_features.where(F.col("feature") == "PhoneService").select( "customerID", F.lit("Yes").alias("PhoneService") ) multiple_lines = phone_features.where(F.col("feature") == "MultipleLines").select( "customerID", F.lit("Yes").alias("MultipleLines") ) customer_phone_features = ( customers().join(phone_service, "customerID", how="leftouter") .join(multiple_lines, "customerID", how="leftouter") .select( "customerID", F.when(F.col("PhoneService").isNull(), "No") .otherwise("Yes") .alias("PhoneService"), "MultipleLines", ) .select( "customerID", "PhoneService", F.when(F.col("PhoneService") == "No", "No phone service") .otherwise(F.when(F.col("MultipleLines").isNull(), "No").otherwise("Yes")) .alias("MultipleLines"), ) ) _register_views(locals(), "phone_service", "multiple_lines", "customer_phone_features") return customer_phone_features def untidy_feature(df, feature): """ 'untidies' a feature by turning it into a column """ return df.where(F.col("feature") == feature).select( "customerID", F.col("value").alias(feature) ) def chained_join(column, base_df, dfs, how="leftouter"): """ repeatedly joins a sequence of data frames on the same column """ acc = base_df for df in dfs: acc = acc.join(df, column, how=how) return acc def resolve_nullable_column(df, col, null_val="No"): return F.when(df[col].isNull(), null_val).otherwise(df[col]).alias(col) def resolve_dependent_column( df, col, parent_col="InternetService", null_val="No", null_parent_val="No internet service", ): return ( F.when((df[parent_col] == "No") | (df[parent_col].isNull()), null_parent_val) .otherwise(F.when(df[col].isNull(), null_val).otherwise(df[col])) .alias(col) ) def join_internet_features(internet_features_df): internet_features = internet_features_df internet_service = untidy_feature(internet_features, "InternetService") online_security = untidy_feature(internet_features, "OnlineSecurity") online_backup = untidy_feature(internet_features, "OnlineBackup") device_protection = untidy_feature(internet_features, "DeviceProtection") tech_support = untidy_feature(internet_features, "TechSupport") streaming_tv = untidy_feature(internet_features, "StreamingTV") streaming_movies = untidy_feature(internet_features, "StreamingMovies") customer_internet_features = chained_join( "customerID", customers(), [ internet_service, online_security, online_backup, device_protection, tech_support, streaming_tv, streaming_movies, ], ) customer_internet_features = customer_internet_features.select( "customerID", resolve_nullable_column(customer_internet_features, "InternetService"), resolve_dependent_column( customer_internet_features, "OnlineSecurity", "InternetService" ), resolve_dependent_column( customer_internet_features, "OnlineBackup", "InternetService" ), resolve_dependent_column( customer_internet_features, "DeviceProtection", "InternetService" ), resolve_dependent_column( customer_internet_features, "TechSupport", "InternetService" ), resolve_dependent_column( customer_internet_features, "StreamingTV", "InternetService" ), resolve_dependent_column( customer_internet_features, "StreamingMovies", "InternetService" ), ) _register_views(locals(), "internet_service", "online_security", "online_backup", "device_protection", "tech_support", "streaming_tv", "streaming_movies", "customer_internet_features" ) return customer_internet_features def join_account_features(account_features_df): account_features = account_features_df contracts = untidy_feature(account_features, "Contract") paperless = untidy_feature(account_features, "PaperlessBilling") payment = untidy_feature(account_features, "PaymentMethod") customer_account_features = chained_join( "customerID", customers(), [contracts, paperless, payment] ) customer_account_features = customer_account_features.select( "customerID", "Contract", resolve_nullable_column(customer_account_features, "PaperlessBilling"), "PaymentMethod", ) _register_views(locals(), "contracts", "paperless", "payment", "customer_account_features") return customer_account_features def process_account_meta(account_meta_df, usecal=None): def is_senior_citizen(nowcol, dobcol): if options['use_calendar_arithmetic']: return F.when( F.col("now") >= F.add_months( F.col("dateOfBirth"), 65 * 12 ), F.lit(True) ).otherwise(F.lit(False)) else: return (F.year(F.col(nowcol)) > (F.year(F.col(dobcol)) + 65)) | \ (F.year(F.col(nowcol)) == (F.year(F.col(dobcol)) + 65)) & \ ( (F.month(F.col(nowcol)) < F.month(F.col(dobcol))) | \ ( (F.month(F.col(nowcol)) == F.month(F.col(dobcol))) & \ (F.dayofmonth(F.col(nowcol)) <= F.dayofmonth(F.col(nowcol))) ) ) customer_account_meta = account_meta_df.select( "customerID", is_senior_citizen("now", "dateOfBirth").alias("SeniorCitizen"), "Partner", "Dependents", "gender", "MonthlyCharges", ) _register_views(locals(), "customer_account_meta") return customer_account_meta def forcefloat(c): return F.col(c).cast("float").alias(c) def join_wide_table(customer_billing, customer_phone_features, customer_internet_features, customer_account_features, customer_account_meta): wide_data = chained_join( "customerID", customers(), [ customer_billing, customer_phone_features, customer_internet_features, customer_account_features, customer_account_meta, ], ).select( "customerID", "gender", "SeniorCitizen", "Partner", "Dependents", "tenure", "PhoneService", "MultipleLines", "InternetService", "OnlineSecurity", "OnlineBackup", "DeviceProtection", "TechSupport", "StreamingTV", "StreamingMovies", "Contract", "PaperlessBilling", "PaymentMethod", "MonthlyCharges", "TotalCharges", "Churn", ) return wide_data # In[ ]: def cast_and_coalesce_wide_data(wd): if options["coalesce_output"] > 0: wd = wd.coalesce(options["coalesce_output"]) return wd.select( "customerID", "gender", "SeniorCitizen", "Partner", "Dependents", "tenure", "PhoneService", "MultipleLines", "InternetService", "OnlineSecurity", "OnlineBackup", "DeviceProtection", "TechSupport", "StreamingTV", "StreamingMovies", "Contract", "PaperlessBilling", "PaymentMethod", forcefloat("MonthlyCharges"), forcefloat("TotalCharges"), "Churn", ) def write_df(df, name): output_kind = options["output_kind"] output_mode = options["output_mode"] output_prefix = options["output_prefix"] name = "%s.%s" % (name, output_kind) if output_prefix != "": name = "%s%s" % (output_prefix, name) kwargs = {} if output_kind == "csv": kwargs["header"] = True getattr(df.write.mode(output_mode), output_kind)(name, **kwargs) ================================================ FILE: examples/SQL+DF-Examples/customer-churn/notebooks/python/etl.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Transforming and joining raw data\n", "\n", "The \"raw\" data is divided among the following tables:\n", "\n", "- **Customer metadata**\n", " - customerID\n", " - gender\n", " - date of birth (we'll derive age and senior citizen status from this)\n", " - Partner\n", " - Dependents\n", " - (nominal) MonthlyCharges\n", "- **Billing events**\n", " - customerID\n", " - date (we'll derive tenure from the number/duration of billing events)\n", " - kind (one of \"AccountCreation\", \"Charge\", or \"AccountTermination\")\n", " - value (either a positive nonzero amount or 0.00; we'll derive TotalCharges from the sum of amounts and Churn from the existence of an AccountTermination event)\n", "- **Customer phone features**\n", " - customerID\n", " - feature (one of \"PhoneService\" or \"MultipleLines\")\n", "- **Customer internet features**\n", " - customerID\n", " - feature (one of \"InternetService\", \"OnlineSecurity\", \"OnlineBackup\", \"DeviceProtection\", \"TechSupport\", \"StreamingTV\", \"StreamingMovies\")\n", " - value (one of \"Fiber\", \"DSL\", \"Yes\", \"No\")\n", "- **Customer account features**\n", " - customerID\n", " - feature (one of \"Contract\", \"PaperlessBilling\", \"PaymentMethod\")\n", " - value (one of \"Month-to-month\", \"One year\", \"Two year\", \"No\", \"Yes\", \"Credit card (automatic)\", \"Mailed check\", \"Bank transfer (automatic)\", \"Electronic check\")\n", "\n", "We want to join these together to reconstitute a training data set with this schema:\n", "\n", "- customerID\n", "- gender\n", "- SeniorCitizen\n", "- Partner\n", "- Dependents\n", "- tenure\n", "- PhoneService\n", "- MultipleLines\n", "- InternetService\n", "- OnlineSecurity\n", "- OnlineBackup\n", "- DeviceProtection\n", "- TechSupport\n", "- StreamingTV\n", "- StreamingMovies\n", "- Contract\n", "- PaperlessBilling\n", "- PaymentMethod\n", "- MonthlyCharges\n", "- TotalCharges\n", "- Churn" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "tags": [ "parameters" ] }, "outputs": [], "source": [ "# notebook parameters\n", "\n", "import os\n", "\n", "spark_master = os.getenv(\"SPARK_MASTER_URL\", \"/spark://ip:port\")\n", "app_name = \"churn-etl\"\n", "input_files = dict(\n", " billing=\"billing_events\", \n", " account_features=\"customer_account_features\", \n", " internet_features=\"customer_internet_features\", \n", " meta=\"customer_meta\", \n", " phone_features=\"customer_phone_features\"\n", ")\n", "output_file = \"churn-etl\"\n", "output_mode = \"overwrite\"\n", "output_kind = \"parquet\"\n", "input_kind = \"parquet\"\n", "driver_memory = '8g'\n", "executor_memory = '8g'\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "

SparkSession - hive

\n", " \n", "
\n", "

SparkContext

\n", "\n", "

Spark UI

\n", "\n", "
\n", "
Version
\n", "
v3.2.0
\n", "
Master
\n", "
spark://yuanli-System-Product-Name:7077
\n", "
AppName
\n", "
PySparkShell
\n", "
\n", "
\n", " \n", "
\n", " " ], "text/plain": [ "" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pyspark\n", "\n", "session = pyspark.sql.SparkSession.builder \\\n", " .master(spark_master) \\\n", " .appName(app_name) \\\n", " .config(\"spark.eventLog.enabled\", True) \\\n", " .config(\"spark.eventLog.dir\", \".\") \\\n", " .config(\"spark.driver.memory\", driver_memory) \\\n", " .config(\"spark.executor.memory\", executor_memory) \\\n", " .getOrCreate()\n", "session" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import churn.etl\n", "\n", "churn.etl.register_options(\n", " spark_master = spark_master,\n", " app_name = app_name,\n", " input_files = input_files,\n", " output_mode = output_mode,\n", " output_kind = output_kind,\n", " input_kind = input_kind,\n", " driver_memory = driver_memory,\n", " executor_memory = executor_memory\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Reconstructing billing events and charges" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "root\n", " |-- customerID: string (nullable = true)\n", " |-- kind: string (nullable = true)\n", " |-- value: decimal(8,2) (nullable = true)\n", " |-- date: date (nullable = true)\n", " |-- month: string (nullable = true)\n", "\n" ] } ], "source": [ "from churn.etl import read_df\n", "billing_events = read_df(session, input_files[\"billing\"])\n", "billing_events.printSchema()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "from churn.etl import join_billing_data\n", "customer_billing = join_billing_data(billing_events)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DataFrame[customerID: string, Churn: boolean, tenure: bigint, TotalCharges: decimal(18,2)]" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "customer_billing" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "When we aggregated billing data, we also captured a unique list of customers in a temporary view. For convenience, we can access it as follows:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "from churn.etl import customers as get_customers\n", "customers = get_customers()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Reconstructing phone features\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "root\n", " |-- customerID: string (nullable = true)\n", " |-- feature: string (nullable = true)\n", " |-- value: string (nullable = true)\n", "\n" ] } ], "source": [ "phone_features = read_df(session, input_files[\"phone_features\"])\n", "phone_features.printSchema()" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "from churn.etl import join_phone_features\n", "customer_phone_features = join_phone_features(phone_features)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Reconstructing internet features\n", "\n", "Whereas phone features only include whether or not there are multiple lines, there are several internet-specific features in accounts:\n", "\n", "- `InternetService` (one of `Fiber optic` or `DSL` in the \"raw\" data; its absence translates to `No` in the processed data)\n", "- `OnlineSecurity` (`Yes` in the \"raw\" data if present; one of `No`, `Yes`, or `No internet service` in the processed data)\n", "- `OnlineBackup` (`Yes` in the \"raw\" data if present; one of `No`, `Yes`, or `No internet service` in the processed data)\n", "- `DeviceProtection` (`Yes` in the \"raw\" data if present; one of `No`, `Yes`, or `No internet service` in the processed data)\n", "- `TechSupport` (`Yes` in the \"raw\" data if present; one of `No`, `Yes`, or `No internet service` in the processed data)\n", "- `StreamingTV` (`Yes` in the \"raw\" data if present; one of `No`, `Yes`, or `No internet service` in the processed data)\n", "- `StreamingMovies` (`Yes` in the \"raw\" data if present; one of `No`, `Yes`, or `No internet service` in the processed data)\n", "\n", "This will lead to some slightly more interesting joins!" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "root\n", " |-- customerID: string (nullable = true)\n", " |-- feature: string (nullable = true)\n", " |-- value: string (nullable = true)\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2022-04-05 09:59:39,224 WARN rapids.GpuOverrides: \n", "!Exec cannot run on GPU because the Exec CollectLimitExec has been disabled, and is disabled by default because Collect Limit replacement can be slower on the GPU, if huge number of rows in a batch it could help by limiting the number of rows transferred from GPU to CPU. Set spark.rapids.sql.exec.CollectLimitExec to true if you wish to enable it\n", " @Partitioning could run on GPU\n", "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "+--------------------+---------------+-----+\n", "| customerID| feature|value|\n", "+--------------------+---------------+-----+\n", "|7590-VHVEG-Mg8VG5...|InternetService| DSL|\n", "|7590-VHVEG-5xLi5Z...|InternetService| DSL|\n", "|7590-VHVEG-ZePlJi...|InternetService| DSL|\n", "|7590-VHVEG-x9IoNd...|InternetService| DSL|\n", "|7590-VHVEG-Z9yCIk...|InternetService| DSL|\n", "|7590-VHVEG-K8kBya...|InternetService| DSL|\n", "|7590-VHVEG-4ZjnIU...|InternetService| DSL|\n", "|7590-VHVEG-0stTDJ...|InternetService| DSL|\n", "|7590-VHVEG-lqhKlh...|InternetService| DSL|\n", "|7590-VHVEG-4Y_zUA...|InternetService| DSL|\n", "|7590-VHVEG-34V86Q...|InternetService| DSL|\n", "|7590-VHVEG-GCNzU2...|InternetService| DSL|\n", "|7590-VHVEG-i0AFUE...|InternetService| DSL|\n", "|7590-VHVEG-F1ALBc...|InternetService| DSL|\n", "|7590-VHVEG-aEfHl7...|InternetService| DSL|\n", "|7590-VHVEG-eiqTDe...|InternetService| DSL|\n", "|7590-VHVEG-3K15yQ...|InternetService| DSL|\n", "|7590-VHVEG-iMYyeZ...|InternetService| DSL|\n", "|7590-VHVEG-rReekB...|InternetService| DSL|\n", "|7590-VHVEG-2l92Zs...|InternetService| DSL|\n", "+--------------------+---------------+-----+\n", "only showing top 20 rows\n", "\n" ] } ], "source": [ "internet_features = read_df(session, input_files[\"internet_features\"])\n", "internet_features.printSchema()\n", "internet_features.show()" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "from churn.etl import join_internet_features\n", "customer_internet_features = join_internet_features(internet_features)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Reconstructing account features" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "root\n", " |-- customerID: string (nullable = true)\n", " |-- feature: string (nullable = true)\n", " |-- value: string (nullable = true)\n", "\n", "+--------------------+-------------+----------------+\n", "| customerID| feature| value|\n", "+--------------------+-------------+----------------+\n", "|7590-VHVEG-Mg8VG5...|PaymentMethod|Electronic check|\n", "|7590-VHVEG-5xLi5Z...|PaymentMethod|Electronic check|\n", "|7590-VHVEG-ZePlJi...|PaymentMethod|Electronic check|\n", "|7590-VHVEG-x9IoNd...|PaymentMethod|Electronic check|\n", "|7590-VHVEG-Z9yCIk...|PaymentMethod|Electronic check|\n", "|7590-VHVEG-K8kBya...|PaymentMethod|Electronic check|\n", "|7590-VHVEG-4ZjnIU...|PaymentMethod|Electronic check|\n", "|7590-VHVEG-0stTDJ...|PaymentMethod|Electronic check|\n", "|7590-VHVEG-lqhKlh...|PaymentMethod|Electronic check|\n", "|7590-VHVEG-4Y_zUA...|PaymentMethod|Electronic check|\n", "|7590-VHVEG-34V86Q...|PaymentMethod|Electronic check|\n", "|7590-VHVEG-GCNzU2...|PaymentMethod|Electronic check|\n", "|7590-VHVEG-i0AFUE...|PaymentMethod|Electronic check|\n", "|7590-VHVEG-F1ALBc...|PaymentMethod|Electronic check|\n", "|7590-VHVEG-aEfHl7...|PaymentMethod|Electronic check|\n", "|7590-VHVEG-3K15yQ...|PaymentMethod|Electronic check|\n", "|7590-VHVEG-eiqTDe...|PaymentMethod|Electronic check|\n", "|7590-VHVEG-iMYyeZ...|PaymentMethod|Electronic check|\n", "|7590-VHVEG-rReekB...|PaymentMethod|Electronic check|\n", "|7590-VHVEG-2l92Zs...|PaymentMethod|Electronic check|\n", "+--------------------+-------------+----------------+\n", "only showing top 20 rows\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2022-04-05 09:59:42,068 WARN rapids.GpuOverrides: \n", "!Exec cannot run on GPU because the Exec CollectLimitExec has been disabled, and is disabled by default because Collect Limit replacement can be slower on the GPU, if huge number of rows in a batch it could help by limiting the number of rows transferred from GPU to CPU. Set spark.rapids.sql.exec.CollectLimitExec to true if you wish to enable it\n", " @Partitioning could run on GPU\n", "\n" ] } ], "source": [ "account_features = read_df(session, input_files[\"account_features\"])\n", "account_features.printSchema()\n", "account_features.show()" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "from churn.etl import join_account_features\n", "customer_account_features = join_account_features(account_features)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Account metadata" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "root\n", " |-- customerID: string (nullable = true)\n", " |-- dateOfBirth: date (nullable = true)\n", " |-- gender: string (nullable = true)\n", " |-- SeniorCitizen: string (nullable = true)\n", " |-- Partner: string (nullable = true)\n", " |-- Dependents: string (nullable = true)\n", " |-- MonthlyCharges: decimal(8,2) (nullable = true)\n", " |-- now: timestamp (nullable = true)\n", "\n" ] } ], "source": [ "account_meta = read_df(session, input_files[\"meta\"])\n", "account_meta.printSchema()" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "from churn.etl import process_account_meta\n", "customer_account_meta = process_account_meta(account_meta)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Putting it all together" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "from churn.etl import chained_join\n", "from churn.etl import forcefloat\n", "\n", "wide_data = chained_join(\n", " \"customerID\",\n", " customers,\n", " [\n", " customer_billing,\n", " customer_phone_features,\n", " customer_internet_features,\n", " customer_account_features,\n", " customer_account_meta\n", " ]\n", ").select(\n", " \"customerID\", \n", " \"gender\", \n", " \"SeniorCitizen\", \n", " \"Partner\", \n", " \"Dependents\", \n", " \"tenure\", \n", " \"PhoneService\", \n", " \"MultipleLines\", \n", " \"InternetService\", \n", " \"OnlineSecurity\", \n", " \"OnlineBackup\", \n", " \"DeviceProtection\", \n", " \"TechSupport\", \n", " \"StreamingTV\", \n", " \"StreamingMovies\", \n", " \"Contract\", \n", " \"PaperlessBilling\", \n", " \"PaymentMethod\", \n", " forcefloat(\"MonthlyCharges\"),\n", " forcefloat(\"TotalCharges\"), \n", " \"Churn\"\n", ")" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "== Physical Plan ==\n", "AdaptiveSparkPlan isFinalPlan=false\n", "+- Project [customerID#0, gender#265, SeniorCitizen#279, Partner#267, Dependents#268, tenure#61L, PhoneService#97, MultipleLines#101, InternetService#199, OnlineSecurity#200, OnlineBackup#201, DeviceProtection#202, TechSupport#203, StreamingTV#204, StreamingMovies#205, Contract#233, PaperlessBilling#258, PaymentMethod#239, cast(MonthlyCharges#269 as float) AS MonthlyCharges#366, cast(TotalCharges#62 as float) AS TotalCharges#367, Churn#41]\n", " +- BroadcastHashJoin [customerID#0], [customerID#263], LeftOuter, BuildRight, false\n", " :- Project [customerID#0, Churn#41, tenure#61L, TotalCharges#62, PhoneService#97, MultipleLines#101, InternetService#199, OnlineSecurity#200, OnlineBackup#201, DeviceProtection#202, TechSupport#203, StreamingTV#204, StreamingMovies#205, Contract#233, PaperlessBilling#258, PaymentMethod#239]\n", " : +- SortMergeJoin [customerID#0], [customerID#324], LeftOuter\n", " : :- Project [customerID#0, Churn#41, tenure#61L, TotalCharges#62, PhoneService#97, MultipleLines#101, InternetService#199, OnlineSecurity#200, OnlineBackup#201, DeviceProtection#202, TechSupport#203, StreamingTV#204, StreamingMovies#205]\n", " : : +- SortMergeJoin [customerID#0], [customerID#306], LeftOuter\n", " : : :- Project [customerID#0, Churn#41, tenure#61L, TotalCharges#62, PhoneService#97, MultipleLines#101]\n", " : : : +- SortMergeJoin [customerID#0], [customerID#295], LeftOuter\n", " : : : :- Project [customerID#0, Churn#41, tenure#61L, TotalCharges#62]\n", " : : : : +- SortMergeJoin [customerID#0], [customerID#286], LeftOuter\n", " : : : : :- Sort [customerID#0 ASC NULLS FIRST], false, 0\n", " : : : : : +- HashAggregate(keys=[customerID#0], functions=[])\n", " : : : : : +- Exchange hashpartitioning(customerID#0, 200), ENSURE_REQUIREMENTS, [id=#550]\n", " : : : : : +- HashAggregate(keys=[customerID#0], functions=[])\n", " : : : : : +- Project [customerID#0]\n", " : : : : : +- FileScan parquet [customerID#0,month#4] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/data/home/yuanli/work/customer-churn/data-science-blueprints/chu..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct\n", " : : : : +- Project [customerID#286, Churn#41, tenure#61L, TotalCharges#62]\n", " : : : : +- SortMergeJoin [customerID#286], [customerID#66], Inner\n", " : : : : :- Project [customerID#286, isnotnull(Churn#30) AS Churn#41]\n", " : : : : : +- SortMergeJoin [customerID#286], [Churn#30], LeftOuter\n", " : : : : : :- Sort [customerID#286 ASC NULLS FIRST], false, 0\n", " : : : : : : +- HashAggregate(keys=[customerID#286], functions=[])\n", " : : : : : : +- Exchange hashpartitioning(customerID#286, 200), ENSURE_REQUIREMENTS, [id=#552]\n", " : : : : : : +- HashAggregate(keys=[customerID#286], functions=[])\n", " : : : : : : +- Project [customerID#286]\n", " : : : : : : +- Filter isnotnull(customerID#286)\n", " : : : : : : +- FileScan parquet [customerID#286,month#290] Batched: true, DataFilters: [isnotnull(customerID#286)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/data/home/yuanli/work/customer-churn/data-science-blueprints/chu..., PartitionFilters: [], PushedFilters: [IsNotNull(customerID)], ReadSchema: struct\n", " : : : : : +- Sort [Churn#30 ASC NULLS FIRST], false, 0\n", " : : : : : +- Exchange hashpartitioning(Churn#30, 200), ENSURE_REQUIREMENTS, [id=#556]\n", " : : : : : +- Project [customerID#32 AS Churn#30]\n", " : : : : : +- Filter ((isnotnull(kind#33) AND (kind#33 = AccountTermination)) AND isnotnull(customerID#32))\n", " : : : : : +- FileScan parquet [customerID#32,kind#33,month#36] Batched: true, DataFilters: [isnotnull(kind#33), (kind#33 = AccountTermination), isnotnull(customerID#32)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/data/home/yuanli/work/customer-churn/data-science-blueprints/chu..., PartitionFilters: [], PushedFilters: [IsNotNull(kind), EqualTo(kind,AccountTermination), IsNotNull(customerID)], ReadSchema: struct\n", " : : : : +- Project [customerID#66, coalesce(event_counts#23L, 0) AS tenure#61L, coalesce(total_charges#25, 0.00) AS TotalCharges#62]\n", " : : : : +- SortMergeJoin [customerID#66], [customerID#44], LeftOuter\n", " : : : : :- Sort [customerID#66 ASC NULLS FIRST], false, 0\n", " : : : : : +- HashAggregate(keys=[customerID#66], functions=[])\n", " : : : : : +- Exchange hashpartitioning(customerID#66, 200), ENSURE_REQUIREMENTS, [id=#561]\n", " : : : : : +- HashAggregate(keys=[customerID#66], functions=[])\n", " : : : : : +- Project [customerID#66]\n", " : : : : : +- Filter isnotnull(customerID#66)\n", " : : : : : +- FileScan parquet [customerID#66,month#70] Batched: true, DataFilters: [isnotnull(customerID#66)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/data/home/yuanli/work/customer-churn/data-science-blueprints/chu..., PartitionFilters: [], PushedFilters: [IsNotNull(customerID)], ReadSchema: struct\n", " : : : : +- Sort [customerID#44 ASC NULLS FIRST], false, 0\n", " : : : : +- Exchange hashpartitioning(customerID#44, 200), ENSURE_REQUIREMENTS, [id=#567]\n", " : : : : +- HashAggregate(keys=[customerID#44, kind#45], functions=[count(value#46), sum(UnscaledValue(value#46))])\n", " : : : : +- Exchange hashpartitioning(customerID#44, kind#45, 200), ENSURE_REQUIREMENTS, [id=#563]\n", " : : : : +- HashAggregate(keys=[customerID#44, kind#45], functions=[partial_count(value#46), partial_sum(UnscaledValue(value#46))])\n", " : : : : +- Project [customerID#44, kind#45, value#46]\n", " : : : : +- Filter ((isnotnull(kind#45) AND (kind#45 = Charge)) AND isnotnull(customerID#44))\n", " : : : : +- FileScan parquet [customerID#44,kind#45,value#46,month#48] Batched: true, DataFilters: [isnotnull(kind#45), (kind#45 = Charge), isnotnull(customerID#44)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/data/home/yuanli/work/customer-churn/data-science-blueprints/chu..., PartitionFilters: [], PushedFilters: [IsNotNull(kind), EqualTo(kind,Charge), IsNotNull(customerID)], ReadSchema: struct\n", " : : : +- Sort [customerID#295 ASC NULLS FIRST], false, 0\n", " : : : +- Project [customerID#295, CASE WHEN isnull(PhoneService#82) THEN No ELSE Yes END AS PhoneService#97, CASE WHEN isnull(PhoneService#82) THEN No phone service ELSE CASE WHEN isnull(MultipleLines#85) THEN No ELSE Yes END END AS MultipleLines#101]\n", " : : : +- BroadcastHashJoin [customerID#295], [customerID#91], LeftOuter, BuildRight, false\n", " : : : :- Project [customerID#295, PhoneService#82]\n", " : : : : +- BroadcastHashJoin [customerID#295], [customerID#76], LeftOuter, BuildRight, false\n", " : : : : :- HashAggregate(keys=[customerID#295], functions=[])\n", " : : : : : +- Exchange hashpartitioning(customerID#295, 200), ENSURE_REQUIREMENTS, [id=#580]\n", " : : : : : +- HashAggregate(keys=[customerID#295], functions=[])\n", " : : : : : +- Project [customerID#295]\n", " : : : : : +- Filter isnotnull(customerID#295)\n", " : : : : : +- FileScan parquet [customerID#295,month#299] Batched: true, DataFilters: [isnotnull(customerID#295)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/data/home/yuanli/work/customer-churn/data-science-blueprints/chu..., PartitionFilters: [], PushedFilters: [IsNotNull(customerID)], ReadSchema: struct\n", " : : : : +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, string, true]),false), [id=#583]\n", " : : : : +- Project [customerID#76, Yes AS PhoneService#82]\n", " : : : : +- Filter ((isnotnull(feature#77) AND (feature#77 = PhoneService)) AND isnotnull(customerID#76))\n", " : : : : +- FileScan parquet [customerID#76,feature#77] Batched: true, DataFilters: [isnotnull(feature#77), (feature#77 = PhoneService), isnotnull(customerID#76)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/data/home/yuanli/work/customer-churn/data-science-blueprints/chu..., PartitionFilters: [], PushedFilters: [IsNotNull(feature), EqualTo(feature,PhoneService), IsNotNull(customerID)], ReadSchema: struct\n", " : : : +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, string, true]),false), [id=#587]\n", " : : : +- Project [customerID#91, Yes AS MultipleLines#85]\n", " : : : +- Filter ((isnotnull(feature#92) AND (feature#92 = MultipleLines)) AND isnotnull(customerID#91))\n", " : : : +- FileScan parquet [customerID#91,feature#92] Batched: true, DataFilters: [isnotnull(feature#92), (feature#92 = MultipleLines), isnotnull(customerID#91)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/data/home/yuanli/work/customer-churn/data-science-blueprints/chu..., PartitionFilters: [], PushedFilters: [IsNotNull(feature), EqualTo(feature,MultipleLines), IsNotNull(customerID)], ReadSchema: struct\n", " : : +- Sort [customerID#306 ASC NULLS FIRST], false, 0\n", " : : +- Project [customerID#306, CASE WHEN isnull(InternetService#124) THEN No ELSE InternetService#124 END AS InternetService#199, CASE WHEN ((InternetService#124 = No) OR isnull(InternetService#124)) THEN No internet service ELSE CASE WHEN isnull(OnlineSecurity#127) THEN No ELSE OnlineSecurity#127 END END AS OnlineSecurity#200, CASE WHEN ((InternetService#124 = No) OR isnull(InternetService#124)) THEN No internet service ELSE CASE WHEN isnull(OnlineBackup#130) THEN No ELSE OnlineBackup#130 END END AS OnlineBackup#201, CASE WHEN ((InternetService#124 = No) OR isnull(InternetService#124)) THEN No internet service ELSE CASE WHEN isnull(DeviceProtection#133) THEN No ELSE DeviceProtection#133 END END AS DeviceProtection#202, CASE WHEN ((InternetService#124 = No) OR isnull(InternetService#124)) THEN No internet service ELSE CASE WHEN isnull(TechSupport#136) THEN No ELSE TechSupport#136 END END AS TechSupport#203, CASE WHEN ((InternetService#124 = No) OR isnull(InternetService#124)) THEN No internet service ELSE CASE WHEN isnull(StreamingTV#139) THEN No ELSE StreamingTV#139 END END AS StreamingTV#204, CASE WHEN ((InternetService#124 = No) OR isnull(InternetService#124)) THEN No internet service ELSE CASE WHEN isnull(StreamingMovies#142) THEN No ELSE StreamingMovies#142 END END AS StreamingMovies#205]\n", " : : +- BroadcastHashJoin [customerID#306], [customerID#188], LeftOuter, BuildRight, false\n", " : : :- Project [customerID#306, InternetService#124, OnlineSecurity#127, OnlineBackup#130, DeviceProtection#133, TechSupport#136, StreamingTV#139]\n", " : : : +- BroadcastHashJoin [customerID#306], [customerID#178], LeftOuter, BuildRight, false\n", " : : : :- Project [customerID#306, InternetService#124, OnlineSecurity#127, OnlineBackup#130, DeviceProtection#133, TechSupport#136]\n", " : : : : +- BroadcastHashJoin [customerID#306], [customerID#169], LeftOuter, BuildRight, false\n", " : : : : :- Project [customerID#306, InternetService#124, OnlineSecurity#127, OnlineBackup#130, DeviceProtection#133]\n", " : : : : : +- BroadcastHashJoin [customerID#306], [customerID#161], LeftOuter, BuildRight, false\n", " : : : : : :- Project [customerID#306, InternetService#124, OnlineSecurity#127, OnlineBackup#130]\n", " : : : : : : +- BroadcastHashJoin [customerID#306], [customerID#154], LeftOuter, BuildRight, false\n", " : : : : : : :- Project [customerID#306, InternetService#124, OnlineSecurity#127]\n", " : : : : : : : +- BroadcastHashJoin [customerID#306], [customerID#148], LeftOuter, BuildRight, false\n", " : : : : : : : :- Project [customerID#306, InternetService#124]\n", " : : : : : : : : +- BroadcastHashJoin [customerID#306], [customerID#105], LeftOuter, BuildRight, false\n", " : : : : : : : : :- HashAggregate(keys=[customerID#306], functions=[])\n", " : : : : : : : : : +- Exchange hashpartitioning(customerID#306, 200), ENSURE_REQUIREMENTS, [id=#595]\n", " : : : : : : : : : +- HashAggregate(keys=[customerID#306], functions=[])\n", " : : : : : : : : : +- Project [customerID#306]\n", " : : : : : : : : : +- Filter isnotnull(customerID#306)\n", " : : : : : : : : : +- FileScan parquet [customerID#306,month#310] Batched: true, DataFilters: [isnotnull(customerID#306)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/data/home/yuanli/work/customer-churn/data-science-blueprints/chu..., PartitionFilters: [], PushedFilters: [IsNotNull(customerID)], ReadSchema: struct\n", " : : : : : : : : +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, string, true]),false), [id=#598]\n", " : : : : : : : : +- Project [customerID#105, value#107 AS InternetService#124]\n", " : : : : : : : : +- Filter ((isnotnull(feature#106) AND (feature#106 = InternetService)) AND isnotnull(customerID#105))\n", " : : : : : : : : +- FileScan parquet [customerID#105,feature#106,value#107] Batched: true, DataFilters: [isnotnull(feature#106), (feature#106 = InternetService), isnotnull(customerID#105)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/data/home/yuanli/work/customer-churn/data-science-blueprints/chu..., PartitionFilters: [], PushedFilters: [IsNotNull(feature), EqualTo(feature,InternetService), IsNotNull(customerID)], ReadSchema: struct\n", " : : : : : : : +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, string, true]),false), [id=#602]\n", " : : : : : : : +- Project [customerID#148, value#150 AS OnlineSecurity#127]\n", " : : : : : : : +- Filter ((isnotnull(feature#149) AND (feature#149 = OnlineSecurity)) AND isnotnull(customerID#148))\n", " : : : : : : : +- FileScan parquet [customerID#148,feature#149,value#150] Batched: true, DataFilters: [isnotnull(feature#149), (feature#149 = OnlineSecurity), isnotnull(customerID#148)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/data/home/yuanli/work/customer-churn/data-science-blueprints/chu..., PartitionFilters: [], PushedFilters: [IsNotNull(feature), EqualTo(feature,OnlineSecurity), IsNotNull(customerID)], ReadSchema: struct\n", " : : : : : : +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, string, true]),false), [id=#606]\n", " : : : : : : +- Project [customerID#154, value#156 AS OnlineBackup#130]\n", " : : : : : : +- Filter ((isnotnull(feature#155) AND (feature#155 = OnlineBackup)) AND isnotnull(customerID#154))\n", " : : : : : : +- FileScan parquet [customerID#154,feature#155,value#156] Batched: true, DataFilters: [isnotnull(feature#155), (feature#155 = OnlineBackup), isnotnull(customerID#154)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/data/home/yuanli/work/customer-churn/data-science-blueprints/chu..., PartitionFilters: [], PushedFilters: [IsNotNull(feature), EqualTo(feature,OnlineBackup), IsNotNull(customerID)], ReadSchema: struct\n", " : : : : : +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, string, true]),false), [id=#610]\n", " : : : : : +- Project [customerID#161, value#163 AS DeviceProtection#133]\n", " : : : : : +- Filter ((isnotnull(feature#162) AND (feature#162 = DeviceProtection)) AND isnotnull(customerID#161))\n", " : : : : : +- FileScan parquet [customerID#161,feature#162,value#163] Batched: true, DataFilters: [isnotnull(feature#162), (feature#162 = DeviceProtection), isnotnull(customerID#161)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/data/home/yuanli/work/customer-churn/data-science-blueprints/chu..., PartitionFilters: [], PushedFilters: [IsNotNull(feature), EqualTo(feature,DeviceProtection), IsNotNull(customerID)], ReadSchema: struct\n", " : : : : +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, string, true]),false), [id=#614]\n", " : : : : +- Project [customerID#169, value#171 AS TechSupport#136]\n", " : : : : +- Filter ((isnotnull(feature#170) AND (feature#170 = TechSupport)) AND isnotnull(customerID#169))\n", " : : : : +- FileScan parquet [customerID#169,feature#170,value#171] Batched: true, DataFilters: [isnotnull(feature#170), (feature#170 = TechSupport), isnotnull(customerID#169)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/data/home/yuanli/work/customer-churn/data-science-blueprints/chu..., PartitionFilters: [], PushedFilters: [IsNotNull(feature), EqualTo(feature,TechSupport), IsNotNull(customerID)], ReadSchema: struct\n", " : : : +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, string, true]),false), [id=#618]\n", " : : : +- Project [customerID#178, value#180 AS StreamingTV#139]\n", " : : : +- Filter ((isnotnull(feature#179) AND (feature#179 = StreamingTV)) AND isnotnull(customerID#178))\n", " : : : +- FileScan parquet [customerID#178,feature#179,value#180] Batched: true, DataFilters: [isnotnull(feature#179), (feature#179 = StreamingTV), isnotnull(customerID#178)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/data/home/yuanli/work/customer-churn/data-science-blueprints/chu..., PartitionFilters: [], PushedFilters: [IsNotNull(feature), EqualTo(feature,StreamingTV), IsNotNull(customerID)], ReadSchema: struct\n", " : : +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, string, true]),false), [id=#622]\n", " : : +- Project [customerID#188, value#190 AS StreamingMovies#142]\n", " : : +- Filter ((isnotnull(feature#189) AND (feature#189 = StreamingMovies)) AND isnotnull(customerID#188))\n", " : : +- FileScan parquet [customerID#188,feature#189,value#190] Batched: true, DataFilters: [isnotnull(feature#189), (feature#189 = StreamingMovies), isnotnull(customerID#188)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/data/home/yuanli/work/customer-churn/data-science-blueprints/chu..., PartitionFilters: [], PushedFilters: [IsNotNull(feature), EqualTo(feature,StreamingMovies), IsNotNull(customerID)], ReadSchema: struct\n", " : +- Sort [customerID#324 ASC NULLS FIRST], false, 0\n", " : +- Project [customerID#324, Contract#233, CASE WHEN isnull(PaperlessBilling#236) THEN No ELSE PaperlessBilling#236 END AS PaperlessBilling#258, PaymentMethod#239]\n", " : +- BroadcastHashJoin [customerID#324], [customerID#251], LeftOuter, BuildRight, false\n", " : :- Project [customerID#324, Contract#233, PaperlessBilling#236]\n", " : : +- BroadcastHashJoin [customerID#324], [customerID#245], LeftOuter, BuildRight, false\n", " : : :- Project [customerID#324, Contract#233]\n", " : : : +- BroadcastHashJoin [customerID#324], [customerID#214], LeftOuter, BuildRight, false\n", " : : : :- HashAggregate(keys=[customerID#324], functions=[])\n", " : : : : +- Exchange hashpartitioning(customerID#324, 200), ENSURE_REQUIREMENTS, [id=#630]\n", " : : : : +- HashAggregate(keys=[customerID#324], functions=[])\n", " : : : : +- Project [customerID#324]\n", " : : : : +- Filter isnotnull(customerID#324)\n", " : : : : +- FileScan parquet [customerID#324,month#328] Batched: true, DataFilters: [isnotnull(customerID#324)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/data/home/yuanli/work/customer-churn/data-science-blueprints/chu..., PartitionFilters: [], PushedFilters: [IsNotNull(customerID)], ReadSchema: struct\n", " : : : +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, string, true]),false), [id=#633]\n", " : : : +- Project [customerID#214, value#216 AS Contract#233]\n", " : : : +- Filter ((isnotnull(feature#215) AND (feature#215 = Contract)) AND isnotnull(customerID#214))\n", " : : : +- FileScan parquet [customerID#214,feature#215,value#216] Batched: true, DataFilters: [isnotnull(feature#215), (feature#215 = Contract), isnotnull(customerID#214)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/data/home/yuanli/work/customer-churn/data-science-blueprints/chu..., PartitionFilters: [], PushedFilters: [IsNotNull(feature), EqualTo(feature,Contract), IsNotNull(customerID)], ReadSchema: struct\n", " : : +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, string, true]),false), [id=#637]\n", " : : +- Project [customerID#245, value#247 AS PaperlessBilling#236]\n", " : : +- Filter ((isnotnull(feature#246) AND (feature#246 = PaperlessBilling)) AND isnotnull(customerID#245))\n", " : : +- FileScan parquet [customerID#245,feature#246,value#247] Batched: true, DataFilters: [isnotnull(feature#246), (feature#246 = PaperlessBilling), isnotnull(customerID#245)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/data/home/yuanli/work/customer-churn/data-science-blueprints/chu..., PartitionFilters: [], PushedFilters: [IsNotNull(feature), EqualTo(feature,PaperlessBilling), IsNotNull(customerID)], ReadSchema: struct\n", " : +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, string, true]),false), [id=#641]\n", " : +- Project [customerID#251, value#253 AS PaymentMethod#239]\n", " : +- Filter ((isnotnull(feature#252) AND (feature#252 = PaymentMethod)) AND isnotnull(customerID#251))\n", " : +- FileScan parquet [customerID#251,feature#252,value#253] Batched: true, DataFilters: [isnotnull(feature#252), (feature#252 = PaymentMethod), isnotnull(customerID#251)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/data/home/yuanli/work/customer-churn/data-science-blueprints/chu..., PartitionFilters: [], PushedFilters: [IsNotNull(feature), EqualTo(feature,PaymentMethod), IsNotNull(customerID)], ReadSchema: struct\n", " +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, string, true]),false), [id=#650]\n", " +- Project [customerID#263, ((year(cast(now#270 as date)) > (year(dateOfBirth#264) + 65)) OR ((year(cast(now#270 as date)) = (year(dateOfBirth#264) + 65)) AND ((month(cast(now#270 as date)) < month(dateOfBirth#264)) OR ((month(cast(now#270 as date)) = month(dateOfBirth#264)) AND (dayofmonth(cast(now#270 as date)) <= dayofmonth(cast(now#270 as date))))))) AS SeniorCitizen#279, Partner#267, Dependents#268, gender#265, MonthlyCharges#269]\n", " +- Filter isnotnull(customerID#263)\n", " +- FileScan parquet [customerID#263,dateOfBirth#264,gender#265,Partner#267,Dependents#268,MonthlyCharges#269,now#270] Batched: true, DataFilters: [isnotnull(customerID#263)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/data/home/yuanli/work/customer-churn/data-science-blueprints/chu..., PartitionFilters: [], PushedFilters: [IsNotNull(customerID)], ReadSchema: struct (28 + 1) / 29]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 1.15 s, sys: 188 ms, total: 1.34 s\n", "Wall time: 2min 58s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\r", " \r" ] } ], "source": [ "%%time\n", "from churn.etl import write_df\n", "write_df(wide_data, output_file)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Inspecting individual tables\n", "\n", "If we need to inspect individual components of our processing, we can. Each constituent of these joins is registered as a temporary view. For example, we loaded `customers` earlier using a method from `churn.etl`, but it is also available as a table:" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "customers = session.table(\"customers\")" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-04-05 10:02:56,112 WARN rapids.GpuOverrides: \n", "!Exec cannot run on GPU because the Exec CollectLimitExec has been disabled, and is disabled by default because Collect Limit replacement can be slower on the GPU, if huge number of rows in a batch it could help by limiting the number of rows transferred from GPU to CPU. Set spark.rapids.sql.exec.CollectLimitExec to true if you wish to enable it\n", " @Partitioning could run on GPU\n", "\n", "2022-04-05 10:02:56,113 WARN rapids.GpuOverrides: \n", "!Exec cannot run on GPU because the Exec CollectLimitExec has been disabled, and is disabled by default because Collect Limit replacement can be slower on the GPU, if huge number of rows in a batch it could help by limiting the number of rows transferred from GPU to CPU. Set spark.rapids.sql.exec.CollectLimitExec to true if you wish to enable it\n", " @Partitioning could run on GPU\n", "\n", "2022-04-05 10:02:56,114 WARN rapids.GpuOverrides: \n", "!Exec cannot run on GPU because the Exec CollectLimitExec has been disabled, and is disabled by default because Collect Limit replacement can be slower on the GPU, if huge number of rows in a batch it could help by limiting the number of rows transferred from GPU to CPU. Set spark.rapids.sql.exec.CollectLimitExec to true if you wish to enable it\n", " @Partitioning could run on GPU\n", "\n", "2022-04-05 10:03:35,633 WARN rapids.GpuOverrides: =============>(790 + 1) / 795]\n", "!Exec cannot run on GPU because the Exec CollectLimitExec has been disabled, and is disabled by default because Collect Limit replacement can be slower on the GPU, if huge number of rows in a batch it could help by limiting the number of rows transferred from GPU to CPU. Set spark.rapids.sql.exec.CollectLimitExec to true if you wish to enable it\n", " @Partitioning could run on GPU\n", "\n", "2022-04-05 10:03:35,634 WARN rapids.GpuOverrides: \n", "!Exec cannot run on GPU because the Exec CollectLimitExec has been disabled, and is disabled by default because Collect Limit replacement can be slower on the GPU, if huge number of rows in a batch it could help by limiting the number of rows transferred from GPU to CPU. Set spark.rapids.sql.exec.CollectLimitExec to true if you wish to enable it\n", " @Partitioning could run on GPU\n", "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "+--------------------+\n", "| customerID|\n", "+--------------------+\n", "|9102-OXKFY-fCyaBG...|\n", "|5478-JJVZK-pUoKfE...|\n", "|1843-TLSGD-PjdWrt...|\n", "|2027-FECZV-5HMwOd...|\n", "|3793-MMFUH-FBa4QK...|\n", "|5360-XGYAZ-F1ALBc...|\n", "|1843-TLSGD-L@JxWt...|\n", "|5872-OEQNH-5NXyac...|\n", "|6773-LQTVT-XB@vuC...|\n", "|3301-VKTGC-PjdWrt...|\n", "|9251-AWQGT-fCyaBG...|\n", "|9830-ECLEN-lqhKlh...|\n", "|7969-FFOWG-fPARzA...|\n", "|9451-WLYRI-0stTDJ...|\n", "|4293-ETKAP-dkh3P1...|\n", "|6281-FKEWS-0V3zMQ...|\n", "|8220-OCUFY-PjdWrt...|\n", "|0578-SKVMF-GSLp0h...|\n", "|2165-VOEGB-K8kBya...|\n", "|6754-WKSHP-rt81Nn...|\n", "+--------------------+\n", "only showing top 20 rows\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\r", " \r" ] } ], "source": [ "customers.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can see which tables are available by querying the session catalog:" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-04-05 10:03:38,813 WARN conf.HiveConf: HiveConf of name hive.stats.jdbc.timeout does not exist\n", "2022-04-05 10:03:38,814 WARN conf.HiveConf: HiveConf of name hive.stats.retries.wait does not exist\n", "2022-04-05 10:03:40,550 WARN metastore.ObjectStore: Version information not found in metastore. hive.metastore.schema.verification is not enabled so recording the schema version 2.3.0\n", "2022-04-05 10:03:40,550 WARN metastore.ObjectStore: setMetaStoreSchemaVersion called but recording version is disabled: version = 2.3.0, comment = Set by MetaStore yuanli@127.0.1.1\n", "2022-04-05 10:03:40,703 WARN metastore.ObjectStore: Failed to get database global_temp, returning NoSuchObjectException\n", "2022-04-05 10:03:40,833 WARN rapids.GpuOverrides: \n", "! cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.execution.LocalTableScanExec\n", " @Expression name#507 could run on GPU\n", " @Expression database#508 could run on GPU\n", " @Expression description#509 could run on GPU\n", " @Expression tableType#510 could run on GPU\n", " @Expression isTemporary#511 could run on GPU\n", "\n" ] }, { "data": { "text/plain": [ "['churned',\n", " 'contracts',\n", " 'counts_and_charges',\n", " 'customer_account_features',\n", " 'customer_account_meta',\n", " 'customer_billing',\n", " 'customer_charges',\n", " 'customer_internet_features',\n", " 'customer_phone_features',\n", " 'customers',\n", " 'device_protection',\n", " 'internet_service',\n", " 'multiple_lines',\n", " 'online_backup',\n", " 'online_security',\n", " 'paperless',\n", " 'payment',\n", " 'phone_service',\n", " 'streaming_movies',\n", " 'streaming_tv',\n", " 'tech_support',\n", " 'terminations']" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tables = session.catalog.listTables()\n", "[t.name for t in tables]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Finishing up" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "session.stop()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: examples/SQL+DF-Examples/demo/Spark_get_json_object.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "id": "Td_alkbOv3Aj", "metadata": { "id": "Td_alkbOv3Aj" }, "source": [ "# Spark RAPIDS get_json_object acceleration\n", "\n" ] }, { "cell_type": "markdown", "id": "c6ed860b", "metadata": { "id": "c6ed860b" }, "source": [ "\n", " \"Open\n", "\n" ] }, { "cell_type": "markdown", "id": "AhUsdz6jLdMi", "metadata": { "id": "AhUsdz6jLdMi" }, "source": [ "\n", "Before getting started - be sure to change your runtime to use a GPU Hardware accelerator! Use the Runtime -> \"Change runtime type\" menu option to add a GPU." ] }, { "cell_type": "markdown", "id": "ZfNDlz0SM0DB", "metadata": { "id": "ZfNDlz0SM0DB" }, "source": [ "# Let's get started using the RAPIDS Accelerator for Apache Spark" ] }, { "cell_type": "code", "execution_count": null, "id": "PzW61-K04A1E", "metadata": { "id": "PzW61-K04A1E" }, "outputs": [], "source": [ "!nvidia-smi" ] }, { "cell_type": "code", "source": [ "!cat /proc/cpuinfo" ], "metadata": { "id": "OIEun51OCyC4" }, "id": "OIEun51OCyC4", "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "spark_version='3.5.0'\n", "rapids_version='24.12.0'" ], "metadata": { "id": "NEGt46X7nEqf" }, "id": "NEGt46X7nEqf", "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "%pip install --quiet \\\n", " pyspark=={spark_version}" ], "metadata": { "id": "g9XK28gcnHiG" }, "id": "g9XK28gcnHiG", "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "from importlib.resources import files\n", "from pyspark.sql import SparkSession\n", "import glob\n", "import os\n", "import re\n", "import time\n", "import statistics" ], "metadata": { "id": "gr2msGD1nLh-" }, "id": "gr2msGD1nLh-", "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "pyspark_files = files('pyspark')\n", "spark_sql_jar_path, *_ = glob.glob(f\"{pyspark_files}/*/spark-sql_*jar\")\n", "spark_sql_jar = os.path.basename(spark_sql_jar_path)\n", "scala_version = re.search(r'^spark-sql_(\\d+.\\d+)-.*\\.jar$', spark_sql_jar).group(1)" ], "metadata": { "id": "0uXK6z8KoFUt" }, "id": "0uXK6z8KoFUt", "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "spark = (\n", " SparkSession.builder\n", " .appName('JSON PySpark RAPIDS=ON/OFF')\n", " .config('spark.driver.memory', '5g')\n", " .config('spark.plugins', 'com.nvidia.spark.SQLPlugin')\n", " .config('spark.jars.packages', f\"com.nvidia:rapids-4-spark_{scala_version}:{rapids_version}\")\n", " .getOrCreate()\n", ")\n", "spark" ], "metadata": { "id": "ayT5VJQvnQv4" }, "id": "ayT5VJQvnQv4", "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "location = \"./TMP_DATA\"\n", "iters = 3" ], "metadata": { "id": "3VsYyTATpNG1" }, "id": "3VsYyTATpNG1", "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "def mk_json_column(i):\n", " return \"\"\" '\"', CAST(rand(\"\"\" + str(i) + \"\"\") * 10000 AS LONG), '\":\"\"\" + str(i) + \"\"\"'\"\"\"\n", "\n", "# generate json lines with very sparse keys\n", "spark.range(1000000).selectExpr(\"\"\"concat('{', \"\"\" + (\"\"\", ',' ,\"\"\".join([mk_json_column(i) for i in range(100)])) + \"\"\"'}') as json\"\"\").write.mode(\"overwrite\").parquet(location)" ], "metadata": { "id": "diUi3mxWh91X" }, "id": "diUi3mxWh91X", "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# Test pulling out a few keys using the GPU\n", "spark.conf.set(\"spark.rapids.sql.enabled\",True)\n", "gpu_times = []\n", "for i in range(iters):\n", " start = time.time()\n", " df = spark.read.parquet(location).selectExpr(\"count(get_json_object(json,'$.0')) as zero\", \"count(get_json_object(json,'$.10')) as ten\", \"count(get_json_object(json,'$.100')) as hundred\", \"count(get_json_object(json,'$.1000')) as thousand\", \"count(get_json_object(json,'$.1001')) as thousandAndOne\", \"avg(octet_length(json)) as len\")\n", " if i == 0:\n", " df.show()\n", " else:\n", " df.collect()\n", " end = time.time()\n", " gpu_times.append(end - start)\n", "\n", "\n", "print(f\"Median execution time of {iters} runs for GPU get_json_object: {statistics.median(gpu_times):.3f}\")" ], "metadata": { "id": "iXaXVgBNt4pK" }, "id": "iXaXVgBNt4pK", "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# Run the same test using the CPU. Note that this is a exceptional result\n", "# because Colab provides very little CPU (2 cores) to go with the GPU (T4)\n", "# on a 16 core AMD CPU that is not overcommited and with an NVMe to load the\n", "# data, and an A6000 GPU, the GPU takes about 0.662 seconds to complete and\n", "# the CPU taks about 2.986 seconds, or about a 4.5x speedup, compared to this\n", "# notebook's ~30x speedup.\n", "spark.conf.set(\"spark.rapids.sql.enabled\",False)\n", "cpu_times = []\n", "for i in range(iters):\n", " start = time.time()\n", " df = spark.read.parquet(location).selectExpr(\"count(get_json_object(json,'$.0')) as zero\", \"count(get_json_object(json,'$.10')) as ten\", \"count(get_json_object(json,'$.100')) as hundred\", \"count(get_json_object(json,'$.1000')) as thousand\", \"count(get_json_object(json,'$.1001')) as thousandAndOne\", \"avg(octet_length(json)) as len\")\n", " if i == 0:\n", " df.show()\n", " else:\n", " df.collect()\n", " end = time.time()\n", " cpu_times.append(end - start)\n", "\n", "print(f\"Median execution time of {iters} runs for CPU get_json_object: {statistics.median(cpu_times):.3f}\")" ], "metadata": { "id": "lUmVe12Wic5X" }, "id": "lUmVe12Wic5X", "execution_count": null, "outputs": [] } ], "metadata": { "accelerator": "GPU", "colab": { "provenance": [] }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3.9.12 ('base')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.12" }, "vscode": { "interpreter": { "hash": "5327a248d9883bedf47bfd9e608af95bf318797e621edcc550c6b5b3fdc820cc" } } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: examples/SQL+DF-Examples/demo/Spark_parquet_microkernels.ipynb ================================================ { "cells": [ { "cell_type": "raw", "id": "Td_alkbOv3Aj", "metadata": { "id": "Td_alkbOv3Aj" }, "source": [ "{\n", " \"cells\": [\n", " {\n", " \"cell_type\": \"markdown\",\n", " \"id\": \"Td_alkbOv3Aj\",\n", " \"metadata\": {\n", " \"id\": \"Td_alkbOv3Aj\"\n", " },\n", " \"source\": [\n", " \"# Spark RAPIDS Parquet acceleration\\n\",\n", " \"\\n\"\n", " ]\n", " },\n", " {\n", " \"cell_type\": \"markdown\",\n", " \"id\": \"c6ed860b\",\n", " \"metadata\": {\n", " \"id\": \"c6ed860b\"\n", " },\n", " \"source\": [\n", " \"\\n\",\n", " \" \\\"Open\\n\",\n", " \"\\n\"\n", " ]\n", " },\n", " {\n", " \"cell_type\": \"markdown\",\n", " \"id\": \"AhUsdz6jLdMi\",\n", " \"metadata\": {\n", " \"id\": \"AhUsdz6jLdMi\"\n", " },\n", " \"source\": [\n", " \"\\n\",\n", " \"Before getting started - be sure to change your runtime to use a GPU Hardware accelerator! Use the Runtime -> \\\"Change runtime type\\\" menu option to add a GPU.\"\n", " ]\n", " },\n", " {\n", " \"cell_type\": \"markdown\",\n", " \"id\": \"ZfNDlz0SM0DB\",\n", " \"metadata\": {\n", " \"id\": \"ZfNDlz0SM0DB\"\n", " },\n", " \"source\": [\n", " \"# Let's get started using the RAPIDS Accelerator for Apache Spark\"\n", " ]\n", " },\n", " {\n", " \"cell_type\": \"code\",\n", " \"execution_count\": null,\n", " \"id\": \"PzW61-K04A1E\",\n", " \"metadata\": {\n", " \"id\": \"PzW61-K04A1E\"\n", " },\n", " \"outputs\": [],\n", " \"source\": [\n", " \"!nvidia-smi\"\n", " ]\n", " },\n", " {\n", " \"cell_type\": \"code\",\n", " \"source\": [\n", " \"!cat /proc/cpuinfo\"\n", " ],\n", " \"metadata\": {\n", " \"id\": \"OIEun51OCyC4\"\n", " },\n", " \"id\": \"OIEun51OCyC4\",\n", " \"execution_count\": null,\n", " \"outputs\": []\n", " },\n", " {\n", " \"cell_type\": \"code\",\n", " \"source\": [\n", " \"spark_version='3.5.0'\\n\",\n", " \"rapids_version='24.12.0'\"\n", " ],\n", " \"metadata\": {\n", " \"id\": \"NEGt46X7nEqf\"\n", " },\n", " \"id\": \"NEGt46X7nEqf\",\n", " \"execution_count\": null,\n", " \"outputs\": []\n", " },\n", " {\n", " \"cell_type\": \"code\",\n", " \"source\": [\n", " \"%pip install --quiet \\\\\\n\",\n", " \" pyspark=={spark_version}\"\n", " ],\n", " \"metadata\": {\n", " \"id\": \"g9XK28gcnHiG\"\n", " },\n", " \"id\": \"g9XK28gcnHiG\",\n", " \"execution_count\": null,\n", " \"outputs\": []\n", " },\n", " {\n", " \"cell_type\": \"code\",\n", " \"source\": [\n", " \"from importlib.resources import files\\n\",\n", " \"from pyspark.sql import SparkSession\\n\",\n", " \"import glob\\n\",\n", " \"import os\\n\",\n", " \"import re\\n\",\n", " \"import time\\n\",\n", " \"import statistics\"\n", " ],\n", " \"metadata\": {\n", " \"id\": \"gr2msGD1nLh-\"\n", " },\n", " \"id\": \"gr2msGD1nLh-\",\n", " \"execution_count\": null,\n", " \"outputs\": []\n", " },\n", " {\n", " \"cell_type\": \"code\",\n", " \"source\": [\n", " \"pyspark_files = files('pyspark')\\n\",\n", " \"spark_sql_jar_path, *_ = glob.glob(f\\\"{pyspark_files}/*/spark-sql_*jar\\\")\\n\",\n", " \"spark_sql_jar = os.path.basename(spark_sql_jar_path)\\n\",\n", " \"scala_version = re.search(r'^spark-sql_(\\\\d+.\\\\d+)-.*\\\\.jar$', spark_sql_jar).group(1)\"\n", " ],\n", " \"metadata\": {\n", " \"id\": \"0uXK6z8KoFUt\"\n", " },\n", " \"id\": \"0uXK6z8KoFUt\",\n", " \"execution_count\": null,\n", " \"outputs\": []\n", " },\n", " {\n", " \"cell_type\": \"code\",\n", " \"source\": [\n", " \"spark = (\\n\",\n", " \" SparkSession.builder\\n\",\n", " \" .appName('Parquet Spark GPU Acceleration')\\n\",\n", " \" .master('local[*]')\\n\",\n", " \" .config('spark.driver.memory', '5g')\\n\",\n", " \" .config('spark.plugins', 'com.nvidia.spark.SQLPlugin')\\n\",\n", " \" .config('spark.jars.packages', f\\\"com.nvidia:rapids-4-spark_{scala_version}:{rapids_version}\\\")\\n\",\n", " \" .getOrCreate()\\n\",\n", " \")\\n\",\n", " \"spark\"\n", " ],\n", " \"metadata\": {\n", " \"id\": \"ayT5VJQvnQv4\"\n", " },\n", " \"id\": \"ayT5VJQvnQv4\",\n", " \"execution_count\": null,\n", " \"outputs\": []\n", " },\n", " {\n", " \"cell_type\": \"code\",\n", " \"source\": [\n", " \"location = \\\"./TMP_DATA\\\"\\n\",\n", " \"iters = 5\"\n", " ],\n", " \"metadata\": {\n", " \"id\": \"3VsYyTATpNG1\"\n", " },\n", " \"id\": \"3VsYyTATpNG1\",\n", " \"execution_count\": null,\n", " \"outputs\": []\n", " },\n", " {\n", " \"cell_type\": \"code\",\n", " \"source\": [\n", " \"from pyspark.sql.types import IntegerType, StringType, StructType, StructField\\n\",\n", " \"from pyspark.sql import functions as F\\n\",\n", " \"import random\\n\",\n", " \"import string\\n\",\n", " \"\\n\",\n", " \"# Define schema\\n\",\n", " \"schema = StructType([\\n\",\n", " \" StructField(\\\"id\\\", IntegerType(), False),\\n\",\n", " \" StructField(\\\"name\\\", StringType(), False),\\n\",\n", " \" StructField(\\\"age\\\", IntegerType(), False),\\n\",\n", " \" StructField(\\\"salary\\\", IntegerType(), False)\\n\",\n", " \"])\\n\",\n", " \"\\n\",\n", " \"# Function to generate random strings\\n\",\n", " \"def random_string(length=10):\\n\",\n", " \" return ''.join(random.choices(string.ascii_letters, k=length))\\n\",\n", " \"\\n\",\n", " \"# Generate DataFrame with 20M rows\\n\",\n", " \"df = spark.range(0, 20_000_000).toDF(\\\"id\\\") \\\\\\n\",\n", " \" .withColumn(\\\"name\\\", F.udf(lambda: random_string(), StringType())()) \\\\\\n\",\n", " \" .withColumn(\\\"age\\\", (F.rand() * 50 + 20).cast(IntegerType())) \\\\\\n\",\n", " \" .withColumn(\\\"salary\\\", (F.rand() * 100000 + 30000).cast(IntegerType()))\\n\",\n", " \"\\n\",\n", " \"df.write.mode(\\\"overwrite\\\").parquet(location)\"\n", " ],\n", " \"metadata\": {\n", " \"id\": \"diUi3mxWh91X\"\n", " },\n", " \"id\": \"diUi3mxWh91X\",\n", " \"execution_count\": null,\n", " \"outputs\": []\n", " },\n", " {\n", " \"cell_type\": \"code\",\n", " \"source\": [\n", " \"# Run the Parquet scan test on the GPU\\n\",\n", " \"spark.conf.set(\\\"spark.rapids.sql.enabled\\\",True)\\n\",\n", " \"gpu_times = []\\n\",\n", " \"for i in range(iters):\\n\",\n", " \" start = time.time()\\n\",\n", " \" df = spark.read.parquet(location).selectExpr(\\\"count(name) as rows\\\", \\\"avg(salary) as average_salary\\\", \\\"median(salary) as median_salary\\\", \\\"sum(salary) as total_salary\\\", \\\"avg(age) as average_age\\\", \\\"median(age) as median_age\\\")\\n\",\n", " \" if i == 0:\\n\",\n", " \" df.show()\\n\",\n", " \" else:\\n\",\n", " \" df.collect()\\n\",\n", " \" end = time.time()\\n\",\n", " \" gpu_times.append(end - start)\\n\",\n", " \"\\n\",\n", " \"gpu_median = statistics.median(gpu_times)\\n\",\n", " \"\\n\",\n", " \"print(f\\\"Median execution time of {iters} runs for GPU Parquet scan: {gpu_median:.3f}\\\")\"\n", " ],\n", " \"metadata\": {\n", " \"id\": \"iXaXVgBNt4pK\"\n", " },\n", " \"id\": \"iXaXVgBNt4pK\",\n", " \"execution_count\": null,\n", " \"outputs\": []\n", " },\n", " {\n", " \"cell_type\": \"code\",\n", " \"source\": [\n", " \"# Run the Parquet scan test on the CPU\\n\",\n", " \"spark.conf.set(\\\"spark.rapids.sql.enabled\\\",False)\\n\",\n", " \"cpu_times = []\\n\",\n", " \"for i in range(iters):\\n\",\n", " \" start = time.time()\\n\",\n", " \" df = spark.read.parquet(location).selectExpr(\\\"count(name) as rows\\\", \\\"avg(salary) as average_salary\\\", \\\"median(salary) as median_salary\\\", \\\"sum(salary) as total_salary\\\", \\\"avg(age) as average_age\\\", \\\"median(age) as median_age\\\")\\n\",\n", " \" if i == 0:\\n\",\n", " \" df.show()\\n\",\n", " \" else:\\n\",\n", " \" df.collect()\\n\",\n", " \" end = time.time()\\n\",\n", " \" cpu_times.append(end - start)\\n\",\n", " \"\\n\",\n", " \"cpu_median = statistics.median(cpu_times)\\n\",\n", " \"print(f\\\"Median execution time of {iters} runs for CPU Parquet scan: {cpu_median:.3f}\\\")\"\n", " ],\n", " \"metadata\": {\n", " \"id\": \"lUmVe12Wic5X\"\n", " },\n", " \"id\": \"lUmVe12Wic5X\",\n", " \"execution_count\": null,\n", " \"outputs\": []\n", " },\n", " {\n", " \"cell_type\": \"code\",\n", " \"source\": [\n", " \"# GPU speedup should be in the range of 5-10x\\n\",\n", " \"speedup = cpu_median / gpu_median\\n\",\n", " \"print(f\\\"GPU speedup: {speedup:.2f}x\\\")\"\n", " ],\n", " \"metadata\": {\n", " \"id\": \"CxROFk_AoQQl\"\n", " },\n", " \"id\": \"CxROFk_AoQQl\",\n", " \"execution_count\": null,\n", " \"outputs\": []\n", " }\n", " ],\n", " \"metadata\": {\n", " \"accelerator\": \"GPU\",\n", " \"colab\": {\n", " \"provenance\": []\n", " },\n", " \"gpuClass\": \"standard\",\n", " \"kernelspec\": {\n", " \"display_name\": \"Python 3.9.12 ('base')\",\n", " \"language\": \"python\",\n", " \"name\": \"python3\"\n", " },\n", " \"language_info\": {\n", " \"codemirror_mode\": {\n", " \"name\": \"ipython\",\n", " \"version\": 3\n", " },\n", " \"file_extension\": \".py\",\n", " \"mimetype\": \"text/x-python\",\n", " \"name\": \"python\",\n", " \"nbconvert_exporter\": \"python\",\n", " \"pygments_lexer\": \"ipython3\",\n", " \"version\": \"3.9.12\"\n", " },\n", " \"vscode\": {\n", " \"interpreter\": {\n", " \"hash\": \"5327a248d9883bedf47bfd9e608af95bf318797e621edcc550c6b5b3fdc820cc\"\n", " }\n", " }\n", " },\n", " \"nbformat\": 4,\n", " \"nbformat_minor\": 5\n", "}\n" ] }, { "cell_type": "markdown", "id": "AhUsdz6jLdMi", "metadata": { "id": "AhUsdz6jLdMi" }, "source": [ "\n", "Before getting started - be sure to change your runtime to use a GPU Hardware accelerator! Use the Runtime -> \"Change runtime type\" menu option to add a GPU." ] }, { "cell_type": "code", "execution_count": null, "id": "PzW61-K04A1E", "metadata": { "id": "PzW61-K04A1E" }, "outputs": [], "source": [ "!nvidia-smi" ] }, { "cell_type": "code", "source": [ "spark_version='3.5.0'\n", "rapids_version='24.12.0'" ], "metadata": { "id": "NEGt46X7nEqf" }, "id": "NEGt46X7nEqf", "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "from importlib.resources import files\n", "from pyspark.sql import SparkSession\n", "import glob\n", "import os\n", "import re\n", "import time\n", "import statistics" ], "metadata": { "id": "gr2msGD1nLh-" }, "id": "gr2msGD1nLh-", "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "spark = (\n", " SparkSession.builder\n", " .appName('Parquet Spark GPU Acceleration')\n", " .master('local[*]')\n", " .config('spark.driver.memory', '5g')\n", " .config('spark.plugins', 'com.nvidia.spark.SQLPlugin')\n", " .config('spark.jars.packages', f\"com.nvidia:rapids-4-spark_{scala_version}:{rapids_version}\")\n", " .getOrCreate()\n", ")\n", "spark" ], "metadata": { "id": "ayT5VJQvnQv4" }, "id": "ayT5VJQvnQv4", "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "from pyspark.sql.types import IntegerType, StringType, StructType, StructField\n", "from pyspark.sql import functions as F\n", "import random\n", "import string\n", "\n", "# Define schema\n", "schema = StructType([\n", " StructField(\"id\", IntegerType(), False),\n", " StructField(\"name\", StringType(), False),\n", " StructField(\"age\", IntegerType(), False),\n", " StructField(\"salary\", IntegerType(), False)\n", "])\n", "\n", "# Function to generate random strings\n", "def random_string(length=10):\n", " return ''.join(random.choices(string.ascii_letters, k=length))\n", "\n", "# Generate DataFrame with 20M rows\n", "df = spark.range(0, 20_000_000).toDF(\"id\") \\\n", " .withColumn(\"name\", F.udf(lambda: random_string(), StringType())()) \\\n", " .withColumn(\"age\", (F.rand() * 50 + 20).cast(IntegerType())) \\\n", " .withColumn(\"salary\", (F.rand() * 100000 + 30000).cast(IntegerType()))\n", "\n", "df.write.mode(\"overwrite\").parquet(location)" ], "metadata": { "id": "diUi3mxWh91X" }, "id": "diUi3mxWh91X", "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# Run the Parquet scan test on the CPU\n", "spark.conf.set(\"spark.rapids.sql.enabled\",False)\n", "cpu_times = []\n", "for i in range(iters):\n", " start = time.time()\n", " df = spark.read.parquet(location).selectExpr(\"count(name) as rows\", \"avg(salary) as average_salary\", \"median(salary) as median_salary\", \"sum(salary) as total_salary\", \"avg(age) as average_age\", \"median(age) as median_age\")\n", " if i == 0:\n", " df.show()\n", " else:\n", " df.collect()\n", " end = time.time()\n", " cpu_times.append(end - start)\n", "\n", "cpu_median = statistics.median(cpu_times)\n", "print(f\"Median execution time of {iters} runs for CPU Parquet scan: {cpu_median:.3f}\")" ], "metadata": { "id": "lUmVe12Wic5X" }, "id": "lUmVe12Wic5X", "execution_count": null, "outputs": [] } ], "metadata": { "accelerator": "GPU", "colab": { "provenance": [] }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3.9.12 ('base')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.12" }, "vscode": { "interpreter": { "hash": "5327a248d9883bedf47bfd9e608af95bf318797e621edcc550c6b5b3fdc820cc" } } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: examples/SQL+DF-Examples/micro-benchmarks/README.md ================================================ # Microbenchmark Standard industry benchmarks are a great way to measure performance over a period of time but another barometer to measure performance is to measure performance of common operators that are used in the data preprocessing stage or in data analytics. The microbenchmark notebook in this repo uses five such queries in the chart shown below: - **Count Distinct**: a function used to estimate the number of unique page views or unique customers visiting an e-commerce site. - **Window**: a critical operator necessary for preprocessing components in analyzing timestamped event data in marketing or financial industry. - **Intersect**: an operator used to remove duplicates in a dataframe. - **Cross-join**: A common use for a cross join is to obtain all combinations of items. - **Hash-join**: Joining two tables together by matching rows based on a common column. These queries were run on a standard eight-nodes CPU cluster with 2 CPU (128 cores), 512GB memory and 1xA100 GPUs per node. The dataset used was of size 3TB with multiple different data types. The queries are based on several tables in NDS parquet format with Decimal. These four queries show not only performance and cost benefits but also the range of speed-up (27x to 1.5x) varies depending on compute intensity. These queries vary in compute and network utilization similar to a practical use case in data preprocessing.To test these queries, you can generate the parquet format dataset using this NDS dataset generator tool. All the queries are running on the SF3000(Scale Factor 3000) dataset. You can generate it with the following command: ``` # Assuming your platform is Linux # Install sbt echo "deb https://repo.scala-sbt.org/scalasbt/debian all main" | sudo tee /etc/apt/sources.list.d/sbt.list echo "deb https://repo.scala-sbt.org/scalasbt/debian /" | sudo tee /etc/apt/sources.list.d/sbt_old.list curl -sL "https://keyserver.ubuntu.com/pks/lookup?op=get&search=0x2EE0EA64E40A89B84B2DF73499E82A75642AC823" | sudo apt-key add sudo apt-get update sudo apt-get install sbt # Install jdk sudo apt-get install openjdk-8-jdk # clone related repos git clone https://github.com/databricks/spark-sql-perf.git git clone https://github.com/databricks/tpcds-kit.git # build cd tpcds-kit/tools make OS=LINUX sbt "test:runMain com.databricks.spark.sql.perf.tpcds.GenTPCDSData -d /databricks-tpcds-kit-path -s 3000G -l /your-dataset-path -f parquet" ``` ![microbenchmark-speedup](/docs/img/guides/microbm.png) ================================================ FILE: examples/SQL+DF-Examples/micro-benchmarks/notebooks/micro-benchmarks-cpu.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "id": "d89df9bf", "metadata": {}, "source": [ "# Microbenchmarks on CPU\n", "This is a notebook for microbenchmarks running on CPU." ] }, { "cell_type": "code", "execution_count": 1, "id": "d08c8bae", "metadata": {}, "outputs": [], "source": [ "from pyspark.sql import SparkSession\n", "from pyspark.conf import SparkConf\n", "from time import time\n", "import os\n", "\n", "# Change to your cluster ip:port\n", "SPARK_MASTER_URL = os.getenv(\"SPARK_MASTER_URL\", \"spark://your-ip:port\")\n" ] }, { "cell_type": "markdown", "id": "6842522a", "metadata": {}, "source": [ "Run the microbenchmark with retry times" ] }, { "cell_type": "code", "execution_count": 2, "id": "45f50252", "metadata": {}, "outputs": [], "source": [ "def runMicroBenchmark(spark, appName, query, retryTimes):\n", " count = 0\n", " total_time = 0\n", " # You can print the physical plan of each query\n", " # spark.sql(query).explain()\n", " while count < retryTimes:\n", " start = time()\n", " spark.sql(query).show(5)\n", " end = time()\n", " total_time += round(end - start, 2)\n", " count = count + 1\n", " print(\"Retry times : {}, \".format(count) + appName + \" Microbenchmark takes {} seconds\".format(round(end - start, 2)))\n", " print(appName + \" Microbenchmark takes average {} seconds after {} retries\".format(round(total_time/retryTimes),retryTimes))\n", " " ] }, { "cell_type": "code", "execution_count": 3, "id": "682c67b1", "metadata": {}, "outputs": [], "source": [ "# You need to update data path with your real path and hardware resource!\n", "dataRoot = os.getenv(\"DATA_ROOT\", \"/data\")\n", "driverMem = os.getenv(\"DRIVER_MEM\", \"50g\")\n", "executorMem = os.getenv(\"EXECUTOR_MEM\", \"12g\")\n", "maxPartionBytes = os.getenv(\"MAX_PARTITION_BYTES\", \"1g\")\n", "executorCores = int(os.getenv(\"EXECUTOR_CORES\", \"4\"))\n", "# common spark settings\n", "conf = SparkConf()\n", "conf.setMaster(SPARK_MASTER_URL)\n", "conf.setAppName(\"Microbenchmark on CPU\")\n", "conf.set(\"spark.driver.memory\", driverMem)\n", "conf.set(\"spark.executor.memory\", executorMem)\n", "conf.set(\"spark.executor.cores\", executorCores)\n", " \n", "conf.set(\"spark.locality.wait\", \"0\")\n", "conf.set(\"spark.sql.files.maxPartitionBytes\", maxPartionBytes) \n", "conf.set(\"spark.dynamicAllocation.enabled\", \"false\") \n", "conf.set(\"spark.sql.adaptive.enabled\", \"true\") \n", "\n", "# create spark session\n", "spark = SparkSession.builder.config(conf=conf).getOrCreate()\n", "# Load dataframe and create tempView\n", "spark.read.parquet(dataRoot + \"/tpcds/store_sales\").createOrReplaceTempView(\"store_sales\")\n", "spark.read.parquet(dataRoot + \"/tpcds/catalog_sales\").createOrReplaceTempView(\"catalog_sales\")\n", "spark.read.parquet(dataRoot + \"/tpcds/web_sales\").createOrReplaceTempView(\"web_sales\")\n", "spark.read.parquet(dataRoot + \"/tpcds/item\").createOrReplaceTempView(\"item\")\n", "spark.read.parquet(dataRoot + \"/tpcds/date_dim\").createOrReplaceTempView(\"date_dim\")\n" ] }, { "cell_type": "markdown", "id": "89512b77", "metadata": {}, "source": [ "### Expand&HashAggregate\n", "This is a microbenchmark about Expand&HashAggregate expressions running on the CPU. The query calculates the distinct value of some dimension columns and average birth year by different c_salutation of customers after grouping by c_current_hdemo_sk." ] }, { "cell_type": "code", "execution_count": 4, "id": "3272ef56", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--------------------------------------------------\n" ] } ], "source": [ "# As a part of this query the size of the data in each task grows a lot. \n", "# By default, Spark will try to distribute the data among all the tasks in the cluster, \n", "# but on large clusters with large parquet files the splittable portions of the parquet files end up not being distributed evenly \n", "# and it is faster to re-partition the data to redistribute it than to deal with skew.\n", "spark.read.parquet(dataRoot + \"/tpcds/customer\").repartition(512).createOrReplaceTempView(\"customer\")\n", "\n", "print(\"-\"*50)" ] }, { "cell_type": "code", "execution_count": 5, "id": "dd12d749", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--------------------------------------------------\n" ] } ], "source": [ "query = '''\n", "select c_current_hdemo_sk,\n", "count(DISTINCT if(c_salutation==\"Ms.\",c_salutation,null)) as c1,\n", "count(DISTINCT if(c_salutation==\"Mr.\",c_salutation,null)) as c12,\n", "count(DISTINCT if(c_salutation==\"Dr.\",c_salutation,null)) as c13,\n", "\n", "count(DISTINCT if(c_salutation==\"Ms.\",c_first_name,null)) as c2,\n", "count(DISTINCT if(c_salutation==\"Mr.\",c_first_name,null)) as c22,\n", "count(DISTINCT if(c_salutation==\"Dr.\",c_first_name,null)) as c23,\n", "\n", "count(DISTINCT if(c_salutation==\"Ms.\",c_last_name,null)) as c3,\n", "count(DISTINCT if(c_salutation==\"Mr.\",c_last_name,null)) as c32,\n", "count(DISTINCT if(c_salutation==\"Dr.\",c_last_name,null)) as c33,\n", "\n", "count(DISTINCT if(c_salutation==\"Ms.\",c_birth_country,null)) as c4,\n", "count(DISTINCT if(c_salutation==\"Mr.\",c_birth_country,null)) as c42,\n", "count(DISTINCT if(c_salutation==\"Dr.\",c_birth_country,null)) as c43,\n", "\n", "count(DISTINCT if(c_salutation==\"Ms.\",c_email_address,null)) as c5,\n", "count(DISTINCT if(c_salutation==\"Mr.\",c_email_address,null)) as c52,\n", "count(DISTINCT if(c_salutation==\"Dr.\",c_email_address,null)) as c53,\n", "\n", "count(DISTINCT if(c_salutation==\"Ms.\",c_login,null)) as c6,\n", "count(DISTINCT if(c_salutation==\"Mr.\",c_login,null)) as c62,\n", "count(DISTINCT if(c_salutation==\"Dr.\",c_login,null)) as c63,\n", "\n", "count(DISTINCT if(c_salutation==\"Ms.\",c_preferred_cust_flag,null)) as c7,\n", "count(DISTINCT if(c_salutation==\"Mr.\",c_preferred_cust_flag,null)) as c72,\n", "count(DISTINCT if(c_salutation==\"Dr.\",c_preferred_cust_flag,null)) as c73,\n", "\n", "count(DISTINCT if(c_salutation==\"Ms.\",c_birth_month,null)) as c8,\n", "count(DISTINCT if(c_salutation==\"Mr.\",c_birth_month,null)) as c82,\n", "count(DISTINCT if(c_salutation==\"Dr.\",c_birth_month,null)) as c83,\n", "\n", "avg(if(c_salutation==\"Ms.\",c_birth_year,null)) as avg1,\n", "avg(if(c_salutation==\"Mr.\",c_birth_year,null)) as avg2,\n", "avg(if(c_salutation==\"Dr.\",c_birth_year,null)) as avg3,\n", "avg(if(c_salutation==\"Miss.\",c_birth_year,null)) as avg4,\n", "avg(if(c_salutation==\"Mrs.\",c_birth_year,null)) as avg5,\n", "avg(if(c_salutation==\"Sir.\",c_birth_year,null)) as avg6,\n", "avg(if(c_salutation==\"Professor.\",c_birth_year,null)) as avg7,\n", "avg(if(c_salutation==\"Teacher.\",c_birth_year,null)) as avg8,\n", "avg(if(c_salutation==\"Agent.\",c_birth_year,null)) as avg9,\n", "avg(if(c_salutation==\"Director.\",c_birth_year,null)) as avg10\n", "from customer group by c_current_hdemo_sk\n", "'''\n", "print(\"-\"*50)" ] }, { "cell_type": "code", "execution_count": 6, "id": "2e105bf8", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+------------------+---+---+---+---+---+---+---+---+---+---+---+---+---+---+----+---+---+---+---+---+---+---+---+---+------------------+------------------+------------------+----+------------------+----+----+----+----+-----+\n", "|c_current_hdemo_sk| c1|c12|c13| c2|c22|c23| c3|c32|c33| c4|c42|c43| c5|c52| c53| c6|c62|c63| c7|c72|c73| c8|c82|c83| avg1| avg2| avg3|avg4| avg5|avg6|avg7|avg8|avg9|avg10|\n", "+------------------+---+---+---+---+---+---+---+---+---+---+---+---+---+---+----+---+---+---+---+---+---+---+---+---+------------------+------------------+------------------+----+------------------+----+----+----+----+-----+\n", "| 5803| 1| 1| 1|285|272|592|358|496|791|185|202|210|458|674|1177| 0| 0| 0| 2| 2| 2| 12| 12| 12|1959.5225806451613|1959.6557863501484|1958.5581196581197|null| 1958.873873873874|null|null|null|null| null|\n", "| 1591| 1| 1| 1|283|237|544|374|489|739|193|206|211|476|664|1144| 0| 0| 0| 2| 2| 2| 12| 12| 12|1957.3514644351465|1958.2278860569716|1958.6174672489083|null|1958.4357894736843|null|null|null|null| null|\n", "| 3918| 1| 1| 1|300|266|539|392|499|755|190|203|210|507|675|1140| 0| 0| 0| 2| 2| 2| 12| 12| 12|1957.6745562130177|1958.2998522895125|1958.8992994746059|null|1959.4233009708737|null|null|null|null| null|\n", "| 1580| 1| 1| 1|296|256|562|392|499|808|190|203|211|499|692|1222| 0| 0| 0| 2| 2| 2| 12| 12| 12|1958.5771543086173| 1957.53591954023|1957.3303278688525|null|1958.3611691022963|null|null|null|null| null|\n", "| 148| 1| 1| 1|309|260|562|392|501|772|187|207|211|488|668|1154| 0| 0| 0| 2| 2| 2| 12| 12| 12| 1956.219008264463|1958.9161676646706|1957.8076256499132|null|1958.3412017167382|null|null|null|null| null|\n", "+------------------+---+---+---+---+---+---+---+---+---+---+---+---+---+---+----+---+---+---+---+---+---+---+---+---+------------------+------------------+------------------+----+------------------+----+----+----+----+-----+\n", "only showing top 5 rows\n", "\n", "Retry times : 1, Expand&HashAggregate Microbenchmark takes 65.21 seconds\n", "Expand&HashAggregate Microbenchmark takes average 65 seconds after 1 retries\n" ] } ], "source": [ "# Run microbenchmark with n retry time\n", "runMicroBenchmark(spark,\"Expand&HashAggregate\",query ,1)" ] }, { "cell_type": "markdown", "id": "57da403a", "metadata": {}, "source": [ "### Windowing (without data skew)\n", "This is a microbenchmark about windowing expressions running on CPU mode. The sub-query calculates the average ss_sales_price of a fixed window function partition by ss_customer_sk, and the parent query calculates the average price of the sub-query grouping by each customer." ] }, { "cell_type": "code", "execution_count": 10, "id": "68169e7f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--------------------------------------------------\n" ] } ], "source": [ "query = '''\n", "select ss_customer_sk,avg(avg_price) as avg_price\n", "from\n", "(\n", "SELECT ss_customer_sk ,avg(ss_sales_price) OVER (PARTITION BY ss_customer_sk order by ss_sold_date_sk ROWS BETWEEN 50 PRECEDING AND 50 FOLLOWING ) as avg_price\n", "FROM store_sales\n", "where ss_customer_sk is not null\n", ") group by ss_customer_sk order by 2 desc \n", "'''\n", "print(\"-\"*50)" ] }, { "cell_type": "code", "execution_count": 11, "id": "f4d1d9ea", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+--------------+------------------+\n", "|ss_customer_sk| avg_price|\n", "+--------------+------------------+\n", "| 15924921|52.453036568858586|\n", "| 24796404|52.406491887877976|\n", "| 10174233|52.217149302596276|\n", "| 27571451| 52.14256448618126|\n", "| 14299506| 52.09827897444722|\n", "+--------------+------------------+\n", "only showing top 5 rows\n", "\n", "Retry times : 1, Windowing without skew Microbenchmark takes 176.61 seconds\n", "Windowing without skew Microbenchmark takes average 177 seconds after 1 retries\n" ] } ], "source": [ "# Run microbenchmark with n retry time\n", "runMicroBenchmark(spark,\"Windowing without skew\",query , 1)" ] }, { "cell_type": "markdown", "id": "7df0e850", "metadata": {}, "source": [ "### Windowing(with data skew)\n", "Data skew is caused by many null values in the ss_customer_sk column." ] }, { "cell_type": "code", "execution_count": 15, "id": "12ec99fb", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--------------------------------------------------\n" ] } ], "source": [ "query = '''\n", "select ss_customer_sk,avg(avg_price) as avg_price\n", "from\n", "(\n", "SELECT ss_customer_sk ,avg(ss_sales_price) OVER (PARTITION BY ss_customer_sk order by ss_sold_date_sk ROWS BETWEEN 50 PRECEDING AND 50 FOLLOWING ) as avg_price\n", "FROM store_sales\n", ") group by ss_customer_sk order by 2 desc \n", "'''\n", "print(\"-\"*50)" ] }, { "cell_type": "code", "execution_count": 16, "id": "86e12b88", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+--------------+------------------+\n", "|ss_customer_sk| avg_price|\n", "+--------------+------------------+\n", "| 15924921| 52.44865972015809|\n", "| 24796404|52.406491887877976|\n", "| 10174233|52.215293069577626|\n", "| 27571451| 52.14256448618126|\n", "| 14299506| 52.09827897444722|\n", "+--------------+------------------+\n", "only showing top 5 rows\n", "\n", "Retry times : 1, Windowing with skew Microbenchmark takes 1666.07 seconds\n", "Windowing with skew Microbenchmark takes average 1666 seconds after 1 retries\n" ] } ], "source": [ "# Run microbenchmark with n retry time\n", "runMicroBenchmark(spark,\"Windowing with skew\",query ,1)" ] }, { "cell_type": "markdown", "id": "2ef292cc", "metadata": {}, "source": [ "### Intersection\n", "This is a microbenchmark about intersection operation running on CPU mode. The query calculates items in the same brand, class, and category that are sold in all three sales channels in two consecutive years." ] }, { "cell_type": "code", "execution_count": 19, "id": "30c8eb8e", "metadata": {}, "outputs": [], "source": [ "query = '''\n", "select i_item_sk ss_item_sk\n", " from item,\n", " (select iss.i_brand_id brand_id, iss.i_class_id class_id, iss.i_category_id category_id\n", " from store_sales, item iss, date_dim d1\n", " where ss_item_sk = iss.i_item_sk\n", " and ss_sold_date_sk = d1.d_date_sk\n", " and d1.d_year between 1999 AND 1999 + 2\n", " intersect\n", " select ics.i_brand_id, ics.i_class_id, ics.i_category_id\n", " from catalog_sales, item ics, date_dim d2\n", " where cs_item_sk = ics.i_item_sk\n", " and cs_sold_date_sk = d2.d_date_sk\n", " and d2.d_year between 1999 AND 1999 + 2\n", " intersect\n", " select iws.i_brand_id, iws.i_class_id, iws.i_category_id\n", " from web_sales, item iws, date_dim d3\n", " where ws_item_sk = iws.i_item_sk\n", " and ws_sold_date_sk = d3.d_date_sk\n", " and d3.d_year between 1999 AND 1999 + 2) x\n", " where i_brand_id = brand_id\n", " and i_class_id = class_id\n", " and i_category_id = category_id\n", "'''" ] }, { "cell_type": "code", "execution_count": 20, "id": "d4f9f669", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+----------+\n", "|ss_item_sk|\n", "+----------+\n", "| 326835|\n", "| 248465|\n", "| 174935|\n", "| 130715|\n", "| 78159|\n", "+----------+\n", "only showing top 5 rows\n", "\n", "Retry times : 1, NDS Q14a subquery Microbenchmark takes 62.42 seconds\n", "NDS Q14a subquery Microbenchmark takes average 62 seconds after 1 retries\n" ] } ], "source": [ "# Run microbenchmark with n retry time\n", "runMicroBenchmark(spark,\"NDS Q14a subquery\",query ,1)" ] }, { "cell_type": "markdown", "id": "5b051d6b", "metadata": {}, "source": [ "### Crossjoin\n", "This is a microbenchmark for a 1-million rows crossjoin with itself." ] }, { "cell_type": "code", "execution_count": 21, "id": "56af3f00", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--------------------------------------------------\n" ] } ], "source": [ "# You have to stop the sparksession and create a new one \n", "# because in this query we need to create more executors with less cores to get the best performance\n", "spark.stop()\n", "conf = SparkConf()\n", "# Common spark settings\n", "conf.setMaster(SPARK_MASTER_URL)\n", "conf.setAppName(\"Crossjoin Microbenchmark on CPU\")\n", " \n", "conf.set(\"spark.driver.memory\", driverMem)\n", "conf.set(\"spark.executor.memory\", executorMem)\n", "conf.set(\"spark.executor.cores\", executorCores)\n", " \n", "conf.set(\"spark.locality.wait\", \"0\")\n", "conf.set(\"spark.sql.files.maxPartitionBytes\", maxPartionBytes) \n", "conf.set(\"spark.dynamicAllocation.enabled\", \"false\") \n", "conf.set(\"spark.sql.adaptive.enabled\", \"true\")\n", "# We can get a better performance by broadcast one table to change CartesianJoin to BroadCastNestLoopJoin\n", "conf.set(\"spark.sql.autoBroadcastJoinThreshold\",1000000000)\n", "# Get or create spark session\n", "spark = SparkSession.builder.config(conf=conf).getOrCreate()\n", "\n", "print(\"-\"*50)" ] }, { "cell_type": "code", "execution_count": 22, "id": "ae9cdc08", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "scanning and writing parquet cost : 18.18 seconds\n", "--------------------------------------------------\n" ] } ], "source": [ "# Load dataframe and create tempView\n", "start = time() \n", "spark.read.parquet(dataRoot + \"/tpcds/customer\").limit(1000000).write.format(\"parquet\").mode(\"overwrite\").save(\"/data/tmp/customer1m\")\n", "end = time()\n", "print(\"scanning and writing parquet cost : {} seconds\".format(round(end - start, 2)))\n", "# We need to tune the partition number to get the best performance.\n", "spark.read.parquet(\"/data/tmp/customer1m\").repartition(16000).createOrReplaceTempView(\"costomer_df_1_million\")\n", "query = '''\n", "select count(*) from costomer_df_1_million c1 inner join costomer_df_1_million c2 on c1.c_customer_sk>c2.c_customer_sk\n", "'''\n", "print(\"-\"*50)" ] }, { "cell_type": "code", "execution_count": 23, "id": "0571d861", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+------------+\n", "| count(1)|\n", "+------------+\n", "|499999500000|\n", "+------------+\n", "\n", "Retry times : 1, Crossjoin Microbenchmark takes 78.8 seconds\n", "Crossjoin Microbenchmark takes average 79 seconds after 1 retries\n" ] } ], "source": [ "# Run microbenchmark with n retry time\n", "runMicroBenchmark(spark,\"Crossjoin\",query ,1)" ] }, { "cell_type": "markdown", "id": "56f915c2-9b9a-4982-8c4e-5b570c17bfeb", "metadata": {}, "source": [ "### HashJoin\n", "This is a microbenchmark for a HashJoin. The query on GPU will be more than 10x times faster than CPU based on the cluster in the readme." ] }, { "cell_type": "code", "execution_count": null, "id": "040603c9-a96f-4017-bcdb-5f93e12996a4", "metadata": {}, "outputs": [], "source": [ "spark.read.parquet(dataRoot + \"/tpcds/store_sales\").createOrReplaceTempView(\"store_sales\")\n", "spark.read.parquet(dataRoot + \"/tpcds/store_returns\").createOrReplaceTempView(\"store_returns\")\n", "\n", "print(\"-\"*50)\n", "query = '''\n", "select sum(store_sales.ss_ext_wholesale_cost)\n", "from store_sales\n", "join store_returns on (ss_item_sk = sr_item_sk) and (ss_addr_sk=sr_addr_sk)\n", "'''\n", "runMicroBenchmark(spark,\"HashJoin\",query,1)" ] }, { "cell_type": "code", "execution_count": 24, "id": "7c118cc9", "metadata": {}, "outputs": [], "source": [ "spark.stop()" ] }, { "cell_type": "code", "execution_count": null, "id": "c9e43255", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: examples/SQL+DF-Examples/micro-benchmarks/notebooks/micro-benchmarks-gpu.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "id": "62787244", "metadata": {}, "source": [ "# Microbenchmarks on GPU\n", "This is a notebook for microbenchmarks running on GPU. " ] }, { "cell_type": "code", "execution_count": 1, "id": "1c3a15d7", "metadata": {}, "outputs": [], "source": [ "from pyspark.sql import SparkSession\n", "from pyspark.conf import SparkConf\n", "from time import time\n", "import os\n", "# Change to your cluster ip:port and directories\n", "SPARK_MASTER_URL = os.getenv(\"SPARK_MASTER_URL\", \"spark:your-ip:port\")\n", "RAPIDS_JAR = os.getenv(\"RAPIDS_JAR\", \"/your-path/rapids-4-spark_2.12-26.02.0.jar\")\n" ] }, { "cell_type": "markdown", "id": "b10a2ad1", "metadata": {}, "source": [ "Run the microbenchmark with retryTimes" ] }, { "cell_type": "code", "execution_count": 2, "id": "0c3536ad", "metadata": {}, "outputs": [], "source": [ "def runMicroBenchmark(spark, appName, query, retryTimes):\n", " count = 0\n", " total_time = 0\n", " # You can print the physical plan of each query\n", " # spark.sql(query).explain()\n", " while count < retryTimes:\n", " start = time()\n", " spark.sql(query).show(5)\n", " end = time()\n", " total_time += round(end - start, 2)\n", " count = count + 1\n", " print(\"Retry times : {}, \".format(count) + appName + \" microbenchmark takes {} seconds\".format(round(end - start, 2)))\n", " print(appName + \" microbenchmark takes average {} seconds after {} retries\".format(round(total_time/retryTimes),retryTimes))\n", " with open('result.txt', 'a') as file:\n", " file.write(\"{},{},{}\\n\".format(appName, round(total_time/retryTimes), retryTimes))" ] }, { "cell_type": "code", "execution_count": null, "id": "975717da", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--------------------------------------------------\n" ] } ], "source": [ "# You need to update with your real hardware resource \n", "driverMem = os.getenv(\"DRIVER_MEM\", \"50g\")\n", "executorMem = os.getenv(\"EXECUTOR_MEM\", \"16g\")\n", "maxPartionBytes = os.getenv(\"MAX_PARTITION_BYTES\", \"1g\")\n", "pinnedPoolSize = os.getenv(\"PINNED_POOL_SIZE\", \"8g\")\n", "concurrentGpuTasks = os.getenv(\"CONCURRENT_GPU_TASKS\", \"4\")\n", "executorCores = int(os.getenv(\"EXECUTOR_CORES\", \"16\"))\n", "eventlogDir = \"file:\"+os.getenv(\"EVENTLOG_DIR\")\n", "gpuPerExecutor = 1/executorCores\n", "# Common spark settings\n", "conf = SparkConf()\n", "conf.setMaster(SPARK_MASTER_URL)\n", "conf.setAppName(\"Microbenchmark on GPU\")\n", "conf.set(\"spark.driver.memory\", driverMem)\n", "## The tasks will run on GPU memory, so there is no need to set a high host memory\n", "conf.set(\"spark.executor.memory\", executorMem)\n", "## The tasks will run on GPU cores, so there is no need to use many cpu cores\n", "conf.set(\"spark.executor.cores\", executorCores)\n", "conf.set(\"spark.locality.wait\", \"0\")\n", "conf.set(\"spark.sql.files.maxPartitionBytes\", maxPartionBytes) \n", "conf.set(\"spark.dynamicAllocation.enabled\", \"false\") \n", "conf.set(\"spark.sql.adaptive.enabled\", \"true\") \n", "\n", "# Plugin settings\n", "conf.set(\"spark.executor.resource.gpu.amount\", \"1\")\n", "# 4 tasks will run concurrently per GPU\n", "conf.set(\"spark.rapids.sql.concurrentGpuTasks\", concurrentGpuTasks)\n", "# Pinned 8g host memory to transfer data between GPU and host memory\n", "conf.set(\"spark.rapids.memory.pinnedPool.size\", pinnedPoolSize)\n", "# 16 tasks will run concurrently per executor, as we set spark.executor.cores=16\n", "conf.set(\"spark.task.resource.gpu.amount\", gpuPerExecutor) \n", "conf.set(\"spark.rapids.sql.enabled\", \"true\") \n", "conf.set(\"spark.plugins\", \"com.nvidia.spark.SQLPlugin\")\n", "conf.set(\"spark.rapids.sql.variableFloatAgg.enabled\", \"true\")\n", "conf.set(\"spark.driver.extraClassPath\", RAPIDS_JAR)\n", "conf.set(\"spark.executor.extraClassPath\", RAPIDS_JAR)\n", "conf.set(\"spark.jars\", RAPIDS_JAR)\n", "conf.set(\"spark.eventLog.enabled\", \"true\")\n", "conf.set(\"spark.eventLog.dir\", eventlogDir)\n", "# Create spark session\n", "spark = SparkSession.builder.config(conf=conf).getOrCreate()\n", "# Load dataframe and create tempView\n", "# You need to update data path to your real path!\n", "dataRoot = os.getenv(\"DATA_ROOT\", \"/data\")\n", "spark.read.parquet(dataRoot + \"/customer.dat\").createOrReplaceTempView(\"customer\")\n", "spark.read.parquet(dataRoot + \"/store_sales.dat\").createOrReplaceTempView(\"store_sales\")\n", "spark.read.parquet(dataRoot + \"/catalog_sales.dat\").createOrReplaceTempView(\"catalog_sales\")\n", "spark.read.parquet(dataRoot + \"/web_sales.dat\").createOrReplaceTempView(\"web_sales\")\n", "spark.read.parquet(dataRoot + \"/item.dat\").createOrReplaceTempView(\"item\")\n", "spark.read.parquet(dataRoot + \"/date_dim.dat\").createOrReplaceTempView(\"date_dim\")\n", "print(\"-\"*50)" ] }, { "cell_type": "markdown", "id": "7136eb63", "metadata": {}, "source": [ "### Expand&HashAggregate\n", "This is a microbenchmark about Expand&HashAggregate expressions running on the GPU. The query calculates the distinct value of some dimension columns and average birth year by different c_salutation of customers after grouping by c_current_hdemo_sk. You will see about 10x speedups in this query. Because an additional shuffle involved by the repartition operator in CPU mode. And GPUExpand and GPUHashAggregate is much faster than Expand and HashAggregate because GPU algorithms allow us to parallelize the computation and we can utilize most of the GPU cores. The tasks' duration in the third stage is less than one second but will cost 20x-40x while running on CPU. There will be a more significant performance improvement along with the increasing number of count distinct columns and aggregate functions." ] }, { "cell_type": "code", "execution_count": 4, "id": "dd12d749", "metadata": {}, "outputs": [], "source": [ "query = '''\n", "select c_current_hdemo_sk,\n", "count(DISTINCT if(c_salutation==\"Ms.\",c_salutation,null)) as c1,\n", "count(DISTINCT if(c_salutation==\"Mr.\",c_salutation,null)) as c12,\n", "count(DISTINCT if(c_salutation==\"Dr.\",c_salutation,null)) as c13,\n", "\n", "count(DISTINCT if(c_salutation==\"Ms.\",c_first_name,null)) as c2,\n", "count(DISTINCT if(c_salutation==\"Mr.\",c_first_name,null)) as c22,\n", "count(DISTINCT if(c_salutation==\"Dr.\",c_first_name,null)) as c23,\n", "\n", "count(DISTINCT if(c_salutation==\"Ms.\",c_last_name,null)) as c3,\n", "count(DISTINCT if(c_salutation==\"Mr.\",c_last_name,null)) as c32,\n", "count(DISTINCT if(c_salutation==\"Dr.\",c_last_name,null)) as c33,\n", "\n", "count(DISTINCT if(c_salutation==\"Ms.\",c_birth_country,null)) as c4,\n", "count(DISTINCT if(c_salutation==\"Mr.\",c_birth_country,null)) as c42,\n", "count(DISTINCT if(c_salutation==\"Dr.\",c_birth_country,null)) as c43,\n", "\n", "count(DISTINCT if(c_salutation==\"Ms.\",c_email_address,null)) as c5,\n", "count(DISTINCT if(c_salutation==\"Mr.\",c_email_address,null)) as c52,\n", "count(DISTINCT if(c_salutation==\"Dr.\",c_email_address,null)) as c53,\n", "\n", "count(DISTINCT if(c_salutation==\"Ms.\",c_login,null)) as c6,\n", "count(DISTINCT if(c_salutation==\"Mr.\",c_login,null)) as c62,\n", "count(DISTINCT if(c_salutation==\"Dr.\",c_login,null)) as c63,\n", "\n", "count(DISTINCT if(c_salutation==\"Ms.\",c_preferred_cust_flag,null)) as c7,\n", "count(DISTINCT if(c_salutation==\"Mr.\",c_preferred_cust_flag,null)) as c72,\n", "count(DISTINCT if(c_salutation==\"Dr.\",c_preferred_cust_flag,null)) as c73,\n", "\n", "count(DISTINCT if(c_salutation==\"Ms.\",c_birth_month,null)) as c8,\n", "count(DISTINCT if(c_salutation==\"Mr.\",c_birth_month,null)) as c82,\n", "count(DISTINCT if(c_salutation==\"Dr.\",c_birth_month,null)) as c83,\n", "\n", "avg(if(c_salutation==\"Ms.\",c_birth_year,null)) as avg1,\n", "avg(if(c_salutation==\"Mr.\",c_birth_year,null)) as avg2,\n", "avg(if(c_salutation==\"Dr.\",c_birth_year,null)) as avg3,\n", "avg(if(c_salutation==\"Miss.\",c_birth_year,null)) as avg4,\n", "avg(if(c_salutation==\"Mrs.\",c_birth_year,null)) as avg5,\n", "avg(if(c_salutation==\"Sir.\",c_birth_year,null)) as avg6,\n", "avg(if(c_salutation==\"Professor.\",c_birth_year,null)) as avg7,\n", "avg(if(c_salutation==\"Teacher.\",c_birth_year,null)) as avg8,\n", "avg(if(c_salutation==\"Agent.\",c_birth_year,null)) as avg9,\n", "avg(if(c_salutation==\"Director.\",c_birth_year,null)) as avg10\n", "from customer group by c_current_hdemo_sk\n", "'''" ] }, { "cell_type": "code", "execution_count": 5, "id": "2e105bf8", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+------------------+---+---+---+---+---+---+---+---+---+---+---+---+---+---+----+---+---+---+---+---+---+---+---+---+------------------+------------------+------------------+----+------------------+----+----+----+----+-----+\n", "|c_current_hdemo_sk| c1|c12|c13| c2|c22|c23| c3|c32|c33| c4|c42|c43| c5|c52| c53| c6|c62|c63| c7|c72|c73| c8|c82|c83| avg1| avg2| avg3|avg4| avg5|avg6|avg7|avg8|avg9|avg10|\n", "+------------------+---+---+---+---+---+---+---+---+---+---+---+---+---+---+----+---+---+---+---+---+---+---+---+---+------------------+------------------+------------------+----+------------------+----+----+----+----+-----+\n", "| 1238| 1| 1| 1|284|255|562|358|467|772|194|203|211|452|664|1157| 0| 0| 0| 2| 2| 2| 12| 12| 12|1957.2444933920706|1958.8547655068078|1957.2870771899393|null| 1958.042643923241|null|null|null|null| null|\n", "| 6658| 1| 1| 1|318|253|541|384|492|752|190|203|210|516|647|1115| 0| 0| 0| 2| 2| 2| 12| 12| 12|1959.0155945419103|1958.9720930232559|1958.0089525514773|null|1959.2618025751074|null|null|null|null| null|\n", "| 1088| 1| 1| 1|302|263|547|374|476|736|191|206|210|487|648|1074| 0| 0| 0| 2| 2| 2| 12| 12| 12|1957.7084188911704|1959.1323076923077|1957.2780898876404|null|1958.5641025641025|null|null|null|null| null|\n", "| 4818| 1| 1| 1|276|248|542|368|514|747|183|204|211|460|691|1093| 0| 0| 0| 2| 2| 2| 12| 12| 12|1957.8954248366013|1958.1313131313132|1957.5018315018315|null|1958.0252293577983|null|null|null|null| null|\n", "| 148| 1| 1| 1|309|260|562|392|501|772|187|207|211|488|668|1154| 0| 0| 0| 2| 2| 2| 12| 12| 12| 1956.219008264463|1958.9161676646706|1957.8076256499132|null|1958.3412017167382|null|null|null|null| null|\n", "+------------------+---+---+---+---+---+---+---+---+---+---+---+---+---+---+----+---+---+---+---+---+---+---+---+---+------------------+------------------+------------------+----+------------------+----+----+----+----+-----+\n", "only showing top 5 rows\n", "\n", "Retry times : 1, Expand&HashAggregate microbenchmark takes 11.13 seconds\n", "+------------------+---+---+---+---+---+---+---+---+---+---+---+---+---+---+----+---+---+---+---+---+---+---+---+---+------------------+------------------+------------------+----+------------------+----+----+----+----+-----+\n", "|c_current_hdemo_sk| c1|c12|c13| c2|c22|c23| c3|c32|c33| c4|c42|c43| c5|c52| c53| c6|c62|c63| c7|c72|c73| c8|c82|c83| avg1| avg2| avg3|avg4| avg5|avg6|avg7|avg8|avg9|avg10|\n", "+------------------+---+---+---+---+---+---+---+---+---+---+---+---+---+---+----+---+---+---+---+---+---+---+---+---+------------------+------------------+------------------+----+------------------+----+----+----+----+-----+\n", "| 1238| 1| 1| 1|284|255|562|358|467|772|194|203|211|452|664|1157| 0| 0| 0| 2| 2| 2| 12| 12| 12|1957.2444933920706|1958.8547655068078|1957.2870771899393|null| 1958.042643923241|null|null|null|null| null|\n", "| 6658| 1| 1| 1|318|253|541|384|492|752|190|203|210|516|647|1115| 0| 0| 0| 2| 2| 2| 12| 12| 12|1959.0155945419103|1958.9720930232559|1958.0089525514773|null|1959.2618025751074|null|null|null|null| null|\n", "| 4818| 1| 1| 1|276|248|542|368|514|747|183|204|211|460|691|1093| 0| 0| 0| 2| 2| 2| 12| 12| 12|1957.8954248366013|1958.1313131313132|1957.5018315018315|null|1958.0252293577983|null|null|null|null| null|\n", "| 1088| 1| 1| 1|302|263|547|374|476|736|191|206|210|487|648|1074| 0| 0| 0| 2| 2| 2| 12| 12| 12|1957.7084188911704|1959.1323076923077|1957.2780898876404|null|1958.5641025641025|null|null|null|null| null|\n", "| 148| 1| 1| 1|309|260|562|392|501|772|187|207|211|488|668|1154| 0| 0| 0| 2| 2| 2| 12| 12| 12| 1956.219008264463|1958.9161676646706|1957.8076256499132|null|1958.3412017167382|null|null|null|null| null|\n", "+------------------+---+---+---+---+---+---+---+---+---+---+---+---+---+---+----+---+---+---+---+---+---+---+---+---+------------------+------------------+------------------+----+------------------+----+----+----+----+-----+\n", "only showing top 5 rows\n", "\n", "Retry times : 2, Expand&HashAggregate microbenchmark takes 7.74 seconds\n", "Expand&HashAggregate microbenchmark takes average 9 seconds after 2 retries\n" ] } ], "source": [ "# Run microbenchmark with n retry time\n", "runMicroBenchmark(spark,\"Expand&HashAggregate\",query,2)" ] }, { "cell_type": "markdown", "id": "f50ec183", "metadata": {}, "source": [ "### Windowing(without data skew)\n", "This is a microbenchmark about windowing expressions running on GPU mode. The sub-query calculates the average ss_sales_price of a fixed window function partition by ss_customer_sk, and the parent query calculates the average price of the sub-query grouping by each customer. You will see about 25x speedups in this query. The speedup mainly comes from GPUSort/GPUWindow/GPUHashAggregate. The avg aggregation function evaluates all rows which are generated by the sub-query's window function. There will be a more significant performance improvement along with the increasing number of sub-query aggregate functions." ] }, { "cell_type": "code", "execution_count": 6, "id": "31bd0635", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--------------------------------------------------\n" ] } ], "source": [ "query = '''\n", "select ss_customer_sk,avg(avg_price) as avg_price\n", "from\n", "(\n", "SELECT ss_customer_sk ,avg(ss_sales_price) OVER (PARTITION BY ss_customer_sk order by ss_sold_date_sk ROWS BETWEEN 50 PRECEDING AND 50 FOLLOWING ) as avg_price\n", "FROM store_sales\n", "where ss_customer_sk is not null\n", ") group by ss_customer_sk order by 2 desc \n", "'''\n", "print(\"-\"*50)" ] }, { "cell_type": "code", "execution_count": 7, "id": "f9e93983", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+--------------+------------------+\n", "|ss_customer_sk| avg_price|\n", "+--------------+------------------+\n", "| 15924921|52.375180502283705|\n", "| 24796404| 52.21073975966333|\n", "| 14299506| 52.16263537127018|\n", "| 27571451|52.156112032252395|\n", "| 10174233| 52.06401030721082|\n", "+--------------+------------------+\n", "only showing top 5 rows\n", "\n", "Retry times : 1, Windowing without skew microbenchmark takes 11.39 seconds\n", "+--------------+-----------------+\n", "|ss_customer_sk| avg_price|\n", "+--------------+-----------------+\n", "| 15924921|52.53781291335107|\n", "| 24796404|52.39683466140243|\n", "| 27571451|52.18830023174899|\n", "| 14299506|52.10829141087412|\n", "| 10174233|51.92766214818386|\n", "+--------------+-----------------+\n", "only showing top 5 rows\n", "\n", "Retry times : 2, Windowing without skew microbenchmark takes 9.53 seconds\n", "Windowing without skew microbenchmark takes average 10 seconds after 2 retries\n" ] } ], "source": [ "# Run microbenchmark with n retry time\n", "runMicroBenchmark(spark,\"Windowing without skew\",query,2)" ] }, { "cell_type": "markdown", "id": "dcf08e47", "metadata": {}, "source": [ "### Windowing(with data skew)\n", "Data skew is caused by many null values in the ss_customer_sk column. You will see about 80x speedups in this query. The heavier skew task a query has, the more improved performance we will get because GPU parallelizes the computation, CPU is limited to just a single core because of how the algorithms are written." ] }, { "cell_type": "code", "execution_count": 8, "id": "2b9d223c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--------------------------------------------------\n" ] } ], "source": [ "query = '''\n", "select ss_customer_sk,avg(avg_price) as avg_price\n", "from\n", "(\n", "SELECT ss_customer_sk ,avg(ss_sales_price) OVER (PARTITION BY ss_customer_sk order by ss_sold_date_sk ROWS BETWEEN 50 PRECEDING AND 50 FOLLOWING ) as avg_price\n", "FROM store_sales\n", ") group by ss_customer_sk order by 2 desc \n", "'''\n", "print(\"-\"*50)" ] }, { "cell_type": "code", "execution_count": 9, "id": "0d7c65ee", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+--------------+------------------+\n", "|ss_customer_sk| avg_price|\n", "+--------------+------------------+\n", "| 24796404| 52.40675225109215|\n", "| 27571451|52.396675141359374|\n", "| 15924921| 52.30557497833058|\n", "| 10174233|52.088916933379096|\n", "| 14299506|51.995045713009794|\n", "+--------------+------------------+\n", "only showing top 5 rows\n", "\n", "Retry times : 1, Windowing with skew microbenchmark takes 17.46 seconds\n", "+--------------+------------------+\n", "|ss_customer_sk| avg_price|\n", "+--------------+------------------+\n", "| 24796404|52.403564615099896|\n", "| 15924921|52.262694645994465|\n", "| 27571451| 52.14256448618127|\n", "| 10174233| 52.11346591610992|\n", "| 14299506| 51.99180221022445|\n", "+--------------+------------------+\n", "only showing top 5 rows\n", "\n", "Retry times : 2, Windowing with skew microbenchmark takes 16.63 seconds\n", "Windowing with skew microbenchmark takes average 17 seconds after 2 retries\n" ] } ], "source": [ "# Run microbenchmark with n retry time\n", "runMicroBenchmark(spark,\"Windowing with skew\",query,2)" ] }, { "cell_type": "markdown", "id": "53c0ed28", "metadata": {}, "source": [ "### Intersection\n", "This is a microbenchmark about intersection operation running on GPU mode. The query calculates items in the same brand, class, and category that are sold in all three sales channels in two consecutive years. You will see about 10x speedups in this query. This is a competition between high cardinality SortMergeJoin vs GpuShuffleHashJoin. The mainly improved performance comes from two SortMergeJoin(s) in this query running on CPU get converted to GpuShuffleHashJoin running on GPU." ] }, { "cell_type": "code", "execution_count": 10, "id": "643c2e8a", "metadata": {}, "outputs": [], "source": [ "query = '''\n", "select i_item_sk ss_item_sk\n", " from item,\n", " (select iss.i_brand_id brand_id, iss.i_class_id class_id, iss.i_category_id category_id\n", " from store_sales, item iss, date_dim d1\n", " where ss_item_sk = iss.i_item_sk\n", " and ss_sold_date_sk = d1.d_date_sk\n", " and d1.d_year between 1999 AND 1999 + 2\n", " intersect\n", " select ics.i_brand_id, ics.i_class_id, ics.i_category_id\n", " from catalog_sales, item ics, date_dim d2\n", " where cs_item_sk = ics.i_item_sk\n", " and cs_sold_date_sk = d2.d_date_sk\n", " and d2.d_year between 1999 AND 1999 + 2\n", " intersect\n", " select iws.i_brand_id, iws.i_class_id, iws.i_category_id\n", " from web_sales, item iws, date_dim d3\n", " where ws_item_sk = iws.i_item_sk\n", " and ws_sold_date_sk = d3.d_date_sk\n", " and d3.d_year between 1999 AND 1999 + 2) x\n", " where i_brand_id = brand_id\n", " and i_class_id = class_id\n", " and i_category_id = category_id\n", "'''" ] }, { "cell_type": "code", "execution_count": 11, "id": "61bc2260", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+----------+\n", "|ss_item_sk|\n", "+----------+\n", "| 4323|\n", "| 4324|\n", "| 4325|\n", "| 4327|\n", "| 4328|\n", "+----------+\n", "only showing top 5 rows\n", "\n", "Retry times : 1, NDS Q14a subquery microbenchmark takes 6.71 seconds\n", "+----------+\n", "|ss_item_sk|\n", "+----------+\n", "| 14103|\n", "| 14104|\n", "| 14105|\n", "| 14107|\n", "| 14108|\n", "+----------+\n", "only showing top 5 rows\n", "\n", "Retry times : 2, NDS Q14a subquery microbenchmark takes 6.11 seconds\n", "NDS Q14a subquery microbenchmark takes average 6 seconds after 2 retries\n" ] } ], "source": [ "# Run microbenchmark with n retry time\n", "runMicroBenchmark(spark,\"NDS Q14a subquery\",query,2)" ] }, { "cell_type": "markdown", "id": "1346d126", "metadata": {}, "source": [ "### Crossjoin\n", "This is a microbenchmark for a 1-million rows crossjoin with itself. You will see about 10x speedups in this query. The mainly improved performance comes from converting BroadcastNestedLoogJoin running on CPU to GpuBroadcastNestedLoogJoin running on GPU." ] }, { "cell_type": "code", "execution_count": 12, "id": "286ea45d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "scanning and writing parquet cost : 5.31 seconds\n", "--------------------------------------------------\n" ] } ], "source": [ "start = time() \n", "spark.read.parquet(dataRoot + \"/customer.dat\").limit(1000000).write.format(\"parquet\").mode(\"overwrite\").save(\"/data/tmp/customer1m\")\n", "end = time()\n", "# Parquet file scanning and writing will be about 3 times faster running on GPU\n", "print(\"scanning and writing parquet cost : {} seconds\".format(round(end - start, 2)))\n", "spark.read.parquet(\"/data/tmp/customer1m\").repartition(200).createOrReplaceTempView(\"costomer_df_1_million\")\n", "query = '''\n", "select count(*) from costomer_df_1_million c1 inner join costomer_df_1_million c2 on c1.c_customer_sk>c2.c_customer_sk\n", "'''\n", "print(\"-\"*50)" ] }, { "cell_type": "code", "execution_count": 13, "id": "f41b8d54", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+------------+\n", "| count(1)|\n", "+------------+\n", "|499999500000|\n", "+------------+\n", "\n", "Retry times : 1, Crossjoin microbenchmark takes 6.7 seconds\n", "+------------+\n", "| count(1)|\n", "+------------+\n", "|499999500000|\n", "+------------+\n", "\n", "Retry times : 2, Crossjoin microbenchmark takes 6.37 seconds\n", "Crossjoin microbenchmark takes average 7 seconds after 2 retries\n" ] } ], "source": [ "# Run microbenchmark with n retry time\n", "runMicroBenchmark(spark,\"Crossjoin\",query,2)" ] }, { "cell_type": "markdown", "id": "06b351e6-b7bd-4063-a20b-fe4fd71221f9", "metadata": {}, "source": [ "### HashJoin\n", "This is a microbenchmark for a HashJoin. The query on GPU will be more than 10x times faster than CPU based on the cluster in the readme." ] }, { "cell_type": "code", "execution_count": null, "id": "191d0c9a-2d3a-40f4-89aa-f61dab5caa90", "metadata": {}, "outputs": [], "source": [ "spark.read.parquet(dataRoot + \"/store_sales.dat\").createOrReplaceTempView(\"store_sales\")\n", "spark.read.parquet(dataRoot + \"/store_returns.dat\").createOrReplaceTempView(\"store_returns\")\n", "\n", "print(\"-\"*50)\n", "query = '''\n", "select sum(store_sales.ss_ext_wholesale_cost)\n", "from store_sales\n", "join store_returns on (ss_item_sk = sr_item_sk) and (ss_addr_sk=sr_addr_sk)\n", "'''\n", "runMicroBenchmark(spark,\"HashJoin\",query,1)" ] }, { "cell_type": "code", "execution_count": null, "id": "fc2092e8", "metadata": {}, "outputs": [], "source": [ "spark.stop()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: examples/SQL+DF-Examples/retail-analytics/README.md ================================================ # Overview Retail Analytics This repository contains two Jupyter notebooks: Data Generation: This notebook generates sample data that can be used for analysis. It demonstrates how to use various Python libraries to create synthetic data sets that can be used for testing and experimentation. This notebook can be run in GCP n1-standard-32 instance type Data Cleaning and Analysis: This notebook takes the generated data and performs a series of cleaning and analysis tasks. It demonstrates how to use Spark RAPIDS library to manipulate and analyze data sets. ================================================ FILE: examples/SQL+DF-Examples/retail-analytics/notebooks/python/retail-analytic.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import random\n", "from pyspark.sql import SparkSession\n", "from pyspark import broadcast, SparkConf\n", "import time\n", "import os\n", "\n", "RAPIDS_JAR = os.getenv(\"RAPIDS_JAR\", \"/path/to/your/jars/rapids.jar\")\n", "SPARK_MASTER = os.getenv(\"SPARK_MASTER_URL\", \"spark://ip:port\")\n", "print(\"RAPIDS_JAR: {}\".format(RAPIDS_JAR))\n", "if \"sc\" in globals():\n", " sc.stop()\n", "\n", "### Configure the parameters based on your dataproc cluster ###\n", "conf = SparkConf().setAppName(\"Retail Analytics\")\n", "conf.setMaster(SPARK_MASTER)\n", "conf.set(\"spark.driver.extraClassPath\", RAPIDS_JAR)\n", "conf.set(\"spark.executor.extraClassPath\", RAPIDS_JAR)\n", "conf.set(\"spark.jars\", RAPIDS_JAR)\n", "conf.set(\"spark.executor.instances\", \"1\")\n", "conf.set(\"spark.executor.cores\", \"4\")\n", "conf.set(\"spark.task.resource.gpu.amount\", \"0.25\")\n", "conf.set(\"spark.rapids.sql.concurrentGpuTasks\", \"2\")\n", "conf.set(\"spark.executor.memory\", \"4g\")\n", "conf.set(\"spark.sql.files.maxPartitionBytes\", \"128m\")\n", "conf.set(\"spark.executor.resource.gpu.amount\", \"1\")\n", "conf.set(\"spark.rapids.memory.pinnedPool.size\", \"2048m\")\n", "conf.set(\"spark.executor.memoryOverhead\", \"4096m\")\n", "conf.set(\"spark.dynamicAllocation.enabled\", \"false\")\n", "conf.set(\"spark.rapids.sql.format.json.read.enabled\",True)\n", "conf.set(\"spark.rapids.sql.castStringToTimestamp.enabled\",True)\n", "conf.set(\"spark.rapids.sql.expression.PercentRank\",False)\n", "conf.set(\"spark.rapids.sql.castDecimalToString.enabled\",True)\n", "conf.set(\"spark.rapids.sql.hasExtendedYearValues\",False)\n", "conf.set(\"spark.rapids.sql.enabled\",True)\n", "conf.set(\"spark.plugins\", \"com.nvidia.spark.SQLPlugin\")\n", "conf.set(\"spark.rapids.sql.allowMultipleJars\", \"ALWAYS\")\n", "\n", "spark = SparkSession.builder \\\n", " .config(conf=conf) \\\n", " .getOrCreate()\n", "# create a SparkSession\n", "spark = SparkSession.builder.appName(\"RetailInvMgmt\").getOrCreate()" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import os\n", "# You need to update these to your real paths!\n", "dataRoot = os.getenv(\"DATA_ROOT\", 'path/to/your/datasets')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from pyspark.sql.functions import *\n", "from pyspark.sql.types import *\n", "from pyspark.sql.window import Window\n", "\n", "start = time.time()\n", "\n", "def clean_data(df):\n", " # remove missing values\n", " df = df.dropna()\n", " # remove duplicate data\n", " df = df.dropDuplicates()\n", " return df\n", "\n", "\n", "def read_data(spark, format, file_path):\n", " if format==\"csv\":\n", " return spark.read.format(format).load(file_path,header=True)\n", " else:\n", " return spark.read.format(format).load(file_path)\n", "\n", "# read sales data\n", "sales_df = read_data(spark, \"csv\", dataRoot+\"/sales/\")\n", "\n", "# read stock data\n", "stock_df = read_data(spark, \"json\", dataRoot+\"/stock/\")\n", "\n", "# read supplier data\n", "supplier_df = read_data(spark, \"json\", dataRoot+\"/supplier/\")\n", "\n", "# read customer data\n", "customer_df = read_data(spark, \"csv\", dataRoot+\"/customer/\")\n", "\n", "# read market data\n", "market_df = read_data(spark, \"csv\", dataRoot+\"/market/\")\n", "\n", "# read logistic data\n", "logistic_df = read_data(spark, \"csv\", dataRoot+\"/logistic/\")\n", "\n", "\n", "# data cleaning\n", "sales_df = clean_data(sales_df)\n", "stock_df = clean_data(stock_df)\n", "supplier_df = clean_data(supplier_df)\n", "customer_df = clean_data(customer_df)\n", "market_df = clean_data(market_df)\n", "logistic_df = clean_data(logistic_df)\n", "\n", "\n", "# convert date columns to date type\n", "sales_df = sales_df.withColumn(\"date_of_sale\", to_date(col(\"date_of_sale\")))\n", "stock_df = stock_df.withColumn(\"date_received\", to_date(col(\"date_received\")))\n", "supplier_df = supplier_df.withColumn(\"date_ordered\", to_date(col(\"date_ordered\")))\n", "\n", "# standardize case of string columns\n", "sales_df = sales_df.withColumn(\"product_name\", upper(col(\"product_name\")))\n", "stock_df = stock_df.withColumn(\"product_name\", upper(col(\"product_name\")))\n", "stock_df = stock_df.withColumn(\"location\", upper(col(\"location\")))\n", "supplier_df = supplier_df.withColumn(\"product_name\", upper(col(\"product_name\")))\n", "customer_df = customer_df.withColumn(\"customer_name\", upper(col(\"customer_name\")))\n", "market_df = market_df.withColumn(\"product_name\", upper(col(\"product_name\")))\n", "logistic_df = logistic_df.withColumn(\"product_name\", upper(col(\"product_name\")))\n", "\n", "# remove leading and trailing whitespaces\n", "sales_df = sales_df.withColumn(\"product_name\", trim(col(\"product_name\")))\n", "stock_df = stock_df.withColumn(\"location\", trim(col(\"location\")))\n", "\n", "supplier_df = supplier_df.withColumn(\"product_name\", trim(col(\"product_name\")))\n", "customer_df = customer_df.withColumn(\"customer_name\", trim(col(\"customer_name\")))\n", "market_df = market_df.withColumn(\"product_name\", trim(col(\"product_name\")))\n", "logistic_df = logistic_df.withColumn(\"product_name\", trim(col(\"product_name\")))\n", "\n", "# check for invalid values\n", "sales_df = sales_df.filter(col(\"product_name\").isNotNull())\n", "stock_df = stock_df.filter(col(\"location\").isNotNull())\n", "customer_df = customer_df.filter(col(\"gender\").isin(\"male\",\"female\"))\n", "market_df = market_df.filter(col(\"product_name\").isNotNull())\n", "logistic_df = logistic_df.filter(col(\"product_name\").isNotNull())\n", "\n", "#drop extra columns\n", "market_df = market_df.drop(\"price\")\n", "supplier_df = supplier_df.drop(\"price\")\n", "\n", "# join all data\n", "data_int = sales_df.join(stock_df, \"product_name\",\"leftouter\").join(supplier_df, \"product_name\",\"leftouter\").join(market_df, \"product_name\",\"leftouter\").join(logistic_df, \"product_name\",\"leftouter\").join(customer_df, \"customer_id\",\"leftouter\") \n", "\n", "# write the cleaned data\n", "os.makedirs(dataRoot+\"cleaned/\", exist_ok=True)\n", "data_int.write.mode(\"overwrite\").format(\"parquet\").save(dataRoot+\"/cleaned/\")\n", "\n", "end = time.time()\n", "\n", "print(\"Time taken on GPU for Data Cleaning: \", end - start)\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from pyspark.sql.functions import *\n", "from pyspark.sql.types import *\n", "from pyspark.sql.window import Window\n", "\n", "#DO VARIOUS RETAIL DATA ANALYTICS \n", "\n", "start = time.time()\n", "\n", "# read cleaned data\n", "\n", "data = spark.read.format(\"parquet\").load(dataRoot+\"/cleaned/\")\n", "\n", "#Case when statement to create a new column to indicate whether the product is perishable or not:\n", "\n", "data = data.withColumn(\"perishable\", when(col(\"shelf_life\") <= 30, \"yes\").otherwise(\"no\"))\n", "\n", "# You can use the when() and otherwise() functions to create new columns based on certain conditions:\n", "\n", "data = data.withColumn(\"sales_status\", when(col(\"quantity_sold\") > 50, \"good\").otherwise(\"bad\"))\n", "\n", "# create a window to perform time series analysis\n", "window = Window.partitionBy(\"product_name\").orderBy(\"date_of_sale\")\n", "\n", "# calculate the rolling average of sales for each product\n", "time_series_df = data.withColumn(\"rolling_avg_sales\", avg(\"quantity_sold\").over(window))\n", "\n", "# use window function for forecasting\n", "\n", "forecast_df = time_series_df.withColumn(\"prev_sales\", lag(\"rolling_avg_sales\").over(window))\\\n", " .withColumn(\"next_sales\", lead(\"rolling_avg_sales\").over(window))\n", "\n", "\n", "# Calculate the average price of a product, grouped by supplier\n", "forecast_df.groupBy(\"sup_id\").agg({\"price\": \"avg\"}).show()\n", "\n", "\n", "# Calculate the total quantity in stock and total sales by supplier\n", "forecast_df.groupBy(\"sup_id\").agg({\"quantity_in_stock\": \"sum\", \"price\": \"sum\"}).show()\n", "\n", "#Calculate the number of perishable v/s non-perishable product per location\n", "forecast_df.groupBy(\"perishable\").agg({\"perishable\": \"count\"}).show()\n", "\n", "\n", "#Calculate number of good v/s bad sales status per location\n", "forecast_df.groupBy(\"sales_status\").agg({\"sales_status\": \"count\"}).show()\n", "\n", "# Count the number of sales that contain a 10% off promotion\n", "countt = forecast_df.filter(forecast_df[\"contains_promotion\"].contains(\"10% off\")).count()\n", "print(countt)\n", "# Perform some complex analysis on the DataFrame\n", "\n", "# Calculate the total sales, quantity sold by product and location\n", "total_sales_by_product_location = forecast_df.groupBy(\"product_name\", \"location\").agg(sum(\"price\").alias(\"total_price\"),sum(\"quantity_ordered\").alias(\"total_quantity_sold\"),avg(\"quantity_sold\").alias(\"avg_quantity_sold\")).sort(desc(\"total_price\"))\n", "\n", "# Group the data by product_name\n", "grouped_df = forecast_df.groupBy(\"product_name\")\n", "\n", "#Sum the quantity_in_stock, quantity_ordered, quantity_sold, and (price * quantity_sold) for each group\n", "aggregated_df = grouped_df.agg(sum(\"quantity_in_stock\").alias(\"total_quantity_in_stock\"),avg(\"price\").alias(\"average_price\"),sum(\"quantity_ordered\").alias(\"total_quantity_ordered\"),sum(\"quantity_sold\").alias(\"total_quantity_sold\"),sum(col(\"price\") * col(\"quantity_sold\")).alias(\"total_sales\"),sum(\"prev_sales\").alias(\"total_prev_sales\"),sum(\"next_sales\").alias(\"total_next_sales\"),).sort(desc(\"total_sales\"))\n", "\n", "#WRITE THE AGGREGATES TO DISK\n", "aggregated_df.write.mode(\"overwrite\").format(\"parquet\").save(dataRoot+\"/app/data.parquet\")\n", "total_sales_by_product_location.write.mode(\"overwrite\").format(\"parquet\").save(dataRoot+\"/app1/data.parquet\")\n", "\n", "end = time.time()\n", "\n", "print(\"Time taken on GPU for Data Analysis: \", end - start)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "spark.stop()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.19" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: examples/SQL+DF-Examples/retail-analytics/notebooks/python/retail-datagen.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "### Generating and Writing Data to GCS" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "import pandas as pd\n", "import multiprocessing as mp\n", "import random\n", "\n", "# You need to update these to your real paths!\n", "dataRoot = os.getenv(\"DATA_ROOT\", '/path/to/your/datasets')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#We define the generate_data function which takes an integer i as input and generates sales data using random numbers. The generated data includes sales ID, product name, price, quantity sold, date of sale, and customer ID. The function returns a tuple of the generated data.\n", "def generate_data(i):\n", " sales_id = \"s_{}\".format(i)\n", " product_name = \"Product_{}\".format(i)\n", " price = random.uniform(1,10)\n", " quantity_sold = random.randint(1,10)\n", " date_of_sale = \"2022-{}-{}\".format(random.randint(1,12), random.randint(1,28))\n", " customer_id = \"c_{}\".format(random.randint(1,10000))\n", " return (sales_id, product_name, price, quantity_sold, date_of_sale, customer_id)\n", "\n", "with mp.Pool(mp.cpu_count()) as p:\n", " sales_data = p.map(generate_data, range(1000000))\n", " sales_data = list(sales_data)\n", " \n", "print(\"write to gcs started\")\n", "sales_df = pd.DataFrame(sales_data, columns=[\"sales_id\", \"product_name\", \"price\", \"quantity_sold\", \"date_of_sale\", \"customer_id\"])\n", "os.makedirs(dataRoot+\"/sales/\", exist_ok=True)\n", "sales_df.to_csv(dataRoot+\"/sales/data.csv\", index=False, header=True)\n", "print(\"Write to gcs completed\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def generate_data(i):\n", " product_name = \"Product_{}\".format(i)\n", " shelf_life = random.randint(1,365)\n", " contains_promotion = \"{} % off\".format(random.randint(0,10))\n", " quantity_in_stock = random.randint(1,100)\n", " location = \"Location_{}\".format(random.randint(1,100))\n", " date_received = \"2022-{}-{}\".format(random.randint(1,12), random.randint(1,28))\n", " return (product_name,shelf_life,contains_promotion,quantity_in_stock, location, date_received)\n", "\n", "with mp.Pool(mp.cpu_count()) as p:\n", " stock_data = p.map(generate_data, range(50000))\n", " stock_data = list(stock_data)\n", " \n", "stock_df = pd.DataFrame(stock_data, columns=[\"product_name\",\"shelf_life\",\"contains_promotion\",\"quantity_in_stock\", \"location\", \"date_received\"])\n", "os.makedirs(dataRoot+\"/stock/\", exist_ok=True)\n", "stock_df.to_json(dataRoot+\"/stock/stock.json\", orient='records')\n", "print(\"Write to gcs completed\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def generate_data(i):\n", " sup_id = \"s_{}\".format(i)\n", " product_name = \"Product_{}\".format(i)\n", " quantity_ordered = random.randint(1,100)\n", " price = random.uniform(1,10)\n", " date_ordered = \"2022-{}-{}\".format(random.randint(1,12), random.randint(1,28))\n", " return (sup_id,product_name, quantity_ordered, price, date_ordered)\n", "\n", "with mp.Pool(mp.cpu_count()) as p:\n", " supplier_data = p.map(generate_data, range(50000))\n", " supplier_data = list(supplier_data)\n", " \n", "supplier_df = pd.DataFrame(supplier_data, columns=[\"sup_id\",\"product_name\", \"quantity_ordered\", \"price\", \"date_ordered\"])\n", "os.makedirs(dataRoot+\"/supplier/\", exist_ok=True)\n", "supplier_df.to_json(dataRoot+\"/supplier/supplier.json\", orient='records')\n", "print(\"Write to gcs completed\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def generate_data(i):\n", " customer_id = \"c_{}\".format(i)\n", " customer_name = \"Customer_{}\".format(i)\n", " age = random.randint(20,70)\n", " gender = random.choice([\"male\", \"female\"])\n", " purchase_history = random.randint(1,100)\n", " contact_info = \"email_{}@gmail.com\".format(i)\n", " return (customer_id,customer_name, age, gender, purchase_history, contact_info)\n", "\n", "with mp.Pool(mp.cpu_count()) as p:\n", " customer_data = p.map(generate_data, range(1000))\n", " customer_data = list(customer_data)\n", " \n", "customer_df = pd.DataFrame(customer_data, columns=[\"customer_id\",\"customer_name\", \"age\", \"gender\", \"purchase_history\", \"contact_info\"])\n", "os.makedirs(dataRoot+\"/customer/\", exist_ok=True)\n", "customer_df.to_csv(dataRoot+\"/customer/customer.csv\", index=False,header=True)\n", "print(\"Write to gcs completed\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def generate_data(i):\n", " product_name = \"Product_{}\".format(i)\n", " competitor_price = random.uniform(1,100)\n", " sales_trend = random.randint(1,100)\n", " demand_forecast = random.randint(1,100)\n", " return (product_name, competitor_price, sales_trend, demand_forecast)\n", "\n", "with mp.Pool(mp.cpu_count()) as p:\n", " market_data = p.map(generate_data, range(500000))\n", " market_data = list(market_data)\n", " \n", "market_df = pd.DataFrame(market_data, columns=[\"product_name\", \"competitor_price\", \"sales_trend\", \"demand_forecast\"])\n", "os.makedirs(dataRoot+\"/market/\", exist_ok=True)\n", "market_df.to_csv(dataRoot+\"/market/market.csv\", index=False,header=True)\n", "print(\"Write to gcs completed\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def generate_data(i):\n", " product_name = \"Product_{}\".format(i)\n", " shipping_cost = random.uniform(1,10)\n", " transportation_cost = random.uniform(1,10)\n", " warehouse_cost = random.uniform(1,10)\n", " return (product_name, shipping_cost, transportation_cost, warehouse_cost)\n", "\n", "with mp.Pool(mp.cpu_count()) as p:\n", " logistic_data = p.map(generate_data, range(500000))\n", " logistic_data = list(logistic_data)\n", " \n", "logistic_df = pd.DataFrame(logistic_data, columns=[\"product_name\", \"shipping_cost\", \"transportation_cost\", \"warehouse_cost\"])\n", "os.makedirs(dataRoot+\"/logistic/\", exist_ok=True)\n", "logistic_df.to_csv(dataRoot+\"/logistic/logistic.csv\", index=False,header=True)\n", "print(\"Write to gcs completed\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.19" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: examples/SQL+DF-Examples/tpcds/README.md ================================================ # TPC-DS Scale Factor 10 (GiB) - CPU Spark vs GPU Spark [TPC-DS](https://www.tpc.org/tpcds/) is a decision support benchmark often used to evaluate performance of OLAP Databases and Big Data systems. The notebook in this folder runs a user-specified subset of the TPC-DS queries on the Scale Factor 10 (GiB) dataset. It uses [TPCDS PySpark](https://github.com/cerndb/SparkTraining/blob/master/notebooks/TPCDS_PySpark_CERN_SWAN_getstarted.ipynb) to execute TPC-DS queries with SparkSQL on GPU and CPU capturing the metrics as a Pandas dataframe. It then plots a comparison bar chart visualizing the GPU acceleration achieved for the queries run with RAPIDS Spark in this very notebook. This notebook can be opened and executed using standard - Jupyter(Lab) - in VSCode with Jupyter [extension](https://marketplace.visualstudio.com/items?itemName=ms-toolsai.jupyter) It can also be opened and evaluated on hosted Notebook environments. Use the link below to launch on Google Colab and connect it to a [GPU instance](https://research.google.com/colaboratory/faq.html). Open In Colab Here is the bar chart from a recent execution on Google Colab's T4 High RAM instance using RAPIDS Spark 26.02.0 with Apache Spark 3.5.0 ![tpcds-speedup](/docs/img/guides/tpcds.png) ================================================ FILE: examples/SQL+DF-Examples/tpcds/notebooks/TPCDS-SF10.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "editable": true, "id": "HtgYO0bXEBrN", "slideshow": { "slide_type": "" }, "tags": [] }, "source": [ "# TPC-DS 10GiB - Apache Spark acceleration on GPU with RAPIDS Spark\n", "\n", "based on https://colab.research.google.com/github/LucaCanali/Miscellaneous/blob/master/Performance_Testing/TPCDS_PySpark/Labs_and_Notes/TPCDS_PySpark_getstarted.ipynb#scrollTo=6bab7772" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Install packages" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "spark_version='3.5.5'\n", "rapids_version='26.02.0'\n", "sparkmeasure_version='0.27'" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "executionInfo": { "elapsed": 1630, "status": "ok", "timestamp": 1729291037060, "user": { "displayName": "Gera Shegalov", "userId": "07399839501144323282" }, "user_tz": 420 }, "id": "Yq230e1Nho_M" }, "outputs": [], "source": [ "%pip install --quiet \\\n", " tpcds_pyspark \\\n", " pyspark=={spark_version} \\\n", " pandas \\\n", " sparkmeasure=={sparkmeasure_version}.0 \\\n", " matplotlib" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Import modules" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "executionInfo": { "elapsed": 1052, "status": "ok", "timestamp": 1729291488008, "user": { "displayName": "Gera Shegalov", "userId": "07399839501144323282" }, "user_tz": 420 }, "id": "uq_LmKsB36R_" }, "outputs": [], "source": [ "from importlib.resources import files\n", "from pyspark.sql import SparkSession\n", "from tpcds_pyspark import TPCDS\n", "import glob\n", "import os\n", "import pandas as pd\n", "import re\n", "import time" ] }, { "cell_type": "markdown", "metadata": { "id": "edMCFrhvgDS8" }, "source": [ "# Download TPC-DS 10GiB Scale Parquet Dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "executionInfo": { "elapsed": 41530, "status": "ok", "timestamp": 1729292943990, "user": { "displayName": "Gera Shegalov", "userId": "07399839501144323282" }, "user_tz": 420 }, "id": "DY8TkhPQTjbB" }, "outputs": [], "source": [ "if not os.path.isdir('tpcds_10'):\n", " if not os.path.isfile('tpcds_10.zip'):\n", " !wget https://sparkdltrigger.web.cern.ch/sparkdltrigger/TPCDS/tpcds_10.zip\n", " !unzip -q tpcds_10.zip" ] }, { "cell_type": "markdown", "metadata": { "id": "6tgF9LWcgUEs" }, "source": [ "# Init a SparkSession with RAPIDS Spark" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Detect Scala Version used in PySpark package" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pyspark_files = files('pyspark')\n", "spark_sql_jar_path, *_ = glob.glob(f\"{pyspark_files}/*/spark-sql_*jar\")\n", "spark_sql_jar = os.path.basename(spark_sql_jar_path)\n", "scala_version = re.search(r'^spark-sql_(\\d+.\\d+)-.*\\.jar$', spark_sql_jar).group(1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create Spark Session" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "executionInfo": { "elapsed": 39420, "status": "ok", "timestamp": 1729289098419, "user": { "displayName": "Gera Shegalov", "userId": "07399839501144323282" }, "user_tz": 420 }, "id": "-L-wMZTpfYxs" }, "outputs": [], "source": [ "extra_packages = [\n", " f\"com.nvidia:rapids-4-spark_{scala_version}:{rapids_version}\",\n", " f\"ch.cern.sparkmeasure:spark-measure_{scala_version}:{sparkmeasure_version}\"\n", "]\n", "spark = (\n", " SparkSession.builder\n", " .appName('TPCDS PySpark RAPIDS=ON/OFF')\n", " .config('spark.driver.memory', '5g')\n", " .config('spark.plugins', 'com.nvidia.spark.SQLPlugin')\n", " .config('spark.jars.packages', ','.join(extra_packages))\n", " .getOrCreate()\n", ")\n", "spark\n" ] }, { "cell_type": "markdown", "metadata": { "id": "_4sYje2NiNA7" }, "source": [ "# Verify SQL Acceleration on GPU can be enabled by checking the query plan" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 0 }, "executionInfo": { "elapsed": 5921, "status": "ok", "timestamp": 1729289104337, "user": { "displayName": "Gera Shegalov", "userId": "07399839501144323282" }, "user_tz": 420 }, "id": "nUyQBKtkga9y", "outputId": "5d493a51-58de-4aed-bbaf-d73c82769836" }, "outputs": [], "source": [ "spark.conf.set('spark.rapids.sql.enabled', True)\n", "sum_df = spark.range(1000).selectExpr('SUM(*)')\n", "sum_df.collect()\n", "sum_df.explain()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# TPCDS App" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 0 }, "executionInfo": { "elapsed": 4, "status": "ok", "timestamp": 1729289104337, "user": { "displayName": "Gera Shegalov", "userId": "07399839501144323282" }, "user_tz": 420 }, "id": "BYPgafupcxaY", "outputId": "fdfb427f-6cc0-4dff-9295-dc44e6ead132" }, "outputs": [], "source": [ "# https://github.com/LucaCanali/Miscellaneous/tree/master/Performance_Testing/TPCDS_PySpark/tpcds_pyspark/Queries\n", "\n", "# queries = None to run all (takes much longer)\n", "queries = None\n", "queries = [\n", " 'q14a',\n", " 'q14b',\n", " 'q23a',\n", " 'q23b',\n", " # 'q24a',\n", " # 'q24b',\n", " # 'q88',\n", "]\n", "\n", "demo_start = time.time()\n", "tpcds = TPCDS(data_path='./tpcds_10', num_runs=1, queries_repeat_times=1, queries=queries)" ] }, { "cell_type": "markdown", "metadata": { "id": "0Yaaw2GfliC5" }, "source": [ "## Register TPC-DS tables before running queries" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 0 }, "executionInfo": { "elapsed": 2992, "status": "ok", "timestamp": 1729289107327, "user": { "displayName": "Gera Shegalov", "userId": "07399839501144323282" }, "user_tz": 420 }, "id": "kfsHodFqdDl7", "outputId": "5a810f9d-e353-456c-b7bb-48ae3290178a" }, "outputs": [], "source": [ "tpcds.map_tables()" ] }, { "cell_type": "markdown", "metadata": { "id": "bs6X_54UhuqJ" }, "source": [ "## Measure Apache Spark GPU" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 0 }, "executionInfo": { "elapsed": 45658, "status": "ok", "timestamp": 1729290819190, "user": { "displayName": "Gera Shegalov", "userId": "07399839501144323282" }, "user_tz": 420 }, "id": "8vXDasUom70g", "outputId": "adccdd7f-99f0-4c82-d600-056b59f53933" }, "outputs": [], "source": [ "tpcds.spark.conf.set('spark.rapids.sql.enabled', True)\n", "%time tpcds.run_TPCDS()\n", "gpu_grouped_results = tpcds.grouped_results_pdf.copy()\n", "gpu_grouped_results" ] }, { "cell_type": "markdown", "metadata": { "id": "ulyFidEPhg_l" }, "source": [ "## Measure Apache Spark CPU" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 0 }, "executionInfo": { "elapsed": 135425, "status": "ok", "timestamp": 1729289242749, "user": { "displayName": "Gera Shegalov", "userId": "07399839501144323282" }, "user_tz": 420 }, "id": "Dg0itS7cdIf4", "outputId": "4ce1f8a2-5ac7-4805-e6f6-37a8acb7e039" }, "outputs": [], "source": [ "tpcds.spark.conf.set('spark.rapids.sql.enabled', False)\n", "%time tpcds.run_TPCDS()\n", "cpu_grouped_results = tpcds.grouped_results_pdf.copy()\n", "cpu_grouped_results" ] }, { "cell_type": "markdown", "metadata": { "id": "PcZ12b13h3cq" }, "source": [ "## Show Speedup Factors achieved by GPU\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "executionInfo": { "elapsed": 5, "status": "ok", "timestamp": 1729289293047, "user": { "displayName": "Gera Shegalov", "userId": "07399839501144323282" }, "user_tz": 420 }, "id": "cJxS9Nqi3AQj" }, "outputs": [], "source": [ "res = pd.merge(cpu_grouped_results, gpu_grouped_results, on='query', how='inner', suffixes=['_cpu', '_gpu'])\n", "res['speedup'] = res['elapsedTime_cpu'] / res['elapsedTime_gpu']\n", "res = res.sort_values(by='elapsedTime_cpu', ascending=False)\n", "res" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "demo_dur = time.time() - demo_start\n", "print(f\"CPU and GPU run took: {demo_dur=} seconds\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 510 }, "executionInfo": { "elapsed": 1041, "status": "ok", "timestamp": 1729289294084, "user": { "displayName": "Gera Shegalov", "userId": "07399839501144323282" }, "user_tz": 420 }, "id": "wn7u33fZlUJL", "outputId": "8d1ef757-e5c2-4761-fc58-f65f833bdffc" }, "outputs": [], "source": [ "res.plot(title='TPC-DS query elapsedTime on CPU vs GPU (lower is better)', \n", " kind='bar', x='query', y=['elapsedTime_cpu', 'elapsedTime_gpu'],\n", " color=['blue', '#76B900'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 510 }, "executionInfo": { "elapsed": 381, "status": "ok", "timestamp": 1729289294462, "user": { "displayName": "Gera Shegalov", "userId": "07399839501144323282" }, "user_tz": 420 }, "id": "hW-LZponwGQE", "outputId": "ed456120-ca7f-4c91-a2bf-de87c2401f0c" }, "outputs": [], "source": [ "res.plot(title='Speedup factors of TPC-DS queries on GPU', kind='bar', \n", " x='query', y='speedup', color='#76B900')" ] }, { "cell_type": "markdown", "metadata": { "id": "Pk2TR4yimNqP" }, "source": [ "# Run Queries interactively" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tpcds_pyspark_files = files('tpcds_pyspark')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "executionInfo": { "elapsed": 4, "status": "ok", "timestamp": 1729289294462, "user": { "displayName": "Gera Shegalov", "userId": "07399839501144323282" }, "user_tz": 420 }, "id": "RpIl6NyNzqYU" }, "outputs": [], "source": [ "query = 'q88'\n", "with open(f\"{tpcds_pyspark_files}/Queries/{query}.sql\") as f:\n", " q = f.read()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 0 }, "executionInfo": { "elapsed": 3, "status": "ok", "timestamp": 1729289294462, "user": { "displayName": "Gera Shegalov", "userId": "07399839501144323282" }, "user_tz": 420 }, "id": "cuCVL1Ed1lQd", "outputId": "d256f4b7-e0e2-450c-ba88-aff0d7571510" }, "outputs": [], "source": [ "print(q)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 0 }, "editable": true, "executionInfo": { "elapsed": 1470, "status": "ok", "timestamp": 1729289295930, "user": { "displayName": "Gera Shegalov", "userId": "07399839501144323282" }, "user_tz": 420 }, "id": "n4QUdq17040i", "outputId": "7d7c7562-fae6-4426-97a7-ec23b8fe2f0d", "slideshow": { "slide_type": "" }, "tags": [] }, "outputs": [], "source": [ "spark.conf.set('spark.rapids.sql.enabled', True)\n", "df = spark.sql(q)\n", "%time df.collect()" ] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "T4", "machine_shape": "hm", "provenance": [] }, "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: examples/UDF-Examples/RAPIDS-accelerated-UDFs/Dockerfile ================================================ # # Copyright (c) 2021-2026, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # A container that can be used to build UDF native code against libcudf ARG CUDA_VERSION=12.9.1 ARG LINUX_VERSION=rockylinux8 FROM nvidia/cuda:${CUDA_VERSION}-devel-${LINUX_VERSION} ARG TOOLSET_VERSION=13 ENV TOOLSET_VERSION=13 ARG PARALLEL_LEVEL=10 ENV PARALLEL_LEVEL=10 ### Install basic requirements RUN dnf --enablerepo=powertools install -y \ gcc-toolset-${TOOLSET_VERSION} \ git \ java-1.8.0-openjdk \ maven \ ninja-build \ patch \ python39 \ scl-utils \ tar \ wget \ zlib-devel \ && alternatives --set python /usr/bin/python3 # 3.22.3: CUDA architecture 'native' support + flexible CMAKE__*_LAUNCHER for ccache ARG CMAKE_VERSION=3.30.4 # default x86_64 from x86 build, aarch64 cmake for arm build ARG CMAKE_ARCH=x86_64 RUN cd /usr/local && wget --quiet https://github.com/Kitware/CMake/releases/download/v${CMAKE_VERSION}/cmake-${CMAKE_VERSION}-linux-${CMAKE_ARCH}.tar.gz && \ tar zxf cmake-${CMAKE_VERSION}-linux-${CMAKE_ARCH}.tar.gz && \ rm cmake-${CMAKE_VERSION}-linux-${CMAKE_ARCH}.tar.gz ENV PATH /usr/local/cmake-${CMAKE_VERSION}-linux-${CMAKE_ARCH}/bin:$PATH # ccache for interactive builds ARG CCACHE_VERSION=4.11.2 RUN cd /tmp && wget --quiet https://github.com/ccache/ccache/releases/download/v${CCACHE_VERSION}/ccache-${CCACHE_VERSION}.tar.gz && \ tar zxf ccache-${CCACHE_VERSION}.tar.gz && \ rm ccache-${CCACHE_VERSION}.tar.gz && \ cd ccache-${CCACHE_VERSION} && \ mkdir build && \ cd build && \ scl enable gcc-toolset-${TOOLSET_VERSION} \ "cmake .. \ -DCMAKE_BUILD_TYPE=Release \ -DZSTD_FROM_INTERNET=ON \ -DREDIS_STORAGE_BACKEND=OFF && \ cmake --build . --parallel ${PARALLEL_LEVEL} --target install" && \ cd ../.. && \ rm -rf ccache-${CCACHE_VERSION} ENTRYPOINT /usr/bin/scl enable gcc-toolset-${TOOLSET_VERSION} -- bash ================================================ FILE: examples/UDF-Examples/RAPIDS-accelerated-UDFs/README.md ================================================ # RAPIDS Accelerated UDF Examples This project contains sample implementations of RAPIDS accelerated user-defined functions. The ideal solution would be to replace the UDF with a series of DataFrame or SQL operations. If that is not possible, we also provide a [UDF compiler extension](https://nvidia.github.io/spark-rapids/docs/additional-functionality/udf-to-catalyst-expressions.html) to translate UDFs to Catalyst expressions. The extension is limited to only support compiling simple operations. For complicated cases, you can choose to implement a RAPIDS accelerated UDF. ## Spark Scala UDF Examples [URLDecode](src/main/scala/com/nvidia/spark/rapids/udf/scala/URLDecode.scala) is the simplest demo for getting started. From the code you can see there is an original CPU implementation provided by the `apply` method. We only need to implement the RapidsUDF interface which provides a single method we need to override called `evaluateColumnar`. The CPU URLDecode function processes the input row by row, but the GPU evaluateColumnar returns a cudf ColumnVector, because the GPU get its speed by performing operations on many rows at a time. In the `evaluateColumnar` function, there is a cudf implementation of URL decode that we're leveraging, so we don't need to write any native C++ code. This is all done through the [Java APIs of RAPIDS cudf](https://docs.rapids.ai/api/cudf-java/legacy). The benefit to implement via the Java API is ease of development, but the memory model is not friendly for doing GPU operations because the JVM makes the assumption that everything we're trying to do is in heap memory. We need to free the GPU resources in a timely manner with try-finally blocks. Note that we need to implement both CPU and GPU functions so the UDF will still work if a higher-level operation involving the RAPIDS accelerated UDF falls back to the CPU. - [URLDecode](src/main/scala/com/nvidia/spark/rapids/udf/scala/URLDecode.scala) decodes URL-encoded strings using the [Java APIs of RAPIDS cudf](https://docs.rapids.ai/api/cudf-java/legacy) - [URLEncode](src/main/scala/com/nvidia/spark/rapids/udf/scala/URLEncode.scala) URL-encodes strings using the [Java APIs of RAPIDS cudf](https://docs.rapids.ai/api/cudf-java/legacy) ## Spark Java UDF Examples Below are some examples for implementing RAPIDS accelerated Scala UDF via JNI and native code. If there is no existing simple Java API we could leverage, we can write native custom code. Take [CosineSimilarity](src/main/java/com/nvidia/spark/rapids/udf/java/CosineSimilarity.java) as the example, the Java class for the UDF is similar as the previous URLDecode/URLEncode demo. We need to implement a cosineSimilarity function in C++ code and goes into the native code as quickly as possible, because it is easier to write the code safely. In the native code, it `reinterpret_cast` the input to a column view, do some sanity checking and convert to list column views, then compute the cosine similarity, finally return the unique pointer to a column, release the underlying resources. On Java side we are going to wrap it in a column vector and own that resource. In `cosine_similarity.cu` we implement the computation as the actual CUDA kernel. In the CUDA kernel we can leverage the [Thrust template library](https://docs.nvidia.com/cuda/thrust/index.html) to write the standard algorithms for GPU parallelizing code. The benefit of implementing the UDF in native code is for maximum control over GPU memory utilization and performance. However the trade-off is a more complicated build environment, as we need to build against libcudf with significantly longer build times. Implementing a RAPIDS accelerated UDF in native code is a significant effort. - [URLDecode](src/main/java/com/nvidia/spark/rapids/udf/java/URLDecode.java) decodes URL-encoded strings using the [Java APIs of RAPIDS cudf](https://docs.rapids.ai/api/cudf-java/legacy) - [URLEncode](src/main/java/com/nvidia/spark/rapids/udf/java/URLEncode.java) URL-encodes strings using the [Java APIs of RAPIDS cudf](https://docs.rapids.ai/api/cudf-java/legacy) - [CosineSimilarity](src/main/java/com/nvidia/spark/rapids/udf/java/CosineSimilarity.java) computes the [cosine similarity](https://en.wikipedia.org/wiki/Cosine_similarity) between two float vectors using [native code](src/main/cpp/src) ## Hive UDF Examples Below are some examples for implementing RAPIDS accelerated Hive UDF via JNI and native code. - [URLDecode](src/main/java/com/nvidia/spark/rapids/udf/hive/URLDecode.java) implements a Hive simple UDF using the [Java APIs of RAPIDS cudf](https://docs.rapids.ai/api/cudf-java/legacy) to decode URL-encoded strings - [URLEncode](src/main/java/com/nvidia/spark/rapids/udf/hive/URLEncode.java) implements a Hive generic UDF using the [Java APIs of RAPIDS cudf](https://docs.rapids.ai/api/cudf-java/legacy) to URL-encode strings - [StringWordCount](src/main/java/com/nvidia/spark/rapids/udf/hive/StringWordCount.java) implements a Hive simple UDF using [native code](src/main/cpp/src) to count words in strings ## Building and run the tests without Native Code Examples Some UDF examples use native code in their implementation. Building the native code requires a libcudf build environment, so these examples do not build by default. ### Prerequisites Download [Apache Spark](https://spark.apache.org/downloads.html) and set `SPARK_HOME` environment variable. Install Python 3.8+, then install `pytest`, `sre_yield` by using pip or conda. For example: ``` export SPARK_HOME=path-to-spark pip install pytest # If running in the docker container, please use pip3 pip install sre_yield # If running in the docker container, please use pip3 ``` Run the following command to build and run tests ```bash cd spark-rapids-examples/examples/UDF-Examples/RAPIDS-accelerated-UDFs mvn clean package ./run_pyspark_from_build.sh -m "not rapids_udf_example_native" ``` ## Building with Native Code Examples and run test cases The `udf-native-examples` Maven profile can be used to include the native UDF examples in the build, i.e.: specify `-Pudf-native-examples` on the `mvn` command-line. ### Creating a libcudf Build Environment Building the native code requires a libcudf build environment. The `Dockerfile` in this directory can be used to setup a Docker image that provides a libcudf build environment. This repository will either need to be cloned or mounted into a container using that Docker image. The `Dockerfile` contains build arguments to control the Linux version, CUDA version, and other settings. See the top of the `Dockerfile` for details. First install docker and [nvidia-docker](https://github.com/NVIDIA/nvidia-docker) Run the following commands to build and start a docker ```bash cd spark-rapids-examples/examples/UDF-Examples/RAPIDS-accelerated-UDFs docker build -t my-local:my-udf-example . nvidia-docker run -it my-local:my-udf-example ``` ### Build the udf-examples jar #### Option 1: Fast Build Using Prebuilt libcudf (Recommended) Instead of building cuDF from source (which takes a long time), you can use the prebuilt `libcudf.so` from the `rapids-4-spark` jar. This is much faster! **Prerequisites:** - rapids-4-spark jar must be available in your local Maven repository **Steps:** 1. Extract libcudf.so and cuDF headers (automatic with Maven): ```bash cd spark-rapids-examples/examples/UDF-Examples/RAPIDS-accelerated-UDFs mvn clean package -Pudf-native-examples ``` The build will automatically: - Extract `libcudf.so` from the rapids-4-spark jar - Clone cuDF repository for headers (shallow clone) - Build only your UDF native code against the prebuilt library **Or manually extract first:** ```bash ./extract-cudf-libs.sh mvn clean package -Pudf-native-examples ``` This approach typically reduces the native cuDF build time by almost **2 hours**! #### Option 2: Build cuDF from Source (Slow but Complete) If you need to build cuDF from source, you can disable the prebuilt library option. **How it works:** - The Maven property `USE_PREBUILT_CUDF` (default: `ON` in pom.xml) is passed to CMake - Use `-DUSE_PREBUILT_CUDF=OFF` as a Maven system property to override the default - Maven replaces `${USE_PREBUILT_CUDF}` in pom.xml and passes it to CMake as `-DUSE_PREBUILT_CUDF=OFF` **Build with source:** ```bash cd spark-rapids-examples/examples/UDF-Examples/RAPIDS-accelerated-UDFs export LOCAL_CCACHE_DIR="$HOME/.ccache" mkdir -p $LOCAL_CCACHE_DIR export CCACHE_DIR="$LOCAL_CCACHE_DIR" export CMAKE_C_COMPILER_LAUNCHER="ccache" export CMAKE_CXX_COMPILER_LAUNCHER="ccache" export CMAKE_CUDA_COMPILER_LAUNCHER="ccache" export CMAKE_CXX_LINKER_LAUNCHER="ccache" mvn clean package -Pudf-native-examples -DUSE_PREBUILT_CUDF=OFF ``` **Alternative: Edit CMakeLists.txt directly** You can also edit `src/main/cpp/CMakeLists.txt` and set: ```cmake option(USE_PREBUILT_CUDF "Use prebuilt libcudf.so from rapids-4-spark jar" OFF) ``` #### Configurable Maven Properties You can customize the build by passing Maven system properties via `-D=`. These properties are defined in `pom.xml` and passed to CMake: | Maven Property | Default Value | Description | |----------------|---------------|-------------| | `USE_PREBUILT_CUDF` | `ON` | Use prebuilt libcudf.so from rapids-4-spark jar (faster build) | | `GPU_ARCHS` | `RAPIDS` | GPU architectures to compile for (e.g., `60;70;75;80`) | | `CPP_PARALLEL_LEVEL` | `10` | Number of parallel compilation jobs | | `BUILD_UDF_BENCHMARKS` | `OFF` | Build benchmark executables | | `PER_THREAD_DEFAULT_STREAM` | `ON` | Enable per-thread default CUDA streams | | `CUDF_ENABLE_ARROW_S3` | `OFF` | Enable Arrow S3 support in cuDF | | `cudf.git.branch` | `main` | cuDF git branch to clone for headers | | `skipCudfExtraction` | `false` | Skip extracting cuDF dependencies from jar | **Example usage:** ```bash # Build for specific GPU architectures with more parallel jobs mvn clean package -Pudf-native-examples -DGPU_ARCHS="75;80;86" -DCPP_PARALLEL_LEVEL=16 # Skip cuDF extraction and use existing dependencies mvn clean package -Pudf-native-examples -DskipCudfExtraction=true ``` #### Using ccache to Accelerate Builds The Docker container has installed ccache 4.6 to accelerate the incremental building. You can change the LOCAL_CCACHE_DIR to a mounted folder so that the cache can persist. If you don't want to use ccache, you can remove or unset the ccache environment variables. ```bash unset CCACHE_DIR unset CMAKE_C_COMPILER_LAUNCHER unset CMAKE_CXX_COMPILER_LAUNCHER unset CMAKE_CUDA_COMPILER_LAUNCHER unset CMAKE_CXX_LINKER_LAUNCHER ``` The first build could take a long time (e.g.: 1.5 hours). Then the rapids-4-spark-udf-examples*.jar is generated under RAPIDS-accelerated-UDFs/target directory. The following build can benefit from ccache if you enable it. If you want to enable building with ccache on your own system, please refer to the commands which build ccache from the source code in the Dockerfile. ### Run all the examples including native examples in the docker See the above [Prerequisites section](#prerequisites) ``` export SPARK_HOME=path-to-spark pip install pytest pip install sre_yield ``` Run the following command to run tests ``` ./run_pyspark_from_build.sh ``` ## How to run the Native UDFs on Spark local mode First finish the steps in [Building with Native Code Examples and run test cases](#building-with-native-code-examples-and-run-test-cases) section, then do the following inside the Docker container. ### Get jars from Maven Central [rapids-4-spark_2.12-26.02.0.jar](https://repo1.maven.org/maven2/com/nvidia/rapids-4-spark_2.12/26.02.0/rapids-4-spark_2.12-26.02.0.jar) ### Launch a local mode Spark ```bash export SPARK_RAPIDS_PLUGIN_JAR=path-to-rapids-4-spark-jar export SPARK_RAPIDS_UDF_EXAMPLES_JAR=path-to-udf-examples-jar $SPARK_HOME/bin/pyspark --master local[*] \ --conf spark.executor.cores=6 \ --driver-memory 5G \ --executor-memory 5G \ --jars ${SPARK_RAPIDS_PLUGIN_JAR},${SPARK_RAPIDS_UDF_EXAMPLES_JAR} \ --conf spark.plugins=com.nvidia.spark.SQLPlugin \ --conf spark.rapids.sql.enabled=true ``` ### Test native based UDF Input the following commands to test wordcount JNI UDF ```python from pyspark.sql.types import * schema = StructType([ StructField("c1", StringType()), StructField("c2", IntegerType()), ]) data = [ ("a b c d",1), ("",2), (None,3), ("the quick brown fox jumped over the lazy dog",3), ] df = spark.createDataFrame( SparkContext.getOrCreate().parallelize(data, numSlices=2), schema) df.createOrReplaceTempView("tab") spark.sql("CREATE TEMPORARY FUNCTION {} AS '{}'".format("wordcount", "com.nvidia.spark.rapids.udf.hive.StringWordCount")) spark.sql("select c1, wordcount(c1) from tab").show() spark.sql("select c1, wordcount(c1) from tab").explain() ``` ================================================ FILE: examples/UDF-Examples/RAPIDS-accelerated-UDFs/clone-cudf-repo.sh ================================================ #!/bin/bash # # Copyright (c) 2026, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ############################################################################### # Clone or update cuDF repository for header files # # This script is called by Maven during the build process to obtain cuDF # headers needed for compiling native UDF code. # # Usage: # clone-cudf-repo.sh # # Arguments: # target_directory - Directory where cuDF repo will be cloned # branch_name - Git branch to clone/checkout # # Exit codes: # 0 - Success # 1 - Failed to clone, fetch, or checkout ############################################################################### set -e set -o pipefail # Parse arguments if [ $# -ne 2 ]; then echo "ERROR: Usage: $0 " >&2 exit 1 fi CUDF_DIR="$1" BRANCH="$2" echo "================================================" echo "cuDF Repository Management" echo " Target directory: $CUDF_DIR" echo " Branch: $BRANCH" echo "================================================" # Check if repository already exists if [ ! -d "$CUDF_DIR/.git" ]; then # Repository doesn't exist - clone it echo "Cloning cuDF repository ($BRANCH branch)..." git clone --depth 1 --branch "$BRANCH" \ https://github.com/rapidsai/cudf.git "$CUDF_DIR" || { echo "ERROR: Failed to clone cuDF from branch $BRANCH" >&2 echo "Please check:" >&2 echo " 1. Network connectivity to GitHub" >&2 echo " 2. Branch '$BRANCH' exists in cuDF repository" >&2 exit 1 } echo "✓ Successfully cloned cuDF repository" else # Repository exists - verify and update if needed echo "cuDF repository exists, verifying branch..." cd "$CUDF_DIR" || { echo "ERROR: Cannot access directory $CUDF_DIR" >&2 exit 1 } # Get current branch CURRENT_BRANCH=$(git rev-parse --abbrev-ref HEAD 2>/dev/null || echo "unknown") if [ "$CURRENT_BRANCH" != "$BRANCH" ]; then # Branch mismatch - fetch and switch to correct branch echo "Branch mismatch detected:" echo " Current branch: $CURRENT_BRANCH" echo " Expected branch: $BRANCH" echo "Fetching and switching to $BRANCH..." git fetch --depth 1 origin "$BRANCH" || { echo "ERROR: Failed to fetch branch $BRANCH from origin" >&2 echo "Please check:" >&2 echo " 1. Network connectivity to GitHub" >&2 echo " 2. Branch '$BRANCH' exists in cuDF repository" >&2 exit 1 } git checkout "$BRANCH" || { echo "ERROR: Failed to checkout branch $BRANCH" >&2 exit 1 } git reset --hard "origin/$BRANCH" || { echo "ERROR: Failed to reset to origin/$BRANCH" >&2 exit 1 } echo "✓ Switched to branch $BRANCH" else echo "✓ Already on correct branch ($BRANCH)" fi fi echo "================================================" echo "✓ cuDF repository ready at: $CUDF_DIR" echo " Branch: $BRANCH" echo "================================================" exit 0 ================================================ FILE: examples/UDF-Examples/RAPIDS-accelerated-UDFs/conftest.py ================================================ # Copyright (c) 2020-2022, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. def pytest_addoption(parser): """Pytest hook to define command line options for pytest""" parser.addoption( "--mortgage_format", action="store", default="parquet", help="format of Mortgage data" ) parser.addoption( "--mortgage_path", action="store", default=None, help="path to Mortgage data" ) parser.addoption( "--std_input_path", action="store", default=None, help="path to standard input files" ) parser.addoption( "--tmp_path", action="store", default=None, help="path to store tmp files" ) parser.addoption( "--debug_tmp_path", action='store_true', default=False, help="if true don't delete tmp_path contents for debugging" ) parser.addoption( "--runtime_env", action='store', default="Apache", help="the runtime environment for the tests - apache or databricks" ) parser.addoption( "--cudf_udf", action='store_true', default=False, help="if true enable cudf_udf test" ) parser.addoption( "--rapids_udf_example_native", action='store_true', default=False, help="if true enable tests for RAPIDS UDF examples with native code" ) parser.addoption( "--test_type", action='store', default="developer", help="the type of tests that are being run to help check all the correct tests are run - developer, pre-commit, or nightly" ) ================================================ FILE: examples/UDF-Examples/RAPIDS-accelerated-UDFs/extract-cudf-libs.sh ================================================ #!/bin/bash # # Copyright (c) 2026, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ############################################################################### # Extract libcudf.so from rapids-4-spark jar # # This script extracts prebuilt cuDF libraries from the rapids-4-spark jar # to enable faster builds by avoiding building cuDF from source. # # Configuration values are read from pom.xml by default, but can be overridden # using environment variables: # # Usage: # ./extract-cudf-libs.sh # # Environment Variables (optional, will use pom.xml values if not set): # RAPIDS4SPARK_VERSION - rapids-4-spark version (e.g., 26.02.0 or 26.06.0-SNAPSHOT) # SCALA_VERSION - Scala binary version (e.g., 2.12, 2.13) # CUDA_VERSION - CUDA version (e.g., cuda11, cuda12) # CUDF_BRANCH - cuDF git branch for headers (e.g., main, branch-26.02) # # Example with overrides: # RAPIDS4SPARK_VERSION=26.02.0 CUDA_VERSION=cuda11 ./extract-cudf-libs.sh ############################################################################### set -e SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" TARGET_DIR="$SCRIPT_DIR/target" NATIVE_DEPS_DIR="$TARGET_DIR/native-deps" CUDF_REPO_DIR="$TARGET_DIR/cudf-repo" POM_FILE="$SCRIPT_DIR/pom.xml" # Function to extract property value from pom.xml # Usage: extract_pom_property "property_name" extract_pom_property() { local property_name="$1" local value # Use xmllint if available (more reliable) if command -v xmllint >/dev/null 2>&1; then value=$(xmllint --xpath "string(//*[local-name()='project']/*[local-name()='properties']/*[local-name()='${property_name}'])" "$POM_FILE" 2>/dev/null) else # Fallback to grep/sed (less robust but widely available) value=$(grep -A 1 "<${property_name}>" "$POM_FILE" | grep -v "^--$" | sed -n "s/.*<${property_name}>\(.*\)<\/${property_name}>.*/\1/p" | head -1 | xargs) fi echo "$value" } echo "==================================================" echo "Extract cuDF Dependencies for UDF Examples" echo "==================================================" echo "Reading configuration from pom.xml..." # Read defaults from pom.xml POM_RAPIDS4SPARK_VERSION=$(extract_pom_property "rapids4spark.version") POM_SCALA_VERSION=$(extract_pom_property "scala.binary.version") POM_CUDA_VERSION=$(extract_pom_property "cuda.version") POM_CUDF_BRANCH=$(extract_pom_property "cudf.git.branch") # Use environment variables if set, otherwise use pom.xml values RAPIDS4SPARK_VERSION="${RAPIDS4SPARK_VERSION:-${POM_RAPIDS4SPARK_VERSION}}" SCALA_VERSION="${SCALA_VERSION:-${POM_SCALA_VERSION}}" CUDA_VERSION="${CUDA_VERSION:-${POM_CUDA_VERSION}}" CUDF_BRANCH="${CUDF_BRANCH:-${POM_CUDF_BRANCH}}" # Validate that we have all required values if [ -z "$RAPIDS4SPARK_VERSION" ] || [ -z "$SCALA_VERSION" ] || [ -z "$CUDA_VERSION" ] || [ -z "$CUDF_BRANCH" ]; then echo "ERROR: Failed to read required properties from pom.xml" >&2 echo "Please ensure pom.xml exists and contains all required properties:" >&2 echo " - rapids4spark.version" >&2 echo " - scala.binary.version" >&2 echo " - cuda.version" >&2 echo " - cudf.git.branch" >&2 exit 1 fi echo "Configuration:" echo " RAPIDS4SPARK_VERSION: $RAPIDS4SPARK_VERSION" echo " SCALA_VERSION: $SCALA_VERSION" echo " CUDA_VERSION: $CUDA_VERSION" echo " CUDF_BRANCH: $CUDF_BRANCH" echo "==================================================" # Create directories mkdir -p "$NATIVE_DEPS_DIR" mkdir -p "$CUDF_REPO_DIR" # Find rapids-4-spark jar in local Maven repository MAVEN_REPO="${HOME}/.m2/repository" # Try multiple naming patterns JAR_PATH_WITH_CLASSIFIER="$MAVEN_REPO/com/nvidia/rapids-4-spark_${SCALA_VERSION}/${RAPIDS4SPARK_VERSION}/rapids-4-spark_${SCALA_VERSION}-${RAPIDS4SPARK_VERSION}-${CUDA_VERSION}.jar" JAR_PATH_NO_CLASSIFIER="$MAVEN_REPO/com/nvidia/rapids-4-spark_${SCALA_VERSION}/${RAPIDS4SPARK_VERSION}/rapids-4-spark_${SCALA_VERSION}-${RAPIDS4SPARK_VERSION}.jar" echo "Looking for rapids-4-spark jar..." echo " Pattern 1 (with classifier): $JAR_PATH_WITH_CLASSIFIER" echo " Pattern 2 (no classifier): $JAR_PATH_NO_CLASSIFIER" if [ -f "$JAR_PATH_WITH_CLASSIFIER" ]; then JAR_PATH="$JAR_PATH_WITH_CLASSIFIER" echo "✓ Found jar (with classifier): $JAR_PATH" elif [ -f "$JAR_PATH_NO_CLASSIFIER" ]; then JAR_PATH="$JAR_PATH_NO_CLASSIFIER" echo "✓ Found jar (no classifier): $JAR_PATH" else echo "" echo "ERROR: rapids-4-spark jar not found!" echo "Tried:" echo " $JAR_PATH_WITH_CLASSIFIER" echo " $JAR_PATH_NO_CLASSIFIER" echo "" echo "For SNAPSHOT versions:" echo " cd /path/to/spark-rapids" echo " mvn clean install -DskipTests" echo "" echo "For release versions:" echo " mvn dependency:get -Dartifact=com.nvidia:rapids-4-spark_${SCALA_VERSION}:${RAPIDS4SPARK_VERSION}:jar:${CUDA_VERSION}" exit 1 fi # Extract libcudf.so and dependencies echo "Extracting native libraries from jar..." echo " Jar: $JAR_PATH" echo " Looking for: */libcudf.so*, */libnvcomp.so*" # Use unzip without -q to capture output, but redirect to log for debugging UNZIP_OUTPUT=$(unzip -o "$JAR_PATH" "*/libcudf.so*" "*/libnvcomp.so*" -d "$TARGET_DIR/temp" 2>&1) UNZIP_EXIT_CODE=$? # Check unzip exit code if [ $UNZIP_EXIT_CODE -ne 0 ]; then echo "ERROR: Failed to extract libraries from jar" >&2 echo "unzip exit code: $UNZIP_EXIT_CODE" >&2 # Provide helpful diagnostics case $UNZIP_EXIT_CODE in 11) echo "Reason: No matching files found in jar" >&2 echo "" >&2 echo "The jar may not contain native libraries for your platform." >&2 echo "Expected patterns: */libcudf.so*, */libnvcomp.so*" >&2 echo "" >&2 echo "Listing jar contents:" >&2 unzip -l "$JAR_PATH" | grep -E '\.(so|dylib|dll)' || echo " No native libraries found" >&2 ;; *) echo "Reason: unzip command failed" >&2 echo "Output: $UNZIP_OUTPUT" >&2 ;; esac echo "" >&2 echo "Falling back to source build..." >&2 exit 1 fi # Verify that we actually extracted some files EXTRACTED_COUNT=$(find "$TARGET_DIR/temp" -name "*.so*" 2>/dev/null | wc -l) echo "Extracted $EXTRACTED_COUNT library file(s)" if [ "$EXTRACTED_COUNT" -eq 0 ]; then echo "ERROR: No library files were extracted from jar" >&2 echo "This usually means the jar doesn't contain native libraries." >&2 echo "" >&2 echo "Listing jar contents:" >&2 unzip -l "$JAR_PATH" | head -20 >&2 exit 1 fi # Move libraries to native-deps directory, detecting conflicts echo "Moving extracted libraries..." CONFLICT_COUNT=0 # Use process substitution to avoid subshell issues while IFS= read -r source_file; do filename=$(basename "$source_file") dest_file="$NATIVE_DEPS_DIR/$filename" if [ -f "$dest_file" ]; then # File already exists - check if it's the same if ! cmp -s "$source_file" "$dest_file"; then echo "WARNING: Conflicting library detected: $filename" >&2 echo " Existing: $dest_file" >&2 echo " New: $source_file" >&2 echo " Keeping existing file, skipping new one" >&2 CONFLICT_COUNT=$((CONFLICT_COUNT + 1)) fi # Remove the duplicate source file rm -f "$source_file" else # No conflict, move the file mv "$source_file" "$dest_file" fi done < <(find "$TARGET_DIR/temp" -name "*.so*") if [ "$CONFLICT_COUNT" -gt 0 ]; then echo "WARNING: $CONFLICT_COUNT library file(s) had conflicts. Review the warnings above." >&2 fi rm -rf "$TARGET_DIR/temp" # Verify that libcudf.so was successfully moved to final location if [ ! -f "$NATIVE_DEPS_DIR/libcudf.so" ]; then echo "ERROR: libcudf.so not found in $NATIVE_DEPS_DIR" >&2 echo "" >&2 echo "This could mean:" >&2 echo " 1. The jar didn't contain libcudf.so" >&2 echo " 2. Extraction succeeded but moving files failed" >&2 echo " 3. Wrong architecture (jar might be for a different platform)" >&2 echo "" >&2 echo "Contents of $NATIVE_DEPS_DIR:" >&2 ls -lh "$NATIVE_DEPS_DIR" >&2 || echo " Directory is empty or doesn't exist" >&2 exit 1 fi echo "✓ Successfully extracted libraries to: $NATIVE_DEPS_DIR" ls -lh "$NATIVE_DEPS_DIR" # Clone cuDF repo for headers (shallow clone) if [ ! -d "$CUDF_REPO_DIR/.git" ]; then echo "Cloning cuDF repository for headers..." git clone --depth 1 --branch "$CUDF_BRANCH" https://github.com/rapidsai/cudf.git "$CUDF_REPO_DIR" echo "✓ Cloned cuDF repo to: $CUDF_REPO_DIR" else echo "✓ cuDF repo already exists at: $CUDF_REPO_DIR" echo " (Delete it to re-clone: rm -rf \"$CUDF_REPO_DIR\")" fi echo "" echo "==================================================" echo "Setup complete! You can now build with:" echo " mvn clean package -P udf-native-examples" echo "" echo "This will use prebuilt libcudf.so and avoid" echo "building cuDF from source (much faster!)." echo "==================================================" ================================================ FILE: examples/UDF-Examples/RAPIDS-accelerated-UDFs/pom.xml ================================================ 4.0.0 com.nvidia rapids-4-spark-udf-examples_2.12 RAPIDS Accelerator for Apache Spark UDF Examples Sample implementations of RAPIDS accelerated user defined functions for use with the RAPIDS Accelerator for Apache Spark 26.06.0-SNAPSHOT 1.8 1.8 8 UTF-8 UTF-8 UTF-8 cuda12 2.12 26.02.0 3.1.1 2.12.15 ${project.build.directory}/cpp-build main false OFF RAPIDS ON 10 OFF ON org.apache.spark spark-hive_${scala.binary.version} ${spark.version} org.scala-lang scala-library ${scala.version} com.nvidia rapids-4-spark_${scala.binary.version} ${rapids4spark.version} provided ${project.build.directory}/extra-resources true org.apache.maven.plugins maven-jar-plugin 3.2.0 default-test-jar none **/* net.alchim31.maven scala-maven-plugin 4.3.0 org.apache.rat apache-rat-plugin 0.13 org.apache.maven.plugins maven-antrun-plugin 3.0.0 generate-build-info none org.codehaus.mojo exec-maven-plugin 3.0.0 run pyspark tests verify exec ./run_pyspark_from_build.sh ./ ${skipTests} org.apache.maven.plugins maven-dependency-plugin copy-dist-jar package copy true com.nvidia rapids-4-spark_${scala.binary.version} ${rapids4spark.version} udf-native-examples org.apache.maven.plugins maven-dependency-plugin download-rapids-jar-with-classifier generate-sources copy com.nvidia rapids-4-spark_${scala.binary.version} ${rapids4spark.version} ${cuda.version} jar false ${project.build.directory}/rapids-jar false true download-rapids-jar-no-classifier generate-sources copy com.nvidia rapids-4-spark_${scala.binary.version} ${rapids4spark.version} jar false ${project.build.directory}/rapids-jar false true maven-antrun-plugin extract-cudf-dependencies generate-sources ${skipCudfExtraction} run cmake compile run maven-resources-plugin 3.2.0 copy-native-libs-to-deps process-classes copy-resources true ${project.build.directory}/native-deps/${os.arch}/${os.name} ${udf.native.build.path} libudfexamplesjni.so copy-native-libs-to-classes process-classes copy-resources true ${project.build.outputDirectory}/${os.arch}/${os.name} ${udf.native.build.path} libudfexamplesjni.so ================================================ FILE: examples/UDF-Examples/RAPIDS-accelerated-UDFs/pytest.ini ================================================ ; Copyright (c) 2020-2022, NVIDIA CORPORATION. ; ; Licensed under the Apache License, Version 2.0 (the "License"); ; you may not use this file except in compliance with the License. ; You may obtain a copy of the License at ; ; http://www.apache.org/licenses/LICENSE-2.0 ; ; Unless required by applicable law or agreed to in writing, software ; distributed under the License is distributed on an "AS IS" BASIS, ; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ; See the License for the specific language governing permissions and ; limitations under the License. [pytest] markers = rapids_udf_example_native: test UDFs that require custom cuda compilation ================================================ FILE: examples/UDF-Examples/RAPIDS-accelerated-UDFs/run_pyspark_from_build.sh ================================================ #!/bin/bash # Copyright (c) 2022-2025, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. set -ex SCRIPTPATH="$( cd "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P )" cd "$SCRIPTPATH" if [[ $( echo ${SKIP_TESTS} | tr '[:upper:]' '[:lower:]' ) == "true" ]]; then echo "PYTHON INTEGRATION TESTS SKIPPED..." exit 0 elif [[ -z "$SPARK_HOME" ]]; then >&2 echo "SPARK_HOME IS NOT SET CANNOT RUN PYTHON INTEGRATION TESTS..." exit 1 else echo "WILL RUN TESTS WITH SPARK_HOME: ${SPARK_HOME}" # Spark 3.1.1 includes https://github.com/apache/spark/pull/31540 # which helps with spurious task failures as observed in our tests. If you are running # Spark versions before 3.1.1, this sets the spark.max.taskFailures to 4 to allow for # more lineant configuration, else it will set them to 1 as spurious task failures are not expected # for Spark 3.1.1+ VERSION_STRING=`$SPARK_HOME/bin/pyspark --version 2>&1|grep -v Scala|awk '/version\ [0-9.]+/{print $NF}'` VERSION_STRING="${VERSION_STRING/-SNAPSHOT/}" [[ -z $VERSION_STRING ]] && { echo "Unable to detect the Spark version at $SPARK_HOME"; exit 1; } [[ -z $SPARK_SHIM_VER ]] && { SPARK_SHIM_VER="spark${VERSION_STRING//./}"; } echo "Detected Spark version $VERSION_STRING (shim version: $SPARK_SHIM_VER)" PLUGIN_JARS=$(echo "$SCRIPTPATH"/target/dependency/rapids-4-spark*.jar) UDF_EXAMPLE_JARS=$(echo "$SCRIPTPATH"/target/rapids-4-spark-udf-examples*.jar) ALL_JARS="$PLUGIN_JARS $UDF_EXAMPLE_JARS" echo "AND PLUGIN JARS: $ALL_JARS" RUN_TESTS_COMMAND=("$SCRIPTPATH"/runtests.py --rootdir "$SCRIPTPATH" "$SCRIPTPATH"/src/main/python) # --ignore=target is used to exclude the target directory which contains unrelated python files. TEST_COMMON_OPTS=(-v -rfExXs "$TEST_ARGS" --color=yes --ignore=target "$@") "$SPARK_HOME"/bin/spark-submit --jars "${ALL_JARS// /,}" \ --master local[1] \ "${RUN_TESTS_COMMAND[@]}" "${TEST_COMMON_OPTS[@]}" fi ================================================ FILE: examples/UDF-Examples/RAPIDS-accelerated-UDFs/runtests.py ================================================ # Copyright (c) 2022, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import sys from pytest import main #import cProfile if __name__ == '__main__': #cProfile.run('main(sys.argv[1:])', 'test_profile') # arguments are the same as for pytest https://docs.pytest.org/en/latest/usage.html # or run pytest -h sys.exit(main(sys.argv[1:])) ================================================ FILE: examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/cpp/CMakeLists.txt ================================================ #============================================================================= # Copyright (c) 2021-2026, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #============================================================================= # Keep the same with https://github.com/rapidsai/rapids-cmake/blob/main/RAPIDS.cmake cmake_minimum_required(VERSION 3.30.4 FATAL_ERROR) # set to the rapids-cmake-branch set(rapids-cmake-branch "main") file(DOWNLOAD https://raw.githubusercontent.com/rapidsai/rapids-cmake/${rapids-cmake-branch}/RAPIDS.cmake ${CMAKE_BINARY_DIR}/RAPIDS.cmake) include(${CMAKE_BINARY_DIR}/RAPIDS.cmake) include(rapids-cmake) include(rapids-cpm) include(rapids-cuda) include(rapids-export) include(rapids-find) # Get the rapids-cmake directory for later use # After include(rapids-cmake), CPM will download rapids-cmake to _deps # We can get it from the CPM cache get_property(rapids-cmake-dir GLOBAL PROPERTY rapids-cmake-dir) if(NOT rapids-cmake-dir) # Fallback: rapids-cmake is downloaded by CPM to _deps set(rapids-cmake-dir "${CMAKE_BINARY_DIR}/_deps/rapids-cmake-src") message(STATUS "rapids-cmake property not set, using fallback path") endif() # Verify rapids-cmake directory exists if(NOT EXISTS "${rapids-cmake-dir}") message(FATAL_ERROR "rapids-cmake directory not found: ${rapids-cmake-dir}\n" "This usually means rapids-cmake wasn't properly fetched by CPM.\n" "Try deleting the build directory and reconfiguring:\n" " rm -rf ${CMAKE_BINARY_DIR}\n" " cmake ..") endif() message(STATUS "rapids-cmake directory: ${rapids-cmake-dir}") # Use GPU_ARCHS if it is defined if(DEFINED GPU_ARCHS) set(CMAKE_CUDA_ARCHITECTURES "${GPU_ARCHS}") endif() rapids_cuda_init_architectures(UDFEXAMPLESJNI) project(UDFEXAMPLESJNI VERSION 26.06.0 LANGUAGES C CXX CUDA) option(PER_THREAD_DEFAULT_STREAM "Build with per-thread default stream" OFF) option(BUILD_UDF_BENCHMARKS "Build the benchmarks" OFF) ################################################################################################### # - build type ------------------------------------------------------------------------------------ # Set a default build type if none was specified set(DEFAULT_BUILD_TYPE "Release") ################################################################################################### # - compiler options ------------------------------------------------------------------------------ set(CMAKE_POSITION_INDEPENDENT_CODE ON) set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_COMPILER $ENV{CXX}) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CUDA_STANDARD 20) set(CMAKE_CUDA_STANDARD_REQUIRED ON) if(CMAKE_COMPILER_IS_GNUCXX) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unknown-pragmas") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=deprecated-declarations") endif(CMAKE_COMPILER_IS_GNUCXX) if(CMAKE_CUDA_COMPILER_VERSION) # Compute the version. from CMAKE_CUDA_COMPILER_VERSION string(REGEX REPLACE "([0-9]+)\\.([0-9]+).*" "\\1" CUDA_VERSION_MAJOR ${CMAKE_CUDA_COMPILER_VERSION}) string(REGEX REPLACE "([0-9]+)\\.([0-9]+).*" "\\2" CUDA_VERSION_MINOR ${CMAKE_CUDA_COMPILER_VERSION}) set(CUDA_VERSION "${CUDA_VERSION_MAJOR}.${CUDA_VERSION_MINOR}" CACHE STRING "Version of CUDA as computed from nvcc.") mark_as_advanced(CUDA_VERSION) endif() message(STATUS "CUDA_VERSION_MAJOR: ${CUDA_VERSION_MAJOR}") message(STATUS "CUDA_VERSION_MINOR: ${CUDA_VERSION_MINOR}") message(STATUS "CUDA_VERSION: ${CUDA_VERSION}") # Always set this convenience variable set(CUDA_VERSION_STRING "${CUDA_VERSION}") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -w --expt-extended-lambda --expt-relaxed-constexpr") #################################################################################################### # - cudf ------------------------------------------------------------------------------------------- # Check if USE_PREBUILT_CUDF was explicitly set by user (e.g., via -DUSE_PREBUILT_CUDF=...) # This must be done BEFORE the option() command if(DEFINED USE_PREBUILT_CUDF) set(USER_SET_USE_PREBUILT_CUDF TRUE) message(STATUS "USE_PREBUILT_CUDF explicitly set by user to: ${USE_PREBUILT_CUDF}") else() set(USER_SET_USE_PREBUILT_CUDF FALSE) endif() option(USE_PREBUILT_CUDF "Use prebuilt libcudf.so from rapids-4-spark jar" ON) message(STATUS "USE_PREBUILT_CUDF is set to: ${USE_PREBUILT_CUDF}") # Check if Maven created a marker to force source build # This happens when rapids-4-spark jar is not found if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/../../../target/USE_SOURCE_BUILD") message(STATUS "Found USE_SOURCE_BUILD marker file from Maven (rapids-4-spark jar not found)") if(USE_PREBUILT_CUDF) if(USER_SET_USE_PREBUILT_CUDF) # User explicitly requested prebuilt mode, but jar is missing - fail fast with clear error message(FATAL_ERROR "\n" "================================================================\n" "ERROR: rapids-4-spark jar not found, but USE_PREBUILT_CUDF=ON\n" "was explicitly set by the user.\n" "\n" "Cannot proceed with prebuilt mode because required libraries\n" "are not available.\n" "\n" "Solutions:\n" " 1. Remove -DUSE_PREBUILT_CUDF=ON to allow automatic fallback\n" " to building from source\n" "\n" " 2. Build and install rapids-4-spark:\n" " cd /path/to/spark-rapids\n" " mvn clean install -DskipTests\n" "\n" " 3. Explicitly use source build:\n" " -DUSE_PREBUILT_CUDF=OFF\n" "================================================================\n") else() # Not explicitly set by user - safe to auto-fallback message(STATUS "Auto-fallback: Switching to source build due to missing jar") set(USE_PREBUILT_CUDF OFF CACHE BOOL "Auto-fallback to source build (jar not found)" FORCE) endif() endif() endif() # Check prebuilt availability before making final decision # This avoids modifying cache variables within conditional blocks set(SHOULD_USE_PREBUILT ${USE_PREBUILT_CUDF}) if(USE_PREBUILT_CUDF AND NOT USER_SET_USE_PREBUILT_CUDF) # User didn't explicitly set the option - check if prebuilt components are available # Set paths for prebuilt library and headers set(CUDF_LIB_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../../../target/native-deps") set(CUDF_INCLUDE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../../../target/cudf-repo/cpp/include") message(STATUS "Checking for prebuilt libcudf.so from rapids-4-spark jar") message(STATUS "Looking in: ${CUDF_LIB_DIR}") # Check if prebuilt components are available set(PREBUILT_AVAILABLE TRUE) if(NOT EXISTS "${CUDF_LIB_DIR}") message(STATUS "Directory ${CUDF_LIB_DIR} does not exist") set(PREBUILT_AVAILABLE FALSE) else() # Try to find the library # Note: find_library sets variable to -NOTFOUND on failure, not undefined find_library(CUDF_LIBRARY_CHECK NAMES cudf PATHS ${CUDF_LIB_DIR} NO_DEFAULT_PATH ) # Proper check: find_library failure results in -NOTFOUND string if(CUDF_LIBRARY_CHECK MATCHES "-NOTFOUND$") message(STATUS "libcudf.so not found in ${CUDF_LIB_DIR}") set(PREBUILT_AVAILABLE FALSE) else() message(STATUS "Found libcudf at: ${CUDF_LIBRARY_CHECK}") endif() endif() # Auto-fallback to source build if components not available if(NOT PREBUILT_AVAILABLE) message(WARNING "\n" "================================================================\n" "Prebuilt libcudf.so not available.\n" "Automatically falling back to building cuDF from source.\n" "This will take 30+ minutes.\n" "\n" "To use fast build mode in future:\n" " 1. For SNAPSHOT versions: Build and install rapids-4-spark\n" " cd /path/to/spark-rapids\n" " mvn clean install -DskipTests\n" " 2. Run: mvn clean package -Pudf-native-examples\n" "\n" "NOTE: If you need to reset this decision, delete:\n" " ${CMAKE_BINARY_DIR}/CMakeCache.txt\n" "================================================================\n") set(SHOULD_USE_PREBUILT FALSE) # Update cache for subsequent runs set(USE_PREBUILT_CUDF OFF CACHE BOOL "Auto-fallback to source build" FORCE) endif() endif() # Now use the final decision consistently if(SHOULD_USE_PREBUILT) # Set paths as cache variables for user customization set(CUDF_LIB_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../../../target/native-deps" CACHE PATH "Path to directory containing libcudf.so") set(CUDF_INCLUDE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../../../target/cudf-repo/cpp/include" CACHE PATH "Path to cudf headers") message(STATUS "✓ Using FAST BUILD mode with prebuilt libcudf.so") # Find the prebuilt libcudf.so (should succeed based on earlier check) find_library(CUDF_LIBRARY NAMES cudf PATHS ${CUDF_LIB_DIR} NO_DEFAULT_PATH REQUIRED ) message(STATUS "✓ Found libcudf: ${CUDF_LIBRARY}") message(STATUS "✓ cuDF include directory: ${CUDF_INCLUDE_DIR}") # Verify cuDF source directory exists (cloned by Maven) set(CUDF_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../../../target/cudf-repo/cpp") if(NOT EXISTS "${CUDF_SOURCE_DIR}/CMakeLists.txt") message(FATAL_ERROR "cuDF source directory not found: ${CUDF_SOURCE_DIR}\n" "The cuDF repository should have been cloned by Maven.\n" "Check if target/cudf-repo/ exists.") endif() message(STATUS "✓ Found cuDF source at: ${CUDF_SOURCE_DIR}") # We'll use cuDF's dependency fetching mechanism but create our own target # First, let rapids-cpm fetch the dependencies that cuDF needs message(STATUS "Fetching cuDF dependencies (this may take a few minutes)...") rapids_cpm_init() # Set options to avoid building unnecessary components set(BUILD_TESTS OFF CACHE BOOL "" FORCE) set(BUILD_BENCHMARKS OFF CACHE BOOL "" FORCE) # Use rapids-cmake's helper scripts to get CCCL and RMM # These scripts use versions defined in rapids-cmake (avoiding duplicate version definitions) message(STATUS "Using rapids-cmake helper scripts for CCCL and RMM") # Get CCCL (Thrust, libcudacxx, CUB) - version defined in rapids-cmake set(CCCL_CMAKE_FILE "${rapids-cmake-dir}/cpm/cccl.cmake") if(NOT EXISTS "${CCCL_CMAKE_FILE}") message(FATAL_ERROR "rapids-cmake CCCL helper script not found: ${CCCL_CMAKE_FILE}\n" "Expected location: ${rapids-cmake-dir}/cpm/cccl.cmake\n" "This indicates rapids-cmake directory structure is incomplete or incorrect.") endif() include(${CCCL_CMAKE_FILE}) rapids_cpm_cccl() # Use rapids-cpm to get RMM - this is what cuDF uses set(RMM_CMAKE_FILE "${rapids-cmake-dir}/cpm/rmm.cmake") if(NOT EXISTS "${RMM_CMAKE_FILE}") message(FATAL_ERROR "rapids-cmake RMM helper script not found: ${RMM_CMAKE_FILE}\n" "Expected location: ${rapids-cmake-dir}/cpm/rmm.cmake\n" "This indicates rapids-cmake directory structure is incomplete or incorrect.") endif() include(${RMM_CMAKE_FILE}) rapids_cpm_rmm() # After rapids_cpm_rmm(), the rmm::rmm target should be available # Verify it exists if(NOT TARGET rmm::rmm) message(FATAL_ERROR "rmm::rmm target not created by rapids_cpm_rmm()") endif() # Get RMM include directory from the target get_target_property(RMM_INCLUDE_DIR rmm::rmm INTERFACE_INCLUDE_DIRECTORIES) message(STATUS "RMM include directories: ${RMM_INCLUDE_DIR}") # Now create our own imported target for cudf using the prebuilt library add_library(cudf_imported SHARED IMPORTED GLOBAL) set_target_properties(cudf_imported PROPERTIES IMPORTED_LOCATION ${CUDF_LIBRARY} ) # Add include directories to the imported target # Include cuDF headers and RMM headers target_include_directories(cudf_imported INTERFACE ${CUDF_INCLUDE_DIR} ${RMM_INCLUDE_DIR} ) # Link against RMM to get other dependencies target_link_libraries(cudf_imported INTERFACE rmm::rmm) # Create an alias to match expected name add_library(cudf::cudf ALIAS cudf_imported) message(STATUS "✓ Prebuilt cuDF configured with all dependencies") message(STATUS " Prebuilt library: ${CUDF_LIBRARY}") message(STATUS " cuDF headers: ${CUDF_INCLUDE_DIR}") message(STATUS " Dependencies: CCCL, RMM (via rapids-cpm)") else() message(STATUS "Building cuDF from source (this will take a long time)") # Ensure CUDA runtime is dynamic despite statically linking Arrow in libcudf set(CUDA_USE_STATIC_CUDA_RUNTIME ON) rapids_cpm_init() rapids_cpm_find(cudf 26.06.00 CPM_ARGS GIT_REPOSITORY https://github.com/rapidsai/cudf.git GIT_TAG ${rapids-cmake-branch} GIT_SHALLOW TRUE SOURCE_SUBDIR cpp OPTIONS "BUILD_TESTS OFF" "BUILD_BENCHMARKS OFF" "CUDF_USE_ARROW_STATIC ON" "JITIFY_USE_CACHE ON" "CUDA_STATIC_RUNTIME ${CUDA_USE_STATIC_CUDA_RUNTIME}" "DISABLE_DEPRECATION_WARNING ON" "AUTO_DETECT_CUDA_ARCHITECTURES OFF" "CUDF_KVIKIO_REMOTE_IO OFF" ) endif() ################################################################################################### # - benchmarks ------------------------------------------------------------------------------------ if(BUILD_UDF_BENCHMARKS) # Find or install GoogleBench CPMFindPackage(NAME benchmark VERSION 1.5.2 GIT_REPOSITORY https://github.com/google/benchmark.git GIT_TAG v1.5.2 GIT_SHALLOW TRUE OPTIONS "BENCHMARK_ENABLE_TESTING OFF" "BENCHMARK_ENABLE_INSTALL OFF") add_subdirectory(benchmarks) endif() ################################################################################################### # - find JNI ------------------------------------------------------------------------------------- find_package(JNI REQUIRED) if(JNI_FOUND) message(STATUS "JDK with JNI in ${JNI_INCLUDE_DIRS}") else() message(FATAL_ERROR "JDK with JNI not found, please check your settings.") endif(JNI_FOUND) ################################################################################################### # - library paths --------------------------------------------------------------------------------- # CMAKE_CUDA_IMPLICIT_LINK_DIRECTORIES is an undocumented/unsupported variable containing the link directories for nvcc link_directories("${CMAKE_CUDA_IMPLICIT_LINK_DIRECTORIES}" "${CMAKE_BINARY_DIR}/lib") ################################################################################################### # - library targets ------------------------------------------------------------------------------- set(SOURCE_FILES "src/CosineSimilarityJni.cpp" "src/StringWordCountJni.cpp" "src/cosine_similarity.cu" "src/string_word_count.cu") add_library(udfexamplesjni SHARED ${SOURCE_FILES}) #Override RPATH for udfexamplesjni SET_TARGET_PROPERTIES(udfexamplesjni PROPERTIES BUILD_RPATH "\$ORIGIN") ################################################################################################### # - build options --------------------------------------------------------------------------------- option(PER_THREAD_DEFAULT_STREAM "Build with per-thread default stream" OFF) if(PER_THREAD_DEFAULT_STREAM) message(STATUS "Using per-thread default stream") target_compile_definitions(udfexamplesjni PRIVATE CUDA_API_PER_THREAD_DEFAULT_STREAM) endif(PER_THREAD_DEFAULT_STREAM) target_include_directories(udfexamplesjni PRIVATE ${JNI_INCLUDE_DIRS}) ################################################################################################### # - rmm logging level ----------------------------------------------------------------------------- set(RMM_LOGGING_LEVEL "OFF" CACHE STRING "Choose the logging level.") # Set the possible values of build type for cmake-gui set_property(CACHE RMM_LOGGING_LEVEL PROPERTY STRINGS "TRACE" "DEBUG" "INFO" "WARN" "ERROR" "CRITICAL" "OFF") message(STATUS "RMM_LOGGING_LEVEL = '${RMM_LOGGING_LEVEL}'.") target_compile_definitions(udfexamplesjni PUBLIC SPDLOG_ACTIVE_LEVEL=SPDLOG_LEVEL_${RMM_LOGGING_LEVEL}) ################################################################################################### # - link libraries -------------------------------------------------------------------------------- target_link_libraries(udfexamplesjni cudf::cudf) ================================================ FILE: examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/cpp/benchmarks/CMakeLists.txt ================================================ #============================================================================= # Copyright (c) 2021-2022, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #============================================================================= # Use an OBJECT library so we only compile these helper source files only once add_library(udf_benchmark_common OBJECT synchronization/synchronization.cpp) target_link_libraries(udf_benchmark_common PUBLIC benchmark::benchmark cudf) target_include_directories(udf_benchmark_common PUBLIC "$" "$" "$/src") function(ConfigureBench CMAKE_BENCH_NAME) add_executable(${CMAKE_BENCH_NAME} ${ARGN}) set_target_properties(${CMAKE_BENCH_NAME} PROPERTIES RUNTIME_OUTPUT_DIRECTORY "$") target_link_libraries(${CMAKE_BENCH_NAME} PRIVATE udf_benchmark_common udfexamplesjni benchmark::benchmark_main) endfunction() ConfigureBench(COSINE_SIMILARITY_BENCH cosine_similarity/cosine_similarity_benchmark.cpp) ================================================ FILE: examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/cpp/benchmarks/cosine_similarity/cosine_similarity_benchmark.cpp ================================================ /* * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "benchmarks/fixture/benchmark_fixture.hpp" #include "benchmarks/synchronization/synchronization.hpp" #include "cosine_similarity.hpp" #include #include #include #include static void cosine_similarity_bench_args(benchmark::internal::Benchmark* b) { int const min_rows = 1 << 12; int const max_rows = 1 << 24; int const row_mult = 8; int const min_rowlen = 1 << 0; int const max_rowlen = 1 << 12; int const len_mult = 8; for (int row_count = min_rows; row_count <= max_rows; row_count *= row_mult) { for (int rowlen = min_rowlen; rowlen <= max_rowlen; rowlen *= len_mult) { // avoid generating combinations that exceed the cudf column limit size_t total_chars = static_cast(row_count) * rowlen; if (total_chars < std::numeric_limits::max()) { b->Args({row_count, rowlen}); } } } } static void BM_cosine_similarity(benchmark::State& state) { cudf::size_type const n_rows{static_cast(state.range(0))}; cudf::size_type const list_len{static_cast(state.range(1))}; auto val_start = cudf::make_fixed_width_scalar(1.0f); auto val_step = cudf::make_fixed_width_scalar(-1.0f); auto child_rows = n_rows * list_len; auto col1_child = cudf::sequence(child_rows, *val_start); auto col2_child = cudf::sequence(child_rows, *val_start, *val_step); auto offset_start = cudf::make_fixed_width_scalar(static_cast(0)); auto offset_step = cudf::make_fixed_width_scalar(list_len); auto offsets = cudf::sequence(n_rows + 1, *offset_start, *offset_step); auto col1 = cudf::make_lists_column( n_rows, std::make_unique(*offsets), std::move(col1_child), 0, cudf::create_null_mask(n_rows, cudf::mask_state::ALL_VALID)); auto lcol1 = cudf::lists_column_view(*col1); auto col2 = cudf::make_lists_column( n_rows, std::move(offsets), std::move(col2_child), 0, cudf::create_null_mask(n_rows, cudf::mask_state::ALL_VALID)); auto lcol2 = cudf::lists_column_view(*col2); for (auto _ : state) { cuda_event_timer raii(state, true, rmm::cuda_stream_default); auto output = cosine_similarity(lcol1, lcol2); } state.SetBytesProcessed(state.iterations() * child_rows * sizeof(float)); } class CosineSimilarity : public native_udf::benchmark { }; BENCHMARK_DEFINE_F(CosineSimilarity, cosine_similarity) (::benchmark::State& state) { BM_cosine_similarity(state); } BENCHMARK_REGISTER_F(CosineSimilarity, cosine_similarity) ->Apply(cosine_similarity_bench_args) ->Unit(benchmark::kMillisecond) ->UseManualTime(); ================================================ FILE: examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/cpp/benchmarks/fixture/benchmark_fixture.hpp ================================================ /* * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include #include namespace native_udf { namespace { // memory resource factory helpers inline auto make_cuda() { return std::make_shared(); } inline auto make_pool() { return rmm::mr::make_owning_wrapper(make_cuda()); } } // namespace /** * @brief Google Benchmark fixture for native UDF benchmarks * * Native UDF benchmarks should use a fixture derived from this fixture class to * ensure that the RAPIDS Memory Manager pool mode is used in benchmarks, which * eliminates memory allocation / deallocation performance overhead from the * benchmark. * * The SetUp and TearDown methods of this fixture initialize RMM into pool mode * and finalize it, respectively. These methods are called automatically by * Google Benchmark * * Example: * * template * class my_benchmark : public native_udf::benchmark { * public: * using TypeParam = T; * }; * * Then: * * BENCHMARK_TEMPLATE_DEFINE_F(my_benchmark, my_test_name, int) * (::benchmark::State& state) { * for (auto _ : state) { * // benchmark stuff * } * } * * BENCHMARK_REGISTER_F(my_benchmark, my_test_name)->Range(128, 512); */ class benchmark : public ::benchmark::Fixture { public: virtual void SetUp(const ::benchmark::State& state) { mr = make_pool(); rmm::mr::set_current_device_resource(mr.get()); // set default resource to pool } virtual void TearDown(const ::benchmark::State& state) { // reset default resource to the initial resource rmm::mr::set_current_device_resource(nullptr); mr.reset(); } // eliminate partial override warnings (see benchmark/benchmark.h) virtual void SetUp(::benchmark::State& st) { SetUp(const_cast(st)); } virtual void TearDown(::benchmark::State& st) { TearDown(const_cast(st)); } std::shared_ptr mr; }; } // namespace native_udf ================================================ FILE: examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/cpp/benchmarks/synchronization/synchronization.cpp ================================================ /* * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "synchronization.hpp" #include #include #include cuda_event_timer::cuda_event_timer(benchmark::State& state, bool flush_l2_cache, rmm::cuda_stream_view stream) : stream(stream), p_state(&state) { // flush all of L2$ if (flush_l2_cache) { int current_device = 0; CUDA_TRY(cudaGetDevice(¤t_device)); int l2_cache_bytes = 0; CUDA_TRY(cudaDeviceGetAttribute(&l2_cache_bytes, cudaDevAttrL2CacheSize, current_device)); if (l2_cache_bytes > 0) { const int memset_value = 0; rmm::device_buffer l2_cache_buffer(l2_cache_bytes, stream); CUDA_TRY( cudaMemsetAsync(l2_cache_buffer.data(), memset_value, l2_cache_bytes, stream.value())); } } CUDA_TRY(cudaEventCreate(&start)); CUDA_TRY(cudaEventCreate(&stop)); CUDA_TRY(cudaEventRecord(start, stream.value())); } cuda_event_timer::~cuda_event_timer() { CUDA_TRY(cudaEventRecord(stop, stream.value())); CUDA_TRY(cudaEventSynchronize(stop)); float milliseconds = 0.0f; CUDA_TRY(cudaEventElapsedTime(&milliseconds, start, stop)); p_state->SetIterationTime(milliseconds / (1000.0f)); CUDA_TRY(cudaEventDestroy(start)); CUDA_TRY(cudaEventDestroy(stop)); } ================================================ FILE: examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/cpp/benchmarks/synchronization/synchronization.hpp ================================================ /* * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /** * @file synchronization.hpp * @brief This is the header file for `cuda_event_timer`. */ /** * @brief This class serves as a wrapper for using `cudaEvent_t` as the user * defined timer within the framework of google benchmark * (https://github.com/google/benchmark). * * It is built on top of the idea of Resource acquisition is initialization * (RAII). In the following we show a minimal example of how to use this class. #include static void sample_cuda_benchmark(benchmark::State& state) { for (auto _ : state){ rmm::cuda_stream_view stream{}; // default stream, could be another stream // Create (Construct) an object of this class. You HAVE to pass in the // benchmark::State object you are using. It measures the time from its // creation to its destruction that is spent on the specified CUDA stream. // It also clears the L2 cache by cudaMemset'ing a device buffer that is of // the size of the L2 cache (if flush_l2_cache is set to true and there is // an L2 cache on the current device). cuda_event_timer raii(state, true, stream); // flush_l2_cache = true // Now perform the operations that is to be benchmarked sample_kernel<<<1, 256, 0, stream.value()>>>(); // Possibly launching a CUDA kernel } } // Register the function as a benchmark. You will need to set the `UseManualTime()` // flag in order to use the timer embedded in this class. BENCHMARK(sample_cuda_benchmark)->UseManualTime(); */ #ifndef UDF_BENCH_SYNCHRONIZATION_H #define UDF_BENCH_SYNCHRONIZATION_H // Google Benchmark library #include #include #include class cuda_event_timer { public: /** * @brief This c'tor clears the L2$ by cudaMemset'ing a buffer of L2$ size * and starts the timer. * * @param[in,out] state This is the benchmark::State whose timer we are going * to update. * @param[in] flush_l2_cache_ whether or not to flush the L2 cache before * every iteration. * @param[in] stream_ The CUDA stream we are measuring time on. */ cuda_event_timer(benchmark::State& state, bool flush_l2_cache, rmm::cuda_stream_view stream = rmm::cuda_stream_default); // The user must provide a benchmark::State object to set // the timer so we disable the default c'tor. cuda_event_timer() = delete; // The d'tor stops the timer and performs a synchronization. // Time of the benchmark::State object provided to the c'tor // will be set to the value given by `cudaEventElapsedTime`. ~cuda_event_timer(); private: cudaEvent_t start; cudaEvent_t stop; rmm::cuda_stream_view stream; benchmark::State* p_state; }; #endif ================================================ FILE: examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/cpp/src/CosineSimilarityJni.cpp ================================================ /* * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include #include #include "cosine_similarity.hpp" namespace { constexpr char const* RUNTIME_ERROR_CLASS = "java/lang/RuntimeException"; constexpr char const* ILLEGAL_ARG_CLASS = "java/lang/IllegalArgumentException"; /** * @brief Throw a Java exception * * @param env The Java environment * @param class_name The fully qualified Java class name of the exception * @param msg The message string to associate with the exception */ void throw_java_exception(JNIEnv* env, char const* class_name, char const* msg) { jclass ex_class = env->FindClass(class_name); if (ex_class != NULL) { env->ThrowNew(ex_class, msg); } } } // anonymous namespace extern "C" { /** * @brief The native implementation of CosineSimilarity.cosineSimilarity which * computes the cosine similarity between two LIST(FLOAT32) columns as a FLOAT32 * columnar result. * * @param env The Java environment * @param j_view1 The address of the cudf column view of the first LIST column * @param j_view2 The address of the cudf column view of the second LIST column * @return The address of the cudf column containing the FLOAT32 results */ JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_udf_java_CosineSimilarity_cosineSimilarity(JNIEnv* env, jclass, jlong j_view1, jlong j_view2) { // Use a try block to translate C++ exceptions into Java exceptions to avoid // crashing the JVM if a C++ exception occurs. try { // turn the addresses into column_view pointers auto v1 = reinterpret_cast(j_view1); auto v2 = reinterpret_cast(j_view2); if (v1->type().id() != v2->type().id() || v1->type().id() != cudf::type_id::LIST) { throw_java_exception(env, ILLEGAL_ARG_CLASS, "inputs not list columns"); return 0; } // run the GPU kernel to compute the cosine similarity auto lv1 = cudf::lists_column_view(*v1); auto lv2 = cudf::lists_column_view(*v2); std::unique_ptr result = cosine_similarity(lv1, lv2); // take ownership of the column and return the column address to Java and release the underlying resources. return reinterpret_cast(result.release()); } catch (std::bad_alloc const& e) { auto msg = std::string("Unable to allocate native memory: ") + (e.what() == nullptr ? "" : e.what()); throw_java_exception(env, RUNTIME_ERROR_CLASS, msg.c_str()); } catch (std::invalid_argument const& e) { throw_java_exception(env, ILLEGAL_ARG_CLASS, e.what() == nullptr ? "" : e.what()); } catch (std::exception const& e) { auto msg = e.what() == nullptr ? "" : e.what(); throw_java_exception(env, RUNTIME_ERROR_CLASS, msg); } return 0; } } ================================================ FILE: examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/cpp/src/StringWordCountJni.cpp ================================================ /* * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include #include "string_word_count.hpp" namespace { constexpr char const* RUNTIME_ERROR_CLASS = "java/lang/RuntimeException"; /** * @brief Throw a Java exception * * @param env The Java environment * @param class_name The fully qualified Java class name of the exception * @param msg The message string to associate with the exception */ void throw_java_exception(JNIEnv* env, char const* class_name, char const* msg) { jclass ex_class = env->FindClass(class_name); if (ex_class != NULL) { env->ThrowNew(ex_class, msg); } } } // anonymous namespace extern "C" { /** * @brief The native implementation of StringWordCount.countWords which counts the * number of words per string in a string column. * * @param env The Java environment * @param j_strings The address of the cudf column view of the strings column * @return The address of the cudf column containing the word counts */ JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_udf_hive_StringWordCount_countWords(JNIEnv* env, jclass, jlong j_strings) { // Use a try block to translate C++ exceptions into Java exceptions to avoid // crashing the JVM if a C++ exception occurs. try { // turn the addresses into column_view pointers auto strs = reinterpret_cast(j_strings); // run the GPU kernel to compute the word counts std::unique_ptr result = string_word_count(*strs); // take ownership of the column and return the column address to Java return reinterpret_cast(result.release()); } catch (std::bad_alloc const& e) { auto msg = std::string("Unable to allocate native memory: ") + (e.what() == nullptr ? "" : e.what()); throw_java_exception(env, RUNTIME_ERROR_CLASS, msg.c_str()); } catch (std::exception const& e) { auto msg = e.what() == nullptr ? "" : e.what(); throw_java_exception(env, RUNTIME_ERROR_CLASS, msg); } return 0; } } ================================================ FILE: examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/cpp/src/cosine_similarity.cu ================================================ /* * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "cosine_similarity.hpp" #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace { /** * @brief Functor for computing the cosine similarity between two list of float columns */ struct cosine_similarity_functor { float const* const v1; float const* const v2; int32_t const* const v1_offsets; int32_t const* const v2_offsets; // This kernel executes thread-per-row which should be fine for relatively short lists // but may need to be revisited for performance if operating on long lists. __device__ float operator()(cudf::size_type row_idx) { auto const v1_start_idx = v1_offsets[row_idx]; auto const v1_num_elems = v1_offsets[row_idx + 1] - v1_start_idx; auto const v2_start_idx = v2_offsets[row_idx]; auto const v2_num_elems = v2_offsets[row_idx + 1] - v2_start_idx; auto const num_elems = std::min(v1_num_elems, v2_num_elems); double mag1 = 0; double mag2 = 0; double dot_product = 0; for (auto i = 0; i < num_elems; i++) { float const f1 = v1[v1_start_idx + i]; mag1 += f1 * f1; float const f2 = v2[v2_start_idx + i]; mag2 += f2 * f2; dot_product += f1 * f2; } mag1 = std::sqrt(mag1); mag2 = std::sqrt(mag2); return static_cast(dot_product / (mag1 * mag2)); } }; } // anonymous namespace /** * @brief Compute the cosine similarity between two LIST of FLOAT32 columns * * The input vectors must have matching shapes, i.e.: same row count and same number of * list elements per row. A null list row is supported, but null float entries within a * list are not supported. * * @param lv1 The first LIST of FLOAT32 column view * @param lv2 The second LIST of FLOAT32 column view * @return A FLOAT32 column containing the cosine similarity corresponding to each input row */ std::unique_ptr cosine_similarity(cudf::lists_column_view const& lv1, cudf::lists_column_view const& lv2) { // sanity-check the input types if (lv1.child().type().id() != lv2.child().type().id() || lv1.child().type().id() != cudf::type_id::FLOAT32) { throw std::invalid_argument("inputs are not lists of floats"); } // sanity check the input shape auto const row_count = lv1.size(); if (row_count != lv2.size()) { throw std::invalid_argument("input row counts do not match"); } if (row_count == 0) { return cudf::make_empty_column(cudf::data_type{cudf::type_id::FLOAT32}); } if (lv1.child().null_count() != 0 || lv2.child().null_count() != 0) { throw std::invalid_argument("null floats are not supported"); } auto const stream = rmm::cuda_stream_default; // Check if list sizes match by comparing offsets differences // Need to handle null lists: if either list is null, consider it valid (will be null in output) auto const lv1_offsets_ptr = lv1.offsets().data(); auto const lv2_offsets_ptr = lv2.offsets().data(); auto const lv1_null_mask = lv1.parent().null_mask(); auto const lv2_null_mask = lv2.parent().null_mask(); bool const are_offsets_equal = thrust::all_of(rmm::exec_policy(stream), thrust::make_counting_iterator(0), thrust::make_counting_iterator(row_count), [lv1_offsets_ptr, lv2_offsets_ptr, lv1_null_mask, lv2_null_mask] __device__(cudf::size_type idx) -> bool { // Check if either list is null - if so, consider valid // Use cudf::bit_is_set() for proper bitmask handling bool lv1_is_null = lv1_null_mask != nullptr && !cudf::bit_is_set(lv1_null_mask, idx); bool lv2_is_null = lv2_null_mask != nullptr && !cudf::bit_is_set(lv2_null_mask, idx); if (lv1_is_null || lv2_is_null) return true; // Both are valid, check sizes auto lv1_size = lv1_offsets_ptr[idx + 1] - lv1_offsets_ptr[idx]; auto lv2_size = lv2_offsets_ptr[idx + 1] - lv2_offsets_ptr[idx]; return lv1_size == lv2_size; }); if (not are_offsets_equal) { throw std::invalid_argument("input list lengths do not match for every row"); } // allocate the vector of float results rmm::device_uvector float_results(row_count, stream); // compute the cosine similarity auto const lv1_data = lv1.child().data(); auto const lv2_data = lv2.child().data(); auto const lv1_offsets = lv1.offsets().data(); auto const lv2_offsets = lv2.offsets().data(); thrust::transform(rmm::exec_policy(stream), thrust::make_counting_iterator(0), thrust::make_counting_iterator(row_count), float_results.data(), cosine_similarity_functor({lv1_data, lv2_data, lv1_offsets, lv2_offsets})); // the validity of the output is the bitwise-and of the two input validity masks auto [null_mask, null_count] = cudf::bitmask_and(cudf::table_view({lv1.parent(), lv2.parent()})); return std::make_unique(cudf::data_type{cudf::type_id::FLOAT32}, row_count, float_results.release(), std::move(null_mask), null_count); } ================================================ FILE: examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/cpp/src/cosine_similarity.hpp ================================================ /* * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include #include #include /** * @brief Compute the cosine similarity between two LIST of FLOAT32 columns * * The input vectors must have matching shapes, i.e.: same row count and same number of * list elements per row. A null list row is supported, but null float entries within a * list are not supported. * * @param lv1 The first LIST of FLOAT32 column view * @param lv2 The second LIST of FLOAT32 column view * @return A FLOAT32 column containing the cosine similarity corresponding to each input row */ std::unique_ptr cosine_similarity(cudf::lists_column_view const& lv1, cudf::lists_column_view const& lv2); ================================================ FILE: examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/cpp/src/string_word_count.cu ================================================ /* * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "string_word_count.hpp" #include #include #include #include #include #include #include #include namespace { // count the words separated by whitespace characters __device__ cudf::size_type count_words(cudf::column_device_view const& d_strings, cudf::size_type idx) { if (d_strings.is_null(idx)) return 0; cudf::string_view const d_str = d_strings.element(idx); cudf::size_type word_count = 0; // run of whitespace is considered a single delimiter bool spaces = true; auto itr = d_str.begin(); while (itr != d_str.end()) { cudf::char_utf8 ch = *itr; if (spaces == (ch <= ' ')) { itr++; } else { word_count += static_cast(spaces); spaces = !spaces; } } return word_count; } } // anonymous namespace /** * @brief Count the words in a string using whitespace as word boundaries * * @param strs The column containing the strings * @param stream The CUDA stream to use * @return The INT32 column containing the word count results per string */ std::unique_ptr string_word_count(cudf::column_view const& strs) { auto strings_count = strs.size(); if (strings_count == 0) { return cudf::make_empty_column(cudf::data_type{cudf::type_id::INT32}); } // the validity of the output matches the validity of the input rmm::device_buffer null_mask = cudf::copy_bitmask(strs); // allocate the column that will contain the word count results std::unique_ptr result = cudf::make_numeric_column( cudf::data_type{cudf::type_id::INT32}, strs.size(), std::move(null_mask), strs.null_count()); // compute the word counts, writing into the result column data buffer auto stream = rmm::cuda_stream_default; auto strs_device_view = cudf::column_device_view::create(strs, stream); auto d_strs_view = *strs_device_view; thrust::transform( rmm::exec_policy(stream), thrust::make_counting_iterator(0), thrust::make_counting_iterator(strings_count), result->mutable_view().data(), [d_strs_view] __device__(cudf::size_type idx) -> cudf::size_type { return count_words(d_strs_view, idx); }); return result; } ================================================ FILE: examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/cpp/src/string_word_count.hpp ================================================ /* * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include #include /** * @brief Count the words in a string separated by whitespace * * @param strs The column containing the strings to be examined * @return The INT32 column containing the word count results for each string */ std::unique_ptr string_word_count(cudf::column_view const& strs); ================================================ FILE: examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/java/com/nvidia/spark/rapids/udf/hive/DecimalFraction.java ================================================ /* * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.nvidia.spark.rapids.udf.hive; import ai.rapids.cudf.ColumnVector; import ai.rapids.cudf.Scalar; import com.nvidia.spark.RapidsUDF; import org.apache.hadoop.hive.common.type.HiveDecimal; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; import java.math.BigDecimal; /** * A simple HiveGenericUDF demo for DecimalType, which extracts and returns * the fraction part of the input Decimal data. So, the output data has the * same precision and scale as the input one. */ public class DecimalFraction extends GenericUDF implements RapidsUDF { private transient PrimitiveObjectInspector inputOI; @Override public String getDisplayString(String[] strings) { return getStandardDisplayString("DecimalFraction", strings); } @Override public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException { if (arguments.length != 1) { throw new UDFArgumentException("One argument is supported, found: " + arguments.length); } if (!(arguments[0] instanceof PrimitiveObjectInspector)) { throw new UDFArgumentException("Unsupported argument type: " + arguments[0].getTypeName()); } inputOI = (PrimitiveObjectInspector) arguments[0]; if (inputOI.getPrimitiveCategory() != PrimitiveObjectInspector.PrimitiveCategory.DECIMAL) { throw new UDFArgumentException("Unsupported primitive type: " + inputOI.getPrimitiveCategory()); } DecimalTypeInfo inputTypeInfo = (DecimalTypeInfo) inputOI.getTypeInfo(); return PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(inputTypeInfo); } @Override public Object evaluate(GenericUDF.DeferredObject[] arguments) throws HiveException { if (arguments[0] == null || arguments[0].get() == null) { return null; } Object input = arguments[0].get(); HiveDecimalWritable decimalWritable = (HiveDecimalWritable) inputOI.getPrimitiveWritableObject(input); BigDecimal decimalInput = decimalWritable.getHiveDecimal().bigDecimalValue(); BigDecimal decimalResult = decimalInput.subtract(new BigDecimal(decimalInput.toBigInteger())); HiveDecimalWritable result = new HiveDecimalWritable(decimalWritable); result.set(HiveDecimal.create(decimalResult)); return result; } @Override public ColumnVector evaluateColumnar(int numRows, ColumnVector... args) { if (args.length != 1) { throw new IllegalArgumentException("Unexpected argument count: " + args.length); } ColumnVector input = args[0]; if (numRows != input.getRowCount()) { throw new IllegalArgumentException("Expected " + numRows + " rows, received " + input.getRowCount()); } if (!input.getType().isDecimalType()) { throw new IllegalArgumentException("Argument type is not a decimal column: " + input.getType()); } try (Scalar nullScalar = Scalar.fromNull(input.getType()); ColumnVector nullPredicate = input.isNull(); ColumnVector integral = input.floor(); ColumnVector fraction = input.sub(integral, input.getType())) { return nullPredicate.ifElse(nullScalar, fraction); } } } ================================================ FILE: examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/java/com/nvidia/spark/rapids/udf/hive/StringWordCount.java ================================================ /* * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.nvidia.spark.rapids.udf.hive; import ai.rapids.cudf.ColumnVector; import ai.rapids.cudf.DType; import ai.rapids.cudf.NativeDepsLoader; import com.nvidia.spark.RapidsUDF; import com.nvidia.spark.rapids.udf.java.NativeUDFExamplesLoader; import org.apache.hadoop.hive.ql.exec.UDF; import java.io.IOException; /** * A user-defined function (UDF) that counts the words in a string. * This avoids the manifestation of intermediate results required when * splitting the string on whitespace and counting the split results. *

* This class demonstrates how to implement a Hive UDF with a RAPIDS * implementation that uses custom native code. */ public class StringWordCount extends UDF implements RapidsUDF { private volatile boolean isNativeCodeLoaded = false; /** Row-by-row implementation that executes on the CPU */ public Integer evaluate(String str) { if (str == null) { return null; } int numWords = 0; // run of whitespace is considered a single delimiter boolean spaces = true; for (int idx = 0; idx < str.length(); idx++) { char ch = str.charAt(idx); if (spaces != (ch <= ' ')) { if (spaces) { numWords++; } spaces = !spaces; } } return numWords; } /** Columnar implementation that runs on the GPU */ @Override public ColumnVector evaluateColumnar(int numRows, ColumnVector... args) { // The CPU implementation takes a single string argument, so similarly // there should only be one column argument of type STRING. if (args.length != 1) { throw new IllegalArgumentException("Unexpected argument count: " + args.length); } ColumnVector strs = args[0]; if (numRows != strs.getRowCount()) { throw new IllegalArgumentException("Expected " + numRows + " rows, received " + strs.getRowCount()); } if (!strs.getType().equals(DType.STRING)) { throw new IllegalArgumentException("type mismatch, expected strings but found " + strs.getType()); } // Load the native code if it has not been already loaded. This is done here // rather than in a static code block since the driver may not have the // required CUDA environment. NativeUDFExamplesLoader.ensureLoaded(); return new ColumnVector(countWords(strs.getNativeView())); } private static native long countWords(long stringsView); } ================================================ FILE: examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/java/com/nvidia/spark/rapids/udf/hive/URLDecode.java ================================================ /* * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.nvidia.spark.rapids.udf.hive; import ai.rapids.cudf.ColumnVector; import ai.rapids.cudf.DType; import ai.rapids.cudf.Scalar; import com.nvidia.spark.RapidsUDF; import org.apache.hadoop.hive.ql.exec.UDF; import java.io.UnsupportedEncodingException; import java.net.URLDecoder; /** * A Hive user-defined function (UDF) that decodes URL-encoded strings. * This class demonstrates how to implement a simple Hive UDF that also * provides a RAPIDS implementation that can run on the GPU when the query * is executed with the RAPIDS Accelerator for Apache Spark. */ public class URLDecode extends UDF implements RapidsUDF { /** Row-by-row implementation that executes on the CPU */ public String evaluate(String s) { String result = null; if (s != null) { try { result = URLDecoder.decode(s, "utf-8"); } catch (IllegalArgumentException ignored) { result = s; } catch (UnsupportedEncodingException e) { // utf-8 is a builtin, standard encoding, so this should never happen throw new RuntimeException(e); } } return result; } /** Columnar implementation that runs on the GPU */ @Override public ColumnVector evaluateColumnar(int numRows, ColumnVector... args) { // The CPU implementation takes a single string argument, so similarly // there should only be one column argument of type STRING. if (args.length != 1) { throw new IllegalArgumentException("Unexpected argument count: " + args.length); } ColumnVector input = args[0]; if (numRows != input.getRowCount()) { throw new IllegalArgumentException("Expected " + numRows + " rows, received " + input.getRowCount()); } if (!input.getType().equals(DType.STRING)) { throw new IllegalArgumentException("Argument type is not a string column: " + input.getType()); } // The cudf urlDecode does not convert '+' to a space, so do that as a pre-pass first. // All intermediate results are closed to avoid leaking GPU resources. try (Scalar plusScalar = Scalar.fromString("+"); Scalar spaceScalar = Scalar.fromString(" "); ColumnVector replaced = input.stringReplace(plusScalar, spaceScalar)) { return replaced.urlDecode(); } } } ================================================ FILE: examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/java/com/nvidia/spark/rapids/udf/hive/URLEncode.java ================================================ /* * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.nvidia.spark.rapids.udf.hive; import ai.rapids.cudf.ColumnVector; import ai.rapids.cudf.DType; import com.nvidia.spark.RapidsUDF; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorConverter; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.io.Text; import java.io.UnsupportedEncodingException; import java.net.URLEncoder; /** * A Hive user-defined function (UDF) that URL-encodes strings. * This class demonstrates how to implement a Hive GenericUDF that also * provides a RAPIDS implementation that can run on the GPU when the query * is executed with the RAPIDS Accelerator for Apache Spark. */ public class URLEncode extends GenericUDF implements RapidsUDF { private transient PrimitiveObjectInspectorConverter.TextConverter converter; private final Text textResult = new Text(); /** Standard getDisplayString method for implementing GenericUDF */ @Override public String getDisplayString(String[] children) { return getStandardDisplayString("urlencode", children); } /** Standard initialize method for implementing GenericUDF for a single string parameter */ @Override public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException { if (arguments.length != 1) { throw new UDFArgumentException("One argument is supported, found: " + arguments.length); } if (!(arguments[0] instanceof PrimitiveObjectInspector)) { throw new UDFArgumentException("Unsupported argument type: " + arguments[0].getTypeName()); } PrimitiveObjectInspector poi = (PrimitiveObjectInspector) arguments[0]; switch (poi.getPrimitiveCategory()) { case STRING: case CHAR: case VARCHAR: break; default: throw new UDFArgumentException("Unsupported primitive type: " + poi.getPrimitiveCategory()); } converter = new PrimitiveObjectInspectorConverter.TextConverter(poi); return PrimitiveObjectInspectorFactory.writableStringObjectInspector; } /** Row-by-row implementation that executes on the CPU */ @Override public Object evaluate(GenericUDF.DeferredObject[] arguments) throws HiveException { Text text = converter.convert(arguments[0].get()); if (text == null) { return null; } String encoded; try { encoded = URLEncoder.encode(text.toString(), "utf-8") .replace("+", "%20") .replace("*", "%2A") .replace("%7E", "~"); } catch (UnsupportedEncodingException e) { // utf-8 is a builtin, standard encoding, so this should never happen throw new RuntimeException(e); } textResult.set(encoded); return textResult; } /** Columnar implementation that runs on the GPU */ @Override public ColumnVector evaluateColumnar(int numRows, ColumnVector... args) { // The CPU implementation takes a single string argument, so similarly // there should only be one column argument of type STRING. if (args.length != 1) { throw new IllegalArgumentException("Unexpected argument count: " + args.length); } ColumnVector input = args[0]; if (numRows != input.getRowCount()) { throw new IllegalArgumentException("Expected " + numRows + " rows, received " + input.getRowCount()); } if (!input.getType().equals(DType.STRING)) { throw new IllegalArgumentException("Argument type is not a string column: " + input.getType()); } return input.urlEncode(); } } ================================================ FILE: examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/java/com/nvidia/spark/rapids/udf/java/CosineSimilarity.java ================================================ /* * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.nvidia.spark.rapids.udf.java; import ai.rapids.cudf.ColumnVector; import com.nvidia.spark.RapidsUDF; import org.apache.spark.sql.api.java.UDF2; import scala.collection.mutable.WrappedArray; /** * A Spark Java UDF that computes the cosine similarity between two float vectors. * The input vectors must have matching shapes, i.e.: same number of elements. * A null vector is supported, but null entries within the vector are not supported. */ public class CosineSimilarity implements UDF2, WrappedArray, Float>, RapidsUDF { /** Row-by-row implementation that executes on the CPU */ @Override public Float call(WrappedArray v1, WrappedArray v2) { if (v1 == null || v2 == null) { return null; } if (v1.length() != v2.length()) { throw new IllegalArgumentException("Array lengths must match: " + v1.length() + " != " + v2.length()); } double dotProduct = 0; for (int i = 0; i < v1.length(); i++) { float f1 = v1.apply(i); float f2 = v2.apply(i); dotProduct += f1 * f2; } double magProduct = magnitude(v1) * magnitude(v2); return (float) (dotProduct / magProduct); } private double magnitude(WrappedArray v) { double sum = 0; for (int i = 0; i < v.length(); i++) { float x = v.apply(i); sum += x * x; } return Math.sqrt(sum); } /** Columnar implementation that processes data on the GPU */ @Override public ColumnVector evaluateColumnar(int numRows, ColumnVector... args) { if (args.length != 2) { throw new IllegalArgumentException("Unexpected argument count: " + args.length); } // Load the native code if it has not been already loaded. This is done here // rather than in a static code block since the driver may not have the // required CUDA environment. NativeUDFExamplesLoader.ensureLoaded(); // We need to go into the native code as quickly as possible // because it is easier to write the code safely. // Then wrap returns in a column vector and own that resource. return new ColumnVector(cosineSimilarity(args[0].getNativeView(), args[1].getNativeView())); } /** Native implementation that computes on the GPU */ private static native long cosineSimilarity(long vectorView1, long vectorView2); } ================================================ FILE: examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/java/com/nvidia/spark/rapids/udf/java/DecimalFraction.java ================================================ /* * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.nvidia.spark.rapids.udf.java; import ai.rapids.cudf.ColumnVector; import ai.rapids.cudf.Scalar; import com.nvidia.spark.RapidsUDF; import org.apache.spark.sql.api.java.UDF1; import java.math.BigDecimal; /** * A simple Java UDF demo for DecimalType, which extracts and returns the * fraction part of the input Decimal data. So, the output data has the * same precision and scale as the input one. */ public class DecimalFraction implements UDF1, RapidsUDF { @Override public BigDecimal call(BigDecimal dec) throws Exception { if (dec == null) { return null; } BigDecimal integral = new BigDecimal(dec.toBigInteger()); return dec.subtract(integral); } @Override public ColumnVector evaluateColumnar(int numRows, ColumnVector... args) { if (args.length != 1) { throw new IllegalArgumentException("Unexpected argument count: " + args.length); } ColumnVector input = args[0]; if (!input.getType().isDecimalType()) { throw new IllegalArgumentException("Argument type is not a decimal column: " + input.getType()); } try (Scalar nullScalar = Scalar.fromNull(input.getType()); ColumnVector nullPredicate = input.isNull(); ColumnVector integral = input.floor(); ColumnVector fraction = input.sub(integral, input.getType())) { return nullPredicate.ifElse(nullScalar, fraction); } } } ================================================ FILE: examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/java/com/nvidia/spark/rapids/udf/java/NativeUDFExamplesLoader.java ================================================ /* * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.nvidia.spark.rapids.udf.java; import ai.rapids.cudf.NativeDepsLoader; import java.io.IOException; /** Loads the native dependencies for UDF examples with a native implementation */ public class NativeUDFExamplesLoader { private static boolean isLoaded; /** Loads native UDF code if necessary */ public static synchronized void ensureLoaded() { if (!isLoaded) { try { NativeDepsLoader.loadNativeDeps(new String[]{"udfexamplesjni"}); isLoaded = true; } catch (IOException e) { throw new RuntimeException(e); } } } } ================================================ FILE: examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/java/com/nvidia/spark/rapids/udf/java/URLDecode.java ================================================ /* * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.nvidia.spark.rapids.udf.java; import ai.rapids.cudf.ColumnVector; import ai.rapids.cudf.DType; import ai.rapids.cudf.Scalar; import com.nvidia.spark.RapidsUDF; import org.apache.spark.sql.api.java.UDF1; import java.io.UnsupportedEncodingException; import java.net.URLDecoder; /** * A Java user-defined function (UDF) that decodes URL-encoded strings. * This class demonstrates how to implement a Java UDF that also * provides a RAPIDS implementation that can run on the GPU when the query * is executed with the RAPIDS Accelerator for Apache Spark. */ public class URLDecode implements UDF1, RapidsUDF { /** Row-by-row implementation that executes on the CPU */ @Override public String call(String s) { String result = null; if (s != null) { try { result = URLDecoder.decode(s, "utf-8"); } catch (IllegalArgumentException ignored) { result = s; } catch (UnsupportedEncodingException e) { // utf-8 is a builtin, standard encoding, so this should never happen throw new RuntimeException(e); } } return result; } /** Columnar implementation that runs on the GPU */ @Override public ColumnVector evaluateColumnar(int numRows, ColumnVector... args) { // The CPU implementation takes a single string argument, so similarly // there should only be one column argument of type STRING. if (args.length != 1) { throw new IllegalArgumentException("Unexpected argument count: " + args.length); } ColumnVector input = args[0]; if (numRows != input.getRowCount()) { throw new IllegalArgumentException("Expected " + numRows + " rows, received " + input.getRowCount()); } if (!input.getType().equals(DType.STRING)) { throw new IllegalArgumentException("Argument type is not a string column: " + input.getType()); } // The cudf urlDecode does not convert '+' to a space, so do that as a pre-pass first. // All intermediate results are closed to avoid leaking GPU resources. try (Scalar plusScalar = Scalar.fromString("+"); Scalar spaceScalar = Scalar.fromString(" "); ColumnVector replaced = input.stringReplace(plusScalar, spaceScalar)) { return replaced.urlDecode(); } } } ================================================ FILE: examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/java/com/nvidia/spark/rapids/udf/java/URLEncode.java ================================================ /* * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.nvidia.spark.rapids.udf.java; import ai.rapids.cudf.ColumnVector; import ai.rapids.cudf.DType; import com.nvidia.spark.RapidsUDF; import org.apache.spark.sql.api.java.UDF1; import java.io.UnsupportedEncodingException; import java.net.URLEncoder; /** * A Java user-defined function (UDF) that URL-encodes strings. * This class demonstrates how to implement a Java UDF that also * provides a RAPIDS implementation that can run on the GPU when the query * is executed with the RAPIDS Accelerator for Apache Spark. */ public class URLEncode implements UDF1, RapidsUDF { /** Row-by-row implementation that executes on the CPU */ @Override public String call(String s) { if (s == null) { return null; } try { return URLEncoder.encode(s, "utf-8") .replace("+", "%20") .replace("*", "%2A") .replace("%7E", "~"); } catch (UnsupportedEncodingException e) { // utf-8 is a builtin, standard encoding, so this should never happen throw new RuntimeException(e); } } /** Columnar implementation that runs on the GPU */ @Override public ColumnVector evaluateColumnar(int numRows, ColumnVector... args) { // The CPU implementation takes a single string argument, so similarly // there should only be one column argument of type STRING. if (args.length != 1) { throw new IllegalArgumentException("Unexpected argument count: " + args.length); } ColumnVector input = args[0]; if (numRows != input.getRowCount()) { throw new IllegalArgumentException("Expected " + numRows + " rows, received " + input.getRowCount()); } if (!input.getType().equals(DType.STRING)) { throw new IllegalArgumentException("Argument type is not a string column: " + input.getType()); } return input.urlEncode(); } } ================================================ FILE: examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/python/asserts.py ================================================ # Copyright (c) 2020-2022, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from conftest import is_incompat, should_sort_on_spark, should_sort_locally, get_float_check, get_limit, spark_jvm from datetime import date, datetime from decimal import Decimal import math from pyspark.sql import Row from py4j.protocol import Py4JJavaError import pytest from spark_session import with_cpu_session, with_gpu_session import time import types as pytypes import data_gen def _assert_equal(cpu, gpu, float_check, path): t = type(cpu) if (t is Row): assert len(cpu) == len(gpu), "CPU and GPU row have different lengths at {} CPU: {} GPU: {}".format(path, len(cpu), len(gpu)) if hasattr(cpu, "__fields__") and hasattr(gpu, "__fields__"): assert cpu.__fields__ == gpu.__fields__, "CPU and GPU row have different fields at {} CPU: {} GPU: {}".format(path, cpu.__fields__, gpu.__fields__) for field in cpu.__fields__: _assert_equal(cpu[field], gpu[field], float_check, path + [field]) else: for index in range(len(cpu)): _assert_equal(cpu[index], gpu[index], float_check, path + [index]) elif (t is list): assert len(cpu) == len(gpu), "CPU and GPU list have different lengths at {} CPU: {} GPU: {}".format(path, len(cpu), len(gpu)) for index in range(len(cpu)): _assert_equal(cpu[index], gpu[index], float_check, path + [index]) elif (t is tuple): assert len(cpu) == len(gpu), "CPU and GPU list have different lengths at {} CPU: {} GPU: {}".format(path, len(cpu), len(gpu)) for index in range(len(cpu)): _assert_equal(cpu[index], gpu[index], float_check, path + [index]) elif (t is pytypes.GeneratorType): index = 0 # generator has no zip :( so we have to do this the hard way done = False while not done: sub_cpu = None sub_gpu = None try: sub_cpu = next(cpu) except StopIteration: done = True try: sub_gpu = next(gpu) except StopIteration: done = True if done: assert sub_cpu == sub_gpu and sub_cpu == None, "CPU and GPU generators have different lengths at {}".format(path) else: _assert_equal(sub_cpu, sub_gpu, float_check, path + [index]) index = index + 1 elif (t is dict): # The order of key/values is not guaranteed in python dicts, nor are they guaranteed by Spark # so sort the items to do our best with ignoring the order of dicts cpu_items = list(cpu.items()).sort(key=_RowCmp) gpu_items = list(gpu.items()).sort(key=_RowCmp) _assert_equal(cpu_items, gpu_items, float_check, path + ["map"]) elif (t is int): assert cpu == gpu, "GPU and CPU int values are different at {}".format(path) elif (t is float): if (math.isnan(cpu)): assert math.isnan(gpu), "GPU and CPU float values are different at {}".format(path) else: assert float_check(cpu, gpu), "GPU and CPU float values are different {}".format(path) elif isinstance(cpu, str): assert cpu == gpu, "GPU and CPU string values are different at {}".format(path) elif isinstance(cpu, datetime): assert cpu == gpu, "GPU and CPU timestamp values are different at {}".format(path) elif isinstance(cpu, date): assert cpu == gpu, "GPU and CPU date values are different at {}".format(path) elif isinstance(cpu, bool): assert cpu == gpu, "GPU and CPU boolean values are different at {}".format(path) elif isinstance(cpu, Decimal): assert cpu == gpu, "GPU and CPU decimal values are different at {}".format(path) elif isinstance(cpu, bytearray): assert cpu == gpu, "GPU and CPU bytearray values are different at {}".format(path) elif (cpu == None): assert cpu == gpu, "GPU and CPU are not both null at {}".format(path) else: assert False, "Found unexpected type {} at {}".format(t, path) def assert_equal(cpu, gpu): """Verify that the result from the CPU and the GPU are equal""" try: _assert_equal(cpu, gpu, float_check=get_float_check(), path=[]) except: print("CPU OUTPUT: %s" % cpu) print("GPU OUTPUT: %s" % gpu) raise def _has_incompat_conf(conf): return ('spark.rapids.sql.incompatibleOps.enabled' in conf and conf['spark.rapids.sql.incompatibleOps.enabled'].lower() == 'true') class _RowCmp(object): """Allows for sorting Rows in a consistent way""" def __init__(self, wrapped): if isinstance(wrapped, Row) or isinstance(wrapped, list) or isinstance(wrapped, tuple): self.wrapped = [_RowCmp(c) for c in wrapped] elif isinstance(wrapped, dict): def sort_dict(e): return _RowCmp(e) tmp = [(k, v) for k, v in wrapped.items()] tmp.sort(key=sort_dict) self.wrapped = [_RowCmp(c) for c in tmp] else: self.wrapped = wrapped if isinstance(wrapped, float): self.is_nan = math.isnan(wrapped) else: self.is_nan = False def cmp(self, other): try: #None comes before anything else #NaN comes next if (self.wrapped is None and other.wrapped is None): return 0 elif (self.wrapped is None): return -1 elif (other.wrapped is None): return 1 elif self.is_nan and other.is_nan: return 0 elif self.is_nan: return -1 elif other.is_nan: return 1 elif self.wrapped == other.wrapped: return 0 elif self.wrapped < other.wrapped: return -1 else: return 1 except TypeError as te: print("ERROR TRYING TO COMPARE {} to {} {}".format(self.wrapped, other.wrapped, te)) raise te def __lt__(self, other): return self.cmp(other) < 0 def __gt__(self, other): return self.cmp(other) > 0 def __eq__(self, other): return self.cmp(other) == 0 def __le__(self, other): return self.cmp(other) <= 0 def __ge__(self, other): return self.cmp(other) >= 0 def __ne__(self, other): return self.cmp(other) != 0 def _prep_func_for_compare(func, mode): sort_locally = should_sort_locally() if should_sort_on_spark(): def with_sorted(spark): df = func(spark) return df.sort(df.columns) sorted_func = with_sorted else: sorted_func = func limit_val = get_limit() if limit_val > 0: def with_limit(spark): df = sorted_func(spark) return df.limit(limit_val) limit_func = with_limit else: limit_func = sorted_func if mode == 'COLLECT': bring_back = lambda spark: limit_func(spark).collect() collect_type = 'COLLECT' elif mode == 'COUNT': bring_back = lambda spark: limit_func(spark).count() collect_type = 'COUNT' elif mode == 'COLLECT_WITH_DATAFRAME': def bring_back(spark): df = limit_func(spark) return (df.collect(), df) collect_type = 'COLLECT' return (bring_back, collect_type) else: bring_back = lambda spark: limit_func(spark).toLocalIterator() collect_type = 'ITERATOR' if sort_locally: raise RuntimeError('Local Sort is only supported on a collect') return (bring_back, collect_type) def _prep_incompat_conf(conf): if is_incompat(): conf = dict(conf) # Make a copy before we change anything conf['spark.rapids.sql.incompatibleOps.enabled'] = 'true' elif _has_incompat_conf(conf): raise AssertionError("incompat must be enabled by the incompat fixture") return conf def _assert_gpu_and_cpu_writes_are_equal( write_func, read_func, base_path, mode, conf={}): conf = _prep_incompat_conf(conf) print('### CPU RUN ###') cpu_start = time.time() cpu_path = base_path + '/CPU' with_cpu_session(lambda spark : write_func(spark, cpu_path), conf=conf) cpu_end = time.time() print('### GPU RUN ###') gpu_start = time.time() gpu_path = base_path + '/GPU' with_gpu_session(lambda spark : write_func(spark, gpu_path), conf=conf) gpu_end = time.time() print('### WRITE: GPU TOOK {} CPU TOOK {} ###'.format( gpu_end - gpu_start, cpu_end - cpu_start)) (cpu_bring_back, cpu_collect_type) = _prep_func_for_compare( lambda spark: read_func(spark, cpu_path), mode) (gpu_bring_back, gpu_collect_type) = _prep_func_for_compare( lambda spark: read_func(spark, gpu_path), mode) from_cpu = with_cpu_session(cpu_bring_back, conf=conf) from_gpu = with_cpu_session(gpu_bring_back, conf=conf) if should_sort_locally(): from_cpu.sort(key=_RowCmp) from_gpu.sort(key=_RowCmp) assert_equal(from_cpu, from_gpu) def assert_gpu_and_cpu_writes_are_equal_collect(write_func, read_func, base_path, conf={}): """ Assert when running write_func on both the CPU and the GPU and reading using read_func ont he CPU that the results are equal. In this case the data is collected back to the driver and compared here, so be careful about the amount of data returned. """ _assert_gpu_and_cpu_writes_are_equal(write_func, read_func, base_path, 'COLLECT', conf=conf) def assert_gpu_and_cpu_writes_are_equal_iterator(write_func, read_func, base_path, conf={}): """ Assert when running write_func on both the CPU and the GPU and reading using read_func ont he CPU that the results are equal. In this case the data is pulled back to the driver in chunks and compared here so any amount of data can work, just be careful about how long it might take. """ _assert_gpu_and_cpu_writes_are_equal(write_func, read_func, base_path, 'ITERATOR', conf=conf) def assert_gpu_fallback_write(write_func, read_func, base_path, cpu_fallback_class_name, conf={}): conf = _prep_incompat_conf(conf) print('### CPU RUN ###') cpu_start = time.time() cpu_path = base_path + '/CPU' with_cpu_session(lambda spark : write_func(spark, cpu_path), conf=conf) cpu_end = time.time() print('### GPU RUN ###') jvm = spark_jvm() jvm.org.apache.spark.sql.rapids.ExecutionPlanCaptureCallback.startCapture() gpu_start = time.time() gpu_path = base_path + '/GPU' with_gpu_session(lambda spark : write_func(spark, gpu_path), conf=conf) gpu_end = time.time() jvm.org.apache.spark.sql.rapids.ExecutionPlanCaptureCallback.assertCapturedAndGpuFellBack(cpu_fallback_class_name, 10000) print('### WRITE: GPU TOOK {} CPU TOOK {} ###'.format( gpu_end - gpu_start, cpu_end - cpu_start)) (cpu_bring_back, cpu_collect_type) = _prep_func_for_compare( lambda spark: read_func(spark, cpu_path), 'COLLECT') (gpu_bring_back, gpu_collect_type) = _prep_func_for_compare( lambda spark: read_func(spark, gpu_path), 'COLLECT') from_cpu = with_cpu_session(cpu_bring_back, conf=conf) from_gpu = with_cpu_session(gpu_bring_back, conf=conf) if should_sort_locally(): from_cpu.sort(key=_RowCmp) from_gpu.sort(key=_RowCmp) assert_equal(from_cpu, from_gpu) def assert_cpu_and_gpu_are_equal_collect_with_capture(func, exist_classes='', non_exist_classes='', conf={}): (bring_back, collect_type) = _prep_func_for_compare(func, 'COLLECT_WITH_DATAFRAME') conf = _prep_incompat_conf(conf) print('### CPU RUN ###') cpu_start = time.time() from_cpu, cpu_df = with_cpu_session(bring_back, conf=conf) cpu_end = time.time() print('### GPU RUN ###') gpu_start = time.time() from_gpu, gpu_df = with_gpu_session(bring_back, conf=conf) gpu_end = time.time() jvm = spark_jvm() if exist_classes: for clz in exist_classes.split(','): jvm.org.apache.spark.sql.rapids.ExecutionPlanCaptureCallback.assertContains(gpu_df._jdf, clz) if non_exist_classes: for clz in non_exist_classes.split(','): jvm.org.apache.spark.sql.rapids.ExecutionPlanCaptureCallback.assertNotContain(gpu_df._jdf, clz) print('### {}: GPU TOOK {} CPU TOOK {} ###'.format(collect_type, gpu_end - gpu_start, cpu_end - cpu_start)) if should_sort_locally(): from_cpu.sort(key=_RowCmp) from_gpu.sort(key=_RowCmp) assert_equal(from_cpu, from_gpu) def assert_cpu_and_gpu_are_equal_sql_with_capture(df_fun, sql, table_name, exist_classes='', non_exist_classes='', conf=None, debug=False): if conf is None: conf = {} def do_it_all(spark): df = df_fun(spark) df.createOrReplaceTempView(table_name) if debug: return data_gen.debug_df(spark.sql(sql)) else: return spark.sql(sql) assert_cpu_and_gpu_are_equal_collect_with_capture(do_it_all, exist_classes, non_exist_classes, conf) def assert_gpu_fallback_collect(func, cpu_fallback_class_name, conf={}): (bring_back, collect_type) = _prep_func_for_compare(func, 'COLLECT_WITH_DATAFRAME') conf = _prep_incompat_conf(conf) print('### CPU RUN ###') cpu_start = time.time() from_cpu, cpu_df = with_cpu_session(bring_back, conf=conf) cpu_end = time.time() print('### GPU RUN ###') gpu_start = time.time() from_gpu, gpu_df = with_gpu_session(bring_back, conf=conf) gpu_end = time.time() jvm = spark_jvm() jvm.org.apache.spark.sql.rapids.ExecutionPlanCaptureCallback.assertDidFallBack(gpu_df._jdf, cpu_fallback_class_name) print('### {}: GPU TOOK {} CPU TOOK {} ###'.format(collect_type, gpu_end - gpu_start, cpu_end - cpu_start)) if should_sort_locally(): from_cpu.sort(key=_RowCmp) from_gpu.sort(key=_RowCmp) assert_equal(from_cpu, from_gpu) def assert_gpu_sql_fallback_collect(df_fun, cpu_fallback_class_name, table_name, sql, conf=None, debug=False): if conf is None: conf = {} def do_it_all(spark): df = df_fun(spark) df.createOrReplaceTempView(table_name) if debug: return data_gen.debug_df(spark.sql(sql)) else: return spark.sql(sql) assert_gpu_fallback_collect(do_it_all, cpu_fallback_class_name, conf) def _assert_gpu_and_cpu_are_equal(func, mode, conf={}, is_cpu_first=True): (bring_back, collect_type) = _prep_func_for_compare(func, mode) conf = _prep_incompat_conf(conf) def run_on_cpu(): print('### CPU RUN ###') global cpu_start cpu_start = time.time() global from_cpu from_cpu = with_cpu_session(bring_back, conf=conf) global cpu_end cpu_end = time.time() def run_on_gpu(): print('### GPU RUN ###') global gpu_start gpu_start = time.time() global from_gpu from_gpu = with_gpu_session(bring_back, conf=conf) global gpu_end gpu_end = time.time() if is_cpu_first: run_on_cpu() run_on_gpu() else: run_on_gpu() run_on_cpu() print('### {}: GPU TOOK {} CPU TOOK {} ###'.format(collect_type, gpu_end - gpu_start, cpu_end - cpu_start)) if should_sort_locally(): from_cpu.sort(key=_RowCmp) from_gpu.sort(key=_RowCmp) assert_equal(from_cpu, from_gpu) def run_with_cpu(func, mode, conf={}): (bring_back, collect_type) = _prep_func_for_compare(func, mode) conf = _prep_incompat_conf(conf) print("run_with_cpu") def run_on_cpu(): print('### CPU RUN ###') global cpu_start cpu_start = time.time() global from_cpu from_cpu = with_cpu_session(bring_back, conf=conf) global cpu_end cpu_end = time.time() run_on_cpu() print('### {}: CPU TOOK {} ###'.format(collect_type, cpu_end - cpu_start)) if should_sort_locally(): from_cpu.sort(key=_RowCmp) return from_cpu def run_with_cpu_and_gpu(func, mode, conf={}): (bring_back, collect_type) = _prep_func_for_compare(func, mode) conf = _prep_incompat_conf(conf) def run_on_cpu(): print('### CPU RUN ###') global cpu_start cpu_start = time.time() global from_cpu from_cpu = with_cpu_session(bring_back, conf=conf) global cpu_end cpu_end = time.time() def run_on_gpu(): print('### GPU RUN ###') global gpu_start gpu_start = time.time() global from_gpu from_gpu = with_gpu_session(bring_back, conf=conf) global gpu_end gpu_end = time.time() run_on_cpu() run_on_gpu() print('### {}: GPU TOOK {} CPU TOOK {} ###'.format(collect_type, gpu_end - gpu_start, cpu_end - cpu_start)) if should_sort_locally(): from_cpu.sort(key=_RowCmp) from_gpu.sort(key=_RowCmp) return (from_cpu, from_gpu) def assert_gpu_and_cpu_are_equal_collect(func, conf={}, is_cpu_first=True): """ Assert when running func on both the CPU and the GPU that the results are equal. In this case the data is collected back to the driver and compared here, so be careful about the amount of data returned. """ _assert_gpu_and_cpu_are_equal(func, 'COLLECT', conf=conf, is_cpu_first=is_cpu_first) def assert_gpu_and_cpu_are_equal_iterator(func, conf={}, is_cpu_first=True): """ Assert when running func on both the CPU and the GPU that the results are equal. In this case the data is pulled back to the driver in chunks and compared here so any amount of data can work, just be careful about how long it might take. """ _assert_gpu_and_cpu_are_equal(func, 'ITERATOR', conf=conf, is_cpu_first=is_cpu_first) def assert_gpu_and_cpu_row_counts_equal(func, conf={}, is_cpu_first=True): """ Assert that the row counts from running the func are the same on both the CPU and GPU. This function runs count() to only get the number of rows and compares that count between the CPU and GPU. It does NOT compare any underlying data. """ _assert_gpu_and_cpu_are_equal(func, 'COUNT', conf=conf, is_cpu_first=is_cpu_first) def assert_gpu_and_cpu_are_equal_sql(df_fun, table_name, sql, conf=None, debug=False, is_cpu_first=True, validate_execs_in_gpu_plan=[]): """ Assert that the specified SQL query produces equal results on CPU and GPU. :param df_fun: a function that will create the dataframe :param table_name: Name of table to be created with the dataframe :param sql: SQL query to be run on the specified table :param conf: Any user-specified confs. Empty by default. :param debug: Boolean to indicate if the SQL output should be printed :param is_cpu_first: Boolean to indicate if the CPU should be run first or not :param validate_execs_in_gpu_plan: String list of expressions to be validated in the GPU plan. :return: Assertion failure, if results from CPU and GPU do not match. """ if conf is None: conf = {} def do_it_all(spark): df = df_fun(spark) df.createOrReplaceTempView(table_name) # we hold off on setting the validate execs until after creating the temp view spark.conf.set('spark.rapids.sql.test.validateExecsInGpuPlan', ','.join(validate_execs_in_gpu_plan)) if debug: return data_gen.debug_df(spark.sql(sql)) else: return spark.sql(sql) assert_gpu_and_cpu_are_equal_collect(do_it_all, conf, is_cpu_first=is_cpu_first) def assert_py4j_exception(func, error_message): """ Assert that a specific Java exception is thrown :param func: a function to be verified :param error_message: a string such as the one produce by java.lang.Exception.toString :return: Assertion failure if no exception matching error_message has occurred. """ with pytest.raises(Py4JJavaError) as py4jError: func() assert error_message in str(py4jError.value.java_exception) def assert_gpu_and_cpu_error(df_fun, conf, error_message): """ Assert that GPU and CPU execution results in a specific Java exception thrown :param df_fun: a function to be verified :param conf: Spark config :param error_message: a string such as the one produce by java.lang.Exception.toString :return: Assertion failure if either GPU or CPU versions has not generated error messages expected """ assert_py4j_exception(lambda: with_cpu_session(df_fun, conf), error_message) assert_py4j_exception(lambda: with_gpu_session(df_fun, conf), error_message) def with_cpu_sql(df_fun, table_name, sql, conf=None, debug=False): if conf is None: conf = {} def do_it_all(spark): df = df_fun(spark) df.createOrReplaceTempView(table_name) if debug: return data_gen.debug_df(spark.sql(sql)) else: return spark.sql(sql) assert_gpu_and_cpu_are_equal_collect(do_it_all, conf, is_cpu_first=is_cpu_first) ================================================ FILE: examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/python/conftest.py ================================================ # Copyright (c) 2020-2022, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import pytest import random from spark_init_internal import get_spark_i_know_what_i_am_doing from pyspark.sql.dataframe import DataFrame _approximate_float_args = None def get_float_check(): if not _approximate_float_args is None: return lambda lhs,rhs: lhs == pytest.approx(rhs, **_approximate_float_args) else: return lambda lhs,rhs: lhs == rhs _incompat = False def is_incompat(): return _incompat _sort_on_spark = False _sort_locally = False def should_sort_on_spark(): return _sort_on_spark def should_sort_locally(): return _sort_locally _allow_any_non_gpu = False _non_gpu_allowed = [] def is_allowing_any_non_gpu(): return _allow_any_non_gpu def get_non_gpu_allowed(): return _non_gpu_allowed def get_validate_execs_in_gpu_plan(): return _validate_execs_in_gpu_plan _runtime_env = "apache" def runtime_env(): return _runtime_env.lower() def is_apache_runtime(): return runtime_env() == "apache" def is_databricks_runtime(): return runtime_env() == "databricks" def is_emr_runtime(): return runtime_env() == "emr" def is_dataproc_runtime(): return runtime_env() == "dataproc" _is_nightly_run = False _is_precommit_run = False def is_nightly_run(): return _is_nightly_run def is_at_least_precommit_run(): return _is_nightly_run or _is_precommit_run def skip_unless_nightly_tests(description): if (_is_nightly_run): raise AssertionError(description + ' during nightly test run') else: pytest.skip(description) def skip_unless_precommit_tests(description): if (_is_nightly_run): raise AssertionError(description + ' during nightly test run') elif (_is_precommit_run): raise AssertionError(description + ' during pre-commit test run') else: pytest.skip(description) _limit = -1 def get_limit(): return _limit def _get_limit_from_mark(mark): if mark.args: return mark.args[0] else: return mark.kwargs.get('num_rows', 100000) def pytest_runtest_setup(item): global _sort_on_spark global _sort_locally order = item.get_closest_marker('ignore_order') if order: if order.kwargs.get('local', False): _sort_on_spark = False _sort_locally = True else: _sort_on_spark = True _sort_locally = False else: _sort_on_spark = False _sort_locally = False global _incompat if item.get_closest_marker('incompat'): _incompat = True else: _incompat = False global _approximate_float_args app_f = item.get_closest_marker('approximate_float') if app_f: _approximate_float_args = app_f.kwargs else: _approximate_float_args = None global _allow_any_non_gpu global _non_gpu_allowed _non_gpu_allowed_databricks = [] _allow_any_non_gpu_databricks = False non_gpu_databricks = item.get_closest_marker('allow_non_gpu_databricks') non_gpu = item.get_closest_marker('allow_non_gpu') if non_gpu_databricks: if is_databricks_runtime(): if non_gpu_databricks.kwargs and non_gpu_databricks.kwargs['any']: _allow_any_non_gpu_databricks = True elif non_gpu_databricks.args: _non_gpu_allowed_databricks = non_gpu_databricks.args else: pytest.warn('allow_non_gpu_databricks marker without anything allowed') if non_gpu: if non_gpu.kwargs and non_gpu.kwargs['any']: _allow_any_non_gpu = True _non_gpu_allowed = [] elif non_gpu.args: _allow_any_non_gpu = False _non_gpu_allowed = non_gpu.args else: pytest.warn('allow_non_gpu marker without anything allowed') _allow_any_non_gpu = False _non_gpu_allowed = [] else: _allow_any_non_gpu = False _non_gpu_allowed = [] _allow_any_non_gpu = _allow_any_non_gpu | _allow_any_non_gpu_databricks if _non_gpu_allowed and _non_gpu_allowed_databricks: _non_gpu_allowed = _non_gpu_allowed + _non_gpu_allowed_databricks elif _non_gpu_allowed_databricks: _non_gpu_allowed = _non_gpu_allowed_databricks global _validate_execs_in_gpu_plan validate_execs = item.get_closest_marker('validate_execs_in_gpu_plan') if validate_execs and validate_execs.args: _validate_execs_in_gpu_plan = validate_execs.args else: _validate_execs_in_gpu_plan = [] global _limit limit_mrk = item.get_closest_marker('limit') if limit_mrk: _limit = _get_limit_from_mark(limit_mrk) else: _limit = -1 def pytest_configure(config): global _runtime_env _runtime_env = config.getoption('runtime_env') global _is_nightly_run global _is_precommit_run test_type = config.getoption('test_type').lower() if "nightly" == test_type: _is_nightly_run = True elif "pre-commit" == test_type: _is_precommit_run = True elif "developer" != test_type: raise Exception("not supported test type {}".format(test_type)) def pytest_collection_modifyitems(config, items): for item in items: extras = [] order = item.get_closest_marker('ignore_order') if order: if order.kwargs: extras.append('IGNORE_ORDER(' + str(order.kwargs) + ')') else: extras.append('IGNORE_ORDER') if item.get_closest_marker('incompat'): extras.append('INCOMPAT') app_f = item.get_closest_marker('approximate_float') if app_f: if app_f.kwargs: extras.append('APPROXIMATE_FLOAT(' + str(app_f.kwargs) + ')') else: extras.append('APPROXIMATE_FLOAT') non_gpu = item.get_closest_marker('allow_non_gpu') if non_gpu: if non_gpu.kwargs and non_gpu.kwargs['any']: extras.append('ALLOW_NON_GPU(ANY)') elif non_gpu.args: extras.append('ALLOW_NON_GPU(' + ','.join(non_gpu.args) + ')') limit_mrk = item.get_closest_marker('limit') if limit_mrk: extras.append('LIMIT({})'.format(_get_limit_from_mark(limit_mrk))) if extras: # This is not ideal because we are reaching into an internal value item._nodeid = item.nodeid + '[' + ', '.join(extras) + ']' @pytest.fixture(scope="session") def std_input_path(request): path = request.config.getoption("std_input_path") if path is None: skip_unless_precommit_tests("std_input_path is not configured") else: yield path @pytest.fixture def spark_tmp_path(request): debug = request.config.getoption('debug_tmp_path') ret = request.config.getoption('tmp_path') if ret is None: ret = '/tmp/pyspark_tests/' ret = ret + '/' + str(random.randint(0, 1000000)) + '/' # Make sure it is there and accessible sc = get_spark_i_know_what_i_am_doing().sparkContext config = sc._jsc.hadoopConfiguration() path = sc._jvm.org.apache.hadoop.fs.Path(ret) fs = sc._jvm.org.apache.hadoop.fs.FileSystem.get(config) fs.mkdirs(path) yield ret if not debug: fs.delete(path) class TmpTableFactory: def __init__(self, base_id): self.base_id = base_id self.running_id = 0 def get(self): ret = '{}_{}'.format(self.base_id, self.running_id) self.running_id = self.running_id + 1 return ret @pytest.fixture def spark_tmp_table_factory(request): base_id = 'tmp_table_{}'.format(random.randint(0, 1000000)) yield TmpTableFactory(base_id) sp = get_spark_i_know_what_i_am_doing() tables = sp.sql("SHOW TABLES".format(base_id)).collect() for row in tables: t_name = row['tableName'] if (t_name.startswith(base_id)): sp.sql("DROP TABLE IF EXISTS {}".format(t_name)) def _get_jvm_session(spark): return spark._jsparkSession def _get_jvm(spark): return spark.sparkContext._jvm def spark_jvm(): return _get_jvm(get_spark_i_know_what_i_am_doing()) class MortgageRunner: def __init__(self, mortgage_format, mortgage_acq_path, mortgage_perf_path): self.mortgage_format = mortgage_format self.mortgage_acq_path = mortgage_acq_path self.mortgage_perf_path = mortgage_perf_path def do_test_query(self, spark): jvm_session = _get_jvm_session(spark) jvm = _get_jvm(spark) acq = self.mortgage_acq_path perf = self.mortgage_perf_path run = jvm.com.nvidia.spark.rapids.tests.mortgage.Run if self.mortgage_format == 'csv': df = run.csv(jvm_session, perf, acq) elif self.mortgage_format == 'parquet': df = run.parquet(jvm_session, perf, acq) elif self.mortgage_format == 'orc': df = run.orc(jvm_session, perf, acq) else: raise AssertionError('Not Supported Format {}'.format(self.mortgage_format)) return DataFrame(df, spark.getActiveSession()) @pytest.fixture(scope="session") def mortgage(request): mortgage_format = request.config.getoption("mortgage_format") mortgage_path = request.config.getoption("mortgage_path") if mortgage_path is None: std_path = request.config.getoption("std_input_path") if std_path is None: skip_unless_precommit_tests("Mortgage tests are not configured to run") else: yield MortgageRunner('parquet', std_path + '/parquet_acq', std_path + '/parquet_perf') else: yield MortgageRunner(mortgage_format, mortgage_path + '/acq', mortgage_path + '/perf') @pytest.fixture(scope="session") def enable_cudf_udf(request): enable_udf_cudf = request.config.getoption("cudf_udf") if not enable_udf_cudf: # cudf_udf tests are not required for any test runs pytest.skip("cudf_udf not configured to run") @pytest.fixture(scope="session") def enable_rapids_udf_example_native(request): native_enabled = request.config.getoption("rapids_udf_example_native") if not native_enabled: # udf_example_native tests are not required for any test runs pytest.skip("rapids_udf_example_native is not configured to run") ================================================ FILE: examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/python/data_gen.py ================================================ # Copyright (c) 2020-2022, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy from datetime import date, datetime, timedelta, timezone from decimal import * import math from pyspark.context import SparkContext from pyspark.sql import Row from pyspark.sql.types import * import pyspark.sql.functions as f import pytest import random from spark_session import is_tz_utc import sre_yield import struct from conftest import skip_unless_precommit_tests class DataGen: """Base class for data generation""" def __repr__(self): if not self.nullable: return self.__class__.__name__[:-3] + '(not_null)' return self.__class__.__name__[:-3] def __hash__(self): return hash(str(self)) def __eq__(self, other): return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ def __ne__(self, other): return not self.__eq__(other) def __init__(self, data_type, nullable=True, special_cases =[]): self.data_type = data_type self.list_of_special_cases = special_cases self._special_cases = [] if isinstance(nullable, tuple): self.nullable = nullable[0] weight = nullable[1] else: self.nullable = nullable weight = 5.0 if self.nullable: self.with_special_case(None, weight) # Special cases can be a value or a tuple of (value, weight). If the # special_case itself is a tuple as in the case of StructGen, it MUST be added with a # weight like : ((special_case_tuple_v1, special_case_tuple_v2), weight). for element in special_cases: if isinstance(element, tuple): self.with_special_case(element[0], element[1]) else: self.with_special_case(element) def copy_special_case(self, special_case, weight=1.0): # it would be good to do a deepcopy, but sre_yield is not happy with that. c = copy.copy(self) c._special_cases = copy.deepcopy(self._special_cases) return c.with_special_case(special_case, weight=weight) def with_special_case(self, special_case, weight=1.0): """ Add in a special case with a given weight. A special case can either be a function that takes an instance of Random and returns the generated data or it can be a constant. By default the weight is 1.0, and the default number generation's weight is 100.0. The number of lines that are generate in the data set should be proportional to the its weight/sum weights """ if callable(special_case): sc = special_case else: sc = lambda rand: special_case self._special_cases.append((weight, sc)) return self def get_types(self): return 'DataType: {}, nullable: {}, special_cases: {}'.format(self.data_type, self.nullable, self.list_of_special_cases) def start(self, rand): """Start data generation using the given rand""" raise TypeError('Children should implement this method and call _start') def _start(self, rand, gen_func): """Start internally, but use the given gen_func as the base""" if not self._special_cases: self._gen_func = gen_func else: weighted_choices = [(100.0, lambda rand: gen_func())] weighted_choices.extend(self._special_cases) total = float(sum(weight for weight,gen in weighted_choices)) normalized_choices = [(weight/total, gen) for weight,gen in weighted_choices] def choose_one(): pick = rand.random() total = 0 for (weight, gen) in normalized_choices: total += weight if total >= pick: return gen(rand) raise RuntimeError('Random did not pick something we expected') self._gen_func = choose_one def gen(self, force_no_nulls=False): """generate the next line""" if not self._gen_func: raise RuntimeError('start must be called before generating any data') v = self._gen_func() if force_no_nulls: while v is None: v = self._gen_func() return v def contains_ts(self): """Checks if this contains a TimestampGen""" return False class ConvertGen(DataGen): """Provides a way to modify the data before it is returned""" def __init__(self, child_gen, func, data_type=None, nullable=True): if data_type is None: data_type = child_gen.data_type super().__init__(data_type, nullable=nullable) self._child_gen = child_gen self._func = func def __repr__(self): return super().__repr__() + '(' + str(self._child_gen) + ')' def start(self, rand): self._child_gen.start(rand) def modify(): return self._func(self._child_gen.gen()) self._start(rand, modify) _MAX_CHOICES = 1 << 64 class StringGen(DataGen): """Generate strings that match a pattern""" def __init__(self, pattern="(.|\n){1,30}", flags=0, charset=sre_yield.CHARSET, nullable=True): super().__init__(StringType(), nullable=nullable) self.base_strs = sre_yield.AllStrings(pattern, flags=flags, charset=charset, max_count=_MAX_CHOICES) def with_special_pattern(self, pattern, flags=0, charset=sre_yield.CHARSET, weight=1.0): """ Like with_special_case but you can provide a regexp pattern instead of a hard coded string value. """ strs = sre_yield.AllStrings(pattern, flags=flags, charset=charset, max_count=_MAX_CHOICES) try: length = int(len(strs)) except OverflowError: length = _MAX_CHOICES return self.with_special_case(lambda rand : strs[rand.randrange(0, length)], weight=weight) def start(self, rand): strs = self.base_strs try: length = int(len(strs)) except OverflowError: length = _MAX_CHOICES self._start(rand, lambda : strs[rand.randrange(0, length)]) BYTE_MIN = -(1 << 7) BYTE_MAX = (1 << 7) - 1 class ByteGen(DataGen): """Generate Bytes""" def __init__(self, nullable=True, min_val = BYTE_MIN, max_val = BYTE_MAX, special_cases=[]): super().__init__(ByteType(), nullable=nullable, special_cases=special_cases) self._min_val = min_val self._max_val = max_val def start(self, rand): self._start(rand, lambda : rand.randint(self._min_val, self._max_val)) SHORT_MIN = -(1 << 15) SHORT_MAX = (1 << 15) - 1 class ShortGen(DataGen): """Generate Shorts, which some built in corner cases.""" def __init__(self, nullable=True, min_val = SHORT_MIN, max_val = SHORT_MAX, special_cases = [SHORT_MIN, SHORT_MAX, 0, 1, -1]): super().__init__(ShortType(), nullable=nullable, special_cases=special_cases) self._min_val = min_val self._max_val = max_val def start(self, rand): self._start(rand, lambda : rand.randint(self._min_val, self._max_val)) INT_MIN = -(1 << 31) INT_MAX = (1 << 31) - 1 class IntegerGen(DataGen): """Generate Ints, which some built in corner cases.""" def __init__(self, nullable=True, min_val = INT_MIN, max_val = INT_MAX, special_cases = [INT_MIN, INT_MAX, 0, 1, -1]): super().__init__(IntegerType(), nullable=nullable, special_cases=special_cases) self._min_val = min_val self._max_val = max_val def start(self, rand): self._start(rand, lambda : rand.randint(self._min_val, self._max_val)) class DecimalGen(DataGen): """Generate Decimals, with some built in corner cases.""" def __init__(self, precision=None, scale=None, nullable=True, special_cases=[]): if precision is None: #Maximum number of decimal digits a Long can represent is 18 precision = 18 scale = 0 DECIMAL_MIN = Decimal('-' + ('9' * precision) + 'e' + str(-scale)) DECIMAL_MAX = Decimal(('9'* precision) + 'e' + str(-scale)) super().__init__(DecimalType(precision, scale), nullable=nullable, special_cases=special_cases) self.scale = scale self.precision = precision pattern = "[0-9]{1,"+ str(precision) + "}e" + str(-scale) self.base_strs = sre_yield.AllStrings(pattern, flags=0, charset=sre_yield.CHARSET, max_count=_MAX_CHOICES) def __repr__(self): return super().__repr__() + '(' + str(self.precision) + ',' + str(self.scale) + ')' def start(self, rand): strs = self.base_strs try: length = int(strs.length) except OverflowError: length = _MAX_CHOICES self._start(rand, lambda : Decimal(strs[rand.randrange(0, length)])) LONG_MIN = -(1 << 63) LONG_MAX = (1 << 63) - 1 class LongGen(DataGen): """Generate Longs, which some built in corner cases.""" def __init__(self, nullable=True, min_val = LONG_MIN, max_val = LONG_MAX, special_cases = []): _special_cases = [min_val, max_val, 0, 1, -1] if not special_cases else special_cases super().__init__(LongType(), nullable=nullable, special_cases=_special_cases) self._min_val = min_val self._max_val = max_val def start(self, rand): self._start(rand, lambda : rand.randint(self._min_val, self._max_val)) class LongRangeGen(DataGen): """Generate Longs in incrementing order.""" def __init__(self, nullable=False, start_val=0, direction="inc"): super().__init__(LongType(), nullable=nullable) self._start_val = start_val self._current_val = start_val if (direction == "dec"): def dec_it(): tmp = self._current_val self._current_val -= 1 return tmp self._do_it = dec_it else: def inc_it(): tmp = self._current_val self._current_val += 1 return tmp self._do_it = inc_it def start(self, rand): self._current_val = self._start_val self._start(rand, self._do_it) class RepeatSeqGen(DataGen): """Generate Repeated seq of `length` random items""" def __init__(self, child, length): super().__init__(child.data_type, nullable=False) self.nullable = child.nullable self._child = child self._vals = [] self._length = length self._index = 0 def __repr__(self): return super().__repr__() + '(' + str(self._child) + ')' def _loop_values(self): ret = self._vals[self._index] self._index = (self._index + 1) % self._length return ret def start(self, rand): self._index = 0 self._child.start(rand) self._start(rand, self._loop_values) self._vals = [self._child.gen() for _ in range(0, self._length)] class SetValuesGen(DataGen): """A set of values that are randomly selected""" def __init__(self, data_type, data): super().__init__(data_type, nullable=False) self.nullable = any(x is None for x in data) self._vals = data def __repr__(self): return super().__repr__() + '(' + str(self._child) + ')' def start(self, rand): data = self._vals length = len(data) self._start(rand, lambda : data[rand.randrange(0, length)]) FLOAT_MIN = -3.4028235E38 FLOAT_MAX = 3.4028235E38 NEG_FLOAT_NAN_MIN_VALUE = struct.unpack('f', struct.pack('I', 0xffffffff))[0] NEG_FLOAT_NAN_MAX_VALUE = struct.unpack('f', struct.pack('I', 0xff800001))[0] POS_FLOAT_NAN_MIN_VALUE = struct.unpack('f', struct.pack('I', 0x7f800001))[0] POS_FLOAT_NAN_MAX_VALUE = struct.unpack('f', struct.pack('I', 0x7fffffff))[0] class FloatGen(DataGen): """Generate floats, which some built in corner cases.""" def __init__(self, nullable=True, no_nans=False, special_cases=None): self._no_nans = no_nans if special_cases is None: special_cases = [FLOAT_MIN, FLOAT_MAX, 0.0, -0.0, 1.0, -1.0] if not no_nans: special_cases.append(float('inf')) special_cases.append(float('-inf')) special_cases.append(float('nan')) special_cases.append(NEG_FLOAT_NAN_MAX_VALUE) super().__init__(FloatType(), nullable=nullable, special_cases=special_cases) def _fixup_nans(self, v): if self._no_nans and (math.isnan(v) or v == math.inf or v == -math.inf): v = None if self.nullable else 0.0 return v def start(self, rand): def gen_float(): i = rand.randint(INT_MIN, INT_MAX) p = struct.pack('i', i) return self._fixup_nans(struct.unpack('f', p)[0]) self._start(rand, gen_float) DOUBLE_MIN_EXP = -1022 DOUBLE_MAX_EXP = 1023 DOUBLE_MAX_FRACTION = int('1'*52, 2) DOUBLE_MIN = -1.7976931348623157E308 DOUBLE_MAX = 1.7976931348623157E308 NEG_DOUBLE_NAN_MIN_VALUE = struct.unpack('d', struct.pack('L', 0xffffffffffffffff))[0] NEG_DOUBLE_NAN_MAX_VALUE = struct.unpack('d', struct.pack('L', 0xfff0000000000001))[0] POS_DOUBLE_NAN_MIN_VALUE = struct.unpack('d', struct.pack('L', 0x7ff0000000000001))[0] POS_DOUBLE_NAN_MAX_VALUE = struct.unpack('d', struct.pack('L', 0x7fffffffffffffff))[0] class DoubleGen(DataGen): """Generate doubles, which some built in corner cases.""" def __init__(self, min_exp=DOUBLE_MIN_EXP, max_exp=DOUBLE_MAX_EXP, no_nans=False, nullable=True, special_cases = None): self._min_exp = min_exp self._max_exp = max_exp self._no_nans = no_nans self._use_full_range = (self._min_exp == DOUBLE_MIN_EXP) and (self._max_exp == DOUBLE_MAX_EXP) if special_cases is None: special_cases = [ self.make_from(1, self._max_exp, DOUBLE_MAX_FRACTION), self.make_from(0, self._max_exp, DOUBLE_MAX_FRACTION), self.make_from(1, self._min_exp, DOUBLE_MAX_FRACTION), self.make_from(0, self._min_exp, DOUBLE_MAX_FRACTION) ] if self._min_exp <= 0 and self._max_exp >= 0: special_cases.append(0.0) special_cases.append(-0.0) if self._min_exp <= 3 and self._max_exp >= 3: special_cases.append(1.0) special_cases.append(-1.0) if not no_nans: special_cases.append(float('inf')) special_cases.append(float('-inf')) special_cases.append(float('nan')) special_cases.append(NEG_DOUBLE_NAN_MAX_VALUE) super().__init__(DoubleType(), nullable=nullable, special_cases=special_cases) @staticmethod def make_from(sign, exp, fraction): sign = sign & 1 # 1 bit exp = (exp + 1023) & 0x7FF # add bias and 11 bits fraction = fraction & DOUBLE_MAX_FRACTION i = (sign << 63) | (exp << 52) | fraction p = struct.pack('L', i) ret = struct.unpack('d', p)[0] return ret def _fixup_nans(self, v): if self._no_nans and (math.isnan(v) or v == math.inf or v == -math.inf): v = None if self.nullable else 0.0 return v def start(self, rand): if self._use_full_range: def gen_double(): i = rand.randint(LONG_MIN, LONG_MAX) p = struct.pack('l', i) return self._fixup_nans(struct.unpack('d', p)[0]) self._start(rand, gen_double) else: def gen_part_double(): sign = rand.getrandbits(1) exp = rand.randint(self._min_exp, self._max_exp) fraction = rand.getrandbits(52) return self._fixup_nans(self.make_from(sign, exp, fraction)) self._start(rand, gen_part_double) class BooleanGen(DataGen): """Generate Bools (True/False)""" def __init__(self, nullable=True): super().__init__(BooleanType(), nullable=nullable) def start(self, rand): self._start(rand, lambda : bool(rand.getrandbits(1))) class StructGen(DataGen): """Generate a Struct""" def __init__(self, children, nullable=True, special_cases=[]): """ Initialize the struct with children. The children should be of the form: [('name', Gen),('name_2', Gen2)] Where name is the name of the strict field and Gens are Generators of the type for that entry. """ tmp = [StructField(name, child.data_type, nullable=child.nullable) for name, child in children] super().__init__(StructType(tmp), nullable=nullable, special_cases=special_cases) self.children = children def __repr__(self): return super().__repr__() + '(' + ','.join([str(i) for i in self.children]) + ')' def start(self, rand): for name, child in self.children: child.start(rand) def make_tuple(): data = [child.gen() for name, child in self.children] return tuple(data) self._start(rand, make_tuple) def contains_ts(self): return any(child[1].contains_ts() for child in self.children) class DateGen(DataGen): """Generate Dates in a given range""" def __init__(self, start=None, end=None, nullable=True): super().__init__(DateType(), nullable=nullable) if start is None: # Spark supports times starting at # "0001-01-01 00:00:00.000000" start = date(1, 1, 1) elif not isinstance(start, date): raise RuntimeError('Unsupported type passed in for start {}'.format(start)) if end is None: # Spark supports time through # "9999-12-31 23:59:59.999999" end = date(9999, 12, 31) elif isinstance(end, timedelta): end = start + end elif not isinstance(start, date): raise RuntimeError('Unsupported type passed in for end {}'.format(end)) self._start_day = self._to_days_since_epoch(start) self._end_day = self._to_days_since_epoch(end) self.with_special_case(start) self.with_special_case(end) # we want a few around the leap year if possible step = int((end.year - start.year) / 5.0) if (step != 0): years = {self._guess_leap_year(y) for y in range(start.year, end.year, step)} for y in years: leap_day = date(y, 2, 29) if (leap_day > start and leap_day < end): self.with_special_case(leap_day) next_day = date(y, 3, 1) if (next_day > start and next_day < end): self.with_special_case(next_day) @staticmethod def _guess_leap_year(t): y = int(math.ceil(t/4.0)) * 4 if ((y % 100) == 0) and ((y % 400) != 0): y = y + 4 if (y == 10000): y = y - 4 return y _epoch = date(1970, 1, 1) _days = timedelta(days=1) def _to_days_since_epoch(self, val): return int((val - self._epoch)/self._days) def _from_days_since_epoch(self, days): return self._epoch + timedelta(days=days) def start(self, rand): start = self._start_day end = self._end_day self._start(rand, lambda : self._from_days_since_epoch(rand.randint(start, end))) class TimestampGen(DataGen): """Generate Timestamps in a given range. All timezones are UTC by default.""" def __init__(self, start=None, end=None, nullable=True): super().__init__(TimestampType(), nullable=nullable) if start is None: # Spark supports times starting at # "0001-01-01 00:00:00.000000" # but it has issues if you get really close to that because it tries to do things # in a different format which causes roundoff, so we have to add a few days, # just to be sure start = datetime(1, 1, 3, tzinfo=timezone.utc) elif not isinstance(start, datetime): raise RuntimeError('Unsupported type passed in for start {}'.format(start)) if end is None: # Spark supports time through # "9999-12-31 23:59:59.999999" end = datetime(9999, 12, 31, 23, 59, 59, 999999, tzinfo=timezone.utc) elif isinstance(end, timedelta): end = start + end elif not isinstance(start, date): raise RuntimeError('Unsupported type passed in for end {}'.format(end)) self._start_time = self._to_ms_since_epoch(start) self._end_time = self._to_ms_since_epoch(end) if (self._epoch >= start and self._epoch <= end): self.with_special_case(self._epoch) _epoch = datetime(1970, 1, 1, tzinfo=timezone.utc) _ms = timedelta(milliseconds=1) def _to_ms_since_epoch(self, val): return int((val - self._epoch)/self._ms) def _from_ms_since_epoch(self, ms): return self._epoch + timedelta(milliseconds=ms) def start(self, rand): start = self._start_time end = self._end_time self._start(rand, lambda : self._from_ms_since_epoch(rand.randint(start, end))) def contains_ts(self): return True class ArrayGen(DataGen): """Generate Arrays of data.""" def __init__(self, child_gen, min_length=0, max_length=20, nullable=True, all_null=False): super().__init__(ArrayType(child_gen.data_type, containsNull=child_gen.nullable), nullable=nullable) self._min_length = min_length self._max_length = max_length self._child_gen = child_gen self.all_null = all_null def __repr__(self): return super().__repr__() + '(' + str(self._child_gen) + ')' def start(self, rand): self._child_gen.start(rand) def gen_array(): if self.all_null: return None length = rand.randint(self._min_length, self._max_length) return [self._child_gen.gen() for _ in range(0, length)] self._start(rand, gen_array) def contains_ts(self): return self._child_gen.contains_ts() class MapGen(DataGen): """Generate a Map""" def __init__(self, key_gen, value_gen, min_length=0, max_length=20, nullable=True, special_cases=[]): # keys cannot be nullable assert not key_gen.nullable self._min_length = min_length self._max_length = max_length self._key_gen = key_gen self._value_gen = value_gen super().__init__(MapType(key_gen.data_type, value_gen.data_type, valueContainsNull=value_gen.nullable), nullable=nullable, special_cases=special_cases) def __repr__(self): return super().__repr__() + '(' + str(self._key_gen) + ',' + str(self._value_gen) + ')' def start(self, rand): self._key_gen.start(rand) self._value_gen.start(rand) def make_dict(): length = rand.randint(self._min_length, self._max_length) return {self._key_gen.gen(): self._value_gen.gen() for idx in range(0, length)} self._start(rand, make_dict) def contains_ts(self): return self._key_gen.contains_ts() or self._value_gen.contains_ts() class NullGen(DataGen): """Generate NullType values""" def __init__(self): super().__init__(NullType(), nullable=True) def start(self, rand): def make_null(): return None self._start(rand, make_null) def skip_if_not_utc(): if (not is_tz_utc()): skip_unless_precommit_tests('The java system time zone is not set to UTC') def gen_df(spark, data_gen, length=2048, seed=0, num_slices=None): """Generate a spark dataframe from the given data generators.""" if isinstance(data_gen, list): src = StructGen(data_gen, nullable=False) else: src = data_gen # we cannot create a data frame from a nullable struct assert not data_gen.nullable # Before we get too far we need to verify that we can run with timestamps if src.contains_ts(): skip_if_not_utc() rand = random.Random(seed) src.start(rand) data = [src.gen() for index in range(0, length)] # We use `numSlices` to create an RDD with the specific number of partitions, # which is then turned into a dataframe. If not specified, it is `None` (default spark value) return spark.createDataFrame( SparkContext.getOrCreate().parallelize(data, numSlices=num_slices), src.data_type) def _mark_as_lit(data, data_type): # To support nested types, 'data_type' is required. assert data_type is not None if data is None: return f.lit(data).cast(data_type) if isinstance(data_type, ArrayType): assert isinstance(data, list) # Sadly you cannot create a literal from just an array in pyspark return f.array([_mark_as_lit(x, data_type.elementType) for x in data]) elif isinstance(data_type, StructType): assert isinstance(data, tuple) and len(data) == len(data_type.fields) # Sadly you cannot create a literal from just a dict/tuple in pyspark children = zip(data, data_type.fields) return f.struct([_mark_as_lit(x, fd.dataType).alias(fd.name) for x, fd in children]) elif isinstance(data_type, DateType): # Due to https://bugs.python.org/issue13305 we need to zero pad for years prior to 1000, # but this works for all of them dateString = data.strftime("%Y-%m-%d").zfill(10) return f.lit(dateString).cast(data_type) elif isinstance(data_type, MapType): assert isinstance(data, dict) # Sadly you cannot create a literal from just a dict/tuple in pyspark col_array = [] for k in data: col_array.append(_mark_as_lit(k, data_type.keyType)) col_array.append(_mark_as_lit(data[k], data_type.valueType)) return f.create_map(*col_array) else: # lit does not take a data type so we might have to cast it return f.lit(data).cast(data_type) def _gen_scalars_common(data_gen, count, seed=0): if isinstance(data_gen, list): src = StructGen(data_gen, nullable=False) else: src = data_gen # Before we get too far we need to verify that we can run with timestamps if src.contains_ts(): skip_if_not_utc() rand = random.Random(seed) src.start(rand) return src def gen_scalars(data_gen, count, seed=0, force_no_nulls=False): """Generate scalar values.""" if force_no_nulls: assert(not isinstance(data_gen, NullGen)) src = _gen_scalars_common(data_gen, count, seed=seed) data_type = src.data_type return (_mark_as_lit(src.gen(force_no_nulls=force_no_nulls), data_type) for i in range(0, count)) def gen_scalar(data_gen, seed=0, force_no_nulls=False): """Generate a single scalar value.""" v = list(gen_scalars(data_gen, 1, seed=seed, force_no_nulls=force_no_nulls)) return v[0] def gen_scalar_values(data_gen, count, seed=0, force_no_nulls=False): """Generate scalar values.""" src = _gen_scalars_common(data_gen, count, seed=seed) return (src.gen(force_no_nulls=force_no_nulls) for i in range(0, count)) def gen_scalar_value(data_gen, seed=0, force_no_nulls=False): """Generate a single scalar value.""" v = list(gen_scalar_values(data_gen, 1, seed=seed, force_no_nulls=force_no_nulls)) return v[0] def debug_df(df, path = None, file_format = 'json', num_parts = 1): """Print out or save the contents and the schema of a dataframe for debugging.""" if path is not None: # Save the dataframe and its schema # The schema can be re-created by using DataType.fromJson and used # for loading the dataframe file_name = f"{path}.{file_format}" schema_file_name = f"{path}.schema.json" df.coalesce(num_parts).write.format(file_format).save(file_name) print(f"SAVED df output for debugging at {file_name}") schema_json = df.schema.json() schema_file = open(schema_file_name , 'w') schema_file.write(schema_json) schema_file.close() print(f"SAVED df schema for debugging along in the output dir") else: print('COLLECTED\n{}'.format(df.collect())) df.explain() df.printSchema() return df def print_params(data_gen): print('Test Datagen Params=' + str([(a, b.get_types()) for a, b in data_gen])) def idfn(val): """Provide an API to provide display names for data type generators.""" return str(val) def meta_idfn(meta): def tmp(something): return meta + idfn(something) return tmp def three_col_df(spark, a_gen, b_gen, c_gen, length=2048, seed=0, num_slices=None): gen = StructGen([('a', a_gen),('b', b_gen),('c', c_gen)], nullable=False) return gen_df(spark, gen, length=length, seed=seed, num_slices=num_slices) def two_col_df(spark, a_gen, b_gen, length=2048, seed=0, num_slices=None): gen = StructGen([('a', a_gen),('b', b_gen)], nullable=False) return gen_df(spark, gen, length=length, seed=seed, num_slices=num_slices) def binary_op_df(spark, gen, length=2048, seed=0, num_slices=None): return two_col_df(spark, gen, gen, length=length, seed=seed, num_slices=num_slices) def unary_op_df(spark, gen, length=2048, seed=0, num_slices=None): return gen_df(spark, StructGen([('a', gen)], nullable=False), length=length, seed=seed, num_slices=num_slices) def to_cast_string(spark_type): if isinstance(spark_type, ByteType): return 'BYTE' elif isinstance(spark_type, ShortType): return 'SHORT' elif isinstance(spark_type, IntegerType): return 'INT' elif isinstance(spark_type, LongType): return 'LONG' elif isinstance(spark_type, FloatType): return 'FLOAT' elif isinstance(spark_type, DoubleType): return 'DOUBLE' elif isinstance(spark_type, BooleanType): return 'BOOLEAN' elif isinstance(spark_type, DateType): return 'DATE' elif isinstance(spark_type, TimestampType): return 'TIMESTAMP' elif isinstance(spark_type, StringType): return 'STRING' elif isinstance(spark_type, DecimalType): return 'DECIMAL({}, {})'.format(spark_type.precision, spark_type.scale) elif isinstance(spark_type, ArrayType): return 'ARRAY<{}>'.format(to_cast_string(spark_type.elementType)) elif isinstance(spark_type, StructType): children = [fd.name + ':' + to_cast_string(fd.dataType) for fd in spark_type.fields] return 'STRUCT<{}>'.format(','.join(children)) else: raise RuntimeError('CAST TO TYPE {} NOT SUPPORTED YET'.format(spark_type)) def get_null_lit_string(spark_type): if isinstance(spark_type, NullType): return 'null' else: string_type = to_cast_string(spark_type) return 'CAST(null as {})'.format(string_type) def _convert_to_sql(spark_type, data): if isinstance(data, str): d = "'" + data.replace("'", "\\'") + "'" elif isinstance(data, datetime): d = "'" + data.strftime('%Y-%m-%d T%H:%M:%S.%f').zfill(26) + "'" elif isinstance(data, date): d = "'" + data.strftime('%Y-%m-%d').zfill(10) + "'" elif isinstance(data, list): assert isinstance(spark_type, ArrayType) d = "array({})".format(",".join([_convert_to_sql(spark_type.elementType, x) for x in data])) elif isinstance(data, tuple): assert isinstance(spark_type, StructType) and len(data) == len(spark_type.fields) # Format of each child: 'name',data children = ["'{}'".format(fd.name) + ',' + _convert_to_sql(fd.dataType, x) for fd, x in zip(spark_type.fields, data)] d = "named_struct({})".format(','.join(children)) elif not data: # data is None d = "null" else: d = "'{}'".format(str(data)) if isinstance(spark_type, NullType): return d else: return 'CAST({} as {})'.format(d, to_cast_string(spark_type)) def gen_scalars_for_sql(data_gen, count, seed=0, force_no_nulls=False): """Generate scalar values, but strings that can be used in selectExpr or SQL""" src = _gen_scalars_common(data_gen, count, seed=seed) if isinstance(data_gen, NullGen): assert not force_no_nulls return ('null' for i in range(0, count)) spark_type = data_gen.data_type return (_convert_to_sql(spark_type, src.gen(force_no_nulls=force_no_nulls)) for i in range(0, count)) byte_gen = ByteGen() short_gen = ShortGen() int_gen = IntegerGen() long_gen = LongGen() float_gen = FloatGen() double_gen = DoubleGen() string_gen = StringGen() boolean_gen = BooleanGen() date_gen = DateGen() timestamp_gen = TimestampGen() decimal_gen_default = DecimalGen() decimal_gen_neg_scale = DecimalGen(precision=7, scale=-3) decimal_gen_scale_precision = DecimalGen(precision=7, scale=3) decimal_gen_same_scale_precision = DecimalGen(precision=7, scale=7) decimal_gen_64bit = DecimalGen(precision=12, scale=2) decimal_gen_12_2 = DecimalGen(precision=12, scale=2) decimal_gen_18_3 = DecimalGen(precision=18, scale=3) decimal_gen_128bit = DecimalGen(precision=20, scale=2) decimal_gen_20_2 = DecimalGen(precision=20, scale=2) decimal_gen_30_2 = DecimalGen(precision=30, scale=2) decimal_gen_36_5 = DecimalGen(precision=36, scale=5) decimal_gen_36_neg5 = DecimalGen(precision=36, scale=-5) decimal_gen_38_0 = DecimalGen(precision=38, scale=0) decimal_gen_38_10 = DecimalGen(precision=38, scale=10) decimal_gen_38_neg10 = DecimalGen(precision=38, scale=-10) null_gen = NullGen() numeric_gens = [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen] integral_gens = [byte_gen, short_gen, int_gen, long_gen] # A lot of mathematical expressions only support a double as input # by parametrizing even for a single param for the test it makes the tests consistent double_gens = [double_gen] double_n_long_gens = [double_gen, long_gen] int_n_long_gens = [int_gen, long_gen] decimal_gens_no_neg = [decimal_gen_default, decimal_gen_scale_precision, decimal_gen_same_scale_precision, decimal_gen_64bit] decimal_gens = [decimal_gen_neg_scale] + decimal_gens_no_neg decimal_128_gens_no_neg = [decimal_gen_20_2, decimal_gen_30_2, decimal_gen_36_5, decimal_gen_38_0, decimal_gen_38_10] decimal_128_gens = decimal_128_gens_no_neg + [decimal_gen_36_neg5, decimal_gen_38_neg10] # all of the basic gens all_basic_gens_no_null = [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, string_gen, boolean_gen, date_gen, timestamp_gen] all_basic_gens = all_basic_gens_no_null + [null_gen] all_basic_gens_no_nan = [byte_gen, short_gen, int_gen, long_gen, FloatGen(no_nans=True), DoubleGen(no_nans=True), string_gen, boolean_gen, date_gen, timestamp_gen, null_gen] # TODO add in some array generators to this once that is supported for sorting # a selection of generators that should be orderable (sortable and compareable) orderable_gens = [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, string_gen, boolean_gen, date_gen, timestamp_gen, null_gen] + decimal_gens # TODO add in some array generators to this once that is supported for these operations # a selection of generators that can be compared for equality eq_gens = [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, string_gen, boolean_gen, date_gen, timestamp_gen, null_gen] # Include decimal type while testing equalTo and notEqualTo eq_gens_with_decimal_gen = eq_gens + decimal_gens date_gens = [date_gen] date_n_time_gens = [date_gen, timestamp_gen] boolean_gens = [boolean_gen] single_level_array_gens = [ArrayGen(sub_gen) for sub_gen in all_basic_gens + decimal_gens] single_array_gens_sample_with_decimal128 = [ArrayGen(sub_gen) for sub_gen in decimal_128_gens] single_level_array_gens_no_null = [ArrayGen(sub_gen) for sub_gen in all_basic_gens_no_null + decimal_gens_no_neg] single_level_array_gens_no_nan = [ArrayGen(sub_gen) for sub_gen in all_basic_gens_no_nan + decimal_gens] single_level_array_gens_no_decimal = [ArrayGen(sub_gen) for sub_gen in all_basic_gens] map_string_string_gen = [MapGen(StringGen(pattern='key_[0-9]', nullable=False), StringGen())] # Be careful to not make these too large of data generation takes for ever # This is only a few nested array gens, because nesting can be very deep nested_array_gens_sample = [ArrayGen(ArrayGen(short_gen, max_length=10), max_length=10), ArrayGen(ArrayGen(string_gen, max_length=10), max_length=10), ArrayGen(StructGen([['child0', byte_gen], ['child1', string_gen], ['child2', float_gen]]))] # Some array gens, but not all because of nesting array_gens_sample = single_level_array_gens + nested_array_gens_sample array_gens_sample_with_decimal128 = single_level_array_gens + nested_array_gens_sample + single_array_gens_sample_with_decimal128 # all of the basic types in a single struct all_basic_struct_gen = StructGen([['child'+str(ind), sub_gen] for ind, sub_gen in enumerate(all_basic_gens)]) # Some struct gens, but not all because of nesting nonempty_struct_gens_sample = [all_basic_struct_gen, StructGen([['child0', byte_gen], ['child1', all_basic_struct_gen]]), StructGen([['child0', ArrayGen(short_gen)], ['child1', double_gen]])] struct_gens_sample = nonempty_struct_gens_sample + [StructGen([])] struct_gen_decimal128 = StructGen( [['child' + str(ind), sub_gen] for ind, sub_gen in enumerate(decimal_128_gens)]) struct_gens_sample_with_decimal128 = struct_gens_sample + [ struct_gen_decimal128] simple_string_to_string_map_gen = MapGen(StringGen(pattern='key_[0-9]', nullable=False), StringGen(), max_length=10) all_basic_map_gens = [MapGen(f(nullable=False), f()) for f in [BooleanGen, ByteGen, ShortGen, IntegerGen, LongGen, FloatGen, DoubleGen, DateGen, TimestampGen]] + [simple_string_to_string_map_gen] decimal_64_map_gens = [MapGen(key_gen=gen, value_gen=gen, nullable=False) for gen in [DecimalGen(7, 3, nullable=False), DecimalGen(12, 2, nullable=False), DecimalGen(18, -3, nullable=False)]] decimal_128_map_gens = [MapGen(key_gen=gen, value_gen=gen, nullable=False) for gen in [DecimalGen(20, 2, nullable=False), DecimalGen(36, 5, nullable=False), DecimalGen(38, 38, nullable=False), DecimalGen(36, -5, nullable=False)]] decimal_128_no_neg_map_gens = [MapGen(key_gen=gen, value_gen=gen, nullable=False) for gen in [DecimalGen(20, 2, nullable=False), DecimalGen(36, 5, nullable=False), DecimalGen(38, 38, nullable=False)]] # Some map gens, but not all because of nesting map_gens_sample = all_basic_map_gens + [MapGen(StringGen(pattern='key_[0-9]', nullable=False), ArrayGen(string_gen), max_length=10), MapGen(RepeatSeqGen(IntegerGen(nullable=False), 10), long_gen, max_length=10), MapGen(StringGen(pattern='key_[0-9]', nullable=False), simple_string_to_string_map_gen)] allow_negative_scale_of_decimal_conf = {'spark.sql.legacy.allowNegativeScaleOfDecimal': 'true'} def copy_and_update(conf, *more_confs): local_conf = conf.copy() for more in more_confs: local_conf.update(more) return local_conf all_gen = [StringGen(), ByteGen(), ShortGen(), IntegerGen(), LongGen(), FloatGen(), DoubleGen(), BooleanGen(), DateGen(), TimestampGen(), decimal_gen_default, decimal_gen_scale_precision, decimal_gen_same_scale_precision, decimal_gen_64bit, decimal_gen_128bit, decimal_gen_36_5, decimal_gen_38_10] # Pyarrow will complain the error as below if the timestamp is out of range for both CPU and GPU, # so narrow down the time range to avoid exceptions causing test failures. # # "pyarrow.lib.ArrowInvalid: Casting from timestamp[us, tz=UTC] to timestamp[ns] # would result in out of bounds timestamp: 51496791452587000" # # This issue has been fixed in pyarrow by the PR https://github.com/apache/arrow/pull/7169 # However it still requires PySpark to specify the new argument "timestamp_as_object". arrow_common_gen = [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, string_gen, boolean_gen, date_gen, TimestampGen(start=datetime(1970, 1, 1, tzinfo=timezone.utc), end=datetime(2262, 1, 1, tzinfo=timezone.utc))] arrow_array_gens = [ArrayGen(subGen) for subGen in arrow_common_gen] + nested_array_gens_sample arrow_one_level_struct_gen = StructGen([ ['child'+str(i), sub_gen] for i, sub_gen in enumerate(arrow_common_gen)]) arrow_struct_gens = [arrow_one_level_struct_gen, StructGen([['child0', ArrayGen(short_gen)], ['child1', arrow_one_level_struct_gen]])] # This function adds a new column named uniq_int where each row # has a new unique integer value. It just starts at 0 and # increments by 1 for each row. # This can be used to add a column to a dataframe if you need to # sort on a column with unique values. # This collects the data to driver though so can be expensive. def append_unique_int_col_to_df(spark, dataframe): def append_unique_to_rows(rows): new = [] for item in range(len(rows)): row_dict = rows[item].asDict() row_dict['uniq_int'] = item new_row = Row(**row_dict) new.append(new_row) return new collected = dataframe.collect() if (len(collected) > INT_MAX): raise RuntimeError('To many rows to add unique integer values starting from 0 to') existing_schema = dataframe.schema new_rows = append_unique_to_rows(collected) new_schema = StructType(existing_schema.fields + [StructField("uniq_int", IntegerType(), False)]) return spark.createDataFrame(new_rows, new_schema) ================================================ FILE: examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/python/rapids_udf_test.py ================================================ # Copyright (c) 2020-2022, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import pytest from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_are_equal_sql from data_gen import * from spark_session import with_spark_session from pyspark.sql.utils import AnalysisException encoded_url_gen = StringGen('([^%]{0,1}(%[0-9A-F][0-9A-F]){0,1}){0,30}') def drop_udf(spark, udfname): spark.sql("DROP TEMPORARY FUNCTION IF EXISTS {}".format(udfname)) def skip_if_no_hive(spark): if spark.conf.get("spark.sql.catalogImplementation") != "hive": raise RuntimeError('The Spark session does not have Hive support') def load_hive_udf_or_skip_test(spark, udfname, udfclass): drop_udf(spark, udfname) spark.sql("CREATE TEMPORARY FUNCTION {} AS '{}'".format(udfname, udfclass)) def test_hive_simple_udf(): with_spark_session(skip_if_no_hive) data_gens = [["i", int_gen], ["s", encoded_url_gen]] def evalfn(spark): load_hive_udf_or_skip_test(spark, "urldecode", "com.nvidia.spark.rapids.udf.hive.URLDecode") return gen_df(spark, data_gens) assert_gpu_and_cpu_are_equal_sql( evalfn, "hive_simple_udf_test_table", "SELECT i, urldecode(s) FROM hive_simple_udf_test_table") def test_hive_generic_udf(): with_spark_session(skip_if_no_hive) def evalfn(spark): load_hive_udf_or_skip_test(spark, "urlencode", "com.nvidia.spark.rapids.udf.hive.URLEncode") return gen_df(spark, [["s", StringGen('.{0,30}')]]) assert_gpu_and_cpu_are_equal_sql( evalfn, "hive_generic_udf_test_table", "SELECT urlencode(s) FROM hive_generic_udf_test_table") def evalfn_decimal(spark): load_hive_udf_or_skip_test(spark, "fraction", "com.nvidia.spark.rapids.udf.hive.DecimalFraction") return gen_df(spark, [["dec", DecimalGen(38, 18)]]) assert_gpu_and_cpu_are_equal_sql( evalfn_decimal, "hive_generic_udf_test_table", "SELECT fraction(dec) FROM hive_generic_udf_test_table") @pytest.mark.rapids_udf_example_native def test_hive_simple_udf_native(): with_spark_session(skip_if_no_hive) data_gens = [["s", StringGen('.{0,30}')]] def evalfn(spark): load_hive_udf_or_skip_test(spark, "wordcount", "com.nvidia.spark.rapids.udf.hive.StringWordCount") return gen_df(spark, data_gens) assert_gpu_and_cpu_are_equal_sql( evalfn, "hive_native_udf_test_table", "SELECT wordcount(s) FROM hive_native_udf_test_table") def load_java_udf_or_skip_test(spark, udfname, udfclass, udf_return_type=None): drop_udf(spark, udfname) spark.udf.registerJavaFunction(udfname, udfclass, udf_return_type) def test_java_url_decode(): def evalfn(spark): load_java_udf_or_skip_test(spark, 'urldecode', 'com.nvidia.spark.rapids.udf.java.URLDecode') return unary_op_df(spark, encoded_url_gen).selectExpr("urldecode(a)") assert_gpu_and_cpu_are_equal_collect(evalfn) def test_java_url_encode(): def evalfn(spark): load_java_udf_or_skip_test(spark, 'urlencode', 'com.nvidia.spark.rapids.udf.java.URLEncode') return unary_op_df(spark, StringGen('.{0,30}')).selectExpr("urlencode(a)") assert_gpu_and_cpu_are_equal_collect(evalfn) def test_java_decimal_fraction(): def evalfn(spark): from pyspark.sql.types import DecimalType load_java_udf_or_skip_test(spark, 'fraction', 'com.nvidia.spark.rapids.udf.java.DecimalFraction') load_java_udf_or_skip_test(spark, 'fraction_dec64_s10', 'com.nvidia.spark.rapids.udf.java.DecimalFraction', DecimalType(18, 10)) load_java_udf_or_skip_test(spark, 'fraction_dec32_s3', 'com.nvidia.spark.rapids.udf.java.DecimalFraction', DecimalType(8, 3)) return three_col_df(spark, DecimalGen(38, 18), DecimalGen(18, 10), DecimalGen(8, 3) ).selectExpr("fraction(a)", "fraction_dec64_s10(b)", "fraction_dec32_s3(c)") assert_gpu_and_cpu_are_equal_collect(evalfn) @pytest.mark.rapids_udf_example_native def test_java_cosine_similarity_reasonable_range(): def evalfn(spark): class RangeFloatGen(FloatGen): def start(self, rand): self._start(rand, lambda: rand.uniform(-1000.0, 1000.0)) load_java_udf_or_skip_test(spark, "cosine_similarity", "com.nvidia.spark.rapids.udf.java.CosineSimilarity") arraygen = ArrayGen(RangeFloatGen(nullable=False, no_nans=True, special_cases=[]), min_length=8, max_length=8) df = binary_op_df(spark, arraygen) return df.selectExpr("cosine_similarity(a, b)") assert_gpu_and_cpu_are_equal_collect(evalfn) @pytest.mark.rapids_udf_example_native def test_java_cosine_similarity_with_nans(): def evalfn(spark): load_java_udf_or_skip_test(spark, "cosine_similarity", "com.nvidia.spark.rapids.udf.java.CosineSimilarity") arraygen = ArrayGen(FloatGen(nullable=False), min_length=8, max_length=8) return binary_op_df(spark, arraygen).selectExpr("cosine_similarity(a, b)") assert_gpu_and_cpu_are_equal_collect(evalfn) ================================================ FILE: examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/python/spark_init_internal.py ================================================ # Copyright (c) 2020-2021, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os try: import pyspark except ImportError as error: import findspark findspark.init() import pyspark _DRIVER_ENV = 'PYSP_TEST_spark_driver_extraJavaOptions' def _spark__init(): #Force the RapidsPlugin to be enabled, so it blows up if the classpath is not set properly # DO NOT SET ANY OTHER CONFIGS HERE!!! # due to bugs in pyspark/pytest it looks like any configs set here # can be reset in the middle of a test if specific operations are done (some types of cast etc) _sb = pyspark.sql.SparkSession.builder _sb.config('spark.plugins', 'com.nvidia.spark.SQLPlugin') \ .config("spark.sql.adaptive.enabled", "false") \ .config('spark.sql.queryExecutionListeners', 'org.apache.spark.sql.rapids.ExecutionPlanCaptureCallback') for key, value in os.environ.items(): if key.startswith('PYSP_TEST_') and key != _DRIVER_ENV: _sb.config(key[10:].replace('_', '.'), value) driver_opts = os.environ.get(_DRIVER_ENV, "") if ('PYTEST_XDIST_WORKER' in os.environ): wid = os.environ['PYTEST_XDIST_WORKER'] _handle_derby_dir(_sb, driver_opts, wid) _handle_event_log_dir(_sb, wid) else: _sb.config('spark.driver.extraJavaOptions', driver_opts) _handle_event_log_dir(_sb, 'gw0') # enableHiveSupport() is needed for parquet bucket tests _s = _sb.enableHiveSupport() \ .appName('rapids spark plugin integration tests (python)').getOrCreate() #TODO catch the ClassNotFound error that happens if the classpath is not set up properly and # make it a better error message _s.sparkContext.setLogLevel("WARN") return _s def _handle_derby_dir(sb, driver_opts, wid): d = "./derby_{}".format(wid) if not os.path.exists(d): os.makedirs(d) sb.config('spark.driver.extraJavaOptions', driver_opts + ' -Dderby.system.home={}'.format(d)) def _handle_event_log_dir(sb, wid): if os.environ.get('SPARK_EVENTLOG_ENABLED', str(True)).lower() in [ str(False).lower(), 'off', '0' ]: print('Automatic configuration for spark event log disabled') return spark_conf = pyspark.SparkConf() master_url = os.environ.get('PYSP_TEST_spark_master', spark_conf.get("spark.master", 'local')) event_log_config = os.environ.get('PYSP_TEST_spark_eventLog_enabled', spark_conf.get('spark.eventLog.enabled', str(False).lower())) event_log_codec = os.environ.get('PYSP_TEST_spark_eventLog_compression_codec', 'zstd') if not master_url.startswith('local') or event_log_config != str(False).lower(): print("SPARK_EVENTLOG_ENABLED is ignored for non-local Spark master and when " "it's pre-configured by the user") return d = "./eventlog_{}".format(wid) if not os.path.exists(d): os.makedirs(d) print('Spark event logs will appear under {}. Set the environmnet variable ' 'SPARK_EVENTLOG_ENABLED=false if you want to disable it'.format(d)) sb\ .config('spark.eventLog.dir', "file://{}".format(os.path.abspath(d))) \ .config('spark.eventLog.compress', True) \ .config('spark.eventLog.enabled', True) \ .config('spark.eventLog.compression.codec', event_log_codec) _spark = _spark__init() def get_spark_i_know_what_i_am_doing(): """ Get the current SparkSession. This should almost never be called directly instead you should call with_spark_session, with_cpu_session, or with_gpu_session for spark_session. This is to guarantee that the session and it's config is setup in a repeatable way. """ return _spark def spark_version(): return _spark.version ================================================ FILE: examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/python/spark_session.py ================================================ # Copyright (c) 2020-2022, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from conftest import is_allowing_any_non_gpu, get_non_gpu_allowed, get_validate_execs_in_gpu_plan, is_databricks_runtime from pyspark.sql import SparkSession, DataFrame from spark_init_internal import get_spark_i_know_what_i_am_doing, spark_version def _from_scala_map(scala_map): ret = {} # The value we get is a scala map, not a java map, so we need to jump through some hoops keys = scala_map.keys().iterator() while keys.hasNext(): key = keys.next() ret[key] = scala_map.get(key).get() return ret _spark = get_spark_i_know_what_i_am_doing() # Have to reach into a private member to get access to the API we need _orig_conf = _from_scala_map(_spark.conf._jconf.getAll()) _orig_conf_keys = _orig_conf.keys() def is_tz_utc(spark=_spark): """ true if the tz is UTC else false """ # Now we have to do some kind of ugly internal java stuff jvm = spark.sparkContext._jvm utc = jvm.java.time.ZoneId.of('UTC').normalized() sys_tz = jvm.java.time.ZoneId.systemDefault().normalized() return utc == sys_tz def _set_all_confs(conf): for key, value in conf.items(): if _spark.conf.get(key, None) != value: _spark.conf.set(key, value) def reset_spark_session_conf(): """Reset all of the configs for a given spark session.""" _set_all_confs(_orig_conf) #We should clear the cache _spark.catalog.clearCache() # Have to reach into a private member to get access to the API we need current_keys = _from_scala_map(_spark.conf._jconf.getAll()).keys() for key in current_keys: if key not in _orig_conf_keys: _spark.conf.unset(key) def _check_for_proper_return_values(something): """We don't want to return an DataFrame or Dataset from a with_spark_session. You will not get what you expect""" if (isinstance(something, DataFrame)): raise RuntimeError("You should never return a DataFrame from a with_*_session, you will not get the results that you expect") def with_spark_session(func, conf={}): """Run func that takes a spark session as input with the given configs set.""" reset_spark_session_conf() _add_job_description(conf) _set_all_confs(conf) ret = func(_spark) _check_for_proper_return_values(ret) return ret def _add_job_description(conf): is_gpu_job = conf.get('spark.rapids.sql.enabled', False) job_type = 'GPU' if str(is_gpu_job).lower() == str(True).lower() else 'CPU' job_desc = '{}[{}]'.format(os.environ.get('PYTEST_CURRENT_TEST'), job_type) _spark.sparkContext.setJobDescription(job_desc) def with_cpu_session(func, conf={}): """Run func that takes a spark session as input with the given configs set on the CPU.""" copy = dict(conf) copy['spark.rapids.sql.enabled'] = 'false' return with_spark_session(func, conf=copy) def with_gpu_session(func, conf={}): """ Run func that takes a spark session as input with the given configs set on the GPU. Note that this forces you into test mode unless. It is not a requirement, but is simplest for right now. """ copy = dict(conf) copy['spark.rapids.sql.enabled'] = 'true' if is_allowing_any_non_gpu(): copy['spark.rapids.sql.test.enabled'] = 'false' else: copy['spark.rapids.sql.test.enabled'] = 'true' copy['spark.rapids.sql.test.allowedNonGpu'] = ','.join(get_non_gpu_allowed()) copy['spark.rapids.sql.test.validateExecsInGpuPlan'] = ','.join(get_validate_execs_in_gpu_plan()) return with_spark_session(func, conf=copy) def is_before_spark_311(): return spark_version() < "3.1.0" def is_before_spark_320(): return spark_version() < "3.2.0" def is_before_spark_330(): return spark_version() < "3.3.0" def is_databricks91_or_later(): spark = get_spark_i_know_what_i_am_doing() return spark.conf.get("spark.databricks.clusterUsageTags.sparkVersion", "") >= "9.1" ================================================ FILE: examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/scala/com/nvidia/spark/rapids/udf/scala/URLDecode.scala ================================================ /* * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.nvidia.spark.rapids.udf.scala import java.net.URLDecoder import ai.rapids.cudf.{ColumnVector, DType, Scalar} import com.nvidia.spark.RapidsUDF /** * A Scala user-defined function (UDF) that decodes URL-encoded strings. * This class demonstrates how to implement a Scala UDF that also * provides a RAPIDS implementation that can run on the GPU when the query * is executed with the RAPIDS Accelerator for Apache Spark. */ class URLDecode extends Function[String, String] with RapidsUDF with Serializable { /** Row-by-row implementation that executes on the CPU */ override def apply(s: String): String = { Option(s).map { s => try { URLDecoder.decode(s, "utf-8") } catch { case _: IllegalArgumentException => s } }.orNull } /** Columnar implementation that runs on the GPU */ override def evaluateColumnar(numRows: Int, args: ColumnVector*): ColumnVector = { // The CPU implementation takes a single string argument, so similarly // there should only be one column argument of type STRING. require(args.length == 1, s"Unexpected argument count: ${args.length}") val input = args.head require(numRows == input.getRowCount, s"Expected $numRows rows, received ${input.getRowCount}") require(input.getType == DType.STRING, s"Argument type is not a string: ${input.getType}") // The cudf urlDecode does not convert '+' to a space, so do that as a pre-pass first. // All intermediate results are closed to avoid leaking GPU resources. val plusScalar = Scalar.fromString("+") try { val spaceScalar = Scalar.fromString(" ") try { val replaced = input.stringReplace(plusScalar, spaceScalar) try { replaced.urlDecode() } finally { replaced.close() } } finally { spaceScalar.close() } } finally { plusScalar.close() } } } ================================================ FILE: examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/scala/com/nvidia/spark/rapids/udf/scala/URLEncode.scala ================================================ /* * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.nvidia.spark.rapids.udf.scala import java.net.URLEncoder import ai.rapids.cudf.{ColumnVector, DType} import com.nvidia.spark.RapidsUDF /** * A Scala user-defined function (UDF) that URL-encodes strings. * This class demonstrates how to implement a Scala UDF that also * provides a RAPIDS implementation that can run on the GPU when the query * is executed with the RAPIDS Accelerator for Apache Spark. */ class URLEncode extends Function[String, String] with RapidsUDF with Serializable { /** Row-by-row implementation that executes on the CPU */ override def apply(s: String): String = { Option(s).map { s => URLEncoder.encode(s, "utf-8") .replace("+", "%20") .replace("*", "%2A") .replace("%7E", "~") }.orNull } /** Columnar implementation that runs on the GPU */ override def evaluateColumnar(numRows: Int, args: ColumnVector*): ColumnVector = { // The CPU implementation takes a single string argument, so similarly // there should only be one column argument of type STRING. require(args.length == 1, s"Unexpected argument count: ${args.length}") val input = args.head require(numRows == input.getRowCount, s"Expected $numRows rows, received ${input.getRowCount}") require(input.getType == DType.STRING, s"Argument type is not a string: ${input.getType}") input.urlEncode() } } ================================================ FILE: examples/XGBoost-Examples/.gitignore ================================================ samples.zip ================================================ FILE: examples/XGBoost-Examples/README.md ================================================ # Spark XGBoost Examples Spark XGBoost examples here showcase the need for ETL+Training pipeline GPU acceleration. The Scala based XGBoost examples here use [DMLC’s version](https://repo1.maven.org/maven2/ml/dmlc/xgboost4j-spark_2.12/). The pyspark based XGBoost examples requires [installing RAPIDS via pip](https://rapids.ai/pip.html#install). Most data scientists spend a lot of time not only on Training models but also processing the large amounts of data needed to train these models. As you can see below, Pyspark+XGBoost training on GPUs can be up to 13X and data processing using RAPIDS Accelerator can also be accelerated with an end-to-end speed-up of 11X on GPU compared to CPU. In the public cloud, better performance can lead to significantly lower costs as demonstrated in this [blog](https://developer.nvidia.com/blog/gpu-accelerated-spark-xgboost/). ![mortgage-speedup](/docs/img/guides/mortgage-perf.png) Note that the Training test result is based on 4 years [Fannie Mea Single-Family Loan Performance Data](https://capitalmarkets.fanniemae.com/credit-risk-transfer/single-family-credit-risk-transfer/fannie-mae-single-family-loan-performance-data) with a 8 A100 GPU and 1024 CPU vcores cluster, the performance is affected by many aspects, including data size and type of GPU. In this folder, there are three blue prints for users to learn about using Spark XGBoost and RAPIDS Accelerator on GPUs : 1. Mortgage Prediction 2. Agaricus Classification 3. Taxi Fare Prediction For each of these examples we have prepared a [sample dataset](/datasets) in this folder for testing. These datasets are only provided for convenience. In order to test for performance, please download the larger dataset from their respectives sources. There are three sections in this readme section. In the first section, we will list the notebooks that can be run on Jupyter with Python or Scala ([Spylon Kernel](https://pypi.org/project/spylon-kernel/) or [Apache Toree Kernel](https://toree.apache.org/)). In the second section, we have sample jar files and source code if users would like to build and run this as a Scala or a PySpark Spark-XGBoost application. In the last section, we provide basic “Getting Started Guides” for setting up GPU Spark-XGBoost on different environments based on the Apache Spark scheduler such as YARN, Standalone or Kubernetes. ## SECTION 1: SPARK-XGBOOST EXAMPLE NOTEBOOKS 1. Mortgage Notebooks - Python - [Mortgage ETL](mortgage/notebooks/python/MortgageETL.ipynb) - [Mortgage Training Prediction](mortgage/notebooks/python/mortgage-gpu.ipynb) - [Mortgage ETL + XGBoost Training](mortgage/notebooks/python/MortgageETL+XGBoost.ipynb) - Scala - [Mortgage ETL](mortgage/notebooks/scala/mortgage-ETL.ipynb) - [Mortgage Training Prediction](mortgage/notebooks/scala/mortgage-gpu.ipynb) 2. Agaricus Notebooks - Python - [Agaricus Training Classification](agaricus/notebooks/python/agaricus-gpu.ipynb) - Scala - [Agaricus Training Classification](agaricus/notebooks/scala/agaricus-gpu.ipynb) 3. Taxi Notebook - Python - [Taxi Training Classification](taxi/notebooks/python/taxi-gpu.ipynb) - Scala - [Taxi Training Classification](taxi/notebooks/scala/taxi-gpu.ipynb) ## SECTION 2: BUILDING A PYSPARK OR A SCALA XGBOOST APPLICATION The first step to build a Spark application is preparing packages and datasets needed to build the jars. Please use the instructions below for building the - [Scala](/docs/get-started/xgboost-examples/prepare-package-data/preparation-scala.md) - [Python](/docs/get-started/xgboost-examples/prepare-package-data/preparation-python.md) In addition, we have the source code for building reference applications. Below are source codes for the example Spark jobs: - Mortgage: [Scala](mortgage/scala/src/com/nvidia/spark/examples/mortgage), [Python](mortgage/python/com/nvidia/spark/examples/mortgage) - Taxi: [Scala](taxi/scala/src/com/nvidia/spark/examples/taxi), [Python](taxi/python/com/nvidia/spark/examples/taxi) - Agaricus: [Scala](agaricus/scala/src/com/nvidia/spark/examples/agaricus), [Python](agaricus/python/com/nvidia/spark/examples/agaricus) ## SECTION 3: SETTING UP THE ENVIRONMENT Please follow below steps to run the example Spark jobs in different Spark environments: - Getting started on on-premises clusters - [Standalone cluster for Scala](/docs/get-started/xgboost-examples/on-prem-cluster/standalone-scala.md) - [Standalone cluster for Python](/docs/get-started/xgboost-examples/on-prem-cluster/standalone-python.md) - [YARN for Scala](/docs/get-started/xgboost-examples/on-prem-cluster/yarn-scala.md) - [YARN for Python](/docs/get-started/xgboost-examples/on-prem-cluster/yarn-python.md) - [Kubernetes](/docs/get-started/xgboost-examples/on-prem-cluster/kubernetes-scala.md) - Getting started on cloud service providers - Amazon AWS - [EC2](/docs/get-started/xgboost-examples/csp/aws/ec2.md) - [Databricks](/docs/get-started/xgboost-examples/csp/databricks/databricks.md) - [GCP](/docs/get-started/xgboost-examples/csp/dataproc/gcp.md) Please follow below steps to run the example notebooks in different notebook environments: - Getting started for Jupyter Notebook applications - [Apache Toree Notebook for Scala](/docs/get-started/xgboost-examples/notebook/toree.md) - [Jupyter Notebook with spylon kernel](/docs/get-started/xgboost-examples/notebook/spylon.md) - [Jupyter Notebook for Python](/docs/get-started/xgboost-examples/notebook/python-notebook.md) Note: Update the default value of `spark.sql.execution.arrow.maxRecordsPerBatch` to a larger number(such as 200000) will significantly improve performance by accelerating data transfer between JVM and Python process. For the CrossValidator job, we need to set `spark.task.resource.gpu.amount=1` to allow only 1 training task running on 1 GPU(executor), otherwise the customized CrossValidator may schedule more than 1 xgboost training tasks into one executor simultaneously and trigger [issue-131](https://github.com/NVIDIA/spark-rapids-examples/issues/131). For XGBoost job, if the number of shuffle stage tasks before training is less than the num_worker, the training tasks will be scheduled to run on part of nodes instead of all nodes due to Spark Data Locality feature. The workaround is to increase the partitions of the shuffle stage by setting `spark.sql.files.maxPartitionBytes=RightNum`. If you are running XGBoost scala notebooks on Dataproc, please make sure to update below configs to avoid job failure: ``` spark.dynamicAllocation.enabled=false spark.task.resource.gpu.amount=1 ``` ================================================ FILE: examples/XGBoost-Examples/agaricus/.gitignore ================================================ .idea target *.iml ================================================ FILE: examples/XGBoost-Examples/agaricus/notebooks/python/agaricus-gpu.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Introduction to XGBoost Spark with GPU\n", "\n", "Agaricus is an example of xgboost classifier for multiple classification. This notebook will show you how to load data, train the xgboost model.\n", "\n", "A few libraries required for this notebook:\n", " 1. cudf-cu11\n", " 2. xgboost\n", " 3. scikit-learn\n", " 4. numpy\n", " \n", "This notebook also illustrates the ease of porting a sample CPU based Spark xgboost4j code into GPU. There is no change required for running Spark XGBoost on GPU because both CPU and GPU call the same API. For CPU run, we need to vectorize the trained dataset before fitting data to classifier." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Import All Libraries" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from xgboost.spark import SparkXGBClassifier, SparkXGBClassifierModel\n", "from pyspark.ml.evaluation import MulticlassClassificationEvaluator\n", "from pyspark.sql import SparkSession\n", "from pyspark.sql.types import FloatType, StructField, StructType\n", "from time import time\n", "from pyspark.conf import SparkConf\n", "import os\n", "# if you pass/unpack the archive file and enable the environment\n", "# os.environ['PYSPARK_PYTHON'] = \"./environment/bin/python\"\n", "# os.environ['PYSPARK_DRIVER_PYTHON'] = \"./environment/bin/python\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Besides CPU version requires two extra libraries.\n", "```Python\n", "from pyspark.ml.feature import VectorAssembler\n", "from pyspark.sql.functions import col\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Create Spark Session and Data Reader" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-11-30 06:57:40,306 WARN util.NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n", "Setting default log level to \"WARN\".\n", "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n", "2022-11-30 06:57:40,550 WARN resource.ResourceUtils: The configuration of cores (exec = 2 task = 1, runnable tasks = 2) will result in wasted resources due to resource gpu limiting the number of runnable tasks per executor to: 1. Please adjust your configuration.\n", "2022-11-30 06:57:54,195 WARN rapids.RapidsPluginUtils: RAPIDS Accelerator 25.02.1 using cudf 25.02.1.\n", "2022-11-30 06:57:54,210 WARN rapids.RapidsPluginUtils: spark.rapids.sql.multiThreadedRead.numThreads is set to 20.\n", "2022-11-30 06:57:54,214 WARN rapids.RapidsPluginUtils: RAPIDS Accelerator is enabled, to disable GPU support set `spark.rapids.sql.enabled` to false.\n", "2022-11-30 06:57:54,214 WARN rapids.RapidsPluginUtils: spark.rapids.sql.explain is set to `NOT_ON_GPU`. Set it to 'NONE' to suppress the diagnostics logging about the query placement on the GPU.\n", "2022-11-30 06:57:54,685 WARN yarn.Client: Neither spark.yarn.jars nor spark.yarn.archive is set, falling back to uploading libraries under SPARK_HOME.\n" ] } ], "source": [ "SPARK_MASTER_URL = os.getenv(\"SPARK_MASTER_URL\", \"/your-url\")\n", "\n", "RAPIDS_JAR = os.getenv(\"RAPIDS_JAR\", \"/your-jar-path\")\n", "\n", "# You need to update with your real hardware resource \n", "driverMem = os.getenv(\"DRIVER_MEM\", \"2g\")\n", "executorMem = os.getenv(\"EXECUTOR_MEM\", \"2g\")\n", "pinnedPoolSize = os.getenv(\"PINNED_POOL_SIZE\", \"2g\")\n", "concurrentGpuTasks = os.getenv(\"CONCURRENT_GPU_TASKS\", \"2\")\n", "executorCores = int(os.getenv(\"EXECUTOR_CORES\", \"2\"))\n", "# Common spark settings\n", "conf = SparkConf()\n", "conf.setMaster(SPARK_MASTER_URL)\n", "conf.setAppName(\"Microbenchmark on GPU\")\n", "conf.set(\"spark.executor.instances\",\"1\")\n", "conf.set(\"spark.driver.memory\", driverMem)\n", "## The tasks will run on GPU memory, so there is no need to set a high host memory\n", "conf.set(\"spark.executor.memory\", executorMem)\n", "## The tasks will run on GPU cores, so there is no need to use many cpu cores\n", "conf.set(\"spark.executor.cores\", executorCores)\n", "\n", "\n", "# Plugin settings\n", "conf.set(\"spark.executor.resource.gpu.amount\", \"1\")\n", "conf.set(\"spark.rapids.sql.concurrentGpuTasks\", concurrentGpuTasks)\n", "conf.set(\"spark.rapids.memory.pinnedPool.size\", pinnedPoolSize)\n", "# since pyspark and xgboost share the same GPU, we disable RMM to avoid GPU OOM while training \n", "conf.set(\"spark.rapids.memory.gpu.pool\", \"NONE\")\n", "conf.set(\"spark.locality.wait\",\"0\")\n", "##############note: only support value=1 https://github.com/dmlc/xgboost/blame/master/python-package/xgboost/spark/core.py#L370-L374\n", "conf.set(\"spark.task.resource.gpu.amount\", 1) \n", "conf.set(\"spark.rapids.sql.enabled\", \"true\") \n", "conf.set(\"spark.plugins\", \"com.nvidia.spark.SQLPlugin\")\n", "conf.set(\"spark.sql.cache.serializer\",\"com.nvidia.spark.ParquetCachedBatchSerializer\")\n", "conf.set(\"spark.sql.execution.arrow.maxRecordsPerBatch\", 200000) \n", "conf.set(\"spark.driver.extraClassPath\", RAPIDS_JAR)\n", "conf.set(\"spark.executor.extraClassPath\", RAPIDS_JAR)\n", "\n", "# if you pass/unpack the archive file and enable the environment\n", "# conf.set(\"spark.yarn.dist.archives\", \"your_pyspark_venv.tar.gz#environment\")\n", "# Create spark session\n", "spark = SparkSession.builder.config(conf=conf).getOrCreate()\n", "\n", "reader = spark.read" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Specify the Data Schema and Load the Data" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "label = 'label'\n", "features = [ 'feature_' + str(i) for i in range(0, 126) ]\n", "schema = StructType([ StructField(x, FloatType()) for x in [label] + features ])\n", "\n", "# You need to update them to your real paths!\n", "dataRoot = os.getenv(\"DATA_ROOT\", \"/data\")\n", "train_path = dataRoot + \"/agaricus/csv/train\"\n", "eval_path = dataRoot + \"/agaricus/csv/eval\"\n", "\n", "data_format = 'csv'\n", "has_header = 'true'\n", "if data_format == 'csv':\n", " train_data = reader.schema(schema).option('header',has_header).csv(train_path)\n", " trans_data = reader.schema(schema).option('header',has_header).csv(eval_path)\n", "else :\n", " train_data = reader.load(train_path)\n", " trans_data = reader.load(eval_path)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note on CPU version, vectorization is required before fitting data to classifier, which means you need to assemble all feature columns into one column.\n", "\n", "```Python\n", "def vectorize(data_frame):\n", " to_floats = [ col(x.name).cast(FloatType()) for x in data_frame.schema ]\n", " return (VectorAssembler()\n", " .setInputCols(features)\n", " .setOutputCol('features')\n", " .transform(data_frame.select(to_floats))\n", " .select(col('features'), col(label)))\n", "\n", "train_data = vectorize(train_data)\n", "trans_data = vectorize(trans_data)\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Create a XGBoostClassifier" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "params = { \n", " \"tree_method\": \"hist\",\n", " \"grow_policy\": \"depthwise\",\n", " \"num_workers\": 1,\n", " \"device\": \"cuda\",\n", "}\n", "params['features_col'] = features\n", "params['label_col'] = label\n", " \n", "classifier = SparkXGBClassifier(**params)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The parameter `num_workers` should be set to the number of GPUs in Spark cluster for GPU version, while for CPU version it is usually equal to the number of the CPU cores.\n", "\n", "Concerning the device, GPU version supports `cuda` currently, while `cpu` is designed and used here for CPU training.\n", "\n", "An example of CPU classifier:\n", "```\n", "classifier = SparkXGBClassifier(\n", " feature_col=features,\n", " label_col=label, \n", " num_workers=1024,\n", ")\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Train the Data with Benchmark" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "If features_cols param set, then features_col param is ignored.\n", "2022-11-30 07:00:45,526 WARN util.package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.\n", "[Stage 5:> (0 + 1) / 1]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training takes 13.92 seconds\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\r", " \r", "/data/home/yuanli/work/reviews/pr252/pyspark_venv_20221125/lib/python3.8/site-packages/xgboost/sklearn.py:808: UserWarning: Loading a native XGBoost model with Scikit-Learn interface.\n", " warnings.warn(\"Loading a native XGBoost model with Scikit-Learn interface.\")\n" ] } ], "source": [ "def with_benchmark(phrase, action):\n", " start = time()\n", " result = action()\n", " end = time()\n", " print('{} takes {} seconds'.format(phrase, round(end - start, 2)))\n", " return result\n", "model = with_benchmark('Training', lambda: classifier.fit(train_data))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Save and Reload the Model" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "If features_cols param set, then features_col param is ignored.\n" ] } ], "source": [ "model.write().overwrite().save(dataRoot + '/model/agaricus')" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "loaded_model = SparkXGBClassifierModel().load(dataRoot + '/model/agaricus')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Transformation and Show Result Sample" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-11-30 07:01:07,030 WARN rapids.GpuOverrides: \n", "!Exec cannot run on GPU because not all expressions can be replaced; unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [rawPrediction#798, probability#1062]\n", " @Expression label#254 could run on GPU\n", " @Expression feature_0#255 could run on GPU\n", " @Expression feature_1#256 could run on GPU\n", " @Expression feature_2#257 could run on GPU\n", " @Expression feature_3#258 could run on GPU\n", " @Expression feature_4#259 could run on GPU\n", " @Expression feature_5#260 could run on GPU\n", " @Expression feature_6#261 could run on GPU\n", " @Expression feature_7#262 could run on GPU\n", " @Expression feature_8#263 could run on GPU\n", " @Expression feature_9#264 could run on GPU\n", " @Expression feature_10#265 could run on GPU\n", " @Expression feature_11#266 could run on GPU\n", " @Expression feature_12#267 could run on GPU\n", " @Expression feature_13#268 could run on GPU\n", " @Expression feature_14#269 could run on GPU\n", " @Expression feature_15#270 could run on GPU\n", " @Expression feature_16#271 could run on GPU\n", " @Expression feature_17#272 could run on GPU\n", " @Expression feature_18#273 could run on GPU\n", " @Expression feature_19#274 could run on GPU\n", " @Expression feature_20#275 could run on GPU\n", " @Expression feature_21#276 could run on GPU\n", " @Expression feature_22#277 could run on GPU\n", " @Expression feature_23#278 could run on GPU\n", " @Expression feature_24#279 could run on GPU\n", " @Expression feature_25#280 could run on GPU\n", " @Expression feature_26#281 could run on GPU\n", " @Expression feature_27#282 could run on GPU\n", " @Expression feature_28#283 could run on GPU\n", " @Expression feature_29#284 could run on GPU\n", " @Expression feature_30#285 could run on GPU\n", " @Expression feature_31#286 could run on GPU\n", " @Expression feature_32#287 could run on GPU\n", " @Expression feature_33#288 could run on GPU\n", " @Expression feature_34#289 could run on GPU\n", " @Expression feature_35#290 could run on GPU\n", " @Expression feature_36#291 could run on GPU\n", " @Expression feature_37#292 could run on GPU\n", " @Expression feature_38#293 could run on GPU\n", " @Expression feature_39#294 could run on GPU\n", " @Expression feature_40#295 could run on GPU\n", " @Expression feature_41#296 could run on GPU\n", " @Expression feature_42#297 could run on GPU\n", " @Expression feature_43#298 could run on GPU\n", " @Expression feature_44#299 could run on GPU\n", " @Expression feature_45#300 could run on GPU\n", " @Expression feature_46#301 could run on GPU\n", " @Expression feature_47#302 could run on GPU\n", " @Expression feature_48#303 could run on GPU\n", " @Expression feature_49#304 could run on GPU\n", " @Expression feature_50#305 could run on GPU\n", " @Expression feature_51#306 could run on GPU\n", " @Expression feature_52#307 could run on GPU\n", " @Expression feature_53#308 could run on GPU\n", " @Expression feature_54#309 could run on GPU\n", " @Expression feature_55#310 could run on GPU\n", " @Expression feature_56#311 could run on GPU\n", " @Expression feature_57#312 could run on GPU\n", " @Expression feature_58#313 could run on GPU\n", " @Expression feature_59#314 could run on GPU\n", " @Expression feature_60#315 could run on GPU\n", " @Expression feature_61#316 could run on GPU\n", " @Expression feature_62#317 could run on GPU\n", " @Expression feature_63#318 could run on GPU\n", " @Expression feature_64#319 could run on GPU\n", " @Expression feature_65#320 could run on GPU\n", " @Expression feature_66#321 could run on GPU\n", " @Expression feature_67#322 could run on GPU\n", " @Expression feature_68#323 could run on GPU\n", " @Expression feature_69#324 could run on GPU\n", " @Expression feature_70#325 could run on GPU\n", " @Expression feature_71#326 could run on GPU\n", " @Expression feature_72#327 could run on GPU\n", " @Expression feature_73#328 could run on GPU\n", " @Expression feature_74#329 could run on GPU\n", " @Expression feature_75#330 could run on GPU\n", " @Expression feature_76#331 could run on GPU\n", " @Expression feature_77#332 could run on GPU\n", " @Expression feature_78#333 could run on GPU\n", " @Expression feature_79#334 could run on GPU\n", " @Expression feature_80#335 could run on GPU\n", " @Expression feature_81#336 could run on GPU\n", " @Expression feature_82#337 could run on GPU\n", " @Expression feature_83#338 could run on GPU\n", " @Expression feature_84#339 could run on GPU\n", " @Expression feature_85#340 could run on GPU\n", " @Expression feature_86#341 could run on GPU\n", " @Expression feature_87#342 could run on GPU\n", " @Expression feature_88#343 could run on GPU\n", " @Expression feature_89#344 could run on GPU\n", " @Expression feature_90#345 could run on GPU\n", " @Expression feature_91#346 could run on GPU\n", " @Expression feature_92#347 could run on GPU\n", " @Expression feature_93#348 could run on GPU\n", " @Expression feature_94#349 could run on GPU\n", " @Expression feature_95#350 could run on GPU\n", " @Expression feature_96#351 could run on GPU\n", " @Expression feature_97#352 could run on GPU\n", " @Expression feature_98#353 could run on GPU\n", " @Expression feature_99#354 could run on GPU\n", " @Expression feature_100#355 could run on GPU\n", " @Expression feature_101#356 could run on GPU\n", " @Expression feature_102#357 could run on GPU\n", " @Expression feature_103#358 could run on GPU\n", " @Expression feature_104#359 could run on GPU\n", " @Expression feature_105#360 could run on GPU\n", " @Expression feature_106#361 could run on GPU\n", " @Expression feature_107#362 could run on GPU\n", " @Expression feature_108#363 could run on GPU\n", " @Expression feature_109#364 could run on GPU\n", " @Expression feature_110#365 could run on GPU\n", " @Expression feature_111#366 could run on GPU\n", " @Expression feature_112#367 could run on GPU\n", " @Expression feature_113#368 could run on GPU\n", " @Expression feature_114#369 could run on GPU\n", " @Expression feature_115#370 could run on GPU\n", " @Expression feature_116#371 could run on GPU\n", " @Expression feature_117#372 could run on GPU\n", " @Expression feature_118#373 could run on GPU\n", " @Expression feature_119#374 could run on GPU\n", " @Expression feature_120#375 could run on GPU\n", " @Expression feature_121#376 could run on GPU\n", " @Expression feature_122#377 could run on GPU\n", " @Expression feature_123#378 could run on GPU\n", " @Expression feature_124#379 could run on GPU\n", " @Expression feature_125#380 could run on GPU\n", " !Expression UDF(pythonUDF0#1327.rawPrediction) AS rawPrediction#798 cannot run on GPU because expression Alias UDF(pythonUDF0#1327.rawPrediction) AS rawPrediction#798 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7; input expression ScalaUDF UDF(pythonUDF0#1327.rawPrediction) (org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 is not supported)\n", " !Expression UDF(pythonUDF0#1327.rawPrediction) cannot run on GPU because neither UDF implemented by class org.apache.spark.ml.functions$$$Lambda$3659/488666387 provides a GPU implementation, nor the conf `spark.rapids.sql.rowBasedUDF.enabled` is enabled; expression ScalaUDF UDF(pythonUDF0#1327.rawPrediction) produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\n", " @Expression pythonUDF0#1327.rawPrediction could run on GPU\n", " @Expression pythonUDF0#1327 could run on GPU\n", " @Expression pythonUDF0#1327.prediction AS prediction#931 could run on GPU\n", " @Expression pythonUDF0#1327.prediction could run on GPU\n", " @Expression pythonUDF0#1327 could run on GPU\n", " !Expression UDF(pythonUDF0#1327.probability) AS probability#1062 cannot run on GPU because expression Alias UDF(pythonUDF0#1327.probability) AS probability#1062 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7; input expression ScalaUDF UDF(pythonUDF0#1327.probability) (org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 is not supported)\n", " !Expression UDF(pythonUDF0#1327.probability) cannot run on GPU because neither UDF implemented by class org.apache.spark.ml.functions$$$Lambda$3659/488666387 provides a GPU implementation, nor the conf `spark.rapids.sql.rowBasedUDF.enabled` is enabled; expression ScalaUDF UDF(pythonUDF0#1327.probability) produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\n", " @Expression pythonUDF0#1327.probability could run on GPU\n", " @Expression pythonUDF0#1327 could run on GPU\n", "\n", "2022-11-30 07:01:07,071 WARN rapids.GpuOverrides: \n", "!Exec cannot run on GPU because unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [rawPrediction, probability]; not all expressions can be replaced; unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [rawPrediction#798, probability#1062]\n", " @Expression label#254 could run on GPU\n", " @Expression feature_0#255 could run on GPU\n", " @Expression feature_1#256 could run on GPU\n", " @Expression feature_2#257 could run on GPU\n", " @Expression feature_3#258 could run on GPU\n", " @Expression feature_4#259 could run on GPU\n", " @Expression feature_5#260 could run on GPU\n", " @Expression feature_6#261 could run on GPU\n", " @Expression feature_7#262 could run on GPU\n", " @Expression feature_8#263 could run on GPU\n", " @Expression feature_9#264 could run on GPU\n", " @Expression feature_10#265 could run on GPU\n", " @Expression feature_11#266 could run on GPU\n", " @Expression feature_12#267 could run on GPU\n", " @Expression feature_13#268 could run on GPU\n", " @Expression feature_14#269 could run on GPU\n", " @Expression feature_15#270 could run on GPU\n", " @Expression feature_16#271 could run on GPU\n", " @Expression feature_17#272 could run on GPU\n", " @Expression feature_18#273 could run on GPU\n", " @Expression feature_19#274 could run on GPU\n", " @Expression feature_20#275 could run on GPU\n", " @Expression feature_21#276 could run on GPU\n", " @Expression feature_22#277 could run on GPU\n", " @Expression feature_23#278 could run on GPU\n", " @Expression feature_24#279 could run on GPU\n", " @Expression feature_25#280 could run on GPU\n", " @Expression feature_26#281 could run on GPU\n", " @Expression feature_27#282 could run on GPU\n", " @Expression feature_28#283 could run on GPU\n", " @Expression feature_29#284 could run on GPU\n", " @Expression feature_30#285 could run on GPU\n", " @Expression feature_31#286 could run on GPU\n", " @Expression feature_32#287 could run on GPU\n", " @Expression feature_33#288 could run on GPU\n", " @Expression feature_34#289 could run on GPU\n", " @Expression feature_35#290 could run on GPU\n", " @Expression feature_36#291 could run on GPU\n", " @Expression feature_37#292 could run on GPU\n", " @Expression feature_38#293 could run on GPU\n", " @Expression feature_39#294 could run on GPU\n", " @Expression feature_40#295 could run on GPU\n", " @Expression feature_41#296 could run on GPU\n", " @Expression feature_42#297 could run on GPU\n", " @Expression feature_43#298 could run on GPU\n", " @Expression feature_44#299 could run on GPU\n", " @Expression feature_45#300 could run on GPU\n", " @Expression feature_46#301 could run on GPU\n", " @Expression feature_47#302 could run on GPU\n", " @Expression feature_48#303 could run on GPU\n", " @Expression feature_49#304 could run on GPU\n", " @Expression feature_50#305 could run on GPU\n", " @Expression feature_51#306 could run on GPU\n", " @Expression feature_52#307 could run on GPU\n", " @Expression feature_53#308 could run on GPU\n", " @Expression feature_54#309 could run on GPU\n", " @Expression feature_55#310 could run on GPU\n", " @Expression feature_56#311 could run on GPU\n", " @Expression feature_57#312 could run on GPU\n", " @Expression feature_58#313 could run on GPU\n", " @Expression feature_59#314 could run on GPU\n", " @Expression feature_60#315 could run on GPU\n", " @Expression feature_61#316 could run on GPU\n", " @Expression feature_62#317 could run on GPU\n", " @Expression feature_63#318 could run on GPU\n", " @Expression feature_64#319 could run on GPU\n", " @Expression feature_65#320 could run on GPU\n", " @Expression feature_66#321 could run on GPU\n", " @Expression feature_67#322 could run on GPU\n", " @Expression feature_68#323 could run on GPU\n", " @Expression feature_69#324 could run on GPU\n", " @Expression feature_70#325 could run on GPU\n", " @Expression feature_71#326 could run on GPU\n", " @Expression feature_72#327 could run on GPU\n", " @Expression feature_73#328 could run on GPU\n", " @Expression feature_74#329 could run on GPU\n", " @Expression feature_75#330 could run on GPU\n", " @Expression feature_76#331 could run on GPU\n", " @Expression feature_77#332 could run on GPU\n", " @Expression feature_78#333 could run on GPU\n", " @Expression feature_79#334 could run on GPU\n", " @Expression feature_80#335 could run on GPU\n", " @Expression feature_81#336 could run on GPU\n", " @Expression feature_82#337 could run on GPU\n", " @Expression feature_83#338 could run on GPU\n", " @Expression feature_84#339 could run on GPU\n", " @Expression feature_85#340 could run on GPU\n", " @Expression feature_86#341 could run on GPU\n", " @Expression feature_87#342 could run on GPU\n", " @Expression feature_88#343 could run on GPU\n", " @Expression feature_89#344 could run on GPU\n", " @Expression feature_90#345 could run on GPU\n", " @Expression feature_91#346 could run on GPU\n", " @Expression feature_92#347 could run on GPU\n", " @Expression feature_93#348 could run on GPU\n", " @Expression feature_94#349 could run on GPU\n", " @Expression feature_95#350 could run on GPU\n", " @Expression feature_96#351 could run on GPU\n", " @Expression feature_97#352 could run on GPU\n", " @Expression feature_98#353 could run on GPU\n", " @Expression feature_99#354 could run on GPU\n", " @Expression feature_100#355 could run on GPU\n", " @Expression feature_101#356 could run on GPU\n", " @Expression feature_102#357 could run on GPU\n", " @Expression feature_103#358 could run on GPU\n", " @Expression feature_104#359 could run on GPU\n", " @Expression feature_105#360 could run on GPU\n", " @Expression feature_106#361 could run on GPU\n", " @Expression feature_107#362 could run on GPU\n", " @Expression feature_108#363 could run on GPU\n", " @Expression feature_109#364 could run on GPU\n", " @Expression feature_110#365 could run on GPU\n", " @Expression feature_111#366 could run on GPU\n", " @Expression feature_112#367 could run on GPU\n", " @Expression feature_113#368 could run on GPU\n", " @Expression feature_114#369 could run on GPU\n", " @Expression feature_115#370 could run on GPU\n", " @Expression feature_116#371 could run on GPU\n", " @Expression feature_117#372 could run on GPU\n", " @Expression feature_118#373 could run on GPU\n", " @Expression feature_119#374 could run on GPU\n", " @Expression feature_120#375 could run on GPU\n", " @Expression feature_121#376 could run on GPU\n", " @Expression feature_122#377 could run on GPU\n", " @Expression feature_123#378 could run on GPU\n", " @Expression feature_124#379 could run on GPU\n", " @Expression feature_125#380 could run on GPU\n", " !Expression rawPrediction#798 cannot run on GPU because expression AttributeReference rawPrediction#798 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\n", " @Expression prediction#931 could run on GPU\n", " !Expression probability#1062 cannot run on GPU because expression AttributeReference probability#1062 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2022-11-30 07:01:09,857 WARN rapids.GpuOverrides: \n", "!Exec cannot run on GPU because the Exec CollectLimitExec has been disabled, and is disabled by default because Collect Limit replacement can be slower on the GPU, if huge number of rows in a batch it could help by limiting the number of rows transferred from GPU to CPU. Set spark.rapids.sql.exec.CollectLimitExec to true if you wish to enable it\n", " @Partitioning could run on GPU\n", " !Exec cannot run on GPU because unsupported data types in input: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [probability#1062, rawPrediction#798]; not all expressions can be replaced\n", " @Expression cast(label#254 as string) AS label#3936 could run on GPU\n", " @Expression cast(label#254 as string) could run on GPU\n", " @Expression label#254 could run on GPU\n", " @Expression cast(rawPrediction#798 as string) AS rawPrediction#3937 could run on GPU\n", " !Expression cast(rawPrediction#798 as string) cannot run on GPU because Cast from org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 to StringType is not supported\n", " !Expression rawPrediction#798 cannot run on GPU because expression AttributeReference rawPrediction#798 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\n", " @Expression cast(probability#1062 as string) AS probability#3938 could run on GPU\n", " !Expression cast(probability#1062 as string) cannot run on GPU because Cast from org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 to StringType is not supported\n", " !Expression probability#1062 cannot run on GPU because expression AttributeReference probability#1062 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\n", " @Expression cast(prediction#931 as string) AS prediction#3939 could run on GPU\n", " @Expression cast(prediction#931 as string) could run on GPU\n", " @Expression prediction#931 could run on GPU\n", " !Exec cannot run on GPU because unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [rawPrediction, probability]; unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [probability#1062, rawPrediction#798]; not all expressions can be replaced\n", " @Expression label#254 could run on GPU\n", " @Expression prediction#931 could run on GPU\n", " !Expression probability#1062 cannot run on GPU because expression AttributeReference probability#1062 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\n", " !Expression rawPrediction#798 cannot run on GPU because expression AttributeReference rawPrediction#798 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\n", "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Transformation takes 3.26 seconds\n", "+-----+--------------------+--------------------+----------+\n", "|label| rawPrediction| probability|prediction|\n", "+-----+--------------------+--------------------+----------+\n", "| 1.0|[-9.6646747589111...|[6.35385513305664...| 1.0|\n", "| 0.0|[-8.3923015594482...|[2.26557254791259...| 1.0|\n", "| 0.0|[-8.0568389892578...|[3.16858291625976...| 1.0|\n", "| 0.0|[1.91234850883483...|[0.87128275632858...| 0.0|\n", "| 0.0|[-8.5582475662231...|[1.91867351531982...| 1.0|\n", "+-----+--------------------+--------------------+----------+\n", "only showing top 5 rows\n", "\n" ] } ], "source": [ "def transform():\n", " result = loaded_model.transform(trans_data).cache()\n", " result.foreachPartition(lambda _: None)\n", " return result\n", "result = with_benchmark('Transformation', transform)\n", "result.select(label, 'rawPrediction', 'probability', 'prediction').show(5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Evaluation" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-11-30 07:01:10,292 WARN rapids.GpuOverrides: \n", "! cannot run on GPU because not all expressions can be replaced; GPU does not currently support the operator class org.apache.spark.sql.execution.DeserializeToObjectExec\n", " ! createexternalrow(prediction#931, label#5899, 1.0#5900, newInstance(class org.apache.spark.ml.linalg.VectorUDT).deserialize, StructField(prediction,DoubleType,true), StructField(label,DoubleType,true), StructField(1.0,DoubleType,false), StructField(probability,org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7,true)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.CreateExternalRow\n", " @Expression prediction#931 could run on GPU\n", " @Expression label#5899 could run on GPU\n", " @Expression 1.0#5900 could run on GPU\n", " ! newInstance(class org.apache.spark.ml.linalg.VectorUDT).deserialize cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.Invoke\n", " ! newInstance(class org.apache.spark.ml.linalg.VectorUDT) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.NewInstance\n", " !Expression probability#1062 cannot run on GPU because expression AttributeReference probability#1062 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\n", " !Expression obj#5905 cannot run on GPU because expression AttributeReference obj#5905 produces an unsupported type ObjectType(interface org.apache.spark.sql.Row)\n", " !Exec cannot run on GPU because not all expressions can be replaced; unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [probability#1062]; unsupported data types in input: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [probability#1062]\n", " @Expression prediction#931 could run on GPU\n", " @Expression cast(label#254 as double) AS label#5899 could run on GPU\n", " @Expression cast(label#254 as double) could run on GPU\n", " @Expression label#254 could run on GPU\n", " @Expression 1.0 AS 1.0#5900 could run on GPU\n", " @Expression 1.0 could run on GPU\n", " !Expression probability#1062 cannot run on GPU because expression AttributeReference probability#1062 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\n", " !Exec cannot run on GPU because unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [rawPrediction, probability]; not all expressions can be replaced; unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [probability#1062]\n", " @Expression label#254 could run on GPU\n", " @Expression prediction#931 could run on GPU\n", " !Expression probability#1062 cannot run on GPU because expression AttributeReference probability#1062 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\n", "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluation takes 1.0 seconds\n", "Accuracy is 0.9069677632722861\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\r", "[Stage 12:> (0 + 1) / 1]\r", "\r", " \r" ] } ], "source": [ "accuracy = with_benchmark(\n", " 'Evaluation',\n", " lambda: MulticlassClassificationEvaluator().setLabelCol(label).evaluate(result))\n", "print('Accuracy is ' + str(accuracy))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Stop" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "spark.stop()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.2" } }, "nbformat": 4, "nbformat_minor": 2 } ================================================ FILE: examples/XGBoost-Examples/agaricus/notebooks/scala/agaricus-gpu.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Introduction to XGBoost Spark3.0 with GPU\n", "\n", "Agaricus is an example of XGBoost classifier for multiple classification. This notebook will show you how to load data, train the xgboost model. Comparing to original XGBoost Spark code, there're only one API difference.\n", "\n", "## Load libraries\n", "First load some common libraries will be used by both GPU version and CPU version XGBoost." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import ml.dmlc.xgboost4j.scala.spark.{XGBoostClassifier, XGBoostClassificationModel}\n", "import org.apache.spark.sql.SparkSession\n", "import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator\n", "import org.apache.spark.sql.types.{DoubleType, IntegerType, StructField, StructType}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Besides CPU version requires some extra libraries, such as:\n", "\n", "```scala\n", "import org.apache.spark.ml.feature.VectorAssembler\n", "import org.apache.spark.sql.functions._\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Set the dataset paths" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "lastException = null\n", "dataRoot = /data\n", "trainPath = /data/agaricus/csv/train/\n", "evalPath = /data/agaricus/csv/test/\n", "transPath = /data/agaricus/csv/test/\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "/data/agaricus/csv/test/" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "// You need to update them to your real paths!\n", "val dataRoot = sys.env.getOrElse(\"DATA_ROOT\", \"/data\")\n", "val trainPath = dataRoot + \"/agaricus/csv/train/\"\n", "val evalPath = dataRoot + \"/agaricus/csv/test/\"\n", "val transPath = dataRoot + \"/agaricus/csv/test/\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Build the schema of the dataset\n", "\n", "For agaricus example, the data has 126 dimensions, being named as \"feature_0\", \"feature_1\" ... \"feature_125\". The schema will be used to load data in the future." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "labelName = label\n", "dataSchema = StructType(StructField(label,DoubleType,true), StructField(feature_0,DoubleType,true), StructField(feature_1,DoubleType,true), StructField(feature_2,DoubleType,true), StructField(feature_3,DoubleType,true), StructField(feature_4,DoubleType,true), StructField(feature_5,DoubleType,true), StructField(feature_6,DoubleType,true), StructField(feature_7,DoubleType,true), StructField(feature_8,DoubleType,true), StructField(feature_9,DoubleType,true), StructField(feature_10,DoubleType,true), StructField(feature_11,DoubleType,true), StructField(feature_12,DoubleType,true), StructField(feature_13,DoubleType,true), StructFiel...\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "columnNames: (length: Int)List[String]\n", "schema: (length: Int)org.apache.spark.sql.types.StructType\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "StructType(StructField(label,DoubleType,true), StructField(feature_0,DoubleType,true), StructField(feature_1,DoubleType,true), StructField(feature_2,DoubleType,true), StructField(feature_3,DoubleType,true), StructField(feature_4,DoubleType,true), StructField(feature_5,DoubleType,true), StructField(feature_6,DoubleType,true), StructField(feature_7,DoubleType,true), StructField(feature_8,DoubleType,true), StructField(feature_9,DoubleType,true), StructField(feature_10,DoubleType,true), StructField(feature_11,DoubleType,true), StructField(feature_12,DoubleType,true), StructField(feature_13,DoubleType,true), StructFiel..." ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "val labelName = \"label\"\n", "def columnNames(length: Int): List[String] =\n", " 0.until(length).map(i => s\"feature_$i\").toList.+:(labelName)\n", "\n", "def schema(length: Int): StructType =\n", " StructType(columnNames(length).map(n => StructField(n, DoubleType)))\n", "\n", "val dataSchema = schema(126)\n", "\n", "// Build the column name list for features.\n", "val featureCols = dataSchema.filter(_.name != labelName).map(_.name).toArray" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create a new spark session and load data\n", "\n", "A new spark session should be created to continue all the following spark operations.\n", "\n", "NOTE: in this notebook, the dependency jars have been loaded when installing toree kernel. Alternatively the jars can be loaded into notebook by [%AddJar magic](https://toree.incubator.apache.org/docs/current/user/faq/). However, there's one restriction for `%AddJar`: the jar uploaded can only be available when `AddJar` is called just after a new spark session is created. Do it as below:\n", "\n", "```scala\n", "import org.apache.spark.sql.SparkSession\n", "val spark = SparkSession.builder().appName(\"agaricus-GPU\").getOrCreate\n", "%AddJar file:/data/libs/rapids-4-spark-XXX.jar\n", "%AddJar file:/data/libs/xgboost4j-spark-gpu_2.12-XXX.jar\n", "%AddJar file:/data/libs/xgboost4j-gpu_2.12-XXX.jar\n", "// ...\n", "```\n", "\n", "##### Please note the new jar \"rapids-4-spark-XXX.jar\" is only needed for GPU version, you can not add it to dependence list for CPU version." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "sparkSession = org.apache.spark.sql.SparkSession@3886ba44\n", "dataReader = org.apache.spark.sql.DataFrameReader@5c8be07f\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "org.apache.spark.sql.DataFrameReader@5c8be07f" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "// Build the spark session and data reader as usual\n", "val sparkSession = SparkSession.builder.appName(\"agaricus-gpu\").getOrCreate\n", "val dataReader = sparkSession.read.option(\"header\", true).schema(dataSchema)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "trainSet = [label: double, feature_0: double ... 125 more fields]\n", "evalSet = [label: double, feature_0: double ... 125 more fields]\n", "transSet = [label: double, feature_0: double ... 125 more fields]\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "[label: double, feature_0: double ... 125 more fields]" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "// load all the dataset\n", "val trainSet = dataReader.csv(trainPath)\n", "val evalSet = dataReader.csv(evalPath)\n", "val transSet = dataReader.csv(transPath)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Set XGBoost parameters and build a XGBoostClassifier\n", "\n", "For CPU version, `num_workers` is recommended being equal to the number of CPU cores, while for GPU version, it should be set to the number of GPUs in Spark cluster.\n", "\n", "Besides the `device` for CPU version is also different from that for GPU version. Now only \"cuda\" is supported for training on GPU.\n", "\n", "```scala\n", "// difference in parameters\n", " \"num_workers\" -> 12,\n", " \"device\" -> \"cpu\",\n", "```" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "paramMap = Map(num_workers -> 1, tree_method -> hist, device -> cuda, num_round -> 100)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "Map(num_workers -> 1, tree_method -> hist, device -> cuda, num_round -> 100)" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "// build XGBoost classifier\n", "val paramMap = Map(\n", " \"num_workers\" -> 1,\n", " \"tree_method\" -> \"hist\",\n", " \"device\" -> \"cuda\",\n", " \"num_round\" -> 100\n", ")" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "xgbClassifier = xgbc_57e2d7fc657a\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "xgbc_57e2d7fc657a" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "val xgbClassifier = new XGBoostClassifier(paramMap)\n", " .setLabelCol(labelName)\n", " .setFeaturesCol(featureCols)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Benchmark and train\n", "The object `benchmark` is used to compute the elapsed time of some operations.\n", "\n", "Training with evaluation dataset is also supported, the same as CPU version's behavior:\n", "\n", "* Call API `setEvalDataset` after initializing an XGBoostClassifier\n", "\n", "```scala\n", "xgbClassifier.setEvalDataset(evalSet)\n", "```" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "xgbc_57e2d7fc657a" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "xgbClassifier.setEvalDataset(evalSet)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "defined object Benchmark\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "object Benchmark {\n", " def time[R](phase: String)(block: => R): (R, Float) = {\n", " val t0 = System.currentTimeMillis\n", " val result = block // call-by-name\n", " val t1 = System.currentTimeMillis\n", " println(\"Elapsed time [\" + phase + \"]: \" + ((t1 - t0).toFloat / 1000) + \"s\")\n", " (result, (t1 - t0).toFloat / 1000)\n", " }\n", "}" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "------ Training ------\n", "Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=34739, DMLC_NUM_WORKER=1}\n" ] }, { "data": { "text/plain": [ "xgbClassificationModel = xgbc_57e2d7fc657a\n" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Elapsed time [train]: 11.177s\n" ] }, { "data": { "text/plain": [ "xgbc_57e2d7fc657a" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "// start training\n", "println(\"\\n------ Training ------\")\n", "val (xgbClassificationModel, _) = Benchmark.time(\"train\") {\n", " xgbClassifier.fit(trainSet)\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Transformation and evaluation\n", "Here uses `transSet` to evaluate our model and prints some useful columns to show our prediction result. After that `MulticlassClassificationEvaluator` is used to calculate an overall accuracy of our predictions." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "------ Transforming ------\n", "Elapsed time [transform]: 2.51s\n", "+-----+--------------------+--------------------+----------+\n", "|label| rawPrediction| probability|prediction|\n", "+-----+--------------------+--------------------+----------+\n", "| 1.0|[-0.9999903440475...|[9.65595245361328...| 1.0|\n", "| 0.0|[-0.9999903440475...|[9.65595245361328...| 1.0|\n", "| 0.0|[-0.9999903440475...|[9.65595245361328...| 1.0|\n", "| 0.0|[-4.4405460357666...|[0.99995559453964...| 0.0|\n", "| 0.0|[-0.9999903440475...|[9.65595245361328...| 1.0|\n", "| 1.0|[-0.9999903440475...|[9.65595245361328...| 1.0|\n", "| 0.0|[-0.9999903440475...|[9.65595245361328...| 1.0|\n", "| 1.0|[-0.9999903440475...|[9.65595245361328...| 1.0|\n", "| 0.0|[-0.9999903440475...|[9.65595245361328...| 1.0|\n", "| 1.0|[-0.9999903440475...|[9.65595245361328...| 1.0|\n", "+-----+--------------------+--------------------+----------+\n", "only showing top 10 rows\n", "\n", "\n", "------Accuracy of Evaluation------\n", "accuracy == 0.9069677632722861\n" ] }, { "data": { "text/plain": [ "results = [label: double, feature_0: double ... 128 more fields]\n", "evaluator = MulticlassClassificationEvaluator: uid=mcEval_8f89b3a17d4b, metricName=f1, metricLabel=0.0, beta=1.0, eps=1.0E-15\n", "accuracy = 0.9069677632722861\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.9069677632722861" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "// start transform\n", "println(\"\\n------ Transforming ------\")\n", "val (results, _) = Benchmark.time(\"transform\") {\n", " val ret = xgbClassificationModel.transform(transSet).cache()\n", " ret.foreachPartition((_: Iterator[_]) => ())\n", " ret\n", "}\n", "results.select(labelName, \"rawPrediction\", \"probability\", \"prediction\").show(10)\n", "\n", "println(\"\\n------Accuracy of Evaluation------\")\n", "val evaluator = new MulticlassClassificationEvaluator()\n", "evaluator.setLabelCol(labelName)\n", "val accuracy = evaluator.evaluate(results)\n", "\n", "println(s\"accuracy == $accuracy\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Save the model to disk and load model\n", "Save the model to disk and then load it to memory. After that use the loaded model to do a new prediction." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Elapsed time [transform2]: 0.069s\n" ] }, { "data": { "text/plain": [ "modelFromDisk = xgbc_57e2d7fc657a\n", "results2 = [label: double, feature_0: double ... 128 more fields]\n" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "+-----+--------------------+--------------------+----------+\n", "|label| rawPrediction| probability|prediction|\n", "+-----+--------------------+--------------------+----------+\n", "| 1.0|[-0.9999903440475...|[9.65595245361328...| 1.0|\n", "| 0.0|[-0.9999903440475...|[9.65595245361328...| 1.0|\n", "| 0.0|[-0.9999903440475...|[9.65595245361328...| 1.0|\n", "| 0.0|[-4.4405460357666...|[0.99995559453964...| 0.0|\n", "| 0.0|[-0.9999903440475...|[9.65595245361328...| 1.0|\n", "| 1.0|[-0.9999903440475...|[9.65595245361328...| 1.0|\n", "| 0.0|[-0.9999903440475...|[9.65595245361328...| 1.0|\n", "| 1.0|[-0.9999903440475...|[9.65595245361328...| 1.0|\n", "| 0.0|[-0.9999903440475...|[9.65595245361328...| 1.0|\n", "| 1.0|[-0.9999903440475...|[9.65595245361328...| 1.0|\n", "+-----+--------------------+--------------------+----------+\n", "only showing top 10 rows\n", "\n" ] }, { "data": { "text/plain": [ "[label: double, feature_0: double ... 128 more fields]" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "xgbClassificationModel.write.overwrite.save(dataRoot + \"/model/agaricus\")\n", "\n", "val modelFromDisk = XGBoostClassificationModel.load(dataRoot + \"/model/agaricus\")\n", "val (results2, _) = Benchmark.time(\"transform2\") {\n", " modelFromDisk.transform(transSet)\n", "}\n", "results2.select(labelName, \"rawPrediction\", \"probability\", \"prediction\").show(10)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "sparkSession.close()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "XGBoost4j-Spark - Scala", "language": "scala", "name": "XGBoost4j-Spark_scala" }, "language_info": { "codemirror_mode": "text/x-scala", "file_extension": ".scala", "mimetype": "text/x-scala", "name": "scala", "pygments_lexer": "scala", "version": "2.12.15" } }, "nbformat": 4, "nbformat_minor": 2 } ================================================ FILE: examples/XGBoost-Examples/agaricus/pom.xml ================================================ sample_xgboost_examples com.nvidia 0.2.3-SNAPSHOT 4.0.0 spark_examples_agaricus_${scala.binary.version} 8 8 com.nvidia spark_examples_utility_${scala.binary.version} ${project.version} compile scala/src ================================================ FILE: examples/XGBoost-Examples/agaricus/python/com/__init__.py ================================================ # Copyright (c) 2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: examples/XGBoost-Examples/agaricus/python/com/nvidia/__init__.py ================================================ # Copyright (c) 2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: examples/XGBoost-Examples/agaricus/python/com/nvidia/spark/__init__.py ================================================ # Copyright (c) 2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: examples/XGBoost-Examples/agaricus/python/com/nvidia/spark/examples/__init__.py ================================================ # Copyright (c) 2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: examples/XGBoost-Examples/agaricus/python/com/nvidia/spark/examples/agaricus/__init__.py ================================================ # Copyright (c) 2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: examples/XGBoost-Examples/agaricus/python/com/nvidia/spark/examples/agaricus/main.py ================================================ # # Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from pyspark.sql.types import * from com.nvidia.spark.examples.utility.utils import * from pyspark.sql import SparkSession from xgboost.spark import SparkXGBClassifier, SparkXGBClassifierModel label = 'label' feature_names = ['feature_' + str(i) for i in range(0, 126)] schema = StructType([StructField(x, FloatType()) for x in [label] + feature_names]) def main(args, xgboost_args): spark = (SparkSession .builder .appName(args.mainClass) .getOrCreate()) train_data, eval_data, trans_data = valid_input_data(spark, args, '', schema) if args.mode in ['all', 'train']: if train_data is None: print('-' * 80) print('Usage: train data path required when mode is all or train') print('-' * 80) exit(1) train_data, features = transform_data(train_data, label, args.use_gpu) xgboost_args['features_col'] = features xgboost_args['label_col'] = label classifier = SparkXGBClassifier(**xgboost_args) if eval_data: # TODO pass model = with_benchmark('Training', lambda: classifier.fit(train_data)) if args.modelPath: writer = model.write().overwrite() if args.overwrite else model writer.save(args.modelPath) else: model = SparkXGBClassifierModel.load(args.modelPath) if args.mode in ['all', 'transform']: if trans_data is None: print('-' * 80) print('Usage: trans data path required when mode is all or transform') print('-' * 80) exit(1) trans_data, _ = transform_data(trans_data, label, args.use_gpu) def transform(): result = model.transform(trans_data).cache() result.foreachPartition(lambda _: None) return result result = with_benchmark('Transformation', transform) show_sample(args, result, label) with_benchmark('Evaluation', lambda: check_classification_accuracy(result, label)) spark.stop() ================================================ FILE: examples/XGBoost-Examples/agaricus/scala/src/com/nvidia/spark/examples/agaricus/Main.scala ================================================ /* * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.nvidia.spark.examples.agaricus import com.nvidia.spark.examples.utility.{Benchmark, SparkSetup, XGBoostArgs} import ml.dmlc.xgboost4j.scala.spark.{XGBoostClassificationModel, XGBoostClassifier} import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator import org.apache.spark.sql.types.{FloatType, StructField, StructType} object Main { def main(args: Array[String]): Unit = { val labelName = "label" def featureNames(length: Int): List[String] = 0.until(length).map(i => s"feature_$i").toList.+:(labelName) def schema(length: Int): StructType = StructType(featureNames(length).map(n => StructField(n, FloatType))) val dataSchema = schema(126) val xgboostArgs = XGBoostArgs.parse(args) val processor = this.getClass.getSimpleName.stripSuffix("$").substring(0, 3) val appInfo = Seq("Agaricus", processor, xgboostArgs.format) // build spark session val spark = SparkSetup(args, appInfo.mkString("-")) val benchmark = Benchmark(appInfo(0), appInfo(1), appInfo(2)) // build data reader val dataReader = spark.read // load data val pathsArray = xgboostArgs.getDataPaths // train, eval, transform var datasets = pathsArray.map { paths => if (paths.nonEmpty) { xgboostArgs.format match { case "csv" => Some(dataReader.option("header", xgboostArgs.hasHeader).schema(dataSchema).csv(paths: _*)) case "orc" => Some(dataReader.orc(paths: _*)) case "parquet" => Some(dataReader.parquet(paths: _*)) case _ => throw new IllegalArgumentException("Unsupported data file format!") } } else None } val featureCols = dataSchema.filter(_.name != labelName).map(_.name).toArray val xgbClassificationModel = if (xgboostArgs.isToTrain) { // build XGBoost classifier val paramMap = xgboostArgs.xgboostParams(Map( "objective" -> "binary:logistic", )) val xgbClassifier = new XGBoostClassifier(paramMap) .setLabelCol(labelName) // === diff === .setFeaturesCol(featureCols) datasets(1).foreach(_ => xgbClassifier.setEvalDataset(_)) println("\n------ Training ------") val (model, _) = benchmark.time("train") { xgbClassifier.fit(datasets(0).get) } // Save model if modelPath exists xgboostArgs.modelPath.foreach(path => if (xgboostArgs.isOverwrite) model.write.overwrite().save(path) else model.save(path)) model } else { XGBoostClassificationModel.load(xgboostArgs.modelPath.get) } if (xgboostArgs.isToTransform) { // start transform println("\n------ Transforming ------") var (results, _) = benchmark.time("transform") { val ret = xgbClassificationModel.transform(datasets(2).get).cache() ret.foreachPartition((_: Iterator[_]) => ()) ret } results = if (xgboostArgs.isShowFeatures) { results } else { results.select(labelName, "rawPrediction", "probability", "prediction") } results.show(xgboostArgs.numRows) println("\n------Accuracy of Evaluation------") val evaluator = new MulticlassClassificationEvaluator().setLabelCol(labelName) evaluator.evaluate(results) match { case accuracy if !accuracy.isNaN => benchmark.value(accuracy, "Accuracy", "Accuracy for") // Throw an exception when NaN ? } } spark.close() } } ================================================ FILE: examples/XGBoost-Examples/aggregator/.gitignore ================================================ .idea target *.iml *.xml ================================================ FILE: examples/XGBoost-Examples/app-parameters/supported_xgboost_parameters_python.md ================================================ Supported Parameters ============================ This is a description of all the parameters available when you are running examples in this repo: 1. All [xgboost parameters](https://xgboost.readthedocs.io/en/latest/parameter.html) are supported. * Please use the `camelCase`, e.g., `--treeMethod=hist`. * `lambda` is replaced with `lambda_`, because `lambda` is a keyword in Python. 2. `--mainClass=[app class]`: The entry class of the application to be started. Available value is one of the below classes. * com.nvidia.spark.examples.agaricus.main * com.nvidia.spark.examples.agaricus.main * com.nvidia.spark.examples.mortgage.main * com.nvidia.spark.examples.mortgage.main * com.nvidia.spark.examples.taxi.main * com.nvidia.spark.examples.taxi.main * com.nvidia.spark.examples.mortgage.etl_main * com.nvidia.spark.examples.taxi.etl_main 3. `--format=[csv|parquet|orc]`: The format of the data for training/transforming, now only supports 'csv', 'parquet' and 'orc'. *Required*. 4. `--mode=[all|train|transform]`. The behavior of the XGBoost application (meaning CPUMain and GPUMain), default is 'all' if not specified. * all: Do both training and transforming, will save model to 'modelPath' if specified * train: Do training only, will save model to 'modelPath' if specified. * transform: Do transforming only, 'modelPath' is required to locate the model data to be loaded. 5. `--dataPath=[prefix]::[path]`: Path to input data file(s), or path to output data files. Use it repeatly to specify multiple data paths. * `--dataPath=train::[path]`: Path to the training data file(s), required when mode is NOT 'transform'. * `--dataPath=trans::[path]`: Path to the transforming data file(s), required when mode is NOT 'train'. * `--dataPath=eval::[path]`: Path to the evaluation data file(s) for training. Optional. * `--dataPath=rawTrain::[path]`: Path to the raw data files for training, only used by taxi/CPUMain, taxi/GPUMain now to support E2E train. * `--dataPath=rawTrans::[path]`: Path to the raw data files for transforming, only used by taxi/CPUMain, taxi/GPUMain now to support E2E tranformation. * `--dataPath=rawEval::[path]`: Path to the raw data files being used as evaluation data for training. Optional. * `--dataPath=raw::[path]`: Path to the raw data files to be transformed by taxi/ETLMain. * `--dataPath=perf::[path]`,`-dataPath=acq::[path]`: Paths to the raw data files to be transformed by mortgage/ETLMain. * `--dataPath=out::`: Path where to place the output data files for both mortgage/ETLMain and taxi/ETLMain. * `--dataPath=tmp::`: Path where to place the output data files for converting raw csv format to parquet. 6. `--modelPath=[path]`: Path to save model after training, or where to load model for transforming only. Required only when mode is 'transform'. 7. `--overwrite=[true|false]`: Whether to overwrite the current model data under 'modelPath'. Default is false. You may need to set to true to avoid IOException when saving the model to a path already exists. 8. `--hasHeader=[true|false]`: Indicate whether the csv file has header. 9. `--numRows=[int value]`: The number of the rows to be shown after transforming done. Default is 5. 10. `--showFeatures=[true|false]`: Whether to show the features columns after transforming done. Default is true. 11. `--dataRatios=[trainRatio:transformRatio]`: The ratios of data for train and transform, then the ratio for evaluation is (100-train-test). Default is 80:20, no evaluation. This is only used by taxi/ETLMain now to generate the output data. ================================================ FILE: examples/XGBoost-Examples/app-parameters/supported_xgboost_parameters_scala.md ================================================ Supported Parameters ============================ This is a description of all the parameters available when you are running examples in this repo: 1. All [xgboost parameters](https://xgboost.readthedocs.io/en/latest/parameter.html) are supported. 2. `-format=[csv|parquet|orc]`: The format of the data for training/transforming, now only supports 'csv', 'parquet' and 'orc'. *Required*. 3. `-mode=[all|train|transform]`. The behavior of the XGBoost application (meaning CPUMain and GPUMain), default is 'all' if not specified. * all: Do both training and transforming, will save model to 'modelPath' if specified * train: Do training only, will save model to 'modelPath' if specified. * transform: Do transforming only, 'modelPath' is required to locate the model data to be loaded. 4. `-dataPath=[prefix]::[path]`: Path to input data file(s), or path to output data files. Use it repeatly to specify multiple data paths. * `-dataPath=train::[path]`: Path to the training data file(s), required when mode is NOT 'transform'. * `-dataPath=trans::[path]`: Path to the transforming data file(s), required when mode is NOT 'train'. * `-dataPath=eval::[path]`: Path to the evaluation data file(s) for training. Optional. * `-dataPath=rawTrain::[path]`: Path to the raw data files for training, only used by taxi/CPUMain, taxi/GPUMain now to support E2E train. * `-dataPath=rawTrans::[path]`: Path to the raw data files for transforming, only used by taxi/CPUMain, taxi/GPUMain now to support E2E tranformation. * `-dataPath=rawEval::[path]`: Path to the raw data files being used as evaluation data for training. Optional. * `-dataPath=raw::[path]`: Path to the raw data files to be transformed by taxi/ETLMain. * `-dataPath=perf::[path]`,`-dataPath=acq::[path]`: Paths to the raw data files to be transformed by mortgage/ETLMain. * `-dataPath=out::`: Path where to place the output data files for both mortgage/ETLMain and taxi/ETLMain. * `-dataPath=tmp::`: Path where to place the output data files for converting raw csv format to parquet. 5. `-modelPath=[path]`: Path to save model after training, or where to load model for transforming only. Required only when mode is 'transform'. 6. `-overwrite=[true|false]`: Whether to overwrite the current model data under 'modelPath'. Default is false. You may need to set to true to avoid IOException when saving the model to a path already exists. 7. `-hasHeader=[true|false]`: Indicate whether the csv file has header. 8. `-numRows=[int value]`: The number of the rows to be shown after transforming done. Default is 5. 9. `-showFeatures=[true|false]`: Whether to show the features columns after transforming done. Default is true. 10. `-dataRatios=[trainRatio:transformRatio]`: The ratios of data for train and transform, then the ratio for evaluation is (100-train-test). Default is 80:20, no evaluation. This is only used by taxi/ETLMain now to generate the output data. ================================================ FILE: examples/XGBoost-Examples/assembly/assembly-no-scala.xml ================================================ jar-with-dependencies_${scala.binary.version} jar false org.scala-lang*:scala-* / true true runtime ================================================ FILE: examples/XGBoost-Examples/main.py ================================================ # # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from com.nvidia.spark.examples.main import main main() ================================================ FILE: examples/XGBoost-Examples/mortgage/.gitignore ================================================ .idea target *.iml ================================================ FILE: examples/XGBoost-Examples/mortgage/notebooks/python/MortgageETL+XGBoost.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Dataset\n", "\n", "Dataset is derived from Fannie Mae’s [Single-Family Loan Performance Data](http://www.fanniemae.com/portal/funding-the-market/data/loan-performance-data.html) with all rights reserved by Fannie Mae. Refer to these [instructions](https://github.com/NVIDIA/spark-rapids-examples/blob/branch-23.10/docs/get-started/xgboost-examples/dataset/mortgage.md) to download the dataset.\n", "\n", "# ETL + XGBoost train & transform\n", "\n", "This notebook is an end-to-end example of ETL + XGBoost Train & Transform by using [Spark-Rapids](https://github.com/NVIDIA/spark-rapids) and [XGBoost](https://github.com/dmlc/xgboost) with GPU accelerated.\n", "
The main steps:\n", "1. Run ETL to generate 2 datasets for train and test
\n", " You can choose to save the datasets or not by setting \"is_save_dataset\" to True or False.
\n", " It means you don't need to save the dataset to disk after ETL and directly feed the dataframe to XGBoost train or transform.\n", "2. Run XGBoost train with the train dataset\n", "3. Run XGBoost transform with the test dataset" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import time\n", "import os\n", "from pyspark import broadcast\n", "from pyspark.conf import SparkConf\n", "from pyspark.sql import SparkSession\n", "from pyspark.sql.functions import *\n", "from pyspark.sql.types import *\n", "from pyspark.sql.window import Window\n", "# if you pass/unpack the archive file and enable the environment\n", "# os.environ['PYSPARK_PYTHON'] = \"./environment/bin/python\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define Part\n", "### 1. Define the paths\n", "You need to update them to your real paths." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# The input path of dataset\n", "dataRoot = os.getenv(\"DATA_ROOT\", \"/data\")\n", "orig_raw_path = dataRoot + \"/mortgage/input/\"\n", "orig_raw_path_csv2parquet = dataRoot + \"/mortgage/output/csv2parquet/\"" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "SPARK_MASTER_URL = os.getenv(\"SPARK_MASTER_URL\", \"/your-url\")\n", "RAPIDS_JAR = os.getenv(\"RAPIDS_JAR\", \"/your-jar-path\")\n", "\n", "# You need to update with your real hardware resource \n", "driverMem = os.getenv(\"DRIVER_MEM\", \"10g\")\n", "executorMem = os.getenv(\"EXECUTOR_MEM\", \"10g\")\n", "pinnedPoolSize = os.getenv(\"PINNED_POOL_SIZE\", \"2g\")\n", "concurrentGpuTasks = os.getenv(\"CONCURRENT_GPU_TASKS\", \"2\")\n", "executorCores = int(os.getenv(\"EXECUTOR_CORES\", \"4\"))\n", "\n", "# Common spark settings\n", "conf = SparkConf()\n", "conf.setMaster(SPARK_MASTER_URL)\n", "conf.setAppName(\"Microbenchmark on GPU\")\n", "conf.set(\"spark.driver.memory\", driverMem)\n", "## The tasks will run on GPU memory, so there is no need to set a high host memory\n", "conf.set(\"spark.executor.memory\", executorMem)\n", "## The tasks will run on GPU cores, so there is no need to use many cpu cores\n", "conf.set(\"spark.executor.cores\", executorCores)\n", "\n", "# Plugin settings\n", "conf.set(\"spark.executor.resource.gpu.amount\", \"1\")\n", "conf.set(\"spark.rapids.sql.concurrentGpuTasks\", concurrentGpuTasks)\n", "conf.set(\"spark.rapids.memory.pinnedPool.size\", pinnedPoolSize)\n", "##############note: only support value=1 see https://github.com/dmlc/xgboost/blame/master/python-package/xgboost/spark/core.py#L370-L374\n", "conf.set(\"spark.task.resource.gpu.amount\", 1) \n", "# since pyspark and xgboost share the same GPU, we disable RMM to avoid GPU OOM while training \n", "conf.set(\"spark.rapids.memory.gpu.pool\", \"NONE\")\n", "conf.set(\"spark.rapids.sql.enabled\", \"true\") \n", "conf.set(\"spark.plugins\", \"com.nvidia.spark.SQLPlugin\")\n", "conf.set(\"spark.sql.cache.serializer\",\"com.nvidia.spark.ParquetCachedBatchSerializer\")\n", "conf.set(\"spark.sql.execution.arrow.maxRecordsPerBatch\", 200000) \n", "conf.set(\"spark.driver.extraClassPath\", RAPIDS_JAR)\n", "conf.set(\"spark.executor.extraClassPath\", RAPIDS_JAR)\n", "conf.set(\"spark.jars\", RAPIDS_JAR)\n", "\n", "# if you pass/unpack the archive file and enable the environment\n", "# conf.set(\"spark.yarn.dist.archives\", \"your_pyspark_venv.tar.gz#environment\")\n", "\n", "# Create spark session\n", "spark = SparkSession.builder.config(conf=conf).getOrCreate()\n", "reader = spark.read" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# Set True to save processed dataset after ETL\n", "# Set False, the dataset after ETL will be directly used in XGBoost train and transform\n", "\n", "is_save_dataset=True\n", "output_path_data=dataRoot + \"/mortgage/output/data/\"\n", "# the path to save the xgboost model\n", "output_path_model=dataRoot + \"/mortgage/output/model/\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2. Define the constants" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# File schema\n", "\n", "_csv_raw_schema = StructType([\n", " StructField(\"reference_pool_id\", StringType()),\n", " StructField(\"loan_id\", LongType()),\n", " StructField(\"monthly_reporting_period\", StringType()),\n", " StructField(\"orig_channel\", StringType()),\n", " StructField(\"seller_name\", StringType()),\n", " StructField(\"servicer\", StringType()),\n", " StructField(\"master_servicer\", StringType()),\n", " StructField(\"orig_interest_rate\", DoubleType()),\n", " StructField(\"interest_rate\", DoubleType()),\n", " StructField(\"orig_upb\", DoubleType()),\n", " StructField(\"upb_at_issuance\", StringType()),\n", " StructField(\"current_actual_upb\", DoubleType()),\n", " StructField(\"orig_loan_term\", IntegerType()),\n", " StructField(\"orig_date\", StringType()),\n", " StructField(\"first_pay_date\", StringType()), \n", " StructField(\"loan_age\", DoubleType()),\n", " StructField(\"remaining_months_to_legal_maturity\", DoubleType()),\n", " StructField(\"adj_remaining_months_to_maturity\", DoubleType()),\n", " StructField(\"maturity_date\", StringType()),\n", " StructField(\"orig_ltv\", DoubleType()),\n", " StructField(\"orig_cltv\", DoubleType()),\n", " StructField(\"num_borrowers\", DoubleType()),\n", " StructField(\"dti\", DoubleType()),\n", " StructField(\"borrower_credit_score\", DoubleType()),\n", " StructField(\"coborrow_credit_score\", DoubleType()),\n", " StructField(\"first_home_buyer\", StringType()),\n", " StructField(\"loan_purpose\", StringType()),\n", " StructField(\"property_type\", StringType()),\n", " StructField(\"num_units\", IntegerType()),\n", " StructField(\"occupancy_status\", StringType()),\n", " StructField(\"property_state\", StringType()),\n", " StructField(\"msa\", DoubleType()),\n", " StructField(\"zip\", IntegerType()),\n", " StructField(\"mortgage_insurance_percent\", DoubleType()),\n", " StructField(\"product_type\", StringType()),\n", " StructField(\"prepayment_penalty_indicator\", StringType()),\n", " StructField(\"interest_only_loan_indicator\", StringType()),\n", " StructField(\"interest_only_first_principal_and_interest_payment_date\", StringType()),\n", " StructField(\"months_to_amortization\", StringType()),\n", " StructField(\"current_loan_delinquency_status\", IntegerType()),\n", " StructField(\"loan_payment_history\", StringType()),\n", " StructField(\"mod_flag\", StringType()),\n", " StructField(\"mortgage_insurance_cancellation_indicator\", StringType()),\n", " StructField(\"zero_balance_code\", StringType()),\n", " StructField(\"zero_balance_effective_date\", StringType()),\n", " StructField(\"upb_at_the_time_of_removal\", StringType()),\n", " StructField(\"repurchase_date\", StringType()),\n", " StructField(\"scheduled_principal_current\", StringType()),\n", " StructField(\"total_principal_current\", StringType()),\n", " StructField(\"unscheduled_principal_current\", StringType()),\n", " StructField(\"last_paid_installment_date\", StringType()),\n", " StructField(\"foreclosed_after\", StringType()),\n", " StructField(\"disposition_date\", StringType()),\n", " StructField(\"foreclosure_costs\", DoubleType()),\n", " StructField(\"prop_preservation_and_repair_costs\", DoubleType()),\n", " StructField(\"asset_recovery_costs\", DoubleType()),\n", " StructField(\"misc_holding_expenses\", DoubleType()),\n", " StructField(\"holding_taxes\", DoubleType()),\n", " StructField(\"net_sale_proceeds\", DoubleType()),\n", " StructField(\"credit_enhancement_proceeds\", DoubleType()),\n", " StructField(\"repurchase_make_whole_proceeds\", StringType()),\n", " StructField(\"other_foreclosure_proceeds\", DoubleType()),\n", " StructField(\"non_interest_bearing_upb\", DoubleType()),\n", " StructField(\"principal_forgiveness_upb\", StringType()),\n", " StructField(\"original_list_start_date\", StringType()),\n", " StructField(\"original_list_price\", StringType()),\n", " StructField(\"current_list_start_date\", StringType()),\n", " StructField(\"current_list_price\", StringType()),\n", " StructField(\"borrower_credit_score_at_issuance\", StringType()),\n", " StructField(\"co-borrower_credit_score_at_issuance\", StringType()),\n", " StructField(\"borrower_credit_score_current\", StringType()),\n", " StructField(\"co-Borrower_credit_score_current\", StringType()),\n", " StructField(\"mortgage_insurance_type\", DoubleType()),\n", " StructField(\"servicing_activity_indicator\", StringType()),\n", " StructField(\"current_period_modification_loss_amount\", StringType()),\n", " StructField(\"cumulative_modification_loss_amount\", StringType()),\n", " StructField(\"current_period_credit_event_net_gain_or_loss\", StringType()),\n", " StructField(\"cumulative_credit_event_net_gain_or_loss\", StringType()),\n", " StructField(\"homeready_program_indicator\", StringType()),\n", " StructField(\"foreclosure_principal_write_off_amount\", StringType()),\n", " StructField(\"relocation_mortgage_indicator\", StringType()),\n", " StructField(\"zero_balance_code_change_date\", StringType()),\n", " StructField(\"loan_holdback_indicator\", StringType()),\n", " StructField(\"loan_holdback_effective_date\", StringType()),\n", " StructField(\"delinquent_accrued_interest\", StringType()),\n", " StructField(\"property_valuation_method\", StringType()),\n", " StructField(\"high_balance_loan_indicator\", StringType()),\n", " StructField(\"arm_initial_fixed-rate_period_lt_5_yr_indicator\", StringType()),\n", " StructField(\"arm_product_type\", StringType()),\n", " StructField(\"initial_fixed-rate_period\", StringType()),\n", " StructField(\"interest_rate_adjustment_frequency\", StringType()),\n", " StructField(\"next_interest_rate_adjustment_date\", StringType()),\n", " StructField(\"next_payment_change_date\", StringType()),\n", " StructField(\"index\", StringType()),\n", " StructField(\"arm_cap_structure\", StringType()),\n", " StructField(\"initial_interest_rate_cap_up_percent\", StringType()),\n", " StructField(\"periodic_interest_rate_cap_up_percent\", StringType()),\n", " StructField(\"lifetime_interest_rate_cap_up_percent\", StringType()),\n", " StructField(\"mortgage_margin\", StringType()),\n", " StructField(\"arm_balloon_indicator\", StringType()),\n", " StructField(\"arm_plan_number\", StringType()),\n", " StructField(\"borrower_assistance_plan\", StringType()),\n", " StructField(\"hltv_refinance_option_indicator\", StringType()),\n", " StructField(\"deal_name\", StringType()),\n", " StructField(\"repurchase_make_whole_proceeds_flag\", StringType()),\n", " StructField(\"alternative_delinquency_resolution\", StringType()),\n", " StructField(\"alternative_delinquency_resolution_count\", StringType()),\n", " StructField(\"total_deferral_amount\", StringType())\n", " ])" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# name mappings\n", "_name_mapping = [\n", " (\"WITMER FUNDING, LLC\", \"Witmer\"),\n", " (\"WELLS FARGO CREDIT RISK TRANSFER SECURITIES TRUST 2015\", \"Wells Fargo\"),\n", " (\"WELLS FARGO BANK, NA\" , \"Wells Fargo\"),\n", " (\"WELLS FARGO BANK, N.A.\" , \"Wells Fargo\"),\n", " (\"WELLS FARGO BANK, NA\" , \"Wells Fargo\"),\n", " (\"USAA FEDERAL SAVINGS BANK\" , \"USAA\"),\n", " (\"UNITED SHORE FINANCIAL SERVICES, LLC D\\\\/B\\\\/A UNITED WHOLESALE MORTGAGE\" , \"United Seq(e\"),\n", " (\"U.S. BANK N.A.\" , \"US Bank\"),\n", " (\"SUNTRUST MORTGAGE INC.\" , \"Suntrust\"),\n", " (\"STONEGATE MORTGAGE CORPORATION\" , \"Stonegate Mortgage\"),\n", " (\"STEARNS LENDING, LLC\" , \"Stearns Lending\"),\n", " (\"STEARNS LENDING, INC.\" , \"Stearns Lending\"),\n", " (\"SIERRA PACIFIC MORTGAGE COMPANY, INC.\" , \"Sierra Pacific Mortgage\"),\n", " (\"REGIONS BANK\" , \"Regions\"),\n", " (\"RBC MORTGAGE COMPANY\" , \"RBC\"),\n", " (\"QUICKEN LOANS INC.\" , \"Quicken Loans\"),\n", " (\"PULTE MORTGAGE, L.L.C.\" , \"Pulte Mortgage\"),\n", " (\"PROVIDENT FUNDING ASSOCIATES, L.P.\" , \"Provident Funding\"),\n", " (\"PROSPECT MORTGAGE, LLC\" , \"Prospect Mortgage\"),\n", " (\"PRINCIPAL RESIDENTIAL MORTGAGE CAPITAL RESOURCES, LLC\" , \"Principal Residential\"),\n", " (\"PNC BANK, N.A.\" , \"PNC\"),\n", " (\"PMT CREDIT RISK TRANSFER TRUST 2015-2\" , \"PennyMac\"),\n", " (\"PHH MORTGAGE CORPORATION\" , \"PHH Mortgage\"),\n", " (\"PENNYMAC CORP.\" , \"PennyMac\"),\n", " (\"PACIFIC UNION FINANCIAL, LLC\" , \"Other\"),\n", " (\"OTHER\" , \"Other\"),\n", " (\"NYCB MORTGAGE COMPANY, LLC\" , \"NYCB\"),\n", " (\"NEW YORK COMMUNITY BANK\" , \"NYCB\"),\n", " (\"NETBANK FUNDING SERVICES\" , \"Netbank\"),\n", " (\"NATIONSTAR MORTGAGE, LLC\" , \"Nationstar Mortgage\"),\n", " (\"METLIFE BANK, NA\" , \"Metlife\"),\n", " (\"LOANDEPOT.COM, LLC\" , \"LoanDepot.com\"),\n", " (\"J.P. MORGAN MADISON AVENUE SECURITIES TRUST, SERIES 2015-1\" , \"JP Morgan Chase\"),\n", " (\"J.P. MORGAN MADISON AVENUE SECURITIES TRUST, SERIES 2014-1\" , \"JP Morgan Chase\"),\n", " (\"JPMORGAN CHASE BANK, NATIONAL ASSOCIATION\" , \"JP Morgan Chase\"),\n", " (\"JPMORGAN CHASE BANK, NA\" , \"JP Morgan Chase\"),\n", " (\"JP MORGAN CHASE BANK, NA\" , \"JP Morgan Chase\"),\n", " (\"IRWIN MORTGAGE, CORPORATION\" , \"Irwin Mortgage\"),\n", " (\"IMPAC MORTGAGE CORP.\" , \"Impac Mortgage\"),\n", " (\"HSBC BANK USA, NATIONAL ASSOCIATION\" , \"HSBC\"),\n", " (\"HOMEWARD RESIDENTIAL, INC.\" , \"Homeward Mortgage\"),\n", " (\"HOMESTREET BANK\" , \"Other\"),\n", " (\"HOMEBRIDGE FINANCIAL SERVICES, INC.\" , \"HomeBridge\"),\n", " (\"HARWOOD STREET FUNDING I, LLC\" , \"Harwood Mortgage\"),\n", " (\"GUILD MORTGAGE COMPANY\" , \"Guild Mortgage\"),\n", " (\"GMAC MORTGAGE, LLC (USAA FEDERAL SAVINGS BANK)\" , \"GMAC\"),\n", " (\"GMAC MORTGAGE, LLC\" , \"GMAC\"),\n", " (\"GMAC (USAA)\" , \"GMAC\"),\n", " (\"FREMONT BANK\" , \"Fremont Bank\"),\n", " (\"FREEDOM MORTGAGE CORP.\" , \"Freedom Mortgage\"),\n", " (\"FRANKLIN AMERICAN MORTGAGE COMPANY\" , \"Franklin America\"),\n", " (\"FLEET NATIONAL BANK\" , \"Fleet National\"),\n", " (\"FLAGSTAR CAPITAL MARKETS CORPORATION\" , \"Flagstar Bank\"),\n", " (\"FLAGSTAR BANK, FSB\" , \"Flagstar Bank\"),\n", " (\"FIRST TENNESSEE BANK NATIONAL ASSOCIATION\" , \"Other\"),\n", " (\"FIFTH THIRD BANK\" , \"Fifth Third Bank\"),\n", " (\"FEDERAL HOME LOAN BANK OF CHICAGO\" , \"Fedral Home of Chicago\"),\n", " (\"FDIC, RECEIVER, INDYMAC FEDERAL BANK FSB\" , \"FDIC\"),\n", " (\"DOWNEY SAVINGS AND LOAN ASSOCIATION, F.A.\" , \"Downey Mortgage\"),\n", " (\"DITECH FINANCIAL LLC\" , \"Ditech\"),\n", " (\"CITIMORTGAGE, INC.\" , \"Citi\"),\n", " (\"CHICAGO MORTGAGE SOLUTIONS DBA INTERFIRST MORTGAGE COMPANY\" , \"Chicago Mortgage\"),\n", " (\"CHICAGO MORTGAGE SOLUTIONS DBA INTERBANK MORTGAGE COMPANY\" , \"Chicago Mortgage\"),\n", " (\"CHASE HOME FINANCE, LLC\" , \"JP Morgan Chase\"),\n", " (\"CHASE HOME FINANCE FRANKLIN AMERICAN MORTGAGE COMPANY\" , \"JP Morgan Chase\"),\n", " (\"CHASE HOME FINANCE (CIE 1)\" , \"JP Morgan Chase\"),\n", " (\"CHASE HOME FINANCE\" , \"JP Morgan Chase\"),\n", " (\"CASHCALL, INC.\" , \"CashCall\"),\n", " (\"CAPITAL ONE, NATIONAL ASSOCIATION\" , \"Capital One\"),\n", " (\"CALIBER HOME LOANS, INC.\" , \"Caliber Funding\"),\n", " (\"BISHOPS GATE RESIDENTIAL MORTGAGE TRUST\" , \"Bishops Gate Mortgage\"),\n", " (\"BANK OF AMERICA, N.A.\" , \"Bank of America\"),\n", " (\"AMTRUST BANK\" , \"AmTrust\"),\n", " (\"AMERISAVE MORTGAGE CORPORATION\" , \"Amerisave\"),\n", " (\"AMERIHOME MORTGAGE COMPANY, LLC\" , \"AmeriHome Mortgage\"),\n", " (\"ALLY BANK\" , \"Ally Bank\"),\n", " (\"ACADEMY MORTGAGE CORPORATION\" , \"Academy Mortgage\"),\n", " (\"NO CASH-OUT REFINANCE\" , \"OTHER REFINANCE\"),\n", " (\"REFINANCE - NOT SPECIFIED\" , \"OTHER REFINANCE\"),\n", " (\"Other REFINANCE\" , \"OTHER REFINANCE\")]" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# String columns\n", "cate_col_names = [\n", " \"orig_channel\",\n", " \"first_home_buyer\",\n", " \"loan_purpose\",\n", " \"property_type\",\n", " \"occupancy_status\",\n", " \"property_state\",\n", " \"product_type\",\n", " \"relocation_mortgage_indicator\",\n", " \"seller_name\",\n", " \"mod_flag\"\n", "]\n", "# Numeric columns\n", "label_col_name = \"delinquency_12\"\n", "numeric_col_names = [\n", " \"orig_interest_rate\",\n", " \"orig_upb\",\n", " \"orig_loan_term\",\n", " \"orig_ltv\",\n", " \"orig_cltv\",\n", " \"num_borrowers\",\n", " \"dti\",\n", " \"borrower_credit_score\",\n", " \"num_units\",\n", " \"zip\",\n", " \"mortgage_insurance_percent\",\n", " \"current_loan_delinquency_status\",\n", " \"current_actual_upb\",\n", " \"interest_rate\",\n", " \"loan_age\",\n", " \"msa\",\n", " \"non_interest_bearing_upb\",\n", " label_col_name\n", "]\n", "all_col_names = cate_col_names + numeric_col_names" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3. Define ETL Process\n", "\n", "Define the function to do the ETL process\n", "\n", "#### 3.1 Define Functions to Read Raw CSV File\n", "\n", "* Define function to get quarter from input CSV file name" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "def _get_quarter_from_csv_file_name():\n", " return substring_index(substring_index(input_file_name(), \".\", 1), \"/\", -1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Define function to read raw CSV data file" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "def read_raw_csv(spark, path):\n", " return spark.read.format('csv') \\\n", " .option('nullValue', '') \\\n", " .option('header', False) \\\n", " .option('delimiter', '|') \\\n", " .schema(_csv_raw_schema) \\\n", " .load(path) \\\n", " .withColumn('quarter', _get_quarter_from_csv_file_name())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Functions to extract perf and acq columns from raw schema" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "def extract_perf_columns(rawDf):\n", " perfDf = rawDf.select(\n", " col(\"loan_id\"),\n", " date_format(to_date(col(\"monthly_reporting_period\"),\"MMyyyy\"), \"MM/dd/yyyy\").alias(\"monthly_reporting_period\"),\n", " upper(col(\"servicer\")).alias(\"servicer\"),\n", " col(\"interest_rate\"),\n", " col(\"current_actual_upb\"),\n", " col(\"loan_age\"),\n", " col(\"remaining_months_to_legal_maturity\"),\n", " col(\"adj_remaining_months_to_maturity\"),\n", " date_format(to_date(col(\"maturity_date\"),\"MMyyyy\"), \"MM/yyyy\").alias(\"maturity_date\"),\n", " col(\"msa\"),\n", " col(\"current_loan_delinquency_status\"),\n", " col(\"mod_flag\"),\n", " col(\"zero_balance_code\"),\n", " date_format(to_date(col(\"zero_balance_effective_date\"),\"MMyyyy\"), \"MM/yyyy\").alias(\"zero_balance_effective_date\"),\n", " date_format(to_date(col(\"last_paid_installment_date\"),\"MMyyyy\"), \"MM/dd/yyyy\").alias(\"last_paid_installment_date\"),\n", " date_format(to_date(col(\"foreclosed_after\"),\"MMyyyy\"), \"MM/dd/yyyy\").alias(\"foreclosed_after\"),\n", " date_format(to_date(col(\"disposition_date\"),\"MMyyyy\"), \"MM/dd/yyyy\").alias(\"disposition_date\"),\n", " col(\"foreclosure_costs\"),\n", " col(\"prop_preservation_and_repair_costs\"),\n", " col(\"asset_recovery_costs\"),\n", " col(\"misc_holding_expenses\"),\n", " col(\"holding_taxes\"),\n", " col(\"net_sale_proceeds\"),\n", " col(\"credit_enhancement_proceeds\"),\n", " col(\"repurchase_make_whole_proceeds\"),\n", " col(\"other_foreclosure_proceeds\"),\n", " col(\"non_interest_bearing_upb\"),\n", " col(\"principal_forgiveness_upb\"),\n", " col(\"repurchase_make_whole_proceeds_flag\"),\n", " col(\"foreclosure_principal_write_off_amount\"),\n", " col(\"servicing_activity_indicator\"),\n", " col('quarter')\n", " )\n", " return perfDf.select(\"*\").filter(\"current_actual_upb != 0.0\")\n", "\n", "def extract_acq_columns(rawDf):\n", " acqDf = rawDf.select(\n", " col(\"loan_id\"),\n", " col(\"orig_channel\"),\n", " upper(col(\"seller_name\")).alias(\"seller_name\"),\n", " col(\"orig_interest_rate\"),\n", " col(\"orig_upb\"),\n", " col(\"orig_loan_term\"),\n", " date_format(to_date(col(\"orig_date\"),\"MMyyyy\"), \"MM/yyyy\").alias(\"orig_date\"),\n", " date_format(to_date(col(\"first_pay_date\"),\"MMyyyy\"), \"MM/yyyy\").alias(\"first_pay_date\"),\n", " col(\"orig_ltv\"),\n", " col(\"orig_cltv\"),\n", " col(\"num_borrowers\"),\n", " col(\"dti\"),\n", " col(\"borrower_credit_score\"),\n", " col(\"first_home_buyer\"),\n", " col(\"loan_purpose\"),\n", " col(\"property_type\"),\n", " col(\"num_units\"),\n", " col(\"occupancy_status\"),\n", " col(\"property_state\"),\n", " col(\"zip\"),\n", " col(\"mortgage_insurance_percent\"),\n", " col(\"product_type\"),\n", " col(\"coborrow_credit_score\"),\n", " col(\"mortgage_insurance_type\"),\n", " col(\"relocation_mortgage_indicator\"),\n", " dense_rank().over(Window.partitionBy(\"loan_id\").orderBy(to_date(col(\"monthly_reporting_period\"),\"MMyyyy\"))).alias(\"rank\"),\n", " col('quarter')\n", " )\n", "\n", " return acqDf.select(\"*\").filter(col(\"rank\")==1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 3.2 Define ETL Process\n", "\n", "* Define function to parse dates in Performance data" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "def _parse_dates(perf):\n", " return perf \\\n", " .withColumn(\"monthly_reporting_period\", to_date(col(\"monthly_reporting_period\"), \"MM/dd/yyyy\")) \\\n", " .withColumn(\"monthly_reporting_period_month\", month(col(\"monthly_reporting_period\"))) \\\n", " .withColumn(\"monthly_reporting_period_year\", year(col(\"monthly_reporting_period\"))) \\\n", " .withColumn(\"monthly_reporting_period_day\", dayofmonth(col(\"monthly_reporting_period\"))) \\\n", " .withColumn(\"last_paid_installment_date\", to_date(col(\"last_paid_installment_date\"), \"MM/dd/yyyy\")) \\\n", " .withColumn(\"foreclosed_after\", to_date(col(\"foreclosed_after\"), \"MM/dd/yyyy\")) \\\n", " .withColumn(\"disposition_date\", to_date(col(\"disposition_date\"), \"MM/dd/yyyy\")) \\\n", " .withColumn(\"maturity_date\", to_date(col(\"maturity_date\"), \"MM/yyyy\")) \\\n", " .withColumn(\"zero_balance_effective_date\", to_date(col(\"zero_balance_effective_date\"), \"MM/yyyy\"))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Define function to create deliquency data frame from Performance data" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "def _create_perf_deliquency(spark, perf):\n", " aggDF = perf.select(\n", " col(\"quarter\"),\n", " col(\"loan_id\"),\n", " col(\"current_loan_delinquency_status\"),\n", " when(col(\"current_loan_delinquency_status\") >= 1, col(\"monthly_reporting_period\")).alias(\"delinquency_30\"),\n", " when(col(\"current_loan_delinquency_status\") >= 3, col(\"monthly_reporting_period\")).alias(\"delinquency_90\"),\n", " when(col(\"current_loan_delinquency_status\") >= 6, col(\"monthly_reporting_period\")).alias(\"delinquency_180\")) \\\n", " .groupBy(\"quarter\", \"loan_id\") \\\n", " .agg(\n", " max(\"current_loan_delinquency_status\").alias(\"delinquency_12\"),\n", " min(\"delinquency_30\").alias(\"delinquency_30\"),\n", " min(\"delinquency_90\").alias(\"delinquency_90\"),\n", " min(\"delinquency_180\").alias(\"delinquency_180\")) \\\n", " .select(\n", " col(\"quarter\"),\n", " col(\"loan_id\"),\n", " (col(\"delinquency_12\") >= 1).alias(\"ever_30\"),\n", " (col(\"delinquency_12\") >= 3).alias(\"ever_90\"),\n", " (col(\"delinquency_12\") >= 6).alias(\"ever_180\"),\n", " col(\"delinquency_30\"),\n", " col(\"delinquency_90\"),\n", " col(\"delinquency_180\"))\n", " joinedDf = perf \\\n", " .withColumnRenamed(\"monthly_reporting_period\", \"timestamp\") \\\n", " .withColumnRenamed(\"monthly_reporting_period_month\", \"timestamp_month\") \\\n", " .withColumnRenamed(\"monthly_reporting_period_year\", \"timestamp_year\") \\\n", " .withColumnRenamed(\"current_loan_delinquency_status\", \"delinquency_12\") \\\n", " .withColumnRenamed(\"current_actual_upb\", \"upb_12\") \\\n", " .select(\"quarter\", \"loan_id\", \"timestamp\", \"delinquency_12\", \"upb_12\", \"timestamp_month\", \"timestamp_year\") \\\n", " .join(aggDF, [\"loan_id\", \"quarter\"], \"left_outer\")\n", "\n", " # calculate the 12 month delinquency and upb values\n", " months = 12\n", " monthArray = [lit(x) for x in range(0, 12)]\n", " # explode on a small amount of data is actually slightly more efficient than a cross join\n", " testDf = joinedDf \\\n", " .withColumn(\"month_y\", explode(array(monthArray))) \\\n", " .select(\n", " col(\"quarter\"),\n", " floor(((col(\"timestamp_year\") * 12 + col(\"timestamp_month\")) - 24000) / months).alias(\"josh_mody\"),\n", " floor(((col(\"timestamp_year\") * 12 + col(\"timestamp_month\")) - 24000 - col(\"month_y\")) / months).alias(\"josh_mody_n\"),\n", " col(\"ever_30\"),\n", " col(\"ever_90\"),\n", " col(\"ever_180\"),\n", " col(\"delinquency_30\"),\n", " col(\"delinquency_90\"),\n", " col(\"delinquency_180\"),\n", " col(\"loan_id\"),\n", " col(\"month_y\"),\n", " col(\"delinquency_12\"),\n", " col(\"upb_12\")) \\\n", " .groupBy(\"quarter\", \"loan_id\", \"josh_mody_n\", \"ever_30\", \"ever_90\", \"ever_180\", \"delinquency_30\", \"delinquency_90\", \"delinquency_180\", \"month_y\") \\\n", " .agg(max(\"delinquency_12\").alias(\"delinquency_12\"), min(\"upb_12\").alias(\"upb_12\")) \\\n", " .withColumn(\"timestamp_year\", floor((lit(24000) + (col(\"josh_mody_n\") * lit(months)) + (col(\"month_y\") - 1)) / lit(12))) \\\n", " .selectExpr(\"*\", \"pmod(24000 + (josh_mody_n * {}) + month_y, 12) as timestamp_month_tmp\".format(months)) \\\n", " .withColumn(\"timestamp_month\", when(col(\"timestamp_month_tmp\") == lit(0), lit(12)).otherwise(col(\"timestamp_month_tmp\"))) \\\n", " .withColumn(\"delinquency_12\", ((col(\"delinquency_12\") > 3).cast(\"int\") + (col(\"upb_12\") == 0).cast(\"int\")).alias(\"delinquency_12\")) \\\n", " .drop(\"timestamp_month_tmp\", \"josh_mody_n\", \"month_y\")\n", "\n", " return perf.withColumnRenamed(\"monthly_reporting_period_month\", \"timestamp_month\") \\\n", " .withColumnRenamed(\"monthly_reporting_period_year\", \"timestamp_year\") \\\n", " .join(testDf, [\"quarter\", \"loan_id\", \"timestamp_year\", \"timestamp_month\"], \"left\") \\\n", " .drop(\"timestamp_year\", \"timestamp_month\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Define function to create acquisition data frame from Acquisition data" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "def _create_acquisition(spark, acq):\n", " nameMapping = spark.createDataFrame(_name_mapping, [\"from_seller_name\", \"to_seller_name\"])\n", " return acq.join(nameMapping, col(\"seller_name\") == col(\"from_seller_name\"), \"left\") \\\n", " .drop(\"from_seller_name\") \\\n", " .withColumn(\"old_name\", col(\"seller_name\")) \\\n", " .withColumn(\"seller_name\", coalesce(col(\"to_seller_name\"), col(\"seller_name\"))) \\\n", " .drop(\"to_seller_name\") \\\n", " .withColumn(\"orig_date\", to_date(col(\"orig_date\"), \"MM/yyyy\")) \\\n", " .withColumn(\"first_pay_date\", to_date(col(\"first_pay_date\"), \"MM/yyyy\")) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 3.3 Define Casting Process\n", "This part is casting String column to Numeric one. \n", "Example:\n", "```\n", "col_1\n", " \"a\"\n", " \"b\"\n", " \"c\"\n", " \"a\"\n", "# After String ====> Numeric\n", "col_1\n", " 0\n", " 1\n", " 2\n", " 0\n", "``` \n", "
\n", "\n", "* Define function to get column dictionary\n", "\n", " Example\n", " ```\n", " col1 = [row(data=\"a\",id=0), row(data=\"b\",id=1)]\n", " ```" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "def _gen_dictionary(etl_df, col_names):\n", " cnt_table = etl_df.select(posexplode(array([col(i) for i in col_names])))\\\n", " .withColumnRenamed(\"pos\", \"column_id\")\\\n", " .withColumnRenamed(\"col\", \"data\")\\\n", " .filter(\"data is not null\")\\\n", " .groupBy(\"column_id\", \"data\")\\\n", " .count()\n", " windowed = Window.partitionBy(\"column_id\").orderBy(desc(\"count\"))\n", " return cnt_table.withColumn(\"id\", row_number().over(windowed)).drop(\"count\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Define function to convert string columns to numeric" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "def _cast_string_columns_to_numeric(spark, input_df):\n", " cached_dict_df = _gen_dictionary(input_df, cate_col_names).cache()\n", " output_df = input_df\n", " # Generate the final table with all columns being numeric.\n", " for col_pos, col_name in enumerate(cate_col_names):\n", " col_dict_df = cached_dict_df.filter(col(\"column_id\") == col_pos)\\\n", " .drop(\"column_id\")\\\n", " .withColumnRenamed(\"data\", col_name)\n", " \n", " output_df = output_df.join(broadcast(col_dict_df), col_name, \"left\")\\\n", " .drop(col_name)\\\n", " .withColumnRenamed(\"id\", col_name)\n", " return output_df " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 3.4 Define Main Function\n", "In this function:\n", "1. Parse date in Performance data by calling _parse_dates (parsed_perf)\n", "2. Create deliqency dataframe(perf_deliqency) form Performance data by calling _create_perf_deliquency\n", "3. Create cleaned acquisition dataframe(cleaned_acq) from Acquisition data by calling _create_acquisition\n", "4. Join deliqency dataframe(perf_deliqency) and cleaned acquisition dataframe(cleaned_acq), get clean_df\n", "5. Cast String column to Numeric in clean_df by calling _cast_string_columns_to_numeric, get casted_clean_df\n", "6. Return casted_clean_df as final result" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "def run_mortgage(spark, perf, acq):\n", " parsed_perf = _parse_dates(perf)\n", " perf_deliqency = _create_perf_deliquency(spark, parsed_perf)\n", " cleaned_acq = _create_acquisition(spark, acq)\n", " clean_df = perf_deliqency.join(cleaned_acq, [\"loan_id\", \"quarter\"], \"inner\").drop(\"quarter\")\n", " casted_clean_df = _cast_string_columns_to_numeric(spark, clean_df)\\\n", " .select(all_col_names)\\\n", " .withColumn(label_col_name, when(col(label_col_name) > 0, 1).otherwise(0))\\\n", " .fillna(float(0))\n", " return casted_clean_df" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Run Part\n", "### Run ETL\n", "#### 1. Add additional Spark settings" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "# GPU run, set to true\n", "spark.conf.set(\"spark.rapids.sql.enabled\", \"true\")\n", "# CPU run, set to false, it can only make ETL run on CPU when is_save_dataset=True.\n", "# spark.conf.set(\"spark.rapids.sql.enabled\", \"false\")\n", "spark.conf.set(\"spark.sql.files.maxPartitionBytes\", \"1G\")\n", "spark.conf.set(\"spark.rapids.sql.explain\", \"ALL\")\n", "spark.conf.set(\"spark.rapids.sql.batchSizeBytes\", \"512M\")\n", "spark.conf.set(\"spark.rapids.sql.reader.batchSizeBytes\", \"768M\")\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 2. Read Raw Data and Run ETL Process, Save the Result" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ETL takes 135.9117729663849\n" ] } ], "source": [ "\n", "# read raw dataset\n", "rawDf = read_raw_csv(spark, orig_raw_path)\n", "rawDf.write.parquet(orig_raw_path_csv2parquet, mode='overwrite')\n", "rawDf = spark.read.parquet(orig_raw_path_csv2parquet)\n", "\n", "acq = extract_acq_columns(rawDf)\n", "perf = extract_perf_columns(rawDf)\n", "\n", "# run main function to process data\n", "out = run_mortgage(spark, perf, acq)\n", "\n", "# save processed data\n", "if is_save_dataset:\n", " start = time.time()\n", " out.write.parquet(output_path_data, mode=\"overwrite\")\n", " end = time.time()\n", " print(\"ETL takes {}\".format(end - start))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## XGBoost Spark with GPU" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "###### Import ML Libraries" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "from xgboost.spark import SparkXGBClassifier, SparkXGBClassifierModel\n", "from pyspark.ml.evaluation import MulticlassClassificationEvaluator" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "###### Create Data Reader" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "# Make sure it runs on GPU\n", "spark.conf.set(\"spark.rapids.sql.enabled\", \"true\")\n", "\n", "reader = spark.read" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "###### Specify the Data Schema and Load the Data" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "label = \"delinquency_12\"\n", "schema = StructType([\n", " StructField(\"orig_channel\", FloatType()),\n", " StructField(\"first_home_buyer\", FloatType()),\n", " StructField(\"loan_purpose\", FloatType()),\n", " StructField(\"property_type\", FloatType()),\n", " StructField(\"occupancy_status\", FloatType()),\n", " StructField(\"property_state\", FloatType()),\n", " StructField(\"product_type\", FloatType()),\n", " StructField(\"relocation_mortgage_indicator\", FloatType()),\n", " StructField(\"seller_name\", FloatType()),\n", " StructField(\"mod_flag\", FloatType()),\n", " StructField(\"orig_interest_rate\", FloatType()),\n", " StructField(\"orig_upb\", DoubleType()),\n", " StructField(\"orig_loan_term\", IntegerType()),\n", " StructField(\"orig_ltv\", FloatType()),\n", " StructField(\"orig_cltv\", FloatType()),\n", " StructField(\"num_borrowers\", FloatType()),\n", " StructField(\"dti\", FloatType()),\n", " StructField(\"borrower_credit_score\", FloatType()),\n", " StructField(\"num_units\", IntegerType()),\n", " StructField(\"zip\", IntegerType()),\n", " StructField(\"mortgage_insurance_percent\", FloatType()),\n", " StructField(\"current_loan_delinquency_status\", IntegerType()),\n", " StructField(\"current_actual_upb\", FloatType()),\n", " StructField(\"interest_rate\", FloatType()),\n", " StructField(\"loan_age\", FloatType()),\n", " StructField(\"msa\", FloatType()),\n", " StructField(\"non_interest_bearing_upb\", FloatType()),\n", " StructField(label, IntegerType()),\n", "])\n", "features = [ x.name for x in schema if x.name != label ]\n", "\n", "if is_save_dataset:\n", " # load dataset from file\n", " etlDf = reader.parquet(output_path_data)\n", " splits = etlDf.randomSplit([0.8, 0.2])\n", " train_data = splits[0]\n", " test_data = splits[1]\n", "else:\n", " # use Dataframe from ETL directly\n", " splits = out.randomSplit([0.8, 0.2])\n", " train_data = splits[0]\n", " test_data = splits[1]" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "# This sample uses 1 worker(GPU) to run XGBoost training, you can change according to your GPU resources\n", "params = { \n", " \"tree_method\": \"hist\",\n", " \"grow_policy\": \"depthwise\",\n", " \"num_workers\": 1,\n", " \"device\": \"cuda\",\n", "}\n", "params['features_col'] = features\n", "params['label_col'] = label\n", " \n", "classifier = SparkXGBClassifier(**params)" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training takes 18.92583155632019 seconds\n" ] } ], "source": [ "def with_benchmark(phrase, action):\n", " start = time.time()\n", " result = action()\n", " end = time.time()\n", " print(\"{} takes {} seconds\".format(phrase, end - start))\n", " return result\n", "model = with_benchmark(\"Training\", lambda: classifier.fit(train_data))" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "model.write().overwrite().save(output_path_model)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "loaded_model = SparkXGBClassifierModel().load(output_path_model)" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Transformation takes 8.959877967834473 seconds\n", "+--------------+--------------------+--------------------+----------+\n", "|delinquency_12| rawPrediction| probability|prediction|\n", "+--------------+--------------------+--------------------+----------+\n", "| 0|[7.92072248458862...|[0.99963699193904...| 0.0|\n", "| 0|[7.92072248458862...|[0.99963699193904...| 0.0|\n", "| 0|[8.43130302429199...|[0.99978211015695...| 0.0|\n", "| 0|[8.20779895782470...|[0.99972755435737...| 0.0|\n", "| 0|[8.885986328125,-...|[0.99986170543706...| 0.0|\n", "+--------------+--------------------+--------------------+----------+\n", "only showing top 5 rows\n", "\n" ] } ], "source": [ "def transform():\n", " result = loaded_model.transform(test_data).cache()\n", " result.foreachPartition(lambda _: None)\n", " return result\n", "result = with_benchmark(\"Transformation\", transform)\n", "result.select(label, \"rawPrediction\", \"probability\", \"prediction\").show(5)" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Evaluation takes 0.6158628463745117 seconds\n", "Accuracy is 0.9861453808970397\n" ] } ], "source": [ "accuracy = with_benchmark(\n", " \"Evaluation\",\n", " lambda: MulticlassClassificationEvaluator().setLabelCol(label).evaluate(result))\n", "print(\"Accuracy is \" + str(accuracy))" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "spark.stop()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.2" }, "name": "gpu-mortgage", "notebookId": 4440374682851873 }, "nbformat": 4, "nbformat_minor": 1 } ================================================ FILE: examples/XGBoost-Examples/mortgage/notebooks/python/MortgageETL.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Prerequirement\n", "### 1. Download data\n", "Dataset is derived from Fannie Mae’s [Single-Family Loan Performance Data](http://www.fanniemae.com/portal/funding-the-market/data/loan-performance-data.html) with all rights reserved by Fannie Mae. Refer to these [instructions](https://github.com/NVIDIA/spark-rapids-examples/blob/branch-24.12/docs/get-started/xgboost-examples/dataset/mortgage.md) to download the dataset.\n", "\n", "### 2. Download needed jars\n", "* [rapids-4-spark_2.12-26.02.0.jar](https://repo1.maven.org/maven2/com/nvidia/rapids-4-spark_2.12/26.02.0/rapids-4-spark_2.12-26.02.0.jar)\n", "\n", "\n", "### 3. Start Spark Standalone\n", "Before running the script, please setup Spark standalone mode\n", "\n", "### 4. Add ENV\n", "```\n", "$ export SPARK_JARS=rapids-4-spark_2.12-26.02.0.jar\n", "$ export PYSPARK_DRIVER_PYTHON=jupyter \n", "$ export PYSPARK_DRIVER_PYTHON_OPTS=notebook\n", "```\n", "\n", "### 5. Start Jupyter Notebook with plugin config\n", "\n", "```\n", "$ pyspark --master ${SPARK_MASTER} \\\n", "--jars ${SPARK_JARS} \\\n", "--conf spark.plugins=com.nvidia.spark.SQLPlugin \\\n", "--conf spark.rapids.sql.incompatibleDateFormats.enabled=true \\\n", "--conf spark.rapids.sql.csv.read.double.enabled=true \\\n", "--py-files ${SPARK_PY_FILES}\n", "```\n", "\n", "## Import Libs" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "import time\n", "import os\n", "from pyspark import broadcast\n", "from pyspark.sql import SparkSession\n", "from pyspark.sql.functions import *\n", "from pyspark.sql.types import *\n", "from pyspark.sql.window import Window\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create Spark Session" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "spark = (SparkSession\n", " .builder\n", " .appName(\"MortgageETL\")\n", " .getOrCreate())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Function Define\n", "### 1. Define the constants\n", "\n", "* Define input file schema (Performance and Acquisition)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "# File schema\n", "_csv_raw_schema = StructType([\n", " StructField(\"reference_pool_id\", StringType()),\n", " StructField(\"loan_id\", LongType()),\n", " StructField(\"monthly_reporting_period\", StringType()),\n", " StructField(\"orig_channel\", StringType()),\n", " StructField(\"seller_name\", StringType()),\n", " StructField(\"servicer\", StringType()),\n", " StructField(\"master_servicer\", StringType()),\n", " StructField(\"orig_interest_rate\", DoubleType()),\n", " StructField(\"interest_rate\", DoubleType()),\n", " StructField(\"orig_upb\", DoubleType()),\n", " StructField(\"upb_at_issuance\", StringType()),\n", " StructField(\"current_actual_upb\", DoubleType()),\n", " StructField(\"orig_loan_term\", IntegerType()),\n", " StructField(\"orig_date\", StringType()),\n", " StructField(\"first_pay_date\", StringType()), \n", " StructField(\"loan_age\", DoubleType()),\n", " StructField(\"remaining_months_to_legal_maturity\", DoubleType()),\n", " StructField(\"adj_remaining_months_to_maturity\", DoubleType()),\n", " StructField(\"maturity_date\", StringType()),\n", " StructField(\"orig_ltv\", DoubleType()),\n", " StructField(\"orig_cltv\", DoubleType()),\n", " StructField(\"num_borrowers\", DoubleType()),\n", " StructField(\"dti\", DoubleType()),\n", " StructField(\"borrower_credit_score\", DoubleType()),\n", " StructField(\"coborrow_credit_score\", DoubleType()),\n", " StructField(\"first_home_buyer\", StringType()),\n", " StructField(\"loan_purpose\", StringType()),\n", " StructField(\"property_type\", StringType()),\n", " StructField(\"num_units\", IntegerType()),\n", " StructField(\"occupancy_status\", StringType()),\n", " StructField(\"property_state\", StringType()),\n", " StructField(\"msa\", DoubleType()),\n", " StructField(\"zip\", IntegerType()),\n", " StructField(\"mortgage_insurance_percent\", DoubleType()),\n", " StructField(\"product_type\", StringType()),\n", " StructField(\"prepayment_penalty_indicator\", StringType()),\n", " StructField(\"interest_only_loan_indicator\", StringType()),\n", " StructField(\"interest_only_first_principal_and_interest_payment_date\", StringType()),\n", " StructField(\"months_to_amortization\", StringType()),\n", " StructField(\"current_loan_delinquency_status\", IntegerType()),\n", " StructField(\"loan_payment_history\", StringType()),\n", " StructField(\"mod_flag\", StringType()),\n", " StructField(\"mortgage_insurance_cancellation_indicator\", StringType()),\n", " StructField(\"zero_balance_code\", StringType()),\n", " StructField(\"zero_balance_effective_date\", StringType()),\n", " StructField(\"upb_at_the_time_of_removal\", StringType()),\n", " StructField(\"repurchase_date\", StringType()),\n", " StructField(\"scheduled_principal_current\", StringType()),\n", " StructField(\"total_principal_current\", StringType()),\n", " StructField(\"unscheduled_principal_current\", StringType()),\n", " StructField(\"last_paid_installment_date\", StringType()),\n", " StructField(\"foreclosed_after\", StringType()),\n", " StructField(\"disposition_date\", StringType()),\n", " StructField(\"foreclosure_costs\", DoubleType()),\n", " StructField(\"prop_preservation_and_repair_costs\", DoubleType()),\n", " StructField(\"asset_recovery_costs\", DoubleType()),\n", " StructField(\"misc_holding_expenses\", DoubleType()),\n", " StructField(\"holding_taxes\", DoubleType()),\n", " StructField(\"net_sale_proceeds\", DoubleType()),\n", " StructField(\"credit_enhancement_proceeds\", DoubleType()),\n", " StructField(\"repurchase_make_whole_proceeds\", StringType()),\n", " StructField(\"other_foreclosure_proceeds\", DoubleType()),\n", " StructField(\"non_interest_bearing_upb\", DoubleType()),\n", " StructField(\"principal_forgiveness_upb\", StringType()),\n", " StructField(\"original_list_start_date\", StringType()),\n", " StructField(\"original_list_price\", StringType()),\n", " StructField(\"current_list_start_date\", StringType()),\n", " StructField(\"current_list_price\", StringType()),\n", " StructField(\"borrower_credit_score_at_issuance\", StringType()),\n", " StructField(\"co-borrower_credit_score_at_issuance\", StringType()),\n", " StructField(\"borrower_credit_score_current\", StringType()),\n", " StructField(\"co-Borrower_credit_score_current\", StringType()),\n", " StructField(\"mortgage_insurance_type\", DoubleType()),\n", " StructField(\"servicing_activity_indicator\", StringType()),\n", " StructField(\"current_period_modification_loss_amount\", StringType()),\n", " StructField(\"cumulative_modification_loss_amount\", StringType()),\n", " StructField(\"current_period_credit_event_net_gain_or_loss\", StringType()),\n", " StructField(\"cumulative_credit_event_net_gain_or_loss\", StringType()),\n", " StructField(\"homeready_program_indicator\", StringType()),\n", " StructField(\"foreclosure_principal_write_off_amount\", StringType()),\n", " StructField(\"relocation_mortgage_indicator\", StringType()),\n", " StructField(\"zero_balance_code_change_date\", StringType()),\n", " StructField(\"loan_holdback_indicator\", StringType()),\n", " StructField(\"loan_holdback_effective_date\", StringType()),\n", " StructField(\"delinquent_accrued_interest\", StringType()),\n", " StructField(\"property_valuation_method\", StringType()),\n", " StructField(\"high_balance_loan_indicator\", StringType()),\n", " StructField(\"arm_initial_fixed-rate_period_lt_5_yr_indicator\", StringType()),\n", " StructField(\"arm_product_type\", StringType()),\n", " StructField(\"initial_fixed-rate_period\", StringType()),\n", " StructField(\"interest_rate_adjustment_frequency\", StringType()),\n", " StructField(\"next_interest_rate_adjustment_date\", StringType()),\n", " StructField(\"next_payment_change_date\", StringType()),\n", " StructField(\"index\", StringType()),\n", " StructField(\"arm_cap_structure\", StringType()),\n", " StructField(\"initial_interest_rate_cap_up_percent\", StringType()),\n", " StructField(\"periodic_interest_rate_cap_up_percent\", StringType()),\n", " StructField(\"lifetime_interest_rate_cap_up_percent\", StringType()),\n", " StructField(\"mortgage_margin\", StringType()),\n", " StructField(\"arm_balloon_indicator\", StringType()),\n", " StructField(\"arm_plan_number\", StringType()),\n", " StructField(\"borrower_assistance_plan\", StringType()),\n", " StructField(\"hltv_refinance_option_indicator\", StringType()),\n", " StructField(\"deal_name\", StringType()),\n", " StructField(\"repurchase_make_whole_proceeds_flag\", StringType()),\n", " StructField(\"alternative_delinquency_resolution\", StringType()),\n", " StructField(\"alternative_delinquency_resolution_count\", StringType()),\n", " StructField(\"total_deferral_amount\", StringType())\n", " ])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Define seller name mapping" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "# name mappings\n", "_name_mapping = [\n", " (\"WITMER FUNDING, LLC\", \"Witmer\"),\n", " (\"WELLS FARGO CREDIT RISK TRANSFER SECURITIES TRUST 2015\", \"Wells Fargo\"),\n", " (\"WELLS FARGO BANK, NA\" , \"Wells Fargo\"),\n", " (\"WELLS FARGO BANK, N.A.\" , \"Wells Fargo\"),\n", " (\"WELLS FARGO BANK, NA\" , \"Wells Fargo\"),\n", " (\"USAA FEDERAL SAVINGS BANK\" , \"USAA\"),\n", " (\"UNITED SHORE FINANCIAL SERVICES, LLC D\\\\/B\\\\/A UNITED WHOLESALE MORTGAGE\" , \"United Seq(e\"),\n", " (\"U.S. BANK N.A.\" , \"US Bank\"),\n", " (\"SUNTRUST MORTGAGE INC.\" , \"Suntrust\"),\n", " (\"STONEGATE MORTGAGE CORPORATION\" , \"Stonegate Mortgage\"),\n", " (\"STEARNS LENDING, LLC\" , \"Stearns Lending\"),\n", " (\"STEARNS LENDING, INC.\" , \"Stearns Lending\"),\n", " (\"SIERRA PACIFIC MORTGAGE COMPANY, INC.\" , \"Sierra Pacific Mortgage\"),\n", " (\"REGIONS BANK\" , \"Regions\"),\n", " (\"RBC MORTGAGE COMPANY\" , \"RBC\"),\n", " (\"QUICKEN LOANS INC.\" , \"Quicken Loans\"),\n", " (\"PULTE MORTGAGE, L.L.C.\" , \"Pulte Mortgage\"),\n", " (\"PROVIDENT FUNDING ASSOCIATES, L.P.\" , \"Provident Funding\"),\n", " (\"PROSPECT MORTGAGE, LLC\" , \"Prospect Mortgage\"),\n", " (\"PRINCIPAL RESIDENTIAL MORTGAGE CAPITAL RESOURCES, LLC\" , \"Principal Residential\"),\n", " (\"PNC BANK, N.A.\" , \"PNC\"),\n", " (\"PMT CREDIT RISK TRANSFER TRUST 2015-2\" , \"PennyMac\"),\n", " (\"PHH MORTGAGE CORPORATION\" , \"PHH Mortgage\"),\n", " (\"PENNYMAC CORP.\" , \"PennyMac\"),\n", " (\"PACIFIC UNION FINANCIAL, LLC\" , \"Other\"),\n", " (\"OTHER\" , \"Other\"),\n", " (\"NYCB MORTGAGE COMPANY, LLC\" , \"NYCB\"),\n", " (\"NEW YORK COMMUNITY BANK\" , \"NYCB\"),\n", " (\"NETBANK FUNDING SERVICES\" , \"Netbank\"),\n", " (\"NATIONSTAR MORTGAGE, LLC\" , \"Nationstar Mortgage\"),\n", " (\"METLIFE BANK, NA\" , \"Metlife\"),\n", " (\"LOANDEPOT.COM, LLC\" , \"LoanDepot.com\"),\n", " (\"J.P. MORGAN MADISON AVENUE SECURITIES TRUST, SERIES 2015-1\" , \"JP Morgan Chase\"),\n", " (\"J.P. MORGAN MADISON AVENUE SECURITIES TRUST, SERIES 2014-1\" , \"JP Morgan Chase\"),\n", " (\"JPMORGAN CHASE BANK, NATIONAL ASSOCIATION\" , \"JP Morgan Chase\"),\n", " (\"JPMORGAN CHASE BANK, NA\" , \"JP Morgan Chase\"),\n", " (\"JP MORGAN CHASE BANK, NA\" , \"JP Morgan Chase\"),\n", " (\"IRWIN MORTGAGE, CORPORATION\" , \"Irwin Mortgage\"),\n", " (\"IMPAC MORTGAGE CORP.\" , \"Impac Mortgage\"),\n", " (\"HSBC BANK USA, NATIONAL ASSOCIATION\" , \"HSBC\"),\n", " (\"HOMEWARD RESIDENTIAL, INC.\" , \"Homeward Mortgage\"),\n", " (\"HOMESTREET BANK\" , \"Other\"),\n", " (\"HOMEBRIDGE FINANCIAL SERVICES, INC.\" , \"HomeBridge\"),\n", " (\"HARWOOD STREET FUNDING I, LLC\" , \"Harwood Mortgage\"),\n", " (\"GUILD MORTGAGE COMPANY\" , \"Guild Mortgage\"),\n", " (\"GMAC MORTGAGE, LLC (USAA FEDERAL SAVINGS BANK)\" , \"GMAC\"),\n", " (\"GMAC MORTGAGE, LLC\" , \"GMAC\"),\n", " (\"GMAC (USAA)\" , \"GMAC\"),\n", " (\"FREMONT BANK\" , \"Fremont Bank\"),\n", " (\"FREEDOM MORTGAGE CORP.\" , \"Freedom Mortgage\"),\n", " (\"FRANKLIN AMERICAN MORTGAGE COMPANY\" , \"Franklin America\"),\n", " (\"FLEET NATIONAL BANK\" , \"Fleet National\"),\n", " (\"FLAGSTAR CAPITAL MARKETS CORPORATION\" , \"Flagstar Bank\"),\n", " (\"FLAGSTAR BANK, FSB\" , \"Flagstar Bank\"),\n", " (\"FIRST TENNESSEE BANK NATIONAL ASSOCIATION\" , \"Other\"),\n", " (\"FIFTH THIRD BANK\" , \"Fifth Third Bank\"),\n", " (\"FEDERAL HOME LOAN BANK OF CHICAGO\" , \"Fedral Home of Chicago\"),\n", " (\"FDIC, RECEIVER, INDYMAC FEDERAL BANK FSB\" , \"FDIC\"),\n", " (\"DOWNEY SAVINGS AND LOAN ASSOCIATION, F.A.\" , \"Downey Mortgage\"),\n", " (\"DITECH FINANCIAL LLC\" , \"Ditech\"),\n", " (\"CITIMORTGAGE, INC.\" , \"Citi\"),\n", " (\"CHICAGO MORTGAGE SOLUTIONS DBA INTERFIRST MORTGAGE COMPANY\" , \"Chicago Mortgage\"),\n", " (\"CHICAGO MORTGAGE SOLUTIONS DBA INTERBANK MORTGAGE COMPANY\" , \"Chicago Mortgage\"),\n", " (\"CHASE HOME FINANCE, LLC\" , \"JP Morgan Chase\"),\n", " (\"CHASE HOME FINANCE FRANKLIN AMERICAN MORTGAGE COMPANY\" , \"JP Morgan Chase\"),\n", " (\"CHASE HOME FINANCE (CIE 1)\" , \"JP Morgan Chase\"),\n", " (\"CHASE HOME FINANCE\" , \"JP Morgan Chase\"),\n", " (\"CASHCALL, INC.\" , \"CashCall\"),\n", " (\"CAPITAL ONE, NATIONAL ASSOCIATION\" , \"Capital One\"),\n", " (\"CALIBER HOME LOANS, INC.\" , \"Caliber Funding\"),\n", " (\"BISHOPS GATE RESIDENTIAL MORTGAGE TRUST\" , \"Bishops Gate Mortgage\"),\n", " (\"BANK OF AMERICA, N.A.\" , \"Bank of America\"),\n", " (\"AMTRUST BANK\" , \"AmTrust\"),\n", " (\"AMERISAVE MORTGAGE CORPORATION\" , \"Amerisave\"),\n", " (\"AMERIHOME MORTGAGE COMPANY, LLC\" , \"AmeriHome Mortgage\"),\n", " (\"ALLY BANK\" , \"Ally Bank\"),\n", " (\"ACADEMY MORTGAGE CORPORATION\" , \"Academy Mortgage\"),\n", " (\"NO CASH-OUT REFINANCE\" , \"OTHER REFINANCE\"),\n", " (\"REFINANCE - NOT SPECIFIED\" , \"OTHER REFINANCE\"),\n", " (\"Other REFINANCE\" , \"OTHER REFINANCE\")]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Define category (string) column and numeric column" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "# String columns\n", "cate_col_names = [\n", " \"orig_channel\",\n", " \"first_home_buyer\",\n", " \"loan_purpose\",\n", " \"property_type\",\n", " \"occupancy_status\",\n", " \"property_state\",\n", " \"product_type\",\n", " \"relocation_mortgage_indicator\",\n", " \"seller_name\",\n", " \"mod_flag\"\n", "]\n", "# Numberic columns\n", "label_col_name = \"delinquency_12\"\n", "numeric_col_names = [\n", " \"orig_interest_rate\",\n", " \"orig_upb\",\n", " \"orig_loan_term\",\n", " \"orig_ltv\",\n", " \"orig_cltv\",\n", " \"num_borrowers\",\n", " \"dti\",\n", " \"borrower_credit_score\",\n", " \"num_units\",\n", " \"zip\",\n", " \"mortgage_insurance_percent\",\n", " \"current_loan_delinquency_status\",\n", " \"current_actual_upb\",\n", " \"interest_rate\",\n", " \"loan_age\",\n", " \"msa\",\n", " \"non_interest_bearing_upb\",\n", " label_col_name\n", "]\n", "all_col_names = cate_col_names + numeric_col_names" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Functions to extract perf and acq columns from raw schema" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [], "source": [ "def extract_perf_columns(rawDf):\n", " perfDf = rawDf.select(\n", " col(\"loan_id\"),\n", " date_format(to_date(col(\"monthly_reporting_period\"),\"MMyyyy\"), \"MM/dd/yyyy\").alias(\"monthly_reporting_period\"),\n", " upper(col(\"servicer\")).alias(\"servicer\"),\n", " col(\"interest_rate\"),\n", " col(\"current_actual_upb\"),\n", " col(\"loan_age\"),\n", " col(\"remaining_months_to_legal_maturity\"),\n", " col(\"adj_remaining_months_to_maturity\"),\n", " date_format(to_date(col(\"maturity_date\"),\"MMyyyy\"), \"MM/yyyy\").alias(\"maturity_date\"),\n", " col(\"msa\"),\n", " col(\"current_loan_delinquency_status\"),\n", " col(\"mod_flag\"),\n", " col(\"zero_balance_code\"),\n", " date_format(to_date(col(\"zero_balance_effective_date\"),\"MMyyyy\"), \"MM/yyyy\").alias(\"zero_balance_effective_date\"),\n", " date_format(to_date(col(\"last_paid_installment_date\"),\"MMyyyy\"), \"MM/dd/yyyy\").alias(\"last_paid_installment_date\"),\n", " date_format(to_date(col(\"foreclosed_after\"),\"MMyyyy\"), \"MM/dd/yyyy\").alias(\"foreclosed_after\"),\n", " date_format(to_date(col(\"disposition_date\"),\"MMyyyy\"), \"MM/dd/yyyy\").alias(\"disposition_date\"),\n", " col(\"foreclosure_costs\"),\n", " col(\"prop_preservation_and_repair_costs\"),\n", " col(\"asset_recovery_costs\"),\n", " col(\"misc_holding_expenses\"),\n", " col(\"holding_taxes\"),\n", " col(\"net_sale_proceeds\"),\n", " col(\"credit_enhancement_proceeds\"),\n", " col(\"repurchase_make_whole_proceeds\"),\n", " col(\"other_foreclosure_proceeds\"),\n", " col(\"non_interest_bearing_upb\"),\n", " col(\"principal_forgiveness_upb\"),\n", " col(\"repurchase_make_whole_proceeds_flag\"),\n", " col(\"foreclosure_principal_write_off_amount\"),\n", " col(\"servicing_activity_indicator\"),\n", " col('quarter')\n", " )\n", "\n", " return perfDf.select(\"*\").filter(\"current_actual_upb != 0.0\")\n", "\n", "def extract_acq_columns(rawDf):\n", " acqDf = rawDf.select(\n", " col(\"loan_id\"),\n", " col(\"orig_channel\"),\n", " upper(col(\"seller_name\")).alias(\"seller_name\"),\n", " col(\"orig_interest_rate\"),\n", " col(\"orig_upb\"),\n", " col(\"orig_loan_term\"),\n", " date_format(to_date(col(\"orig_date\"),\"MMyyyy\"), \"MM/yyyy\").alias(\"orig_date\"),\n", " date_format(to_date(col(\"first_pay_date\"),\"MMyyyy\"), \"MM/yyyy\").alias(\"first_pay_date\"),\n", " col(\"orig_ltv\"),\n", " col(\"orig_cltv\"),\n", " col(\"num_borrowers\"),\n", " col(\"dti\"),\n", " col(\"borrower_credit_score\"),\n", " col(\"first_home_buyer\"),\n", " col(\"loan_purpose\"),\n", " col(\"property_type\"),\n", " col(\"num_units\"),\n", " col(\"occupancy_status\"),\n", " col(\"property_state\"),\n", " col(\"zip\"),\n", " col(\"mortgage_insurance_percent\"),\n", " col(\"product_type\"),\n", " col(\"coborrow_credit_score\"),\n", " col(\"mortgage_insurance_type\"),\n", " col(\"relocation_mortgage_indicator\"),\n", " dense_rank().over(Window.partitionBy(\"loan_id\").orderBy(to_date(col(\"monthly_reporting_period\"),\"MMyyyy\"))).alias(\"rank\"),\n", " col('quarter')\n", " )\n", "\n", " return acqDf.select(\"*\").filter(col(\"rank\")==1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2. Define ETL Process\n", "\n", "Define the function to do the ETL process\n", "\n", "#### 2.1 Define Functions to Read Raw CSV File\n", "\n", "* Define function to get quarter from input CSV file name" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [], "source": [ "def _get_quarter_from_csv_file_name():\n", " return substring_index(substring_index(input_file_name(), '.', 1), '/', -1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Define function to read raw CSV data file" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [], "source": [ "def read_raw_csv(spark, path):\n", " return spark.read.format('csv') \\\n", " .option('nullValue', '') \\\n", " .option('header', False) \\\n", " .option('delimiter', '|') \\\n", " .schema(_csv_raw_schema) \\\n", " .load(path) \\\n", " .withColumn('quarter', _get_quarter_from_csv_file_name())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 2.2 Define ETL Process\n", "\n", "* Define function to parse dates in Performance data" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [], "source": [ "def _parse_dates(perf):\n", " return perf \\\n", " .withColumn('monthly_reporting_period', to_date(col('monthly_reporting_period'), 'MM/dd/yyyy')) \\\n", " .withColumn('monthly_reporting_period_month', month(col('monthly_reporting_period'))) \\\n", " .withColumn('monthly_reporting_period_year', year(col('monthly_reporting_period'))) \\\n", " .withColumn('monthly_reporting_period_day', dayofmonth(col('monthly_reporting_period'))) \\\n", " .withColumn('last_paid_installment_date', to_date(col('last_paid_installment_date'), 'MM/dd/yyyy')) \\\n", " .withColumn('foreclosed_after', to_date(col('foreclosed_after'), 'MM/dd/yyyy')) \\\n", " .withColumn('disposition_date', to_date(col('disposition_date'), 'MM/dd/yyyy')) \\\n", " .withColumn('maturity_date', to_date(col('maturity_date'), 'MM/yyyy')) \\\n", " .withColumn('zero_balance_effective_date', to_date(col('zero_balance_effective_date'), 'MM/yyyy'))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Define function to create deliquency data frame from Performance data" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [], "source": [ "def _create_perf_deliquency(spark, perf):\n", " aggDF = perf.select(\n", " col(\"quarter\"),\n", " col(\"loan_id\"),\n", " col(\"current_loan_delinquency_status\"),\n", " when(col(\"current_loan_delinquency_status\") >= 1, col(\"monthly_reporting_period\")).alias(\"delinquency_30\"),\n", " when(col(\"current_loan_delinquency_status\") >= 3, col(\"monthly_reporting_period\")).alias(\"delinquency_90\"),\n", " when(col(\"current_loan_delinquency_status\") >= 6, col(\"monthly_reporting_period\")).alias(\"delinquency_180\")) \\\n", " .groupBy(\"quarter\", \"loan_id\") \\\n", " .agg(\n", " max(\"current_loan_delinquency_status\").alias(\"delinquency_12\"),\n", " min(\"delinquency_30\").alias(\"delinquency_30\"),\n", " min(\"delinquency_90\").alias(\"delinquency_90\"),\n", " min(\"delinquency_180\").alias(\"delinquency_180\")) \\\n", " .select(\n", " col(\"quarter\"),\n", " col(\"loan_id\"),\n", " (col(\"delinquency_12\") >= 1).alias(\"ever_30\"),\n", " (col(\"delinquency_12\") >= 3).alias(\"ever_90\"),\n", " (col(\"delinquency_12\") >= 6).alias(\"ever_180\"),\n", " col(\"delinquency_30\"),\n", " col(\"delinquency_90\"),\n", " col(\"delinquency_180\"))\n", " joinedDf = perf \\\n", " .withColumnRenamed(\"monthly_reporting_period\", \"timestamp\") \\\n", " .withColumnRenamed(\"monthly_reporting_period_month\", \"timestamp_month\") \\\n", " .withColumnRenamed(\"monthly_reporting_period_year\", \"timestamp_year\") \\\n", " .withColumnRenamed(\"current_loan_delinquency_status\", \"delinquency_12\") \\\n", " .withColumnRenamed(\"current_actual_upb\", \"upb_12\") \\\n", " .select(\"quarter\", \"loan_id\", \"timestamp\", \"delinquency_12\", \"upb_12\", \"timestamp_month\", \"timestamp_year\") \\\n", " .join(aggDF, [\"loan_id\", \"quarter\"], \"left_outer\")\n", "\n", " # calculate the 12 month delinquency and upb values\n", " months = 12\n", " monthArray = [lit(x) for x in range(0, 12)]\n", " # explode on a small amount of data is actually slightly more efficient than a cross join\n", " testDf = joinedDf \\\n", " .withColumn(\"month_y\", explode(array(monthArray))) \\\n", " .select(\n", " col(\"quarter\"),\n", " floor(((col(\"timestamp_year\") * 12 + col(\"timestamp_month\")) - 24000) / months).alias(\"josh_mody\"),\n", " floor(((col(\"timestamp_year\") * 12 + col(\"timestamp_month\")) - 24000 - col(\"month_y\")) / months).alias(\"josh_mody_n\"),\n", " col(\"ever_30\"),\n", " col(\"ever_90\"),\n", " col(\"ever_180\"),\n", " col(\"delinquency_30\"),\n", " col(\"delinquency_90\"),\n", " col(\"delinquency_180\"),\n", " col(\"loan_id\"),\n", " col(\"month_y\"),\n", " col(\"delinquency_12\"),\n", " col(\"upb_12\")) \\\n", " .groupBy(\"quarter\", \"loan_id\", \"josh_mody_n\", \"ever_30\", \"ever_90\", \"ever_180\", \"delinquency_30\", \"delinquency_90\", \"delinquency_180\", \"month_y\") \\\n", " .agg(max(\"delinquency_12\").alias(\"delinquency_12\"), min(\"upb_12\").alias(\"upb_12\")) \\\n", " .withColumn(\"timestamp_year\", floor((lit(24000) + (col(\"josh_mody_n\") * lit(months)) + (col(\"month_y\") - 1)) / lit(12))) \\\n", " .selectExpr('*', 'pmod(24000 + (josh_mody_n * {}) + month_y, 12) as timestamp_month_tmp'.format(months)) \\\n", " .withColumn(\"timestamp_month\", when(col(\"timestamp_month_tmp\") == lit(0), lit(12)).otherwise(col(\"timestamp_month_tmp\"))) \\\n", " .withColumn(\"delinquency_12\", ((col(\"delinquency_12\") > 3).cast(\"int\") + (col(\"upb_12\") == 0).cast(\"int\")).alias(\"delinquency_12\")) \\\n", " .drop(\"timestamp_month_tmp\", \"josh_mody_n\", \"month_y\")\n", "\n", " return perf.withColumnRenamed(\"monthly_reporting_period_month\", \"timestamp_month\") \\\n", " .withColumnRenamed(\"monthly_reporting_period_year\", \"timestamp_year\") \\\n", " .join(testDf, [\"quarter\", \"loan_id\", \"timestamp_year\", \"timestamp_month\"], \"left\") \\\n", " .drop(\"timestamp_year\", \"timestamp_month\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Define function to create acquisition data frame from Acquisition data" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [], "source": [ "def _create_acquisition(spark, acq):\n", " nameMapping = spark.createDataFrame(_name_mapping, [\"from_seller_name\", \"to_seller_name\"])\n", " return acq.join(nameMapping, col(\"seller_name\") == col(\"from_seller_name\"), \"left\") \\\n", " .drop(\"from_seller_name\") \\\n", " .withColumn(\"old_name\", col(\"seller_name\")) \\\n", " .withColumn(\"seller_name\", coalesce(col(\"to_seller_name\"), col(\"seller_name\"))) \\\n", " .drop(\"to_seller_name\") \\\n", " .withColumn(\"orig_date\", to_date(col(\"orig_date\"), \"MM/yyyy\")) \\\n", " .withColumn(\"first_pay_date\", to_date(col(\"first_pay_date\"), \"MM/yyyy\")) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 2.3 Define Casting Process\n", "This part is casting String column to Numbric. \n", "Example:\n", "```\n", "col_1\n", " \"a\"\n", " \"b\"\n", " \"c\"\n", " \"a\"\n", "# After String ====> Numberic\n", "col_1\n", " 0\n", " 1\n", " 2\n", " 0\n", "``` \n", "
\n", "\n", "* Define function to get column dictionary\n", "\n", " Example\n", " ```\n", " col1 = [row(data=\"a\",id=0), row(data=\"b\",id=1)]\n", " ```" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [], "source": [ "def _gen_dictionary(etl_df, col_names):\n", " cnt_table = etl_df.select(posexplode(array([col(i) for i in col_names])))\\\n", " .withColumnRenamed(\"pos\", \"column_id\")\\\n", " .withColumnRenamed(\"col\", \"data\")\\\n", " .filter(\"data is not null\")\\\n", " .groupBy(\"column_id\", \"data\")\\\n", " .count()\n", " windowed = Window.partitionBy(\"column_id\").orderBy(desc(\"count\"))\n", " return cnt_table.withColumn(\"id\", row_number().over(windowed)).drop(\"count\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Define function to convert string columns to numeric" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [], "source": [ "def _cast_string_columns_to_numeric(spark, input_df):\n", " cached_dict_df = _gen_dictionary(input_df, cate_col_names).cache()\n", " output_df = input_df\n", " # Generate the final table with all columns being numeric.\n", " for col_pos, col_name in enumerate(cate_col_names):\n", " col_dict_df = cached_dict_df.filter(col(\"column_id\") == col_pos)\\\n", " .drop(\"column_id\")\\\n", " .withColumnRenamed(\"data\", col_name)\n", " \n", " output_df = output_df.join(broadcast(col_dict_df), col_name, \"left\")\\\n", " .drop(col_name)\\\n", " .withColumnRenamed(\"id\", col_name)\n", " return output_df " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 2.4 Define Main Function\n", "In this function:\n", "1. Parse date in Performance data by calling _parse_dates (parsed_perf)\n", "2. Create deliqency dataframe(perf_deliqency) form Performance data by calling _create_perf_deliquency\n", "3. Create cleaned acquisition dataframe(cleaned_acq) from Acquisition data by calling _create_acquisition\n", "4. Join deliqency dataframe(perf_deliqency) and cleaned acquisition dataframe(cleaned_acq), get clean_df\n", "5. Cast String column to Numbric in clean_df by calling _cast_string_columns_to_numeric, get casted_clean_df\n", "6. Return casted_clean_df as final result" ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [], "source": [ "def run_mortgage(spark, perf, acq):\n", " parsed_perf = _parse_dates(perf)\n", " perf_deliqency = _create_perf_deliquency(spark, parsed_perf)\n", " cleaned_acq = _create_acquisition(spark, acq)\n", " clean_df = perf_deliqency.join(cleaned_acq, [\"loan_id\", \"quarter\"], \"inner\").drop(\"quarter\")\n", " casted_clean_df = _cast_string_columns_to_numeric(spark, clean_df)\\\n", " .select(all_col_names)\\\n", " .withColumn(label_col_name, when(col(label_col_name) > 0, 1).otherwise(0))\\\n", " .fillna(float(0))\n", " return casted_clean_df" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Script Settings\n", "\n", "### 1. File Path Settings\n", "* Define input file path" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [], "source": [ "# You need to update them to your real paths!\n", "dataRoot = os.getenv(\"DATA_ROOT\", \"/data\")\n", "orig_raw_path = dataRoot + '/mortgage/input/'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Define output folder path" ] }, { "cell_type": "code", "execution_count": 56, "metadata": {}, "outputs": [], "source": [ "output_path = dataRoot + '/mortgage/output/data/'\n", "output_csv2parquet = dataRoot + '/mortgage/output/csv2parquet/'\n", "output_path_train = dataRoot + '/mortgage/output/train/'\n", "output_path_eval = dataRoot + '/mortgage/output/eval/'\n", "save_train_eval_dataset = True" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2. Common Spark Settings" ] }, { "cell_type": "code", "execution_count": 57, "metadata": {}, "outputs": [], "source": [ "spark.conf.set('spark.rapids.sql.explain', 'ALL')\n", "spark.conf.set('spark.rapids.sql.batchSizeBytes', '512M')\n", "spark.conf.set('spark.rapids.sql.reader.batchSizeBytes', '768M')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Run Part" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Read Raw File" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "rawDf = read_raw_csv(spark, orig_raw_path)\n", "rawDf.write.parquet(output_csv2parquet, mode='overwrite')\n", "rawDf = spark.read.parquet(output_csv2parquet)\n", "\n", "acq = extract_acq_columns(rawDf)\n", "perf = extract_perf_columns(rawDf)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Run ETL\n", "#### 1. Add additional Spark settings" ] }, { "cell_type": "code", "execution_count": 60, "metadata": {}, "outputs": [], "source": [ "# GPU run, set to true\n", "spark.conf.set('spark.rapids.sql.enabled', 'true')\n", "# CPU run, set to false\n", "# spark.conf.set('spark.rapids.sql.enabled', 'false')\n", "spark.conf.set('spark.sql.files.maxPartitionBytes', '1G')\n", "# use GPU to read CSV\n", "spark.conf.set(\"spark.rapids.sql.csv.read.double.enabled\", \"true\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 2.Read Parquet File and Run ETL Process, Save the Result" ] }, { "cell_type": "code", "execution_count": 61, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "== Physical Plan ==\n", "GpuColumnarToRow false\n", "+- GpuProject [gpucoalesce(orig_channel#3146, 0) AS orig_channel#5143, gpucoalesce(first_home_buyer#3351, 0) AS first_home_buyer#5144, gpucoalesce(loan_purpose#3556, 0) AS loan_purpose#5145, gpucoalesce(property_type#3761, 0) AS property_type#5146, gpucoalesce(occupancy_status#3966, 0) AS occupancy_status#5147, gpucoalesce(property_state#4171, 0) AS property_state#5148, gpucoalesce(product_type#4376, 0) AS product_type#5149, gpucoalesce(relocation_mortgage_indicator#4581, 0) AS relocation_mortgage_indicator#5150, gpucoalesce(seller_name#4786, 0) AS seller_name#5151, gpucoalesce(id#2956, 0) AS mod_flag#5152, gpucoalesce(gpunanvl(orig_interest_rate#1606, null), 0.0) AS orig_interest_rate#5153, gpucoalesce(orig_upb#1607, 0) AS orig_upb#5154, gpucoalesce(orig_loan_term#1608, 0) AS orig_loan_term#5155, gpucoalesce(gpunanvl(orig_ltv#1611, null), 0.0) AS orig_ltv#5156, gpucoalesce(gpunanvl(orig_cltv#1612, null), 0.0) AS orig_cltv#5157, gpucoalesce(gpunanvl(num_borrowers#1613, null), 0.0) AS num_borrowers#5158, gpucoalesce(gpunanvl(dti#1614, null), 0.0) AS dti#5159, gpucoalesce(gpunanvl(borrower_credit_score#1615, null), 0.0) AS borrower_credit_score#5160, gpucoalesce(num_units#1619, 0) AS num_units#5161, gpucoalesce(zip#1622, 0) AS zip#5162, gpucoalesce(gpunanvl(mortgage_insurance_percent#1623, null), 0.0) AS mortgage_insurance_percent#5163, gpucoalesce(current_loan_delinquency_status#1549, 0) AS current_loan_delinquency_status#5164, gpucoalesce(gpunanvl(current_actual_upb#1543, null), 0.0) AS current_actual_upb#5165, gpucoalesce(gpunanvl(interest_rate#1542, null), 0.0) AS interest_rate#5166, ... 4 more fields]\n", " +- GpuBroadcastHashJoin [mod_flag#1550], [mod_flag#4855], LeftOuter, GpuBuildRight\n", " :- GpuProject [interest_rate#1542, current_actual_upb#1543, loan_age#1544, msa#1548, current_loan_delinquency_status#1549, mod_flag#1550, non_interest_bearing_upb#1565, delinquency_12#2353, orig_interest_rate#1606, orig_upb#1607, orig_loan_term#1608, orig_ltv#1611, orig_cltv#1612, num_borrowers#1613, dti#1614, borrower_credit_score#1615, num_units#1619, zip#1622, mortgage_insurance_percent#1623, orig_channel#3146, first_home_buyer#3351, loan_purpose#3556, property_type#3761, occupancy_status#3966, ... 4 more fields]\n", " : +- GpuBroadcastHashJoin [seller_name#2689], [seller_name#4650], LeftOuter, GpuBuildRight\n", " : :- GpuProject [interest_rate#1542, current_actual_upb#1543, loan_age#1544, msa#1548, current_loan_delinquency_status#1549, mod_flag#1550, non_interest_bearing_upb#1565, delinquency_12#2353, seller_name#2689, orig_interest_rate#1606, orig_upb#1607, orig_loan_term#1608, orig_ltv#1611, orig_cltv#1612, num_borrowers#1613, dti#1614, borrower_credit_score#1615, num_units#1619, zip#1622, mortgage_insurance_percent#1623, orig_channel#3146, first_home_buyer#3351, loan_purpose#3556, property_type#3761, ... 4 more fields]\n", " : : +- GpuBroadcastHashJoin [relocation_mortgage_indicator#1627], [relocation_mortgage_indicator#4445], LeftOuter, GpuBuildRight\n", " : : :- GpuProject [interest_rate#1542, current_actual_upb#1543, loan_age#1544, msa#1548, current_loan_delinquency_status#1549, mod_flag#1550, non_interest_bearing_upb#1565, delinquency_12#2353, seller_name#2689, orig_interest_rate#1606, orig_upb#1607, orig_loan_term#1608, orig_ltv#1611, orig_cltv#1612, num_borrowers#1613, dti#1614, borrower_credit_score#1615, num_units#1619, zip#1622, mortgage_insurance_percent#1623, relocation_mortgage_indicator#1627, orig_channel#3146, first_home_buyer#3351, loan_purpose#3556, ... 4 more fields]\n", " : : : +- GpuBroadcastHashJoin [product_type#1624], [product_type#4240], LeftOuter, GpuBuildRight\n", " : : : :- GpuProject [interest_rate#1542, current_actual_upb#1543, loan_age#1544, msa#1548, current_loan_delinquency_status#1549, mod_flag#1550, non_interest_bearing_upb#1565, delinquency_12#2353, seller_name#2689, orig_interest_rate#1606, orig_upb#1607, orig_loan_term#1608, orig_ltv#1611, orig_cltv#1612, num_borrowers#1613, dti#1614, borrower_credit_score#1615, num_units#1619, zip#1622, mortgage_insurance_percent#1623, product_type#1624, relocation_mortgage_indicator#1627, orig_channel#3146, first_home_buyer#3351, ... 4 more fields]\n", " : : : : +- GpuBroadcastHashJoin [property_state#1621], [property_state#4035], LeftOuter, GpuBuildRight\n", " : : : : :- GpuProject [interest_rate#1542, current_actual_upb#1543, loan_age#1544, msa#1548, current_loan_delinquency_status#1549, mod_flag#1550, non_interest_bearing_upb#1565, delinquency_12#2353, seller_name#2689, orig_interest_rate#1606, orig_upb#1607, orig_loan_term#1608, orig_ltv#1611, orig_cltv#1612, num_borrowers#1613, dti#1614, borrower_credit_score#1615, num_units#1619, property_state#1621, zip#1622, mortgage_insurance_percent#1623, product_type#1624, relocation_mortgage_indicator#1627, orig_channel#3146, ... 4 more fields]\n", " : : : : : +- GpuBroadcastHashJoin [occupancy_status#1620], [occupancy_status#3830], LeftOuter, GpuBuildRight\n", " : : : : : :- GpuProject [interest_rate#1542, current_actual_upb#1543, loan_age#1544, msa#1548, current_loan_delinquency_status#1549, mod_flag#1550, non_interest_bearing_upb#1565, delinquency_12#2353, seller_name#2689, orig_interest_rate#1606, orig_upb#1607, orig_loan_term#1608, orig_ltv#1611, orig_cltv#1612, num_borrowers#1613, dti#1614, borrower_credit_score#1615, num_units#1619, occupancy_status#1620, property_state#1621, zip#1622, mortgage_insurance_percent#1623, product_type#1624, relocation_mortgage_indicator#1627, ... 4 more fields]\n", " : : : : : : +- GpuBroadcastHashJoin [property_type#1618], [property_type#3625], LeftOuter, GpuBuildRight\n", " : : : : : : :- GpuProject [interest_rate#1542, current_actual_upb#1543, loan_age#1544, msa#1548, current_loan_delinquency_status#1549, mod_flag#1550, non_interest_bearing_upb#1565, delinquency_12#2353, seller_name#2689, orig_interest_rate#1606, orig_upb#1607, orig_loan_term#1608, orig_ltv#1611, orig_cltv#1612, num_borrowers#1613, dti#1614, borrower_credit_score#1615, property_type#1618, num_units#1619, occupancy_status#1620, property_state#1621, zip#1622, mortgage_insurance_percent#1623, product_type#1624, ... 4 more fields]\n", " : : : : : : : +- GpuBroadcastHashJoin [loan_purpose#1617], [loan_purpose#3420], LeftOuter, GpuBuildRight\n", " : : : : : : : :- GpuProject [interest_rate#1542, current_actual_upb#1543, loan_age#1544, msa#1548, current_loan_delinquency_status#1549, mod_flag#1550, non_interest_bearing_upb#1565, delinquency_12#2353, seller_name#2689, orig_interest_rate#1606, orig_upb#1607, orig_loan_term#1608, orig_ltv#1611, orig_cltv#1612, num_borrowers#1613, dti#1614, borrower_credit_score#1615, loan_purpose#1617, property_type#1618, num_units#1619, occupancy_status#1620, property_state#1621, zip#1622, mortgage_insurance_percent#1623, ... 4 more fields]\n", " : : : : : : : : +- GpuBroadcastHashJoin [first_home_buyer#1616], [first_home_buyer#3215], LeftOuter, GpuBuildRight\n", " : : : : : : : : :- GpuProject [interest_rate#1542, current_actual_upb#1543, loan_age#1544, msa#1548, current_loan_delinquency_status#1549, mod_flag#1550, non_interest_bearing_upb#1565, delinquency_12#2353, seller_name#2689, orig_interest_rate#1606, orig_upb#1607, orig_loan_term#1608, orig_ltv#1611, orig_cltv#1612, num_borrowers#1613, dti#1614, borrower_credit_score#1615, first_home_buyer#1616, loan_purpose#1617, property_type#1618, num_units#1619, occupancy_status#1620, property_state#1621, zip#1622, ... 4 more fields]\n", " : : : : : : : : : +- GpuBroadcastHashJoin [orig_channel#1604], [orig_channel#3010], LeftOuter, GpuBuildRight\n", " : : : : : : : : : :- GpuProject [interest_rate#1542, current_actual_upb#1543, loan_age#1544, msa#1548, current_loan_delinquency_status#1549, mod_flag#1550, non_interest_bearing_upb#1565, delinquency_12#2353, orig_channel#1604, seller_name#2689, orig_interest_rate#1606, orig_upb#1607, orig_loan_term#1608, orig_ltv#1611, orig_cltv#1612, num_borrowers#1613, dti#1614, borrower_credit_score#1615, first_home_buyer#1616, loan_purpose#1617, property_type#1618, num_units#1619, occupancy_status#1620, property_state#1621, ... 4 more fields]\n", " : : : : : : : : : : +- GpuShuffledHashJoin [loan_id#1539L, quarter#1570], [loan_id#1603L, quarter#1629], Inner, GpuBuildRight, false\n", " : : : : : : : : : : :- GpuShuffleCoalesce 536870912\n", " : : : : : : : : : : : +- GpuColumnarExchange gpuhashpartitioning(loan_id#1539L, quarter#1570, 192), ENSURE_REQUIREMENTS, [id=#3885]\n", " : : : : : : : : : : : +- GpuProject [quarter#1570, loan_id#1539L, interest_rate#1542, current_actual_upb#1543, loan_age#1544, msa#1548, current_loan_delinquency_status#1549, mod_flag#1550, non_interest_bearing_upb#1565, delinquency_12#2353]\n", " : : : : : : : : : : : +- GpuRowToColumnar targetsize(536870912)\n", " : : : : : : : : : : : +- SortMergeJoin [quarter#1570, loan_id#1539L, cast(timestamp_year#2417 as bigint), cast(timestamp_month#2381 as bigint)], [quarter#2484, loan_id#2453L, timestamp_year#2307L, timestamp_month#2336L], LeftOuter\n", " : : : : : : : : : : : :- *(2) Sort [quarter#1570 ASC NULLS FIRST, loan_id#1539L ASC NULLS FIRST, cast(timestamp_year#2417 as bigint) ASC NULLS FIRST, cast(timestamp_month#2381 as bigint) ASC NULLS FIRST], false, 0\n", " : : : : : : : : : : : : +- Exchange hashpartitioning(quarter#1570, loan_id#1539L, cast(timestamp_year#2417 as bigint), cast(timestamp_month#2381 as bigint), 192), ENSURE_REQUIREMENTS, [id=#3847]\n", " : : : : : : : : : : : : +- *(1) Project [loan_id#1539L, interest_rate#1542, current_actual_upb#1543, loan_age#1544, msa#1548, current_loan_delinquency_status#1549, mod_flag#1550, non_interest_bearing_upb#1565, quarter#1570, month(cast(gettimestamp(monthly_reporting_period#1540, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_month#2381, year(cast(gettimestamp(monthly_reporting_period#1540, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_year#2417]\n", " : : : : : : : : : : : : +- GpuColumnarToRow false\n", " : : : : : : : : : : : : +- GpuFilter (gpuisnotnull(loan_id#1539L) AND gpuisnotnull(quarter#1570)), true\n", " : : : : : : : : : : : : +- GpuFileGpuScan parquet [loan_id#1539L,monthly_reporting_period#1540,interest_rate#1542,current_actual_upb#1543,loan_age#1544,msa#1548,current_loan_delinquency_status#1549,mod_flag#1550,non_interest_bearing_upb#1565,quarter#1570] Batched: true, DataFilters: [isnotnull(loan_id#1539L), isnotnull(quarter#1570)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(loan_id), IsNotNull(quarter)], ReadSchema: struct= 1) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_30#1975, CASE WHEN (current_loan_delinquency_status#2212 >= 3) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_90#1976, CASE WHEN (current_loan_delinquency_status#2212 >= 6) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_180#1977]\n", " : : : : : : : : : : : +- GpuColumnarToRow false\n", " : : : : : : : : : : : +- GpuFilter (gpuisnotnull(loan_id#2202L) AND gpuisnotnull(quarter#2233)), true\n", " : : : : : : : : : : : +- GpuFileGpuScan parquet [loan_id#2202L,monthly_reporting_period#2203,current_loan_delinquency_status#2212,quarter#2233] Batched: true, DataFilters: [isnotnull(loan_id#2202L), isnotnull(quarter#2233)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(loan_id), IsNotNull(quarter)], ReadSchema: struct\n", " : : : : : : : : : : +- GpuColumnarToRow false\n", " : : : : : : : : : : +- GpuSort [quarter#2484 ASC NULLS FIRST, loan_id#2453L ASC NULLS FIRST, timestamp_year#2307L ASC NULLS FIRST, timestamp_month#2336L ASC NULLS FIRST], false, com.nvidia.spark.rapids.OutOfCoreSort$@163d9f7d\n", " : : : : : : : : : : +- GpuShuffleCoalesce 536870912\n", " : : : : : : : : : : +- GpuColumnarExchange gpuhashpartitioning(quarter#2484, loan_id#2453L, timestamp_year#2307L, timestamp_month#2336L, 192), ENSURE_REQUIREMENTS, [id=#1114]\n", " : : : : : : : : : : +- GpuRowToColumnar targetsize(536870912)\n", " : : : : : : : : : : +- *(6) HashAggregate(keys=[quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248], functions=[])\n", " : : : : : : : : : : +- GpuColumnarToRow false\n", " : : : : : : : : : : +- GpuShuffleCoalesce 536870912\n", " : : : : : : : : : : +- GpuColumnarExchange gpuhashpartitioning(quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248, 192), ENSURE_REQUIREMENTS, [id=#1107]\n", " : : : : : : : : : : +- GpuHashAggregate(keys=[quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248], functions=[]), filters=ArrayBuffer())\n", " : : : : : : : : : : +- GpuRowToColumnar targetsize(536870912)\n", " : : : : : : : : : : +- *(5) Project [quarter#2484, FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) AS josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, loan_id#2453L, month_y#2248]\n", " : : : : : : : : : : +- *(5) Filter (isnotnull(FLOOR((cast(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast((month_y#2248 - 1) as bigint)) as double) / 12.0))) AND isnotnull(CASE WHEN (pmod(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast(month_y#2248 as bigint)), 12) = 0) THEN 12 ELSE pmod(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast(month_y#2248 as bigint)), 12) END))\n", " : : : : : : : : : : +- GpuColumnarToRow false\n", " : : : : : : : : : : +- GpuGenerate gpuexplode([0,1,2,3,4,5,6,7,8,9,10,11]), [loan_id#2453L, quarter#2484, timestamp_month#2051, timestamp_year#2087, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997], false, [month_y#2248]\n", " : : : : : : : : : : +- GpuProject [loan_id#2453L, quarter#2484, timestamp_month#2051, timestamp_year#2087, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997]\n", " : : : : : : : : : : +- GpuBroadcastHashJoin [loan_id#2453L, quarter#2484], [loan_id#2202L, quarter#2233], LeftOuter, GpuBuildRight\n", " : : : : : : : : : : :- GpuRowToColumnar targetsize(536870912)\n", " : : : : : : : : : : : +- *(3) Project [quarter#2484, loan_id#2453L, month(cast(gettimestamp(monthly_reporting_period#2454, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_month#2051, year(cast(gettimestamp(monthly_reporting_period#2454, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_year#2087]\n", " : : : : : : : : : : : +- GpuColumnarToRow false\n", " : : : : : : : : : : : +- GpuFilter (gpuisnotnull(quarter#2484) AND gpuisnotnull(loan_id#2453L)), true\n", " : : : : : : : : : : : +- GpuFileGpuScan parquet [loan_id#2453L,monthly_reporting_period#2454,quarter#2484] Batched: true, DataFilters: [isnotnull(quarter#2484), isnotnull(loan_id#2453L)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(quarter), IsNotNull(loan_id)], ReadSchema: struct\n", " : : : : : : : : : : +- GpuBroadcastExchange HashedRelationBroadcastMode(List(input[1, bigint, true], input[0, string, true]),false), [id=#1096]\n", " : : : : : : : : : : +- GpuHashAggregate(keys=[quarter#2233, loan_id#2202L], functions=[gpumax(current_loan_delinquency_status#2212), gpumin(delinquency_30#1975), gpumin(delinquency_90#1976), gpumin(delinquency_180#1977)]), filters=ArrayBuffer(None, None, None, None))\n", " : : : : : : : : : : +- GpuShuffleCoalesce 536870912\n", " : : : : : : : : : : +- GpuColumnarExchange gpuhashpartitioning(quarter#2233, loan_id#2202L, 192), ENSURE_REQUIREMENTS, [id=#1093]\n", " : : : : : : : : : : +- GpuHashAggregate(keys=[quarter#2233, loan_id#2202L], functions=[partial_gpumax(current_loan_delinquency_status#2212), partial_gpumin(delinquency_30#1975), partial_gpumin(delinquency_90#1976), partial_gpumin(delinquency_180#1977)]), filters=ArrayBuffer(None, None, None, None))\n", " : : : : : : : : : : +- GpuRowToColumnar targetsize(536870912)\n", " : : : : : : : : : : +- *(4) Project [quarter#2233, loan_id#2202L, current_loan_delinquency_status#2212, CASE WHEN (current_loan_delinquency_status#2212 >= 1) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_30#1975, CASE WHEN (current_loan_delinquency_status#2212 >= 3) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_90#1976, CASE WHEN (current_loan_delinquency_status#2212 >= 6) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_180#1977]\n", " : : : : : : : : : : +- GpuColumnarToRow false\n", " : : : : : : : : : : +- GpuFilter (gpuisnotnull(loan_id#2202L) AND gpuisnotnull(quarter#2233)), true\n", " : : : : : : : : : : +- GpuFileGpuScan parquet [loan_id#2202L,monthly_reporting_period#2203,current_loan_delinquency_status#2212,quarter#2233] Batched: true, DataFilters: [isnotnull(loan_id#2202L), isnotnull(quarter#2233)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(loan_id), IsNotNull(quarter)], ReadSchema: struct\n", " : : : : : : : : : +- GpuColumnarToRow false\n", " : : : : : : : : : +- GpuSort [quarter#2484 ASC NULLS FIRST, loan_id#2453L ASC NULLS FIRST, timestamp_year#2307L ASC NULLS FIRST, timestamp_month#2336L ASC NULLS FIRST], false, com.nvidia.spark.rapids.OutOfCoreSort$@163d9f7d\n", " : : : : : : : : : +- GpuShuffleCoalesce 536870912\n", " : : : : : : : : : +- GpuColumnarExchange gpuhashpartitioning(quarter#2484, loan_id#2453L, timestamp_year#2307L, timestamp_month#2336L, 192), ENSURE_REQUIREMENTS, [id=#1114]\n", " : : : : : : : : : +- GpuRowToColumnar targetsize(536870912)\n", " : : : : : : : : : +- *(6) HashAggregate(keys=[quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248], functions=[])\n", " : : : : : : : : : +- GpuColumnarToRow false\n", " : : : : : : : : : +- GpuShuffleCoalesce 536870912\n", " : : : : : : : : : +- GpuColumnarExchange gpuhashpartitioning(quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248, 192), ENSURE_REQUIREMENTS, [id=#1107]\n", " : : : : : : : : : +- GpuHashAggregate(keys=[quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248], functions=[]), filters=ArrayBuffer())\n", " : : : : : : : : : +- GpuRowToColumnar targetsize(536870912)\n", " : : : : : : : : : +- *(5) Project [quarter#2484, FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) AS josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, loan_id#2453L, month_y#2248]\n", " : : : : : : : : : +- *(5) Filter (isnotnull(FLOOR((cast(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast((month_y#2248 - 1) as bigint)) as double) / 12.0))) AND isnotnull(CASE WHEN (pmod(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast(month_y#2248 as bigint)), 12) = 0) THEN 12 ELSE pmod(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast(month_y#2248 as bigint)), 12) END))\n", " : : : : : : : : : +- GpuColumnarToRow false\n", " : : : : : : : : : +- GpuGenerate gpuexplode([0,1,2,3,4,5,6,7,8,9,10,11]), [loan_id#2453L, quarter#2484, timestamp_month#2051, timestamp_year#2087, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997], false, [month_y#2248]\n", " : : : : : : : : : +- GpuProject [loan_id#2453L, quarter#2484, timestamp_month#2051, timestamp_year#2087, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997]\n", " : : : : : : : : : +- GpuBroadcastHashJoin [loan_id#2453L, quarter#2484], [loan_id#2202L, quarter#2233], LeftOuter, GpuBuildRight\n", " : : : : : : : : : :- GpuRowToColumnar targetsize(536870912)\n", " : : : : : : : : : : +- *(3) Project [quarter#2484, loan_id#2453L, month(cast(gettimestamp(monthly_reporting_period#2454, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_month#2051, year(cast(gettimestamp(monthly_reporting_period#2454, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_year#2087]\n", " : : : : : : : : : : +- GpuColumnarToRow false\n", " : : : : : : : : : : +- GpuFilter (gpuisnotnull(quarter#2484) AND gpuisnotnull(loan_id#2453L)), true\n", " : : : : : : : : : : +- GpuFileGpuScan parquet [loan_id#2453L,monthly_reporting_period#2454,quarter#2484] Batched: true, DataFilters: [isnotnull(quarter#2484), isnotnull(loan_id#2453L)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(quarter), IsNotNull(loan_id)], ReadSchema: struct\n", " : : : : : : : : : +- GpuBroadcastExchange HashedRelationBroadcastMode(List(input[1, bigint, true], input[0, string, true]),false), [id=#1096]\n", " : : : : : : : : : +- GpuHashAggregate(keys=[quarter#2233, loan_id#2202L], functions=[gpumax(current_loan_delinquency_status#2212), gpumin(delinquency_30#1975), gpumin(delinquency_90#1976), gpumin(delinquency_180#1977)]), filters=ArrayBuffer(None, None, None, None))\n", " : : : : : : : : : +- GpuShuffleCoalesce 536870912\n", " : : : : : : : : : +- GpuColumnarExchange gpuhashpartitioning(quarter#2233, loan_id#2202L, 192), ENSURE_REQUIREMENTS, [id=#1093]\n", " : : : : : : : : : +- GpuHashAggregate(keys=[quarter#2233, loan_id#2202L], functions=[partial_gpumax(current_loan_delinquency_status#2212), partial_gpumin(delinquency_30#1975), partial_gpumin(delinquency_90#1976), partial_gpumin(delinquency_180#1977)]), filters=ArrayBuffer(None, None, None, None))\n", " : : : : : : : : : +- GpuRowToColumnar targetsize(536870912)\n", " : : : : : : : : : +- *(4) Project [quarter#2233, loan_id#2202L, current_loan_delinquency_status#2212, CASE WHEN (current_loan_delinquency_status#2212 >= 1) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_30#1975, CASE WHEN (current_loan_delinquency_status#2212 >= 3) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_90#1976, CASE WHEN (current_loan_delinquency_status#2212 >= 6) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_180#1977]\n", " : : : : : : : : : +- GpuColumnarToRow false\n", " : : : : : : : : : +- GpuFilter (gpuisnotnull(loan_id#2202L) AND gpuisnotnull(quarter#2233)), true\n", " : : : : : : : : : +- GpuFileGpuScan parquet [loan_id#2202L,monthly_reporting_period#2203,current_loan_delinquency_status#2212,quarter#2233] Batched: true, DataFilters: [isnotnull(loan_id#2202L), isnotnull(quarter#2233)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(loan_id), IsNotNull(quarter)], ReadSchema: struct\n", " : : : : : : : : +- GpuColumnarToRow false\n", " : : : : : : : : +- GpuSort [quarter#2484 ASC NULLS FIRST, loan_id#2453L ASC NULLS FIRST, timestamp_year#2307L ASC NULLS FIRST, timestamp_month#2336L ASC NULLS FIRST], false, com.nvidia.spark.rapids.OutOfCoreSort$@163d9f7d\n", " : : : : : : : : +- GpuShuffleCoalesce 536870912\n", " : : : : : : : : +- GpuColumnarExchange gpuhashpartitioning(quarter#2484, loan_id#2453L, timestamp_year#2307L, timestamp_month#2336L, 192), ENSURE_REQUIREMENTS, [id=#1114]\n", " : : : : : : : : +- GpuRowToColumnar targetsize(536870912)\n", " : : : : : : : : +- *(6) HashAggregate(keys=[quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248], functions=[])\n", " : : : : : : : : +- GpuColumnarToRow false\n", " : : : : : : : : +- GpuShuffleCoalesce 536870912\n", " : : : : : : : : +- GpuColumnarExchange gpuhashpartitioning(quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248, 192), ENSURE_REQUIREMENTS, [id=#1107]\n", " : : : : : : : : +- GpuHashAggregate(keys=[quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248], functions=[]), filters=ArrayBuffer())\n", " : : : : : : : : +- GpuRowToColumnar targetsize(536870912)\n", " : : : : : : : : +- *(5) Project [quarter#2484, FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) AS josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, loan_id#2453L, month_y#2248]\n", " : : : : : : : : +- *(5) Filter (isnotnull(FLOOR((cast(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast((month_y#2248 - 1) as bigint)) as double) / 12.0))) AND isnotnull(CASE WHEN (pmod(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast(month_y#2248 as bigint)), 12) = 0) THEN 12 ELSE pmod(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast(month_y#2248 as bigint)), 12) END))\n", " : : : : : : : : +- GpuColumnarToRow false\n", " : : : : : : : : +- GpuGenerate gpuexplode([0,1,2,3,4,5,6,7,8,9,10,11]), [loan_id#2453L, quarter#2484, timestamp_month#2051, timestamp_year#2087, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997], false, [month_y#2248]\n", " : : : : : : : : +- GpuProject [loan_id#2453L, quarter#2484, timestamp_month#2051, timestamp_year#2087, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997]\n", " : : : : : : : : +- GpuBroadcastHashJoin [loan_id#2453L, quarter#2484], [loan_id#2202L, quarter#2233], LeftOuter, GpuBuildRight\n", " : : : : : : : : :- GpuRowToColumnar targetsize(536870912)\n", " : : : : : : : : : +- *(3) Project [quarter#2484, loan_id#2453L, month(cast(gettimestamp(monthly_reporting_period#2454, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_month#2051, year(cast(gettimestamp(monthly_reporting_period#2454, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_year#2087]\n", " : : : : : : : : : +- GpuColumnarToRow false\n", " : : : : : : : : : +- GpuFilter (gpuisnotnull(quarter#2484) AND gpuisnotnull(loan_id#2453L)), true\n", " : : : : : : : : : +- GpuFileGpuScan parquet [loan_id#2453L,monthly_reporting_period#2454,quarter#2484] Batched: true, DataFilters: [isnotnull(quarter#2484), isnotnull(loan_id#2453L)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(quarter), IsNotNull(loan_id)], ReadSchema: struct\n", " : : : : : : : : +- GpuBroadcastExchange HashedRelationBroadcastMode(List(input[1, bigint, true], input[0, string, true]),false), [id=#1096]\n", " : : : : : : : : +- GpuHashAggregate(keys=[quarter#2233, loan_id#2202L], functions=[gpumax(current_loan_delinquency_status#2212), gpumin(delinquency_30#1975), gpumin(delinquency_90#1976), gpumin(delinquency_180#1977)]), filters=ArrayBuffer(None, None, None, None))\n", " : : : : : : : : +- GpuShuffleCoalesce 536870912\n", " : : : : : : : : +- GpuColumnarExchange gpuhashpartitioning(quarter#2233, loan_id#2202L, 192), ENSURE_REQUIREMENTS, [id=#1093]\n", " : : : : : : : : +- GpuHashAggregate(keys=[quarter#2233, loan_id#2202L], functions=[partial_gpumax(current_loan_delinquency_status#2212), partial_gpumin(delinquency_30#1975), partial_gpumin(delinquency_90#1976), partial_gpumin(delinquency_180#1977)]), filters=ArrayBuffer(None, None, None, None))\n", " : : : : : : : : +- GpuRowToColumnar targetsize(536870912)\n", " : : : : : : : : +- *(4) Project [quarter#2233, loan_id#2202L, current_loan_delinquency_status#2212, CASE WHEN (current_loan_delinquency_status#2212 >= 1) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_30#1975, CASE WHEN (current_loan_delinquency_status#2212 >= 3) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_90#1976, CASE WHEN (current_loan_delinquency_status#2212 >= 6) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_180#1977]\n", " : : : : : : : : +- GpuColumnarToRow false\n", " : : : : : : : : +- GpuFilter (gpuisnotnull(loan_id#2202L) AND gpuisnotnull(quarter#2233)), true\n", " : : : : : : : : +- GpuFileGpuScan parquet [loan_id#2202L,monthly_reporting_period#2203,current_loan_delinquency_status#2212,quarter#2233] Batched: true, DataFilters: [isnotnull(loan_id#2202L), isnotnull(quarter#2233)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(loan_id), IsNotNull(quarter)], ReadSchema: struct\n", " : : : : : : : +- GpuColumnarToRow false\n", " : : : : : : : +- GpuSort [quarter#2484 ASC NULLS FIRST, loan_id#2453L ASC NULLS FIRST, timestamp_year#2307L ASC NULLS FIRST, timestamp_month#2336L ASC NULLS FIRST], false, com.nvidia.spark.rapids.OutOfCoreSort$@163d9f7d\n", " : : : : : : : +- GpuShuffleCoalesce 536870912\n", " : : : : : : : +- GpuColumnarExchange gpuhashpartitioning(quarter#2484, loan_id#2453L, timestamp_year#2307L, timestamp_month#2336L, 192), ENSURE_REQUIREMENTS, [id=#1114]\n", " : : : : : : : +- GpuRowToColumnar targetsize(536870912)\n", " : : : : : : : +- *(6) HashAggregate(keys=[quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248], functions=[])\n", " : : : : : : : +- GpuColumnarToRow false\n", " : : : : : : : +- GpuShuffleCoalesce 536870912\n", " : : : : : : : +- GpuColumnarExchange gpuhashpartitioning(quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248, 192), ENSURE_REQUIREMENTS, [id=#1107]\n", " : : : : : : : +- GpuHashAggregate(keys=[quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248], functions=[]), filters=ArrayBuffer())\n", " : : : : : : : +- GpuRowToColumnar targetsize(536870912)\n", " : : : : : : : +- *(5) Project [quarter#2484, FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) AS josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, loan_id#2453L, month_y#2248]\n", " : : : : : : : +- *(5) Filter (isnotnull(FLOOR((cast(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast((month_y#2248 - 1) as bigint)) as double) / 12.0))) AND isnotnull(CASE WHEN (pmod(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast(month_y#2248 as bigint)), 12) = 0) THEN 12 ELSE pmod(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast(month_y#2248 as bigint)), 12) END))\n", " : : : : : : : +- GpuColumnarToRow false\n", " : : : : : : : +- GpuGenerate gpuexplode([0,1,2,3,4,5,6,7,8,9,10,11]), [loan_id#2453L, quarter#2484, timestamp_month#2051, timestamp_year#2087, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997], false, [month_y#2248]\n", " : : : : : : : +- GpuProject [loan_id#2453L, quarter#2484, timestamp_month#2051, timestamp_year#2087, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997]\n", " : : : : : : : +- GpuBroadcastHashJoin [loan_id#2453L, quarter#2484], [loan_id#2202L, quarter#2233], LeftOuter, GpuBuildRight\n", " : : : : : : : :- GpuRowToColumnar targetsize(536870912)\n", " : : : : : : : : +- *(3) Project [quarter#2484, loan_id#2453L, month(cast(gettimestamp(monthly_reporting_period#2454, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_month#2051, year(cast(gettimestamp(monthly_reporting_period#2454, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_year#2087]\n", " : : : : : : : : +- GpuColumnarToRow false\n", " : : : : : : : : +- GpuFilter (gpuisnotnull(quarter#2484) AND gpuisnotnull(loan_id#2453L)), true\n", " : : : : : : : : +- GpuFileGpuScan parquet [loan_id#2453L,monthly_reporting_period#2454,quarter#2484] Batched: true, DataFilters: [isnotnull(quarter#2484), isnotnull(loan_id#2453L)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(quarter), IsNotNull(loan_id)], ReadSchema: struct\n", " : : : : : : : +- GpuBroadcastExchange HashedRelationBroadcastMode(List(input[1, bigint, true], input[0, string, true]),false), [id=#1096]\n", " : : : : : : : +- GpuHashAggregate(keys=[quarter#2233, loan_id#2202L], functions=[gpumax(current_loan_delinquency_status#2212), gpumin(delinquency_30#1975), gpumin(delinquency_90#1976), gpumin(delinquency_180#1977)]), filters=ArrayBuffer(None, None, None, None))\n", " : : : : : : : +- GpuShuffleCoalesce 536870912\n", " : : : : : : : +- GpuColumnarExchange gpuhashpartitioning(quarter#2233, loan_id#2202L, 192), ENSURE_REQUIREMENTS, [id=#1093]\n", " : : : : : : : +- GpuHashAggregate(keys=[quarter#2233, loan_id#2202L], functions=[partial_gpumax(current_loan_delinquency_status#2212), partial_gpumin(delinquency_30#1975), partial_gpumin(delinquency_90#1976), partial_gpumin(delinquency_180#1977)]), filters=ArrayBuffer(None, None, None, None))\n", " : : : : : : : +- GpuRowToColumnar targetsize(536870912)\n", " : : : : : : : +- *(4) Project [quarter#2233, loan_id#2202L, current_loan_delinquency_status#2212, CASE WHEN (current_loan_delinquency_status#2212 >= 1) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_30#1975, CASE WHEN (current_loan_delinquency_status#2212 >= 3) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_90#1976, CASE WHEN (current_loan_delinquency_status#2212 >= 6) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_180#1977]\n", " : : : : : : : +- GpuColumnarToRow false\n", " : : : : : : : +- GpuFilter (gpuisnotnull(loan_id#2202L) AND gpuisnotnull(quarter#2233)), true\n", " : : : : : : : +- GpuFileGpuScan parquet [loan_id#2202L,monthly_reporting_period#2203,current_loan_delinquency_status#2212,quarter#2233] Batched: true, DataFilters: [isnotnull(loan_id#2202L), isnotnull(quarter#2233)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(loan_id), IsNotNull(quarter)], ReadSchema: struct\n", " : : : : : : +- GpuColumnarToRow false\n", " : : : : : : +- GpuSort [quarter#2484 ASC NULLS FIRST, loan_id#2453L ASC NULLS FIRST, timestamp_year#2307L ASC NULLS FIRST, timestamp_month#2336L ASC NULLS FIRST], false, com.nvidia.spark.rapids.OutOfCoreSort$@163d9f7d\n", " : : : : : : +- GpuShuffleCoalesce 536870912\n", " : : : : : : +- GpuColumnarExchange gpuhashpartitioning(quarter#2484, loan_id#2453L, timestamp_year#2307L, timestamp_month#2336L, 192), ENSURE_REQUIREMENTS, [id=#1114]\n", " : : : : : : +- GpuRowToColumnar targetsize(536870912)\n", " : : : : : : +- *(6) HashAggregate(keys=[quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248], functions=[])\n", " : : : : : : +- GpuColumnarToRow false\n", " : : : : : : +- GpuShuffleCoalesce 536870912\n", " : : : : : : +- GpuColumnarExchange gpuhashpartitioning(quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248, 192), ENSURE_REQUIREMENTS, [id=#1107]\n", " : : : : : : +- GpuHashAggregate(keys=[quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248], functions=[]), filters=ArrayBuffer())\n", " : : : : : : +- GpuRowToColumnar targetsize(536870912)\n", " : : : : : : +- *(5) Project [quarter#2484, FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) AS josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, loan_id#2453L, month_y#2248]\n", " : : : : : : +- *(5) Filter (isnotnull(FLOOR((cast(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast((month_y#2248 - 1) as bigint)) as double) / 12.0))) AND isnotnull(CASE WHEN (pmod(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast(month_y#2248 as bigint)), 12) = 0) THEN 12 ELSE pmod(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast(month_y#2248 as bigint)), 12) END))\n", " : : : : : : +- GpuColumnarToRow false\n", " : : : : : : +- GpuGenerate gpuexplode([0,1,2,3,4,5,6,7,8,9,10,11]), [loan_id#2453L, quarter#2484, timestamp_month#2051, timestamp_year#2087, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997], false, [month_y#2248]\n", " : : : : : : +- GpuProject [loan_id#2453L, quarter#2484, timestamp_month#2051, timestamp_year#2087, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997]\n", " : : : : : : +- GpuBroadcastHashJoin [loan_id#2453L, quarter#2484], [loan_id#2202L, quarter#2233], LeftOuter, GpuBuildRight\n", " : : : : : : :- GpuRowToColumnar targetsize(536870912)\n", " : : : : : : : +- *(3) Project [quarter#2484, loan_id#2453L, month(cast(gettimestamp(monthly_reporting_period#2454, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_month#2051, year(cast(gettimestamp(monthly_reporting_period#2454, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_year#2087]\n", " : : : : : : : +- GpuColumnarToRow false\n", " : : : : : : : +- GpuFilter (gpuisnotnull(quarter#2484) AND gpuisnotnull(loan_id#2453L)), true\n", " : : : : : : : +- GpuFileGpuScan parquet [loan_id#2453L,monthly_reporting_period#2454,quarter#2484] Batched: true, DataFilters: [isnotnull(quarter#2484), isnotnull(loan_id#2453L)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(quarter), IsNotNull(loan_id)], ReadSchema: struct\n", " : : : : : : +- GpuBroadcastExchange HashedRelationBroadcastMode(List(input[1, bigint, true], input[0, string, true]),false), [id=#1096]\n", " : : : : : : +- GpuHashAggregate(keys=[quarter#2233, loan_id#2202L], functions=[gpumax(current_loan_delinquency_status#2212), gpumin(delinquency_30#1975), gpumin(delinquency_90#1976), gpumin(delinquency_180#1977)]), filters=ArrayBuffer(None, None, None, None))\n", " : : : : : : +- GpuShuffleCoalesce 536870912\n", " : : : : : : +- GpuColumnarExchange gpuhashpartitioning(quarter#2233, loan_id#2202L, 192), ENSURE_REQUIREMENTS, [id=#1093]\n", " : : : : : : +- GpuHashAggregate(keys=[quarter#2233, loan_id#2202L], functions=[partial_gpumax(current_loan_delinquency_status#2212), partial_gpumin(delinquency_30#1975), partial_gpumin(delinquency_90#1976), partial_gpumin(delinquency_180#1977)]), filters=ArrayBuffer(None, None, None, None))\n", " : : : : : : +- GpuRowToColumnar targetsize(536870912)\n", " : : : : : : +- *(4) Project [quarter#2233, loan_id#2202L, current_loan_delinquency_status#2212, CASE WHEN (current_loan_delinquency_status#2212 >= 1) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_30#1975, CASE WHEN (current_loan_delinquency_status#2212 >= 3) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_90#1976, CASE WHEN (current_loan_delinquency_status#2212 >= 6) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_180#1977]\n", " : : : : : : +- GpuColumnarToRow false\n", " : : : : : : +- GpuFilter (gpuisnotnull(loan_id#2202L) AND gpuisnotnull(quarter#2233)), true\n", " : : : : : : +- GpuFileGpuScan parquet [loan_id#2202L,monthly_reporting_period#2203,current_loan_delinquency_status#2212,quarter#2233] Batched: true, DataFilters: [isnotnull(loan_id#2202L), isnotnull(quarter#2233)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(loan_id), IsNotNull(quarter)], ReadSchema: struct\n", " : : : : : +- GpuColumnarToRow false\n", " : : : : : +- GpuSort [quarter#2484 ASC NULLS FIRST, loan_id#2453L ASC NULLS FIRST, timestamp_year#2307L ASC NULLS FIRST, timestamp_month#2336L ASC NULLS FIRST], false, com.nvidia.spark.rapids.OutOfCoreSort$@163d9f7d\n", " : : : : : +- GpuShuffleCoalesce 536870912\n", " : : : : : +- GpuColumnarExchange gpuhashpartitioning(quarter#2484, loan_id#2453L, timestamp_year#2307L, timestamp_month#2336L, 192), ENSURE_REQUIREMENTS, [id=#1114]\n", " : : : : : +- GpuRowToColumnar targetsize(536870912)\n", " : : : : : +- *(6) HashAggregate(keys=[quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248], functions=[])\n", " : : : : : +- GpuColumnarToRow false\n", " : : : : : +- GpuShuffleCoalesce 536870912\n", " : : : : : +- GpuColumnarExchange gpuhashpartitioning(quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248, 192), ENSURE_REQUIREMENTS, [id=#1107]\n", " : : : : : +- GpuHashAggregate(keys=[quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248], functions=[]), filters=ArrayBuffer())\n", " : : : : : +- GpuRowToColumnar targetsize(536870912)\n", " : : : : : +- *(5) Project [quarter#2484, FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) AS josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, loan_id#2453L, month_y#2248]\n", " : : : : : +- *(5) Filter (isnotnull(FLOOR((cast(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast((month_y#2248 - 1) as bigint)) as double) / 12.0))) AND isnotnull(CASE WHEN (pmod(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast(month_y#2248 as bigint)), 12) = 0) THEN 12 ELSE pmod(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast(month_y#2248 as bigint)), 12) END))\n", " : : : : : +- GpuColumnarToRow false\n", " : : : : : +- GpuGenerate gpuexplode([0,1,2,3,4,5,6,7,8,9,10,11]), [loan_id#2453L, quarter#2484, timestamp_month#2051, timestamp_year#2087, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997], false, [month_y#2248]\n", " : : : : : +- GpuProject [loan_id#2453L, quarter#2484, timestamp_month#2051, timestamp_year#2087, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997]\n", " : : : : : +- GpuBroadcastHashJoin [loan_id#2453L, quarter#2484], [loan_id#2202L, quarter#2233], LeftOuter, GpuBuildRight\n", " : : : : : :- GpuRowToColumnar targetsize(536870912)\n", " : : : : : : +- *(3) Project [quarter#2484, loan_id#2453L, month(cast(gettimestamp(monthly_reporting_period#2454, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_month#2051, year(cast(gettimestamp(monthly_reporting_period#2454, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_year#2087]\n", " : : : : : : +- GpuColumnarToRow false\n", " : : : : : : +- GpuFilter (gpuisnotnull(quarter#2484) AND gpuisnotnull(loan_id#2453L)), true\n", " : : : : : : +- GpuFileGpuScan parquet [loan_id#2453L,monthly_reporting_period#2454,quarter#2484] Batched: true, DataFilters: [isnotnull(quarter#2484), isnotnull(loan_id#2453L)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(quarter), IsNotNull(loan_id)], ReadSchema: struct\n", " : : : : : +- GpuBroadcastExchange HashedRelationBroadcastMode(List(input[1, bigint, true], input[0, string, true]),false), [id=#1096]\n", " : : : : : +- GpuHashAggregate(keys=[quarter#2233, loan_id#2202L], functions=[gpumax(current_loan_delinquency_status#2212), gpumin(delinquency_30#1975), gpumin(delinquency_90#1976), gpumin(delinquency_180#1977)]), filters=ArrayBuffer(None, None, None, None))\n", " : : : : : +- GpuShuffleCoalesce 536870912\n", " : : : : : +- GpuColumnarExchange gpuhashpartitioning(quarter#2233, loan_id#2202L, 192), ENSURE_REQUIREMENTS, [id=#1093]\n", " : : : : : +- GpuHashAggregate(keys=[quarter#2233, loan_id#2202L], functions=[partial_gpumax(current_loan_delinquency_status#2212), partial_gpumin(delinquency_30#1975), partial_gpumin(delinquency_90#1976), partial_gpumin(delinquency_180#1977)]), filters=ArrayBuffer(None, None, None, None))\n", " : : : : : +- GpuRowToColumnar targetsize(536870912)\n", " : : : : : +- *(4) Project [quarter#2233, loan_id#2202L, current_loan_delinquency_status#2212, CASE WHEN (current_loan_delinquency_status#2212 >= 1) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_30#1975, CASE WHEN (current_loan_delinquency_status#2212 >= 3) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_90#1976, CASE WHEN (current_loan_delinquency_status#2212 >= 6) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_180#1977]\n", " : : : : : +- GpuColumnarToRow false\n", " : : : : : +- GpuFilter (gpuisnotnull(loan_id#2202L) AND gpuisnotnull(quarter#2233)), true\n", " : : : : : +- GpuFileGpuScan parquet [loan_id#2202L,monthly_reporting_period#2203,current_loan_delinquency_status#2212,quarter#2233] Batched: true, DataFilters: [isnotnull(loan_id#2202L), isnotnull(quarter#2233)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(loan_id), IsNotNull(quarter)], ReadSchema: struct\n", " : : : : +- GpuColumnarToRow false\n", " : : : : +- GpuSort [quarter#2484 ASC NULLS FIRST, loan_id#2453L ASC NULLS FIRST, timestamp_year#2307L ASC NULLS FIRST, timestamp_month#2336L ASC NULLS FIRST], false, com.nvidia.spark.rapids.OutOfCoreSort$@163d9f7d\n", " : : : : +- GpuShuffleCoalesce 536870912\n", " : : : : +- GpuColumnarExchange gpuhashpartitioning(quarter#2484, loan_id#2453L, timestamp_year#2307L, timestamp_month#2336L, 192), ENSURE_REQUIREMENTS, [id=#1114]\n", " : : : : +- GpuRowToColumnar targetsize(536870912)\n", " : : : : +- *(6) HashAggregate(keys=[quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248], functions=[])\n", " : : : : +- GpuColumnarToRow false\n", " : : : : +- GpuShuffleCoalesce 536870912\n", " : : : : +- GpuColumnarExchange gpuhashpartitioning(quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248, 192), ENSURE_REQUIREMENTS, [id=#1107]\n", " : : : : +- GpuHashAggregate(keys=[quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248], functions=[]), filters=ArrayBuffer())\n", " : : : : +- GpuRowToColumnar targetsize(536870912)\n", " : : : : +- *(5) Project [quarter#2484, FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) AS josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, loan_id#2453L, month_y#2248]\n", " : : : : +- *(5) Filter (isnotnull(FLOOR((cast(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast((month_y#2248 - 1) as bigint)) as double) / 12.0))) AND isnotnull(CASE WHEN (pmod(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast(month_y#2248 as bigint)), 12) = 0) THEN 12 ELSE pmod(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast(month_y#2248 as bigint)), 12) END))\n", " : : : : +- GpuColumnarToRow false\n", " : : : : +- GpuGenerate gpuexplode([0,1,2,3,4,5,6,7,8,9,10,11]), [loan_id#2453L, quarter#2484, timestamp_month#2051, timestamp_year#2087, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997], false, [month_y#2248]\n", " : : : : +- GpuProject [loan_id#2453L, quarter#2484, timestamp_month#2051, timestamp_year#2087, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997]\n", " : : : : +- GpuBroadcastHashJoin [loan_id#2453L, quarter#2484], [loan_id#2202L, quarter#2233], LeftOuter, GpuBuildRight\n", " : : : : :- GpuRowToColumnar targetsize(536870912)\n", " : : : : : +- *(3) Project [quarter#2484, loan_id#2453L, month(cast(gettimestamp(monthly_reporting_period#2454, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_month#2051, year(cast(gettimestamp(monthly_reporting_period#2454, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_year#2087]\n", " : : : : : +- GpuColumnarToRow false\n", " : : : : : +- GpuFilter (gpuisnotnull(quarter#2484) AND gpuisnotnull(loan_id#2453L)), true\n", " : : : : : +- GpuFileGpuScan parquet [loan_id#2453L,monthly_reporting_period#2454,quarter#2484] Batched: true, DataFilters: [isnotnull(quarter#2484), isnotnull(loan_id#2453L)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(quarter), IsNotNull(loan_id)], ReadSchema: struct\n", " : : : : +- GpuBroadcastExchange HashedRelationBroadcastMode(List(input[1, bigint, true], input[0, string, true]),false), [id=#1096]\n", " : : : : +- GpuHashAggregate(keys=[quarter#2233, loan_id#2202L], functions=[gpumax(current_loan_delinquency_status#2212), gpumin(delinquency_30#1975), gpumin(delinquency_90#1976), gpumin(delinquency_180#1977)]), filters=ArrayBuffer(None, None, None, None))\n", " : : : : +- GpuShuffleCoalesce 536870912\n", " : : : : +- GpuColumnarExchange gpuhashpartitioning(quarter#2233, loan_id#2202L, 192), ENSURE_REQUIREMENTS, [id=#1093]\n", " : : : : +- GpuHashAggregate(keys=[quarter#2233, loan_id#2202L], functions=[partial_gpumax(current_loan_delinquency_status#2212), partial_gpumin(delinquency_30#1975), partial_gpumin(delinquency_90#1976), partial_gpumin(delinquency_180#1977)]), filters=ArrayBuffer(None, None, None, None))\n", " : : : : +- GpuRowToColumnar targetsize(536870912)\n", " : : : : +- *(4) Project [quarter#2233, loan_id#2202L, current_loan_delinquency_status#2212, CASE WHEN (current_loan_delinquency_status#2212 >= 1) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_30#1975, CASE WHEN (current_loan_delinquency_status#2212 >= 3) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_90#1976, CASE WHEN (current_loan_delinquency_status#2212 >= 6) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_180#1977]\n", " : : : : +- GpuColumnarToRow false\n", " : : : : +- GpuFilter (gpuisnotnull(loan_id#2202L) AND gpuisnotnull(quarter#2233)), true\n", " : : : : +- GpuFileGpuScan parquet [loan_id#2202L,monthly_reporting_period#2203,current_loan_delinquency_status#2212,quarter#2233] Batched: true, DataFilters: [isnotnull(loan_id#2202L), isnotnull(quarter#2233)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(loan_id), IsNotNull(quarter)], ReadSchema: struct\n", " : : : +- GpuColumnarToRow false\n", " : : : +- GpuSort [quarter#2484 ASC NULLS FIRST, loan_id#2453L ASC NULLS FIRST, timestamp_year#2307L ASC NULLS FIRST, timestamp_month#2336L ASC NULLS FIRST], false, com.nvidia.spark.rapids.OutOfCoreSort$@163d9f7d\n", " : : : +- GpuShuffleCoalesce 536870912\n", " : : : +- GpuColumnarExchange gpuhashpartitioning(quarter#2484, loan_id#2453L, timestamp_year#2307L, timestamp_month#2336L, 192), ENSURE_REQUIREMENTS, [id=#1114]\n", " : : : +- GpuRowToColumnar targetsize(536870912)\n", " : : : +- *(6) HashAggregate(keys=[quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248], functions=[])\n", " : : : +- GpuColumnarToRow false\n", " : : : +- GpuShuffleCoalesce 536870912\n", " : : : +- GpuColumnarExchange gpuhashpartitioning(quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248, 192), ENSURE_REQUIREMENTS, [id=#1107]\n", " : : : +- GpuHashAggregate(keys=[quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248], functions=[]), filters=ArrayBuffer())\n", " : : : +- GpuRowToColumnar targetsize(536870912)\n", " : : : +- *(5) Project [quarter#2484, FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) AS josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, loan_id#2453L, month_y#2248]\n", " : : : +- *(5) Filter (isnotnull(FLOOR((cast(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast((month_y#2248 - 1) as bigint)) as double) / 12.0))) AND isnotnull(CASE WHEN (pmod(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast(month_y#2248 as bigint)), 12) = 0) THEN 12 ELSE pmod(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast(month_y#2248 as bigint)), 12) END))\n", " : : : +- GpuColumnarToRow false\n", " : : : +- GpuGenerate gpuexplode([0,1,2,3,4,5,6,7,8,9,10,11]), [loan_id#2453L, quarter#2484, timestamp_month#2051, timestamp_year#2087, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997], false, [month_y#2248]\n", " : : : +- GpuProject [loan_id#2453L, quarter#2484, timestamp_month#2051, timestamp_year#2087, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997]\n", " : : : +- GpuBroadcastHashJoin [loan_id#2453L, quarter#2484], [loan_id#2202L, quarter#2233], LeftOuter, GpuBuildRight\n", " : : : :- GpuRowToColumnar targetsize(536870912)\n", " : : : : +- *(3) Project [quarter#2484, loan_id#2453L, month(cast(gettimestamp(monthly_reporting_period#2454, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_month#2051, year(cast(gettimestamp(monthly_reporting_period#2454, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_year#2087]\n", " : : : : +- GpuColumnarToRow false\n", " : : : : +- GpuFilter (gpuisnotnull(quarter#2484) AND gpuisnotnull(loan_id#2453L)), true\n", " : : : : +- GpuFileGpuScan parquet [loan_id#2453L,monthly_reporting_period#2454,quarter#2484] Batched: true, DataFilters: [isnotnull(quarter#2484), isnotnull(loan_id#2453L)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(quarter), IsNotNull(loan_id)], ReadSchema: struct\n", " : : : +- GpuBroadcastExchange HashedRelationBroadcastMode(List(input[1, bigint, true], input[0, string, true]),false), [id=#1096]\n", " : : : +- GpuHashAggregate(keys=[quarter#2233, loan_id#2202L], functions=[gpumax(current_loan_delinquency_status#2212), gpumin(delinquency_30#1975), gpumin(delinquency_90#1976), gpumin(delinquency_180#1977)]), filters=ArrayBuffer(None, None, None, None))\n", " : : : +- GpuShuffleCoalesce 536870912\n", " : : : +- GpuColumnarExchange gpuhashpartitioning(quarter#2233, loan_id#2202L, 192), ENSURE_REQUIREMENTS, [id=#1093]\n", " : : : +- GpuHashAggregate(keys=[quarter#2233, loan_id#2202L], functions=[partial_gpumax(current_loan_delinquency_status#2212), partial_gpumin(delinquency_30#1975), partial_gpumin(delinquency_90#1976), partial_gpumin(delinquency_180#1977)]), filters=ArrayBuffer(None, None, None, None))\n", " : : : +- GpuRowToColumnar targetsize(536870912)\n", " : : : +- *(4) Project [quarter#2233, loan_id#2202L, current_loan_delinquency_status#2212, CASE WHEN (current_loan_delinquency_status#2212 >= 1) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_30#1975, CASE WHEN (current_loan_delinquency_status#2212 >= 3) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_90#1976, CASE WHEN (current_loan_delinquency_status#2212 >= 6) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_180#1977]\n", " : : : +- GpuColumnarToRow false\n", " : : : +- GpuFilter (gpuisnotnull(loan_id#2202L) AND gpuisnotnull(quarter#2233)), true\n", " : : : +- GpuFileGpuScan parquet [loan_id#2202L,monthly_reporting_period#2203,current_loan_delinquency_status#2212,quarter#2233] Batched: true, DataFilters: [isnotnull(loan_id#2202L), isnotnull(quarter#2233)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(loan_id), IsNotNull(quarter)], ReadSchema: struct\n", " : : +- GpuColumnarToRow false\n", " : : +- GpuSort [quarter#2484 ASC NULLS FIRST, loan_id#2453L ASC NULLS FIRST, timestamp_year#2307L ASC NULLS FIRST, timestamp_month#2336L ASC NULLS FIRST], false, com.nvidia.spark.rapids.OutOfCoreSort$@163d9f7d\n", " : : +- GpuShuffleCoalesce 536870912\n", " : : +- GpuColumnarExchange gpuhashpartitioning(quarter#2484, loan_id#2453L, timestamp_year#2307L, timestamp_month#2336L, 192), ENSURE_REQUIREMENTS, [id=#1114]\n", " : : +- GpuRowToColumnar targetsize(536870912)\n", " : : +- *(6) HashAggregate(keys=[quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248], functions=[])\n", " : : +- GpuColumnarToRow false\n", " : : +- GpuShuffleCoalesce 536870912\n", " : : +- GpuColumnarExchange gpuhashpartitioning(quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248, 192), ENSURE_REQUIREMENTS, [id=#1107]\n", " : : +- GpuHashAggregate(keys=[quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248], functions=[]), filters=ArrayBuffer())\n", " : : +- GpuRowToColumnar targetsize(536870912)\n", " : : +- *(5) Project [quarter#2484, FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) AS josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, loan_id#2453L, month_y#2248]\n", " : : +- *(5) Filter (isnotnull(FLOOR((cast(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast((month_y#2248 - 1) as bigint)) as double) / 12.0))) AND isnotnull(CASE WHEN (pmod(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast(month_y#2248 as bigint)), 12) = 0) THEN 12 ELSE pmod(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast(month_y#2248 as bigint)), 12) END))\n", " : : +- GpuColumnarToRow false\n", " : : +- GpuGenerate gpuexplode([0,1,2,3,4,5,6,7,8,9,10,11]), [loan_id#2453L, quarter#2484, timestamp_month#2051, timestamp_year#2087, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997], false, [month_y#2248]\n", " : : +- GpuProject [loan_id#2453L, quarter#2484, timestamp_month#2051, timestamp_year#2087, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997]\n", " : : +- GpuBroadcastHashJoin [loan_id#2453L, quarter#2484], [loan_id#2202L, quarter#2233], LeftOuter, GpuBuildRight\n", " : : :- GpuRowToColumnar targetsize(536870912)\n", " : : : +- *(3) Project [quarter#2484, loan_id#2453L, month(cast(gettimestamp(monthly_reporting_period#2454, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_month#2051, year(cast(gettimestamp(monthly_reporting_period#2454, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_year#2087]\n", " : : : +- GpuColumnarToRow false\n", " : : : +- GpuFilter (gpuisnotnull(quarter#2484) AND gpuisnotnull(loan_id#2453L)), true\n", " : : : +- GpuFileGpuScan parquet [loan_id#2453L,monthly_reporting_period#2454,quarter#2484] Batched: true, DataFilters: [isnotnull(quarter#2484), isnotnull(loan_id#2453L)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(quarter), IsNotNull(loan_id)], ReadSchema: struct\n", " : : +- GpuBroadcastExchange HashedRelationBroadcastMode(List(input[1, bigint, true], input[0, string, true]),false), [id=#1096]\n", " : : +- GpuHashAggregate(keys=[quarter#2233, loan_id#2202L], functions=[gpumax(current_loan_delinquency_status#2212), gpumin(delinquency_30#1975), gpumin(delinquency_90#1976), gpumin(delinquency_180#1977)]), filters=ArrayBuffer(None, None, None, None))\n", " : : +- GpuShuffleCoalesce 536870912\n", " : : +- GpuColumnarExchange gpuhashpartitioning(quarter#2233, loan_id#2202L, 192), ENSURE_REQUIREMENTS, [id=#1093]\n", " : : +- GpuHashAggregate(keys=[quarter#2233, loan_id#2202L], functions=[partial_gpumax(current_loan_delinquency_status#2212), partial_gpumin(delinquency_30#1975), partial_gpumin(delinquency_90#1976), partial_gpumin(delinquency_180#1977)]), filters=ArrayBuffer(None, None, None, None))\n", " : : +- GpuRowToColumnar targetsize(536870912)\n", " : : +- *(4) Project [quarter#2233, loan_id#2202L, current_loan_delinquency_status#2212, CASE WHEN (current_loan_delinquency_status#2212 >= 1) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_30#1975, CASE WHEN (current_loan_delinquency_status#2212 >= 3) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_90#1976, CASE WHEN (current_loan_delinquency_status#2212 >= 6) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_180#1977]\n", " : : +- GpuColumnarToRow false\n", " : : +- GpuFilter (gpuisnotnull(loan_id#2202L) AND gpuisnotnull(quarter#2233)), true\n", " : : +- GpuFileGpuScan parquet [loan_id#2202L,monthly_reporting_period#2203,current_loan_delinquency_status#2212,quarter#2233] Batched: true, DataFilters: [isnotnull(loan_id#2202L), isnotnull(quarter#2233)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(loan_id), IsNotNull(quarter)], ReadSchema: struct\n", " : +- GpuColumnarToRow false\n", " : +- GpuSort [quarter#2484 ASC NULLS FIRST, loan_id#2453L ASC NULLS FIRST, timestamp_year#2307L ASC NULLS FIRST, timestamp_month#2336L ASC NULLS FIRST], false, com.nvidia.spark.rapids.OutOfCoreSort$@163d9f7d\n", " : +- GpuShuffleCoalesce 536870912\n", " : +- GpuColumnarExchange gpuhashpartitioning(quarter#2484, loan_id#2453L, timestamp_year#2307L, timestamp_month#2336L, 192), ENSURE_REQUIREMENTS, [id=#1114]\n", " : +- GpuRowToColumnar targetsize(536870912)\n", " : +- *(6) HashAggregate(keys=[quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248], functions=[])\n", " : +- GpuColumnarToRow false\n", " : +- GpuShuffleCoalesce 536870912\n", " : +- GpuColumnarExchange gpuhashpartitioning(quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248, 192), ENSURE_REQUIREMENTS, [id=#1107]\n", " : +- GpuHashAggregate(keys=[quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248], functions=[]), filters=ArrayBuffer())\n", " : +- GpuRowToColumnar targetsize(536870912)\n", " : +- *(5) Project [quarter#2484, FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) AS josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, loan_id#2453L, month_y#2248]\n", " : +- *(5) Filter (isnotnull(FLOOR((cast(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast((month_y#2248 - 1) as bigint)) as double) / 12.0))) AND isnotnull(CASE WHEN (pmod(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast(month_y#2248 as bigint)), 12) = 0) THEN 12 ELSE pmod(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast(month_y#2248 as bigint)), 12) END))\n", " : +- GpuColumnarToRow false\n", " : +- GpuGenerate gpuexplode([0,1,2,3,4,5,6,7,8,9,10,11]), [loan_id#2453L, quarter#2484, timestamp_month#2051, timestamp_year#2087, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997], false, [month_y#2248]\n", " : +- GpuProject [loan_id#2453L, quarter#2484, timestamp_month#2051, timestamp_year#2087, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997]\n", " : +- GpuBroadcastHashJoin [loan_id#2453L, quarter#2484], [loan_id#2202L, quarter#2233], LeftOuter, GpuBuildRight\n", " : :- GpuRowToColumnar targetsize(536870912)\n", " : : +- *(3) Project [quarter#2484, loan_id#2453L, month(cast(gettimestamp(monthly_reporting_period#2454, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_month#2051, year(cast(gettimestamp(monthly_reporting_period#2454, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_year#2087]\n", " : : +- GpuColumnarToRow false\n", " : : +- GpuFilter (gpuisnotnull(quarter#2484) AND gpuisnotnull(loan_id#2453L)), true\n", " : : +- GpuFileGpuScan parquet [loan_id#2453L,monthly_reporting_period#2454,quarter#2484] Batched: true, DataFilters: [isnotnull(quarter#2484), isnotnull(loan_id#2453L)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(quarter), IsNotNull(loan_id)], ReadSchema: struct\n", " : +- GpuBroadcastExchange HashedRelationBroadcastMode(List(input[1, bigint, true], input[0, string, true]),false), [id=#1096]\n", " : +- GpuHashAggregate(keys=[quarter#2233, loan_id#2202L], functions=[gpumax(current_loan_delinquency_status#2212), gpumin(delinquency_30#1975), gpumin(delinquency_90#1976), gpumin(delinquency_180#1977)]), filters=ArrayBuffer(None, None, None, None))\n", " : +- GpuShuffleCoalesce 536870912\n", " : +- GpuColumnarExchange gpuhashpartitioning(quarter#2233, loan_id#2202L, 192), ENSURE_REQUIREMENTS, [id=#1093]\n", " : +- GpuHashAggregate(keys=[quarter#2233, loan_id#2202L], functions=[partial_gpumax(current_loan_delinquency_status#2212), partial_gpumin(delinquency_30#1975), partial_gpumin(delinquency_90#1976), partial_gpumin(delinquency_180#1977)]), filters=ArrayBuffer(None, None, None, None))\n", " : +- GpuRowToColumnar targetsize(536870912)\n", " : +- *(4) Project [quarter#2233, loan_id#2202L, current_loan_delinquency_status#2212, CASE WHEN (current_loan_delinquency_status#2212 >= 1) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_30#1975, CASE WHEN (current_loan_delinquency_status#2212 >= 3) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_90#1976, CASE WHEN (current_loan_delinquency_status#2212 >= 6) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_180#1977]\n", " : +- GpuColumnarToRow false\n", " : +- GpuFilter (gpuisnotnull(loan_id#2202L) AND gpuisnotnull(quarter#2233)), true\n", " : +- GpuFileGpuScan parquet [loan_id#2202L,monthly_reporting_period#2203,current_loan_delinquency_status#2212,quarter#2233] Batched: true, DataFilters: [isnotnull(loan_id#2202L), isnotnull(quarter#2233)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(loan_id), IsNotNull(quarter)], ReadSchema: struct cannot run on GPU because not all expressions can be replaced; GPU does not currently support the operator class org.apache.spark.sql.execution.DeserializeToObjectExec\n", " ! createexternalrow(prediction#2153, delinquency_12#2255, 1.0#2256, newInstance(class org.apache.spark.ml.linalg.VectorUDT).deserialize, StructField(prediction,DoubleType,true), StructField(delinquency_12,DoubleType,true), StructField(1.0,DoubleType,false), StructField(probability,org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7,true)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.CreateExternalRow\n", " @Expression prediction#2153 could run on GPU\n", " @Expression delinquency_12#2255 could run on GPU\n", " @Expression 1.0#2256 could run on GPU\n", " ! newInstance(class org.apache.spark.ml.linalg.VectorUDT).deserialize cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.Invoke\n", " ! newInstance(class org.apache.spark.ml.linalg.VectorUDT) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.NewInstance\n", " !Expression probability#2186 cannot run on GPU because expression AttributeReference probability#2186 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\n", " !Expression obj#2261 cannot run on GPU because expression AttributeReference obj#2261 produces an unsupported type ObjectType(interface org.apache.spark.sql.Row)\n", " !Exec cannot run on GPU because not all expressions can be replaced; unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [probability#2186]\n", " @Expression pythonUDF0#2552.prediction AS prediction#2153 could run on GPU\n", " @Expression pythonUDF0#2552.prediction could run on GPU\n", " @Expression pythonUDF0#2552 could run on GPU\n", " @Expression cast(delinquency_12#27 as double) AS delinquency_12#2255 could run on GPU\n", " @Expression cast(delinquency_12#27 as double) could run on GPU\n", " @Expression delinquency_12#27 could run on GPU\n", " @Expression 1.0 AS 1.0#2256 could run on GPU\n", " @Expression 1.0 could run on GPU\n", " !Expression UDF(pythonUDF0#2552.probability) AS probability#2186 cannot run on GPU because input expression ScalaUDF UDF(pythonUDF0#2552.probability) (org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 is not supported); expression Alias UDF(pythonUDF0#2552.probability) AS probability#2186 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\n", " !Expression UDF(pythonUDF0#2552.probability) cannot run on GPU because expression ScalaUDF UDF(pythonUDF0#2552.probability) produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7; neither UDF implemented by class org.apache.spark.ml.functions$$$Lambda$3606/1625633331 provides a GPU implementation, nor the conf `spark.rapids.sql.rowBasedUDF.enabled` is enabled\n", " @Expression pythonUDF0#2552.probability could run on GPU\n", " @Expression pythonUDF0#2552 could run on GPU\n", "\n", "If features_cols param set, then features_col param is ignored. \n", "2022-11-25 09:35:34,074 WARN rapids.GpuOverrides: \n", "! cannot run on GPU because not all expressions can be replaced; GPU does not currently support the operator class org.apache.spark.sql.execution.DeserializeToObjectExec\n", " ! createexternalrow(prediction#4415, delinquency_12#4517, 1.0#4518, newInstance(class org.apache.spark.ml.linalg.VectorUDT).deserialize, StructField(prediction,DoubleType,true), StructField(delinquency_12,DoubleType,true), StructField(1.0,DoubleType,false), StructField(probability,org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7,true)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.CreateExternalRow\n", " @Expression prediction#4415 could run on GPU\n", " @Expression delinquency_12#4517 could run on GPU\n", " @Expression 1.0#4518 could run on GPU\n", " ! newInstance(class org.apache.spark.ml.linalg.VectorUDT).deserialize cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.Invoke\n", " ! newInstance(class org.apache.spark.ml.linalg.VectorUDT) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.NewInstance\n", " !Expression probability#4448 cannot run on GPU because expression AttributeReference probability#4448 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\n", " !Expression obj#4523 cannot run on GPU because expression AttributeReference obj#4523 produces an unsupported type ObjectType(interface org.apache.spark.sql.Row)\n", " !Exec cannot run on GPU because unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [probability#4448]; not all expressions can be replaced\n", " @Expression pythonUDF0#4814.prediction AS prediction#4415 could run on GPU\n", " @Expression pythonUDF0#4814.prediction could run on GPU\n", " @Expression pythonUDF0#4814 could run on GPU\n", " @Expression cast(delinquency_12#27 as double) AS delinquency_12#4517 could run on GPU\n", " @Expression cast(delinquency_12#27 as double) could run on GPU\n", " @Expression delinquency_12#27 could run on GPU\n", " @Expression 1.0 AS 1.0#4518 could run on GPU\n", " @Expression 1.0 could run on GPU\n", " !Expression UDF(pythonUDF0#4814.probability) AS probability#4448 cannot run on GPU because expression Alias UDF(pythonUDF0#4814.probability) AS probability#4448 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7; input expression ScalaUDF UDF(pythonUDF0#4814.probability) (org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 is not supported)\n", " !Expression UDF(pythonUDF0#4814.probability) cannot run on GPU because neither UDF implemented by class org.apache.spark.ml.functions$$$Lambda$3606/1625633331 provides a GPU implementation, nor the conf `spark.rapids.sql.rowBasedUDF.enabled` is enabled; expression ScalaUDF UDF(pythonUDF0#4814.probability) produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\n", " @Expression pythonUDF0#4814.probability could run on GPU\n", " @Expression pythonUDF0#4814 could run on GPU\n", "\n", "If features_cols param set, then features_col param is ignored.\n", "2022-11-25 09:35:37,859 WARN rapids.GpuOverrides: \n", "! cannot run on GPU because not all expressions can be replaced; GPU does not currently support the operator class org.apache.spark.sql.execution.DeserializeToObjectExec\n", " ! createexternalrow(prediction#6677, delinquency_12#6779, 1.0#6780, newInstance(class org.apache.spark.ml.linalg.VectorUDT).deserialize, StructField(prediction,DoubleType,true), StructField(delinquency_12,DoubleType,true), StructField(1.0,DoubleType,false), StructField(probability,org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7,true)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.CreateExternalRow\n", " @Expression prediction#6677 could run on GPU\n", " @Expression delinquency_12#6779 could run on GPU\n", " @Expression 1.0#6780 could run on GPU\n", " ! newInstance(class org.apache.spark.ml.linalg.VectorUDT).deserialize cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.Invoke\n", " ! newInstance(class org.apache.spark.ml.linalg.VectorUDT) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.NewInstance\n", " !Expression probability#6710 cannot run on GPU because expression AttributeReference probability#6710 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\n", " !Expression obj#6785 cannot run on GPU because expression AttributeReference obj#6785 produces an unsupported type ObjectType(interface org.apache.spark.sql.Row)\n", " !Exec cannot run on GPU because unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [probability#6710]; not all expressions can be replaced\n", " @Expression pythonUDF0#7076.prediction AS prediction#6677 could run on GPU\n", " @Expression pythonUDF0#7076.prediction could run on GPU\n", " @Expression pythonUDF0#7076 could run on GPU\n", " @Expression cast(delinquency_12#27 as double) AS delinquency_12#6779 could run on GPU\n", " @Expression cast(delinquency_12#27 as double) could run on GPU\n", " @Expression delinquency_12#27 could run on GPU\n", " @Expression 1.0 AS 1.0#6780 could run on GPU\n", " @Expression 1.0 could run on GPU\n", " !Expression UDF(pythonUDF0#7076.probability) AS probability#6710 cannot run on GPU because input expression ScalaUDF UDF(pythonUDF0#7076.probability) (org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 is not supported); expression Alias UDF(pythonUDF0#7076.probability) AS probability#6710 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\n", " !Expression UDF(pythonUDF0#7076.probability) cannot run on GPU because expression ScalaUDF UDF(pythonUDF0#7076.probability) produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7; neither UDF implemented by class org.apache.spark.ml.functions$$$Lambda$3606/1625633331 provides a GPU implementation, nor the conf `spark.rapids.sql.rowBasedUDF.enabled` is enabled\n", " @Expression pythonUDF0#7076.probability could run on GPU\n", " @Expression pythonUDF0#7076 could run on GPU\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "If features_cols param set, then features_col param is ignored.\n", "2022-11-25 09:35:41,551 WARN rapids.GpuOverrides: \n", "! cannot run on GPU because not all expressions can be replaced; GPU does not currently support the operator class org.apache.spark.sql.execution.DeserializeToObjectExec\n", " ! createexternalrow(prediction#8939, delinquency_12#9041, 1.0#9042, newInstance(class org.apache.spark.ml.linalg.VectorUDT).deserialize, StructField(prediction,DoubleType,true), StructField(delinquency_12,DoubleType,true), StructField(1.0,DoubleType,false), StructField(probability,org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7,true)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.CreateExternalRow\n", " @Expression prediction#8939 could run on GPU\n", " @Expression delinquency_12#9041 could run on GPU\n", " @Expression 1.0#9042 could run on GPU\n", " ! newInstance(class org.apache.spark.ml.linalg.VectorUDT).deserialize cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.Invoke\n", " ! newInstance(class org.apache.spark.ml.linalg.VectorUDT) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.NewInstance\n", " !Expression probability#8972 cannot run on GPU because expression AttributeReference probability#8972 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\n", " !Expression obj#9047 cannot run on GPU because expression AttributeReference obj#9047 produces an unsupported type ObjectType(interface org.apache.spark.sql.Row)\n", " !Exec cannot run on GPU because unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [probability#8972]; not all expressions can be replaced\n", " @Expression pythonUDF0#9338.prediction AS prediction#8939 could run on GPU\n", " @Expression pythonUDF0#9338.prediction could run on GPU\n", " @Expression pythonUDF0#9338 could run on GPU\n", " @Expression cast(delinquency_12#27 as double) AS delinquency_12#9041 could run on GPU\n", " @Expression cast(delinquency_12#27 as double) could run on GPU\n", " @Expression delinquency_12#27 could run on GPU\n", " @Expression 1.0 AS 1.0#9042 could run on GPU\n", " @Expression 1.0 could run on GPU\n", " !Expression UDF(pythonUDF0#9338.probability) AS probability#8972 cannot run on GPU because input expression ScalaUDF UDF(pythonUDF0#9338.probability) (org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 is not supported); expression Alias UDF(pythonUDF0#9338.probability) AS probability#8972 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\n", " !Expression UDF(pythonUDF0#9338.probability) cannot run on GPU because neither UDF implemented by class org.apache.spark.ml.functions$$$Lambda$3606/1625633331 provides a GPU implementation, nor the conf `spark.rapids.sql.rowBasedUDF.enabled` is enabled; expression ScalaUDF UDF(pythonUDF0#9338.probability) produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\n", " @Expression pythonUDF0#9338.probability could run on GPU\n", " @Expression pythonUDF0#9338 could run on GPU\n", "\n", "If features_cols param set, then features_col param is ignored.\n", "2022-11-25 09:35:45,231 WARN rapids.GpuOverrides: \n", "! cannot run on GPU because not all expressions can be replaced; GPU does not currently support the operator class org.apache.spark.sql.execution.DeserializeToObjectExec\n", " ! createexternalrow(prediction#11491, delinquency_12#11593, 1.0#11594, newInstance(class org.apache.spark.ml.linalg.VectorUDT).deserialize, StructField(prediction,DoubleType,true), StructField(delinquency_12,DoubleType,true), StructField(1.0,DoubleType,false), StructField(probability,org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7,true)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.CreateExternalRow\n", " @Expression prediction#11491 could run on GPU\n", " @Expression delinquency_12#11593 could run on GPU\n", " @Expression 1.0#11594 could run on GPU\n", " ! newInstance(class org.apache.spark.ml.linalg.VectorUDT).deserialize cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.Invoke\n", " ! newInstance(class org.apache.spark.ml.linalg.VectorUDT) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.NewInstance\n", " !Expression probability#11524 cannot run on GPU because expression AttributeReference probability#11524 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\n", " !Expression obj#11599 cannot run on GPU because expression AttributeReference obj#11599 produces an unsupported type ObjectType(interface org.apache.spark.sql.Row)\n", " !Exec cannot run on GPU because not all expressions can be replaced; unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [probability#11524]\n", " @Expression pythonUDF0#11890.prediction AS prediction#11491 could run on GPU\n", " @Expression pythonUDF0#11890.prediction could run on GPU\n", " @Expression pythonUDF0#11890 could run on GPU\n", " @Expression cast(delinquency_12#27 as double) AS delinquency_12#11593 could run on GPU\n", " @Expression cast(delinquency_12#27 as double) could run on GPU\n", " @Expression delinquency_12#27 could run on GPU\n", " @Expression 1.0 AS 1.0#11594 could run on GPU\n", " @Expression 1.0 could run on GPU\n", " !Expression UDF(pythonUDF0#11890.probability) AS probability#11524 cannot run on GPU because expression Alias UDF(pythonUDF0#11890.probability) AS probability#11524 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7; input expression ScalaUDF UDF(pythonUDF0#11890.probability) (org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 is not supported)\n", " !Expression UDF(pythonUDF0#11890.probability) cannot run on GPU because neither UDF implemented by class org.apache.spark.ml.functions$$$Lambda$3606/1625633331 provides a GPU implementation, nor the conf `spark.rapids.sql.rowBasedUDF.enabled` is enabled; expression ScalaUDF UDF(pythonUDF0#11890.probability) produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\n", " @Expression pythonUDF0#11890.probability could run on GPU\n", " @Expression pythonUDF0#11890 could run on GPU\n", "\n", "If features_cols param set, then features_col param is ignored.\n", "2022-11-25 09:35:49,003 WARN rapids.GpuOverrides: \n", "! cannot run on GPU because not all expressions can be replaced; GPU does not currently support the operator class org.apache.spark.sql.execution.DeserializeToObjectExec\n", " ! createexternalrow(prediction#13753, delinquency_12#13855, 1.0#13856, newInstance(class org.apache.spark.ml.linalg.VectorUDT).deserialize, StructField(prediction,DoubleType,true), StructField(delinquency_12,DoubleType,true), StructField(1.0,DoubleType,false), StructField(probability,org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7,true)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.CreateExternalRow\n", " @Expression prediction#13753 could run on GPU\n", " @Expression delinquency_12#13855 could run on GPU\n", " @Expression 1.0#13856 could run on GPU\n", " ! newInstance(class org.apache.spark.ml.linalg.VectorUDT).deserialize cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.Invoke\n", " ! newInstance(class org.apache.spark.ml.linalg.VectorUDT) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.NewInstance\n", " !Expression probability#13786 cannot run on GPU because expression AttributeReference probability#13786 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\n", " !Expression obj#13861 cannot run on GPU because expression AttributeReference obj#13861 produces an unsupported type ObjectType(interface org.apache.spark.sql.Row)\n", " !Exec cannot run on GPU because unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [probability#13786]; not all expressions can be replaced\n", " @Expression pythonUDF0#14152.prediction AS prediction#13753 could run on GPU\n", " @Expression pythonUDF0#14152.prediction could run on GPU\n", " @Expression pythonUDF0#14152 could run on GPU\n", " @Expression cast(delinquency_12#27 as double) AS delinquency_12#13855 could run on GPU\n", " @Expression cast(delinquency_12#27 as double) could run on GPU\n", " @Expression delinquency_12#27 could run on GPU\n", " @Expression 1.0 AS 1.0#13856 could run on GPU\n", " @Expression 1.0 could run on GPU\n", " !Expression UDF(pythonUDF0#14152.probability) AS probability#13786 cannot run on GPU because expression Alias UDF(pythonUDF0#14152.probability) AS probability#13786 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7; input expression ScalaUDF UDF(pythonUDF0#14152.probability) (org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 is not supported)\n", " !Expression UDF(pythonUDF0#14152.probability) cannot run on GPU because expression ScalaUDF UDF(pythonUDF0#14152.probability) produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7; neither UDF implemented by class org.apache.spark.ml.functions$$$Lambda$3606/1625633331 provides a GPU implementation, nor the conf `spark.rapids.sql.rowBasedUDF.enabled` is enabled\n", " @Expression pythonUDF0#14152.probability could run on GPU\n", " @Expression pythonUDF0#14152 could run on GPU\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "If features_cols param set, then features_col param is ignored.\n", "2022-11-25 09:35:52,578 WARN rapids.GpuOverrides: \n", "! cannot run on GPU because not all expressions can be replaced; GPU does not currently support the operator class org.apache.spark.sql.execution.DeserializeToObjectExec\n", " ! createexternalrow(prediction#16015, delinquency_12#16117, 1.0#16118, newInstance(class org.apache.spark.ml.linalg.VectorUDT).deserialize, StructField(prediction,DoubleType,true), StructField(delinquency_12,DoubleType,true), StructField(1.0,DoubleType,false), StructField(probability,org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7,true)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.CreateExternalRow\n", " @Expression prediction#16015 could run on GPU\n", " @Expression delinquency_12#16117 could run on GPU\n", " @Expression 1.0#16118 could run on GPU\n", " ! newInstance(class org.apache.spark.ml.linalg.VectorUDT).deserialize cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.Invoke\n", " ! newInstance(class org.apache.spark.ml.linalg.VectorUDT) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.NewInstance\n", " !Expression probability#16048 cannot run on GPU because expression AttributeReference probability#16048 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\n", " !Expression obj#16123 cannot run on GPU because expression AttributeReference obj#16123 produces an unsupported type ObjectType(interface org.apache.spark.sql.Row)\n", " !Exec cannot run on GPU because unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [probability#16048]; not all expressions can be replaced\n", " @Expression pythonUDF0#16414.prediction AS prediction#16015 could run on GPU\n", " @Expression pythonUDF0#16414.prediction could run on GPU\n", " @Expression pythonUDF0#16414 could run on GPU\n", " @Expression cast(delinquency_12#27 as double) AS delinquency_12#16117 could run on GPU\n", " @Expression cast(delinquency_12#27 as double) could run on GPU\n", " @Expression delinquency_12#27 could run on GPU\n", " @Expression 1.0 AS 1.0#16118 could run on GPU\n", " @Expression 1.0 could run on GPU\n", " !Expression UDF(pythonUDF0#16414.probability) AS probability#16048 cannot run on GPU because expression Alias UDF(pythonUDF0#16414.probability) AS probability#16048 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7; input expression ScalaUDF UDF(pythonUDF0#16414.probability) (org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 is not supported)\n", " !Expression UDF(pythonUDF0#16414.probability) cannot run on GPU because expression ScalaUDF UDF(pythonUDF0#16414.probability) produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7; neither UDF implemented by class org.apache.spark.ml.functions$$$Lambda$3606/1625633331 provides a GPU implementation, nor the conf `spark.rapids.sql.rowBasedUDF.enabled` is enabled\n", " @Expression pythonUDF0#16414.probability could run on GPU\n", " @Expression pythonUDF0#16414 could run on GPU\n", "\n", "If features_cols param set, then features_col param is ignored.\n", "2022-11-25 09:35:56,267 WARN rapids.GpuOverrides: \n", "! cannot run on GPU because not all expressions can be replaced; GPU does not currently support the operator class org.apache.spark.sql.execution.DeserializeToObjectExec\n", " ! createexternalrow(prediction#18277, delinquency_12#18379, 1.0#18380, newInstance(class org.apache.spark.ml.linalg.VectorUDT).deserialize, StructField(prediction,DoubleType,true), StructField(delinquency_12,DoubleType,true), StructField(1.0,DoubleType,false), StructField(probability,org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7,true)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.CreateExternalRow\n", " @Expression prediction#18277 could run on GPU\n", " @Expression delinquency_12#18379 could run on GPU\n", " @Expression 1.0#18380 could run on GPU\n", " ! newInstance(class org.apache.spark.ml.linalg.VectorUDT).deserialize cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.Invoke\n", " ! newInstance(class org.apache.spark.ml.linalg.VectorUDT) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.NewInstance\n", " !Expression probability#18310 cannot run on GPU because expression AttributeReference probability#18310 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\n", " !Expression obj#18385 cannot run on GPU because expression AttributeReference obj#18385 produces an unsupported type ObjectType(interface org.apache.spark.sql.Row)\n", " !Exec cannot run on GPU because unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [probability#18310]; not all expressions can be replaced\n", " @Expression pythonUDF0#18676.prediction AS prediction#18277 could run on GPU\n", " @Expression pythonUDF0#18676.prediction could run on GPU\n", " @Expression pythonUDF0#18676 could run on GPU\n", " @Expression cast(delinquency_12#27 as double) AS delinquency_12#18379 could run on GPU\n", " @Expression cast(delinquency_12#27 as double) could run on GPU\n", " @Expression delinquency_12#27 could run on GPU\n", " @Expression 1.0 AS 1.0#18380 could run on GPU\n", " @Expression 1.0 could run on GPU\n", " !Expression UDF(pythonUDF0#18676.probability) AS probability#18310 cannot run on GPU because expression Alias UDF(pythonUDF0#18676.probability) AS probability#18310 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7; input expression ScalaUDF UDF(pythonUDF0#18676.probability) (org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 is not supported)\n", " !Expression UDF(pythonUDF0#18676.probability) cannot run on GPU because neither UDF implemented by class org.apache.spark.ml.functions$$$Lambda$3606/1625633331 provides a GPU implementation, nor the conf `spark.rapids.sql.rowBasedUDF.enabled` is enabled; expression ScalaUDF UDF(pythonUDF0#18676.probability) produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\n", " @Expression pythonUDF0#18676.probability could run on GPU\n", " @Expression pythonUDF0#18676 could run on GPU\n", "\n", "If features_cols param set, then features_col param is ignored.\n", "[Stage 69:> (0 + 1) / 1]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Cross-Validation takes 59.46 seconds\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\r", " \r" ] } ], "source": [ "def with_benchmark(phrase, action):\n", " start = time()\n", " result = action()\n", " end = time()\n", " print('{} takes {} seconds'.format(phrase, round(end - start, 2)))\n", " return result\n", "model = with_benchmark('Cross-Validation', lambda: cross_validator.fit(train_data)).bestModel" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Transform On the Best Model" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-11-25 09:35:59,886 WARN rapids.GpuOverrides: \n", "!Exec cannot run on GPU because unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [rawPrediction#18908, probability#18974]; not all expressions can be replaced\n", " @Expression orig_channel#56 could run on GPU\n", " @Expression first_home_buyer#57 could run on GPU\n", " @Expression loan_purpose#58 could run on GPU\n", " @Expression property_type#59 could run on GPU\n", " @Expression occupancy_status#60 could run on GPU\n", " @Expression property_state#61 could run on GPU\n", " @Expression product_type#62 could run on GPU\n", " @Expression relocation_mortgage_indicator#63 could run on GPU\n", " @Expression seller_name#64 could run on GPU\n", " @Expression mod_flag#65 could run on GPU\n", " @Expression orig_interest_rate#66 could run on GPU\n", " @Expression orig_upb#67 could run on GPU\n", " @Expression orig_loan_term#68 could run on GPU\n", " @Expression orig_ltv#69 could run on GPU\n", " @Expression orig_cltv#70 could run on GPU\n", " @Expression num_borrowers#71 could run on GPU\n", " @Expression dti#72 could run on GPU\n", " @Expression borrower_credit_score#73 could run on GPU\n", " @Expression num_units#74 could run on GPU\n", " @Expression zip#75 could run on GPU\n", " @Expression mortgage_insurance_percent#76 could run on GPU\n", " @Expression current_loan_delinquency_status#77 could run on GPU\n", " @Expression current_actual_upb#78 could run on GPU\n", " @Expression interest_rate#79 could run on GPU\n", " @Expression loan_age#80 could run on GPU\n", " @Expression msa#81 could run on GPU\n", " @Expression non_interest_bearing_upb#82 could run on GPU\n", " @Expression delinquency_12#83 could run on GPU\n", " !Expression UDF(pythonUDF0#19041.rawPrediction) AS rawPrediction#18908 cannot run on GPU because input expression ScalaUDF UDF(pythonUDF0#19041.rawPrediction) (org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 is not supported); expression Alias UDF(pythonUDF0#19041.rawPrediction) AS rawPrediction#18908 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\n", " !Expression UDF(pythonUDF0#19041.rawPrediction) cannot run on GPU because neither UDF implemented by class org.apache.spark.ml.functions$$$Lambda$3606/1625633331 provides a GPU implementation, nor the conf `spark.rapids.sql.rowBasedUDF.enabled` is enabled; expression ScalaUDF UDF(pythonUDF0#19041.rawPrediction) produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\n", " @Expression pythonUDF0#19041.rawPrediction could run on GPU\n", " @Expression pythonUDF0#19041 could run on GPU\n", " @Expression pythonUDF0#19041.prediction AS prediction#18942 could run on GPU\n", " @Expression pythonUDF0#19041.prediction could run on GPU\n", " @Expression pythonUDF0#19041 could run on GPU\n", " !Expression UDF(pythonUDF0#19041.probability) AS probability#18974 cannot run on GPU because expression Alias UDF(pythonUDF0#19041.probability) AS probability#18974 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7; input expression ScalaUDF UDF(pythonUDF0#19041.probability) (org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 is not supported)\n", " !Expression UDF(pythonUDF0#19041.probability) cannot run on GPU because neither UDF implemented by class org.apache.spark.ml.functions$$$Lambda$3606/1625633331 provides a GPU implementation, nor the conf `spark.rapids.sql.rowBasedUDF.enabled` is enabled; expression ScalaUDF UDF(pythonUDF0#19041.probability) produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\n", " @Expression pythonUDF0#19041.probability could run on GPU\n", " @Expression pythonUDF0#19041 could run on GPU\n", "\n", "2022-11-25 09:35:59,893 WARN rapids.GpuOverrides: \n", "!Exec cannot run on GPU because unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [rawPrediction, probability]; unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [rawPrediction#18908, probability#18974]; not all expressions can be replaced\n", " @Expression orig_channel#56 could run on GPU\n", " @Expression first_home_buyer#57 could run on GPU\n", " @Expression loan_purpose#58 could run on GPU\n", " @Expression property_type#59 could run on GPU\n", " @Expression occupancy_status#60 could run on GPU\n", " @Expression property_state#61 could run on GPU\n", " @Expression product_type#62 could run on GPU\n", " @Expression relocation_mortgage_indicator#63 could run on GPU\n", " @Expression seller_name#64 could run on GPU\n", " @Expression mod_flag#65 could run on GPU\n", " @Expression orig_interest_rate#66 could run on GPU\n", " @Expression orig_upb#67 could run on GPU\n", " @Expression orig_loan_term#68 could run on GPU\n", " @Expression orig_ltv#69 could run on GPU\n", " @Expression orig_cltv#70 could run on GPU\n", " @Expression num_borrowers#71 could run on GPU\n", " @Expression dti#72 could run on GPU\n", " @Expression borrower_credit_score#73 could run on GPU\n", " @Expression num_units#74 could run on GPU\n", " @Expression zip#75 could run on GPU\n", " @Expression mortgage_insurance_percent#76 could run on GPU\n", " @Expression current_loan_delinquency_status#77 could run on GPU\n", " @Expression current_actual_upb#78 could run on GPU\n", " @Expression interest_rate#79 could run on GPU\n", " @Expression loan_age#80 could run on GPU\n", " @Expression msa#81 could run on GPU\n", " @Expression non_interest_bearing_upb#82 could run on GPU\n", " @Expression delinquency_12#83 could run on GPU\n", " !Expression rawPrediction#18908 cannot run on GPU because expression AttributeReference rawPrediction#18908 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\n", " @Expression prediction#18942 could run on GPU\n", " !Expression probability#18974 cannot run on GPU because expression AttributeReference probability#18974 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\n", "\n", "2022-11-25 09:36:00,975 WARN rapids.GpuOverrides: \n", "!Exec cannot run on GPU because the Exec CollectLimitExec has been disabled, and is disabled by default because Collect Limit replacement can be slower on the GPU, if huge number of rows in a batch it could help by limiting the number of rows transferred from GPU to CPU. Set spark.rapids.sql.exec.CollectLimitExec to true if you wish to enable it\n", " @Partitioning could run on GPU\n", " !Exec cannot run on GPU because unsupported data types in input: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [probability#18974, rawPrediction#18908]; not all expressions can be replaced\n", " @Expression cast(delinquency_12#83 as string) AS delinquency_12#19670 could run on GPU\n", " @Expression cast(delinquency_12#83 as string) could run on GPU\n", " @Expression delinquency_12#83 could run on GPU\n", " @Expression cast(rawPrediction#18908 as string) AS rawPrediction#19671 could run on GPU\n", " !Expression cast(rawPrediction#18908 as string) cannot run on GPU because Cast from org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 to StringType is not supported\n", " !Expression rawPrediction#18908 cannot run on GPU because expression AttributeReference rawPrediction#18908 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\n", " @Expression cast(probability#18974 as string) AS probability#19672 could run on GPU\n", " !Expression cast(probability#18974 as string) cannot run on GPU because Cast from org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 to StringType is not supported\n", " !Expression probability#18974 cannot run on GPU because expression AttributeReference probability#18974 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\n", " @Expression cast(prediction#18942 as string) AS prediction#19673 could run on GPU\n", " @Expression cast(prediction#18942 as string) could run on GPU\n", " @Expression prediction#18942 could run on GPU\n", " !Exec cannot run on GPU because unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [probability#18974, rawPrediction#18908]; unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [rawPrediction, probability]; not all expressions can be replaced\n", " @Expression delinquency_12#83 could run on GPU\n", " @Expression prediction#18942 could run on GPU\n", " !Expression probability#18974 cannot run on GPU because expression AttributeReference probability#18974 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\n", " !Expression rawPrediction#18908 cannot run on GPU because expression AttributeReference rawPrediction#18908 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\n", "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Transforming takes 1.15 seconds\n", "+--------------+--------------------+--------------------+----------+\n", "|delinquency_12| rawPrediction| probability|prediction|\n", "+--------------+--------------------+--------------------+----------+\n", "| 0|[10.2152490615844...|[0.99996340274810...| 0.0|\n", "| 0|[8.85215473175048...|[0.99985694885253...| 0.0|\n", "| 0|[8.85215473175048...|[0.99985694885253...| 0.0|\n", "| 0|[8.85215473175048...|[0.99985694885253...| 0.0|\n", "| 0|[10.2152490615844...|[0.99996340274810...| 0.0|\n", "+--------------+--------------------+--------------------+----------+\n", "only showing top 5 rows\n", "\n" ] } ], "source": [ "def transform():\n", " result = model.transform(trans_data).cache()\n", " result.foreachPartition(lambda _: None)\n", " return result\n", "result = with_benchmark('Transforming', transform)\n", "result.select(label, 'rawPrediction', 'probability', 'prediction').show(5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Evaluation" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-11-25 09:36:01,155 WARN rapids.GpuOverrides: \n", "! cannot run on GPU because not all expressions can be replaced; GPU does not currently support the operator class org.apache.spark.sql.execution.DeserializeToObjectExec\n", " ! createexternalrow(prediction#18942, delinquency_12#20148, 1.0#20149, newInstance(class org.apache.spark.ml.linalg.VectorUDT).deserialize, StructField(prediction,DoubleType,true), StructField(delinquency_12,DoubleType,true), StructField(1.0,DoubleType,false), StructField(probability,org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7,true)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.CreateExternalRow\n", " @Expression prediction#18942 could run on GPU\n", " @Expression delinquency_12#20148 could run on GPU\n", " @Expression 1.0#20149 could run on GPU\n", " ! newInstance(class org.apache.spark.ml.linalg.VectorUDT).deserialize cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.Invoke\n", " ! newInstance(class org.apache.spark.ml.linalg.VectorUDT) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.NewInstance\n", " !Expression probability#18974 cannot run on GPU because expression AttributeReference probability#18974 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\n", " !Expression obj#20154 cannot run on GPU because expression AttributeReference obj#20154 produces an unsupported type ObjectType(interface org.apache.spark.sql.Row)\n", " !Exec cannot run on GPU because unsupported data types in input: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [probability#18974]; not all expressions can be replaced; unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [probability#18974]\n", " @Expression prediction#18942 could run on GPU\n", " @Expression cast(delinquency_12#83 as double) AS delinquency_12#20148 could run on GPU\n", " @Expression cast(delinquency_12#83 as double) could run on GPU\n", " @Expression delinquency_12#83 could run on GPU\n", " @Expression 1.0 AS 1.0#20149 could run on GPU\n", " @Expression 1.0 could run on GPU\n", " !Expression probability#18974 cannot run on GPU because expression AttributeReference probability#18974 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\n", " !Exec cannot run on GPU because unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [rawPrediction, probability]; not all expressions can be replaced; unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [probability#18974]\n", " @Expression delinquency_12#83 could run on GPU\n", " @Expression prediction#18942 could run on GPU\n", " !Expression probability#18974 cannot run on GPU because expression AttributeReference probability#18974 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\n", "\n", "[Stage 72:> (0 + 1) / 1]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluation takes 1.41 seconds\n", "Accuracy is 1.0\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\r", " \r" ] } ], "source": [ "accuracy = with_benchmark(\n", " 'Evaluation',\n", " lambda: MulticlassClassificationEvaluator().setLabelCol(label).evaluate(result))\n", "print('Accuracy is ' + str(accuracy))" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "spark.stop()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.2" } }, "nbformat": 4, "nbformat_minor": 2 } ================================================ FILE: examples/XGBoost-Examples/mortgage/notebooks/python/mortgage-gpu.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Introduction to XGBoost Spark with GPU\n", "\n", "The goal of this notebook is to show how to train a XGBoost Model with Spark RAPIDS XGBoost library on GPUs. The dataset used with this notebook is derived from Fannie Mae’s Single-Family Loan Performance Data with all rights reserved by Fannie Mae. This processed dataset is redistributed with permission and consent from Fannie Mae. This notebook uses XGBoost to train 12-month mortgage loan delinquency prediction model .\n", "\n", "A few libraries required for this notebook:\n", " 1. cudf-cu11\n", " 2. xgboost\n", " 3. scikit-learn\n", " 4. numpy\n", "\n", "This notebook also illustrates the ease of porting a sample CPU based Spark xgboost4j code into GPU. There is no change required for running Spark XGBoost on GPU because both CPU and GPU call the same API. For CPU run, we need to vectorize the trained dataset before fitting data to classifier." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Import All Libraries" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", "# if you pass/unpack the archive file and enable the environment\n", "# os.environ['PYSPARK_PYTHON'] = \"./environment/bin/python\"" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from xgboost.spark import SparkXGBClassifier, SparkXGBClassifierModel\n", "from pyspark.ml.evaluation import MulticlassClassificationEvaluator\n", "from pyspark.sql import SparkSession\n", "from pyspark.sql.types import FloatType, IntegerType, StructField, StructType, DoubleType\n", "from pyspark.conf import SparkConf\n", "from time import time" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Besides CPU version requires two extra libraries.\n", "```Python\n", "from pyspark.ml.feature import VectorAssembler\n", "from pyspark.sql.functions import col\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Create Spark Session and Data Reader" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Setting default log level to \"WARN\".\n", "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n", "22/11/24 06:14:05 WARN org.apache.spark.resource.ResourceUtils: The configuration of cores (exec = 4 task = 1, runnable tasks = 4) will result in wasted resources due to resource gpu limiting the number of runnable tasks per executor to: 1. Please adjust your configuration.\n", "22/11/24 06:14:06 INFO org.apache.spark.SparkEnv: Registering MapOutputTracker\n", "22/11/24 06:14:06 INFO org.apache.spark.SparkEnv: Registering BlockManagerMaster\n", "22/11/24 06:14:06 INFO org.apache.spark.SparkEnv: Registering BlockManagerMasterHeartbeat\n", "22/11/24 06:14:06 INFO org.apache.spark.SparkEnv: Registering OutputCommitCoordinator\n", "22/11/24 06:14:07 WARN com.nvidia.spark.rapids.RapidsPluginUtils: RAPIDS Accelerator 25.02.1 using cudf 25.02.1.\n", "22/11/24 06:14:07 WARN com.nvidia.spark.rapids.RapidsPluginUtils: spark.rapids.sql.multiThreadedRead.numThreads is set to 20.\n", "22/11/24 06:14:07 WARN com.nvidia.spark.rapids.RapidsPluginUtils: RAPIDS Accelerator is enabled, to disable GPU support set `spark.rapids.sql.enabled` to false.\n", "22/11/24 06:14:07 WARN com.nvidia.spark.rapids.RapidsPluginUtils: spark.rapids.sql.explain is set to `NOT_ON_GPU`. Set it to 'NONE' to suppress the diagnostics logging about the query placement on the GPU.\n" ] } ], "source": [ "SPARK_MASTER_URL = os.getenv(\"SPARK_MASTER_URL\", \"/your-url\")\n", "RAPIDS_JAR = os.getenv(\"RAPIDS_JAR\", \"/your-jar-path\")\n", "\n", "# You need to update with your real hardware resource \n", "driverMem = os.getenv(\"DRIVER_MEM\", \"10g\")\n", "executorMem = os.getenv(\"EXECUTOR_MEM\", \"10g\")\n", "pinnedPoolSize = os.getenv(\"PINNED_POOL_SIZE\", \"2g\")\n", "concurrentGpuTasks = os.getenv(\"CONCURRENT_GPU_TASKS\", \"2\")\n", "executorCores = int(os.getenv(\"EXECUTOR_CORES\", \"4\"))\n", "\n", "# Common spark settings\n", "conf = SparkConf()\n", "conf.setMaster(SPARK_MASTER_URL)\n", "conf.setAppName(\"Microbenchmark on GPU\")\n", "conf.set(\"spark.driver.memory\", driverMem)\n", "## The tasks will run on GPU memory, so there is no need to set a high host memory\n", "conf.set(\"spark.executor.memory\", executorMem)\n", "## The tasks will run on GPU cores, so there is no need to use many cpu cores\n", "conf.set(\"spark.executor.cores\", executorCores)\n", "\n", "# Plugin settings\n", "conf.set(\"spark.executor.resource.gpu.amount\", \"1\")\n", "conf.set(\"spark.rapids.sql.concurrentGpuTasks\", concurrentGpuTasks)\n", "conf.set(\"spark.rapids.memory.pinnedPool.size\", pinnedPoolSize)\n", "##############note: only support value=1 see https://github.com/dmlc/xgboost/blame/master/python-package/xgboost/spark/core.py#L370-L374\n", "conf.set(\"spark.task.resource.gpu.amount\", 1) \n", "# since pyspark and xgboost share the same GPU, we disable RMM to avoid GPU OOM while training \n", "conf.set(\"spark.rapids.memory.gpu.pool\", \"NONE\")\n", "conf.set(\"spark.rapids.sql.enabled\", \"true\") \n", "conf.set(\"spark.plugins\", \"com.nvidia.spark.SQLPlugin\")\n", "conf.set(\"spark.sql.cache.serializer\",\"com.nvidia.spark.ParquetCachedBatchSerializer\")\n", "conf.set(\"spark.driver.extraClassPath\", RAPIDS_JAR)\n", "conf.set(\"spark.sql.execution.arrow.maxRecordsPerBatch\", 200000) \n", "conf.set(\"spark.executor.extraClassPath\", RAPIDS_JAR)\n", "conf.set(\"spark.jars\", RAPIDS_JAR)\n", "\n", "# if you pass/unpack the archive file and enable the environment\n", "# conf.set(\"spark.yarn.dist.archives\", \"your_pyspark_venv.tar.gz#environment\")\n", "\n", "# Create spark session\n", "spark = SparkSession.builder.config(conf=conf).getOrCreate()\n", "reader = spark.read" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Specify the Data Schema and Load the Data" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "label = 'delinquency_12'\n", "schema = StructType([\n", " StructField('orig_channel', FloatType()),\n", " StructField('first_home_buyer', FloatType()),\n", " StructField('loan_purpose', FloatType()),\n", " StructField('property_type', FloatType()),\n", " StructField('occupancy_status', FloatType()),\n", " StructField('property_state', FloatType()),\n", " StructField('product_type', FloatType()),\n", " StructField('relocation_mortgage_indicator', FloatType()),\n", " StructField('seller_name', FloatType()),\n", " StructField('mod_flag', FloatType()),\n", " StructField('orig_interest_rate', FloatType()),\n", " StructField('orig_upb', DoubleType()),\n", " StructField('orig_loan_term', IntegerType()),\n", " StructField('orig_ltv', FloatType()),\n", " StructField('orig_cltv', FloatType()),\n", " StructField('num_borrowers', FloatType()),\n", " StructField('dti', FloatType()),\n", " StructField('borrower_credit_score', FloatType()),\n", " StructField('num_units', IntegerType()),\n", " StructField('zip', IntegerType()),\n", " StructField('mortgage_insurance_percent', FloatType()),\n", " StructField('current_loan_delinquency_status', IntegerType()),\n", " StructField('current_actual_upb', FloatType()),\n", " StructField('interest_rate', FloatType()),\n", " StructField('loan_age', FloatType()),\n", " StructField('msa', FloatType()),\n", " StructField('non_interest_bearing_upb', FloatType()),\n", " StructField(label, IntegerType()),\n", "])\n", "features = [ x.name for x in schema if x.name != label ]\n", "\n", "# You need to update them to your real paths!\n", "dataRoot = os.getenv(\"DATA_ROOT\", \"/data\")\n", "train_path = dataRoot + \"/mortgage/output/train\"\n", "eval_path = dataRoot + \"/mortgage/output/eval\"\n", "\n", "data_format = 'parquet'\n", "has_header = 'true'\n", "if data_format == 'csv':\n", " train_data = reader.schema(schema).option('header',has_header).csv(train_path)\n", " trans_data = reader.schema(schema).option('header',has_header).csv(eval_path)\n", "else :\n", " train_data = reader.load(train_path)\n", " trans_data = reader.load(eval_path)\n", " " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note on CPU version, vectorization is required before fitting data to classifier, which means you need to assemble all feature columns into one column.\n", "\n", "```Python\n", "def vectorize(data_frame):\n", " to_floats = [ col(x.name).cast(FloatType()) for x in data_frame.schema ]\n", " return (VectorAssembler()\n", " .setInputCols(features)\n", " .setOutputCol('features')\n", " .transform(data_frame.select(to_floats))\n", " .select(col('features'), col(label)))\n", "\n", "train_data = vectorize(train_data)\n", "trans_data = vectorize(trans_data)\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Create a XGBoostClassifier" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "params = { \n", " \"tree_method\": \"hist\",\n", " \"grow_policy\": \"depthwise\",\n", " \"num_workers\": 1,\n", " \"device\": \"cuda\",\n", "}\n", "params['features_col'] = features\n", "params['label_col'] = label\n", " \n", "classifier = SparkXGBClassifier(**params)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The parameter `num_workers` should be set to the number of GPUs in Spark cluster for GPU version, while for CPU version it is usually equal to the number of the CPU cores.\n", "\n", "Concerning the device, GPU version only supports `cuda` currently, while `cpu` is designed and used here for CPU training.\n", "\n", "An example of CPU classifier:\n", "```\n", "classifier = SparkXGBClassifier(\n", " feature_col=features,\n", " label_col=label, \n", " num_workers=1024,\n", ")\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Train the Data with Benchmark" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "If features_cols param set, then features_col param is ignored.\n", "22/11/24 06:14:44 WARN org.apache.spark.sql.catalyst.util.package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.\n", "[Stage 12:> (0 + 1) / 1]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[06:15:10] WARNING: ../src/learner.cc:553: \n", " If you are loading a serialized model (like pickle in Python, RDS in R) generated by\n", " older XGBoost, please export the model by calling `Booster.save_model` from that version\n", " first, then load it back in current version. See:\n", "\n", " https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html\n", "\n", " for more details about differences between saving model and serializing.\n", "\n", "Training takes 28.6 seconds\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\r", " \r", "/home/yuali_nvidia_com/.local/lib/python3.8/site-packages/xgboost/sklearn.py:808: UserWarning: Loading a native XGBoost model with Scikit-Learn interface.\n", " warnings.warn(\"Loading a native XGBoost model with Scikit-Learn interface.\")\n" ] } ], "source": [ "def with_benchmark(phrase, action):\n", " start = time()\n", " result = action()\n", " end = time()\n", " print('{} takes {} seconds'.format(phrase, round(end - start, 2)))\n", " return result\n", "model = with_benchmark('Training', lambda: classifier.fit(train_data))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Save and Reload the Model" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "If features_cols param set, then features_col param is ignored.\n", " \r" ] } ], "source": [ "model.write().overwrite().save(dataRoot + '/model/mortgage')" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "loaded_model = SparkXGBClassifierModel().load(dataRoot + '/model/mortgage')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Transformation and Show Result Sample" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "22/11/24 06:15:13 WARN com.nvidia.spark.rapids.GpuOverrides: \n", "!Exec cannot run on GPU because not all expressions can be replaced; unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [rawPrediction#209, probability#275]\n", " @Expression orig_channel#56 could run on GPU\n", " @Expression first_home_buyer#57 could run on GPU\n", " @Expression loan_purpose#58 could run on GPU\n", " @Expression property_type#59 could run on GPU\n", " @Expression occupancy_status#60 could run on GPU\n", " @Expression property_state#61 could run on GPU\n", " @Expression product_type#62 could run on GPU\n", " @Expression relocation_mortgage_indicator#63 could run on GPU\n", " @Expression seller_name#64 could run on GPU\n", " @Expression mod_flag#65 could run on GPU\n", " @Expression orig_interest_rate#66 could run on GPU\n", " @Expression orig_upb#67 could run on GPU\n", " @Expression orig_loan_term#68 could run on GPU\n", " @Expression orig_ltv#69 could run on GPU\n", " @Expression orig_cltv#70 could run on GPU\n", " @Expression num_borrowers#71 could run on GPU\n", " @Expression dti#72 could run on GPU\n", " @Expression borrower_credit_score#73 could run on GPU\n", " @Expression num_units#74 could run on GPU\n", " @Expression zip#75 could run on GPU\n", " @Expression mortgage_insurance_percent#76 could run on GPU\n", " @Expression current_loan_delinquency_status#77 could run on GPU\n", " @Expression current_actual_upb#78 could run on GPU\n", " @Expression interest_rate#79 could run on GPU\n", " @Expression loan_age#80 could run on GPU\n", " @Expression msa#81 could run on GPU\n", " @Expression non_interest_bearing_upb#82 could run on GPU\n", " @Expression delinquency_12#83 could run on GPU\n", " !Expression UDF(pythonUDF0#342.rawPrediction) AS rawPrediction#209 cannot run on GPU because input expression ScalaUDF UDF(pythonUDF0#342.rawPrediction) (org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 is not supported); expression Alias UDF(pythonUDF0#342.rawPrediction) AS rawPrediction#209 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\n", " !Expression UDF(pythonUDF0#342.rawPrediction) cannot run on GPU because expression ScalaUDF UDF(pythonUDF0#342.rawPrediction) produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7; neither UDF implemented by class org.apache.spark.ml.functions$$$Lambda$3898/645590696 provides a GPU implementation, nor the conf `spark.rapids.sql.rowBasedUDF.enabled` is enabled\n", " @Expression pythonUDF0#342.rawPrediction could run on GPU\n", " @Expression pythonUDF0#342 could run on GPU\n", " @Expression pythonUDF0#342.prediction AS prediction#243 could run on GPU\n", " @Expression pythonUDF0#342.prediction could run on GPU\n", " @Expression pythonUDF0#342 could run on GPU\n", " !Expression UDF(pythonUDF0#342.probability) AS probability#275 cannot run on GPU because input expression ScalaUDF UDF(pythonUDF0#342.probability) (org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 is not supported); expression Alias UDF(pythonUDF0#342.probability) AS probability#275 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\n", " !Expression UDF(pythonUDF0#342.probability) cannot run on GPU because neither UDF implemented by class org.apache.spark.ml.functions$$$Lambda$3898/645590696 provides a GPU implementation, nor the conf `spark.rapids.sql.rowBasedUDF.enabled` is enabled; expression ScalaUDF UDF(pythonUDF0#342.probability) produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\n", " @Expression pythonUDF0#342.probability could run on GPU\n", " @Expression pythonUDF0#342 could run on GPU\n", "\n", "22/11/24 06:15:13 WARN com.nvidia.spark.rapids.GpuOverrides: \n", "!Exec cannot run on GPU because unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [rawPrediction, probability]; not all expressions can be replaced; unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [rawPrediction#209, probability#275]\n", " @Expression orig_channel#56 could run on GPU\n", " @Expression first_home_buyer#57 could run on GPU\n", " @Expression loan_purpose#58 could run on GPU\n", " @Expression property_type#59 could run on GPU\n", " @Expression occupancy_status#60 could run on GPU\n", " @Expression property_state#61 could run on GPU\n", " @Expression product_type#62 could run on GPU\n", " @Expression relocation_mortgage_indicator#63 could run on GPU\n", " @Expression seller_name#64 could run on GPU\n", " @Expression mod_flag#65 could run on GPU\n", " @Expression orig_interest_rate#66 could run on GPU\n", " @Expression orig_upb#67 could run on GPU\n", " @Expression orig_loan_term#68 could run on GPU\n", " @Expression orig_ltv#69 could run on GPU\n", " @Expression orig_cltv#70 could run on GPU\n", " @Expression num_borrowers#71 could run on GPU\n", " @Expression dti#72 could run on GPU\n", " @Expression borrower_credit_score#73 could run on GPU\n", " @Expression num_units#74 could run on GPU\n", " @Expression zip#75 could run on GPU\n", " @Expression mortgage_insurance_percent#76 could run on GPU\n", " @Expression current_loan_delinquency_status#77 could run on GPU\n", " @Expression current_actual_upb#78 could run on GPU\n", " @Expression interest_rate#79 could run on GPU\n", " @Expression loan_age#80 could run on GPU\n", " @Expression msa#81 could run on GPU\n", " @Expression non_interest_bearing_upb#82 could run on GPU\n", " @Expression delinquency_12#83 could run on GPU\n", " !Expression rawPrediction#209 cannot run on GPU because expression AttributeReference rawPrediction#209 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\n", " @Expression prediction#243 could run on GPU\n", " !Expression probability#275 cannot run on GPU because expression AttributeReference probability#275 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\n", "\n", "22/11/24 06:15:28 WARN com.nvidia.spark.rapids.GpuOverrides: \n", "!Exec cannot run on GPU because the Exec CollectLimitExec has been disabled, and is disabled by default because Collect Limit replacement can be slower on the GPU, if huge number of rows in a batch it could help by limiting the number of rows transferred from GPU to CPU. Set spark.rapids.sql.exec.CollectLimitExec to true if you wish to enable it\n", " @Partitioning could run on GPU\n", " !Exec cannot run on GPU because not all expressions can be replaced; unsupported data types in input: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [probability#275, rawPrediction#209]\n", " @Expression cast(delinquency_12#83 as string) AS delinquency_12#971 could run on GPU\n", " @Expression cast(delinquency_12#83 as string) could run on GPU\n", " @Expression delinquency_12#83 could run on GPU\n", " @Expression cast(rawPrediction#209 as string) AS rawPrediction#972 could run on GPU\n", " !Expression cast(rawPrediction#209 as string) cannot run on GPU because Cast from org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 to StringType is not supported\n", " !Expression rawPrediction#209 cannot run on GPU because expression AttributeReference rawPrediction#209 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\n", " @Expression cast(probability#275 as string) AS probability#973 could run on GPU\n", " !Expression cast(probability#275 as string) cannot run on GPU because Cast from org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 to StringType is not supported\n", " !Expression probability#275 cannot run on GPU because expression AttributeReference probability#275 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\n", " @Expression cast(prediction#243 as string) AS prediction#974 could run on GPU\n", " @Expression cast(prediction#243 as string) could run on GPU\n", " @Expression prediction#243 could run on GPU\n", " !Exec cannot run on GPU because unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [rawPrediction, probability]; unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [probability#275, rawPrediction#209]; not all expressions can be replaced\n", " @Expression delinquency_12#83 could run on GPU\n", " @Expression prediction#243 could run on GPU\n", " !Expression probability#275 cannot run on GPU because expression AttributeReference probability#275 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\n", " !Expression rawPrediction#209 cannot run on GPU because expression AttributeReference rawPrediction#209 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\n", "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Transformation takes 15.62 seconds\n", "+--------------+--------------------+--------------------+----------+\n", "|delinquency_12| rawPrediction| probability|prediction|\n", "+--------------+--------------------+--------------------+----------+\n", "| 0|[8.84631538391113...|[0.99985611438751...| 0.0|\n", "| 0|[9.41864871978759...|[0.99991881847381...| 0.0|\n", "| 0|[9.41864871978759...|[0.99991881847381...| 0.0|\n", "| 0|[9.41864871978759...|[0.99991881847381...| 0.0|\n", "| 0|[8.84631538391113...|[0.99985611438751...| 0.0|\n", "+--------------+--------------------+--------------------+----------+\n", "only showing top 5 rows\n", "\n" ] } ], "source": [ "def transform():\n", " result = loaded_model.transform(trans_data).cache()\n", " result.foreachPartition(lambda _: None)\n", " return result\n", "result = with_benchmark('Transformation', transform)\n", "result.select(label, 'rawPrediction', 'probability', 'prediction').show(5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Evaluation" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "def check_classification_accuracy(data_frame, label):\n", " accuracy = (MulticlassClassificationEvaluator()\n", " .setLabelCol(label)\n", " .evaluate(data_frame))\n", " print('-' * 100)\n", " print('Accuracy is ' + str(accuracy))" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "22/11/24 06:15:28 WARN com.nvidia.spark.rapids.GpuOverrides: \n", "! cannot run on GPU because not all expressions can be replaced; GPU does not currently support the operator class org.apache.spark.sql.execution.DeserializeToObjectExec\n", " ! createexternalrow(prediction#243, delinquency_12#1450, 1.0#1449, newInstance(class org.apache.spark.ml.linalg.VectorUDT).deserialize, StructField(prediction,DoubleType,true), StructField(delinquency_12,DoubleType,true), StructField(1.0,DoubleType,false), StructField(probability,org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7,true)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.CreateExternalRow\n", " @Expression prediction#243 could run on GPU\n", " @Expression delinquency_12#1450 could run on GPU\n", " @Expression 1.0#1449 could run on GPU\n", " ! newInstance(class org.apache.spark.ml.linalg.VectorUDT).deserialize cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.Invoke\n", " ! newInstance(class org.apache.spark.ml.linalg.VectorUDT) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.NewInstance\n", " !Expression probability#275 cannot run on GPU because expression AttributeReference probability#275 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\n", " !Expression obj#1455 cannot run on GPU because expression AttributeReference obj#1455 produces an unsupported type ObjectType(interface org.apache.spark.sql.Row)\n", " !Exec cannot run on GPU because unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [probability#275]; unsupported data types in input: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [probability#275]; not all expressions can be replaced\n", " @Expression prediction#243 could run on GPU\n", " @Expression cast(delinquency_12#83 as double) AS delinquency_12#1450 could run on GPU\n", " @Expression cast(delinquency_12#83 as double) could run on GPU\n", " @Expression delinquency_12#83 could run on GPU\n", " @Expression 1.0 AS 1.0#1449 could run on GPU\n", " @Expression 1.0 could run on GPU\n", " !Expression probability#275 cannot run on GPU because expression AttributeReference probability#275 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\n", " !Exec cannot run on GPU because unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [rawPrediction, probability]; unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [probability#275]; not all expressions can be replaced\n", " @Expression delinquency_12#83 could run on GPU\n", " @Expression prediction#243 could run on GPU\n", " !Expression probability#275 cannot run on GPU because expression AttributeReference probability#275 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\n", "\n", "[Stage 19:> (0 + 1) / 1]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "----------------------------------------------------------------------------------------------------\n", "Accuracy is 1.0\n", "Evaluation takes 2.29 seconds\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\r", " \r" ] } ], "source": [ "with_benchmark('Evaluation', lambda: check_classification_accuracy(result, label))" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "spark.stop()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.2" } }, "nbformat": 4, "nbformat_minor": 2 } ================================================ FILE: examples/XGBoost-Examples/mortgage/notebooks/scala/mortgage-ETL.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "id": "e82e9fb4", "metadata": {}, "source": [ "# Introduction to Mortgage ETL Job\n", "This is the mortgage ETL job to generate the input datasets for the mortgage Xgboost job.\n" ] }, { "cell_type": "markdown", "id": "d0c8c3fa", "metadata": {}, "source": [ "## Prerequirement\n", "### 1. Download data\n", "\n", "Refer to these [instructions](https://github.com/NVIDIA/spark-rapids-examples/blob/branch-23.12/docs/get-started/xgboost-examples/dataset/mortgage.md) to download the dataset.\n", "\n", "### 2. Download needed jars\n", "* [rapids-4-spark_2.12-26.02.0.jar](https://repo1.maven.org/maven2/com/nvidia/rapids-4-spark_2.12/26.02.0/rapids-4-spark_2.12-26.02.0.jar)\n", "\n", "### 3. Start Spark Standalone\n", "Before Running the script, please setup Spark standalone mode\n", "\n", "### 4. Add ENV\n", "```\n", "$ export SPARK_JARS=rapids-4-spark_2.12-26.02.0.jar\n", "\n", "```\n", "\n", "### 5.Start Jupyter Notebook with spylon-kernel or toree\n", "\n", "```\n", "$ jupyter notebook --allow-root --notebook-dir=${your-dir} --config=${your-configs}\n", "```\n", "\n", "## Import Libs" ] }, { "cell_type": "code", "execution_count": 1, "id": "3ecc912c", "metadata": {}, "outputs": [], "source": [ "import org.apache.hadoop.fs.Path\n", "import org.apache.spark.sql.expressions.Window\n", "import org.apache.spark.sql.functions._\n", "import org.apache.spark.sql.types._\n", "import org.apache.spark.sql.{Column, DataFrame, SparkSession}" ] }, { "cell_type": "markdown", "id": "b58fcd6d", "metadata": {}, "source": [ "## Script Settings\n", "\n", "### 1. File Path Settings\n", "* Define input file path" ] }, { "cell_type": "code", "execution_count": null, "id": "b2834c06", "metadata": {}, "outputs": [], "source": [ "val dataRoot = sys.env.getOrElse(\"DATA_ROOT\", \"/data\")\n", "val dataOut = sys.env.getOrElse(\"DATA_OUT\", \"/data\")\n", "val dataPath = dataRoot + \"/mortgage/input\"\n", "val outPath = dataOut + \"/mortgage/output\"\n", "val output_csv2parquet = dataOut + \"/mortgage/output/csv2parquet/\"\n", "val saveTrainEvalDataset = true" ] }, { "cell_type": "markdown", "id": "775a2c7b", "metadata": {}, "source": [ "## Function and Object Define\n", "### 1. Define the constants\n", "\n", "* Define input/output file schema (Performance and Acquisition)" ] }, { "cell_type": "code", "execution_count": 3, "id": "e557beb0", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "rawSchema = StructType(StructField(reference_pool_id,StringType,true), StructField(loan_id,LongType,true), StructField(monthly_reporting_period,StringType,true), StructField(orig_channel,StringType,true), StructField(seller_name,StringType,true), StructField(servicer,StringType,true), StructField(master_servicer,StringType,true), StructField(orig_interest_rate,DoubleType,true), StructField(interest_rate,DoubleType,true), StructField(orig_upb,IntegerType,true), StructField(upb_at_issuance,StringType,true), StructField(current_actual_upb,DoubleType,true), StructField(orig_loan_term,IntegerType,true), StructField(orig_date,StringType,true), StructField(first_pay_date,StringType,true), StructField(loan_age,DoubleType,true), StructField(remaining_months...\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "StructType(StructField(reference_pool_id,StringType,true), StructField(loan_id,LongType,true), StructField(monthly_reporting_period,StringType,true), StructField(orig_channel,StringType,true), StructField(seller_name,StringType,true), StructField(servicer,StringType,true), StructField(master_servicer,StringType,true), StructField(orig_interest_rate,DoubleType,true), StructField(interest_rate,DoubleType,true), StructField(orig_upb,IntegerType,true), StructField(upb_at_issuance,StringType,true), StructField(current_actual_upb,DoubleType,true), StructField(orig_loan_term,IntegerType,true), StructField(orig_date,StringType,true), StructField(first_pay_date,StringType,true), StructField(loan_age,DoubleType,true), StructField(remaining_months..." ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "// File schema\n", "val rawSchema = StructType(Array(\n", " StructField(\"reference_pool_id\", StringType),\n", " StructField(\"loan_id\", LongType),\n", " StructField(\"monthly_reporting_period\", StringType),\n", " StructField(\"orig_channel\", StringType),\n", " StructField(\"seller_name\", StringType),\n", " StructField(\"servicer\", StringType),\n", " StructField(\"master_servicer\", StringType),\n", " StructField(\"orig_interest_rate\", DoubleType),\n", " StructField(\"interest_rate\", DoubleType),\n", " StructField(\"orig_upb\", DoubleType),\n", " StructField(\"upb_at_issuance\", StringType),\n", " StructField(\"current_actual_upb\", DoubleType),\n", " StructField(\"orig_loan_term\", IntegerType),\n", " StructField(\"orig_date\", StringType),\n", " StructField(\"first_pay_date\", StringType), \n", " StructField(\"loan_age\", DoubleType),\n", " StructField(\"remaining_months_to_legal_maturity\", DoubleType),\n", " StructField(\"adj_remaining_months_to_maturity\", DoubleType),\n", " StructField(\"maturity_date\", StringType),\n", " StructField(\"orig_ltv\", DoubleType),\n", " StructField(\"orig_cltv\", DoubleType),\n", " StructField(\"num_borrowers\", DoubleType),\n", " StructField(\"dti\", DoubleType),\n", " StructField(\"borrower_credit_score\", DoubleType),\n", " StructField(\"coborrow_credit_score\", DoubleType),\n", " StructField(\"first_home_buyer\", StringType),\n", " StructField(\"loan_purpose\", StringType),\n", " StructField(\"property_type\", StringType),\n", " StructField(\"num_units\", IntegerType),\n", " StructField(\"occupancy_status\", StringType),\n", " StructField(\"property_state\", StringType),\n", " StructField(\"msa\", DoubleType),\n", " StructField(\"zip\", IntegerType),\n", " StructField(\"mortgage_insurance_percent\", DoubleType),\n", " StructField(\"product_type\", StringType),\n", " StructField(\"prepayment_penalty_indicator\", StringType),\n", " StructField(\"interest_only_loan_indicator\", StringType),\n", " StructField(\"interest_only_first_principal_and_interest_payment_date\", StringType),\n", " StructField(\"months_to_amortization\", StringType),\n", " StructField(\"current_loan_delinquency_status\", IntegerType),\n", " StructField(\"loan_payment_history\", StringType),\n", " StructField(\"mod_flag\", StringType),\n", " StructField(\"mortgage_insurance_cancellation_indicator\", StringType),\n", " StructField(\"zero_balance_code\", StringType),\n", " StructField(\"zero_balance_effective_date\", StringType),\n", " StructField(\"upb_at_the_time_of_removal\", StringType),\n", " StructField(\"repurchase_date\", StringType),\n", " StructField(\"scheduled_principal_current\", StringType),\n", " StructField(\"total_principal_current\", StringType),\n", " StructField(\"unscheduled_principal_current\", StringType),\n", " StructField(\"last_paid_installment_date\", StringType),\n", " StructField(\"foreclosed_after\", StringType),\n", " StructField(\"disposition_date\", StringType),\n", " StructField(\"foreclosure_costs\", DoubleType),\n", " StructField(\"prop_preservation_and_repair_costs\", DoubleType),\n", " StructField(\"asset_recovery_costs\", DoubleType),\n", " StructField(\"misc_holding_expenses\", DoubleType),\n", " StructField(\"holding_taxes\", DoubleType),\n", " StructField(\"net_sale_proceeds\", DoubleType),\n", " StructField(\"credit_enhancement_proceeds\", DoubleType),\n", " StructField(\"repurchase_make_whole_proceeds\", StringType),\n", " StructField(\"other_foreclosure_proceeds\", DoubleType),\n", " StructField(\"non_interest_bearing_upb\", DoubleType),\n", " StructField(\"principal_forgiveness_upb\", StringType),\n", " StructField(\"original_list_start_date\", StringType),\n", " StructField(\"original_list_price\", StringType),\n", " StructField(\"current_list_start_date\", StringType),\n", " StructField(\"current_list_price\", StringType),\n", " StructField(\"borrower_credit_score_at_issuance\", StringType),\n", " StructField(\"co-borrower_credit_score_at_issuance\", StringType),\n", " StructField(\"borrower_credit_score_current\", StringType),\n", " StructField(\"co-Borrower_credit_score_current\", StringType),\n", " StructField(\"mortgage_insurance_type\", DoubleType),\n", " StructField(\"servicing_activity_indicator\", StringType),\n", " StructField(\"current_period_modification_loss_amount\", StringType),\n", " StructField(\"cumulative_modification_loss_amount\", StringType),\n", " StructField(\"current_period_credit_event_net_gain_or_loss\", StringType),\n", " StructField(\"cumulative_credit_event_net_gain_or_loss\", StringType),\n", " StructField(\"homeready_program_indicator\", StringType),\n", " StructField(\"foreclosure_principal_write_off_amount\", StringType),\n", " StructField(\"relocation_mortgage_indicator\", StringType),\n", " StructField(\"zero_balance_code_change_date\", StringType),\n", " StructField(\"loan_holdback_indicator\", StringType),\n", " StructField(\"loan_holdback_effective_date\", StringType),\n", " StructField(\"delinquent_accrued_interest\", StringType),\n", " StructField(\"property_valuation_method\", StringType),\n", " StructField(\"high_balance_loan_indicator\", StringType),\n", " StructField(\"arm_initial_fixed-rate_period_lt_5_yr_indicator\", StringType),\n", " StructField(\"arm_product_type\", StringType),\n", " StructField(\"initial_fixed-rate_period\", StringType),\n", " StructField(\"interest_rate_adjustment_frequency\", StringType),\n", " StructField(\"next_interest_rate_adjustment_date\", StringType),\n", " StructField(\"next_payment_change_date\", StringType),\n", " StructField(\"index\", StringType),\n", " StructField(\"arm_cap_structure\", StringType),\n", " StructField(\"initial_interest_rate_cap_up_percent\", StringType),\n", " StructField(\"periodic_interest_rate_cap_up_percent\", StringType),\n", " StructField(\"lifetime_interest_rate_cap_up_percent\", StringType),\n", " StructField(\"mortgage_margin\", StringType),\n", " StructField(\"arm_balloon_indicator\", StringType),\n", " StructField(\"arm_plan_number\", StringType),\n", " StructField(\"borrower_assistance_plan\", StringType),\n", " StructField(\"hltv_refinance_option_indicator\", StringType),\n", " StructField(\"deal_name\", StringType),\n", " StructField(\"repurchase_make_whole_proceeds_flag\", StringType),\n", " StructField(\"alternative_delinquency_resolution\", StringType),\n", " StructField(\"alternative_delinquency_resolution_count\", StringType),\n", " StructField(\"total_deferral_amount\", StringType)\n", " )\n", " )" ] }, { "cell_type": "markdown", "id": "86af48b6", "metadata": {}, "source": [ "* Define seller name mapping" ] }, { "cell_type": "code", "execution_count": 4, "id": "69f193d7", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "defined object NameMapping\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "object NameMapping {\n", " /**\n", " * Returns a dataframe with two columns named based off of the column names passed in.\n", " * The fromColName has the original name we want to clean up, the toColName\n", " * will have the name we want to go to, the unambiguous name.\n", " */\n", " def apply(spark: SparkSession, fromColName: String, toColName: String): DataFrame = {\n", " import spark.sqlContext.implicits._\n", " broadcast(Seq(\n", " (\"WITMER FUNDING, LLC\", \"Witmer\"),\n", " (\"WELLS FARGO CREDIT RISK TRANSFER SECURITIES TRUST 2015\", \"Wells Fargo\"),\n", " (\"WELLS FARGO BANK, NA\" , \"Wells Fargo\"),\n", " (\"WELLS FARGO BANK, N.A.\" , \"Wells Fargo\"),\n", " (\"WELLS FARGO BANK, NA\" , \"Wells Fargo\"),\n", " (\"USAA FEDERAL SAVINGS BANK\" , \"USAA\"),\n", " (\"UNITED SHORE FINANCIAL SERVICES, LLC D\\\\/B\\\\/A UNITED WHOLESALE MORTGAGE\" , \"United Seq(e\"),\n", " (\"U.S. BANK N.A.\" , \"US Bank\"),\n", " (\"SUNTRUST MORTGAGE INC.\" , \"Suntrust\"),\n", " (\"STONEGATE MORTGAGE CORPORATION\" , \"Stonegate Mortgage\"),\n", " (\"STEARNS LENDING, LLC\" , \"Stearns Lending\"),\n", " (\"STEARNS LENDING, INC.\" , \"Stearns Lending\"),\n", " (\"SIERRA PACIFIC MORTGAGE COMPANY, INC.\" , \"Sierra Pacific Mortgage\"),\n", " (\"REGIONS BANK\" , \"Regions\"),\n", " (\"RBC MORTGAGE COMPANY\" , \"RBC\"),\n", " (\"QUICKEN LOANS INC.\" , \"Quicken Loans\"),\n", " (\"PULTE MORTGAGE, L.L.C.\" , \"Pulte Mortgage\"),\n", " (\"PROVIDENT FUNDING ASSOCIATES, L.P.\" , \"Provident Funding\"),\n", " (\"PROSPECT MORTGAGE, LLC\" , \"Prospect Mortgage\"),\n", " (\"PRINCIPAL RESIDENTIAL MORTGAGE CAPITAL RESOURCES, LLC\" , \"Principal Residential\"),\n", " (\"PNC BANK, N.A.\" , \"PNC\"),\n", " (\"PMT CREDIT RISK TRANSFER TRUST 2015-2\" , \"PennyMac\"),\n", " (\"PHH MORTGAGE CORPORATION\" , \"PHH Mortgage\"),\n", " (\"PENNYMAC CORP.\" , \"PennyMac\"),\n", " (\"PACIFIC UNION FINANCIAL, LLC\" , \"Other\"),\n", " (\"OTHER\" , \"Other\"),\n", " (\"NYCB MORTGAGE COMPANY, LLC\" , \"NYCB\"),\n", " (\"NEW YORK COMMUNITY BANK\" , \"NYCB\"),\n", " (\"NETBANK FUNDING SERVICES\" , \"Netbank\"),\n", " (\"NATIONSTAR MORTGAGE, LLC\" , \"Nationstar Mortgage\"),\n", " (\"METLIFE BANK, NA\" , \"Metlife\"),\n", " (\"LOANDEPOT.COM, LLC\" , \"LoanDepot.com\"),\n", " (\"J.P. MORGAN MADISON AVENUE SECURITIES TRUST, SERIES 2015-1\" , \"JP Morgan Chase\"),\n", " (\"J.P. MORGAN MADISON AVENUE SECURITIES TRUST, SERIES 2014-1\" , \"JP Morgan Chase\"),\n", " (\"JPMORGAN CHASE BANK, NATIONAL ASSOCIATION\" , \"JP Morgan Chase\"),\n", " (\"JPMORGAN CHASE BANK, NA\" , \"JP Morgan Chase\"),\n", " (\"JP MORGAN CHASE BANK, NA\" , \"JP Morgan Chase\"),\n", " (\"IRWIN MORTGAGE, CORPORATION\" , \"Irwin Mortgage\"),\n", " (\"IMPAC MORTGAGE CORP.\" , \"Impac Mortgage\"),\n", " (\"HSBC BANK USA, NATIONAL ASSOCIATION\" , \"HSBC\"),\n", " (\"HOMEWARD RESIDENTIAL, INC.\" , \"Homeward Mortgage\"),\n", " (\"HOMESTREET BANK\" , \"Other\"),\n", " (\"HOMEBRIDGE FINANCIAL SERVICES, INC.\" , \"HomeBridge\"),\n", " (\"HARWOOD STREET FUNDING I, LLC\" , \"Harwood Mortgage\"),\n", " (\"GUILD MORTGAGE COMPANY\" , \"Guild Mortgage\"),\n", " (\"GMAC MORTGAGE, LLC (USAA FEDERAL SAVINGS BANK)\" , \"GMAC\"),\n", " (\"GMAC MORTGAGE, LLC\" , \"GMAC\"),\n", " (\"GMAC (USAA)\" , \"GMAC\"),\n", " (\"FREMONT BANK\" , \"Fremont Bank\"),\n", " (\"FREEDOM MORTGAGE CORP.\" , \"Freedom Mortgage\"),\n", " (\"FRANKLIN AMERICAN MORTGAGE COMPANY\" , \"Franklin America\"),\n", " (\"FLEET NATIONAL BANK\" , \"Fleet National\"),\n", " (\"FLAGSTAR CAPITAL MARKETS CORPORATION\" , \"Flagstar Bank\"),\n", " (\"FLAGSTAR BANK, FSB\" , \"Flagstar Bank\"),\n", " (\"FIRST TENNESSEE BANK NATIONAL ASSOCIATION\" , \"Other\"),\n", " (\"FIFTH THIRD BANK\" , \"Fifth Third Bank\"),\n", " (\"FEDERAL HOME LOAN BANK OF CHICAGO\" , \"Fedral Home of Chicago\"),\n", " (\"FDIC, RECEIVER, INDYMAC FEDERAL BANK FSB\" , \"FDIC\"),\n", " (\"DOWNEY SAVINGS AND LOAN ASSOCIATION, F.A.\" , \"Downey Mortgage\"),\n", " (\"DITECH FINANCIAL LLC\" , \"Ditech\"),\n", " (\"CITIMORTGAGE, INC.\" , \"Citi\"),\n", " (\"CHICAGO MORTGAGE SOLUTIONS DBA INTERFIRST MORTGAGE COMPANY\" , \"Chicago Mortgage\"),\n", " (\"CHICAGO MORTGAGE SOLUTIONS DBA INTERBANK MORTGAGE COMPANY\" , \"Chicago Mortgage\"),\n", " (\"CHASE HOME FINANCE, LLC\" , \"JP Morgan Chase\"),\n", " (\"CHASE HOME FINANCE FRANKLIN AMERICAN MORTGAGE COMPANY\" , \"JP Morgan Chase\"),\n", " (\"CHASE HOME FINANCE (CIE 1)\" , \"JP Morgan Chase\"),\n", " (\"CHASE HOME FINANCE\" , \"JP Morgan Chase\"),\n", " (\"CASHCALL, INC.\" , \"CashCall\"),\n", " (\"CAPITAL ONE, NATIONAL ASSOCIATION\" , \"Capital One\"),\n", " (\"CALIBER HOME LOANS, INC.\" , \"Caliber Funding\"),\n", " (\"BISHOPS GATE RESIDENTIAL MORTGAGE TRUST\" , \"Bishops Gate Mortgage\"),\n", " (\"BANK OF AMERICA, N.A.\" , \"Bank of America\"),\n", " (\"AMTRUST BANK\" , \"AmTrust\"),\n", " (\"AMERISAVE MORTGAGE CORPORATION\" , \"Amerisave\"),\n", " (\"AMERIHOME MORTGAGE COMPANY, LLC\" , \"AmeriHome Mortgage\"),\n", " (\"ALLY BANK\" , \"Ally Bank\"),\n", " (\"ACADEMY MORTGAGE CORPORATION\" , \"Academy Mortgage\"),\n", " (\"NO CASH-OUT REFINANCE\" , \"OTHER REFINANCE\"),\n", " (\"REFINANCE - NOT SPECIFIED\" , \"OTHER REFINANCE\"),\n", " (\"Other REFINANCE\" , \"OTHER REFINANCE\")\n", " ).toDF(fromColName, toColName))\n", " }\n", "}" ] }, { "cell_type": "markdown", "id": "42098a5a", "metadata": {}, "source": [ "### 2. Define ETL Process\n", "\n", "Define the function to do the ETL process\n", "\n", "* Define function to get quarter from input CSV file name" ] }, { "cell_type": "code", "execution_count": 5, "id": "f18cab51", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "defined object GetQuarterFromCsvFileName\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "object GetQuarterFromCsvFileName {\n", " // The format is path/TYPE_yyyy\\QQ.txt followed by a (_index)* where index is a single digit number [0-9]\n", " // i.e. mortgage/perf/Performance_2003Q4.txt_0_1\n", " // So we strip off the .txt and everything after it\n", " // and then take everything after the last remaining _\n", " def apply(): Column = substring_index(\n", " substring_index(input_file_name(), \".\", 1), \"/\", -1)\n", "}" ] }, { "cell_type": "markdown", "id": "ead44543", "metadata": {}, "source": [ "* Define category (string) column and numeric column" ] }, { "cell_type": "code", "execution_count": 6, "id": "9936e221", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "labelColName = delinquency_12\n", "categaryCols = List((orig_channel,FloatType), (first_home_buyer,FloatType), (loan_purpose,FloatType), (property_type,FloatType), (occupancy_status,FloatType), (property_state,FloatType), (product_type,FloatType), (relocation_mortgage_indicator,FloatType), (seller_name,FloatType), (mod_flag,FloatType))\n", "numericCols = List((orig_interest_rate,FloatType), (orig_upb,IntegerType), (orig_loan_term,IntegerType), (orig_ltv,FloatType), (orig_cltv,FloatType), (num_borrowers,FloatType), (dti,FloatType), (borrower_credit_score,FloatType), (num_units,IntegerType), (zip,IntegerType), (mortgage_insurance_percent,FloatType...\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "List((orig_interest_rate,FloatType), (orig_upb,IntegerType), (orig_loan_term,IntegerType), (orig_ltv,FloatType), (orig_cltv,FloatType), (num_borrowers,FloatType), (dti,FloatType), (borrower_credit_score,FloatType), (num_units,IntegerType), (zip,IntegerType), (mortgage_insurance_percent,FloatType..." ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "val labelColName = \"delinquency_12\"\n", "val categaryCols = List(\n", " (\"orig_channel\", FloatType),\n", " (\"first_home_buyer\", FloatType),\n", " (\"loan_purpose\", FloatType),\n", " (\"property_type\", FloatType),\n", " (\"occupancy_status\", FloatType),\n", " (\"property_state\", FloatType),\n", " (\"product_type\", FloatType),\n", " (\"relocation_mortgage_indicator\", FloatType),\n", " (\"seller_name\", FloatType),\n", " (\"mod_flag\", FloatType)\n", " )\n", "\n", "val numericCols = List(\n", " (\"orig_interest_rate\", FloatType),\n", " (\"orig_upb\", DoubleType),\n", " (\"orig_loan_term\", IntegerType),\n", " (\"orig_ltv\", FloatType),\n", " (\"orig_cltv\", FloatType),\n", " (\"num_borrowers\", FloatType),\n", " (\"dti\", FloatType),\n", " (\"borrower_credit_score\", FloatType),\n", " (\"num_units\", IntegerType),\n", " (\"zip\", IntegerType),\n", " (\"mortgage_insurance_percent\", FloatType),\n", " (\"current_loan_delinquency_status\", IntegerType),\n", " (\"current_actual_upb\", FloatType),\n", " (\"interest_rate\", FloatType),\n", " (\"loan_age\", FloatType),\n", " (\"msa\", FloatType),\n", " (\"non_interest_bearing_upb\", FloatType),\n", " (labelColName, IntegerType)\n", " )\n", "\n", "var cachedDictDF: DataFrame = _" ] }, { "cell_type": "markdown", "id": "6177b6b8", "metadata": {}, "source": [ "* Define Casting Process\n", "This part is casting String column to Numbric. \n", "Example:\n", "```\n", "col_1\n", " \"a\"\n", " \"b\"\n", " \"c\"\n", " \"a\"\n", "# After String ====> Numberic\n", "col_1\n", " 0\n", " 1\n", " 2\n", " 0\n", "``` \n", "
\n", "\n", "* Define function to get column dictionary\n", "\n", " Example\n", " ```\n", " col1 = [row(data=\"a\",id=0), row(data=\"b\",id=1)]\n", " ```" ] }, { "cell_type": "code", "execution_count": 7, "id": "5091c8a1", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "genDictionary: (etlDF: org.apache.spark.sql.DataFrame, colNames: Seq[String])org.apache.spark.sql.DataFrame\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "def genDictionary(etlDF: DataFrame, colNames: Seq[String]): DataFrame = {\n", " val cntTable = etlDF\n", " .select(posexplode(array(colNames.map(col(_)): _*)))\n", " .withColumnRenamed(\"pos\", \"column_id\")\n", " .withColumnRenamed(\"col\", \"data\")\n", " .filter(\"data is not null\")\n", " .groupBy(\"column_id\", \"data\")\n", " .count()\n", " val windowed = Window.partitionBy(\"column_id\").orderBy(desc(\"count\"))\n", " cntTable\n", " .withColumn(\"id\", row_number().over(windowed))\n", " .drop(\"count\")\n", " }" ] }, { "cell_type": "markdown", "id": "1466af65", "metadata": {}, "source": [ "* Define function to convert string columns to numeric" ] }, { "cell_type": "code", "execution_count": 8, "id": "9df8fe60", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "castStringColumnsToNumeric: (inputDF: org.apache.spark.sql.DataFrame, spark: org.apache.spark.sql.SparkSession)org.apache.spark.sql.DataFrame\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "def castStringColumnsToNumeric(inputDF: DataFrame, spark: SparkSession): DataFrame = {\n", " val cateColNames = categaryCols.map(_._1)\n", " cachedDictDF = genDictionary(inputDF, cateColNames).cache()\n", "\n", " // Generate the final table with all columns being numeric.\n", " cateColNames.foldLeft(inputDF) {\n", " case (df, colName) =>\n", " val colPos = cateColNames.indexOf(colName)\n", " val colDictDF = cachedDictDF\n", " .filter(col(\"column_id\") === colPos)\n", " .drop(\"column_id\")\n", " .withColumnRenamed(\"data\", colName)\n", " df.join(broadcast(colDictDF), Seq(colName), \"left\")\n", " .drop(colName)\n", " .withColumnRenamed(\"id\", colName)\n", " }\n", " }" ] }, { "cell_type": "code", "execution_count": 9, "id": "9e1fbb61", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "defined object extractPerfColumns\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "object extractPerfColumns{\n", " def apply(rawDf : DataFrame) : DataFrame = {\n", " val perfDf = rawDf.select(\n", " col(\"loan_id\"),\n", " date_format(to_date(col(\"monthly_reporting_period\"),\"MMyyyy\"), \"MM/dd/yyyy\").as(\"monthly_reporting_period\"),\n", " upper(col(\"servicer\")).as(\"servicer\"),\n", " col(\"interest_rate\"),\n", " col(\"current_actual_upb\"),\n", " col(\"loan_age\"),\n", " col(\"remaining_months_to_legal_maturity\"),\n", " col(\"adj_remaining_months_to_maturity\"),\n", " date_format(to_date(col(\"maturity_date\"),\"MMyyyy\"), \"MM/yyyy\").as(\"maturity_date\"),\n", " col(\"msa\"),\n", " col(\"current_loan_delinquency_status\"),\n", " col(\"mod_flag\"),\n", " col(\"zero_balance_code\"),\n", " date_format(to_date(col(\"zero_balance_effective_date\"),\"MMyyyy\"), \"MM/yyyy\").as(\"zero_balance_effective_date\"),\n", " date_format(to_date(col(\"last_paid_installment_date\"),\"MMyyyy\"), \"MM/dd/yyyy\").as(\"last_paid_installment_date\"),\n", " date_format(to_date(col(\"foreclosed_after\"),\"MMyyyy\"), \"MM/dd/yyyy\").as(\"foreclosed_after\"),\n", " date_format(to_date(col(\"disposition_date\"),\"MMyyyy\"), \"MM/dd/yyyy\").as(\"disposition_date\"),\n", " col(\"foreclosure_costs\"),\n", " col(\"prop_preservation_and_repair_costs\"),\n", " col(\"asset_recovery_costs\"),\n", " col(\"misc_holding_expenses\"),\n", " col(\"holding_taxes\"),\n", " col(\"net_sale_proceeds\"),\n", " col(\"credit_enhancement_proceeds\"),\n", " col(\"repurchase_make_whole_proceeds\"),\n", " col(\"other_foreclosure_proceeds\"),\n", " col(\"non_interest_bearing_upb\"),\n", " col(\"principal_forgiveness_upb\"),\n", " col(\"repurchase_make_whole_proceeds_flag\"),\n", " col(\"foreclosure_principal_write_off_amount\"),\n", " col(\"servicing_activity_indicator\"),\n", " col(\"quarter\")\n", " )\n", " \n", " perfDf.select(\"*\").filter(\"current_actual_upb != 0.0\")\n", " }\n", "}" ] }, { "cell_type": "code", "execution_count": 10, "id": "ce429163", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "defined object extractAcqColumns\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "object extractAcqColumns{\n", " def apply(rawDf : DataFrame) : DataFrame = {\n", " val acqDf = rawDf.select(\n", " col(\"loan_id\"),\n", " col(\"orig_channel\"),\n", " upper(col(\"seller_name\")).as(\"seller_name\"),\n", " col(\"orig_interest_rate\"),\n", " col(\"orig_upb\"),\n", " col(\"orig_loan_term\"),\n", " date_format(to_date(col(\"orig_date\"),\"MMyyyy\"), \"MM/yyyy\").as(\"orig_date\"),\n", " date_format(to_date(col(\"first_pay_date\"),\"MMyyyy\"), \"MM/yyyy\").as(\"first_pay_date\"),\n", " col(\"orig_ltv\"),\n", " col(\"orig_cltv\"),\n", " col(\"num_borrowers\"),\n", " col(\"dti\"),\n", " col(\"borrower_credit_score\"),\n", " col(\"first_home_buyer\"),\n", " col(\"loan_purpose\"),\n", " col(\"property_type\"),\n", " col(\"num_units\"),\n", " col(\"occupancy_status\"),\n", " col(\"property_state\"),\n", " col(\"zip\"),\n", " col(\"mortgage_insurance_percent\"),\n", " col(\"product_type\"),\n", " col(\"coborrow_credit_score\"),\n", " col(\"mortgage_insurance_type\"),\n", " col(\"relocation_mortgage_indicator\"),\n", " col(\"quarter\"),\n", " dense_rank().over(Window.partitionBy(\"loan_id\").orderBy(to_date(col(\"monthly_reporting_period\"),\"MMyyyy\"))).as(\"rank\")\n", " )\n", "\n", " acqDf.select(\"*\").filter(col(\"rank\") === 1)\n", " }\n", "\n", "}" ] }, { "cell_type": "markdown", "id": "37c64d85", "metadata": {}, "source": [ "* Build the spark session and data reader" ] }, { "cell_type": "code", "execution_count": 11, "id": "98d37174", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "sparkSession = org.apache.spark.sql.SparkSession@694178ec\n", "reader = org.apache.spark.sql.DataFrameReader@4b2afd51\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "org.apache.spark.sql.DataFrameReader@4b2afd51" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "// Build the spark session and data reader as usual\n", "val sparkSession = SparkSession.builder.appName(\"mortgage-gpu\").config(\"spark.sql.cache.serializer\", \"com.nvidia.spark.ParquetCachedBatchSerializer\").getOrCreate\n", "\n", "// GPU run, set to true\n", "sparkSession.conf.set(\"spark.rapids.sql.enabled\", true)\n", "// CPU run, set to false\n", "// sparkSession.conf.set('spark.rapids.sql.enabled', 'false')\n", "// remove config(\"spark.sql.cache.serializer\", \"com.nvidia.spark.ParquetCachedBatchSerializer\") for CPU\n", "sparkSession.conf.set(\"spark.sql.files.maxPartitionBytes\", \"1G\")\n", "sparkSession.conf.set(\"spark.sql.broadcastTimeout\", 700)\n", "// use GPU to read CSV\n", "sparkSession.conf.set(\"spark.rapids.sql.csv.read.double.enabled\", true)\n", "\n", "val reader = sparkSession.read.schema(rawSchema)" ] }, { "cell_type": "markdown", "id": "b47b5456", "metadata": {}, "source": [ "* Read CSV Files" ] }, { "cell_type": "code", "execution_count": 12, "id": "5bac2301", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "optionsMap = Map(header -> true)\n", "rawDf = [reference_pool_id: string, loan_id: bigint ... 107 more fields]\n", "perfSet = [loan_id: bigint, monthly_reporting_period: string ... 30 more fields]\n", "acqSet = [loan_id: bigint, orig_channel: string ... 25 more fields]\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "[loan_id: bigint, orig_channel: string ... 25 more fields]" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "val rawDf_csv = reader.option(\"header\", false)\n", " .option(\"nullValue\", \"\")\n", " .option(\"delimiter\", \"|\")\n", " .option(\"parserLib\", \"univocity\")\n", " .schema(rawSchema)\n", " .csv(dataPath)\n", " .withColumn(\"quarter\", GetQuarterFromCsvFileName())\n", "\n", "rawDf_csv.write.mode(\"overwrite\").parquet(output_csv2parquet)\n", "val rawDf = spark.read.parquet(output_csv2parquet)\n", "\n", "val perfSet = extractPerfColumns(rawDf)\n", "val acqSet = extractAcqColumns(rawDf)" ] }, { "cell_type": "markdown", "id": "f4c814c8", "metadata": {}, "source": [ "* Define ETL Object" ] }, { "cell_type": "code", "execution_count": 13, "id": "a16155cb", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "defined trait MortgageETL\n", "allCols = List(orig_channel, first_home_buyer, loan_purpose, property_type, occupancy_status, property_state, product_type, relocation_mortgage_indicator, seller_name, mod_flag, orig_interest_rate, orig_upb, orig_loan_term, orig_ltv, orig_cltv, num_borrowers, dti, borrower_credit_score, num_units, zip, mortgage_insurance_percent, current_loan_delinquency_status, current_actual_upb, interest_rate, loan_age, msa, non_interest_bearing_upb, delinquency_12)\n", "defined object PerformanceETL\n", "defined object AcquisitionETL\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "List(orig_channel, first_home_buyer, loan_purpose, property_type, occupancy_status, property_state, product_type, relocation_mortgage_indicator, seller_name, mod_flag, orig_interest_rate, orig_upb, orig_loan_term, orig_ltv, orig_cltv, num_borrowers, dti, borrower_credit_score, num_units, zip, mortgage_insurance_percent, current_loan_delinquency_status, current_actual_upb, interest_rate, loan_age, msa, non_interest_bearing_upb, delinquency_12)" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trait MortgageETL {\n", " var dataFrame: DataFrame = _\n", "\n", " def from(df: DataFrame): this.type = {\n", " dataFrame = df\n", " this\n", " }\n", "}\n", "val allCols = (categaryCols ++ numericCols).map(c => col(c._1))\n", "\n", "object PerformanceETL extends MortgageETL {\n", "\n", " def prepare: this.type = {\n", " dataFrame = dataFrame\n", " .withColumn(\"monthly_reporting_period\", to_date(col(\"monthly_reporting_period\"), \"MM/dd/yyyy\"))\n", " .withColumn(\"monthly_reporting_period_month\", month(col(\"monthly_reporting_period\")))\n", " .withColumn(\"monthly_reporting_period_year\", year(col(\"monthly_reporting_period\")))\n", " .withColumn(\"monthly_reporting_period_day\", dayofmonth(col(\"monthly_reporting_period\")))\n", " .withColumn(\"last_paid_installment_date\", to_date(col(\"last_paid_installment_date\"), \"MM/dd/yyyy\"))\n", " .withColumn(\"foreclosed_after\", to_date(col(\"foreclosed_after\"), \"MM/dd/yyyy\"))\n", " .withColumn(\"disposition_date\", to_date(col(\"disposition_date\"), \"MM/dd/yyyy\"))\n", " .withColumn(\"maturity_date\", to_date(col(\"maturity_date\"), \"MM/yyyy\"))\n", " .withColumn(\"zero_balance_effective_date\", to_date(col(\"zero_balance_effective_date\"), \"MM/yyyy\"))\n", " .withColumn(\"current_actual_upb\", col(\"current_actual_upb\"))\n", " .withColumn(\"current_loan_delinquency_status\", col(\"current_loan_delinquency_status\"))\n", " this\n", " }\n", "\n", " def createDelinquency(spark: SparkSession): this.type = {\n", " val aggDF = dataFrame\n", " .select(\n", " col(\"quarter\"),\n", " col(\"loan_id\"),\n", " col(\"current_loan_delinquency_status\"),\n", " when(col(\"current_loan_delinquency_status\") >= 1, col(\"monthly_reporting_period\")).alias(\"delinquency_30\"),\n", " when(col(\"current_loan_delinquency_status\") >= 3, col(\"monthly_reporting_period\")).alias(\"delinquency_90\"),\n", " when(col(\"current_loan_delinquency_status\") >= 6, col(\"monthly_reporting_period\")).alias(\"delinquency_180\")\n", " )\n", " .groupBy(\"quarter\", \"loan_id\")\n", " .agg(\n", " max(\"current_loan_delinquency_status\").alias(\"delinquency_12\"),\n", " min(\"delinquency_30\").alias(\"delinquency_30\"),\n", " min(\"delinquency_90\").alias(\"delinquency_90\"),\n", " min(\"delinquency_180\").alias(\"delinquency_180\")\n", " )\n", " .select(\n", " col(\"quarter\"),\n", " col(\"loan_id\"),\n", " (col(\"delinquency_12\") >= 1).alias(\"ever_30\"),\n", " (col(\"delinquency_12\") >= 3).alias(\"ever_90\"),\n", " (col(\"delinquency_12\") >= 6).alias(\"ever_180\"),\n", " col(\"delinquency_30\"),\n", " col(\"delinquency_90\"),\n", " col(\"delinquency_180\")\n", " )\n", "\n", " val joinedDf = dataFrame\n", " .withColumnRenamed(\"monthly_reporting_period\", \"timestamp\")\n", " .withColumnRenamed(\"monthly_reporting_period_month\", \"timestamp_month\")\n", " .withColumnRenamed(\"monthly_reporting_period_year\", \"timestamp_year\")\n", " .withColumnRenamed(\"current_loan_delinquency_status\", \"delinquency_12\")\n", " .withColumnRenamed(\"current_actual_upb\", \"upb_12\")\n", " .select(\"quarter\", \"loan_id\", \"timestamp\", \"delinquency_12\", \"upb_12\", \"timestamp_month\", \"timestamp_year\")\n", " .join(aggDF, Seq(\"loan_id\", \"quarter\"), \"left_outer\")\n", "\n", " // calculate the 12 month delinquency and upb values\n", " val months = 12\n", " val monthArray = 0.until(months).toArray\n", " val testDf = joinedDf\n", " // explode on a small amount of data is actually slightly more efficient than a cross join\n", " .withColumn(\"month_y\", explode(lit(monthArray)))\n", " .select(\n", " col(\"quarter\"),\n", " floor(((col(\"timestamp_year\") * 12 + col(\"timestamp_month\")) - 24000) / months).alias(\"josh_mody\"),\n", " floor(((col(\"timestamp_year\") * 12 + col(\"timestamp_month\")) - 24000 - col(\"month_y\")) / months).alias(\"josh_mody_n\"),\n", " col(\"ever_30\"),\n", " col(\"ever_90\"),\n", " col(\"ever_180\"),\n", " col(\"delinquency_30\"),\n", " col(\"delinquency_90\"),\n", " col(\"delinquency_180\"),\n", " col(\"loan_id\"),\n", " col(\"month_y\"),\n", " col(\"delinquency_12\"),\n", " col(\"upb_12\")\n", " )\n", " .groupBy(\"quarter\", \"loan_id\", \"josh_mody_n\", \"ever_30\", \"ever_90\", \"ever_180\", \"delinquency_30\", \"delinquency_90\", \"delinquency_180\", \"month_y\")\n", " .agg(max(\"delinquency_12\").alias(\"delinquency_12\"), min(\"upb_12\").alias(\"upb_12\"))\n", " .withColumn(\"timestamp_year\", floor((lit(24000) + (col(\"josh_mody_n\") * lit(months)) + (col(\"month_y\") - 1)) / lit(12)))\n", " .withColumn(\"timestamp_month_tmp\", pmod(lit(24000) + (col(\"josh_mody_n\") * lit(months)) + col(\"month_y\"), lit(12)))\n", " .withColumn(\"timestamp_month\", when(col(\"timestamp_month_tmp\") === lit(0), lit(12)).otherwise(col(\"timestamp_month_tmp\")))\n", " .withColumn(\"delinquency_12\", ((col(\"delinquency_12\") > 3).cast(\"int\") + (col(\"upb_12\") === 0).cast(\"int\")).alias(\"delinquency_12\"))\n", " .drop(\"timestamp_month_tmp\", \"josh_mody_n\", \"month_y\")\n", "\n", " dataFrame = dataFrame\n", " .withColumnRenamed(\"monthly_reporting_period_month\", \"timestamp_month\")\n", " .withColumnRenamed(\"monthly_reporting_period_year\", \"timestamp_year\")\n", " .join(testDf, Seq(\"quarter\", \"loan_id\", \"timestamp_year\", \"timestamp_month\"), \"left\").drop(\"timestamp_year\", \"timestamp_month\")\n", " this\n", " }\n", "}\n", "\n", "object AcquisitionETL extends MortgageETL {\n", "\n", " def createAcquisition(spark: SparkSession): this.type = {\n", " val nameMapping = NameMapping(spark, \"from_seller_name\", \"to_seller_name\")\n", " dataFrame = dataFrame\n", " .join(nameMapping, col(\"seller_name\") === col(\"from_seller_name\"), \"left\")\n", " .drop(\"from_seller_name\")\n", " /* backup the original name before we replace it */\n", " .withColumn(\"old_name\", col(\"seller_name\"))\n", " /* replace seller_name with the new version if we found one in the mapping, or the old version\n", " if we didn't */\n", " .withColumn(\"seller_name\", coalesce(col(\"to_seller_name\"), col(\"seller_name\")))\n", " .drop(\"to_seller_name\")\n", " .withColumn(\"orig_date\", to_date(col(\"orig_date\"), \"MM/yyyy\"))\n", " .withColumn(\"first_pay_date\", to_date(col(\"first_pay_date\"), \"MM/yyyy\"))\n", " this\n", " }\n", "\n", " def cleanPrime(perfDF: DataFrame): this.type = {\n", " dataFrame = perfDF.join(dataFrame, Seq(\"loan_id\", \"quarter\"), \"inner\").drop(\"quarter\")\n", " this\n", " }\n", "}\n" ] }, { "cell_type": "code", "execution_count": 14, "id": "78b76252", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "transform: (perfDF: org.apache.spark.sql.DataFrame, acqDF: org.apache.spark.sql.DataFrame, spark: org.apache.spark.sql.SparkSession)org.apache.spark.sql.DataFrame\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "def transform(perfDF: DataFrame, acqDF: DataFrame, spark: SparkSession): DataFrame = {\n", " val etlPerfDF = PerformanceETL.from(perfDF)\n", " .prepare\n", " .createDelinquency(spark)\n", " .dataFrame\n", " val cleanDF = AcquisitionETL.from(acqDF)\n", " .createAcquisition(spark)\n", " .cleanPrime(etlPerfDF)\n", " .dataFrame\n", "\n", " // Convert to xgb required Dataset\n", " castStringColumnsToNumeric(cleanDF, spark)\n", " .select(allCols: _*)\n", " .withColumn(labelColName, when(col(labelColName) > 0, 1).otherwise(0))\n", " .na.fill(0.0f)\n", " }" ] }, { "cell_type": "markdown", "id": "b1234f49", "metadata": {}, "source": [ "## Run ETL Process and Save the Result" ] }, { "cell_type": "code", "execution_count": 15, "id": "ffdb0a62", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Elapsed time : 399.241s\n" ] }, { "data": { "text/plain": [ "t0 = 1656695479451\n", "optionsMap = Map(header -> true)\n", "rawDF = [orig_channel: int, first_home_buyer: int ... 26 more fields]\n", "t1 = 1656695878692\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "1656695878692" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "val t0 = System.currentTimeMillis\n", "val rawDF = transform(\n", " perfSet,\n", " acqSet,\n", " sparkSession\n", " )\n", "\n", "val etlDataPath = new Path(outPath, \"data\").toString\n", "rawDF.write.mode(\"overwrite\").parquet(etlDataPath)\n", "\n", "if(saveTrainEvalDataset == true)\n", "{\n", " val etlDf = sparkSession.read.parquet(etlDataPath)\n", " val sets = etlDf.randomSplit(Array[Double](0.8, 0.2))\n", " val train = sets(0)\n", " val eval = sets(1)\n", " train.write.mode(\"overwrite\").parquet(new Path(outPath, \"train\").toString)\n", " eval.write.mode(\"overwrite\").parquet(new Path(outPath, \"eval\").toString)\n", "}\n", "\n", "\n", "val t1 = System.currentTimeMillis\n", "println(\"Elapsed time : \" + ((t1 - t0).toFloat / 1000) + \"s\")\n", "sparkSession.stop()" ] } ], "metadata": { "kernelspec": { "display_name": "XGBoost4j-Spark Scala", "language": "scala", "name": "XGBoost4j-Spark_scala" }, "language_info": { "codemirror_mode": "text/x-scala", "file_extension": ".scala", "mimetype": "text/x-scala", "name": "scala", "pygments_lexer": "scala", "version": "2.12.15" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: examples/XGBoost-Examples/mortgage/notebooks/scala/mortgage-gpu.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Introduction to XGBoost Spark with GPU\n", "\n", "The goal of this notebook is to show how to train a XGBoost Model with Spark RAPIDS XGBoost library on GPUs. The dataset used with this notebook is derived from Fannie Mae’s Single-Family Loan Performance Data with all rights reserved by Fannie Mae. This processed dataset is redistributed with permission and consent from Fannie Mae. This notebook uses XGBoost to train 12-month mortgage loan delinquency prediction model.\n", "\n", "## Load libraries\n", "First load some common libraries will be used by both GPU version and CPU version xgboost." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import ml.dmlc.xgboost4j.scala.spark.{XGBoostClassifier, XGBoostClassificationModel}\n", "import org.apache.spark.sql.SparkSession\n", "import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator\n", "import org.apache.spark.sql.types.{DoubleType, IntegerType, StructField, StructType}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Besides CPU version requires some extra libraries, such as:\n", "\n", "```scala\n", "import org.apache.spark.ml.feature.VectorAssembler\n", "import org.apache.spark.sql.DataFrame\n", "import org.apache.spark.sql.functions._\n", "import org.apache.spark.sql.types.FloatType\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Set the dataset path" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "// You need to update them to your real paths! The input data files is the output of mortgage-etl jobs\n", "val dataRoot = sys.env.getOrElse(\"DATA_ROOT\", \"/data\")\n", "val trainPath = dataRoot + \"/mortgage/output/train/\"\n", "val evalPath = dataRoot + \"/mortgage/output/eval/\"\n", "val transPath = dataRoot + \"/mortgage/output/eval/\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Build the schema and parameters\n", "The mortgage data has 27 columns: 26 features and 1 label. \"deinquency_12\" is the label column. The schema will be used to load data in the future.\n", "\n", "The next block also defines some key parameters used in xgboost training process." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "labelColName = delinquency_12\n", "schema = StructType(StructField(orig_channel,DoubleType,true), StructField(first_home_buyer,DoubleType,true), StructField(loan_purpose,DoubleType,true), StructField(property_type,DoubleType,true), StructField(occupancy_status,DoubleType,true), StructField(property_state,DoubleType,true), StructField(product_type,DoubleType,true), StructField(relocation_mortgage_indicator,DoubleType,true), StructField(seller_name,DoubleType,true), StructField(mod_flag,DoubleType,true), StructField(orig_interest_rate,DoubleType,true), StructField(orig_upb,IntegerType,true), StructField(orig_loan_term,IntegerType,true), StructField(orig_ltv,DoubleType,true), StructField(orig_cltv,DoubleType,true), StructField(num_borrowers,DoubleT...\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "StructType(StructField(orig_channel,DoubleType,true), StructField(first_home_buyer,DoubleType,true), StructField(loan_purpose,DoubleType,true), StructField(property_type,DoubleType,true), StructField(occupancy_status,DoubleType,true), StructField(property_state,DoubleType,true), StructField(product_type,DoubleType,true), StructField(relocation_mortgage_indicator,DoubleType,true), StructField(seller_name,DoubleType,true), StructField(mod_flag,DoubleType,true), StructField(orig_interest_rate,DoubleType,true), StructField(orig_upb,IntegerType,true), StructField(orig_loan_term,IntegerType,true), StructField(orig_ltv,DoubleType,true), StructField(orig_cltv,DoubleType,true), StructField(num_borrowers,DoubleT..." ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "val labelColName = \"delinquency_12\"\n", "val schema = StructType(List(\n", " StructField(\"orig_channel\", DoubleType),\n", " StructField(\"first_home_buyer\", DoubleType),\n", " StructField(\"loan_purpose\", DoubleType),\n", " StructField(\"property_type\", DoubleType),\n", " StructField(\"occupancy_status\", DoubleType),\n", " StructField(\"property_state\", DoubleType),\n", " StructField(\"product_type\", DoubleType),\n", " StructField(\"relocation_mortgage_indicator\", DoubleType),\n", " StructField(\"seller_name\", DoubleType),\n", " StructField(\"mod_flag\", DoubleType),\n", " StructField(\"orig_interest_rate\", DoubleType),\n", " StructField(\"orig_upb\", DoubleType),\n", " StructField(\"orig_loan_term\", IntegerType),\n", " StructField(\"orig_ltv\", DoubleType),\n", " StructField(\"orig_cltv\", DoubleType),\n", " StructField(\"num_borrowers\", DoubleType),\n", " StructField(\"dti\", DoubleType),\n", " StructField(\"borrower_credit_score\", DoubleType),\n", " StructField(\"num_units\", IntegerType),\n", " StructField(\"zip\", IntegerType),\n", " StructField(\"mortgage_insurance_percent\", DoubleType),\n", " StructField(\"current_loan_delinquency_status\", IntegerType),\n", " StructField(\"current_actual_upb\", DoubleType),\n", " StructField(\"interest_rate\", DoubleType),\n", " StructField(\"loan_age\", DoubleType),\n", " StructField(\"msa\", DoubleType),\n", " StructField(\"non_interest_bearing_upb\", DoubleType),\n", " StructField(labelColName, IntegerType)))\n", "\n", "val featureNames = schema.filter(_.name != labelColName).map(_.name).toArray\n", "\n", "val commParamMap = Map(\n", " \"objective\" -> \"binary:logistic\",\n", " \"num_round\" -> 100)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create a new spark session and load data\n", "\n", "A new spark session should be created to continue all the following spark operations.\n", "\n", "NOTE: in this notebook, the dependency jars have been loaded when installing toree kernel. Alternatively the jars can be loaded into notebook by [%AddJar magic](https://toree.incubator.apache.org/docs/current/user/faq/). However, there's one restriction for `%AddJar`: the jar uploaded can only be available when `AddJar` is called just after a new spark session is created. Do it as below:\n", "\n", "```scala\n", "import org.apache.spark.sql.SparkSession\n", "val spark = SparkSession.builder().appName(\"mortgage-GPU\").getOrCreate\n", "%AddJar file:/data/libs/rapids-4-spark-XXX.jar\n", "%AddJar file:/data/libs/xgboost4j-spark-gpu_2.12-XXX.jar\n", "%AddJar file:/data/libs/xgboost4j-gpu_2.12-XXX.jar\n", "// ...\n", "```\n", "\n", "##### Please note the new jar \"rapids-4-spark-XXX.jar\" is only needed for GPU version, you can not add it to dependence list for CPU version." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "sparkSession = org.apache.spark.sql.SparkSession@26420dda\n", "reader = org.apache.spark.sql.DataFrameReader@77740a8c\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "org.apache.spark.sql.DataFrameReader@77740a8c" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "// Build the spark session and data reader as usual\n", "val sparkSession = SparkSession.builder.appName(\"mortgage-gpu\").getOrCreate\n", "val reader = sparkSession.read" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "trainSet = [orig_channel: double, first_home_buyer: double ... 26 more fields]\n", "evalSet = [orig_channel: double, first_home_buyer: double ... 26 more fields]\n", "transSet = [orig_channel: double, first_home_buyer: double ... 26 more fields]\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "[orig_channel: double, first_home_buyer: double ... 26 more fields]" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "val trainSet = reader.parquet(trainPath)\n", "val evalSet = reader.parquet(evalPath)\n", "val transSet = reader.parquet(transPath)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Set xgboost parameters and build a XGBoostClassifier\n", "\n", "For CPU version, `num_workers` is recommended being equal to the number of CPU cores, while for GPU version, it should be set to the number of GPUs in Spark cluster.\n", "\n", "Besides the `device` for CPU version is also different from that for GPU version. Now only \"cuda\" is supported for training on GPU.\n", "\n", "```scala\n", "// difference in parameters\n", " \"num_workers\" -> 12,\n", " \"device\" -> \"cpu\",\n", "```" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "xgbParamFinal = Map(objective -> binary:logistic, num_round -> 100, tree_method -> hist, device -> cuda, num_workers -> 1)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "Map(objective -> binary:logistic, num_round -> 100, tree_method -> hist, device -> cuda, num_workers -> 1)" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": "val xgbParamFinal = commParamMap ++ Map(\"tree_method\" -> \"hist\", \"device\" -> \"cuda\", \"num_workers\" -> 1)" }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "xgbClassifier = xgbc_ecac6474dbb2\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "xgbc_ecac6474dbb2" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "val xgbClassifier = new XGBoostClassifier(xgbParamFinal)\n", " .setLabelCol(labelColName)\n", " .setFeaturesCol(featureNames)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Benchmark and train\n", "The object `benchmark` is used to compute the elapsed time of some operations.\n", "\n", "Training with evaluation dataset is also supported, the same as CPU version's behavior:\n", "\n", "* Call API `setEvalDataset` after initializing an XGBoostClassifier\n", "\n", "```scala\n", "xgbClassifier.setEvalDataset(evalSet)\n", "```" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "xgbc_ecac6474dbb2" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "xgbClassifier.setEvalDataset(evalSet)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "defined object Benchmark\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "object Benchmark {\n", " def time[R](phase: String)(block: => R): (R, Float) = {\n", " val t0 = System.currentTimeMillis\n", " val result = block // call-by-name\n", " val t1 = System.currentTimeMillis\n", " println(\"Elapsed time [\" + phase + \"]: \" + ((t1 - t0).toFloat / 1000) + \"s\")\n", " (result, (t1 - t0).toFloat / 1000)\n", " }\n", "}" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "------ Training ------\n", "Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=38315, DMLC_NUM_WORKER=1}\n" ] }, { "data": { "text/plain": [ "xgbClassificationModel = xgbc_ecac6474dbb2\n" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Elapsed time [train]: 8.083s\n" ] }, { "data": { "text/plain": [ "xgbc_ecac6474dbb2" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "// Start training\n", "println(\"\\n------ Training ------\")\n", "val (xgbClassificationModel, _) = Benchmark.time(\"train\") {\n", " xgbClassifier.fit(trainSet)\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Transformation and evaluation\n", "Here uses `transSet` to evaluate our model and prints some useful columns to show our prediction result. After that `MulticlassClassificationEvaluator` is used to calculate an overall accuracy of our predictions." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "------ Transforming ------\n", "Elapsed time [transform]: 1.916s\n", "+------------+--------------+--------------------+--------------------+----------+\n", "|orig_channel|delinquency_12| rawPrediction| probability|prediction|\n", "+------------+--------------+--------------------+--------------------+----------+\n", "| 0.0| 0|[7.57764625549316...|[0.99948849738575...| 0.0|\n", "| 0.0| 0|[8.74893283843994...|[0.99984139463049...| 0.0|\n", "| 0.0| 0|[8.74893283843994...|[0.99984139463049...| 0.0|\n", "| 0.0| 0|[8.74893283843994...|[0.99984139463049...| 0.0|\n", "| 0.0| 0|[7.57764625549316...|[0.99948849738575...| 0.0|\n", "| 0.0| 0|[7.57764625549316...|[0.99948849738575...| 0.0|\n", "| 0.0| 0|[7.57764625549316...|[0.99948849738575...| 0.0|\n", "| 0.0| 0|[6.58476591110229...|[0.99862065445631...| 0.0|\n", "| 0.0| 0|[7.98751401901245...|[0.99966043786844...| 0.0|\n", "| 0.0| 0|[7.21919107437133...|[0.99926814140053...| 0.0|\n", "+------------+--------------+--------------------+--------------------+----------+\n", "only showing top 10 rows\n", "\n", "\n", "------Accuracy of Evaluation------\n", "1.0\n" ] }, { "data": { "text/plain": [ "results = [orig_channel: double, first_home_buyer: double ... 29 more fields]\n", "evaluator = MulticlassClassificationEvaluator: uid=mcEval_d9645b60a007, metricName=f1, metricLabel=0.0, beta=1.0, eps=1.0E-15\n", "accuracy = 1.0\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "1.0" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "println(\"\\n------ Transforming ------\")\n", "val (results, _) = Benchmark.time(\"transform\") {\n", " val ret = xgbClassificationModel.transform(transSet).cache()\n", " ret.foreachPartition((_: Iterator[_]) => ())\n", " ret\n", "}\n", "results.select(\"orig_channel\", labelColName,\"rawPrediction\",\"probability\",\"prediction\").show(10)\n", "\n", "println(\"\\n------Accuracy of Evaluation------\")\n", "val evaluator = new MulticlassClassificationEvaluator().setLabelCol(labelColName)\n", "val accuracy = evaluator.evaluate(results)\n", "println(accuracy)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Save the model to disk and load model\n", "Save the model to disk and then load it to memory. After that use the loaded model to do a new prediction." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Elapsed time [transform2]: 0.044s\n" ] }, { "data": { "text/plain": [ "modelFromDisk = xgbc_ecac6474dbb2\n", "results2 = [orig_channel: double, first_home_buyer: double ... 29 more fields]\n" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "+------------+----------------+------------+-------------+----------------+--------------+------------+-----------------------------+-----------+--------+------------------+--------+--------------+--------+---------+-------------+----+---------------------+---------+---+--------------------------+-------------------------------+------------------+-------------+--------+-------+------------------------+--------------+--------------------+--------------------+----------+\n", "|orig_channel|first_home_buyer|loan_purpose|property_type|occupancy_status|property_state|product_type|relocation_mortgage_indicator|seller_name|mod_flag|orig_interest_rate|orig_upb|orig_loan_term|orig_ltv|orig_cltv|num_borrowers| dti|borrower_credit_score|num_units|zip|mortgage_insurance_percent|current_loan_delinquency_status|current_actual_upb|interest_rate|loan_age| msa|non_interest_bearing_upb|delinquency_12| rawPrediction| probability|prediction|\n", "+------------+----------------+------------+-------------+----------------+--------------+------------+-----------------------------+-----------+--------+------------------+--------+--------------+--------+---------+-------------+----+---------------------+---------+---+--------------------------+-------------------------------+------------------+-------------+--------+-------+------------------------+--------------+--------------------+--------------------+----------+\n", "| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 5.75| 81000| 360| 95.0| 0.0| 1.0|39.0| 696.0| 1|191| 30.0| -2| 7747.01| 5.75| 81.0|37980.0| 0.0| 0|[7.57764625549316...|[0.99948849738575...| 0.0|\n", "| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 5.75| 81000| 360| 95.0| 0.0| 1.0|39.0| 696.0| 1|191| 30.0| 0| 0.0| 5.75| 0.0|37980.0| 0.0| 0|[8.74893283843994...|[0.99984139463049...| 0.0|\n", "| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 5.75| 81000| 360| 95.0| 0.0| 1.0|39.0| 696.0| 1|191| 30.0| 0| 0.0| 5.75| 2.0|37980.0| 0.0| 0|[8.74893283843994...|[0.99984139463049...| 0.0|\n", "| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 5.75| 81000| 360| 95.0| 0.0| 1.0|39.0| 696.0| 1|191| 30.0| 0| 0.0| 5.75| 5.0|37980.0| 0.0| 0|[8.74893283843994...|[0.99984139463049...| 0.0|\n", "| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 5.75| 81000| 360| 95.0| 0.0| 1.0|39.0| 696.0| 1|191| 30.0| 0| 7747.01| 5.75| 80.0|37980.0| 0.0| 0|[7.57764625549316...|[0.99948849738575...| 0.0|\n", "| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 5.75| 81000| 360| 95.0| 0.0| 1.0|39.0| 696.0| 1|191| 30.0| 0| 13155.21| 5.75| 79.0|37980.0| 0.0| 0|[7.57764625549316...|[0.99948849738575...| 0.0|\n", "| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 5.75| 81000| 360| 95.0| 0.0| 1.0|39.0| 696.0| 1|191| 30.0| 0| 18526.93| 5.75| 78.0|37980.0| 0.0| 0|[7.57764625549316...|[0.99948849738575...| 0.0|\n", "| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 5.75| 81000| 360| 95.0| 0.0| 1.0|39.0| 696.0| 1|191| 30.0| 0| 23883.73| 5.75| 77.0|37980.0| 0.0| 0|[6.58476591110229...|[0.99862065445631...| 0.0|\n", "| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 5.75| 81000| 360| 95.0| 0.0| 1.0|39.0| 696.0| 1|191| 30.0| 0| 29214.98| 5.75| 76.0|37980.0| 0.0| 0|[7.98751401901245...|[0.99966043786844...| 0.0|\n", "| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 5.75| 81000| 360| 95.0| 0.0| 1.0|39.0| 696.0| 1|191| 30.0| 0| 34520.81| 5.75| 75.0|37980.0| 0.0| 0|[7.21919107437133...|[0.99926814140053...| 0.0|\n", "+------------+----------------+------------+-------------+----------------+--------------+------------+-----------------------------+-----------+--------+------------------+--------+--------------+--------+---------+-------------+----+---------------------+---------+---+--------------------------+-------------------------------+------------------+-------------+--------+-------+------------------------+--------------+--------------------+--------------------+----------+\n", "only showing top 10 rows\n", "\n" ] }, { "data": { "text/plain": [ "[orig_channel: double, first_home_buyer: double ... 29 more fields]" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "xgbClassificationModel.write.overwrite.save(dataRoot + \"/mortgage/model/\")\n", "\n", "val modelFromDisk = XGBoostClassificationModel.load(dataRoot + \"/mortgage/model/\")\n", "\n", "val (results2, _) = Benchmark.time(\"transform2\") {\n", " modelFromDisk.transform(transSet)\n", "}\n", "results2.show(10)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "sparkSession.close()" ] } ], "metadata": { "kernelspec": { "display_name": "XGBoost4j-Spark - Scala", "language": "scala", "name": "XGBoost4j-Spark_scala" }, "language_info": { "codemirror_mode": "text/x-scala", "file_extension": ".scala", "mimetype": "text/x-scala", "name": "scala", "pygments_lexer": "scala", "version": "2.12.15" } }, "nbformat": 4, "nbformat_minor": 2 } ================================================ FILE: examples/XGBoost-Examples/mortgage/notebooks/scala/mortgage_gpu_crossvalidation.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Mortgage CrossValidation with GPU accelerating on XGBoost\n", "\n", "In this notebook, we will show you how to levarage GPU to accelerate mortgage CrossValidation of XGBoost to find out the best model given a group of parameters.\n", "\n", "## Import classes\n", "First we need load some common classes that both GPU version and CPU version will use:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import ml.dmlc.xgboost4j.scala.spark.{XGBoostClassificationModel, XGBoostClassifier}\n", "\n", "import org.apache.spark.sql.SparkSession\n", "import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator\n", "import org.apache.spark.ml.tuning.{ParamGridBuilder,CrossValidator}\n", "import org.apache.spark.sql.types.{FloatType, IntegerType, StructField, StructType, DoubleType}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "what is new to xgboost-spark users is **rapids.CrossValidator**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Set dataset path" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "// You need to update them to your real paths!\n", "val dataRoot = sys.env.getOrElse(\"DATA_ROOT\", \"/data\")\n", "val trainParquetPath=dataRoot + \"/mortgage/output/train\"\n", "val evalParquetPath=dataRoot + \"/mortgage/output/eval\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Set the schema of the dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "val labelColName = \"delinquency_12\"\n", "val schema = StructType(List(\n", " StructField(\"orig_channel\", FloatType),\n", " StructField(\"first_home_buyer\", FloatType),\n", " StructField(\"loan_purpose\", FloatType),\n", " StructField(\"property_type\", FloatType),\n", " StructField(\"occupancy_status\", FloatType),\n", " StructField(\"property_state\", FloatType),\n", " StructField(\"product_type\", FloatType),\n", " StructField(\"relocation_mortgage_indicator\", FloatType),\n", " StructField(\"seller_name\", FloatType),\n", " StructField(\"mod_flag\", FloatType),\n", " StructField(\"orig_interest_rate\", FloatType),\n", " StructField(\"orig_upb\", DoubleType),\n", " StructField(\"orig_loan_term\", IntegerType),\n", " StructField(\"orig_ltv\", FloatType),\n", " StructField(\"orig_cltv\", FloatType),\n", " StructField(\"num_borrowers\", FloatType),\n", " StructField(\"dti\", FloatType),\n", " StructField(\"borrower_credit_score\", FloatType),\n", " StructField(\"num_units\", IntegerType),\n", " StructField(\"zip\", IntegerType),\n", " StructField(\"mortgage_insurance_percent\", FloatType),\n", " StructField(\"current_loan_delinquency_status\", IntegerType),\n", " StructField(\"current_actual_upb\", FloatType),\n", " StructField(\"interest_rate\", FloatType),\n", " StructField(\"loan_age\", FloatType),\n", " StructField(\"msa\", FloatType),\n", " StructField(\"non_interest_bearing_upb\", FloatType),\n", " StructField(labelColName, IntegerType)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create a new spark session and load data\n", "we must create a new spark session to continue all spark operations.\n", "\n", "NOTE: in this notebook, we have uploaded dependency jars when installing toree kernel. If we don't upload them at installation time, we can also upload in notebook by [%AddJar magic](https://toree.incubator.apache.org/docs/current/user/faq/). However, there's one restriction for `%AddJar`: the jar uploaded can only be available when `AddJar` is called after a new spark session is created. We must use it as below:\n", "\n", "```scala\n", "import org.apache.spark.sql.SparkSession\n", "val spark = SparkSession.builder().appName(\"mortgage-gpu-cv\").getOrCreate\n", "%AddJar file:/data/libs/xgboost4j-spark-gpu_2.12-XXX.jar\n", "%AddJar file:/data/libs/xgboost4j-gpu_2.12-XXX.jar\n", "// ...\n", "```" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "spark = org.apache.spark.sql.SparkSession@51af6ff3\n", "trainDs = [orig_channel: double, first_home_buyer: double ... 26 more fields]\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "[orig_channel: double, first_home_buyer: double ... 26 more fields]" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "val spark = SparkSession.builder().appName(\"mortgage-gpu-cv\").getOrCreate()\n", "val trainDs = spark.read.parquet(trainParquetPath)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Find out features to train" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "featureNames = Array(orig_channel, first_home_buyer, loan_purpose, property_type, occupancy_status, property_state, product_type, relocation_mortgage_indicator, seller_name, mod_flag, orig_interest_rate, orig_upb, orig_loan_term, orig_ltv, orig_cltv, num_borrowers, dti, borrower_credit_score, num_units, zip, mortgage_insurance_percent, current_loan_delinquency_status, current_actual_upb, interest_rate, loan_age, msa, non_interest_bearing_upb)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "Array(orig_channel, first_home_buyer, loan_purpose, property_type, occupancy_status, property_state, product_type, relocation_mortgage_indicator, seller_name, mod_flag, orig_interest_rate, orig_upb, orig_loan_term, orig_ltv, orig_cltv, num_borrowers, dti, borrower_credit_score, num_units, zip, mortgage_insurance_percent, current_loan_delinquency_status, current_actual_upb, interest_rate, loan_age, msa, non_interest_bearing_upb)" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "val featureNames = schema.filter(_.name != labelColName).map(_.name).toArray" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "classifierParam = Map(objective -> binary:logistic, num_round -> 100, num_workers -> 1, tree_method -> hist, device -> cuda)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "Map(objective -> binary:logistic, num_round -> 100, num_workers -> 1, tree_method -> hist, device -> cuda)" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "val classifierParam = Map(\n", " \"objective\" -> \"binary:logistic\",\n", " \"num_round\" -> 100,\n", " \"num_workers\" -> 1,\n", " \"tree_method\" -> \"hist\",\n", " \"device\" -> \"cuda\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Construct CrossValidator" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "classifier = xgbc_ae8896ab2b67\n", "paramGrid = \n", "evaluator = MulticlassClassificationEvaluator: uid=mcEval_ebda5b6cea6c, metricName=f1, metricLabel=0.0, beta=1.0, eps=1.0E-15\n", "cv = cv_cb7d8efe9ab5\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "Array({\n", "\txgbc_ae8896ab2b67-eta: 0.2,\n", "\txgbc_ae8896ab2b67-maxDepth: 3\n", "}, {\n", "\txgbc_ae8896ab2b67-eta: 0.2,\n", "\txgbc_ae8896ab2b67-maxDepth: 10\n", "}, {\n", "\txgbc_ae8896ab2b67-eta: 0.6,\n", "\txgbc_ae8896ab2b67-maxDepth: 3\n", "}, {\n", "\txgbc_ae8896ab2b67-eta: 0.6,\n", "\txgbc_ae8896ab2b67-maxDepth: 10\n", "})\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "cv_cb7d8efe9ab5" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "val classifier = new XGBoostClassifier(classifierParam)\n", " .setLabelCol(labelColName)\n", " .setFeaturesCol(featureNames)\n", "val paramGrid = new ParamGridBuilder()\n", " .addGrid(classifier.maxDepth, Array(3, 10))\n", " .addGrid(classifier.eta, Array(0.2, 0.6))\n", " .build()\n", "val evaluator = new MulticlassClassificationEvaluator().setLabelCol(labelColName)\n", "val cv = new CrossValidator()\n", " .setEstimator(classifier)\n", " .setEvaluator(evaluator)\n", " .setEstimatorParamMaps(paramGrid)\n", " .setNumFolds(3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## train with CrossValidator" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=41609, DMLC_NUM_WORKER=1}\n", "Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=45469, DMLC_NUM_WORKER=1}\n", "Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=52795, DMLC_NUM_WORKER=1}\n", "Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=53483, DMLC_NUM_WORKER=1}\n", "Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=58067, DMLC_NUM_WORKER=1}\n", "Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=43717, DMLC_NUM_WORKER=1}\n", "Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=36075, DMLC_NUM_WORKER=1}\n", "Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=53851, DMLC_NUM_WORKER=1}\n", "Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=42227, DMLC_NUM_WORKER=1}\n", "Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=46587, DMLC_NUM_WORKER=1}\n", "Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=51295, DMLC_NUM_WORKER=1}\n", "Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=54695, DMLC_NUM_WORKER=1}\n", "Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=54019, DMLC_NUM_WORKER=1}\n" ] }, { "data": { "text/plain": [ "model = xgbc_ae8896ab2b67\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "xgbc_ae8896ab2b67" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "val model = cv.fit(trainDs).bestModel.asInstanceOf[XGBoostClassificationModel]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## tranform with best model trained by CrossValidator" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "transformDs = [orig_channel: double, first_home_buyer: double ... 26 more fields]\n", "df = [orig_channel: double, first_home_buyer: double ... 29 more fields]\n" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "+--------------+--------------------+--------------------+----------+\n", "|delinquency_12| rawPrediction| probability|prediction|\n", "+--------------+--------------------+--------------------+----------+\n", "| 0|[17.3849449157714...|[0.99999997182821...| 0.0|\n", "| 0|[16.6074829101562...|[0.99999993869981...| 0.0|\n", "| 0|[16.0062618255615...|[0.99999988816731...| 0.0|\n", "| 0|[16.7623615264892...|[0.99999994749521...| 0.0|\n", "| 0|[15.1363153457641...|[0.99999973307967...| 0.0|\n", "+--------------+--------------------+--------------------+----------+\n", "only showing top 5 rows\n", "\n" ] }, { "data": { "text/plain": [ "[orig_channel: double, first_home_buyer: double ... 29 more fields]" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "val transformDs = spark.read.parquet(evalParquetPath)\n", "val df = model.transform(transformDs).cache()\n", "df.drop(featureNames: _*).show(5)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "evaluator = MulticlassClassificationEvaluator: uid=mcEval_d880c25944f1, metricName=f1, metricLabel=0.0, beta=1.0, eps=1.0E-15\n", "accuracy = 1.0\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "1.0" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "val evaluator = new MulticlassClassificationEvaluator().setLabelCol(labelColName)\n", "val accuracy = evaluator.evaluate(df)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "spark.close()" ] } ], "metadata": { "kernelspec": { "display_name": "XGBoost4j-Spark - Scala", "language": "scala", "name": "XGBoost4j-Spark_scala" }, "language_info": { "codemirror_mode": "text/x-scala", "file_extension": ".scala", "mimetype": "text/x-scala", "name": "scala", "pygments_lexer": "scala", "version": "2.12.15" } }, "nbformat": 4, "nbformat_minor": 2 } ================================================ FILE: examples/XGBoost-Examples/mortgage/pom.xml ================================================ sample_xgboost_examples com.nvidia 0.2.3-SNAPSHOT 4.0.0 spark_examples_mortgage_${scala.binary.version} 8 8 com.nvidia spark_examples_utility_${scala.binary.version} ${project.version} compile scala/src ================================================ FILE: examples/XGBoost-Examples/mortgage/python/com/__init__.py ================================================ # Copyright (c) 2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: examples/XGBoost-Examples/mortgage/python/com/nvidia/__init__.py ================================================ # Copyright (c) 2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: examples/XGBoost-Examples/mortgage/python/com/nvidia/spark/__init__.py ================================================ # Copyright (c) 2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: examples/XGBoost-Examples/mortgage/python/com/nvidia/spark/examples/__init__.py ================================================ # Copyright (c) 2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: examples/XGBoost-Examples/mortgage/python/com/nvidia/spark/examples/mortgage/__init__.py ================================================ # Copyright (c) 2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: examples/XGBoost-Examples/mortgage/python/com/nvidia/spark/examples/mortgage/consts.py ================================================ # # Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from pyspark.sql.types import * label = 'delinquency_12' schema = StructType([ StructField('orig_channel', FloatType()), StructField('first_home_buyer', FloatType()), StructField('loan_purpose', FloatType()), StructField('property_type', FloatType()), StructField('occupancy_status', FloatType()), StructField('property_state', FloatType()), StructField('product_type', FloatType()), StructField('relocation_mortgage_indicator', FloatType()), StructField('seller_name', FloatType()), StructField('mod_flag', FloatType()), StructField('orig_interest_rate', FloatType()), StructField('orig_upb', DoubleType()), StructField('orig_loan_term', IntegerType()), StructField('orig_ltv', FloatType()), StructField('orig_cltv', FloatType()), StructField('num_borrowers', FloatType()), StructField('dti', FloatType()), StructField('borrower_credit_score', FloatType()), StructField('num_units', IntegerType()), StructField('zip', IntegerType()), StructField('mortgage_insurance_percent', FloatType()), StructField('current_loan_delinquency_status', IntegerType()), StructField('current_actual_upb', FloatType()), StructField('interest_rate', FloatType()), StructField('loan_age', FloatType()), StructField('msa', FloatType()), StructField('non_interest_bearing_upb', FloatType()), StructField(label, IntegerType()), ]) name_mapping = { 'WITMER FUNDING, LLC': 'Witmer', 'WELLS FARGO CREDIT RISK TRANSFER SECURITIES TRUST 2015': 'Wells Fargo', 'WELLS FARGO BANK, NA': 'Wells Fargo', 'WELLS FARGO BANK, N.A.': 'Wells Fargo', 'WELLS FARGO BANK, NA': 'Wells Fargo', 'USAA FEDERAL SAVINGS BANK': 'USAA', 'UNITED SHORE FINANCIAL SERVICES, LLC D\\/B\\/A UNITED WHOLESALE MORTGAGE': 'United Seq(e', 'U.S. BANK N.A.': 'US Bank', 'SUNTRUST MORTGAGE INC.': 'Suntrust', 'STONEGATE MORTGAGE CORPORATION': 'Stonegate Mortgage', 'STEARNS LENDING, LLC': 'Stearns Lending', 'STEARNS LENDING, INC.': 'Stearns Lending', 'SIERRA PACIFIC MORTGAGE COMPANY, INC.': 'Sierra Pacific Mortgage', 'REGIONS BANK': 'Regions', 'RBC MORTGAGE COMPANY': 'RBC', 'QUICKEN LOANS INC.': 'Quicken Loans', 'PULTE MORTGAGE, L.L.C.': 'Pulte Mortgage', 'PROVIDENT FUNDING ASSOCIATES, L.P.': 'Provident Funding', 'PROSPECT MORTGAGE, LLC': 'Prospect Mortgage', 'PRINCIPAL RESIDENTIAL MORTGAGE CAPITAL RESOURCES, LLC': 'Principal Residential', 'PNC BANK, N.A.': 'PNC', 'PMT CREDIT RISK TRANSFER TRUST 2015-2': 'PennyMac', 'PHH MORTGAGE CORPORATION': 'PHH Mortgage', 'PENNYMAC CORP.': 'PennyMac', 'PACIFIC UNION FINANCIAL, LLC': 'Other', 'OTHER': 'Other', 'NYCB MORTGAGE COMPANY, LLC': 'NYCB', 'NEW YORK COMMUNITY BANK': 'NYCB', 'NETBANK FUNDING SERVICES': 'Netbank', 'NATIONSTAR MORTGAGE, LLC': 'Nationstar Mortgage', 'METLIFE BANK, NA': 'Metlife', 'LOANDEPOT.COM, LLC': 'LoanDepot.com', 'J.P. MORGAN MADISON AVENUE SECURITIES TRUST, SERIES 2015-1': 'JP Morgan Chase', 'J.P. MORGAN MADISON AVENUE SECURITIES TRUST, SERIES 2014-1': 'JP Morgan Chase', 'JPMORGAN CHASE BANK, NATIONAL ASSOCIATION': 'JP Morgan Chase', 'JPMORGAN CHASE BANK, NA': 'JP Morgan Chase', 'JP MORGAN CHASE BANK, NA': 'JP Morgan Chase', 'IRWIN MORTGAGE, CORPORATION': 'Irwin Mortgage', 'IMPAC MORTGAGE CORP.': 'Impac Mortgage', 'HSBC BANK USA, NATIONAL ASSOCIATION': 'HSBC', 'HOMEWARD RESIDENTIAL, INC.': 'Homeward Mortgage', 'HOMESTREET BANK': 'Other', 'HOMEBRIDGE FINANCIAL SERVICES, INC.': 'HomeBridge', 'HARWOOD STREET FUNDING I, LLC': 'Harwood Mortgage', 'GUILD MORTGAGE COMPANY': 'Guild Mortgage', 'GMAC MORTGAGE, LLC (USAA FEDERAL SAVINGS BANK)': 'GMAC', 'GMAC MORTGAGE, LLC': 'GMAC', 'GMAC (USAA)': 'GMAC', 'FREMONT BANK': 'Fremont Bank', 'FREEDOM MORTGAGE CORP.': 'Freedom Mortgage', 'FRANKLIN AMERICAN MORTGAGE COMPANY': 'Franklin America', 'FLEET NATIONAL BANK': 'Fleet National', 'FLAGSTAR CAPITAL MARKETS CORPORATION': 'Flagstar Bank', 'FLAGSTAR BANK, FSB': 'Flagstar Bank', 'FIRST TENNESSEE BANK NATIONAL ASSOCIATION': 'Other', 'FIFTH THIRD BANK': 'Fifth Third Bank', 'FEDERAL HOME LOAN BANK OF CHICAGO': 'Fedral Home of Chicago', 'FDIC, RECEIVER, INDYMAC FEDERAL BANK FSB': 'FDIC', 'DOWNEY SAVINGS AND LOAN ASSOCIATION, F.A.': 'Downey Mortgage', 'DITECH FINANCIAL LLC': 'Ditech', 'CITIMORTGAGE, INC.': 'Citi', 'CHICAGO MORTGAGE SOLUTIONS DBA INTERFIRST MORTGAGE COMPANY': 'Chicago Mortgage', 'CHICAGO MORTGAGE SOLUTIONS DBA INTERBANK MORTGAGE COMPANY': 'Chicago Mortgage', 'CHASE HOME FINANCE, LLC': 'JP Morgan Chase', 'CHASE HOME FINANCE FRANKLIN AMERICAN MORTGAGE COMPANY': 'JP Morgan Chase', 'CHASE HOME FINANCE (CIE 1)': 'JP Morgan Chase', 'CHASE HOME FINANCE': 'JP Morgan Chase', 'CASHCALL, INC.': 'CashCall', 'CAPITAL ONE, NATIONAL ASSOCIATION': 'Capital One', 'CALIBER HOME LOANS, INC.': 'Caliber Funding', 'BISHOPS GATE RESIDENTIAL MORTGAGE TRUST': 'Bishops Gate Mortgage', 'BANK OF AMERICA, N.A.': 'Bank of America', 'AMTRUST BANK': 'AmTrust', 'AMERISAVE MORTGAGE CORPORATION': 'Amerisave', 'AMERIHOME MORTGAGE COMPANY, LLC': 'AmeriHome Mortgage', 'ALLY BANK': 'Ally Bank', 'ACADEMY MORTGAGE CORPORATION': 'Academy Mortgage', 'NO CASH-OUT REFINANCE': 'OTHER REFINANCE', 'REFINANCE - NOT SPECIFIED': 'OTHER REFINANCE', 'Other REFINANCE': 'OTHER REFINANCE', } rawSchema = StructType([ StructField("reference_pool_id", StringType()), StructField("loan_id", LongType()), StructField("monthly_reporting_period", StringType()), StructField("orig_channel", StringType()), StructField("seller_name", StringType()), StructField("servicer", StringType()), StructField("master_servicer", StringType()), StructField("orig_interest_rate", DoubleType()), StructField("interest_rate", DoubleType()), StructField("orig_upb", DoubleType()), StructField("upb_at_issuance", StringType()), StructField("current_actual_upb", DoubleType()), StructField("orig_loan_term", IntegerType()), StructField("orig_date", StringType()), StructField("first_pay_date", StringType()), StructField("loan_age", DoubleType()), StructField("remaining_months_to_legal_maturity", DoubleType()), StructField("adj_remaining_months_to_maturity", DoubleType()), StructField("maturity_date", StringType()), StructField("orig_ltv", DoubleType()), StructField("orig_cltv", DoubleType()), StructField("num_borrowers", DoubleType()), StructField("dti", DoubleType()), StructField("borrower_credit_score", DoubleType()), StructField("coborrow_credit_score", DoubleType()), StructField("first_home_buyer", StringType()), StructField("loan_purpose", StringType()), StructField("property_type", StringType()), StructField("num_units", IntegerType()), StructField("occupancy_status", StringType()), StructField("property_state", StringType()), StructField("msa", DoubleType()), StructField("zip", IntegerType()), StructField("mortgage_insurance_percent", DoubleType()), StructField("product_type", StringType()), StructField("prepayment_penalty_indicator", StringType()), StructField("interest_only_loan_indicator", StringType()), StructField("interest_only_first_principal_and_interest_payment_date", StringType()), StructField("months_to_amortization", StringType()), StructField("current_loan_delinquency_status", IntegerType()), StructField("loan_payment_history", StringType()), StructField("mod_flag", StringType()), StructField("mortgage_insurance_cancellation_indicator", StringType()), StructField("zero_balance_code", StringType()), StructField("zero_balance_effective_date", StringType()), StructField("upb_at_the_time_of_removal", StringType()), StructField("repurchase_date", StringType()), StructField("scheduled_principal_current", StringType()), StructField("total_principal_current", StringType()), StructField("unscheduled_principal_current", StringType()), StructField("last_paid_installment_date", StringType()), StructField("foreclosed_after", StringType()), StructField("disposition_date", StringType()), StructField("foreclosure_costs", DoubleType()), StructField("prop_preservation_and_repair_costs", DoubleType()), StructField("asset_recovery_costs", DoubleType()), StructField("misc_holding_expenses", DoubleType()), StructField("holding_taxes", DoubleType()), StructField("net_sale_proceeds", DoubleType()), StructField("credit_enhancement_proceeds", DoubleType()), StructField("repurchase_make_whole_proceeds", StringType()), StructField("other_foreclosure_proceeds", DoubleType()), StructField("non_interest_bearing_upb", DoubleType()), StructField("principal_forgiveness_upb", StringType()), StructField("original_list_start_date", StringType()), StructField("original_list_price", StringType()), StructField("current_list_start_date", StringType()), StructField("current_list_price", StringType()), StructField("borrower_credit_score_at_issuance", StringType()), StructField("co-borrower_credit_score_at_issuance", StringType()), StructField("borrower_credit_score_current", StringType()), StructField("co-Borrower_credit_score_current", StringType()), StructField("mortgage_insurance_type", DoubleType()), StructField("servicing_activity_indicator", StringType()), StructField("current_period_modification_loss_amount", StringType()), StructField("cumulative_modification_loss_amount", StringType()), StructField("current_period_credit_event_net_gain_or_loss", StringType()), StructField("cumulative_credit_event_net_gain_or_loss", StringType()), StructField("homeready_program_indicator", StringType()), StructField("foreclosure_principal_write_off_amount", StringType()), StructField("relocation_mortgage_indicator", StringType()), StructField("zero_balance_code_change_date", StringType()), StructField("loan_holdback_indicator", StringType()), StructField("loan_holdback_effective_date", StringType()), StructField("delinquent_accrued_interest", StringType()), StructField("property_valuation_method", StringType()), StructField("high_balance_loan_indicator", StringType()), StructField("arm_initial_fixed-rate_period_lt_5_yr_indicator", StringType()), StructField("arm_product_type", StringType()), StructField("initial_fixed-rate_period", StringType()), StructField("interest_rate_adjustment_frequency", StringType()), StructField("next_interest_rate_adjustment_date", StringType()), StructField("next_payment_change_date", StringType()), StructField("index", StringType()), StructField("arm_cap_structure", StringType()), StructField("initial_interest_rate_cap_up_percent", StringType()), StructField("periodic_interest_rate_cap_up_percent", StringType()), StructField("lifetime_interest_rate_cap_up_percent", StringType()), StructField("mortgage_margin", StringType()), StructField("arm_balloon_indicator", StringType()), StructField("arm_plan_number", StringType()), StructField("borrower_assistance_plan", StringType()), StructField("hltv_refinance_option_indicator", StringType()), StructField("deal_name", StringType()), StructField("repurchase_make_whole_proceeds_flag", StringType()), StructField("alternative_delinquency_resolution", StringType()), StructField("alternative_delinquency_resolution_count", StringType()), StructField("total_deferral_amount", StringType()) ]) categorical_columns = [ 'orig_channel', 'first_home_buyer', 'loan_purpose', 'property_type', 'occupancy_status', 'property_state', 'product_type', 'relocation_mortgage_indicator', 'seller_name', 'mod_flag', ] numeric_columns = [ 'orig_interest_rate', 'orig_upb', 'orig_loan_term', 'orig_ltv', 'orig_cltv', 'num_borrowers', 'dti', 'borrower_credit_score', 'num_units', 'zip', 'mortgage_insurance_percent', 'current_loan_delinquency_status', 'current_actual_upb', 'interest_rate', 'loan_age', 'msa', 'non_interest_bearing_upb', 'delinquency_12', ] ================================================ FILE: examples/XGBoost-Examples/mortgage/python/com/nvidia/spark/examples/mortgage/cross_validator_main.py ================================================ # # Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from pyspark.ml.tuning import ParamGridBuilder, CrossValidator from .consts import * from com.nvidia.spark.examples.utility.utils import * from pyspark.sql import SparkSession from xgboost.spark import SparkXGBClassifier, SparkXGBClassifierModel def main(args, xgboost_args): spark = (SparkSession .builder .appName(args.mainClass) .getOrCreate()) train_data, eval_data, trans_data = valid_input_data(spark, args, '', schema) if args.mode in ['all', 'train']: if train_data is None: print('-' * 80) print('Usage: training data path required when mode is all or train') exit(1) train_data, features = transform_data(train_data, label, args.use_gpu) xgboost_args['features_col'] = features xgboost_args['label_col'] = label classifier = SparkXGBClassifier(**xgboost_args) evaluator = (MulticlassClassificationEvaluator() .setLabelCol(label)) param_grid = (ParamGridBuilder() .addGrid(classifier.max_depth, [6, 8]) .addGrid(classifier.n_estimators, [20, 40]) .build()) cross_validator = (CrossValidator() .setEstimator(classifier) .setEvaluator(evaluator) .setEstimatorParamMaps(param_grid) .setNumFolds(3)) if not train_data: print('-' * 80) print('Usage: training data path required when mode is all or train') exit(1) model = with_benchmark('Training', lambda: cross_validator.fit(train_data)) # get the best model to do transform model = model.bestModel if args.modelPath: writer = model.write().overwrite() if args.overwrite else model writer.save(args.modelPath) else: model = SparkXGBClassifierModel.load(args.modelPath) if args.mode in ['all', 'transform']: if not trans_data: print('-' * 80) print('Usage: trans data path required when mode is all or transform') exit(1) trans_data, _ = transform_data(trans_data, label, args.use_gpu) def transform(): result = model.transform(trans_data).cache() result.foreachPartition(lambda _: None) return result result = with_benchmark('Transformation', transform) show_sample(args, result, label) with_benchmark('Evaluation', lambda: check_classification_accuracy(result, label)) spark.stop() ================================================ FILE: examples/XGBoost-Examples/mortgage/python/com/nvidia/spark/examples/mortgage/etl.py ================================================ # # Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from .consts import * from pyspark.sql.functions import * from pyspark.sql.types import * from pyspark.sql.window import Window from sys import exit get_quarter = udf(lambda path: path.split(r'.')[0].split('/')[-1], StringType()) standardize_name = udf(lambda name: name_mapping.get(name), StringType()) def load_data(spark, paths, schema, args, extra_csv_opts={}): reader = (spark .read .format(args.format) .option('asFloats', args.asFloats) .option('maxRowsPerChunk', args.maxRowsPerChunk)) if args.format == 'csv': (reader .schema(schema) .option('delimiter', '|') .option('header', False)) for k, v in extra_csv_opts.items(): reader.option(k, v) return reader.load(paths) def prepare_rawDf(spark, args): extra_csv_options = { 'nullValue': '', 'parserLib': 'univocity', } paths = extract_paths(args.dataPaths, 'data::') rawDf = load_data(spark, paths, rawSchema, args, extra_csv_options) return rawDf def extract_perf_columns(rawDf): perfDf = rawDf.select( col("loan_id"), date_format(to_date(col("monthly_reporting_period"),"MMyyyy"), "MM/dd/yyyy").alias("monthly_reporting_period"), upper(col("servicer")).alias("servicer"), col("interest_rate"), col("current_actual_upb"), col("loan_age"), col("remaining_months_to_legal_maturity"), col("adj_remaining_months_to_maturity"), date_format(to_date(col("maturity_date"),"MMyyyy"), "MM/yyyy").alias("maturity_date"), col("msa"), col("current_loan_delinquency_status"), col("mod_flag"), col("zero_balance_code"), date_format(to_date(col("zero_balance_effective_date"),"MMyyyy"), "MM/yyyy").alias("zero_balance_effective_date"), date_format(to_date(col("last_paid_installment_date"),"MMyyyy"), "MM/dd/yyyy").alias("last_paid_installment_date"), date_format(to_date(col("foreclosed_after"),"MMyyyy"), "MM/dd/yyyy").alias("foreclosed_after"), date_format(to_date(col("disposition_date"),"MMyyyy"), "MM/dd/yyyy").alias("disposition_date"), col("foreclosure_costs"), col("prop_preservation_and_repair_costs"), col("asset_recovery_costs"), col("misc_holding_expenses"), col("holding_taxes"), col("net_sale_proceeds"), col("credit_enhancement_proceeds"), col("repurchase_make_whole_proceeds"), col("other_foreclosure_proceeds"), col("non_interest_bearing_upb"), col("principal_forgiveness_upb"), col("repurchase_make_whole_proceeds_flag"), col("foreclosure_principal_write_off_amount"), col("servicing_activity_indicator")) return perfDf.select("*").filter("current_actual_upb != 0.0") def prepare_performance(spark, args, rawDf): performance = (extract_perf_columns(rawDf) .withColumn('quarter', get_quarter(input_file_name())) .withColumn('timestamp', to_date(col('monthly_reporting_period'), 'MM/dd/yyyy')) .withColumn('timestamp_year', year(col('timestamp'))) .withColumn('timestamp_month', month(col('timestamp')))) aggregation = (performance .select( 'quarter', 'loan_id', 'current_loan_delinquency_status', when(col('current_loan_delinquency_status') >= 1, col('timestamp')) .alias('delinquency_30'), when(col('current_loan_delinquency_status') >= 3, col('timestamp')) .alias('delinquency_90'), when(col('current_loan_delinquency_status') >= 6, col('timestamp')) .alias('delinquency_180')) .groupBy('quarter', 'loan_id') .agg( max('current_loan_delinquency_status').alias('delinquency_12'), min('delinquency_30').alias('delinquency_30'), min('delinquency_90').alias('delinquency_90'), min('delinquency_180').alias('delinquency_180')) .select( 'quarter', 'loan_id', (col('delinquency_12') >= 1).alias('ever_30'), (col('delinquency_12') >= 3).alias('ever_90'), (col('delinquency_12') >= 6).alias('ever_180'), 'delinquency_30', 'delinquency_90', 'delinquency_180')) months = spark.createDataFrame(range(12), IntegerType()).withColumnRenamed('value', 'month_y') to_join = (performance .select( 'quarter', 'loan_id', 'timestamp_year', 'timestamp_month', col('current_loan_delinquency_status').alias('delinquency_12'), col('current_actual_upb').alias('upb_12')) .join(aggregation, ['loan_id', 'quarter'], 'left_outer') .crossJoin(months) .select( 'quarter', floor( (col('timestamp_year') * 12 + col('timestamp_month') - 24000 - col('month_y')) / 12 ).alias('josh_mody_n'), 'ever_30', 'ever_90', 'ever_180', 'delinquency_30', 'delinquency_90', 'delinquency_180', 'loan_id', 'month_y', 'delinquency_12', 'upb_12') .groupBy( 'quarter', 'loan_id', 'josh_mody_n', 'ever_30', 'ever_90', 'ever_180', 'delinquency_30', 'delinquency_90', 'delinquency_180', 'month_y') .agg( max('delinquency_12').alias('delinquency_12'), min('upb_12').alias('upb_12')) .withColumn( 'timestamp_year', floor((24000 + (col('josh_mody_n') * 12) + (col('month_y') - 1)) / 12)) .withColumn( 'timestamp_month_tmp', (24000 + (col('josh_mody_n') * 12) + col('month_y')) % 12) .withColumn( 'timestamp_month', when(col('timestamp_month_tmp') == 0, 12).otherwise(col('timestamp_month_tmp'))) .withColumn( 'delinquency_12', ((col('delinquency_12') > 3).cast('int') + (col('upb_12') == 0).cast('int'))) .drop('timestamp_month_tmp', 'josh_mody_n', 'month_y')) return (performance .join(to_join, ['quarter', 'loan_id', 'timestamp_year', 'timestamp_month'], 'left') .drop('timestamp_year', 'timestamp_month')) def extract_acq_columns(rawDf): acqDf = rawDf.select( col("loan_id"), col("orig_channel"), upper(col("seller_name")).alias("seller_name"), col("orig_interest_rate"), col("orig_upb"), col("orig_loan_term"), date_format(to_date(col("orig_date"),"MMyyyy"), "MM/yyyy").alias("orig_date"), date_format(to_date(col("first_pay_date"),"MMyyyy"), "MM/yyyy").alias("first_pay_date"), col("orig_ltv"), col("orig_cltv"), col("num_borrowers"), col("dti"), col("borrower_credit_score"), col("first_home_buyer"), col("loan_purpose"), col("property_type"), col("num_units"), col("occupancy_status"), col("property_state"), col("zip"), col("mortgage_insurance_percent"), col("product_type"), col("coborrow_credit_score"), col("mortgage_insurance_type"), col("relocation_mortgage_indicator"), dense_rank().over(Window.partitionBy("loan_id").orderBy(to_date(col("monthly_reporting_period"),"MMyyyy"))).alias("rank") ) return acqDf.select("*").filter(col("rank")==1) def prepare_acquisition(spark, args, rawDf): return (extract_acq_columns(rawDf) .withColumn('quarter', get_quarter(input_file_name())) .withColumn('seller_name', standardize_name(col('seller_name')))) def extract_paths(paths, prefix): results = [ path[len(prefix):] for path in paths if path.startswith(prefix) ] if not results: print('-' * 80) print('Usage: {} data path required'.format(prefix)) exit(1) return results def etl(spark, args): rawDf = prepare_rawDf(spark, args) rawDf.write.parquet(extract_paths(args.dataPaths, 'tmp::')[0], mode='overwrite') rawDf = spark.read.parquet(extract_paths(args.dataPaths, 'tmp::')[0]) performance = prepare_performance(spark, args, rawDf) acquisition = prepare_acquisition(spark, args, rawDf) return (performance .join(acquisition, ['loan_id', 'quarter'], 'left_outer') .select( [(md5(col(x)) % 100).alias(x) for x in categorical_columns] + [col(x) for x in numeric_columns]) .withColumn('delinquency_12', when(col('delinquency_12') > 0, 1).otherwise(0)) .na .fill(0)) ================================================ FILE: examples/XGBoost-Examples/mortgage/python/com/nvidia/spark/examples/mortgage/etl_main.py ================================================ # # Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from .etl import etl, extract_paths from com.nvidia.spark.examples.utility.utils import * from pyspark.sql import SparkSession def main(args, xgboost_args): spark = (SparkSession .builder .appName(args.mainClass) .getOrCreate()) etled_df = etl(spark, args) # outPath should has only one input outPath = extract_paths(args.dataPaths, 'out::')[0] etled_df.write.mode("overwrite").parquet(outPath) ================================================ FILE: examples/XGBoost-Examples/mortgage/python/com/nvidia/spark/examples/mortgage/main.py ================================================ # # Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from xgboost.spark import SparkXGBClassifier, SparkXGBClassifierModel from .consts import * from com.nvidia.spark.examples.utility.utils import * from pyspark.sql import SparkSession def main(args, xgboost_args): spark = (SparkSession .builder .appName(args.mainClass) .getOrCreate()) train_data, eval_data, trans_data = valid_input_data(spark, args, '', schema) if args.mode in ['all', 'train']: if train_data is None: print('-' * 80) print('Usage: training data path required when mode is all or train') exit(1) train_data, features = transform_data(train_data, label, args.use_gpu) xgboost_args['features_col'] = features xgboost_args['label_col'] = label classifier = SparkXGBClassifier(**xgboost_args) if eval_data: # TODO pass model = with_benchmark('Training', lambda: classifier.fit(train_data)) if args.modelPath: writer = model.write().overwrite() if args.overwrite else model writer.save(args.modelPath) else: model = SparkXGBClassifierModel.load(args.modelPath) if args.mode in ['all', 'transform']: trans_data, _ = transform_data(trans_data, label, args.use_gpu) def transform(): result = model.transform(trans_data).cache() result.foreachPartition(lambda _: None) return result if not trans_data: print('-' * 80) print('Usage: trans data path required when mode is all or transform') exit(1) result = with_benchmark('Transformation', transform) show_sample(args, result, label) with_benchmark('Evaluation', lambda: check_classification_accuracy(result, label)) spark.stop() ================================================ FILE: examples/XGBoost-Examples/mortgage/scala/src/com/nvidia/spark/examples/mortgage/CrossValidationMain.scala ================================================ /* * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.nvidia.spark.examples.mortgage import com.nvidia.spark.examples.utility.{XGBoostArgs, Benchmark} import ml.dmlc.xgboost4j.scala.spark.{XGBoostClassificationModel, XGBoostClassifier} import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder} import org.apache.spark.sql.SparkSession object CrossValidationMain extends Mortgage { def main(args: Array[String]): Unit = { val appArgs = XGBoostArgs(args) val processor = this.getClass.getSimpleName.stripSuffix("$").substring(0, 3) val appInfo = Seq(appName, processor, appArgs.format) val benchmark = Benchmark(appInfo(0), appInfo(1), appInfo(2)) // build spark session val spark = SparkSession.builder().appName(appInfo.mkString("-")).getOrCreate() // build data reader val dataReader = spark.read try { // loaded XGBoost ETLed data val pathsArray = appArgs.getDataPaths // 0: train 1: eval 2:transform val datasets = pathsArray.map { paths => if (paths.nonEmpty) { appArgs.format match { case "csv" => Some(dataReader.option("header", appArgs.hasHeader).schema(schema).csv(paths: _*)) case "orc" => Some(dataReader.orc(paths: _*)) case "parquet" => Some(dataReader.parquet(paths: _*)) case _ => throw new IllegalArgumentException("Unsupported data file format!") } } else { None } } val xgbClassificationModel = if (appArgs.isToTrain) { // build XGBoost classifier val xgbParamFinal = appArgs.xgboostParams(commParamMap) val xgbClassifier = new XGBoostClassifier(xgbParamFinal) .setLabelCol(labelColName) .setFeaturesCol(featureNames) // Tune model using cross validation val paramGrid = new ParamGridBuilder() .addGrid(xgbClassifier.maxDepth, Array(3, 10)) .addGrid(xgbClassifier.eta, Array(0.2, 0.6)) .build() val evaluator = new MulticlassClassificationEvaluator().setLabelCol(labelColName) val cv = new CrossValidator() .setEstimator(xgbClassifier) .setEvaluator(evaluator) .setEstimatorParamMaps(paramGrid) .setNumFolds(appArgs.numFold) // Start training println("\n------ CrossValidation ------") // Shall we not log the time if it is abnormal, which is usually caused by training failure val (model, _) = benchmark.time("CrossValidation") { cv.fit(datasets(0).get).bestModel.asInstanceOf[XGBoostClassificationModel] } // Save model if modelPath exists appArgs.modelPath.foreach(path => if (appArgs.isOverwrite) model.write.overwrite().save(path) else model.save(path)) model } else { XGBoostClassificationModel.load(appArgs.modelPath.get) } if (appArgs.isToTransform) { println("\n------ Transforming ------") var (results, _) = benchmark.time("transform") { val ret = xgbClassificationModel.transform(datasets(2).get).cache() // Trigger the transformation ret.foreachPartition((_: Iterator[_]) => ()) ret } results = if (appArgs.isShowFeatures) { results } else { results.select(labelColName, "rawPrediction", "probability", "prediction") } results.show(appArgs.numRows) println("\n------Accuracy of Evaluation------") val evaluator = new MulticlassClassificationEvaluator().setLabelCol(labelColName) evaluator.evaluate(results) match { case accuracy if !accuracy.isNaN => benchmark.value(accuracy, "Accuracy", "Accuracy for") // Throw an exception when NaN ? } } } finally { spark.close() } } } ================================================ FILE: examples/XGBoost-Examples/mortgage/scala/src/com/nvidia/spark/examples/mortgage/ETLMain.scala ================================================ /* * Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.nvidia.spark.examples.mortgage import com.nvidia.spark.examples.utility.{XGBoostArgs, Benchmark} import org.apache.hadoop.fs.Path import org.apache.spark.sql.SparkSession object ETLMain extends Mortgage { def main(args: Array[String]): Unit = { val xgbArgs = XGBoostArgs(args) val subTitle = getClass.getSimpleName.stripSuffix("$").substring(0, 3) val appInfo = Seq(appName, subTitle, xgbArgs.format) val benchmark = Benchmark(appInfo(0), appInfo(1), appInfo(2)) // build spark session val spark = SparkSession.builder().appName(appInfo.mkString("-")).getOrCreate() try { val (dataPaths, outPath, tmpPath) = checkAndGetPaths(xgbArgs.dataPaths) println("\n------ Start ETL ------") benchmark.time("ETL") { // ETL the raw data val rawDF = xgbArgs.format match { case "csv" => XGBoostETL.csv(spark, dataPaths, tmpPath, false) case "orc" => XGBoostETL.orc(spark, dataPaths) case "parquet" => XGBoostETL.parquet(spark, dataPaths) case _ => throw new IllegalArgumentException("Unsupported data file format!") } rawDF.write.mode("overwrite").parquet(outPath) } if (xgbArgs.saveDict) { XGBoostETL.saveDictTable(new Path(outPath, ".dict").toString) } } finally { XGBoostETL.clean() spark.close() } } def checkAndGetPaths(paths: Seq[String]): (Seq[String], String, String) = { val prefixes = Array("data::", "out::", "tmp::") val validPaths = paths.filter(_.nonEmpty).map(_.trim) // get and check perf data paths val dataPaths = validPaths.filter(_.startsWith(prefixes.head)) require(dataPaths.nonEmpty, s"$appName ETL requires at least one path for data file." + s" Please specify it by '-dataPath=data::your_data_path'") // get and check out path val outPath = validPaths.filter(_.startsWith(prefixes(1))) require(outPath.nonEmpty, s"$appName ETL requires a path to save the ETLed data file. Please specify it" + " by '-dataPath=out::your_out_path', only the first path is used if multiple paths are found.") // get and check tmp path val tmpPath = validPaths.filter(_.startsWith(prefixes(2))) require(tmpPath.nonEmpty, s"$appName ETL requires a path to save the temp parquet files. Please specify it" + " by '-dataPath=tmp::your_out_path'.") // check data paths not specified type val unknownPaths = validPaths.filterNot(p => prefixes.exists(p.contains(_))) require(unknownPaths.isEmpty, s"Unknown type for data path: ${unknownPaths.head}, $appName requires to specify" + " the type for each data path by adding the prefix 'data::' or 'out::'.") (dataPaths.map(_.stripPrefix(prefixes.head)), outPath.head.stripPrefix(prefixes(1)), tmpPath.head.stripPrefix(prefixes(2))) } } ================================================ FILE: examples/XGBoost-Examples/mortgage/scala/src/com/nvidia/spark/examples/mortgage/Main.scala ================================================ /* * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.nvidia.spark.examples.mortgage import com.nvidia.spark.examples.utility.{XGBoostArgs, Benchmark} import ml.dmlc.xgboost4j.scala.spark.{XGBoostClassificationModel, XGBoostClassifier} import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator import org.apache.spark.sql.SparkSession object Main extends Mortgage { def main(args: Array[String]): Unit = { val appArgs = XGBoostArgs(args) val processor = this.getClass.getSimpleName.stripSuffix("$").substring(0, 3) val appInfo = Seq(appName, processor, appArgs.format) val benchmark = Benchmark(appInfo(0), appInfo(1), appInfo(2)) // build spark session val spark = SparkSession.builder().appName(appInfo.mkString("-")).getOrCreate() // build data reader val dataReader = spark.read try { // loaded XGBoost ETLed data val pathsArray = appArgs.getDataPaths // 0: train 1: eval 2:transform val datasets = pathsArray.map { paths => if (paths.nonEmpty) { appArgs.format match { case "csv" => Some(dataReader.option("header", appArgs.hasHeader).schema(schema).csv(paths: _*)) case "orc" => Some(dataReader.orc(paths: _*)) case "parquet" => Some(dataReader.parquet(paths: _*)) case _ => throw new IllegalArgumentException("Unsupported data file format!") } } else { None } } val xgbClassificationModel = if (appArgs.isToTrain) { // build XGBoost classifier val xgbParamFinal = appArgs.xgboostParams(commParamMap) val xgbClassifier = new XGBoostClassifier(xgbParamFinal) .setLabelCol(labelColName) .setFeaturesCol(featureNames) datasets(1).foreach(_ => xgbClassifier.setEvalDataset(_)) // Start training println("\n------ Training ------") // Shall we not log the time if it is abnormal, which is usually caused by training failure val (model, _) = benchmark.time("train") { xgbClassifier.fit(datasets(0).get) } // Save model if modelPath exists appArgs.modelPath.foreach(path => if (appArgs.isOverwrite) model.write.overwrite().save(path) else model.save(path)) model } else { XGBoostClassificationModel.load(appArgs.modelPath.get) } if (appArgs.isToTransform) { println("\n------ Transforming ------") var (results, _) = benchmark.time("transform") { val ret = xgbClassificationModel.transform(datasets(2).get).cache() // Trigger the transformation ret.foreachPartition((_: Iterator[_]) => ()) ret } results = if (appArgs.isShowFeatures) { results } else { results.select(labelColName, "rawPrediction", "probability", "prediction") } results.show(appArgs.numRows) println("\n------Accuracy of Evaluation------") val evaluator = new MulticlassClassificationEvaluator().setLabelCol(labelColName) evaluator.evaluate(results) match { case accuracy if !accuracy.isNaN => benchmark.value(accuracy, "Accuracy", "Accuracy for") // Throw an exception when NaN ? } } } finally { spark.close() } } } ================================================ FILE: examples/XGBoost-Examples/mortgage/scala/src/com/nvidia/spark/examples/mortgage/Mortgage.scala ================================================ /* * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.nvidia.spark.examples.mortgage import org.apache.spark.sql.types.{FloatType, IntegerType, StructField, StructType, DoubleType} private[mortgage] trait Mortgage { val appName = "Mortgage" val labelColName = "delinquency_12" protected val categaryCols = List( ("orig_channel", FloatType), ("first_home_buyer", FloatType), ("loan_purpose", FloatType), ("property_type", FloatType), ("occupancy_status", FloatType), ("property_state", FloatType), ("product_type", FloatType), ("relocation_mortgage_indicator", FloatType), ("seller_name", FloatType), ("mod_flag", FloatType) ) protected val numericCols = List( ("orig_interest_rate", FloatType), ("orig_upb", DoubleType), ("orig_loan_term", IntegerType), ("orig_ltv", FloatType), ("orig_cltv", FloatType), ("num_borrowers", FloatType), ("dti", FloatType), ("borrower_credit_score", FloatType), ("num_units", IntegerType), ("zip", IntegerType), ("mortgage_insurance_percent", FloatType), ("current_loan_delinquency_status", IntegerType), ("current_actual_upb", FloatType), ("interest_rate", FloatType), ("loan_age", FloatType), ("msa", FloatType), ("non_interest_bearing_upb", FloatType), (labelColName, IntegerType) ) lazy val schema = StructType((categaryCols ++ numericCols).map(col => StructField(col._1, col._2))) lazy val featureNames = schema.filter(_.name != labelColName).map(_.name).toArray val commParamMap = Map( "objective" -> "binary:logistic", "num_round" -> 100) } ================================================ FILE: examples/XGBoost-Examples/mortgage/scala/src/com/nvidia/spark/examples/mortgage/XGBoostETL.scala ================================================ /* * Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.nvidia.spark.examples.mortgage import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.sql.{Column, DataFrame, SparkSession} object GetQuarterFromCsvFileName { // The format is path/TYPE_yyyy\QQ.txt followed by a (_index)* where index is a single digit number [0-9] // i.e. mortgage/perf/Performance_2003Q4.txt_0_1 // So we strip off the .txt and everything after it // and then take everything after the last remaining _ def apply(): Column = substring_index( substring_index(input_file_name(), ".", 1), "/", -1) } private object CsvReader { def readRaw(spark: SparkSession, paths: Seq[String], optionsMap: Map[String, String]): DataFrame = { val rawSchema = StructType(Array( StructField("reference_pool_id", StringType), StructField("loan_id", LongType), StructField("monthly_reporting_period", StringType), StructField("orig_channel", StringType), StructField("seller_name", StringType), StructField("servicer", StringType), StructField("master_servicer", StringType), StructField("orig_interest_rate", DoubleType), StructField("interest_rate", DoubleType), StructField("orig_upb", DoubleType), StructField("upb_at_issuance", StringType), StructField("current_actual_upb", DoubleType), StructField("orig_loan_term", IntegerType), StructField("orig_date", StringType), StructField("first_pay_date", StringType), StructField("loan_age", DoubleType), StructField("remaining_months_to_legal_maturity", DoubleType), StructField("adj_remaining_months_to_maturity", DoubleType), StructField("maturity_date", StringType), StructField("orig_ltv", DoubleType), StructField("orig_cltv", DoubleType), StructField("num_borrowers", DoubleType), StructField("dti", DoubleType), StructField("borrower_credit_score", DoubleType), StructField("coborrow_credit_score", DoubleType), StructField("first_home_buyer", StringType), StructField("loan_purpose", StringType), StructField("property_type", StringType), StructField("num_units", IntegerType), StructField("occupancy_status", StringType), StructField("property_state", StringType), StructField("msa", DoubleType), StructField("zip", IntegerType), StructField("mortgage_insurance_percent", DoubleType), StructField("product_type", StringType), StructField("prepayment_penalty_indicator", StringType), StructField("interest_only_loan_indicator", StringType), StructField("interest_only_first_principal_and_interest_payment_date", StringType), StructField("months_to_amortization", StringType), StructField("current_loan_delinquency_status", IntegerType), StructField("loan_payment_history", StringType), StructField("mod_flag", StringType), StructField("mortgage_insurance_cancellation_indicator", StringType), StructField("zero_balance_code", StringType), StructField("zero_balance_effective_date", StringType), StructField("upb_at_the_time_of_removal", StringType), StructField("repurchase_date", StringType), StructField("scheduled_principal_current", StringType), StructField("total_principal_current", StringType), StructField("unscheduled_principal_current", StringType), StructField("last_paid_installment_date", StringType), StructField("foreclosed_after", StringType), StructField("disposition_date", StringType), StructField("foreclosure_costs", DoubleType), StructField("prop_preservation_and_repair_costs", DoubleType), StructField("asset_recovery_costs", DoubleType), StructField("misc_holding_expenses", DoubleType), StructField("holding_taxes", DoubleType), StructField("net_sale_proceeds", DoubleType), StructField("credit_enhancement_proceeds", DoubleType), StructField("repurchase_make_whole_proceeds", StringType), StructField("other_foreclosure_proceeds", DoubleType), StructField("non_interest_bearing_upb", DoubleType), StructField("principal_forgiveness_upb", StringType), StructField("original_list_start_date", StringType), StructField("original_list_price", StringType), StructField("current_list_start_date", StringType), StructField("current_list_price", StringType), StructField("borrower_credit_score_at_issuance", StringType), StructField("co-borrower_credit_score_at_issuance", StringType), StructField("borrower_credit_score_current", StringType), StructField("co-Borrower_credit_score_current", StringType), StructField("mortgage_insurance_type", DoubleType), StructField("servicing_activity_indicator", StringType), StructField("current_period_modification_loss_amount", StringType), StructField("cumulative_modification_loss_amount", StringType), StructField("current_period_credit_event_net_gain_or_loss", StringType), StructField("cumulative_credit_event_net_gain_or_loss", StringType), StructField("homeready_program_indicator", StringType), StructField("foreclosure_principal_write_off_amount", StringType), StructField("relocation_mortgage_indicator", StringType), StructField("zero_balance_code_change_date", StringType), StructField("loan_holdback_indicator", StringType), StructField("loan_holdback_effective_date", StringType), StructField("delinquent_accrued_interest", StringType), StructField("property_valuation_method", StringType), StructField("high_balance_loan_indicator", StringType), StructField("arm_initial_fixed-rate_period_lt_5_yr_indicator", StringType), StructField("arm_product_type", StringType), StructField("initial_fixed-rate_period", StringType), StructField("interest_rate_adjustment_frequency", StringType), StructField("next_interest_rate_adjustment_date", StringType), StructField("next_payment_change_date", StringType), StructField("index", StringType), StructField("arm_cap_structure", StringType), StructField("initial_interest_rate_cap_up_percent", StringType), StructField("periodic_interest_rate_cap_up_percent", StringType), StructField("lifetime_interest_rate_cap_up_percent", StringType), StructField("mortgage_margin", StringType), StructField("arm_balloon_indicator", StringType), StructField("arm_plan_number", StringType), StructField("borrower_assistance_plan", StringType), StructField("hltv_refinance_option_indicator", StringType), StructField("deal_name", StringType), StructField("repurchase_make_whole_proceeds_flag", StringType), StructField("alternative_delinquency_resolution", StringType), StructField("alternative_delinquency_resolution_count", StringType), StructField("total_deferral_amount", StringType) ) ) spark.read .options(optionsMap) .option("nullValue", "") .option("delimiter", "|") .schema(rawSchema) .csv(paths: _*) .withColumn("quarter", GetQuarterFromCsvFileName()) } } object extractPerfColumns{ def apply(rawDf : DataFrame) : DataFrame = { val perfDf = rawDf.select( col("loan_id"), date_format(to_date(col("monthly_reporting_period"),"MMyyyy"), "MM/dd/yyyy").as("monthly_reporting_period"), upper(col("servicer")).as("servicer"), col("interest_rate"), col("current_actual_upb"), col("loan_age"), col("remaining_months_to_legal_maturity"), col("adj_remaining_months_to_maturity"), date_format(to_date(col("maturity_date"),"MMyyyy"), "MM/yyyy").as("maturity_date"), col("msa"), col("current_loan_delinquency_status"), col("mod_flag"), col("zero_balance_code"), date_format(to_date(col("zero_balance_effective_date"),"MMyyyy"), "MM/yyyy").as("zero_balance_effective_date"), date_format(to_date(col("last_paid_installment_date"),"MMyyyy"), "MM/dd/yyyy").as("last_paid_installment_date"), date_format(to_date(col("foreclosed_after"),"MMyyyy"), "MM/dd/yyyy").as("foreclosed_after"), date_format(to_date(col("disposition_date"),"MMyyyy"), "MM/dd/yyyy").as("disposition_date"), col("foreclosure_costs"), col("prop_preservation_and_repair_costs"), col("asset_recovery_costs"), col("misc_holding_expenses"), col("holding_taxes"), col("net_sale_proceeds"), col("credit_enhancement_proceeds"), col("repurchase_make_whole_proceeds"), col("other_foreclosure_proceeds"), col("non_interest_bearing_upb"), col("principal_forgiveness_upb"), col("repurchase_make_whole_proceeds_flag"), col("foreclosure_principal_write_off_amount"), col("servicing_activity_indicator"), col("quarter") ) perfDf.select("*").filter("current_actual_upb != 0.0") } } object extractAcqColumns{ def apply(rawDf : DataFrame) : DataFrame = { val acqDf = rawDf.select( col("loan_id"), col("orig_channel"), upper(col("seller_name")).as("seller_name"), col("orig_interest_rate"), col("orig_upb"), col("orig_loan_term"), date_format(to_date(col("orig_date"),"MMyyyy"), "MM/yyyy").as("orig_date"), date_format(to_date(col("first_pay_date"),"MMyyyy"), "MM/yyyy").as("first_pay_date"), col("orig_ltv"), col("orig_cltv"), col("num_borrowers"), col("dti"), col("borrower_credit_score"), col("first_home_buyer"), col("loan_purpose"), col("property_type"), col("num_units"), col("occupancy_status"), col("property_state"), col("zip"), col("mortgage_insurance_percent"), col("product_type"), col("coborrow_credit_score"), col("mortgage_insurance_type"), col("relocation_mortgage_indicator"), col("quarter"), dense_rank().over(Window.partitionBy("loan_id").orderBy(to_date(col("monthly_reporting_period"),"MMyyyy"))).as("rank") ) acqDf.select("*").filter(col("rank") === 1).drop("rank") } } object NameMapping { /** * Returns a dataframe with two columns named based off of the column names passed in. * The fromColName has the original name we want to clean up, the toColName * will have the name we want to go to, the unambiguous name. */ def apply(spark: SparkSession, fromColName: String, toColName: String): DataFrame = { import spark.sqlContext.implicits._ broadcast(Seq( ("WITMER FUNDING, LLC", "Witmer"), ("WELLS FARGO CREDIT RISK TRANSFER SECURITIES TRUST 2015", "Wells Fargo"), ("WELLS FARGO BANK, NA" , "Wells Fargo"), ("WELLS FARGO BANK, N.A." , "Wells Fargo"), ("WELLS FARGO BANK, NA" , "Wells Fargo"), ("USAA FEDERAL SAVINGS BANK" , "USAA"), ("UNITED SHORE FINANCIAL SERVICES, LLC D\\/B\\/A UNITED WHOLESALE MORTGAGE" , "United Seq(e"), ("U.S. BANK N.A." , "US Bank"), ("SUNTRUST MORTGAGE INC." , "Suntrust"), ("STONEGATE MORTGAGE CORPORATION" , "Stonegate Mortgage"), ("STEARNS LENDING, LLC" , "Stearns Lending"), ("STEARNS LENDING, INC." , "Stearns Lending"), ("SIERRA PACIFIC MORTGAGE COMPANY, INC." , "Sierra Pacific Mortgage"), ("REGIONS BANK" , "Regions"), ("RBC MORTGAGE COMPANY" , "RBC"), ("QUICKEN LOANS INC." , "Quicken Loans"), ("PULTE MORTGAGE, L.L.C." , "Pulte Mortgage"), ("PROVIDENT FUNDING ASSOCIATES, L.P." , "Provident Funding"), ("PROSPECT MORTGAGE, LLC" , "Prospect Mortgage"), ("PRINCIPAL RESIDENTIAL MORTGAGE CAPITAL RESOURCES, LLC" , "Principal Residential"), ("PNC BANK, N.A." , "PNC"), ("PMT CREDIT RISK TRANSFER TRUST 2015-2" , "PennyMac"), ("PHH MORTGAGE CORPORATION" , "PHH Mortgage"), ("PENNYMAC CORP." , "PennyMac"), ("PACIFIC UNION FINANCIAL, LLC" , "Other"), ("OTHER" , "Other"), ("NYCB MORTGAGE COMPANY, LLC" , "NYCB"), ("NEW YORK COMMUNITY BANK" , "NYCB"), ("NETBANK FUNDING SERVICES" , "Netbank"), ("NATIONSTAR MORTGAGE, LLC" , "Nationstar Mortgage"), ("METLIFE BANK, NA" , "Metlife"), ("LOANDEPOT.COM, LLC" , "LoanDepot.com"), ("J.P. MORGAN MADISON AVENUE SECURITIES TRUST, SERIES 2015-1" , "JP Morgan Chase"), ("J.P. MORGAN MADISON AVENUE SECURITIES TRUST, SERIES 2014-1" , "JP Morgan Chase"), ("JPMORGAN CHASE BANK, NATIONAL ASSOCIATION" , "JP Morgan Chase"), ("JPMORGAN CHASE BANK, NA" , "JP Morgan Chase"), ("JP MORGAN CHASE BANK, NA" , "JP Morgan Chase"), ("IRWIN MORTGAGE, CORPORATION" , "Irwin Mortgage"), ("IMPAC MORTGAGE CORP." , "Impac Mortgage"), ("HSBC BANK USA, NATIONAL ASSOCIATION" , "HSBC"), ("HOMEWARD RESIDENTIAL, INC." , "Homeward Mortgage"), ("HOMESTREET BANK" , "Other"), ("HOMEBRIDGE FINANCIAL SERVICES, INC." , "HomeBridge"), ("HARWOOD STREET FUNDING I, LLC" , "Harwood Mortgage"), ("GUILD MORTGAGE COMPANY" , "Guild Mortgage"), ("GMAC MORTGAGE, LLC (USAA FEDERAL SAVINGS BANK)" , "GMAC"), ("GMAC MORTGAGE, LLC" , "GMAC"), ("GMAC (USAA)" , "GMAC"), ("FREMONT BANK" , "Fremont Bank"), ("FREEDOM MORTGAGE CORP." , "Freedom Mortgage"), ("FRANKLIN AMERICAN MORTGAGE COMPANY" , "Franklin America"), ("FLEET NATIONAL BANK" , "Fleet National"), ("FLAGSTAR CAPITAL MARKETS CORPORATION" , "Flagstar Bank"), ("FLAGSTAR BANK, FSB" , "Flagstar Bank"), ("FIRST TENNESSEE BANK NATIONAL ASSOCIATION" , "Other"), ("FIFTH THIRD BANK" , "Fifth Third Bank"), ("FEDERAL HOME LOAN BANK OF CHICAGO" , "Fedral Home of Chicago"), ("FDIC, RECEIVER, INDYMAC FEDERAL BANK FSB" , "FDIC"), ("DOWNEY SAVINGS AND LOAN ASSOCIATION, F.A." , "Downey Mortgage"), ("DITECH FINANCIAL LLC" , "Ditech"), ("CITIMORTGAGE, INC." , "Citi"), ("CHICAGO MORTGAGE SOLUTIONS DBA INTERFIRST MORTGAGE COMPANY" , "Chicago Mortgage"), ("CHICAGO MORTGAGE SOLUTIONS DBA INTERBANK MORTGAGE COMPANY" , "Chicago Mortgage"), ("CHASE HOME FINANCE, LLC" , "JP Morgan Chase"), ("CHASE HOME FINANCE FRANKLIN AMERICAN MORTGAGE COMPANY" , "JP Morgan Chase"), ("CHASE HOME FINANCE (CIE 1)" , "JP Morgan Chase"), ("CHASE HOME FINANCE" , "JP Morgan Chase"), ("CASHCALL, INC." , "CashCall"), ("CAPITAL ONE, NATIONAL ASSOCIATION" , "Capital One"), ("CALIBER HOME LOANS, INC." , "Caliber Funding"), ("BISHOPS GATE RESIDENTIAL MORTGAGE TRUST" , "Bishops Gate Mortgage"), ("BANK OF AMERICA, N.A." , "Bank of America"), ("AMTRUST BANK" , "AmTrust"), ("AMERISAVE MORTGAGE CORPORATION" , "Amerisave"), ("AMERIHOME MORTGAGE COMPANY, LLC" , "AmeriHome Mortgage"), ("ALLY BANK" , "Ally Bank"), ("ACADEMY MORTGAGE CORPORATION" , "Academy Mortgage"), ("NO CASH-OUT REFINANCE" , "OTHER REFINANCE"), ("REFINANCE - NOT SPECIFIED" , "OTHER REFINANCE"), ("Other REFINANCE" , "OTHER REFINANCE") ).toDF(fromColName, toColName)) } } private trait MortgageETL { var dataFrame: DataFrame = _ def from(df: DataFrame): this.type = { dataFrame = df this } } private object PerformanceETL extends MortgageETL { def prepare: this.type = { dataFrame = dataFrame .withColumn("monthly_reporting_period", to_date(col("monthly_reporting_period"), "MM/dd/yyyy")) .withColumn("monthly_reporting_period_month", month(col("monthly_reporting_period"))) .withColumn("monthly_reporting_period_year", year(col("monthly_reporting_period"))) .withColumn("monthly_reporting_period_day", dayofmonth(col("monthly_reporting_period"))) .withColumn("last_paid_installment_date", to_date(col("last_paid_installment_date"), "MM/dd/yyyy")) .withColumn("foreclosed_after", to_date(col("foreclosed_after"), "MM/dd/yyyy")) .withColumn("disposition_date", to_date(col("disposition_date"), "MM/dd/yyyy")) .withColumn("maturity_date", to_date(col("maturity_date"), "MM/yyyy")) .withColumn("zero_balance_effective_date", to_date(col("zero_balance_effective_date"), "MM/yyyy")) .withColumn("current_actual_upb", col("current_actual_upb")) .withColumn("current_loan_delinquency_status", col("current_loan_delinquency_status")) this } def createDelinquency(spark: SparkSession): this.type = { val aggDF = dataFrame .select( col("quarter"), col("loan_id"), col("current_loan_delinquency_status"), when(col("current_loan_delinquency_status") >= 1, col("monthly_reporting_period")).alias("delinquency_30"), when(col("current_loan_delinquency_status") >= 3, col("monthly_reporting_period")).alias("delinquency_90"), when(col("current_loan_delinquency_status") >= 6, col("monthly_reporting_period")).alias("delinquency_180") ) .groupBy("quarter", "loan_id") .agg( max("current_loan_delinquency_status").alias("delinquency_12"), min("delinquency_30").alias("delinquency_30"), min("delinquency_90").alias("delinquency_90"), min("delinquency_180").alias("delinquency_180") ) .select( col("quarter"), col("loan_id"), (col("delinquency_12") >= 1).alias("ever_30"), (col("delinquency_12") >= 3).alias("ever_90"), (col("delinquency_12") >= 6).alias("ever_180"), col("delinquency_30"), col("delinquency_90"), col("delinquency_180") ) val joinedDf = dataFrame .withColumnRenamed("monthly_reporting_period", "timestamp") .withColumnRenamed("monthly_reporting_period_month", "timestamp_month") .withColumnRenamed("monthly_reporting_period_year", "timestamp_year") .withColumnRenamed("current_loan_delinquency_status", "delinquency_12") .withColumnRenamed("current_actual_upb", "upb_12") .select("quarter", "loan_id", "timestamp", "delinquency_12", "upb_12", "timestamp_month", "timestamp_year") .join(aggDF, Seq("loan_id", "quarter"), "left_outer") // calculate the 12 month delinquency and upb values val months = 12 val monthArray = 0.until(months).toArray val testDf = joinedDf // explode on a small amount of data is actually slightly more efficient than a cross join .withColumn("month_y", explode(lit(monthArray))) .select( col("quarter"), floor(((col("timestamp_year") * 12 + col("timestamp_month")) - 24000) / months).alias("josh_mody"), floor(((col("timestamp_year") * 12 + col("timestamp_month")) - 24000 - col("month_y")) / months).alias("josh_mody_n"), col("ever_30"), col("ever_90"), col("ever_180"), col("delinquency_30"), col("delinquency_90"), col("delinquency_180"), col("loan_id"), col("month_y"), col("delinquency_12"), col("upb_12") ) .groupBy("quarter", "loan_id", "josh_mody_n", "ever_30", "ever_90", "ever_180", "delinquency_30", "delinquency_90", "delinquency_180", "month_y") .agg(max("delinquency_12").alias("delinquency_12"), min("upb_12").alias("upb_12")) .withColumn("timestamp_year", floor((lit(24000) + (col("josh_mody_n") * lit(months)) + (col("month_y") - 1)) / lit(12))) .withColumn("timestamp_month_tmp", pmod(lit(24000) + (col("josh_mody_n") * lit(months)) + col("month_y"), lit(12))) .withColumn("timestamp_month", when(col("timestamp_month_tmp") === lit(0), lit(12)).otherwise(col("timestamp_month_tmp"))) .withColumn("delinquency_12", ((col("delinquency_12") > 3).cast("int") + (col("upb_12") === 0).cast("int")).alias("delinquency_12")) .drop("timestamp_month_tmp", "josh_mody_n", "month_y") dataFrame = dataFrame .withColumnRenamed("monthly_reporting_period_month", "timestamp_month") .withColumnRenamed("monthly_reporting_period_year", "timestamp_year") .join(testDf, Seq("quarter", "loan_id", "timestamp_year", "timestamp_month"), "left").drop("timestamp_year", "timestamp_month") this } } private object AcquisitionETL extends MortgageETL { def createAcquisition(spark: SparkSession): this.type = { val nameMapping = NameMapping(spark, "from_seller_name", "to_seller_name") dataFrame = dataFrame .join(nameMapping, col("seller_name") === col("from_seller_name"), "left") .drop("from_seller_name") /* backup the original name before we replace it */ .withColumn("old_name", col("seller_name")) /* replace seller_name with the new version if we found one in the mapping, or the old version if we didn't */ .withColumn("seller_name", coalesce(col("to_seller_name"), col("seller_name"))) .drop("to_seller_name") .withColumn("orig_date", to_date(col("orig_date"), "MM/yyyy")) .withColumn("first_pay_date", to_date(col("first_pay_date"), "MM/yyyy")) this } def cleanPrime(perfDF: DataFrame): this.type = { dataFrame = perfDF.join(dataFrame, Seq("loan_id", "quarter"), "inner").drop("quarter") this } } object XGBoostETL extends Mortgage { private lazy val allCols = (categaryCols ++ numericCols).map(c => col(c._1)) private var cachedDictDF: DataFrame = _ /** * Generate a dictionary from string to numeric value for multiple category columns. * * (Copied the solution of casting string to numeric from the utils of DLRM.) */ private def genDictionary(etlDF: DataFrame, colNames: Seq[String]): DataFrame = { val cntTable = etlDF .select(posexplode(array(colNames.map(col(_)): _*))) .withColumnRenamed("pos", "column_id") .withColumnRenamed("col", "data") .filter("data is not null") .groupBy("column_id", "data") .count() val windowed = Window.partitionBy("column_id").orderBy(desc("count")) cntTable .withColumn("id", row_number().over(windowed)) .drop("count") } /** * Cast all the category columns to numeric columns in the given data frame. * Then it is suitable for XGBoost training/transforming */ private def castStringColumnsToNumeric(inputDF: DataFrame, spark: SparkSession): DataFrame = { val cateColNames = categaryCols.map(_._1) cachedDictDF = genDictionary(inputDF, cateColNames).cache() // Generate the final table with all columns being numeric. cateColNames.foldLeft(inputDF) { case (df, colName) => val colPos = cateColNames.indexOf(colName) val colDictDF = cachedDictDF .filter(col("column_id") === colPos) .drop("column_id") .withColumnRenamed("data", colName) df.join(broadcast(colDictDF), Seq(colName), "left") .drop(colName) .withColumnRenamed("id", colName) } } private def transform(perfDF: DataFrame, acqDF: DataFrame, spark: SparkSession): DataFrame = { val etlPerfDF = PerformanceETL.from(perfDF) .prepare .createDelinquency(spark) .dataFrame val cleanDF = AcquisitionETL.from(acqDF) .createAcquisition(spark) .cleanPrime(etlPerfDF) .dataFrame // Convert to xgb required Dataset castStringColumnsToNumeric(cleanDF, spark) .select(allCols: _*) .withColumn(labelColName, when(col(labelColName) > 0, 1).otherwise(0)) .na.fill(0.0f) } def clean(): Unit = { if (cachedDictDF != null) { cachedDictDF.unpersist() cachedDictDF = null } } def saveDictTable(outPath: String): Unit = { if (cachedDictDF != null) { // The dict data is small, so merge it into one file. cachedDictDF .repartition(1) .write .mode("overwrite") .parquet(outPath) } } def csv(spark: SparkSession, dataPaths: Seq[String], tmpPath: String, hasHeader: Boolean): DataFrame = { val optionsMap = Map("header" -> hasHeader.toString) val rawDf_csv = CsvReader.readRaw(spark, dataPaths, optionsMap) rawDf_csv.write.mode("overwrite").parquet(tmpPath) val rawDf = spark.read.parquet(tmpPath) val perfDf = extractPerfColumns(rawDf) val acqDf = extractAcqColumns(rawDf) transform( perfDf, acqDf, spark ) } def parquet(spark: SparkSession, dataPaths: Seq[String]): DataFrame = { val rawDf = spark.read.parquet(dataPaths: _*) val perfDf = extractPerfColumns(rawDf) val acqDf = extractAcqColumns(rawDf) transform( perfDf, acqDf, spark ) } def orc(spark: SparkSession, dataPaths: Seq[String]): DataFrame = { val rawDf = spark.read.orc(dataPaths: _*) val perfDf = extractPerfColumns(rawDf) val acqDf = extractAcqColumns(rawDf) transform( perfDf, acqDf, spark ) } } ================================================ FILE: examples/XGBoost-Examples/pack_pyspark_example.sh ================================================ #!/bin/bash # Copyright (c) 2024-2025, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Follow these steps to package the Python zip file rm -fr samples.zip cd agaricus/python ; zip -r ../../samples.zip com ; cd ../.. cd mortgage/python ; zip -r ../../samples.zip com ; cd ../.. cd taxi/python ; zip -r ../../samples.zip com ; cd ../.. cd utility/python ; zip -r ../../samples.zip com ; cd ../.. ================================================ FILE: examples/XGBoost-Examples/pom.xml ================================================ 4.0.0 com.nvidia sample_xgboost_examples pom Sample XGBoost4J-Spark applications utility agaricus mortgage taxi aggregator 0.2.3-SNAPSHOT sample_xgboost_apps UTF-8 3.1.0-SNAPSHOT 3.5.0 2.12.8 2.12 ml.dmlc xgboost4j-spark-gpu_${scala.binary.version} ${xgboost.version} org.scala-lang scala-library ${scala.version} provided org.apache.spark spark-sql_${scala.binary.version} ${spark.version} provided org.apache.spark spark-mllib_${scala.binary.version} ${spark.version} provided org.scalatest scalatest_${scala.binary.version} 3.2.15 test org.scala-tools maven-scala-plugin 2.15.2 compile testCompile org.scalatest scalatest-maven-plugin 1.0 test test org.apache.maven.plugins maven-assembly-plugin 2.6 assembly/assembly-no-scala.xml assembly package single scala-2.13 2.1.0-SNAPSHOT 3.5.0 2.13.11 2.13 sonatype-repo sonatype-staging-repo Sonatype staging repo https://oss.sonatype.org/content/repositories/staging XGBoost4J Snapshot Repo XGBoost4J Snapshot Repo https://s3-us-west-2.amazonaws.com/xgboost-maven-repo/snapshot/ ================================================ FILE: examples/XGBoost-Examples/taxi/.gitignore ================================================ .idea target *.iml ================================================ FILE: examples/XGBoost-Examples/taxi/notebooks/python/cv-taxi-gpu.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Introduction to XGBoost-Spark Cross Validation with GPU\n", "\n", "The goal of this notebook is to show you how to levarage GPU to accelerate XGBoost spark cross validatoin for hyperparameter tuning. The best model for the given hyperparameters will be returned.\n", "\n", "Here takes the application 'Taxi' as an example.\n", "\n", "A few libraries are required for this notebook:\n", " 1. cudf-cu11\n", " 2. xgboost\n", " 3. scikit-learn\n", " 4. numpy" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Import the Required Libraries" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from xgboost.spark import SparkXGBRegressor, SparkXGBRegressorModel\n", "from pyspark.ml.tuning import ParamGridBuilder, CrossValidator\n", "from pyspark.ml.evaluation import RegressionEvaluator\n", "from pyspark.sql import SparkSession\n", "from pyspark.sql.types import FloatType, IntegerType, StructField, StructType\n", "from time import time\n", "from pyspark.conf import SparkConf\n", "import os\n", "# os.environ['PYSPARK_PYTHON'] = \"./environment/bin/python\"\n", "# os.environ['PYSPARK_DRIVER_PYTHON'] = \"./environment/bin/python\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Create a Spark Session" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-11-30 08:02:09,748 WARN util.NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n", "Setting default log level to \"WARN\".\n", "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n", "2022-11-30 08:02:10,103 WARN resource.ResourceUtils: The configuration of cores (exec = 2 task = 1, runnable tasks = 2) will result in wasted resources due to resource gpu limiting the number of runnable tasks per executor to: 1. Please adjust your configuration.\n", "2022-11-30 08:02:23,737 WARN rapids.RapidsPluginUtils: RAPIDS Accelerator 25.02.1 using cudf 25.02.1.\n", "2022-11-30 08:02:23,752 WARN rapids.RapidsPluginUtils: spark.rapids.sql.multiThreadedRead.numThreads is set to 20.\n", "2022-11-30 08:02:23,756 WARN rapids.RapidsPluginUtils: RAPIDS Accelerator is enabled, to disable GPU support set `spark.rapids.sql.enabled` to false.\n", "2022-11-30 08:02:23,757 WARN rapids.RapidsPluginUtils: spark.rapids.sql.explain is set to `NOT_ON_GPU`. Set it to 'NONE' to suppress the diagnostics logging about the query placement on the GPU.\n", "2022-11-30 08:02:24,226 WARN yarn.Client: Neither spark.yarn.jars nor spark.yarn.archive is set, falling back to uploading libraries under SPARK_HOME.\n" ] } ], "source": [ "SPARK_MASTER_URL = os.getenv(\"SPARK_MASTER_URL\", \"/your-url\")\n", "\n", "RAPIDS_JAR = os.getenv(\"RAPIDS_JAR\", \"/your-jar-path\")\n", "\n", "# You need to update with your real hardware resource \n", "driverMem = os.getenv(\"DRIVER_MEM\", \"2g\")\n", "executorMem = os.getenv(\"EXECUTOR_MEM\", \"2g\")\n", "pinnedPoolSize = os.getenv(\"PINNED_POOL_SIZE\", \"2g\")\n", "concurrentGpuTasks = os.getenv(\"CONCURRENT_GPU_TASKS\", \"2\")\n", "executorCores = int(os.getenv(\"EXECUTOR_CORES\", \"2\"))\n", "# Common spark settings\n", "conf = SparkConf()\n", "conf.setMaster(SPARK_MASTER_URL)\n", "conf.setAppName(\"Microbenchmark on GPU\")\n", "conf.set(\"spark.executor.instances\",\"1\")\n", "conf.set(\"spark.driver.memory\", driverMem)\n", "## The tasks will run on GPU memory, so there is no need to set a high host memory\n", "conf.set(\"spark.executor.memory\", executorMem)\n", "## The tasks will run on GPU cores, so there is no need to use many cpu cores\n", "conf.set(\"spark.executor.cores\", executorCores)\n", "\n", "# Plugin settings\n", "conf.set(\"spark.executor.resource.gpu.amount\", \"1\")\n", "conf.set(\"spark.rapids.sql.concurrentGpuTasks\", concurrentGpuTasks)\n", "conf.set(\"spark.rapids.memory.pinnedPool.size\", pinnedPoolSize)\n", "# since pyspark and xgboost share the same GPU, we disable RMM to avoid GPU OOM while training \n", "conf.set(\"spark.rapids.memory.gpu.pool\", \"NONE\")\n", "conf.set(\"spark.locality.wait\",\"0\")\n", "##############note: only support value=1 https://github.com/dmlc/xgboost/blame/master/python-package/xgboost/spark/core.py#L370-L374\n", "conf.set(\"spark.task.resource.gpu.amount\", 1) \n", "conf.set(\"spark.rapids.sql.enabled\", \"true\") \n", "conf.set(\"spark.plugins\", \"com.nvidia.spark.SQLPlugin\")\n", "conf.set(\"spark.sql.cache.serializer\",\"com.nvidia.spark.ParquetCachedBatchSerializer\")\n", "conf.set(\"spark.sql.execution.arrow.maxRecordsPerBatch\", 200000) \n", "conf.set(\"spark.driver.extraClassPath\", RAPIDS_JAR)\n", "conf.set(\"spark.executor.extraClassPath\", RAPIDS_JAR)\n", "# if you pass/unpack the archive file and enable the environment\n", "# conf.set(\"spark.yarn.dist.archives\", \"your_pyspark_venv.tar.gz#environment\")\n", "# Create spark session\n", "spark = SparkSession.builder.config(conf=conf).getOrCreate()\n", "\n", "reader = spark.read" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Specify the Data Schema and Load the Data" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "label = 'fare_amount'\n", "schema = StructType([\n", " StructField('vendor_id', FloatType()),\n", " StructField('passenger_count', FloatType()),\n", " StructField('trip_distance', FloatType()),\n", " StructField('pickup_longitude', FloatType()),\n", " StructField('pickup_latitude', FloatType()),\n", " StructField('rate_code', FloatType()),\n", " StructField('store_and_fwd', FloatType()),\n", " StructField('dropoff_longitude', FloatType()),\n", " StructField('dropoff_latitude', FloatType()),\n", " StructField(label, FloatType()),\n", " StructField('hour', FloatType()),\n", " StructField('year', IntegerType()),\n", " StructField('month', IntegerType()),\n", " StructField('day', FloatType()),\n", " StructField('day_of_week', FloatType()),\n", " StructField('is_weekend', FloatType()),\n", "])\n", "\n", "features = [ x.name for x in schema if x.name != label ]\n", "\n", "# You need to update them to your real paths!\n", "dataRoot = os.getenv(\"DATA_ROOT\", \"/data\")\n", "train_path = dataRoot + \"/taxi/csv/train\"\n", "eval_path = dataRoot + \"/taxi/csv/test\"\n", "\n", "data_format = 'csv'\n", "has_header = 'true'\n", "if data_format == 'csv':\n", " train_data = reader.schema(schema).option('header',has_header).csv(train_path)\n", " trans_data = reader.schema(schema).option('header',has_header).csv(eval_path)\n", "else :\n", " train_data = reader.load(train_path)\n", " trans_data = reader.load(eval_path)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Build a XGBoost-Spark CrossValidator" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# First build a regressor of GPU version using *setFeaturesCols* to set feature columns\n", "params = { \n", " \"tree_method\": \"hist\",\n", " \"grow_policy\": \"depthwise\",\n", " \"num_workers\": 1,\n", " \"device\": \"cuda\",\n", "}\n", "params['features_col'] = features\n", "params['label_col'] = label\n", "\n", "regressor = SparkXGBRegressor(**params)\n", "# Then build the evaluator and the hyperparameters\n", "evaluator = (RegressionEvaluator()\n", " .setLabelCol(label))\n", "param_grid = (ParamGridBuilder()\n", " .addGrid(regressor.max_depth, [3, 6])\n", " .addGrid(regressor.n_estimators, [100, 200])\n", " .build())\n", "# Finally the corss validator\n", "cross_validator = (CrossValidator()\n", " .setEstimator(regressor)\n", " .setEvaluator(evaluator)\n", " .setEstimatorParamMaps(param_grid)\n", " .setNumFolds(2))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Start Cross Validation by Fitting Data to CrossValidator" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "If features_cols param set, then features_col param is ignored.\n", "/data/home/yuanli/work/reviews/pr252/pyspark_venv_20221125/lib/python3.8/site-packages/xgboost/sklearn.py:808: UserWarning: Loading a native XGBoost model with Scikit-Learn interface.\n", " warnings.warn(\"Loading a native XGBoost model with Scikit-Learn interface.\")\n", "2022-11-30 08:03:14,308 WARN rapids.GpuOverrides: \n", "! cannot run on GPU because not all expressions can be replaced; GPU does not currently support the operator class org.apache.spark.sql.execution.DeserializeToObjectExec\n", " ! createexternalrow(prediction#889, fare_amount#890, 1.0#891, StructField(prediction,DoubleType,true), StructField(fare_amount,DoubleType,true), StructField(1.0,DoubleType,false)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.CreateExternalRow\n", " @Expression prediction#889 could run on GPU\n", " @Expression fare_amount#890 could run on GPU\n", " @Expression 1.0#891 could run on GPU\n", " !Expression obj#895 cannot run on GPU because expression AttributeReference obj#895 produces an unsupported type ObjectType(interface org.apache.spark.sql.Row)\n", "\n", "2022-11-30 08:03:14,317 WARN util.package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.\n", "If features_cols param set, then features_col param is ignored.\n", "2022-11-30 08:03:20,073 WARN rapids.GpuOverrides: \n", "! cannot run on GPU because not all expressions can be replaced; GPU does not currently support the operator class org.apache.spark.sql.execution.DeserializeToObjectExec\n", " ! createexternalrow(prediction#1789, fare_amount#1790, 1.0#1791, StructField(prediction,DoubleType,true), StructField(fare_amount,DoubleType,true), StructField(1.0,DoubleType,false)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.CreateExternalRow\n", " @Expression prediction#1789 could run on GPU\n", " @Expression fare_amount#1790 could run on GPU\n", " @Expression 1.0#1791 could run on GPU\n", " !Expression obj#1795 cannot run on GPU because expression AttributeReference obj#1795 produces an unsupported type ObjectType(interface org.apache.spark.sql.Row)\n", "\n", "If features_cols param set, then features_col param is ignored.\n", "2022-11-30 08:03:23,687 WARN rapids.GpuOverrides: \n", "! cannot run on GPU because not all expressions can be replaced; GPU does not currently support the operator class org.apache.spark.sql.execution.DeserializeToObjectExec\n", " ! createexternalrow(prediction#2689, fare_amount#2690, 1.0#2691, StructField(prediction,DoubleType,true), StructField(fare_amount,DoubleType,true), StructField(1.0,DoubleType,false)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.CreateExternalRow\n", " @Expression prediction#2689 could run on GPU\n", " @Expression fare_amount#2690 could run on GPU\n", " @Expression 1.0#2691 could run on GPU\n", " !Expression obj#2695 cannot run on GPU because expression AttributeReference obj#2695 produces an unsupported type ObjectType(interface org.apache.spark.sql.Row)\n", "\n", "If features_cols param set, then features_col param is ignored.\n", "2022-11-30 08:03:27,457 WARN rapids.GpuOverrides: \n", "! cannot run on GPU because not all expressions can be replaced; GPU does not currently support the operator class org.apache.spark.sql.execution.DeserializeToObjectExec\n", " ! createexternalrow(prediction#3589, fare_amount#3590, 1.0#3591, StructField(prediction,DoubleType,true), StructField(fare_amount,DoubleType,true), StructField(1.0,DoubleType,false)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.CreateExternalRow\n", " @Expression prediction#3589 could run on GPU\n", " @Expression fare_amount#3590 could run on GPU\n", " @Expression 1.0#3591 could run on GPU\n", " !Expression obj#3595 cannot run on GPU because expression AttributeReference obj#3595 produces an unsupported type ObjectType(interface org.apache.spark.sql.Row)\n", "\n", "If features_cols param set, then features_col param is ignored.\n", "2022-11-30 08:03:30,964 WARN rapids.GpuOverrides: \n", "! cannot run on GPU because not all expressions can be replaced; GPU does not currently support the operator class org.apache.spark.sql.execution.DeserializeToObjectExec\n", " ! createexternalrow(prediction#4659, fare_amount#4660, 1.0#4661, StructField(prediction,DoubleType,true), StructField(fare_amount,DoubleType,true), StructField(1.0,DoubleType,false)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.CreateExternalRow\n", " @Expression prediction#4659 could run on GPU\n", " @Expression fare_amount#4660 could run on GPU\n", " @Expression 1.0#4661 could run on GPU\n", " !Expression obj#4665 cannot run on GPU because expression AttributeReference obj#4665 produces an unsupported type ObjectType(interface org.apache.spark.sql.Row)\n", "\n", "If features_cols param set, then features_col param is ignored.\n", "2022-11-30 08:03:34,524 WARN rapids.GpuOverrides: \n", "! cannot run on GPU because not all expressions can be replaced; GPU does not currently support the operator class org.apache.spark.sql.execution.DeserializeToObjectExec\n", " ! createexternalrow(prediction#5559, fare_amount#5560, 1.0#5561, StructField(prediction,DoubleType,true), StructField(fare_amount,DoubleType,true), StructField(1.0,DoubleType,false)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.CreateExternalRow\n", " @Expression prediction#5559 could run on GPU\n", " @Expression fare_amount#5560 could run on GPU\n", " @Expression 1.0#5561 could run on GPU\n", " !Expression obj#5565 cannot run on GPU because expression AttributeReference obj#5565 produces an unsupported type ObjectType(interface org.apache.spark.sql.Row)\n", "\n", "If features_cols param set, then features_col param is ignored.\n", "2022-11-30 08:03:38,067 WARN rapids.GpuOverrides: \n", "! cannot run on GPU because not all expressions can be replaced; GPU does not currently support the operator class org.apache.spark.sql.execution.DeserializeToObjectExec\n", " ! createexternalrow(prediction#6459, fare_amount#6460, 1.0#6461, StructField(prediction,DoubleType,true), StructField(fare_amount,DoubleType,true), StructField(1.0,DoubleType,false)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.CreateExternalRow\n", " @Expression prediction#6459 could run on GPU\n", " @Expression fare_amount#6460 could run on GPU\n", " @Expression 1.0#6461 could run on GPU\n", " !Expression obj#6465 cannot run on GPU because expression AttributeReference obj#6465 produces an unsupported type ObjectType(interface org.apache.spark.sql.Row)\n", "\n", "If features_cols param set, then features_col param is ignored.\n", "2022-11-30 08:03:41,793 WARN rapids.GpuOverrides: \n", "! cannot run on GPU because not all expressions can be replaced; GPU does not currently support the operator class org.apache.spark.sql.execution.DeserializeToObjectExec\n", " ! createexternalrow(prediction#7359, fare_amount#7360, 1.0#7361, StructField(prediction,DoubleType,true), StructField(fare_amount,DoubleType,true), StructField(1.0,DoubleType,false)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.CreateExternalRow\n", " @Expression prediction#7359 could run on GPU\n", " @Expression fare_amount#7360 could run on GPU\n", " @Expression 1.0#7361 could run on GPU\n", " !Expression obj#7365 cannot run on GPU because expression AttributeReference obj#7365 produces an unsupported type ObjectType(interface org.apache.spark.sql.Row)\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "If features_cols param set, then features_col param is ignored.\n", "[Stage 34:> (0 + 1) / 1]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Cross-Validation takes 55.19 seconds\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\r", " \r" ] } ], "source": [ "def with_benchmark(phrase, action):\n", " start = time()\n", " result = action()\n", " end = time()\n", " print('{} takes {} seconds'.format(phrase, round(end - start, 2)))\n", " return result\n", "model = with_benchmark('Cross-Validation', lambda: cross_validator.fit(train_data)).bestModel" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Transform On the Best Model" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Transforming takes 0.23 seconds\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2022-11-30 08:03:45,503 WARN rapids.GpuOverrides: \n", "!Exec cannot run on GPU because the Exec CollectLimitExec has been disabled, and is disabled by default because Collect Limit replacement can be slower on the GPU, if huge number of rows in a batch it could help by limiting the number of rows transferred from GPU to CPU. Set spark.rapids.sql.exec.CollectLimitExec to true if you wish to enable it\n", " @Partitioning could run on GPU\n", "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "+-----------+-----------+\n", "|fare_amount| prediction|\n", "+-----------+-----------+\n", "| 5.0| 5.01032114|\n", "| 34.0| 31.134758|\n", "| 10.0|9.288980484|\n", "| 16.5|15.33446312|\n", "| 7.0|8.197098732|\n", "+-----------+-----------+\n", "only showing top 5 rows\n", "\n" ] } ], "source": [ "def transform():\n", " result = model.transform(trans_data).cache()\n", " result.foreachPartition(lambda _: None)\n", " return result\n", "result = with_benchmark('Transforming', transform)\n", "result.select(label, 'prediction').show(5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Evaluation" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Evaluation takes 0.05 seconds\n", "RMSE is 2.055690464034438\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2022-11-30 08:03:45,728 WARN rapids.GpuOverrides: \n", "! cannot run on GPU because not all expressions can be replaced; GPU does not currently support the operator class org.apache.spark.sql.execution.DeserializeToObjectExec\n", " ! createexternalrow(prediction#7645, fare_amount#8271, 1.0#8272, StructField(prediction,DoubleType,true), StructField(fare_amount,DoubleType,true), StructField(1.0,DoubleType,false)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.CreateExternalRow\n", " @Expression prediction#7645 could run on GPU\n", " @Expression fare_amount#8271 could run on GPU\n", " @Expression 1.0#8272 could run on GPU\n", " !Expression obj#8276 cannot run on GPU because expression AttributeReference obj#8276 produces an unsupported type ObjectType(interface org.apache.spark.sql.Row)\n", "\n" ] } ], "source": [ "accuracy = with_benchmark(\n", " 'Evaluation',\n", " lambda: RegressionEvaluator().setLabelCol(label).evaluate(result))\n", "print('RMSE is ' + str(accuracy))" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "spark.stop()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.2" } }, "nbformat": 4, "nbformat_minor": 2 } ================================================ FILE: examples/XGBoost-Examples/taxi/notebooks/python/taxi-ETL.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "id": "71bf747a", "metadata": {}, "source": [ "# Introduction to Taxi ETL Job\n", "This is the Taxi ETL job to generate the input datasets for the Taxi XGBoost job." ] }, { "cell_type": "markdown", "id": "f0524408", "metadata": {}, "source": [ "## Prerequirement\n", "### 1. Download data\n", "All data could be found at https://www1.nyc.gov/site/tlc/about/tlc-trip-record-data.page\n", "\n", "### 2. Download needed jars\n", "* [rapids-4-spark_2.12-26.02.0.jar](https://repo1.maven.org/maven2/com/nvidia/rapids-4-spark_2.12/26.02.0/rapids-4-spark_2.12-26.02.0.jar)\n", "\n", "### 3. Start Spark Standalone\n", "Before running the script, please setup Spark standalone mode\n", "\n", "### 4. Add ENV\n", "```\n", "$ export SPARK_JARS=rapids-4-spark_2.12-26.02.0.jar\n", "$ export PYSPARK_DRIVER_PYTHON=jupyter \n", "$ export PYSPARK_DRIVER_PYTHON_OPTS=notebook\n", "```\n", "\n", "### 5. Start Jupyter Notebook with plugin config\n", "\n", "```\n", "$ pyspark --master ${SPARK_MASTER} \\\n", "--jars ${SPARK_JARS} \\\n", "--conf spark.plugins=com.nvidia.spark.SQLPlugin \\\n", "--conf spark.rapids.sql.incompatibleDateFormats.enabled=true \\\n", "--conf spark.rapids.sql.csv.read.double.enabled=true \\\n", "--py-files ${SPARK_PY_FILES}\n", "```\n", "\n", "## Import Libs" ] }, { "cell_type": "code", "execution_count": 1, "id": "d2283aab", "metadata": {}, "outputs": [], "source": [ "import time\n", "import os\n", "import math\n", "from pyspark.sql import SparkSession\n", "from pyspark.sql.functions import *\n", "from pyspark.sql.types import *" ] }, { "cell_type": "markdown", "id": "f7ffcace", "metadata": {}, "source": [ "## Script Settings\n", "\n", "### File Path Settings\n", "* Define input/output file path" ] }, { "cell_type": "code", "execution_count": 2, "id": "b348778a", "metadata": {}, "outputs": [], "source": [ "# You need to update them to your real paths! You can download the dataset \n", "# from https://www1.nyc.gov/site/tlc/about/tlc-trip-record-data.page\n", "# or you can just unzip datasets/taxi-small.tar.gz and use the provided\n", "# sample dataset datasets/taxi/taxi-etl-input-small.csv\n", "dataRoot = os.getenv('DATA_ROOT', '/data')\n", "rawPath = dataRoot + '/taxi/taxi-etl-input-small.csv'\n", "outPath = dataRoot + '/taxi/output'" ] }, { "cell_type": "markdown", "id": "0a500530", "metadata": {}, "source": [ "## Function and Object Define\n", "### Define the constants\n", "\n", "* Define input file schema" ] }, { "cell_type": "code", "execution_count": 3, "id": "094f31c5", "metadata": {}, "outputs": [], "source": [ "raw_schema = StructType([\n", " StructField('vendor_id', StringType()),\n", " StructField('pickup_datetime', StringType()),\n", " StructField('dropoff_datetime', StringType()),\n", " StructField('passenger_count', IntegerType()),\n", " StructField('trip_distance', DoubleType()),\n", " StructField('pickup_longitude', DoubleType()),\n", " StructField('pickup_latitude', DoubleType()),\n", " StructField('rate_code', StringType()),\n", " StructField('store_and_fwd_flag', StringType()),\n", " StructField('dropoff_longitude', DoubleType()),\n", " StructField('dropoff_latitude', DoubleType()),\n", " StructField('payment_type', StringType()),\n", " StructField('fare_amount', DoubleType()),\n", " StructField('surcharge', DoubleType()),\n", " StructField('mta_tax', DoubleType()),\n", " StructField('tip_amount', DoubleType()),\n", " StructField('tolls_amount', DoubleType()),\n", " StructField('total_amount', DoubleType()),\n", "])" ] }, { "cell_type": "markdown", "id": "72a4ae18", "metadata": {}, "source": [ "* Define some ETL functions" ] }, { "cell_type": "code", "execution_count": 4, "id": "b45b7606", "metadata": {}, "outputs": [], "source": [ "def drop_useless(data_frame):\n", " return data_frame.drop(\n", " 'dropoff_datetime',\n", " 'payment_type',\n", " 'surcharge',\n", " 'mta_tax',\n", " 'tip_amount',\n", " 'tolls_amount',\n", " 'total_amount')" ] }, { "cell_type": "code", "execution_count": 5, "id": "7af7073d", "metadata": {}, "outputs": [], "source": [ "def encode_categories(data_frame):\n", " categories = [ 'vendor_id', 'rate_code', 'store_and_fwd_flag' ]\n", " for category in categories:\n", " data_frame = data_frame.withColumn(category, hash(col(category)))\n", " return data_frame.withColumnRenamed(\"store_and_fwd_flag\", \"store_and_fwd\")" ] }, { "cell_type": "code", "execution_count": 6, "id": "b799cd5a", "metadata": {}, "outputs": [], "source": [ "def fill_na(data_frame):\n", " return data_frame.fillna(-1)" ] }, { "cell_type": "code", "execution_count": 7, "id": "ceee5c7c", "metadata": {}, "outputs": [], "source": [ "def remove_invalid(data_frame):\n", " conditions = [\n", " ( 'fare_amount', 0, 500 ),\n", " ( 'passenger_count', 0, 6 ),\n", " ( 'pickup_longitude', -75, -73 ),\n", " ( 'dropoff_longitude', -75, -73 ),\n", " ( 'pickup_latitude', 40, 42 ),\n", " ( 'dropoff_latitude', 40, 42 ),\n", " ]\n", " for column, min, max in conditions:\n", " data_frame = data_frame.filter('{} > {} and {} < {}'.format(column, min, column, max))\n", " return data_frame" ] }, { "cell_type": "code", "execution_count": 8, "id": "bd28ae14", "metadata": {}, "outputs": [], "source": [ "def convert_datetime(data_frame):\n", " datetime = col('pickup_datetime')\n", " return (data_frame\n", " .withColumn('pickup_datetime', to_timestamp(datetime))\n", " .withColumn('year', year(datetime))\n", " .withColumn('month', month(datetime))\n", " .withColumn('day', dayofmonth(datetime))\n", " .withColumn('day_of_week', dayofweek(datetime))\n", " .withColumn(\n", " 'is_weekend',\n", " col('day_of_week').isin(1, 7).cast(IntegerType())) # 1: Sunday, 7: Saturday\n", " .withColumn('hour', hour(datetime))\n", " .drop('pickup_datetime'))" ] }, { "cell_type": "code", "execution_count": 9, "id": "39e45f15", "metadata": {}, "outputs": [], "source": [ "def add_h_distance(data_frame):\n", " p = math.pi / 180\n", " lat1 = col('pickup_latitude')\n", " lon1 = col('pickup_longitude')\n", " lat2 = col('dropoff_latitude')\n", " lon2 = col('dropoff_longitude')\n", " internal_value = (0.5\n", " - cos((lat2 - lat1) * p) / 2\n", " + cos(lat1 * p) * cos(lat2 * p) * (1 - cos((lon2 - lon1) * p)) / 2)\n", " h_distance = 12734 * asin(sqrt(internal_value))\n", " return data_frame.withColumn('h_distance', h_distance)" ] }, { "cell_type": "markdown", "id": "d52b062c", "metadata": {}, "source": [ "* Define main ETL function" ] }, { "cell_type": "code", "execution_count": 10, "id": "9fd36618", "metadata": {}, "outputs": [], "source": [ "def pre_process(data_frame):\n", " processes = [\n", " drop_useless,\n", " encode_categories,\n", " fill_na,\n", " remove_invalid,\n", " convert_datetime,\n", " add_h_distance,\n", " ]\n", " for process in processes:\n", " data_frame = process(data_frame)\n", " return data_frame" ] }, { "cell_type": "markdown", "id": "2798f19a", "metadata": {}, "source": [ "## Run ETL Process and Save the Result\n", "* Create Spark Session and create dataframe" ] }, { "cell_type": "code", "execution_count": 11, "id": "26ca4ca6", "metadata": {}, "outputs": [], "source": [ "spark = (SparkSession\n", " .builder\n", " .appName(\"Taxi-ETL\")\n", " .getOrCreate())\n", "reader = (spark\n", " .read\n", " .format('csv'))\n", "reader.schema(raw_schema).option('header', 'True')\n", "\n", "raw_data = reader.load(rawPath)" ] }, { "cell_type": "markdown", "id": "6243b736", "metadata": {}, "source": [ "* Run ETL Process and Save the Result" ] }, { "cell_type": "code", "execution_count": 12, "id": "27f2119b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "5.114504098892212\n" ] } ], "source": [ "start = time.time()\n", "etled_train, etled_eval, etled_trans = pre_process(raw_data).randomSplit(list(map(float, (80,20,0))))\n", "etled_train.write.mode(\"overwrite\").parquet(outPath+'/train')\n", "etled_eval.write.mode(\"overwrite\").parquet(outPath+'/eval')\n", "etled_trans.write.mode(\"overwrite\").parquet(outPath+'/trans')\n", "end = time.time()\n", "print(end - start)\n", "spark.stop()" ] }, { "cell_type": "code", "execution_count": null, "id": "91af3c97", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.2" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: examples/XGBoost-Examples/taxi/notebooks/python/taxi-gpu.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Introduction to XGBoost Spark3.1 with GPU\n", "\n", "Taxi is an example of xgboost regressor. This notebook will show you how to load data, train the xgboost model and use this model to predict \"fare_amount\" of your taxi trip.\n", "\n", "A few libraries required for this notebook:\n", " 1. cudf-cu11\n", " 2. xgboost\n", " 3. scikit-learn\n", " 4. numpy\n", "\n", "This notebook also illustrates the ease of porting a sample CPU based Spark xgboost4j code into GPU. There is no change required for running Spark XGBoost on GPU because both CPU and GPU call the same API. For CPU run, we need to vectorize the trained dataset before fitting data to regressor." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Import Required Libraries" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from xgboost.spark import SparkXGBRegressor, SparkXGBRegressorModel\n", "from pyspark.ml.evaluation import RegressionEvaluator\n", "from pyspark.sql import SparkSession\n", "from pyspark.sql.types import FloatType, IntegerType, StructField, StructType\n", "from time import time\n", "from pyspark.conf import SparkConf\n", "import os\n", "# if you pass/unpack the archive file and enable the environment\n", "# os.environ['PYSPARK_PYTHON'] = \"./environment/bin/python\"\n", "# os.environ['PYSPARK_DRIVER_PYTHON'] = \"./environment/bin/python\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Besides CPU version requires two extra libraries.\n", "```Python\n", "from pyspark.ml.feature import VectorAssembler\n", "from pyspark.sql.functions import col\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Create Spark Session and Data Reader" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-11-30 07:51:19,104 WARN util.NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n", "Setting default log level to \"WARN\".\n", "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n", "2022-11-30 07:51:19,480 WARN resource.ResourceUtils: The configuration of cores (exec = 2 task = 1, runnable tasks = 2) will result in wasted resources due to resource gpu limiting the number of runnable tasks per executor to: 1. Please adjust your configuration.\n", "2022-11-30 07:51:33,277 WARN rapids.RapidsPluginUtils: RAPIDS Accelerator 25.02.1 using cudf 25.02.1.\n", "2022-11-30 07:51:33,292 WARN rapids.RapidsPluginUtils: spark.rapids.sql.multiThreadedRead.numThreads is set to 20.\n", "2022-11-30 07:51:33,295 WARN rapids.RapidsPluginUtils: RAPIDS Accelerator is enabled, to disable GPU support set `spark.rapids.sql.enabled` to false.\n", "2022-11-30 07:51:33,295 WARN rapids.RapidsPluginUtils: spark.rapids.sql.explain is set to `NOT_ON_GPU`. Set it to 'NONE' to suppress the diagnostics logging about the query placement on the GPU.\n", "2022-11-30 07:51:33,798 WARN yarn.Client: Neither spark.yarn.jars nor spark.yarn.archive is set, falling back to uploading libraries under SPARK_HOME.\n" ] } ], "source": [ "SPARK_MASTER_URL = os.getenv(\"SPARK_MASTER_URL\", \"/your-url\")\n", "\n", "RAPIDS_JAR = os.getenv(\"RAPIDS_JAR\", \"/your-jar-path\")\n", "\n", "# You need to update with your real hardware resource \n", "driverMem = os.getenv(\"DRIVER_MEM\", \"2g\")\n", "executorMem = os.getenv(\"EXECUTOR_MEM\", \"2g\")\n", "pinnedPoolSize = os.getenv(\"PINNED_POOL_SIZE\", \"2g\")\n", "concurrentGpuTasks = os.getenv(\"CONCURRENT_GPU_TASKS\", \"2\")\n", "executorCores = int(os.getenv(\"EXECUTOR_CORES\", \"2\"))\n", "# Common spark settings\n", "conf = SparkConf()\n", "conf.setMaster(SPARK_MASTER_URL)\n", "conf.setAppName(\"Microbenchmark on GPU\")\n", "conf.set(\"spark.executor.instances\",\"1\")\n", "conf.set(\"spark.driver.memory\", driverMem)\n", "## The tasks will run on GPU memory, so there is no need to set a high host memory\n", "conf.set(\"spark.executor.memory\", executorMem)\n", "## The tasks will run on GPU cores, so there is no need to use many cpu cores\n", "conf.set(\"spark.executor.cores\", executorCores)\n", "\n", "# Plugin settings\n", "conf.set(\"spark.executor.resource.gpu.amount\", \"1\")\n", "conf.set(\"spark.rapids.sql.concurrentGpuTasks\", concurrentGpuTasks)\n", "conf.set(\"spark.rapids.memory.pinnedPool.size\", pinnedPoolSize)\n", "# since pyspark and xgboost share the same GPU, we disable RMM to avoid GPU OOM while training \n", "conf.set(\"spark.rapids.memory.gpu.pool\", \"NONE\")\n", "conf.set(\"spark.locality.wait\",\"0\")\n", "##############note: only support value=1 https://github.com/dmlc/xgboost/blame/master/python-package/xgboost/spark/core.py#L370-L374\n", "conf.set(\"spark.task.resource.gpu.amount\", 1) \n", "conf.set(\"spark.rapids.sql.enabled\", \"true\") \n", "conf.set(\"spark.plugins\", \"com.nvidia.spark.SQLPlugin\")\n", "conf.set(\"spark.sql.cache.serializer\",\"com.nvidia.spark.ParquetCachedBatchSerializer\")\n", "conf.set(\"spark.sql.execution.arrow.maxRecordsPerBatch\", 200000) \n", "conf.set(\"spark.driver.extraClassPath\", RAPIDS_JAR)\n", "conf.set(\"spark.executor.extraClassPath\", RAPIDS_JAR)\n", "\n", "# if you pass/unpack the archive file and enable the environment\n", "# conf.set(\"spark.yarn.dist.archives\", \"your_pyspark_venv.tar.gz#environment\")\n", "# Create spark session\n", "spark = SparkSession.builder.config(conf=conf).getOrCreate()\n", "\n", "reader = spark.read" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Specify the Data Schema and Load the Data" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "label = 'fare_amount'\n", "schema = StructType([\n", " StructField('vendor_id', FloatType()),\n", " StructField('passenger_count', FloatType()),\n", " StructField('trip_distance', FloatType()),\n", " StructField('pickup_longitude', FloatType()),\n", " StructField('pickup_latitude', FloatType()),\n", " StructField('rate_code', FloatType()),\n", " StructField('store_and_fwd', FloatType()),\n", " StructField('dropoff_longitude', FloatType()),\n", " StructField('dropoff_latitude', FloatType()),\n", " StructField(label, FloatType()),\n", " StructField('hour', FloatType()),\n", " StructField('year', IntegerType()),\n", " StructField('month', IntegerType()),\n", " StructField('day', FloatType()),\n", " StructField('day_of_week', FloatType()),\n", " StructField('is_weekend', FloatType()),\n", "])\n", "features = [ x.name for x in schema if x.name != label ]\n", "\n", "# You need to update them to your real paths!\n", "dataRoot = os.getenv(\"DATA_ROOT\", \"/data\")\n", "train_path = dataRoot + \"/taxi/csv/train\"\n", "eval_path = dataRoot + \"/taxi/csv/test\"\n", "\n", "data_format = 'csv'\n", "has_header = 'true'\n", "if data_format == 'csv':\n", " train_data = reader.schema(schema).option('header',has_header).csv(train_path)\n", " trans_data = reader.schema(schema).option('header',has_header).csv(eval_path)\n", "else :\n", " train_data = reader.load(train_path)\n", " trans_data = reader.load(eval_path)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note on CPU version, vectorization is required before fitting data to regressor, which means you need to assemble all feature columns into one column.\n", "\n", "```Python\n", "def vectorize(data_frame):\n", " to_floats = [ col(x.name).cast(FloatType()) for x in data_frame.schema ]\n", " return (VectorAssembler()\n", " .setInputCols(features)\n", " .setOutputCol('features')\n", " .transform(data_frame.select(to_floats))\n", " .select(col('features'), col(label)))\n", "\n", "train_data = vectorize(train_data)\n", "trans_data = vectorize(trans_data)\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Create a XGBoostRegressor" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "params = { \n", " \"tree_method\": \"hist\",\n", " \"grow_policy\": \"depthwise\",\n", " \"num_workers\": 1,\n", " \"device\": \"cuda\",\n", "}\n", "params['features_col'] = features\n", "params['label_col'] = label\n", " \n", "regressor = SparkXGBRegressor(**params)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The parameter `num_workers` should be set to the number of GPUs in Spark cluster for GPU version, while for CPU version it is usually equal to the number of the CPU cores.\n", "\n", "Concerning the `device`, GPU version only supports `cuda` currently, while `cpu` is designed and used here for CPU training.\n", "\n", "An example of CPU classifier:\n", "```\n", "classifier = SparkXGBClassifier(\n", " feature_col=features,\n", " label_col=label, \n", " num_workers=1024,\n", ")\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Train the Data with Benchmark" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "scrolled": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "If features_cols param set, then features_col param is ignored.\n", "[Stage 2:> (0 + 1) / 1]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training takes 24.12 seconds\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\r", " \r", "/data/home/yuanli/work/reviews/pr252/pyspark_venv_20221125/lib/python3.8/site-packages/xgboost/sklearn.py:808: UserWarning: Loading a native XGBoost model with Scikit-Learn interface.\n", " warnings.warn(\"Loading a native XGBoost model with Scikit-Learn interface.\")\n" ] } ], "source": [ "def with_benchmark(phrase, action):\n", " start = time()\n", " result = action()\n", " end = time()\n", " print('{} takes {} seconds'.format(phrase, round(end - start, 2)))\n", " return result\n", "model = with_benchmark('Training', lambda: regressor.fit(train_data))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Save and Reload the Model" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "If features_cols param set, then features_col param is ignored.\n" ] } ], "source": [ "model.write().overwrite().save(dataRoot + '/model/taxi')" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "loaded_model = SparkXGBRegressorModel().load(dataRoot + '/model/taxi')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Transformation and Show Result Sample" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "scrolled": false }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-11-30 07:52:27,357 WARN util.package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Transformation takes 0.93 seconds\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2022-11-30 07:52:28,189 WARN rapids.GpuOverrides: \n", "!Exec cannot run on GPU because the Exec CollectLimitExec has been disabled, and is disabled by default because Collect Limit replacement can be slower on the GPU, if huge number of rows in a batch it could help by limiting the number of rows transferred from GPU to CPU. Set spark.rapids.sql.exec.CollectLimitExec to true if you wish to enable it\n", " @Partitioning could run on GPU\n", "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "+--------------+---------------+-------------+-----------+-----------+\n", "| vendor_id|passenger_count|trip_distance|fare_amount| prediction|\n", "+--------------+---------------+-------------+-----------+-----------+\n", "|1.559730432E09| 2.0| 0.699999988| 5.0|5.046935558|\n", "|1.559730432E09| 3.0| 10.69999981| 34.0|31.72706413|\n", "|1.559730432E09| 1.0| 2.299999952| 10.0|9.294451714|\n", "|1.559730432E09| 1.0| 4.400000095| 16.5|15.05233097|\n", "|1.559730432E09| 1.0| 1.5| 7.0|8.995832443|\n", "+--------------+---------------+-------------+-----------+-----------+\n", "only showing top 5 rows\n", "\n" ] } ], "source": [ "def transform():\n", " result = loaded_model.transform(trans_data).cache()\n", " result.foreachPartition(lambda _: None)\n", " return result\n", "result = with_benchmark('Transformation', transform)\n", "result.select('vendor_id', 'passenger_count', 'trip_distance', label, 'prediction').show(5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note on CPU version: You cannot `select` the feature columns after vectorization. So please use `result.show(5)` instead." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Evaluation" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Evaluation takes 0.22 seconds\n", "RMSE is 1.9141528471228921\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2022-11-30 07:52:28,580 WARN rapids.GpuOverrides: \n", "! cannot run on GPU because not all expressions can be replaced; GPU does not currently support the operator class org.apache.spark.sql.execution.DeserializeToObjectExec\n", " ! createexternalrow(prediction#87, fare_amount#728, 1.0#729, StructField(prediction,DoubleType,true), StructField(fare_amount,DoubleType,true), StructField(1.0,DoubleType,false)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.CreateExternalRow\n", " @Expression prediction#87 could run on GPU\n", " @Expression fare_amount#728 could run on GPU\n", " @Expression 1.0#729 could run on GPU\n", " !Expression obj#733 cannot run on GPU because expression AttributeReference obj#733 produces an unsupported type ObjectType(interface org.apache.spark.sql.Row)\n", "\n" ] } ], "source": [ "accuracy = with_benchmark(\n", " 'Evaluation',\n", " lambda: RegressionEvaluator().setLabelCol(label).evaluate(result))\n", "print('RMSE is ' + str(accuracy))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Stop" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "spark.stop()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.2" } }, "nbformat": 4, "nbformat_minor": 2 } ================================================ FILE: examples/XGBoost-Examples/taxi/notebooks/scala/taxi-ETL.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "id": "e0336840", "metadata": {}, "source": [ "# Introduction to Taxi ETL Job\n", "This is the Taxi ETL job to generate the input datasets for the Taxi XGBoost job." ] }, { "cell_type": "markdown", "id": "86fd8ad9", "metadata": {}, "source": [ "## Prerequirement\n", "### 1. Download data\n", "All data could be found at https://www1.nyc.gov/site/tlc/about/tlc-trip-record-data.page\n", "\n", "### 2. Download needed jar\n", "* [rapids-4-spark_2.12-26.02.0.jar](https://repo1.maven.org/maven2/com/nvidia/rapids-4-spark_2.12/26.02.0/rapids-4-spark_2.12-26.02.0.jar)\n", "\n", "### 3. Start Spark Standalone\n", "Before running the script, please setup Spark standalone mode\n", "\n", "### 4. Add ENV\n", "```\n", "$ export SPARK_JARS=rapids-4-spark_2.12-26.02.0.jar\n", "\n", "```\n", "\n", "### 5.Start Jupyter Notebook with spylon-kernel or toree\n", "\n", "```\n", "$ jupyter notebook --allow-root --notebook-dir=${your-dir} --config=${your-configs}\n", "```\n", "\n", "## Import Libs" ] }, { "cell_type": "code", "execution_count": 1, "id": "1e50cfad", "metadata": {}, "outputs": [], "source": [ "import org.apache.spark.sql.SparkSession\n", "import org.apache.spark.sql.DataFrame\n", "import org.apache.spark.sql.functions._\n", "import org.apache.spark.sql.types.DataTypes.{DoubleType, IntegerType, StringType}\n", "import org.apache.spark.sql.types.{FloatType, StructField, StructType}" ] }, { "cell_type": "markdown", "id": "24f69140", "metadata": {}, "source": [ "## Script Settings\n", "\n", "### 1. File Path Settings\n", "* Define input file path" ] }, { "cell_type": "code", "execution_count": 6, "id": "317b9415", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "lastException = null\n", "dataRoot = /data\n", "rawPath = /data/taxi/taxi-etl-input-small.csv\n", "outPath = /data/datasets/taxi/output\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "/data/taxi/output" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "val dataRoot = sys.env.getOrElse(\"DATA_ROOT\", \"/data\")\n", "val rawPath = dataRoot + \"/taxi/taxi-etl-input-small.csv\"\n", "val outPath = dataRoot + \"/taxi/output\"" ] }, { "cell_type": "markdown", "id": "6f036d30", "metadata": {}, "source": [ "## Function and Object Define\n", "### Define the constants\n", "\n", "* Define input file schema" ] }, { "cell_type": "code", "execution_count": 7, "id": "acc23ac1", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "rawSchema = StructType(StructField(vendor_id,StringType,true), StructField(pickup_datetime,StringType,true), StructField(dropoff_datetime,StringType,true), StructField(passenger_count,IntegerType,true), StructField(trip_distance,DoubleType,true), StructField(pickup_longitude,DoubleType,true), StructField(pickup_latitude,DoubleType,true), StructField(rate_code,StringType,true), StructField(store_and_fwd_flag,StringType,true), StructField(dropoff_longitude,DoubleType,true), StructField(dropoff_latitude,DoubleType,true), StructField(payment_type,StringType,true), StructField(fare_amount,DoubleType,true), StructField(surcharge,DoubleType,true), StructField(mta_tax,DoubleType,true), StructField(tip_amount,DoubleType,true), StructField(tolls_amount,Doubl...\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "StructType(StructField(vendor_id,StringType,true), StructField(pickup_datetime,StringType,true), StructField(dropoff_datetime,StringType,true), StructField(passenger_count,IntegerType,true), StructField(trip_distance,DoubleType,true), StructField(pickup_longitude,DoubleType,true), StructField(pickup_latitude,DoubleType,true), StructField(rate_code,StringType,true), StructField(store_and_fwd_flag,StringType,true), StructField(dropoff_longitude,DoubleType,true), StructField(dropoff_latitude,DoubleType,true), StructField(payment_type,StringType,true), StructField(fare_amount,DoubleType,true), StructField(surcharge,DoubleType,true), StructField(mta_tax,DoubleType,true), StructField(tip_amount,DoubleType,true), StructField(tolls_amount,Doubl..." ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "val rawSchema = StructType(Seq(\n", " StructField(\"vendor_id\", StringType),\n", " StructField(\"pickup_datetime\", StringType),\n", " StructField(\"dropoff_datetime\", StringType),\n", " StructField(\"passenger_count\", IntegerType),\n", " StructField(\"trip_distance\", DoubleType),\n", " StructField(\"pickup_longitude\", DoubleType),\n", " StructField(\"pickup_latitude\", DoubleType),\n", " StructField(\"rate_code\", StringType),\n", " StructField(\"store_and_fwd_flag\", StringType),\n", " StructField(\"dropoff_longitude\", DoubleType),\n", " StructField(\"dropoff_latitude\", DoubleType),\n", " StructField(\"payment_type\", StringType),\n", " StructField(\"fare_amount\", DoubleType),\n", " StructField(\"surcharge\", DoubleType),\n", " StructField(\"mta_tax\", DoubleType),\n", " StructField(\"tip_amount\", DoubleType),\n", " StructField(\"tolls_amount\", DoubleType),\n", " StructField(\"total_amount\", DoubleType)\n", " ))" ] }, { "cell_type": "code", "execution_count": 8, "id": "2e467519", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "trainRatio = 80\n", "evalRatio = 20\n", "trainEvalRatio = 0\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "dataRatios: (Int, Int, Int)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def dataRatios: (Int, Int, Int) = {\n", " val ratios = (80, 20)\n", " (ratios._1, ratios._2, 100 - ratios._1 - ratios._2)\n", " }\n", "val (trainRatio, evalRatio, trainEvalRatio) = dataRatios" ] }, { "cell_type": "markdown", "id": "5c2024d7", "metadata": {}, "source": [ "* Build the spark session and dataframe" ] }, { "cell_type": "code", "execution_count": 9, "id": "b551ca1d", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "sparkSession = org.apache.spark.sql.SparkSession@68530eb7\n", "df = [vendor_id: string, pickup_datetime: string ... 16 more fields]\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "[vendor_id: string, pickup_datetime: string ... 16 more fields]" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "// Build the spark session and data reader as usual\n", "val sparkSession = SparkSession.builder.appName(\"taxi-etl\").getOrCreate\n", "val df = sparkSession.read.option(\"header\", true).schema(rawSchema).csv(rawPath)" ] }, { "cell_type": "markdown", "id": "2f50ff7d", "metadata": {}, "source": [ "* Define some ETL functions" ] }, { "cell_type": "code", "execution_count": 10, "id": "3ca5738f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "dropUseless: (dataFrame: org.apache.spark.sql.DataFrame)org.apache.spark.sql.DataFrame\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "def dropUseless(dataFrame: DataFrame): DataFrame = {\n", " dataFrame.drop(\n", " \"dropoff_datetime\",\n", " \"payment_type\",\n", " \"surcharge\",\n", " \"mta_tax\",\n", " \"tip_amount\",\n", " \"tolls_amount\",\n", " \"total_amount\")\n", " }" ] }, { "cell_type": "code", "execution_count": 11, "id": "852b06c3", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "encodeCategories: (dataFrame: org.apache.spark.sql.DataFrame)org.apache.spark.sql.DataFrame\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "def encodeCategories(dataFrame: DataFrame): DataFrame = {\n", " val categories = Seq(\"vendor_id\", \"rate_code\", \"store_and_fwd_flag\")\n", "\n", " (categories.foldLeft(dataFrame) {\n", " case (df, category) => df.withColumn(category, hash(col(category)))\n", " }).withColumnRenamed(\"store_and_fwd_flag\", \"store_and_fwd\")\n", " }" ] }, { "cell_type": "code", "execution_count": 12, "id": "dbf0ab75", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "fillNa: (dataFrame: org.apache.spark.sql.DataFrame)org.apache.spark.sql.DataFrame\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "def fillNa(dataFrame: DataFrame): DataFrame = {\n", " dataFrame.na.fill(-1)\n", " }" ] }, { "cell_type": "code", "execution_count": 13, "id": "39308a05", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "removeInvalid: (dataFrame: org.apache.spark.sql.DataFrame)org.apache.spark.sql.DataFrame\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "def removeInvalid(dataFrame: DataFrame): DataFrame = {\n", " val conditions = Seq(\n", " Seq(\"fare_amount\", 0, 500),\n", " Seq(\"passenger_count\", 0, 6),\n", " Seq(\"pickup_longitude\", -75, -73),\n", " Seq(\"dropoff_longitude\", -75, -73),\n", " Seq(\"pickup_latitude\", 40, 42),\n", " Seq(\"dropoff_latitude\", 40, 42))\n", "\n", " conditions\n", " .map { case Seq(column, min, max) => \"%s > %d and %s < %d\".format(column, min, column, max) }\n", " .foldLeft(dataFrame) {\n", " _.filter(_)\n", " }\n", " }" ] }, { "cell_type": "code", "execution_count": 14, "id": "11cd052b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "convertDatetime: (dataFrame: org.apache.spark.sql.DataFrame)org.apache.spark.sql.DataFrame\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "def convertDatetime(dataFrame: DataFrame): DataFrame = {\n", " val datetime = col(\"pickup_datetime\")\n", " dataFrame\n", " .withColumn(\"pickup_datetime\", to_timestamp(datetime))\n", " .withColumn(\"year\", year(datetime))\n", " .withColumn(\"month\", month(datetime))\n", " .withColumn(\"day\", dayofmonth(datetime))\n", " .withColumn(\"day_of_week\", dayofweek(datetime))\n", " .withColumn(\n", " \"is_weekend\",\n", " col(\"day_of_week\").isin(1, 7).cast(IntegerType)) // 1: Sunday, 7: Saturday\n", " .withColumn(\"hour\", hour(datetime))\n", " .drop(datetime.toString)\n", " }" ] }, { "cell_type": "code", "execution_count": 15, "id": "71e1b568", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "addHDistance: (dataFrame: org.apache.spark.sql.DataFrame)org.apache.spark.sql.DataFrame\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "def addHDistance(dataFrame: DataFrame): DataFrame = {\n", " val P = math.Pi / 180\n", " val lat1 = col(\"pickup_latitude\")\n", " val lon1 = col(\"pickup_longitude\")\n", " val lat2 = col(\"dropoff_latitude\")\n", " val lon2 = col(\"dropoff_longitude\")\n", " val internalValue = (lit(0.5)\n", " - cos((lat2 - lat1) * P) / 2\n", " + cos(lat1 * P) * cos(lat2 * P) * (lit(1) - cos((lon2 - lon1) * P)) / 2)\n", " val hDistance = lit(12734) * asin(sqrt(internalValue))\n", " dataFrame.withColumn(\"h_distance\", hDistance)\n", " }" ] }, { "cell_type": "markdown", "id": "6fe805d5", "metadata": {}, "source": [ "* Define main ETL function" ] }, { "cell_type": "code", "execution_count": 19, "id": "6da3b832", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "preProcess: (dataFrame: org.apache.spark.sql.DataFrame, splits: Array[Int])Array[org.apache.spark.sql.DataFrame]\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "def preProcess(dataFrame: DataFrame, splits: Array[Int]): Array[DataFrame] = {\n", " val processes = Seq[DataFrame => DataFrame](\n", " dropUseless,\n", " encodeCategories,\n", " fillNa,\n", " removeInvalid,\n", " convertDatetime,\n", " addHDistance\n", " )\n", "\n", " processes\n", " .foldLeft(dataFrame) { case (df, process) => process(df) }\n", " .randomSplit(splits.map(_.toDouble))\n", " }" ] }, { "cell_type": "code", "execution_count": 20, "id": "85541b03", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "dataset = Array([vendor_id: int, passenger_count: int ... 15 more fields], [vendor_id: int, passenger_count: int ... 15 more fields], [vendor_id: int, passenger_count: int ... 15 more fields])\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "Array([vendor_id: int, passenger_count: int ... 15 more fields], [vendor_id: int, passenger_count: int ... 15 more fields], [vendor_id: int, passenger_count: int ... 15 more fields])" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "val dataset = preProcess(df, Array(trainRatio, trainEvalRatio, evalRatio))" ] }, { "cell_type": "markdown", "id": "6787cac7", "metadata": {}, "source": [ "## Run ETL Process and Save the Result" ] }, { "cell_type": "code", "execution_count": 21, "id": "371886e8", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Elapsed time : 4.371s\n" ] }, { "data": { "text/plain": [ "t0 = 1654139600797\n", "t1 = 1654139605168\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "1654139605168" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "val t0 = System.currentTimeMillis\n", "for ((name, index) <- Seq(\"train\", \"eval\", \"trans\").zipWithIndex) {\n", " dataset(index).write.mode(\"overwrite\").parquet(outPath + \"/parquet/\" + name)\n", " dataset(index).write.mode(\"overwrite\").csv(outPath + \"/csv/\" + name)\n", " }\n", "val t1 = System.currentTimeMillis\n", "println(\"Elapsed time : \" + ((t1 - t0).toFloat / 1000) + \"s\")\n", "sparkSession.stop()" ] }, { "cell_type": "code", "execution_count": null, "id": "8d89fa1b", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "XGBoost4j-Spark - Scala", "language": "scala", "name": "XGBoost4j-Spark_scala" }, "language_info": { "codemirror_mode": "text/x-scala", "file_extension": ".scala", "mimetype": "text/x-scala", "name": "scala", "pygments_lexer": "scala", "version": "2.12.15" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: examples/XGBoost-Examples/taxi/notebooks/scala/taxi-gpu.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Introduction to XGBoost Spark with GPU\n", "\n", "Taxi is an example of XGBoost regressor. This notebook will show you how to load data, train the XGBoost model and use this model to predict \"fare_amount\" of your taxi trip.\n", "\n", "## Load libraries\n", "First load some common libraries will be used by both GPU version and CPU version XGBoost." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import ml.dmlc.xgboost4j.scala.spark.{XGBoostRegressor, XGBoostRegressionModel}\n", "import org.apache.spark.sql.SparkSession\n", "import org.apache.spark.ml.evaluation.RegressionEvaluator\n", "import org.apache.spark.sql.types.{DoubleType, IntegerType, StructField, StructType}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Besides CPU version requires some extra libraries, such as:\n", "\n", "```scala\n", "import org.apache.spark.ml.feature.VectorAssembler\n", "import org.apache.spark.sql.DataFrame\n", "import org.apache.spark.sql.functions._\n", "import org.apache.spark.sql.types.FloatType\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Set the dataset path" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "dataRoot = /data\n", "trainPath = /data/taxi/csv/train/\n", "evalPath = /data/taxi/csv/test/\n", "transPath = /data/taxi/csv/test/\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "/data/taxi/csv/test/" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "// You need to update them to your real paths! The input data files can be the output of taxi-etl jobs, or you can\n", "// just use the provided sample datasets upder datasets path. \n", "val dataRoot = sys.env.getOrElse(\"DATA_ROOT\", \"/data\")\n", "val trainPath = dataRoot + \"/taxi/csv/train/\"\n", "val evalPath = dataRoot + \"/taxi/csv/test/\"\n", "val transPath = dataRoot + \"/taxi/csv/test/\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Build the schema of the dataset\n", "The Taxi data has 16 columns: 15 features and 1 label. \"fare_amount\" is the label column. The schema will be used to load data in the future. \n", "\n", "The next block also defines some key parameters used in XGBoost training process." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "labelName = fare_amount\n", "schema = \n", "featureNames = Array(vendor_id, passenger_count, trip_distance, pickup_longitude, pickup_latitude, rate_code, store_and_fwd, dropoff_longitude, dropoff_latitude, hour, year, month, day, day_of_week, is_weekend)\n", "paramMap = \n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "val labelName = \"fare_amount\"\n", "lazy val schema =\n", " StructType(Array(\n", " StructField(\"vendor_id\", DoubleType),\n", " StructField(\"passenger_count\", DoubleType),\n", " StructField(\"trip_distance\", DoubleType),\n", " StructField(\"pickup_longitude\", DoubleType),\n", " StructField(\"pickup_latitude\", DoubleType),\n", " StructField(\"rate_code\", DoubleType),\n", " StructField(\"store_and_fwd\", DoubleType),\n", " StructField(\"dropoff_longitude\", DoubleType),\n", " StructField(\"dropoff_latitude\", DoubleType),\n", " StructField(labelName, DoubleType),\n", " StructField(\"hour\", DoubleType),\n", " StructField(\"year\", IntegerType),\n", " StructField(\"month\", IntegerType),\n", " StructField(\"day\", DoubleType),\n", " StructField(\"day_of_week\", DoubleType),\n", " StructField(\"is_weekend\", DoubleType)\n", " ))\n", "\n", "val featureNames = schema.filter(_.name != labelName).map(_.name).toArray\n", "\n", "lazy val paramMap = Map(\n", " \"num_round\" -> 100\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create a new spark session and load data\n", "\n", "A new spark session should be created to continue all the following spark operations.\n", "\n", "NOTE: in this notebook, the dependency jars have been loaded when installing toree kernel. Alternatively the jars can be loaded into notebook by [%AddJar magic](https://toree.incubator.apache.org/docs/current/user/faq/). However, there's one restriction for `%AddJar`: the jar uploaded can only be available when `AddJar` is called just after a new spark session is created. Do it as below:\n", "\n", "```scala\n", "import org.apache.spark.sql.SparkSession\n", "val spark = SparkSession.builder().appName(\"taxi-GPU\").getOrCreate\n", "%AddJar file:/data/libs/rapids-4-spark-XXX.jar\n", "%AddJar file:/data/libs/xgboost4j-spark-gpu_2.12-XXX.jar\n", "%AddJar file:/data/libs/xgboost4j-gpu_2.12-XXX.jar\n", "// ...\n", "```\n", "\n", "##### Please note the new jar \"rapids-4-spark-XXX.jar\" is only needed for GPU version, you can not add it to dependence list for CPU version." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "sparkSession = org.apache.spark.sql.SparkSession@6efbc93b\n", "reader = org.apache.spark.sql.DataFrameReader@64b8d6da\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "org.apache.spark.sql.DataFrameReader@64b8d6da" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "// Build the spark session and data reader as usual\n", "val sparkSession = SparkSession.builder().appName(\"taxi-GPU\").getOrCreate\n", "val reader = sparkSession.read.option(\"header\", true).schema(schema)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "trainSet = [vendor_id: double, passenger_count: double ... 14 more fields]\n", "evalSet = [vendor_id: double, passenger_count: double ... 14 more fields]\n", "transSet = [vendor_id: double, passenger_count: double ... 14 more fields]\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "[vendor_id: double, passenger_count: double ... 14 more fields]" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "// Please make sure to change the api to reader.parquet if you load parquet files.\n", "val trainSet = reader.csv(trainPath)\n", "val evalSet = reader.csv(evalPath)\n", "val transSet = reader.csv(transPath)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Set XGBoost parameters and build a XGBoostRegressor\n", "\n", "For CPU version, `num_workers` is recommended being equal to the number of CPU cores, while for GPU version, it should be set to the number of GPUs in Spark cluster.\n", "\n", "Besides the `device` for CPU version is also different from that for GPU version. Now only \"cuda\" is supported for training on GPU.\n", "\n", "```scala\n", "// difference in parameters\n", " \"num_workers\" -> 12" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "xgbParamFinal = Map(num_round -> 100, tree_method -> hist, device -> cuda, num_workers -> 1)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "Map(num_round -> 100, tree_method -> hist, device -> cuda, num_workers -> 1)" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": "val xgbParamFinal = paramMap ++ Map(\"tree_method\" -> \"hist\", \"device\" -> \"cuda\", \"num_workers\" -> 1)" }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "xgbRegressor = xgbr_d36c6f5fd67c\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "xgbr_d36c6f5fd67c" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "val xgbRegressor = new XGBoostRegressor(xgbParamFinal)\n", " .setLabelCol(labelName)\n", " .setFeaturesCol(featureNames)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Benchmark and train\n", "The object `benchmark` is used to compute the elapsed time of some operations.\n", "\n", "Training with evaluation dataset is also supported, the same as CPU version's behavior:\n", "\n", "* Call API `setEvalDataset` after initializing an XGBoostClassifier\n", "\n", "```scala\n", "xgbClassifier.setEvalDataset(evalSet)\n", "```" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "xgbr_d36c6f5fd67c" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "xgbRegressor.setEvalDataset(evalSet)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "defined object Benchmark\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "object Benchmark {\n", " def time[R](phase: String)(block: => R): (R, Float) = {\n", " val t0 = System.currentTimeMillis\n", " val result = block // call-by-name\n", " val t1 = System.currentTimeMillis\n", " println(\"Elapsed time [\" + phase + \"]: \" + ((t1 - t0).toFloat / 1000) + \"s\")\n", " (result, (t1 - t0).toFloat / 1000)\n", " }\n", "}" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=37275, DMLC_NUM_WORKER=1}\n" ] }, { "data": { "text/plain": [ "model = xgbr_d36c6f5fd67c\n" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Elapsed time [train]: 7.441s\n" ] }, { "data": { "text/plain": [ "xgbr_d36c6f5fd67c" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "// start training\n", "val (model, _) = Benchmark.time(\"train\") {\n", " xgbRegressor.fit(trainSet)\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Transformation and evaluation\n", "Here uses `transSet` to evaluate our model and use some key columns to show our predictions. Finally we use `RegressionEvaluator` to calculate an overall `rmse` of our predictions." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Elapsed time [transform]: 2.134s\n", "+-------------+---------------+------------------+-----------+------------------+\n", "| vendor_id|passenger_count| trip_distance|fare_amount| prediction|\n", "+-------------+---------------+------------------+-----------+------------------+\n", "|1.559730423E9| 2.0|0.7000000000000001| 5.0| 5.04693603515625|\n", "|1.559730423E9| 3.0|10.700000000000001| 34.0|31.727073669433594|\n", "|1.559730423E9| 1.0| 2.3| 10.0| 9.294451713562012|\n", "|1.559730423E9| 1.0| 4.4| 16.5| 15.05233097076416|\n", "|1.559730423E9| 1.0| 1.5| 7.0| 8.995831489562988|\n", "|1.559730423E9| 1.0| 0.8| 7.5| 6.239481449127197|\n", "|1.559730423E9| 1.0| 1.2| 5.5| 7.339130401611328|\n", "|1.559730423E9| 1.0| 3.0| 2.5|13.403449058532715|\n", "| 4.52563162E8| 1.0|2.3399999999999994| 9.5| 9.672189712524414|\n", "| 4.52563162E8| 1.0| 3.17| 12.0|11.674100875854492|\n", "+-------------+---------------+------------------+-----------+------------------+\n", "only showing top 10 rows\n", "\n", "Elapsed time [evaluation]: 0.17s\n", "RMSE == 1.9141528880798715\n" ] }, { "data": { "text/plain": [ "prediction = [vendor_id: double, passenger_count: double ... 15 more fields]\n", "evaluator = RegressionEvaluator: uid=regEval_547b9abc7a3b, metricName=rmse, throughOrigin=false\n", "rmse = 1.9141528880798715\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "1.9141528880798715" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "// start transform\n", "val (prediction, _) = Benchmark.time(\"transform\") {\n", " val ret = model.transform(transSet).cache()\n", " ret.foreachPartition((_: Iterator[_]) => ())\n", " ret\n", "}\n", "prediction.select(\"vendor_id\", \"passenger_count\", \"trip_distance\", labelName, \"prediction\").show(10)\n", "val evaluator = new RegressionEvaluator().setLabelCol(labelName)\n", "val (rmse, _) = Benchmark.time(\"evaluation\") {\n", " evaluator.evaluate(prediction)\n", "}\n", "println(s\"RMSE == $rmse\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Save the model to disk and load model\n", "Save the model to disk and then load it to memory. After that use the loaded model to do a new prediction." ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Elapsed time [transform2]: 0.025s\n" ] }, { "data": { "text/plain": [ "modelFromDisk = xgbr_d36c6f5fd67c\n", "results2 = [vendor_id: double, passenger_count: double ... 15 more fields]\n" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "+-------------+---------------+------------------+-----------+------------------+\n", "| vendor_id|passenger_count| trip_distance|fare_amount| prediction|\n", "+-------------+---------------+------------------+-----------+------------------+\n", "|1.559730423E9| 2.0|0.7000000000000001| 5.0| 5.04693603515625|\n", "|1.559730423E9| 3.0|10.700000000000001| 34.0|31.727073669433594|\n", "|1.559730423E9| 1.0| 2.3| 10.0| 9.294451713562012|\n", "|1.559730423E9| 1.0| 4.4| 16.5| 15.05233097076416|\n", "|1.559730423E9| 1.0| 1.5| 7.0| 8.995831489562988|\n", "+-------------+---------------+------------------+-----------+------------------+\n", "only showing top 5 rows\n", "\n" ] }, { "data": { "text/plain": [ "[vendor_id: double, passenger_count: double ... 15 more fields]" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.write.overwrite.save(dataRoot + \"/model/taxi\")\n", "\n", "val modelFromDisk = XGBoostRegressionModel.load(dataRoot + \"/model/taxi\")\n", "val (results2, _) = Benchmark.time(\"transform2\") {\n", " modelFromDisk.transform(transSet)\n", "}\n", "results2.select(\"vendor_id\", \"passenger_count\", \"trip_distance\", labelName, \"prediction\").show(5)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "sparkSession.close()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "XGBoost4j-Spark - Scala", "language": "scala", "name": "XGBoost4j-Spark_scala" }, "language_info": { "codemirror_mode": "text/x-scala", "file_extension": ".scala", "mimetype": "text/x-scala", "name": "scala", "pygments_lexer": "scala", "version": "2.12.15" } }, "nbformat": 4, "nbformat_minor": 2 } ================================================ FILE: examples/XGBoost-Examples/taxi/notebooks/scala/taxi_gpu_crossvalidation.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Taxi CrossValidation with GPU accelerating on XGBoost\n", "\n", "In this notebook, we will show you how to levarage GPU to accelerate taxi CrossValidation on XGBoost to find out the best model given a group parameters.\n", "\n", "## Import classes\n", "First we need load some common classes that both GPU version and CPU version will use:" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "import ml.dmlc.xgboost4j.scala.spark.{XGBoostRegressionModel, XGBoostRegressor}\n", "import org.apache.spark.ml.evaluation.{RegressionEvaluator}\n", "import org.apache.spark.ml.tuning.{ParamGridBuilder,CrossValidator}\n", "import org.apache.spark.sql.SparkSession\n", "import org.apache.spark.sql.types.{FloatType, IntegerType, StructField, StructType}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "what is new to xgboost-spark users is rapids.GpuDataReader and **rapids.CrossValidator**" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "ename": "Syntax Error.", "evalue": "", "output_type": "error", "traceback": [] } ], "source": [ "// import ml.dmlc.xgboost4j.scala.spark.rapids.CrossValidator" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Set dataset path" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "dataRoot = /data\n", "trainParquetPath = /data/taxi/parquet/train\n", "evalParquetPath = /data/taxi/parquet/eval\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "/data/taxi/parquet/eval" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "// You need to update them to your real paths! The input data files can be the output of taxi-etl jobs, or you can\n", "// just use the provided sample datasets under datasets path. \n", "val dataRoot = sys.env.getOrElse(\"DATA_ROOT\", \"/data\")\n", "val trainParquetPath=dataRoot + \"/taxi/parquet/train\"\n", "val evalParquetPath=dataRoot + \"/taxi/parquet/eval\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Set the schema of the dataset" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "labelColName = fare_amount\n", "schema = StructType(StructField(vendor_id,FloatType,true), StructField(passenger_count,FloatType,true), StructField(trip_distance,FloatType,true), StructField(pickup_longitude,FloatType,true), StructField(pickup_latitude,FloatType,true), StructField(rate_code,FloatType,true), StructField(store_and_fwd,FloatType,true), StructField(dropoff_longitude,FloatType,true), StructField(dropoff_latitude,FloatType,true), StructField(fare_amount,FloatType,true), StructField(hour,FloatType,true), StructField(year,IntegerType,true), StructField(month,IntegerType,true), StructField(day,FloatType,true), StructField(day_of_week,FloatType,true), StructField(is_weekend,FloatType,true))\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "StructType(StructField(vendor_id,FloatType,true), StructField(passenger_count,FloatType,true), StructField(trip_distance,FloatType,true), StructField(pickup_longitude,FloatType,true), StructField(pickup_latitude,FloatType,true), StructField(rate_code,FloatType,true), StructField(store_and_fwd,FloatType,true), StructField(dropoff_longitude,FloatType,true), StructField(dropoff_latitude,FloatType,true), StructField(fare_amount,FloatType,true), StructField(hour,FloatType,true), StructField(year,IntegerType,true), StructField(month,IntegerType,true), StructField(day,FloatType,true), StructField(day_of_week,FloatType,true), StructField(is_weekend,FloatType,true))" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "val labelColName = \"fare_amount\"\n", "val schema =\n", " StructType(Array(\n", " StructField(\"vendor_id\", FloatType),\n", " StructField(\"passenger_count\", FloatType),\n", " StructField(\"trip_distance\", FloatType),\n", " StructField(\"pickup_longitude\", FloatType),\n", " StructField(\"pickup_latitude\", FloatType),\n", " StructField(\"rate_code\", FloatType),\n", " StructField(\"store_and_fwd\", FloatType),\n", " StructField(\"dropoff_longitude\", FloatType),\n", " StructField(\"dropoff_latitude\", FloatType),\n", " StructField(labelColName, FloatType),\n", " StructField(\"hour\", FloatType),\n", " StructField(\"year\", IntegerType),\n", " StructField(\"month\", IntegerType),\n", " StructField(\"day\", FloatType),\n", " StructField(\"day_of_week\", FloatType),\n", " StructField(\"is_weekend\", FloatType)\n", " ))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create a new spark session and load data\n", "we must create a new spark session to continue all spark operations. It will also be used to initilize the `GpuDataReader` which is a data reader powered by GPU.\n", "\n", "NOTE: in this notebook, we have uploaded dependency jars when installing toree kernel. If we don't upload them at installation time, we can also upload in notebook by [%AddJar magic](https://toree.incubator.apache.org/docs/current/user/faq/). However, there's one restriction for `%AddJar`: the jar uploaded can only be available when `AddJar` is called after a new spark session is created. We must use it as below:\n", "\n", "```scala\n", "import org.apache.spark.sql.SparkSession\n", "val spark = SparkSession.builder().appName(\"Taxi-GPU-CV\").getOrCreate\n", "%AddJar file:/data/libs/rapids-4-spark-XXX.jar\n", "%AddJar file:/data/libs/xgboost4j-spark-gpu_2.12-XXX.jar\n", "%AddJar file:/data/libs/xgboost4j-gpu_2.12-XXX.jar\n", "// ...\n", "```" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "spark = org.apache.spark.sql.SparkSession@1b953a9c\n", "trainDs = [vendor_id: int, passenger_count: int ... 15 more fields]\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "[vendor_id: int, passenger_count: int ... 15 more fields]" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "val spark = SparkSession.builder().appName(\"taxi-gpu-cv\").getOrCreate()\n", "val trainDs = spark.read.parquet(trainParquetPath)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Find out features to train" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "featureNames = Array(vendor_id, passenger_count, trip_distance, pickup_longitude, pickup_latitude, rate_code, store_and_fwd, dropoff_longitude, dropoff_latitude, hour, year, month, day, day_of_week, is_weekend)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "Array(vendor_id, passenger_count, trip_distance, pickup_longitude, pickup_latitude, rate_code, store_and_fwd, dropoff_longitude, dropoff_latitude, hour, year, month, day, day_of_week, is_weekend)" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "val featureNames = schema.filter(_.name != labelColName).map(_.name).toArray" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "regressorParam = Map(num_round -> 100, tree_method -> hist, device -> cuda, num_workers -> 1)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "Map(num_round -> 100, tree_method -> hist, device -> cuda, num_workers -> 1)" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "val regressorParam = Map(\n", " \"num_round\" -> 100,\n", " \"tree_method\" -> \"hist\",\n", " \"device\" -> \"cuda\",\n", " \"num_workers\" -> 1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Construct CrossValidator" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "regressor = xgbr_1c1bd6fa3a5f\n", "paramGrid = \n", "evaluator = RegressionEvaluator: uid=regEval_c7293a967512, metricName=rmse, throughOrigin=false\n", "cv = cv_06528fc9d704\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "Array({\n", "\txgbr_1c1bd6fa3a5f-eta: 0.2,\n", "\txgbr_1c1bd6fa3a5f-maxDepth: 3\n", "}, {\n", "\txgbr_1c1bd6fa3a5f-eta: 0.6,\n", "\txgbr_1c1bd6fa3a5f-maxDepth: 3\n", "}, {\n", "\txgbr_1c1bd6fa3a5f-eta: 0.2,\n", "\txgbr_1c1bd6fa3a5f-maxDepth: 10\n", "}, {\n", "\txgbr_1c1bd6fa3a5f-eta: 0.6,\n", "\txgbr_1c1bd6fa3a5f-maxDepth: 10\n", "})\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "cv_06528fc9d704" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "val regressor = new XGBoostRegressor(regressorParam)\n", " .setLabelCol(labelColName)\n", " .setFeaturesCol(featureNames)\n", "val paramGrid = new ParamGridBuilder()\n", " .addGrid(regressor.maxDepth, Array(3, 10))\n", " .addGrid(regressor.eta, Array(0.2, 0.6))\n", " .build()\n", "val evaluator = new RegressionEvaluator().setLabelCol(labelColName)\n", "val cv = new CrossValidator()\n", " .setEstimator(regressor)\n", " .setEvaluator(evaluator)\n", " .setEstimatorParamMaps(paramGrid)\n", " .setNumFolds(3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## train with CrossValidator" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=36551, DMLC_NUM_WORKER=1}\n", "Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=40153, DMLC_NUM_WORKER=1}\n", "Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=46553, DMLC_NUM_WORKER=1}\n", "Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=50795, DMLC_NUM_WORKER=1}\n", "Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=44927, DMLC_NUM_WORKER=1}\n", "Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=55309, DMLC_NUM_WORKER=1}\n", "Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=55163, DMLC_NUM_WORKER=1}\n", "Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=54783, DMLC_NUM_WORKER=1}\n", "Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=49873, DMLC_NUM_WORKER=1}\n", "Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=36003, DMLC_NUM_WORKER=1}\n", "Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=41429, DMLC_NUM_WORKER=1}\n", "Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=60783, DMLC_NUM_WORKER=1}\n", "Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=49361, DMLC_NUM_WORKER=1}\n" ] }, { "data": { "text/plain": [ "model = xgbr_1c1bd6fa3a5f\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "xgbr_1c1bd6fa3a5f" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "val model = cv.fit(trainDs).bestModel.asInstanceOf[XGBoostRegressionModel]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## tranform with best model trained by CrossValidator" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "transformDs = [vendor_id: int, passenger_count: int ... 15 more fields]\n", "df = [vendor_id: int, passenger_count: int ... 16 more fields]\n" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "+-----------+------------------+\n", "|fare_amount| prediction|\n", "+-----------+------------------+\n", "| 11.4|12.278875350952148|\n", "| 7.4|7.4439215660095215|\n", "| 5.0| 4.565710067749023|\n", "| 8.5| 9.188780784606934|\n", "| 7.4| 7.266360759735107|\n", "+-----------+------------------+\n", "only showing top 5 rows\n", "\n" ] }, { "data": { "text/plain": [ "[vendor_id: int, passenger_count: int ... 16 more fields]" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "val transformDs = spark.read.parquet(evalParquetPath)\n", "val df = model.transform(transformDs).cache()\n", "df.select(\"fare_amount\", \"prediction\").show(5)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "evaluator = RegressionEvaluator: uid=regEval_1c57378a8fe1, metricName=rmse, throughOrigin=false\n", "rmse = 2.2492672858545992\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "2.2492672858545992" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "val evaluator = new RegressionEvaluator().setLabelCol(labelColName)\n", "val rmse = evaluator.evaluate(df)" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "spark.close()" ] } ], "metadata": { "kernelspec": { "display_name": "XGBoost4j-Spark - Scala", "language": "scala", "name": "XGBoost4j-Spark_scala" }, "language_info": { "codemirror_mode": "text/x-scala", "file_extension": ".scala", "mimetype": "text/x-scala", "name": "scala", "pygments_lexer": "scala", "version": "2.12.15" } }, "nbformat": 4, "nbformat_minor": 2 } ================================================ FILE: examples/XGBoost-Examples/taxi/pom.xml ================================================ sample_xgboost_examples com.nvidia 0.2.3-SNAPSHOT 4.0.0 spark_examples_taxi_${scala.binary.version} 8 8 com.nvidia spark_examples_utility_${scala.binary.version} ${project.version} compile scala/src ================================================ FILE: examples/XGBoost-Examples/taxi/python/com/__init__.py ================================================ # Copyright (c) 2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: examples/XGBoost-Examples/taxi/python/com/nvidia/__init__.py ================================================ # Copyright (c) 2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: examples/XGBoost-Examples/taxi/python/com/nvidia/spark/__init__.py ================================================ # Copyright (c) 2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: examples/XGBoost-Examples/taxi/python/com/nvidia/spark/examples/__init__.py ================================================ # Copyright (c) 2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: examples/XGBoost-Examples/taxi/python/com/nvidia/spark/examples/taxi/__init__.py ================================================ # Copyright (c) 2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: examples/XGBoost-Examples/taxi/python/com/nvidia/spark/examples/taxi/consts.py ================================================ # # Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from pyspark.sql.types import * label = 'fare_amount' raw_schema = StructType([ StructField('vendor_id', StringType()), StructField('pickup_datetime', StringType()), StructField('dropoff_datetime', StringType()), StructField('passenger_count', IntegerType()), StructField('trip_distance', DoubleType()), StructField('pickup_longitude', DoubleType()), StructField('pickup_latitude', DoubleType()), StructField('rate_code', StringType()), StructField('store_and_fwd_flag', StringType()), StructField('dropoff_longitude', DoubleType()), StructField('dropoff_latitude', DoubleType()), StructField('payment_type', StringType()), StructField(label, DoubleType()), StructField('surcharge', DoubleType()), StructField('mta_tax', DoubleType()), StructField('tip_amount', DoubleType()), StructField('tolls_amount', DoubleType()), StructField('total_amount', DoubleType()), ]) final_schema = StructType([ StructField('vendor_id', FloatType()), StructField('passenger_count', FloatType()), StructField('trip_distance', FloatType()), StructField('pickup_longitude', FloatType()), StructField('pickup_latitude', FloatType()), StructField('rate_code', FloatType()), StructField('store_and_fwd', FloatType()), StructField('dropoff_longitude', FloatType()), StructField('dropoff_latitude', FloatType()), StructField(label, FloatType()), StructField('hour', FloatType()), StructField('year', IntegerType()), StructField('month', IntegerType()), StructField('day', FloatType()), StructField('day_of_week', FloatType()), StructField('is_weekend', FloatType()), ]) ================================================ FILE: examples/XGBoost-Examples/taxi/python/com/nvidia/spark/examples/taxi/cross_validator_main.py ================================================ # # Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from .consts import * from com.nvidia.spark.examples.utility.utils import * from pyspark.ml.tuning import ParamGridBuilder, CrossValidator from pyspark.sql import SparkSession from xgboost.spark import SparkXGBRegressor, SparkXGBRegressorModel def main(args, xgboost_args): spark = (SparkSession .builder .appName(args.mainClass) .getOrCreate()) train_data, eval_data, trans_data = valid_input_data(spark, args, raw_schema, final_schema) if args.mode in ['all', 'train']: if train_data is None: print('-' * 80) print('Usage: training data path required when mode is all or train') print('-' * 80) exit(1) train_data, features = transform_data(train_data, label, args.use_gpu) xgboost_args['features_col'] = features xgboost_args['label_col'] = label regressor = SparkXGBRegressor(**xgboost_args) param_grid = (ParamGridBuilder() .addGrid(regressor.max_depth, [6, 8]) .addGrid(regressor.n_estimators, [20, 40]) .build()) evaluator = (RegressionEvaluator() .setLabelCol(label)) cross_validator = (CrossValidator() .setEstimator(regressor) .setEvaluator(evaluator) .setEstimatorParamMaps(param_grid) .setNumFolds(3)) model = with_benchmark('Training', lambda: cross_validator.fit(train_data)) # get the best model to do transform model = model.bestModel if args.modelPath: writer = model.write().overwrite() if args.overwrite else model writer.save(args.modelPath) else: model = SparkXGBRegressorModel.load(args.modelPath) if args.mode in ['all', 'transform']: if trans_data is None: print('-' * 80) print('Usage: trans data path required when mode is all or transform') print('-' * 80) exit(1) trans_data, _ = transform_data(trans_data, label, args.use_gpu) def transform(): result = model.transform(trans_data).cache() result.foreachPartition(lambda _: None) return result result = with_benchmark('Transformation', transform) show_sample(args, result, label) with_benchmark('Evaluation', lambda: check_regression_accuracy(result, label)) spark.stop() ================================================ FILE: examples/XGBoost-Examples/taxi/python/com/nvidia/spark/examples/taxi/etl_main.py ================================================ # # Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from .consts import * from .pre_process import pre_process from com.nvidia.spark.examples.utility.utils import * from pyspark.sql import SparkSession def main(args, xgboost_args): spark = (SparkSession .builder .appName(args.mainClass) .getOrCreate()) raw_data_path = extract_paths(args.dataPaths, 'raw::') output_path = extract_paths(args.dataPaths, 'out::')[0] if not raw_data_path: print('-' * 80) print('Usage: raw data path required when ETL') exit(1) if not output_path: print('-' * 80) print('Usage: output data path required when ETL') exit(1) raw_data = prepare_data(spark, args, raw_schema, raw_data_path) etled_train, etled_eval, etled_trans = pre_process(raw_data).randomSplit(list(map(float, args.splitRatios))) etled_train.write.mode("overwrite").parquet(output_path + '/train') etled_eval.write.mode("overwrite").parquet(output_path + '/eval') etled_trans.write.mode("overwrite").parquet(output_path + '/trans') ================================================ FILE: examples/XGBoost-Examples/taxi/python/com/nvidia/spark/examples/taxi/main.py ================================================ # # Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from .consts import * from com.nvidia.spark.examples.utility.utils import * from pyspark.sql import SparkSession from xgboost.spark import SparkXGBRegressor, SparkXGBRegressorModel def main(args, xgboost_args): spark = (SparkSession .builder .appName(args.mainClass) .getOrCreate()) train_data, eval_data, trans_data = valid_input_data(spark, args, raw_schema, final_schema) if args.mode in ['all', 'train']: if not train_data: print('-' * 80) print('Usage: training data path required when mode is all or train') print('-' * 80) exit(1) train_data, features = transform_data(train_data, label, args.use_gpu) xgboost_args['features_col'] = features xgboost_args['label_col'] = label regressor = SparkXGBRegressor(**xgboost_args) if eval_data: # pass pass model = with_benchmark('Training', lambda: regressor.fit(train_data)) if args.modelPath: writer = model.write().overwrite() if args.overwrite else model writer.save(args.modelPath) else: model = SparkXGBRegressorModel.load(args.modelPath) if args.mode in ['all', 'transform']: if not trans_data: print('-' * 80) print('Usage: trans data path required when mode is all or transform') print('-' * 80) exit(1) trans_data, _ = transform_data(trans_data, label, args.use_gpu) def transform(): result = model.transform(trans_data).cache() result.foreachPartition(lambda _: None) return result result = with_benchmark('Transformation', transform) show_sample(args, result, label) with_benchmark('Evaluation', lambda: check_regression_accuracy(result, label)) spark.stop() ================================================ FILE: examples/XGBoost-Examples/taxi/python/com/nvidia/spark/examples/taxi/pre_process.py ================================================ # # Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import math from pyspark.sql.functions import * from pyspark.sql.types import * from pyspark.sql.functions import col def pre_process(data_frame): processes = [ drop_useless, encode_categories, fill_na, remove_invalid, convert_datetime, add_h_distance, ] for process in processes: data_frame = process(data_frame) return data_frame def drop_useless(data_frame): return data_frame.drop( 'dropoff_datetime', 'payment_type', 'surcharge', 'mta_tax', 'tip_amount', 'tolls_amount', 'total_amount') def encode_categories(data_frame): categories = [ 'vendor_id', 'rate_code', 'store_and_fwd_flag' ] for category in categories: data_frame = data_frame.withColumn(category, hash(col(category))) return data_frame.withColumnRenamed("store_and_fwd_flag", "store_and_fwd") def fill_na(data_frame): return data_frame.fillna(-1) def remove_invalid(data_frame): conditions = [ ( 'fare_amount', 0, 500 ), ( 'passenger_count', 0, 6 ), ( 'pickup_longitude', -75, -73 ), ( 'dropoff_longitude', -75, -73 ), ( 'pickup_latitude', 40, 42 ), ( 'dropoff_latitude', 40, 42 ), ] for column, min, max in conditions: data_frame = data_frame.filter('{} > {} and {} < {}'.format(column, min, column, max)) return data_frame def convert_datetime(data_frame): datetime = col('pickup_datetime') return (data_frame .withColumn('pickup_datetime', to_timestamp(datetime)) .withColumn('year', year(datetime)) .withColumn('month', month(datetime)) .withColumn('day', dayofmonth(datetime)) .withColumn('day_of_week', dayofweek(datetime)) .withColumn( 'is_weekend', col('day_of_week').isin(1, 7).cast(IntegerType())) # 1: Sunday, 7: Saturday .withColumn('hour', hour(datetime)) .drop('pickup_datetime')) def add_h_distance(data_frame): p = math.pi / 180 lat1 = col('pickup_latitude') lon1 = col('pickup_longitude') lat2 = col('dropoff_latitude') lon2 = col('dropoff_longitude') internal_value = (0.5 - cos((lat2 - lat1) * p) / 2 + cos(lat1 * p) * cos(lat2 * p) * (1 - cos((lon2 - lon1) * p)) / 2) h_distance = 12734 * asin(sqrt(internal_value)) return data_frame.withColumn('h_distance', h_distance) ================================================ FILE: examples/XGBoost-Examples/taxi/scala/src/com/nvidia/spark/examples/taxi/CrossValidationMain.scala ================================================ /* * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.nvidia.spark.examples.taxi import com.nvidia.spark.examples.utility.{XGBoostArgs, Benchmark} import ml.dmlc.xgboost4j.scala.spark.{XGBoostRegressionModel, XGBoostRegressor} import org.apache.spark.ml.evaluation.RegressionEvaluator import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder} import org.apache.spark.sql.SparkSession object CrossValidationMain extends Taxi { def main(args: Array[String]): Unit = { val xgboostArgs = XGBoostArgs.parse(args) val processor = this.getClass.getSimpleName.stripSuffix("$").substring(0, 3) val appInfo = Seq(appName, processor, xgboostArgs.format) // build spark session val spark = SparkSession.builder() .appName(appInfo.mkString("-")) .getOrCreate() val benchmark = Benchmark(appInfo(0), appInfo(1), appInfo(2)) // build data reader val dataReader = spark.read val (pathsArray, dataReadSchema, needEtl) = getDataPaths(xgboostArgs.dataPaths, xgboostArgs.isToTrain, xgboostArgs.isToTransform) // 0: train 1: eval 2:transform var datasets = pathsArray.map { paths => if (paths.nonEmpty) { xgboostArgs.format match { case "csv" => Some(dataReader.option("header", xgboostArgs.hasHeader).schema(dataReadSchema).csv(paths: _*)) case "orc" => Some(dataReader.orc(paths: _*)) case "parquet" => Some(dataReader.parquet(paths: _*)) case _ => throw new IllegalArgumentException("Unsupported data file format!") } } else { None } } if (needEtl) datasets = datasets.map(_.map(preProcess(_))) val xgbRegressionModel = if (xgboostArgs.isToTrain) { // build XGBoost XGBoostRegressor val xgbParamFinal = xgboostArgs.xgboostParams(commParamMap) val xgbRegressor = new XGBoostRegressor(xgbParamFinal) .setLabelCol(labelColName) .setFeaturesCol(featureNames) // Tune model using cross validation val paramGrid = new ParamGridBuilder() .addGrid(xgbRegressor.maxDepth, Array(3, 10)) .addGrid(xgbRegressor.eta, Array(0.2, 0.6)) .build() val evaluator = new RegressionEvaluator().setLabelCol(labelColName) val cv = new CrossValidator() .setEstimator(xgbRegressor) .setEvaluator(evaluator) .setEstimatorParamMaps(paramGrid) .setNumFolds(xgboostArgs.numFold) println("\n------ Training ------") // Shall we not log the time if it is abnormal, which is usually caused by training failure val (model, _) = benchmark.time("CrossValidator") { cv.fit(datasets(0).get).bestModel.asInstanceOf[XGBoostRegressionModel] } // Save model if modelPath exists xgboostArgs.modelPath.foreach(path => if (xgboostArgs.isOverwrite) model.write.overwrite().save(path) else model.save(path)) model } else { XGBoostRegressionModel.load(xgboostArgs.modelPath.get) } if (xgboostArgs.isToTransform) { println("\n------ Transforming ------") var (prediction, _) = benchmark.time("transform") { val ret = xgbRegressionModel.transform(datasets(2).get).cache() ret.foreachPartition((_: Iterator[_]) => ()) ret } prediction = if (xgboostArgs.isShowFeatures) { prediction } else { prediction.select(labelColName, "prediction") } prediction.show(xgboostArgs.numRows) println("\n------Accuracy of Evaluation------") val evaluator = new RegressionEvaluator().setLabelCol(labelColName) evaluator.evaluate(prediction) match { case rmse if !rmse.isNaN => benchmark.value(rmse, "RMSE", "RMSE for") // Throw an exception when NaN ? } } spark.close() } } ================================================ FILE: examples/XGBoost-Examples/taxi/scala/src/com/nvidia/spark/examples/taxi/ETLMain.scala ================================================ /* * Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.nvidia.spark.examples.taxi import com.nvidia.spark.examples.utility.{XGBoostArgs, Benchmark} import org.apache.spark.sql.SparkSession object ETLMain extends Taxi { def main(args: Array[String]): Unit = { val xgboostArgs = XGBoostArgs.parse(args) val processor = this.getClass.getSimpleName.stripSuffix("$").substring(0, 3) val appInfo = Seq(appName, processor, xgboostArgs.format) // build spark session val spark = SparkSession.builder() .appName(appInfo.mkString("-")) .getOrCreate() val benchmark = Benchmark(appInfo(0), appInfo(1), appInfo(2)) // build data reader val dataReader = spark.read val (rawPaths, outPath) = checkAndGetPaths(xgboostArgs.dataPaths) val df = xgboostArgs.format match { case "csv" => dataReader.option("header", xgboostArgs.hasHeader).schema(rawSchema).csv(rawPaths: _*) case "parquet" => dataReader.parquet(rawPaths: _*) case "orc" => dataReader.orc(rawPaths: _*) case _ => throw new IllegalArgumentException("Unsupported data file format!") } val (trainRatio, evalRatio, trainEvalRatio) = xgboostArgs.dataRatios val dataset = preProcess(df, Array(trainRatio, trainEvalRatio, evalRatio)) benchmark.time("ETL") { for ((name, index) <- Seq("train", "eval", "trans").zipWithIndex) { dataset(index).write.mode("overwrite").parquet(outPath + "/parquet/" + name) dataset(index).write.mode("overwrite").csv(outPath + "/csv/" + name) } } spark.close() } private def checkAndGetPaths(paths: Seq[String]): (Seq[String], String) = { val prefixes = Array("raw::", "out::") val validPaths = paths.filter(_.nonEmpty).map(_.trim) // get and check train data paths val rawPaths = validPaths.filter(_.startsWith(prefixes.head)) require(rawPaths.nonEmpty, s"$appName ETL requires at least one path for taxi data file." + s" Please specify it by '-dataPath=raw::your_taxi_data_path'") // get and check out path val outPath = validPaths.filter(_.startsWith(prefixes(1))) require(outPath.nonEmpty, s"$appName ETL requires a path to save the ETLed data file. Please specify it" + " by '-dataPath=out::your_out_path', only the first path is used if multiple paths are found.") // check data paths not specified type val unknownPaths = validPaths.filterNot(p => prefixes.exists(p.contains(_))) require(unknownPaths.isEmpty, s"Unknown type for data path: ${unknownPaths.head}, $appName requires to specify" + " the type for each data path by adding the prefix 'raw::' or 'out::'") (rawPaths.map(_.stripPrefix(prefixes.head)), outPath.head.stripPrefix(prefixes(1))) } } ================================================ FILE: examples/XGBoost-Examples/taxi/scala/src/com/nvidia/spark/examples/taxi/Main.scala ================================================ /* * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.nvidia.spark.examples.taxi import com.nvidia.spark.examples.utility.{XGBoostArgs, Benchmark} import ml.dmlc.xgboost4j.scala.spark.{XGBoostRegressionModel, XGBoostRegressor} import org.apache.spark.ml.evaluation.RegressionEvaluator import org.apache.spark.sql.SparkSession object Main extends Taxi { def main(args: Array[String]): Unit = { val xgboostArgs = XGBoostArgs.parse(args) val processor = this.getClass.getSimpleName.stripSuffix("$").substring(0, 3) val appInfo = Seq(appName, processor, xgboostArgs.format) // build spark session val spark = SparkSession.builder() .appName(appInfo.mkString("-")) .getOrCreate() val benchmark = Benchmark(appInfo(0), appInfo(1), appInfo(2)) // build data reader val dataReader = spark.read val (pathsArray, dataReadSchema, needEtl) = getDataPaths(xgboostArgs.dataPaths, xgboostArgs.isToTrain, xgboostArgs.isToTransform) // 0: train 1: eval 2:transform var datasets = pathsArray.map { paths => if (paths.nonEmpty) { xgboostArgs.format match { case "csv" => Some(dataReader.option("header", xgboostArgs.hasHeader).schema(dataReadSchema).csv(paths: _*)) case "orc" => Some(dataReader.orc(paths: _*)) case "parquet" => Some(dataReader.parquet(paths: _*)) case _ => throw new IllegalArgumentException("Unsupported data file format!") } } else { None } } if (needEtl) datasets = datasets.map(_.map(preProcess(_))) val xgbRegressionModel = if (xgboostArgs.isToTrain) { // build XGBoost XGBoostRegressor val xgbParamFinal = xgboostArgs.xgboostParams(commParamMap) val xgbRegressor = new XGBoostRegressor(xgbParamFinal) .setLabelCol(labelColName) .setFeaturesCol(featureNames) datasets(1).foreach(_ => xgbRegressor.setEvalDataset(_)) println("\n------ Training ------") // Shall we not log the time if it is abnormal, which is usually caused by training failure val (model, _) = benchmark.time("train") { xgbRegressor.fit(datasets(0).get) } // Save model if modelPath exists xgboostArgs.modelPath.foreach(path => if (xgboostArgs.isOverwrite) model.write.overwrite().save(path) else model.save(path)) model } else { XGBoostRegressionModel.load(xgboostArgs.modelPath.get) } if (xgboostArgs.isToTransform) { println("\n------ Transforming ------") var (prediction, _) = benchmark.time("transform") { val ret = xgbRegressionModel.transform(datasets(2).get).cache() ret.foreachPartition((_: Iterator[_]) => ()) ret } prediction = if (xgboostArgs.isShowFeatures) { prediction } else { prediction.select(labelColName, "prediction") } prediction.show(xgboostArgs.numRows) println("\n------Accuracy of Evaluation------") val evaluator = new RegressionEvaluator().setLabelCol(labelColName) evaluator.evaluate(prediction) match { case rmse if !rmse.isNaN => benchmark.value(rmse, "RMSE", "RMSE for") // Throw an exception when NaN ? } } spark.close() } } ================================================ FILE: examples/XGBoost-Examples/taxi/scala/src/com/nvidia/spark/examples/taxi/Taxi.scala ================================================ /* * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.nvidia.spark.examples.taxi import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.DataTypes.{DoubleType, IntegerType, StringType} import org.apache.spark.sql.types.{FloatType, StructField, StructType} private[taxi] trait Taxi { val appName = "Taxi" lazy val labelColName = "fare_amount" lazy val featureNames = etledSchema.filter(_.name != labelColName).map(_.name).toArray lazy val commParamMap = Map( "num_round" -> 100 ) val rawSchema = StructType(Seq( StructField("vendor_id", StringType), StructField("pickup_datetime", StringType), StructField("dropoff_datetime", StringType), StructField("passenger_count", IntegerType), StructField("trip_distance", DoubleType), StructField("pickup_longitude", DoubleType), StructField("pickup_latitude", DoubleType), StructField("rate_code", StringType), StructField("store_and_fwd_flag", StringType), StructField("dropoff_longitude", DoubleType), StructField("dropoff_latitude", DoubleType), StructField("payment_type", StringType), StructField(labelColName, DoubleType), StructField("surcharge", DoubleType), StructField("mta_tax", DoubleType), StructField("tip_amount", DoubleType), StructField("tolls_amount", DoubleType), StructField("total_amount", DoubleType) )) private val etledSchema = StructType(Array( StructField("vendor_id", FloatType), StructField("passenger_count", FloatType), StructField("trip_distance", FloatType), StructField("pickup_longitude", FloatType), StructField("pickup_latitude", FloatType), StructField("rate_code", FloatType), StructField("store_and_fwd", FloatType), StructField("dropoff_longitude", FloatType), StructField("dropoff_latitude", FloatType), StructField(labelColName, FloatType), StructField("hour", FloatType), StructField("year", IntegerType), StructField("month", IntegerType), StructField("day", FloatType), StructField("day_of_week", FloatType), StructField("is_weekend", FloatType) )) def preProcess(dataFrame: DataFrame): DataFrame = { val processes = Seq[DataFrame => DataFrame]( dropUseless, encodeCategories, fillNa, removeInvalid, convertDatetime, addHDistance ) processes .foldLeft(dataFrame) { case (df, process) => process(df) } } def preProcess(dataFrame: DataFrame, splits: Array[Int]): Array[DataFrame] = { val processes = Seq[DataFrame => DataFrame]( dropUseless, encodeCategories, fillNa, removeInvalid, convertDatetime, addHDistance ) processes .foldLeft(dataFrame) { case (df, process) => process(df) } .cache() .randomSplit(splits.map(_.toDouble)) } def dropUseless(dataFrame: DataFrame): DataFrame = { dataFrame.drop( "dropoff_datetime", "payment_type", "surcharge", "mta_tax", "tip_amount", "tolls_amount", "total_amount") } def encodeCategories(dataFrame: DataFrame): DataFrame = { val categories = Seq("vendor_id", "rate_code", "store_and_fwd_flag") (categories.foldLeft(dataFrame) { case (df, category) => df.withColumn(category, hash(col(category))) }).withColumnRenamed("store_and_fwd_flag", "store_and_fwd") } def fillNa(dataFrame: DataFrame): DataFrame = { dataFrame.na.fill(-1) } def removeInvalid(dataFrame: DataFrame): DataFrame = { val conditions = Seq( Seq("fare_amount", 0, 500), Seq("passenger_count", 0, 6), Seq("pickup_longitude", -75, -73), Seq("dropoff_longitude", -75, -73), Seq("pickup_latitude", 40, 42), Seq("dropoff_latitude", 40, 42)) conditions .map { case Seq(column, min, max) => "%s > %d and %s < %d".format(column, min, column, max) } .foldLeft(dataFrame) { _.filter(_) } } def convertDatetime(dataFrame: DataFrame): DataFrame = { val datetime = col("pickup_datetime") dataFrame .withColumn("pickup_datetime", to_timestamp(datetime)) .withColumn("year", year(datetime)) .withColumn("month", month(datetime)) .withColumn("day", dayofmonth(datetime)) .withColumn("day_of_week", dayofweek(datetime)) .withColumn( "is_weekend", col("day_of_week").isin(1, 7).cast(IntegerType)) // 1: Sunday, 7: Saturday .withColumn("hour", hour(datetime)) .drop(datetime.toString) } def addHDistance(dataFrame: DataFrame): DataFrame = { val P = math.Pi / 180 val lat1 = col("pickup_latitude") val lon1 = col("pickup_longitude") val lat2 = col("dropoff_latitude") val lon2 = col("dropoff_longitude") val internalValue = (lit(0.5) - cos((lat2 - lat1) * P) / 2 + cos(lat1 * P) * cos(lat2 * P) * (lit(1) - cos((lon2 - lon1) * P)) / 2) val hDistance = lit(12734) * asin(sqrt(internalValue)) dataFrame.withColumn("h_distance", hDistance) } /** * getDataPaths check and get train/eval/transform paths * * @return Array(train_paths, eval_paths, transform_paths) */ def getDataPaths(dataPaths: Seq[String], isToTrain: Boolean, isToTransform: Boolean): (Array[Seq[String]], StructType, Boolean) = { val paths = dataPaths val etledPrefixes = Array("train::", "eval::", "trans::") val rawPrefixes = Array("rawTrain::", "rawEval::", "rawTrans::") val validPaths = paths.filter(_.nonEmpty).map(_.trim) val p1 = validPaths.filter(p => etledPrefixes.exists(p.startsWith(_))) val p2 = validPaths.filter(p => rawPrefixes.exists(p.startsWith(_))) require(p1.isEmpty || p2.isEmpty, s"requires directly train by '-dataPath=${etledPrefixes(0)}train_data_path" + s" -dataPath=${etledPrefixes(1)}eval_data_path -dataPath=${etledPrefixes(2)}transform_data_path' Or " + s"E2E train by '-dataPath=${rawPrefixes(0)}train_data_path -dataPath=${rawPrefixes(1)}eval_data_path" + s" -dataPath=${rawPrefixes(2)}transform_data_path'") val (prefixes, schema, needEtl) = if (p1.nonEmpty) (etledPrefixes, etledSchema, false) else (rawPrefixes, rawSchema, true) // get train data paths val trainPaths = validPaths.filter(_.startsWith(prefixes.head)) if (isToTrain) { require(trainPaths.nonEmpty, s"requires at least one path for train file." + s" Please specify it by '-dataPath=${prefixes(0)}your_train_data_path'") } // get eval path val evalPaths = validPaths.filter(_.startsWith(prefixes(1))) // get and check train data paths val transformPaths = validPaths.filter(_.startsWith(prefixes(2))) if (isToTransform) { require(transformPaths.nonEmpty, s"requires at least one path for transform file." + s" Please specify it by '-dataPath=${prefixes(2)}your_transform_data_path'") } // check data paths not specified type val unknownPaths = validPaths.filterNot(p => prefixes.exists(p.startsWith(_))) require(unknownPaths.isEmpty, s"Unknown type for data path: ${unknownPaths.head}, requires to specify" + s" the type for each data path by adding the prefix '${prefixes(0)}' or '${prefixes(1)}' or '${prefixes(2)}'.") (Array(trainPaths.map(_.stripPrefix(prefixes.head)), evalPaths.map(_.stripPrefix(prefixes(1))), transformPaths.map(_.stripPrefix(prefixes(2)))), schema, needEtl) } } ================================================ FILE: examples/XGBoost-Examples/utility/.gitignore ================================================ .idea target *.iml ================================================ FILE: examples/XGBoost-Examples/utility/pom.xml ================================================ sample_xgboost_examples com.nvidia 0.2.3-SNAPSHOT 4.0.0 spark_examples_utility_${scala.binary.version} 8 8 scala/src ================================================ FILE: examples/XGBoost-Examples/utility/python/com/__init__.py ================================================ # Copyright (c) 2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: examples/XGBoost-Examples/utility/python/com/nvidia/__init__.py ================================================ # Copyright (c) 2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: examples/XGBoost-Examples/utility/python/com/nvidia/spark/__init__.py ================================================ # Copyright (c) 2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: examples/XGBoost-Examples/utility/python/com/nvidia/spark/examples/__init__.py ================================================ # Copyright (c) 2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: examples/XGBoost-Examples/utility/python/com/nvidia/spark/examples/main.py ================================================ # # Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from .utility.args import parse_arguments from importlib import import_module def main(): args, xgboost_args = parse_arguments() getattr(import_module(args.mainClass), 'main')(args, xgboost_args) ================================================ FILE: examples/XGBoost-Examples/utility/python/com/nvidia/spark/examples/utility/__init__.py ================================================ # Copyright (c) 2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: examples/XGBoost-Examples/utility/python/com/nvidia/spark/examples/utility/args.py ================================================ # # Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import typing from argparse import ArgumentParser from distutils.util import strtobool from re import match from sys import exit def _to_bool(literal): return bool(strtobool(literal)) def _to_ratio_pair(literal): # e.g., '80:20' return match(r'^\d+:\d+$', literal) and [int(x) for x in literal.split(':')] MAX_CHUNK_SIZE = 2 ** 31 - 1 _examples = [ 'com.nvidia.spark.examples.agaricus.main', 'com.nvidia.spark.examples.mortgage.main', 'com.nvidia.spark.examples.mortgage.etl_main', 'com.nvidia.spark.examples.mortgage.cross_validator_main', 'com.nvidia.spark.examples.taxi.main', 'com.nvidia.spark.examples.taxi.etl_main', 'com.nvidia.spark.examples.taxi.cross_validator_main', ] def _validate_args(args): usage = '' if not args.dataPaths: usage += ' --dataPaths is required.\n' if not (args.dataRatios and 0 <= args.dataRatios[0] <= 100 and 0 <= args.dataRatios[1] <= 100 and args.dataRatios[0] + args.dataRatios[1] <= 100): usage += ' --dataRatios should be in format \'Int:Int\', these two ints should be' \ ' in range [0, 100] and the sum should be less than or equal to 100.\n' if not (1 <= args.maxRowsPerChunk <= MAX_CHUNK_SIZE): usage += ' --maxRowsPerChunk should be in range [1, {}].\n'.format(MAX_CHUNK_SIZE) if usage: print('-' * 80) print('Usage:\n' + usage) exit(1) def _attach_derived_args(args): args.trainRatio = args.dataRatios[0] args.evalRatio = args.dataRatios[1] args.trainEvalRatio = 100 - args.trainRatio - args.evalRatio args.splitRatios = [args.trainRatio, args.trainEvalRatio, args.evalRatio] def _inspect_xgb_parameters() -> typing.Dict[str, type]: """inspect XGBModel parameters from __init__""" from xgboost import XGBModel from typing import get_type_hints, get_origin xgb_parameters = {} xgb_model_sig = get_type_hints(XGBModel.__init__) for k, v in xgb_model_sig.items(): if k != "kwargs" and k != "return": if get_origin(v) == typing.Union: xgb_parameters[k] = v.__args__[0] else: xgb_parameters[k] = v # some extra parameters used by xgboost pyspark xgb_parameters['objective'] = str xgb_parameters['force_repartition'] = _to_bool xgb_parameters['use_gpu'] = _to_bool xgb_parameters['num_workers'] = int xgb_parameters['enable_sparse_data_optim'] = _to_bool return xgb_parameters def parse_arguments(): parser = ArgumentParser() # application arguments parser.add_argument('--mainClass', required=True, choices=_examples) parser.add_argument('--mode', choices=['all', 'train', 'transform'], default='all') parser.add_argument('--format', required=True, choices=['csv', 'parquet', 'orc']) parser.add_argument('--hasHeader', type=_to_bool, default=True) parser.add_argument('--asFloats', type=_to_bool, default=True) parser.add_argument('--maxRowsPerChunk', type=int, default=MAX_CHUNK_SIZE) parser.add_argument('--modelPath') parser.add_argument('--overwrite', type=_to_bool, default=False) parser.add_argument('--dataPath', dest='dataPaths', action='append') parser.add_argument('--dataRatios', type=_to_ratio_pair, default=[80, 20]) parser.add_argument('--numRows', type=int, default=5) parser.add_argument('--showFeatures', type=_to_bool, default=True) xgboost_all_args = _inspect_xgb_parameters() for arg, tp in xgboost_all_args.items(): parser.add_argument('--' + arg, type=tp) parsed_all = parser.parse_args() _validate_args(parsed_all) _attach_derived_args(parsed_all) parsed_xgboost = { k: v for k, v in vars(parsed_all).items() if k in xgboost_all_args and v is not None } return parsed_all, parsed_xgboost ================================================ FILE: examples/XGBoost-Examples/utility/python/com/nvidia/spark/examples/utility/utils.py ================================================ # # Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import typing from pyspark.ml.evaluation import * from pyspark.ml.feature import VectorAssembler from pyspark.sql import DataFrame from pyspark.sql.functions import col from pyspark.sql.types import FloatType from com.nvidia.spark.examples.taxi.pre_process import pre_process from time import time def merge_dicts(dict_x, dict_y): result = dict_x.copy() result.update(dict_y) return result def show_sample(args, data_frame, label): data_frame = data_frame if args.showFeatures else data_frame.select(label, 'prediction') data_frame.show(args.numRows) def vectorize_data_frame(data_frame, label): features = [x.name for x in data_frame.schema if x.name != label] to_floats = [col(x.name).cast(FloatType()) for x in data_frame.schema] return (VectorAssembler() .setInputCols(features) .setOutputCol('features') .transform(data_frame.select(to_floats)) .select(col('features'), col(label))) def vectorize_data_frames(data_frames, label): return [vectorize_data_frame(x, label) for x in data_frames] def with_benchmark(phrase, action): start = time() result = action() end = time() print('-' * 100) print('{} takes {} seconds'.format(phrase, round(end - start, 2))) return result def check_classification_accuracy(data_frame, label): accuracy = (MulticlassClassificationEvaluator() .setLabelCol(label) .evaluate(data_frame)) print('-' * 100) print('Accuracy is ' + str(accuracy)) def check_regression_accuracy(data_frame, label): accuracy = (RegressionEvaluator() .setLabelCol(label) .evaluate(data_frame)) print('-' * 100) print('RMSE is ' + str(accuracy)) def prepare_data(spark, args, schema, dataPath): reader = (spark .read .format(args.format)) if args.format == 'csv': reader.schema(schema).option('header', args.hasHeader) return reader.load(dataPath) def extract_paths(paths, prefix): results = [path[len(prefix):] for path in paths if path.startswith(prefix)] return results def transform_data( df: DataFrame, label: str, use_gpu: typing.Optional[bool], ) -> (DataFrame, typing.Union[str, typing.List[str]]): if use_gpu: features = [x.name for x in df.schema if x.name != label] else: df = vectorize_data_frame(df, label) features = 'features' return df, features def valid_input_data(spark, args, raw_schema, final_schema): e2e = False for path in args.dataPaths: if 'raw' in path: e2e = True break raw_train_path = '' raw_eval_path = '' raw_trans_path = '' eval_path = '' if e2e: raw_train_path = extract_paths(args.dataPaths, 'rawTrain::') raw_eval_path = extract_paths(args.dataPaths, 'rawEval::') raw_trans_path = extract_paths(args.dataPaths, 'rawTrans::') train_data = '' eval_data = '' trans_data = '' # if this is an e2e run if raw_train_path or raw_eval_path or raw_trans_path: raw_train_data = prepare_data(spark, args, raw_schema, raw_train_path) raw_eval_data = '' raw_trans_data = '' if raw_eval_path: raw_eval_data = prepare_data(spark, args, raw_schema, raw_eval_path) if raw_trans_path: raw_trans_data = prepare_data(spark, args, raw_schema, raw_trans_path) train_data = pre_process(raw_train_data) if raw_eval_data: eval_data = pre_process(raw_eval_data) if raw_trans_data: trans_data = pre_process(raw_trans_data) # if this is just a train/transform else: train_path = extract_paths(args.dataPaths, 'train::') eval_path = extract_paths(args.dataPaths, 'eval::') trans_path = extract_paths(args.dataPaths, 'trans::') if train_path: train_data = prepare_data(spark, args, final_schema, train_path) if eval_path: eval_data = prepare_data(spark, args, final_schema, eval_path) if trans_path: trans_data = prepare_data(spark, args, final_schema, trans_path) return (train_data, eval_data, trans_data) ================================================ FILE: examples/XGBoost-Examples/utility/scala/src/com/nvidia/spark/examples/utility/Benchmark.scala ================================================ /* * Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.nvidia.spark.examples.utility import scala.util.Properties class Benchmark( appName: String, processor: String, dataFormat: String) { def time[R](phase: String, silent: (Any, Float) => Boolean = (_,_) => false) (block: => R): (R, Float) = { val t0 = System.currentTimeMillis val result = block // call-by-name val elapsedTimeSec = (System.currentTimeMillis - t0).toFloat / 1000 logging(elapsedTimeSec, phase, "Elapsed time for", "s", silent(result, elapsedTimeSec)) (result, elapsedTimeSec) } def value(value: Any, name: String = "value", prefix: String="", suffix: String = "") = { logging(value, name, prefix, suffix, false) } private def logging(value: Any, name: String , prefix: String, suffix: String, silent: Boolean) = { if (!silent) { val logString = buildLogSimple(value, prefix, suffix, buildRuntimeInfo(name)) println("\n--------------") println("==> Benchmark: " + logString) println("--------------\n") } } private def buildRuntimeInfo(name: String): String = { // Get runtime information from Environment val osType = Properties.envOrElse("RAPIDS_XGB_EXAMPLE_OS_TYPE", "Unknown") val cudaVersion = Properties.envOrElse("RAPIDS_XGB_EXAMPLE_CUDA_VERSION", "Unknown") val sparkVersion = Properties.envOrElse("RAPIDS_XGB_EXAMPLE_SPARK_VERSION", "Unknown") Seq(appName, processor, name, dataFormat, "stub", cudaVersion, osType, sparkVersion) .mkString(" ") } private def buildLogSimple(value: Any, prefix: String, suffix: String, runtimeInfo: String): String = prefix + " [" + runtimeInfo + "]: " + value + suffix } object Benchmark { def apply(appName: String, processor: String, dataFormat: String) = new Benchmark(appName, processor, dataFormat) } ================================================ FILE: examples/XGBoost-Examples/utility/scala/src/com/nvidia/spark/examples/utility/SparkSetup.scala ================================================ /* * Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.nvidia.spark.examples.utility import org.apache.spark.sql.SparkSession object SparkSetup { def apply(args: Array[String], appName: String) = { val builder = SparkSession.builder() val masterBuilder = Option(System.getenv("SPARK_MASTER")).map { master => builder.master(master) }.getOrElse(builder) masterBuilder.appName(appName).getOrCreate() } def apply(args: Array[String]): SparkSession = SparkSetup(args, "default") } ================================================ FILE: examples/XGBoost-Examples/utility/scala/src/com/nvidia/spark/examples/utility/Vectorize.scala ================================================ /* * Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.nvidia.spark.examples.utility import org.apache.spark.ml.feature.VectorAssembler import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions.col import org.apache.spark.sql.types.FloatType object Vectorize { def apply(df: DataFrame, labelName: String, changeLabelName: Boolean = true): DataFrame = { val features = df.schema.collect { case f if f.name != labelName => f.name } val toFloat = df.schema.map(f => col(f.name).cast(FloatType)) val labelCol = if (changeLabelName) col(labelName).alias("label") else col(labelName) new VectorAssembler() .setInputCols(features.toArray) .setOutputCol("features") .transform(df.select(toFloat: _*)) .select(col("features"), labelCol) } def apply(df: DataFrame, featureNames: Seq[String], labelName: String): DataFrame = { val toFloat = df.schema.map(f => col(f.name).cast(FloatType)) new VectorAssembler() .setInputCols(featureNames.toArray) .setOutputCol("features") .transform(df.select(toFloat: _*)) .select(col("features"), col(labelName)) } def apply(featureNames: Seq[String], df: DataFrame, otherNames: String*): DataFrame = { val resultCols = (otherNames :+ "features").map(col(_)) new VectorAssembler() .setInputCols(featureNames.toArray) .setOutputCol("features") .transform(df) .select(resultCols: _*) } def criteoApply(df: DataFrame, featureNames: Seq[String], labelName: String): DataFrame = { val toFloat = df.schema.map(f => col(f.name).cast(FloatType)) new VectorAssembler() .setHandleInvalid("keep") .setInputCols(featureNames.toArray) .setOutputCol("features") .transform(df.select(toFloat: _*)) .select(col("features"), col(labelName)) } } ================================================ FILE: examples/XGBoost-Examples/utility/scala/src/com/nvidia/spark/examples/utility/XGBoostArgs.scala ================================================ /* * Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.nvidia.spark.examples.utility import com.google.common.base.CaseFormat import scala.collection.mutable import scala.util.Try private case class XGBoostArg( required: Boolean = false, parse: String => Any = value => value, message: String = "") object XGBoostArgs { private val modes = Seq("all", "train", "transform") private val formats = Seq("csv", "parquet", "orc") private val stringToBool = Map( "true" -> true, "false" -> false, "1" -> true, "0" -> false ) private val booleanMessage = "Expect 'true' or '1' for true, 'false' or '0' for false." private def parseDataRatios(value: String): (Int, Int) = { val ratios = value.split(":").filter(_.nonEmpty).map(_.toInt) require(ratios.length == 2 && ratios(0) + ratios(1) <= 100) (ratios(0), ratios(1)) } private val supportedArgs = Map( "mode" -> XGBoostArg( parse = value => { require(modes.contains(value)); value }, message = s"Expect one of [${modes.mkString(", ")}]"), "format" -> XGBoostArg(true, parse = value => { require(formats.contains(value)); value }, message = s"Expect one of [${formats.mkString(", ")}]"), "dataPath" -> XGBoostArg(true), "dataRatios" -> XGBoostArg( parse = parseDataRatios, message = "Expect as :, both train and transform require Int, and total value <= 100"), "modelPath" -> XGBoostArg(), "numRows" -> XGBoostArg(parse = _.toInt, message = "Require an Int."), "numFold" -> XGBoostArg(parse = _.toInt, message = "Require an Int."), "showFeatures" -> XGBoostArg(parse = stringToBool, message = booleanMessage), "overwrite" -> XGBoostArg(parse = stringToBool, message = booleanMessage), "hasHeader" -> XGBoostArg(parse = stringToBool, message = booleanMessage), "saveDict" -> XGBoostArg(parse = stringToBool, message = booleanMessage), ) private def help: Unit = { println("\n\nSupported arguments:") println(" -dataPath=path: String, Required\n" + " The path of data file(s). Use multiple '-dataPath=path#' to specify multiple paths. Such as" + " '-dataPath=path1 -dataPath=path2'.\n") println(" -format=: String, Required\n" + " The format of the data, now only supports 'csv', 'parquet' and 'orc'.\n") println(" -mode=: String\n" + " To control the behavior of apps. Default is 'all'. \n" + " * all: Do training and transformation.\n" + " * train: Do training only, will save model to 'modelPath' if specified.\n" + " * transform: Transformation only, 'modelPath' is required to provide the model.\n") println(" -modelPath=path: String\n" + " Specify where to save model after training, or where to load model for transforming only. \n") println(" -overwrite=value: Boolean\n" + " Whether to overwrite the current model data under 'modelPath'. Default is false\n") println(" -dataRatios=train:transform\n" + " The ratios of data used for train and transform, then the ratio for evaluation is (100-train-test)." + " default is 80:20, no evaluation\n") println(" -hasHeader=value: Boolean\n" + " Whether the csv file has header. Default is true.\n") println(" -numRows=value: Int\n" + " Number of the rows to show after transformation. Default is 5.\n") println(" -numFold=value: Int\n" + " Number of the folders to be used in Cross Validation. Default is 3.\n") println(" -showFeatures=value: Boolean\n" + " Whether to include the features columns when showing results of transformation. Default is true.\n") println(" -saveDict=value: Boolean\n" + " Whether to save the dictionary table for Mortgage ETL. It is saved under '/.dict'. Default is true.\n") println(" -rabitTrackerHost=value: String\n" + " Specify rabit tracker host IP address. In some environments XGBoost might fail to resolve\n" + "the IP address of the rabit tracker, a symptom is user receiving ``OSError: [Errno 99]\n" + "Cannot assign requested address`` error during training. A quick workaround is to\n" + "specify the address explicitly.\n") println("For XGBoost arguments:") println(" Now we pass all XGBoost parameters transparently to XGBoost, no longer to verify them.") println(" Both of the formats are supported, such as 'numWorkers'. You can pass as either one below:") println(" -numWorkers=10 or -num_workers=10 ") println() } def apply(args: Array[String]) = parse(args) def parse(args: Array[String]): XGBoostArgs = { val appArgsMap = mutable.HashMap.empty[String, Any] val xgbArgsMap = mutable.HashMap.empty[String, String] try { args.filter(_.nonEmpty).foreach { argString => require(argString.startsWith("-") && argString.contains('='), s"Invalid argument: $argString, expect '-name=value'") val parts = argString.stripPrefix("-").split('=').filter(_.nonEmpty) require(parts.length == 2, s"Invalid argument: $argString, expect '-name=value'") val (key, value) = (parts(0), parts(1)) if (supportedArgs.contains(key)) { // App arguments val parseTry = Try(supportedArgs(key).parse(value)) require(parseTry.isSuccess, s"Invalid value to '$key'. ${supportedArgs(key).message}") if (key == "dataPath") { val paths = appArgsMap.getOrElse(key, Seq.empty).asInstanceOf[Seq[String]] :+ parseTry.get appArgsMap += key -> paths } else { appArgsMap += key -> parseTry.get } } else { // Supposed to be XGBooost parameters xgbArgsMap += key -> value } } supportedArgs.filter(_._2.required).foreach { case (name, _) => require(appArgsMap.contains(name), s"Missing argument: $name.") } new XGBoostArgs(appArgsMap.toMap, xgbArgsMap.toMap) } catch { case e: Exception => help throw e } } } class XGBoostArgs private[utility] ( val appArgsMap: Map[String, Any], val xgbArgsMap: Map[String, String]) { def format: String = appArgsMap("format").asInstanceOf[String] def modelPath: Option[String] = appArgsMap.get("modelPath").asInstanceOf[Option[String]] // mode is optional with default value 'all' private def mode: String = appArgsMap.getOrElse("mode", "all").asInstanceOf[String] private[utility] def verifyArgsRelation: Unit = { if (mode == "train" && modelPath.isEmpty) { println("==> You may want to specify the 'modelPath' to save the model when 'train only' mode.") } if (mode == "transform") { require(modelPath.nonEmpty, "'modelPath' is required for mode: transform") } } verifyArgsRelation def isToTrain: Boolean = mode != "transform" def isToTransform: Boolean = mode != "train" def dataPaths: Seq[String] = appArgsMap("dataPath").asInstanceOf[Seq[String]] def dataRatios: (Int, Int, Int) = { val ratios = appArgsMap.get("dataRatios").asInstanceOf[Option[(Int, Int)]].getOrElse((80, 20)) (ratios._1, ratios._2, 100 - ratios._1 - ratios._2) } def isShowFeatures: Boolean = appArgsMap.get("showFeatures").forall(_.asInstanceOf[Boolean]) def isOverwrite: Boolean = appArgsMap.get("overwrite").exists(_.asInstanceOf[Boolean]) def hasHeader: Boolean = appArgsMap.get("hasHeader").forall(_.asInstanceOf[Boolean]) def saveDict: Boolean = appArgsMap.get("saveDict").forall(_.asInstanceOf[Boolean]) def numRows: Int = appArgsMap.get("numRows").asInstanceOf[Option[Int]].getOrElse(5) def numFold: Int = appArgsMap.get("numFold").asInstanceOf[Option[Int]].getOrElse(3) def xgboostParams(otherParams: Map[String, Any] = Map.empty): Map[String, Any] = { val params = otherParams ++ xgbArgsMap.map{ case (name, value) if !name.contains('_') => (CaseFormat.LOWER_CAMEL.to(CaseFormat.LOWER_UNDERSCORE, name), value) case (name, value) => (name, value) } val hostIp = params.getOrElse("rabit_tracker_host", "").toString if (!hostIp.isEmpty) { params ++ Map("rabitTrackerHostIp" -> hostIp) } else params } /** * getDataPaths check and get train/eval/transform paths * @return Array(train_paths, eval_paths, transform_paths) */ def getDataPaths: Array[Seq[String]] = { val paths = dataPaths val prefixes = Array("train::", "eval::", "trans::") val validPaths = paths.filter(_.nonEmpty).map(_.trim) // get train data paths val trainPaths = validPaths.filter(_.startsWith(prefixes.head)) if (isToTrain) { require(trainPaths.nonEmpty, s"requires at least one path for train file." + s" Please specify it by '-dataPath=train::your_train_data_path'") } // get eval path val evalPaths = validPaths.filter(_.startsWith(prefixes(1))) // get and check train data paths val transformPaths = validPaths.filter(_.startsWith(prefixes(2))) if (isToTransform) { require(transformPaths.nonEmpty, s"requires at least one path for transform file." + s" Please specify it by '-dataPath=trans::your_transform_data_path'") } // check data paths not specified type val unknownPaths = validPaths.filterNot(p => prefixes.exists(p.contains(_))) require(unknownPaths.isEmpty, s"Unknown type for data path: ${unknownPaths.head}, requires to specify" + " the type for each data path by adding the prefix 'train::' or 'eval::' or 'trans::'.") Array(trainPaths.map(_.stripPrefix(prefixes.head)), evalPaths.map(_.stripPrefix(prefixes(1))), transformPaths.map(_.stripPrefix(prefixes(2)))) } } ================================================ FILE: examples/spark-connect-gpu/client/Dockerfile ================================================ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Build stage FROM maven:3.8-openjdk-17 AS builder # Set platform to linux/amd64 ENV DOCKER_DEFAULT_PLATFORM=linux/amd64 # Copy the entire project COPY scala /build WORKDIR /build # Build all modules RUN mvn clean install -DskipTests # Spark connect client image FROM apache/spark:4.0.0 USER root RUN set -x \ && apt -q update -y && apt-get install -y vim git COPY requirements.txt /tmp/requirements.txt RUN pip3 install -r /tmp/requirements.txt RUN update-alternatives --install /usr/bin/python python /usr/bin/python3 10 RUN mkdir -p /home/spark/demo COPY notebook /home/spark/demo/notebook COPY scala /home/spark/demo/scala COPY --from=builder /build/target/*-jar-with-dependencies.jar /home/spark/demo/scala/ COPY python /home/spark/demo/python # Prepare NDS, make NDS as a package. COPY nds /home/spark/demo/nds RUN git clone --depth 1 -b dev https://github.com/NVIDIA/spark-rapids-benchmarks /tmp/spark-rapids-benchmarks && \ cp /tmp/spark-rapids-benchmarks/nds/nds_power.py \ /tmp/spark-rapids-benchmarks/nds/check.py \ /tmp/spark-rapids-benchmarks/nds/nds_schema.py \ /tmp/spark-rapids-benchmarks/nds/PysparkBenchReport.py \ /home/spark/demo/nds/ && \ rm -rf /tmp/spark-rapids-benchmarks RUN chown -R spark:spark /home/spark RUN chown -R spark:spark /home/spark/demo RUN usermod -d /home/spark spark USER spark WORKDIR /home/spark/demo SHELL [ "/bin/bash", "-c" ] ENV SHELL=/bin/bash ================================================ FILE: examples/spark-connect-gpu/client/README.md ================================================ # GPU-Accelerated Spark Connect for ETL and ML (Spark 4.0) This project demonstrates some python/scala batch jobs and a complete GPU-accelerated ETL and Machine Learning pipeline using Apache Spark 4.0 with Spark Connect, featuring the RAPIDS Accelerator. ## 🏗️ Architecture The client side consists of one Docker services: **Jupyter Lab - Spark Connect Client** (`spark-connect-client`) - Interactive development environment The first step, however, is to set up the GPU-accelerated Spark Connect Server. More details can be found [here](../server/README.md). ## 📋 Prerequisites ### Required - [Docker](https://docs.docker.com/engine/install/) and [Docker Compose](https://docs.docker.com/compose/install/linux) - At least 8GB of available RAM - Available ports: 8888 ## 🚀 Quick Start 1. **Clone and navigate to the project:** ```bash cd examples/spark-connect-gpu/client ``` 2. **Start all services:** Set the `SPARK_REMOTE` environment variable to point to your spark-connect-gpu server. By default this is `sc://localhost` (for same node deployments). If the client and server are on `different nodes`, you can either establish an SSH tunnel with port 15002 forwarded (e.g., `ssh -g -L 15002:localhost:15002 -N CONNECT_SERVER_IP`) and use the default `SPARK_REMOTE` value (`sc://localhost`), or override it with the server’s accessible IP address: ``` bash export SPARK_REMOTE=sc://CONNECT_SERVER_IP ``` Then start the client service: ```bash $ docker compose up -d ``` (`docker compose` can be used in place of `docker-compose` here and throughout) 3. **Access the Web UI interfaces:** **Jupyter Lab**: http://localhost:8888 (no password required) - Interactive notebook environment 4. **Run the demo ETL + ML notebook:** - Navigate to `notebook/spark-connect-gpu-etl-ml.ipynb` in Jupyter Lab - You can also open it in VS Code by selecting http://localhost:8888 as the existing notebook server connection - Run the complete ETL and ML pipeline demonstration 5. **Run the demo python batch job:** - Create a Terminal in the Jupyter Lab - Navigate to `/home/spark/demo/python` - Execute `python batch-job.py` 6. **Run the demo scala batch job:** - Create a Terminal in the Jupyter Lab - Navigate to `/home/spark/demo/scala` - Execute `./run.sh` 7. **Run the demo NDS notebook:** - Navigate to `nds/nds.ipynb` in Jupyter Lab - Run the nds demonstration ## Advanced GPU Configurations Most users won't need to adjust the GPU configurations. However, if you'd like to tune your GPU for better performance, refer to the [advanced GPU configurations documentation](https://nvidia.github.io/spark-rapids/docs/additional-functionality/advanced_configs.html). **Note**: Configurations prefixed with spark.rapids.sql are session-specific and can be set safely. However, those marked as **startup** will not take effect in Spark Connect. ## 🐳 Service Details ### JupyterLab - Spark Connect Client - **Image**: Based on `apache/spark:4.0.0` - **Environment**: Pre-configured with PySpark Connect Client - **Ports**: 8888 (Jupyter Lab) - **Volumes**: Notebooks and work directory mounted ## 🧹 Cleanup Stop and remove all services: ```bash docker-compose down -v ``` Remove built images: ```bash docker-compose down --rmi all -v ``` ### Logs Logs for the spark driver/connect server, standalone master, standalone worker, and jupyter server can be viewed using the respective commands: ```bash docker logs spark-connect-client ``` ## 📖 Additional Resources - [Apache Spark 4.0 Documentation](https://spark.apache.org/docs/latest/) - [Spark Connect Guide](https://spark.apache.org/docs/latest/spark-connect-overview.html) - [NVIDIA RAPIDS Accelerator](https://nvidia.github.io/spark-rapids/) - [Data and AI Summit Session](https://www.databricks.com/dataaisummit/session/gpu-accelerated-spark-connect) ================================================ FILE: examples/spark-connect-gpu/client/docker-compose.yaml ================================================ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # YAML anchors for shared configurations x-spark-common: &spark-common volumes: - ${DATA_DIR:-${PWD}/data}:/data services: spark-connect-client: <<: *spark-common image: spark-connect-client-image build: context: . dockerfile: Dockerfile container_name: spark-connect-client hostname: spark-connect-client network_mode: host environment: - SPARK_REMOTE=${SPARK_REMOTE:-sc://localhost} command: > bash -c 'jupyter-lab --port 8888 --no-browser --IdentityProvider.token="" --ServerApp.password="" --ServerApp.ip='0.0.0.0' --ServerApp.allow_origin='*' ' ================================================ FILE: examples/spark-connect-gpu/client/nds/nds.ipynb ================================================ { "cells": [ { "metadata": {}, "cell_type": "markdown", "source": "### Run nds_power.py directly", "id": "2274cf637f6f4702" }, { "metadata": {}, "cell_type": "code", "outputs": [], "execution_count": null, "source": "%run nds_power.py /data/nds query_0.sql time.csv", "id": "b5d2eeaeb7a2f63" }, { "metadata": { "collapsed": true }, "cell_type": "markdown", "source": [ "### Importing and Executing APIs in a Jupyter Notebook\n", "\n", "Alternatively, you can import the relevant APIs into your Jupyter notebook and execute them as shown below:" ], "id": "cb4ca58b118a9209" }, { "metadata": {}, "cell_type": "code", "outputs": [], "execution_count": null, "source": [ "from nds_power import gen_sql_from_stream, run_query_stream\n", "\n", "query_stream_file = \"query_0.sql\"\n", "nds_data_path = \"/data/nds\"\n", "time_log_file = \"time.csv\"\n", "\n", "query_dict = gen_sql_from_stream(query_stream_file)\n", "\n", "run_query_stream(input_prefix=nds_data_path,\n", " property_file=None,\n", " query_dict=query_dict,\n", " time_log_output_path=time_log_file,\n", " extra_time_log_output_path=None,\n", " sub_queries=None,\n", " warmup_iterations=0,\n", " iterations=1,\n", " plan_types=\"logical\",\n", " )" ], "id": "f8ccb334ec1c6766" } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: examples/spark-connect-gpu/client/nds/query_0.sql ================================================ -- start query 1 in stream 0 using template query96.tpl select count(*) from store_sales ,household_demographics ,time_dim, store where ss_sold_time_sk = time_dim.t_time_sk and ss_hdemo_sk = household_demographics.hd_demo_sk and ss_store_sk = s_store_sk and time_dim.t_hour = 8 and time_dim.t_minute >= 30 and household_demographics.hd_dep_count = 5 and store.s_store_name = 'ese' order by count(*) LIMIT 100; -- end query 1 in stream 0 using template query96.tpl -- start query 2 in stream 0 using template query7.tpl select i_item_id, avg(ss_quantity) agg1, avg(ss_list_price) agg2, avg(ss_coupon_amt) agg3, avg(ss_sales_price) agg4 from store_sales, customer_demographics, date_dim, item, promotion where ss_sold_date_sk = d_date_sk and ss_item_sk = i_item_sk and ss_cdemo_sk = cd_demo_sk and ss_promo_sk = p_promo_sk and cd_gender = 'M' and cd_marital_status = 'M' and cd_education_status = '4 yr Degree' and (p_channel_email = 'N' or p_channel_event = 'N') and d_year = 2001 group by i_item_id order by i_item_id LIMIT 100; -- end query 2 in stream 0 using template query7.tpl -- start query 3 in stream 0 using template query75.tpl WITH all_sales AS ( SELECT d_year ,i_brand_id ,i_class_id ,i_category_id ,i_manufact_id ,SUM(sales_cnt) AS sales_cnt ,SUM(sales_amt) AS sales_amt FROM (SELECT d_year ,i_brand_id ,i_class_id ,i_category_id ,i_manufact_id ,cs_quantity - COALESCE(cr_return_quantity,0) AS sales_cnt ,cs_ext_sales_price - COALESCE(cr_return_amount,0.0) AS sales_amt FROM catalog_sales JOIN item ON i_item_sk=cs_item_sk JOIN date_dim ON d_date_sk=cs_sold_date_sk LEFT JOIN catalog_returns ON (cs_order_number=cr_order_number AND cs_item_sk=cr_item_sk) WHERE i_category='Shoes' UNION SELECT d_year ,i_brand_id ,i_class_id ,i_category_id ,i_manufact_id ,ss_quantity - COALESCE(sr_return_quantity,0) AS sales_cnt ,ss_ext_sales_price - COALESCE(sr_return_amt,0.0) AS sales_amt FROM store_sales JOIN item ON i_item_sk=ss_item_sk JOIN date_dim ON d_date_sk=ss_sold_date_sk LEFT JOIN store_returns ON (ss_ticket_number=sr_ticket_number AND ss_item_sk=sr_item_sk) WHERE i_category='Shoes' UNION SELECT d_year ,i_brand_id ,i_class_id ,i_category_id ,i_manufact_id ,ws_quantity - COALESCE(wr_return_quantity,0) AS sales_cnt ,ws_ext_sales_price - COALESCE(wr_return_amt,0.0) AS sales_amt FROM web_sales JOIN item ON i_item_sk=ws_item_sk JOIN date_dim ON d_date_sk=ws_sold_date_sk LEFT JOIN web_returns ON (ws_order_number=wr_order_number AND ws_item_sk=wr_item_sk) WHERE i_category='Shoes') sales_detail GROUP BY d_year, i_brand_id, i_class_id, i_category_id, i_manufact_id) SELECT prev_yr.d_year AS prev_year ,curr_yr.d_year AS year ,curr_yr.i_brand_id ,curr_yr.i_class_id ,curr_yr.i_category_id ,curr_yr.i_manufact_id ,prev_yr.sales_cnt AS prev_yr_cnt ,curr_yr.sales_cnt AS curr_yr_cnt ,curr_yr.sales_cnt-prev_yr.sales_cnt AS sales_cnt_diff ,curr_yr.sales_amt-prev_yr.sales_amt AS sales_amt_diff FROM all_sales curr_yr, all_sales prev_yr WHERE curr_yr.i_brand_id=prev_yr.i_brand_id AND curr_yr.i_class_id=prev_yr.i_class_id AND curr_yr.i_category_id=prev_yr.i_category_id AND curr_yr.i_manufact_id=prev_yr.i_manufact_id AND curr_yr.d_year=2000 AND prev_yr.d_year=2000-1 AND CAST(curr_yr.sales_cnt AS DECIMAL(17,2))/CAST(prev_yr.sales_cnt AS DECIMAL(17,2))<0.9 ORDER BY sales_cnt_diff,sales_amt_diff LIMIT 100; -- end query 3 in stream 0 using template query75.tpl -- start query 4 in stream 0 using template query44.tpl select asceding.rnk, i1.i_product_name best_performing, i2.i_product_name worst_performing from(select * from (select item_sk,rank() over (order by rank_col asc) rnk from (select ss_item_sk item_sk,avg(ss_net_profit) rank_col from store_sales ss1 where ss_store_sk = 30 group by ss_item_sk having avg(ss_net_profit) > 0.9*(select avg(ss_net_profit) rank_col from store_sales where ss_store_sk = 30 and ss_hdemo_sk is null group by ss_store_sk))V1)V11 where rnk < 11) asceding, (select * from (select item_sk,rank() over (order by rank_col desc) rnk from (select ss_item_sk item_sk,avg(ss_net_profit) rank_col from store_sales ss1 where ss_store_sk = 30 group by ss_item_sk having avg(ss_net_profit) > 0.9*(select avg(ss_net_profit) rank_col from store_sales where ss_store_sk = 30 and ss_hdemo_sk is null group by ss_store_sk))V2)V21 where rnk < 11) descending, item i1, item i2 where asceding.rnk = descending.rnk and i1.i_item_sk=asceding.item_sk and i2.i_item_sk=descending.item_sk order by asceding.rnk LIMIT 100; -- end query 4 in stream 0 using template query44.tpl -- start query 5 in stream 0 using template query39.tpl with inv as (select w_warehouse_name,w_warehouse_sk,i_item_sk,d_moy ,stdev,mean, case mean when 0 then null else stdev/mean end cov from(select w_warehouse_name,w_warehouse_sk,i_item_sk,d_moy ,stddev_samp(inv_quantity_on_hand) stdev,avg(inv_quantity_on_hand) mean from inventory ,item ,warehouse ,date_dim where inv_item_sk = i_item_sk and inv_warehouse_sk = w_warehouse_sk and inv_date_sk = d_date_sk and d_year =2001 group by w_warehouse_name,w_warehouse_sk,i_item_sk,d_moy) foo where case mean when 0 then 0 else stdev/mean end > 1) select inv1.w_warehouse_sk,inv1.i_item_sk,inv1.d_moy,inv1.mean, inv1.cov ,inv2.w_warehouse_sk,inv2.i_item_sk,inv2.d_moy,inv2.mean, inv2.cov from inv inv1,inv inv2 where inv1.i_item_sk = inv2.i_item_sk and inv1.w_warehouse_sk = inv2.w_warehouse_sk and inv1.d_moy=1 and inv2.d_moy=1+1 order by inv1.w_warehouse_sk,inv1.i_item_sk,inv1.d_moy,inv1.mean,inv1.cov ,inv2.d_moy,inv2.mean, inv2.cov ; with inv as (select w_warehouse_name,w_warehouse_sk,i_item_sk,d_moy ,stdev,mean, case mean when 0 then null else stdev/mean end cov from(select w_warehouse_name,w_warehouse_sk,i_item_sk,d_moy ,stddev_samp(inv_quantity_on_hand) stdev,avg(inv_quantity_on_hand) mean from inventory ,item ,warehouse ,date_dim where inv_item_sk = i_item_sk and inv_warehouse_sk = w_warehouse_sk and inv_date_sk = d_date_sk and d_year =2001 group by w_warehouse_name,w_warehouse_sk,i_item_sk,d_moy) foo where case mean when 0 then 0 else stdev/mean end > 1) select inv1.w_warehouse_sk,inv1.i_item_sk,inv1.d_moy,inv1.mean, inv1.cov ,inv2.w_warehouse_sk,inv2.i_item_sk,inv2.d_moy,inv2.mean, inv2.cov from inv inv1,inv inv2 where inv1.i_item_sk = inv2.i_item_sk and inv1.w_warehouse_sk = inv2.w_warehouse_sk and inv1.d_moy=1 and inv2.d_moy=1+1 and inv1.cov > 1.5 order by inv1.w_warehouse_sk,inv1.i_item_sk,inv1.d_moy,inv1.mean,inv1.cov ,inv2.d_moy,inv2.mean, inv2.cov ; -- end query 5 in stream 0 using template query39.tpl -- start query 6 in stream 0 using template query80.tpl with ssr as (select s_store_id as store_id, sum(ss_ext_sales_price) as sales, sum(coalesce(sr_return_amt, 0)) as returns, sum(ss_net_profit - coalesce(sr_net_loss, 0)) as profit from store_sales left outer join store_returns on (ss_item_sk = sr_item_sk and ss_ticket_number = sr_ticket_number), date_dim, store, item, promotion where ss_sold_date_sk = d_date_sk and d_date between cast('2002-08-04' as date) and (cast('2002-08-04' as date) + interval 30 days) and ss_store_sk = s_store_sk and ss_item_sk = i_item_sk and i_current_price > 50 and ss_promo_sk = p_promo_sk and p_channel_tv = 'N' group by s_store_id) , csr as (select cp_catalog_page_id as catalog_page_id, sum(cs_ext_sales_price) as sales, sum(coalesce(cr_return_amount, 0)) as returns, sum(cs_net_profit - coalesce(cr_net_loss, 0)) as profit from catalog_sales left outer join catalog_returns on (cs_item_sk = cr_item_sk and cs_order_number = cr_order_number), date_dim, catalog_page, item, promotion where cs_sold_date_sk = d_date_sk and d_date between cast('2002-08-04' as date) and (cast('2002-08-04' as date) + interval 30 days) and cs_catalog_page_sk = cp_catalog_page_sk and cs_item_sk = i_item_sk and i_current_price > 50 and cs_promo_sk = p_promo_sk and p_channel_tv = 'N' group by cp_catalog_page_id) , wsr as (select web_site_id, sum(ws_ext_sales_price) as sales, sum(coalesce(wr_return_amt, 0)) as returns, sum(ws_net_profit - coalesce(wr_net_loss, 0)) as profit from web_sales left outer join web_returns on (ws_item_sk = wr_item_sk and ws_order_number = wr_order_number), date_dim, web_site, item, promotion where ws_sold_date_sk = d_date_sk and d_date between cast('2002-08-04' as date) and (cast('2002-08-04' as date) + interval 30 days) and ws_web_site_sk = web_site_sk and ws_item_sk = i_item_sk and i_current_price > 50 and ws_promo_sk = p_promo_sk and p_channel_tv = 'N' group by web_site_id) select channel , id , sum(sales) as sales , sum(returns) as returns , sum(profit) as profit from (select 'store channel' as channel , 'store' || store_id as id , sales , returns , profit from ssr union all select 'catalog channel' as channel , 'catalog_page' || catalog_page_id as id , sales , returns , profit from csr union all select 'web channel' as channel , 'web_site' || web_site_id as id , sales , returns , profit from wsr ) x group by rollup (channel, id) order by channel ,id LIMIT 100; -- end query 6 in stream 0 using template query80.tpl -- start query 7 in stream 0 using template query32.tpl select sum(cs_ext_discount_amt) as `excess discount amount` from catalog_sales ,item ,date_dim where i_manufact_id = 283 and i_item_sk = cs_item_sk and d_date between '1999-02-22' and (cast('1999-02-22' as date) + interval 90 days) and d_date_sk = cs_sold_date_sk and cs_ext_discount_amt > ( select 1.3 * avg(cs_ext_discount_amt) from catalog_sales ,date_dim where cs_item_sk = i_item_sk and d_date between '1999-02-22' and (cast('1999-02-22' as date) + interval 90 days) and d_date_sk = cs_sold_date_sk ) LIMIT 100; -- end query 7 in stream 0 using template query32.tpl -- start query 8 in stream 0 using template query19.tpl select i_brand_id brand_id, i_brand brand, i_manufact_id, i_manufact, sum(ss_ext_sales_price) ext_price from date_dim, store_sales, item,customer,customer_address,store where d_date_sk = ss_sold_date_sk and ss_item_sk = i_item_sk and i_manager_id=8 and d_moy=11 and d_year=1999 and ss_customer_sk = c_customer_sk and c_current_addr_sk = ca_address_sk and substr(ca_zip,1,5) <> substr(s_zip,1,5) and ss_store_sk = s_store_sk group by i_brand ,i_brand_id ,i_manufact_id ,i_manufact order by ext_price desc ,i_brand ,i_brand_id ,i_manufact_id ,i_manufact LIMIT 100 ; -- end query 8 in stream 0 using template query19.tpl -- start query 9 in stream 0 using template query25.tpl select i_item_id ,i_item_desc ,s_store_id ,s_store_name ,min(ss_net_profit) as store_sales_profit ,min(sr_net_loss) as store_returns_loss ,min(cs_net_profit) as catalog_sales_profit from store_sales ,store_returns ,catalog_sales ,date_dim d1 ,date_dim d2 ,date_dim d3 ,store ,item where d1.d_moy = 4 and d1.d_year = 2002 and d1.d_date_sk = ss_sold_date_sk and i_item_sk = ss_item_sk and s_store_sk = ss_store_sk and ss_customer_sk = sr_customer_sk and ss_item_sk = sr_item_sk and ss_ticket_number = sr_ticket_number and sr_returned_date_sk = d2.d_date_sk and d2.d_moy between 4 and 10 and d2.d_year = 2002 and sr_customer_sk = cs_bill_customer_sk and sr_item_sk = cs_item_sk and cs_sold_date_sk = d3.d_date_sk and d3.d_moy between 4 and 10 and d3.d_year = 2002 group by i_item_id ,i_item_desc ,s_store_id ,s_store_name order by i_item_id ,i_item_desc ,s_store_id ,s_store_name LIMIT 100; -- end query 9 in stream 0 using template query25.tpl -- start query 10 in stream 0 using template query78.tpl with ws as (select d_year AS ws_sold_year, ws_item_sk, ws_bill_customer_sk ws_customer_sk, sum(ws_quantity) ws_qty, sum(ws_wholesale_cost) ws_wc, sum(ws_sales_price) ws_sp from web_sales left join web_returns on wr_order_number=ws_order_number and ws_item_sk=wr_item_sk join date_dim on ws_sold_date_sk = d_date_sk where wr_order_number is null group by d_year, ws_item_sk, ws_bill_customer_sk ), cs as (select d_year AS cs_sold_year, cs_item_sk, cs_bill_customer_sk cs_customer_sk, sum(cs_quantity) cs_qty, sum(cs_wholesale_cost) cs_wc, sum(cs_sales_price) cs_sp from catalog_sales left join catalog_returns on cr_order_number=cs_order_number and cs_item_sk=cr_item_sk join date_dim on cs_sold_date_sk = d_date_sk where cr_order_number is null group by d_year, cs_item_sk, cs_bill_customer_sk ), ss as (select d_year AS ss_sold_year, ss_item_sk, ss_customer_sk, sum(ss_quantity) ss_qty, sum(ss_wholesale_cost) ss_wc, sum(ss_sales_price) ss_sp from store_sales left join store_returns on sr_ticket_number=ss_ticket_number and ss_item_sk=sr_item_sk join date_dim on ss_sold_date_sk = d_date_sk where sr_ticket_number is null group by d_year, ss_item_sk, ss_customer_sk ) select ss_customer_sk, round(ss_qty/(coalesce(ws_qty,0)+coalesce(cs_qty,0)),2) ratio, ss_qty store_qty, ss_wc store_wholesale_cost, ss_sp store_sales_price, coalesce(ws_qty,0)+coalesce(cs_qty,0) other_chan_qty, coalesce(ws_wc,0)+coalesce(cs_wc,0) other_chan_wholesale_cost, coalesce(ws_sp,0)+coalesce(cs_sp,0) other_chan_sales_price from ss left join ws on (ws_sold_year=ss_sold_year and ws_item_sk=ss_item_sk and ws_customer_sk=ss_customer_sk) left join cs on (cs_sold_year=ss_sold_year and cs_item_sk=ss_item_sk and cs_customer_sk=ss_customer_sk) where (coalesce(ws_qty,0)>0 or coalesce(cs_qty, 0)>0) and ss_sold_year=2001 order by ss_customer_sk, ss_qty desc, ss_wc desc, ss_sp desc, other_chan_qty, other_chan_wholesale_cost, other_chan_sales_price, ratio LIMIT 100; -- end query 10 in stream 0 using template query78.tpl -- start query 11 in stream 0 using template query86.tpl select sum(ws_net_paid) as total_sum ,i_category ,i_class ,grouping(i_category)+grouping(i_class) as lochierarchy ,rank() over ( partition by grouping(i_category)+grouping(i_class), case when grouping(i_class) = 0 then i_category end order by sum(ws_net_paid) desc) as rank_within_parent from web_sales ,date_dim d1 ,item where d1.d_month_seq between 1205 and 1205+11 and d1.d_date_sk = ws_sold_date_sk and i_item_sk = ws_item_sk group by rollup(i_category,i_class) order by lochierarchy desc, case when lochierarchy = 0 then i_category end, rank_within_parent LIMIT 100; -- end query 11 in stream 0 using template query86.tpl -- start query 12 in stream 0 using template query1.tpl with customer_total_return as (select sr_customer_sk as ctr_customer_sk ,sr_store_sk as ctr_store_sk ,sum(SR_RETURN_AMT_INC_TAX) as ctr_total_return from store_returns ,date_dim where sr_returned_date_sk = d_date_sk and d_year =1999 group by sr_customer_sk ,sr_store_sk) select c_customer_id from customer_total_return ctr1 ,store ,customer where ctr1.ctr_total_return > (select avg(ctr_total_return)*1.2 from customer_total_return ctr2 where ctr1.ctr_store_sk = ctr2.ctr_store_sk) and s_store_sk = ctr1.ctr_store_sk and s_state = 'SD' and ctr1.ctr_customer_sk = c_customer_sk order by c_customer_id LIMIT 100; -- end query 12 in stream 0 using template query1.tpl -- start query 13 in stream 0 using template query91.tpl select cc_call_center_id Call_Center, cc_name Call_Center_Name, cc_manager Manager, sum(cr_net_loss) Returns_Loss from call_center, catalog_returns, date_dim, customer, customer_address, customer_demographics, household_demographics where cr_call_center_sk = cc_call_center_sk and cr_returned_date_sk = d_date_sk and cr_returning_customer_sk= c_customer_sk and cd_demo_sk = c_current_cdemo_sk and hd_demo_sk = c_current_hdemo_sk and ca_address_sk = c_current_addr_sk and d_year = 2002 and d_moy = 11 and ( (cd_marital_status = 'M' and cd_education_status = 'Unknown') or(cd_marital_status = 'W' and cd_education_status = 'Advanced Degree')) and hd_buy_potential like 'Unknown%' and ca_gmt_offset = -6 group by cc_call_center_id,cc_name,cc_manager,cd_marital_status,cd_education_status order by sum(cr_net_loss) desc; -- end query 13 in stream 0 using template query91.tpl -- start query 14 in stream 0 using template query21.tpl select * from(select w_warehouse_name ,i_item_id ,sum(case when (cast(d_date as date) < cast ('2000-05-19' as date)) then inv_quantity_on_hand else 0 end) as inv_before ,sum(case when (cast(d_date as date) >= cast ('2000-05-19' as date)) then inv_quantity_on_hand else 0 end) as inv_after from inventory ,warehouse ,item ,date_dim where i_current_price between 0.99 and 1.49 and i_item_sk = inv_item_sk and inv_warehouse_sk = w_warehouse_sk and inv_date_sk = d_date_sk and d_date between (cast ('2000-05-19' as date) - interval 30 days) and (cast ('2000-05-19' as date) + interval 30 days) group by w_warehouse_name, i_item_id) x where (case when inv_before > 0 then inv_after / inv_before else null end) between 2.0/3.0 and 3.0/2.0 order by w_warehouse_name ,i_item_id LIMIT 100; -- end query 14 in stream 0 using template query21.tpl -- start query 15 in stream 0 using template query43.tpl select s_store_name, s_store_id, sum(case when (d_day_name='Sunday') then ss_sales_price else null end) sun_sales, sum(case when (d_day_name='Monday') then ss_sales_price else null end) mon_sales, sum(case when (d_day_name='Tuesday') then ss_sales_price else null end) tue_sales, sum(case when (d_day_name='Wednesday') then ss_sales_price else null end) wed_sales, sum(case when (d_day_name='Thursday') then ss_sales_price else null end) thu_sales, sum(case when (d_day_name='Friday') then ss_sales_price else null end) fri_sales, sum(case when (d_day_name='Saturday') then ss_sales_price else null end) sat_sales from date_dim, store_sales, store where d_date_sk = ss_sold_date_sk and s_store_sk = ss_store_sk and s_gmt_offset = -5 and d_year = 2000 group by s_store_name, s_store_id order by s_store_name, s_store_id,sun_sales,mon_sales,tue_sales,wed_sales,thu_sales,fri_sales,sat_sales LIMIT 100; -- end query 15 in stream 0 using template query43.tpl -- start query 16 in stream 0 using template query27.tpl select i_item_id, s_state, grouping(s_state) g_state, avg(ss_quantity) agg1, avg(ss_list_price) agg2, avg(ss_coupon_amt) agg3, avg(ss_sales_price) agg4 from store_sales, customer_demographics, date_dim, store, item where ss_sold_date_sk = d_date_sk and ss_item_sk = i_item_sk and ss_store_sk = s_store_sk and ss_cdemo_sk = cd_demo_sk and cd_gender = 'F' and cd_marital_status = 'D' and cd_education_status = 'College' and d_year = 2002 and s_state in ('SD','AL', 'TN', 'TN', 'SD', 'SD') group by rollup (i_item_id, s_state) order by i_item_id ,s_state LIMIT 100; -- end query 16 in stream 0 using template query27.tpl -- start query 17 in stream 0 using template query94.tpl select count(distinct ws_order_number) as `order count` ,sum(ws_ext_ship_cost) as `total shipping cost` ,sum(ws_net_profit) as `total net profit` from web_sales ws1 ,date_dim ,customer_address ,web_site where d_date between '2001-5-01' and (cast('2001-5-01' as date) + interval 60 days) and ws1.ws_ship_date_sk = d_date_sk and ws1.ws_ship_addr_sk = ca_address_sk and ca_state = 'AR' and ws1.ws_web_site_sk = web_site_sk and web_company_name = 'pri' and exists (select * from web_sales ws2 where ws1.ws_order_number = ws2.ws_order_number and ws1.ws_warehouse_sk <> ws2.ws_warehouse_sk) and not exists(select * from web_returns wr1 where ws1.ws_order_number = wr1.wr_order_number) order by count(distinct ws_order_number) LIMIT 100; -- end query 17 in stream 0 using template query94.tpl -- start query 18 in stream 0 using template query45.tpl select ca_zip, ca_county, sum(ws_sales_price) from web_sales, customer, customer_address, date_dim, item where ws_bill_customer_sk = c_customer_sk and c_current_addr_sk = ca_address_sk and ws_item_sk = i_item_sk and ( substr(ca_zip,1,5) in ('85669', '86197','88274','83405','86475', '85392', '85460', '80348', '81792') or i_item_id in (select i_item_id from item where i_item_sk in (2, 3, 5, 7, 11, 13, 17, 19, 23, 29) ) ) and ws_sold_date_sk = d_date_sk and d_qoy = 2 and d_year = 2000 group by ca_zip, ca_county order by ca_zip, ca_county LIMIT 100; -- end query 18 in stream 0 using template query45.tpl -- start query 19 in stream 0 using template query58.tpl with ss_items as (select i_item_id item_id ,sum(ss_ext_sales_price) ss_item_rev from store_sales ,item ,date_dim where ss_item_sk = i_item_sk and d_date in (select d_date from date_dim where d_week_seq = (select d_week_seq from date_dim where d_date = '2002-04-19')) and ss_sold_date_sk = d_date_sk group by i_item_id), cs_items as (select i_item_id item_id ,sum(cs_ext_sales_price) cs_item_rev from catalog_sales ,item ,date_dim where cs_item_sk = i_item_sk and d_date in (select d_date from date_dim where d_week_seq = (select d_week_seq from date_dim where d_date = '2002-04-19')) and cs_sold_date_sk = d_date_sk group by i_item_id), ws_items as (select i_item_id item_id ,sum(ws_ext_sales_price) ws_item_rev from web_sales ,item ,date_dim where ws_item_sk = i_item_sk and d_date in (select d_date from date_dim where d_week_seq =(select d_week_seq from date_dim where d_date = '2002-04-19')) and ws_sold_date_sk = d_date_sk group by i_item_id) select ss_items.item_id ,ss_item_rev ,ss_item_rev/((ss_item_rev+cs_item_rev+ws_item_rev)/3) * 100 ss_dev ,cs_item_rev ,cs_item_rev/((ss_item_rev+cs_item_rev+ws_item_rev)/3) * 100 cs_dev ,ws_item_rev ,ws_item_rev/((ss_item_rev+cs_item_rev+ws_item_rev)/3) * 100 ws_dev ,(ss_item_rev+cs_item_rev+ws_item_rev)/3 average from ss_items,cs_items,ws_items where ss_items.item_id=cs_items.item_id and ss_items.item_id=ws_items.item_id and ss_item_rev between 0.9 * cs_item_rev and 1.1 * cs_item_rev and ss_item_rev between 0.9 * ws_item_rev and 1.1 * ws_item_rev and cs_item_rev between 0.9 * ss_item_rev and 1.1 * ss_item_rev and cs_item_rev between 0.9 * ws_item_rev and 1.1 * ws_item_rev and ws_item_rev between 0.9 * ss_item_rev and 1.1 * ss_item_rev and ws_item_rev between 0.9 * cs_item_rev and 1.1 * cs_item_rev order by item_id ,ss_item_rev LIMIT 100; -- end query 19 in stream 0 using template query58.tpl -- start query 20 in stream 0 using template query64.tpl with cs_ui as (select cs_item_sk ,sum(cs_ext_list_price) as sale,sum(cr_refunded_cash+cr_reversed_charge+cr_store_credit) as refund from catalog_sales ,catalog_returns where cs_item_sk = cr_item_sk and cs_order_number = cr_order_number group by cs_item_sk having sum(cs_ext_list_price)>2*sum(cr_refunded_cash+cr_reversed_charge+cr_store_credit)), cross_sales as (select i_product_name product_name ,i_item_sk item_sk ,s_store_name store_name ,s_zip store_zip ,ad1.ca_street_number b_street_number ,ad1.ca_street_name b_street_name ,ad1.ca_city b_city ,ad1.ca_zip b_zip ,ad2.ca_street_number c_street_number ,ad2.ca_street_name c_street_name ,ad2.ca_city c_city ,ad2.ca_zip c_zip ,d1.d_year as syear ,d2.d_year as fsyear ,d3.d_year s2year ,count(*) cnt ,sum(ss_wholesale_cost) s1 ,sum(ss_list_price) s2 ,sum(ss_coupon_amt) s3 FROM store_sales ,store_returns ,cs_ui ,date_dim d1 ,date_dim d2 ,date_dim d3 ,store ,customer ,customer_demographics cd1 ,customer_demographics cd2 ,promotion ,household_demographics hd1 ,household_demographics hd2 ,customer_address ad1 ,customer_address ad2 ,income_band ib1 ,income_band ib2 ,item WHERE ss_store_sk = s_store_sk AND ss_sold_date_sk = d1.d_date_sk AND ss_customer_sk = c_customer_sk AND ss_cdemo_sk= cd1.cd_demo_sk AND ss_hdemo_sk = hd1.hd_demo_sk AND ss_addr_sk = ad1.ca_address_sk and ss_item_sk = i_item_sk and ss_item_sk = sr_item_sk and ss_ticket_number = sr_ticket_number and ss_item_sk = cs_ui.cs_item_sk and c_current_cdemo_sk = cd2.cd_demo_sk AND c_current_hdemo_sk = hd2.hd_demo_sk AND c_current_addr_sk = ad2.ca_address_sk and c_first_sales_date_sk = d2.d_date_sk and c_first_shipto_date_sk = d3.d_date_sk and ss_promo_sk = p_promo_sk and hd1.hd_income_band_sk = ib1.ib_income_band_sk and hd2.hd_income_band_sk = ib2.ib_income_band_sk and cd1.cd_marital_status <> cd2.cd_marital_status and i_color in ('lawn','blush','smoke','ghost','floral','chartreuse') and i_current_price between 51 and 51 + 10 and i_current_price between 51 + 1 and 51 + 15 group by i_product_name ,i_item_sk ,s_store_name ,s_zip ,ad1.ca_street_number ,ad1.ca_street_name ,ad1.ca_city ,ad1.ca_zip ,ad2.ca_street_number ,ad2.ca_street_name ,ad2.ca_city ,ad2.ca_zip ,d1.d_year ,d2.d_year ,d3.d_year ) select cs1.product_name ,cs1.store_name ,cs1.store_zip ,cs1.b_street_number ,cs1.b_street_name ,cs1.b_city ,cs1.b_zip ,cs1.c_street_number ,cs1.c_street_name ,cs1.c_city ,cs1.c_zip ,cs1.syear ,cs1.cnt ,cs1.s1 as s11 ,cs1.s2 as s21 ,cs1.s3 as s31 ,cs2.s1 as s12 ,cs2.s2 as s22 ,cs2.s3 as s32 ,cs2.syear ,cs2.cnt from cross_sales cs1,cross_sales cs2 where cs1.item_sk=cs2.item_sk and cs1.syear = 2001 and cs2.syear = 2001 + 1 and cs2.cnt <= cs1.cnt and cs1.store_name = cs2.store_name and cs1.store_zip = cs2.store_zip order by cs1.product_name ,cs1.store_name ,cs2.cnt ,cs1.s1 ,cs2.s1; -- end query 20 in stream 0 using template query64.tpl -- start query 21 in stream 0 using template query36.tpl select sum(ss_net_profit)/sum(ss_ext_sales_price) as gross_margin ,i_category ,i_class ,grouping(i_category)+grouping(i_class) as lochierarchy ,rank() over ( partition by grouping(i_category)+grouping(i_class), case when grouping(i_class) = 0 then i_category end order by sum(ss_net_profit)/sum(ss_ext_sales_price) asc) as rank_within_parent from store_sales ,date_dim d1 ,item ,store where d1.d_year = 1999 and d1.d_date_sk = ss_sold_date_sk and i_item_sk = ss_item_sk and s_store_sk = ss_store_sk and s_state in ('AL','TN','SD','SD', 'SD','SD','SD','SD') group by rollup(i_category,i_class) order by lochierarchy desc ,case when lochierarchy = 0 then i_category end ,rank_within_parent LIMIT 100; -- end query 21 in stream 0 using template query36.tpl -- start query 22 in stream 0 using template query33.tpl with ss as ( select i_manufact_id,sum(ss_ext_sales_price) total_sales from store_sales, date_dim, customer_address, item where i_manufact_id in (select i_manufact_id from item where i_category in ('Electronics')) and ss_item_sk = i_item_sk and ss_sold_date_sk = d_date_sk and d_year = 2002 and d_moy = 1 and ss_addr_sk = ca_address_sk and ca_gmt_offset = -6 group by i_manufact_id), cs as ( select i_manufact_id,sum(cs_ext_sales_price) total_sales from catalog_sales, date_dim, customer_address, item where i_manufact_id in (select i_manufact_id from item where i_category in ('Electronics')) and cs_item_sk = i_item_sk and cs_sold_date_sk = d_date_sk and d_year = 2002 and d_moy = 1 and cs_bill_addr_sk = ca_address_sk and ca_gmt_offset = -6 group by i_manufact_id), ws as ( select i_manufact_id,sum(ws_ext_sales_price) total_sales from web_sales, date_dim, customer_address, item where i_manufact_id in (select i_manufact_id from item where i_category in ('Electronics')) and ws_item_sk = i_item_sk and ws_sold_date_sk = d_date_sk and d_year = 2002 and d_moy = 1 and ws_bill_addr_sk = ca_address_sk and ca_gmt_offset = -6 group by i_manufact_id) select i_manufact_id ,sum(total_sales) total_sales from (select * from ss union all select * from cs union all select * from ws) tmp1 group by i_manufact_id order by total_sales LIMIT 100; -- end query 22 in stream 0 using template query33.tpl -- start query 23 in stream 0 using template query46.tpl select c_last_name ,c_first_name ,ca_city ,bought_city ,ss_ticket_number ,amt,profit from (select ss_ticket_number ,ss_customer_sk ,ca_city bought_city ,sum(ss_coupon_amt) amt ,sum(ss_net_profit) profit from store_sales,date_dim,store,household_demographics,customer_address where store_sales.ss_sold_date_sk = date_dim.d_date_sk and store_sales.ss_store_sk = store.s_store_sk and store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk and store_sales.ss_addr_sk = customer_address.ca_address_sk and (household_demographics.hd_dep_count = 3 or household_demographics.hd_vehicle_count= 4) and date_dim.d_dow in (6,0) and date_dim.d_year in (2000,2000+1,2000+2) and store.s_city in ('Oak Grove','Fairview','Five Points','Riverside','Pleasant Hill') group by ss_ticket_number,ss_customer_sk,ss_addr_sk,ca_city) dn,customer,customer_address current_addr where ss_customer_sk = c_customer_sk and customer.c_current_addr_sk = current_addr.ca_address_sk and current_addr.ca_city <> bought_city order by c_last_name ,c_first_name ,ca_city ,bought_city ,ss_ticket_number LIMIT 100; -- end query 23 in stream 0 using template query46.tpl -- start query 24 in stream 0 using template query62.tpl select substr(w_warehouse_name,1,20) ,sm_type ,web_name ,sum(case when (ws_ship_date_sk - ws_sold_date_sk <= 30 ) then 1 else 0 end) as `30 days` ,sum(case when (ws_ship_date_sk - ws_sold_date_sk > 30) and (ws_ship_date_sk - ws_sold_date_sk <= 60) then 1 else 0 end ) as `31-60 days` ,sum(case when (ws_ship_date_sk - ws_sold_date_sk > 60) and (ws_ship_date_sk - ws_sold_date_sk <= 90) then 1 else 0 end) as `61-90 days` ,sum(case when (ws_ship_date_sk - ws_sold_date_sk > 90) and (ws_ship_date_sk - ws_sold_date_sk <= 120) then 1 else 0 end) as `91-120 days` ,sum(case when (ws_ship_date_sk - ws_sold_date_sk > 120) then 1 else 0 end) as `>120 days` from web_sales ,warehouse ,ship_mode ,web_site ,date_dim where d_month_seq between 1211 and 1211 + 11 and ws_ship_date_sk = d_date_sk and ws_warehouse_sk = w_warehouse_sk and ws_ship_mode_sk = sm_ship_mode_sk and ws_web_site_sk = web_site_sk group by substr(w_warehouse_name,1,20) ,sm_type ,web_name order by substr(w_warehouse_name,1,20) ,sm_type ,web_name LIMIT 100; -- end query 24 in stream 0 using template query62.tpl -- start query 25 in stream 0 using template query16.tpl select count(distinct cs_order_number) as `order count` ,sum(cs_ext_ship_cost) as `total shipping cost` ,sum(cs_net_profit) as `total net profit` from catalog_sales cs1 ,date_dim ,customer_address ,call_center where d_date between '1999-4-01' and (cast('1999-4-01' as date) + interval 60 days) and cs1.cs_ship_date_sk = d_date_sk and cs1.cs_ship_addr_sk = ca_address_sk and ca_state = 'MD' and cs1.cs_call_center_sk = cc_call_center_sk and cc_county in ('Ziebach County','Williamson County','Walker County','Williamson County', 'Ziebach County' ) and exists (select * from catalog_sales cs2 where cs1.cs_order_number = cs2.cs_order_number and cs1.cs_warehouse_sk <> cs2.cs_warehouse_sk) and not exists(select * from catalog_returns cr1 where cs1.cs_order_number = cr1.cr_order_number) order by count(distinct cs_order_number) LIMIT 100; -- end query 25 in stream 0 using template query16.tpl -- start query 26 in stream 0 using template query10.tpl select cd_gender, cd_marital_status, cd_education_status, count(*) cnt1, cd_purchase_estimate, count(*) cnt2, cd_credit_rating, count(*) cnt3, cd_dep_count, count(*) cnt4, cd_dep_employed_count, count(*) cnt5, cd_dep_college_count, count(*) cnt6 from customer c,customer_address ca,customer_demographics where c.c_current_addr_sk = ca.ca_address_sk and ca_county in ('Bottineau County','Marion County','Randolph County','Providence County','Sagadahoc County') and cd_demo_sk = c.c_current_cdemo_sk and exists (select * from store_sales,date_dim where c.c_customer_sk = ss_customer_sk and ss_sold_date_sk = d_date_sk and d_year = 2000 and d_moy between 1 and 1+3) and (exists (select * from web_sales,date_dim where c.c_customer_sk = ws_bill_customer_sk and ws_sold_date_sk = d_date_sk and d_year = 2000 and d_moy between 1 ANd 1+3) or exists (select * from catalog_sales,date_dim where c.c_customer_sk = cs_ship_customer_sk and cs_sold_date_sk = d_date_sk and d_year = 2000 and d_moy between 1 and 1+3)) group by cd_gender, cd_marital_status, cd_education_status, cd_purchase_estimate, cd_credit_rating, cd_dep_count, cd_dep_employed_count, cd_dep_college_count order by cd_gender, cd_marital_status, cd_education_status, cd_purchase_estimate, cd_credit_rating, cd_dep_count, cd_dep_employed_count, cd_dep_college_count LIMIT 100; -- end query 26 in stream 0 using template query10.tpl -- start query 27 in stream 0 using template query63.tpl select * from (select i_manager_id ,sum(ss_sales_price) sum_sales ,avg(sum(ss_sales_price)) over (partition by i_manager_id) avg_monthly_sales from item ,store_sales ,date_dim ,store where ss_item_sk = i_item_sk and ss_sold_date_sk = d_date_sk and ss_store_sk = s_store_sk and d_month_seq in (1179,1179+1,1179+2,1179+3,1179+4,1179+5,1179+6,1179+7,1179+8,1179+9,1179+10,1179+11) and (( i_category in ('Books','Children','Electronics') and i_class in ('personal','portable','reference','self-help') and i_brand in ('scholaramalgamalg #14','scholaramalgamalg #7', 'exportiunivamalg #9','scholaramalgamalg #9')) or( i_category in ('Women','Music','Men') and i_class in ('accessories','classical','fragrances','pants') and i_brand in ('amalgimporto #1','edu packscholar #1','exportiimporto #1', 'importoamalg #1'))) group by i_manager_id, d_moy) tmp1 where case when avg_monthly_sales > 0 then abs (sum_sales - avg_monthly_sales) / avg_monthly_sales else null end > 0.1 order by i_manager_id ,avg_monthly_sales ,sum_sales LIMIT 100; -- end query 27 in stream 0 using template query63.tpl -- start query 28 in stream 0 using template query69.tpl select cd_gender, cd_marital_status, cd_education_status, count(*) cnt1, cd_purchase_estimate, count(*) cnt2, cd_credit_rating, count(*) cnt3 from customer c,customer_address ca,customer_demographics where c.c_current_addr_sk = ca.ca_address_sk and ca_state in ('IN','ND','PA') and cd_demo_sk = c.c_current_cdemo_sk and exists (select * from store_sales,date_dim where c.c_customer_sk = ss_customer_sk and ss_sold_date_sk = d_date_sk and d_year = 1999 and d_moy between 2 and 2+2) and (not exists (select * from web_sales,date_dim where c.c_customer_sk = ws_bill_customer_sk and ws_sold_date_sk = d_date_sk and d_year = 1999 and d_moy between 2 and 2+2) and not exists (select * from catalog_sales,date_dim where c.c_customer_sk = cs_ship_customer_sk and cs_sold_date_sk = d_date_sk and d_year = 1999 and d_moy between 2 and 2+2)) group by cd_gender, cd_marital_status, cd_education_status, cd_purchase_estimate, cd_credit_rating order by cd_gender, cd_marital_status, cd_education_status, cd_purchase_estimate, cd_credit_rating LIMIT 100; -- end query 28 in stream 0 using template query69.tpl -- start query 29 in stream 0 using template query60.tpl with ss as ( select i_item_id,sum(ss_ext_sales_price) total_sales from store_sales, date_dim, customer_address, item where i_item_id in (select i_item_id from item where i_category in ('Music')) and ss_item_sk = i_item_sk and ss_sold_date_sk = d_date_sk and d_year = 1998 and d_moy = 10 and ss_addr_sk = ca_address_sk and ca_gmt_offset = -5 group by i_item_id), cs as ( select i_item_id,sum(cs_ext_sales_price) total_sales from catalog_sales, date_dim, customer_address, item where i_item_id in (select i_item_id from item where i_category in ('Music')) and cs_item_sk = i_item_sk and cs_sold_date_sk = d_date_sk and d_year = 1998 and d_moy = 10 and cs_bill_addr_sk = ca_address_sk and ca_gmt_offset = -5 group by i_item_id), ws as ( select i_item_id,sum(ws_ext_sales_price) total_sales from web_sales, date_dim, customer_address, item where i_item_id in (select i_item_id from item where i_category in ('Music')) and ws_item_sk = i_item_sk and ws_sold_date_sk = d_date_sk and d_year = 1998 and d_moy = 10 and ws_bill_addr_sk = ca_address_sk and ca_gmt_offset = -5 group by i_item_id) select i_item_id ,sum(total_sales) total_sales from (select * from ss union all select * from cs union all select * from ws) tmp1 group by i_item_id order by i_item_id ,total_sales LIMIT 100; -- end query 29 in stream 0 using template query60.tpl -- start query 30 in stream 0 using template query59.tpl with wss as (select d_week_seq, ss_store_sk, sum(case when (d_day_name='Sunday') then ss_sales_price else null end) sun_sales, sum(case when (d_day_name='Monday') then ss_sales_price else null end) mon_sales, sum(case when (d_day_name='Tuesday') then ss_sales_price else null end) tue_sales, sum(case when (d_day_name='Wednesday') then ss_sales_price else null end) wed_sales, sum(case when (d_day_name='Thursday') then ss_sales_price else null end) thu_sales, sum(case when (d_day_name='Friday') then ss_sales_price else null end) fri_sales, sum(case when (d_day_name='Saturday') then ss_sales_price else null end) sat_sales from store_sales,date_dim where d_date_sk = ss_sold_date_sk group by d_week_seq,ss_store_sk ) select s_store_name1,s_store_id1,d_week_seq1 ,sun_sales1/sun_sales2,mon_sales1/mon_sales2 ,tue_sales1/tue_sales2,wed_sales1/wed_sales2,thu_sales1/thu_sales2 ,fri_sales1/fri_sales2,sat_sales1/sat_sales2 from (select s_store_name s_store_name1,wss.d_week_seq d_week_seq1 ,s_store_id s_store_id1,sun_sales sun_sales1 ,mon_sales mon_sales1,tue_sales tue_sales1 ,wed_sales wed_sales1,thu_sales thu_sales1 ,fri_sales fri_sales1,sat_sales sat_sales1 from wss,store,date_dim d where d.d_week_seq = wss.d_week_seq and ss_store_sk = s_store_sk and d_month_seq between 1202 and 1202 + 11) y, (select s_store_name s_store_name2,wss.d_week_seq d_week_seq2 ,s_store_id s_store_id2,sun_sales sun_sales2 ,mon_sales mon_sales2,tue_sales tue_sales2 ,wed_sales wed_sales2,thu_sales thu_sales2 ,fri_sales fri_sales2,sat_sales sat_sales2 from wss,store,date_dim d where d.d_week_seq = wss.d_week_seq and ss_store_sk = s_store_sk and d_month_seq between 1202+ 12 and 1202 + 23) x where s_store_id1=s_store_id2 and d_week_seq1=d_week_seq2-52 order by s_store_name1,s_store_id1,d_week_seq1 LIMIT 100; -- end query 30 in stream 0 using template query59.tpl -- start query 31 in stream 0 using template query37.tpl select i_item_id ,i_item_desc ,i_current_price from item, inventory, date_dim, catalog_sales where i_current_price between 16 and 16 + 30 and inv_item_sk = i_item_sk and d_date_sk=inv_date_sk and d_date between cast('1999-03-27' as date) and (cast('1999-03-27' as date) + interval 60 days) and i_manufact_id in (821,673,849,745) and inv_quantity_on_hand between 100 and 500 and cs_item_sk = i_item_sk group by i_item_id,i_item_desc,i_current_price order by i_item_id LIMIT 100; -- end query 31 in stream 0 using template query37.tpl -- start query 32 in stream 0 using template query98.tpl select i_item_id ,i_item_desc ,i_category ,i_class ,i_current_price ,sum(ss_ext_sales_price) as itemrevenue ,sum(ss_ext_sales_price)*100/sum(sum(ss_ext_sales_price)) over (partition by i_class) as revenueratio from store_sales ,item ,date_dim where ss_item_sk = i_item_sk and i_category in ('Children', 'Women', 'Shoes') and ss_sold_date_sk = d_date_sk and d_date between cast('2001-03-09' as date) and (cast('2001-03-09' as date) + interval 30 days) group by i_item_id ,i_item_desc ,i_category ,i_class ,i_current_price order by i_category ,i_class ,i_item_id ,i_item_desc ,revenueratio; -- end query 32 in stream 0 using template query98.tpl -- start query 33 in stream 0 using template query85.tpl select substr(r_reason_desc,1,20) ,avg(ws_quantity) ,avg(wr_refunded_cash) ,avg(wr_fee) from web_sales, web_returns, web_page, customer_demographics cd1, customer_demographics cd2, customer_address, date_dim, reason where ws_web_page_sk = wp_web_page_sk and ws_item_sk = wr_item_sk and ws_order_number = wr_order_number and ws_sold_date_sk = d_date_sk and d_year = 2001 and cd1.cd_demo_sk = wr_refunded_cdemo_sk and cd2.cd_demo_sk = wr_returning_cdemo_sk and ca_address_sk = wr_refunded_addr_sk and r_reason_sk = wr_reason_sk and ( ( cd1.cd_marital_status = 'W' and cd1.cd_marital_status = cd2.cd_marital_status and cd1.cd_education_status = 'Primary' and cd1.cd_education_status = cd2.cd_education_status and ws_sales_price between 100.00 and 150.00 ) or ( cd1.cd_marital_status = 'D' and cd1.cd_marital_status = cd2.cd_marital_status and cd1.cd_education_status = 'College' and cd1.cd_education_status = cd2.cd_education_status and ws_sales_price between 50.00 and 100.00 ) or ( cd1.cd_marital_status = 'S' and cd1.cd_marital_status = cd2.cd_marital_status and cd1.cd_education_status = '2 yr Degree' and cd1.cd_education_status = cd2.cd_education_status and ws_sales_price between 150.00 and 200.00 ) ) and ( ( ca_country = 'United States' and ca_state in ('PA', 'IN', 'VA') and ws_net_profit between 100 and 200 ) or ( ca_country = 'United States' and ca_state in ('TX', 'MO', 'MS') and ws_net_profit between 150 and 300 ) or ( ca_country = 'United States' and ca_state in ('MT', 'OR', 'MN') and ws_net_profit between 50 and 250 ) ) group by r_reason_desc order by substr(r_reason_desc,1,20) ,avg(ws_quantity) ,avg(wr_refunded_cash) ,avg(wr_fee) LIMIT 100; -- end query 33 in stream 0 using template query85.tpl -- start query 34 in stream 0 using template query70.tpl select sum(ss_net_profit) as total_sum ,s_state ,s_county ,grouping(s_state)+grouping(s_county) as lochierarchy ,rank() over ( partition by grouping(s_state)+grouping(s_county), case when grouping(s_county) = 0 then s_state end order by sum(ss_net_profit) desc) as rank_within_parent from store_sales ,date_dim d1 ,store where d1.d_month_seq between 1191 and 1191+11 and d1.d_date_sk = ss_sold_date_sk and s_store_sk = ss_store_sk and s_state in ( select s_state from (select s_state as s_state, rank() over ( partition by s_state order by sum(ss_net_profit) desc) as ranking from store_sales, store, date_dim where d_month_seq between 1191 and 1191+11 and d_date_sk = ss_sold_date_sk and s_store_sk = ss_store_sk group by s_state ) tmp1 where ranking <= 5 ) group by rollup(s_state,s_county) order by lochierarchy desc ,case when lochierarchy = 0 then s_state end ,rank_within_parent LIMIT 100; -- end query 34 in stream 0 using template query70.tpl -- start query 35 in stream 0 using template query67.tpl select * from (select i_category ,i_class ,i_brand ,i_product_name ,d_year ,d_qoy ,d_moy ,s_store_id ,sumsales ,rank() over (partition by i_category order by sumsales desc) rk from (select i_category ,i_class ,i_brand ,i_product_name ,d_year ,d_qoy ,d_moy ,s_store_id ,sum(coalesce(ss_sales_price*ss_quantity,0)) sumsales from store_sales ,date_dim ,store ,item where ss_sold_date_sk=d_date_sk and ss_item_sk=i_item_sk and ss_store_sk = s_store_sk and d_month_seq between 1192 and 1192+11 group by rollup(i_category, i_class, i_brand, i_product_name, d_year, d_qoy, d_moy,s_store_id))dw1) dw2 where rk <= 100 order by i_category ,i_class ,i_brand ,i_product_name ,d_year ,d_qoy ,d_moy ,s_store_id ,sumsales ,rk LIMIT 100; -- end query 35 in stream 0 using template query67.tpl -- start query 36 in stream 0 using template query28.tpl select * from (select avg(ss_list_price) B1_LP ,count(ss_list_price) B1_CNT ,count(distinct ss_list_price) B1_CNTD from store_sales where ss_quantity between 0 and 5 and (ss_list_price between 49 and 49+10 or ss_coupon_amt between 5040 and 5040+1000 or ss_wholesale_cost between 4 and 4+20)) B1, (select avg(ss_list_price) B2_LP ,count(ss_list_price) B2_CNT ,count(distinct ss_list_price) B2_CNTD from store_sales where ss_quantity between 6 and 10 and (ss_list_price between 5 and 5+10 or ss_coupon_amt between 441 and 441+1000 or ss_wholesale_cost between 80 and 80+20)) B2, (select avg(ss_list_price) B3_LP ,count(ss_list_price) B3_CNT ,count(distinct ss_list_price) B3_CNTD from store_sales where ss_quantity between 11 and 15 and (ss_list_price between 153 and 153+10 or ss_coupon_amt between 10459 and 10459+1000 or ss_wholesale_cost between 3 and 3+20)) B3, (select avg(ss_list_price) B4_LP ,count(ss_list_price) B4_CNT ,count(distinct ss_list_price) B4_CNTD from store_sales where ss_quantity between 16 and 20 and (ss_list_price between 14 and 14+10 or ss_coupon_amt between 13311 and 13311+1000 or ss_wholesale_cost between 1 and 1+20)) B4, (select avg(ss_list_price) B5_LP ,count(ss_list_price) B5_CNT ,count(distinct ss_list_price) B5_CNTD from store_sales where ss_quantity between 21 and 25 and (ss_list_price between 29 and 29+10 or ss_coupon_amt between 6047 and 6047+1000 or ss_wholesale_cost between 27 and 27+20)) B5, (select avg(ss_list_price) B6_LP ,count(ss_list_price) B6_CNT ,count(distinct ss_list_price) B6_CNTD from store_sales where ss_quantity between 26 and 30 and (ss_list_price between 159 and 159+10 or ss_coupon_amt between 2432 and 2432+1000 or ss_wholesale_cost between 48 and 48+20)) B6 LIMIT 100; -- end query 36 in stream 0 using template query28.tpl -- start query 37 in stream 0 using template query81.tpl with customer_total_return as (select cr_returning_customer_sk as ctr_customer_sk ,ca_state as ctr_state, sum(cr_return_amt_inc_tax) as ctr_total_return from catalog_returns ,date_dim ,customer_address where cr_returned_date_sk = d_date_sk and d_year =2002 and cr_returning_addr_sk = ca_address_sk group by cr_returning_customer_sk ,ca_state ) select c_customer_id,c_salutation,c_first_name,c_last_name,ca_street_number,ca_street_name ,ca_street_type,ca_suite_number,ca_city,ca_county,ca_state,ca_zip,ca_country,ca_gmt_offset ,ca_location_type,ctr_total_return from customer_total_return ctr1 ,customer_address ,customer where ctr1.ctr_total_return > (select avg(ctr_total_return)*1.2 from customer_total_return ctr2 where ctr1.ctr_state = ctr2.ctr_state) and ca_address_sk = c_current_addr_sk and ca_state = 'IL' and ctr1.ctr_customer_sk = c_customer_sk order by c_customer_id,c_salutation,c_first_name,c_last_name,ca_street_number,ca_street_name ,ca_street_type,ca_suite_number,ca_city,ca_county,ca_state,ca_zip,ca_country,ca_gmt_offset ,ca_location_type,ctr_total_return LIMIT 100; -- end query 37 in stream 0 using template query81.tpl -- start query 38 in stream 0 using template query97.tpl with ssci as ( select ss_customer_sk customer_sk ,ss_item_sk item_sk from store_sales,date_dim where ss_sold_date_sk = d_date_sk and d_month_seq between 1176 and 1176 + 11 group by ss_customer_sk ,ss_item_sk), csci as( select cs_bill_customer_sk customer_sk ,cs_item_sk item_sk from catalog_sales,date_dim where cs_sold_date_sk = d_date_sk and d_month_seq between 1176 and 1176 + 11 group by cs_bill_customer_sk ,cs_item_sk) select sum(case when ssci.customer_sk is not null and csci.customer_sk is null then 1 else 0 end) store_only ,sum(case when ssci.customer_sk is null and csci.customer_sk is not null then 1 else 0 end) catalog_only ,sum(case when ssci.customer_sk is not null and csci.customer_sk is not null then 1 else 0 end) store_and_catalog from ssci full outer join csci on (ssci.customer_sk=csci.customer_sk and ssci.item_sk = csci.item_sk) LIMIT 100; -- end query 38 in stream 0 using template query97.tpl -- start query 39 in stream 0 using template query66.tpl select w_warehouse_name ,w_warehouse_sq_ft ,w_city ,w_county ,w_state ,w_country ,ship_carriers ,year ,sum(jan_sales) as jan_sales ,sum(feb_sales) as feb_sales ,sum(mar_sales) as mar_sales ,sum(apr_sales) as apr_sales ,sum(may_sales) as may_sales ,sum(jun_sales) as jun_sales ,sum(jul_sales) as jul_sales ,sum(aug_sales) as aug_sales ,sum(sep_sales) as sep_sales ,sum(oct_sales) as oct_sales ,sum(nov_sales) as nov_sales ,sum(dec_sales) as dec_sales ,sum(jan_sales/w_warehouse_sq_ft) as jan_sales_per_sq_foot ,sum(feb_sales/w_warehouse_sq_ft) as feb_sales_per_sq_foot ,sum(mar_sales/w_warehouse_sq_ft) as mar_sales_per_sq_foot ,sum(apr_sales/w_warehouse_sq_ft) as apr_sales_per_sq_foot ,sum(may_sales/w_warehouse_sq_ft) as may_sales_per_sq_foot ,sum(jun_sales/w_warehouse_sq_ft) as jun_sales_per_sq_foot ,sum(jul_sales/w_warehouse_sq_ft) as jul_sales_per_sq_foot ,sum(aug_sales/w_warehouse_sq_ft) as aug_sales_per_sq_foot ,sum(sep_sales/w_warehouse_sq_ft) as sep_sales_per_sq_foot ,sum(oct_sales/w_warehouse_sq_ft) as oct_sales_per_sq_foot ,sum(nov_sales/w_warehouse_sq_ft) as nov_sales_per_sq_foot ,sum(dec_sales/w_warehouse_sq_ft) as dec_sales_per_sq_foot ,sum(jan_net) as jan_net ,sum(feb_net) as feb_net ,sum(mar_net) as mar_net ,sum(apr_net) as apr_net ,sum(may_net) as may_net ,sum(jun_net) as jun_net ,sum(jul_net) as jul_net ,sum(aug_net) as aug_net ,sum(sep_net) as sep_net ,sum(oct_net) as oct_net ,sum(nov_net) as nov_net ,sum(dec_net) as dec_net from ( select w_warehouse_name ,w_warehouse_sq_ft ,w_city ,w_county ,w_state ,w_country ,'ZOUROS' || ',' || 'ZHOU' as ship_carriers ,d_year as year ,sum(case when d_moy = 1 then ws_sales_price* ws_quantity else 0 end) as jan_sales ,sum(case when d_moy = 2 then ws_sales_price* ws_quantity else 0 end) as feb_sales ,sum(case when d_moy = 3 then ws_sales_price* ws_quantity else 0 end) as mar_sales ,sum(case when d_moy = 4 then ws_sales_price* ws_quantity else 0 end) as apr_sales ,sum(case when d_moy = 5 then ws_sales_price* ws_quantity else 0 end) as may_sales ,sum(case when d_moy = 6 then ws_sales_price* ws_quantity else 0 end) as jun_sales ,sum(case when d_moy = 7 then ws_sales_price* ws_quantity else 0 end) as jul_sales ,sum(case when d_moy = 8 then ws_sales_price* ws_quantity else 0 end) as aug_sales ,sum(case when d_moy = 9 then ws_sales_price* ws_quantity else 0 end) as sep_sales ,sum(case when d_moy = 10 then ws_sales_price* ws_quantity else 0 end) as oct_sales ,sum(case when d_moy = 11 then ws_sales_price* ws_quantity else 0 end) as nov_sales ,sum(case when d_moy = 12 then ws_sales_price* ws_quantity else 0 end) as dec_sales ,sum(case when d_moy = 1 then ws_net_paid * ws_quantity else 0 end) as jan_net ,sum(case when d_moy = 2 then ws_net_paid * ws_quantity else 0 end) as feb_net ,sum(case when d_moy = 3 then ws_net_paid * ws_quantity else 0 end) as mar_net ,sum(case when d_moy = 4 then ws_net_paid * ws_quantity else 0 end) as apr_net ,sum(case when d_moy = 5 then ws_net_paid * ws_quantity else 0 end) as may_net ,sum(case when d_moy = 6 then ws_net_paid * ws_quantity else 0 end) as jun_net ,sum(case when d_moy = 7 then ws_net_paid * ws_quantity else 0 end) as jul_net ,sum(case when d_moy = 8 then ws_net_paid * ws_quantity else 0 end) as aug_net ,sum(case when d_moy = 9 then ws_net_paid * ws_quantity else 0 end) as sep_net ,sum(case when d_moy = 10 then ws_net_paid * ws_quantity else 0 end) as oct_net ,sum(case when d_moy = 11 then ws_net_paid * ws_quantity else 0 end) as nov_net ,sum(case when d_moy = 12 then ws_net_paid * ws_quantity else 0 end) as dec_net from web_sales ,warehouse ,date_dim ,time_dim ,ship_mode where ws_warehouse_sk = w_warehouse_sk and ws_sold_date_sk = d_date_sk and ws_sold_time_sk = t_time_sk and ws_ship_mode_sk = sm_ship_mode_sk and d_year = 2000 and t_time between 18479 and 18479+28800 and sm_carrier in ('ZOUROS','ZHOU') group by w_warehouse_name ,w_warehouse_sq_ft ,w_city ,w_county ,w_state ,w_country ,d_year union all select w_warehouse_name ,w_warehouse_sq_ft ,w_city ,w_county ,w_state ,w_country ,'ZOUROS' || ',' || 'ZHOU' as ship_carriers ,d_year as year ,sum(case when d_moy = 1 then cs_ext_sales_price* cs_quantity else 0 end) as jan_sales ,sum(case when d_moy = 2 then cs_ext_sales_price* cs_quantity else 0 end) as feb_sales ,sum(case when d_moy = 3 then cs_ext_sales_price* cs_quantity else 0 end) as mar_sales ,sum(case when d_moy = 4 then cs_ext_sales_price* cs_quantity else 0 end) as apr_sales ,sum(case when d_moy = 5 then cs_ext_sales_price* cs_quantity else 0 end) as may_sales ,sum(case when d_moy = 6 then cs_ext_sales_price* cs_quantity else 0 end) as jun_sales ,sum(case when d_moy = 7 then cs_ext_sales_price* cs_quantity else 0 end) as jul_sales ,sum(case when d_moy = 8 then cs_ext_sales_price* cs_quantity else 0 end) as aug_sales ,sum(case when d_moy = 9 then cs_ext_sales_price* cs_quantity else 0 end) as sep_sales ,sum(case when d_moy = 10 then cs_ext_sales_price* cs_quantity else 0 end) as oct_sales ,sum(case when d_moy = 11 then cs_ext_sales_price* cs_quantity else 0 end) as nov_sales ,sum(case when d_moy = 12 then cs_ext_sales_price* cs_quantity else 0 end) as dec_sales ,sum(case when d_moy = 1 then cs_net_paid_inc_ship * cs_quantity else 0 end) as jan_net ,sum(case when d_moy = 2 then cs_net_paid_inc_ship * cs_quantity else 0 end) as feb_net ,sum(case when d_moy = 3 then cs_net_paid_inc_ship * cs_quantity else 0 end) as mar_net ,sum(case when d_moy = 4 then cs_net_paid_inc_ship * cs_quantity else 0 end) as apr_net ,sum(case when d_moy = 5 then cs_net_paid_inc_ship * cs_quantity else 0 end) as may_net ,sum(case when d_moy = 6 then cs_net_paid_inc_ship * cs_quantity else 0 end) as jun_net ,sum(case when d_moy = 7 then cs_net_paid_inc_ship * cs_quantity else 0 end) as jul_net ,sum(case when d_moy = 8 then cs_net_paid_inc_ship * cs_quantity else 0 end) as aug_net ,sum(case when d_moy = 9 then cs_net_paid_inc_ship * cs_quantity else 0 end) as sep_net ,sum(case when d_moy = 10 then cs_net_paid_inc_ship * cs_quantity else 0 end) as oct_net ,sum(case when d_moy = 11 then cs_net_paid_inc_ship * cs_quantity else 0 end) as nov_net ,sum(case when d_moy = 12 then cs_net_paid_inc_ship * cs_quantity else 0 end) as dec_net from catalog_sales ,warehouse ,date_dim ,time_dim ,ship_mode where cs_warehouse_sk = w_warehouse_sk and cs_sold_date_sk = d_date_sk and cs_sold_time_sk = t_time_sk and cs_ship_mode_sk = sm_ship_mode_sk and d_year = 2000 and t_time between 18479 AND 18479+28800 and sm_carrier in ('ZOUROS','ZHOU') group by w_warehouse_name ,w_warehouse_sq_ft ,w_city ,w_county ,w_state ,w_country ,d_year ) x group by w_warehouse_name ,w_warehouse_sq_ft ,w_city ,w_county ,w_state ,w_country ,ship_carriers ,year order by w_warehouse_name LIMIT 100; -- end query 39 in stream 0 using template query66.tpl -- start query 40 in stream 0 using template query90.tpl select cast(amc as decimal(15,4))/cast(pmc as decimal(15,4)) am_pm_ratio from ( select count(*) amc from web_sales, household_demographics , time_dim, web_page where ws_sold_time_sk = time_dim.t_time_sk and ws_ship_hdemo_sk = household_demographics.hd_demo_sk and ws_web_page_sk = web_page.wp_web_page_sk and time_dim.t_hour between 12 and 12+1 and household_demographics.hd_dep_count = 0 and web_page.wp_char_count between 5000 and 5200) at, ( select count(*) pmc from web_sales, household_demographics , time_dim, web_page where ws_sold_time_sk = time_dim.t_time_sk and ws_ship_hdemo_sk = household_demographics.hd_demo_sk and ws_web_page_sk = web_page.wp_web_page_sk and time_dim.t_hour between 15 and 15+1 and household_demographics.hd_dep_count = 0 and web_page.wp_char_count between 5000 and 5200) pt order by am_pm_ratio LIMIT 100; -- end query 40 in stream 0 using template query90.tpl -- start query 41 in stream 0 using template query17.tpl select i_item_id ,i_item_desc ,s_state ,count(ss_quantity) as store_sales_quantitycount ,avg(ss_quantity) as store_sales_quantityave ,stddev_samp(ss_quantity) as store_sales_quantitystdev ,stddev_samp(ss_quantity)/avg(ss_quantity) as store_sales_quantitycov ,count(sr_return_quantity) as store_returns_quantitycount ,avg(sr_return_quantity) as store_returns_quantityave ,stddev_samp(sr_return_quantity) as store_returns_quantitystdev ,stddev_samp(sr_return_quantity)/avg(sr_return_quantity) as store_returns_quantitycov ,count(cs_quantity) as catalog_sales_quantitycount ,avg(cs_quantity) as catalog_sales_quantityave ,stddev_samp(cs_quantity) as catalog_sales_quantitystdev ,stddev_samp(cs_quantity)/avg(cs_quantity) as catalog_sales_quantitycov from store_sales ,store_returns ,catalog_sales ,date_dim d1 ,date_dim d2 ,date_dim d3 ,store ,item where d1.d_quarter_name = '2001Q1' and d1.d_date_sk = ss_sold_date_sk and i_item_sk = ss_item_sk and s_store_sk = ss_store_sk and ss_customer_sk = sr_customer_sk and ss_item_sk = sr_item_sk and ss_ticket_number = sr_ticket_number and sr_returned_date_sk = d2.d_date_sk and d2.d_quarter_name in ('2001Q1','2001Q2','2001Q3') and sr_customer_sk = cs_bill_customer_sk and sr_item_sk = cs_item_sk and cs_sold_date_sk = d3.d_date_sk and d3.d_quarter_name in ('2001Q1','2001Q2','2001Q3') group by i_item_id ,i_item_desc ,s_state order by i_item_id ,i_item_desc ,s_state LIMIT 100; -- end query 41 in stream 0 using template query17.tpl -- start query 42 in stream 0 using template query47.tpl with v1 as( select i_category, i_brand, s_store_name, s_company_name, d_year, d_moy, sum(ss_sales_price) sum_sales, avg(sum(ss_sales_price)) over (partition by i_category, i_brand, s_store_name, s_company_name, d_year) avg_monthly_sales, rank() over (partition by i_category, i_brand, s_store_name, s_company_name order by d_year, d_moy) rn from item, store_sales, date_dim, store where ss_item_sk = i_item_sk and ss_sold_date_sk = d_date_sk and ss_store_sk = s_store_sk and ( d_year = 2001 or ( d_year = 2001-1 and d_moy =12) or ( d_year = 2001+1 and d_moy =1) ) group by i_category, i_brand, s_store_name, s_company_name, d_year, d_moy), v2 as( select v1.s_company_name ,v1.d_year, v1.d_moy ,v1.avg_monthly_sales ,v1.sum_sales, v1_lag.sum_sales psum, v1_lead.sum_sales nsum from v1, v1 v1_lag, v1 v1_lead where v1.i_category = v1_lag.i_category and v1.i_category = v1_lead.i_category and v1.i_brand = v1_lag.i_brand and v1.i_brand = v1_lead.i_brand and v1.s_store_name = v1_lag.s_store_name and v1.s_store_name = v1_lead.s_store_name and v1.s_company_name = v1_lag.s_company_name and v1.s_company_name = v1_lead.s_company_name and v1.rn = v1_lag.rn + 1 and v1.rn = v1_lead.rn - 1) select * from v2 where d_year = 2001 and avg_monthly_sales > 0 and case when avg_monthly_sales > 0 then abs(sum_sales - avg_monthly_sales) / avg_monthly_sales else null end > 0.1 order by sum_sales - avg_monthly_sales, avg_monthly_sales LIMIT 100; -- end query 42 in stream 0 using template query47.tpl -- start query 43 in stream 0 using template query95.tpl with ws_wh as (select ws1.ws_order_number,ws1.ws_warehouse_sk wh1,ws2.ws_warehouse_sk wh2 from web_sales ws1,web_sales ws2 where ws1.ws_order_number = ws2.ws_order_number and ws1.ws_warehouse_sk <> ws2.ws_warehouse_sk) select count(distinct ws_order_number) as `order count` ,sum(ws_ext_ship_cost) as `total shipping cost` ,sum(ws_net_profit) as `total net profit` from web_sales ws1 ,date_dim ,customer_address ,web_site where d_date between '1999-3-01' and (cast('1999-3-01' as date) + interval 60 days) and ws1.ws_ship_date_sk = d_date_sk and ws1.ws_ship_addr_sk = ca_address_sk and ca_state = 'OR' and ws1.ws_web_site_sk = web_site_sk and web_company_name = 'pri' and ws1.ws_order_number in (select ws_order_number from ws_wh) and ws1.ws_order_number in (select wr_order_number from web_returns,ws_wh where wr_order_number = ws_wh.ws_order_number) order by count(distinct ws_order_number) LIMIT 100; -- end query 43 in stream 0 using template query95.tpl -- start query 44 in stream 0 using template query92.tpl select sum(ws_ext_discount_amt) as `Excess Discount Amount` from web_sales ,item ,date_dim where i_manufact_id = 783 and i_item_sk = ws_item_sk and d_date between '1999-03-21' and (cast('1999-03-21' as date) + interval 90 days) and d_date_sk = ws_sold_date_sk and ws_ext_discount_amt > ( SELECT 1.3 * avg(ws_ext_discount_amt) FROM web_sales ,date_dim WHERE ws_item_sk = i_item_sk and d_date between '1999-03-21' and (cast('1999-03-21' as date) + interval 90 days) and d_date_sk = ws_sold_date_sk ) order by sum(ws_ext_discount_amt) LIMIT 100; -- end query 44 in stream 0 using template query92.tpl -- start query 45 in stream 0 using template query3.tpl select dt.d_year ,item.i_brand_id brand_id ,item.i_brand brand ,sum(ss_sales_price) sum_agg from date_dim dt ,store_sales ,item where dt.d_date_sk = store_sales.ss_sold_date_sk and store_sales.ss_item_sk = item.i_item_sk and item.i_manufact_id = 211 and dt.d_moy=11 group by dt.d_year ,item.i_brand ,item.i_brand_id order by dt.d_year ,sum_agg desc ,brand_id LIMIT 100; -- end query 45 in stream 0 using template query3.tpl -- start query 46 in stream 0 using template query51.tpl WITH web_v1 as ( select ws_item_sk item_sk, d_date, sum(sum(ws_sales_price)) over (partition by ws_item_sk order by d_date rows between unbounded preceding and current row) cume_sales from web_sales ,date_dim where ws_sold_date_sk=d_date_sk and d_month_seq between 1195 and 1195+11 and ws_item_sk is not NULL group by ws_item_sk, d_date), store_v1 as ( select ss_item_sk item_sk, d_date, sum(sum(ss_sales_price)) over (partition by ss_item_sk order by d_date rows between unbounded preceding and current row) cume_sales from store_sales ,date_dim where ss_sold_date_sk=d_date_sk and d_month_seq between 1195 and 1195+11 and ss_item_sk is not NULL group by ss_item_sk, d_date) select * from (select item_sk ,d_date ,web_sales ,store_sales ,max(web_sales) over (partition by item_sk order by d_date rows between unbounded preceding and current row) web_cumulative ,max(store_sales) over (partition by item_sk order by d_date rows between unbounded preceding and current row) store_cumulative from (select case when web.item_sk is not null then web.item_sk else store.item_sk end item_sk ,case when web.d_date is not null then web.d_date else store.d_date end d_date ,web.cume_sales web_sales ,store.cume_sales store_sales from web_v1 web full outer join store_v1 store on (web.item_sk = store.item_sk and web.d_date = store.d_date) )x )y where web_cumulative > store_cumulative order by item_sk ,d_date LIMIT 100; -- end query 46 in stream 0 using template query51.tpl -- start query 47 in stream 0 using template query35.tpl select ca_state, cd_gender, cd_marital_status, cd_dep_count, count(*) cnt1, stddev_samp(cd_dep_count) aggone1, sum(cd_dep_count) aggtwo1, min(cd_dep_count) aggthree1, cd_dep_employed_count, count(*) cnt2, stddev_samp(cd_dep_employed_count) aggone2, sum(cd_dep_employed_count) aggtwo2, min(cd_dep_employed_count) aggthree2, cd_dep_college_count, count(*) cnt3, stddev_samp(cd_dep_college_count) aggone3, sum(cd_dep_college_count) aggtwo3, min(cd_dep_college_count) aggthree3 from customer c,customer_address ca,customer_demographics where c.c_current_addr_sk = ca.ca_address_sk and cd_demo_sk = c.c_current_cdemo_sk and exists (select * from store_sales,date_dim where c.c_customer_sk = ss_customer_sk and ss_sold_date_sk = d_date_sk and d_year = 2001 and d_qoy < 4) and (exists (select * from web_sales,date_dim where c.c_customer_sk = ws_bill_customer_sk and ws_sold_date_sk = d_date_sk and d_year = 2001 and d_qoy < 4) or exists (select * from catalog_sales,date_dim where c.c_customer_sk = cs_ship_customer_sk and cs_sold_date_sk = d_date_sk and d_year = 2001 and d_qoy < 4)) group by ca_state, cd_gender, cd_marital_status, cd_dep_count, cd_dep_employed_count, cd_dep_college_count order by ca_state, cd_gender, cd_marital_status, cd_dep_count, cd_dep_employed_count, cd_dep_college_count LIMIT 100; -- end query 47 in stream 0 using template query35.tpl -- start query 48 in stream 0 using template query49.tpl select channel, item, return_ratio, return_rank, currency_rank from (select 'web' as channel ,web.item ,web.return_ratio ,web.return_rank ,web.currency_rank from ( select item ,return_ratio ,currency_ratio ,rank() over (order by return_ratio) as return_rank ,rank() over (order by currency_ratio) as currency_rank from ( select ws.ws_item_sk as item ,(cast(sum(coalesce(wr.wr_return_quantity,0)) as decimal(15,4))/ cast(sum(coalesce(ws.ws_quantity,0)) as decimal(15,4) )) as return_ratio ,(cast(sum(coalesce(wr.wr_return_amt,0)) as decimal(15,4))/ cast(sum(coalesce(ws.ws_net_paid,0)) as decimal(15,4) )) as currency_ratio from web_sales ws left outer join web_returns wr on (ws.ws_order_number = wr.wr_order_number and ws.ws_item_sk = wr.wr_item_sk) ,date_dim where wr.wr_return_amt > 10000 and ws.ws_net_profit > 1 and ws.ws_net_paid > 0 and ws.ws_quantity > 0 and ws_sold_date_sk = d_date_sk and d_year = 2000 and d_moy = 12 group by ws.ws_item_sk ) in_web ) web where ( web.return_rank <= 10 or web.currency_rank <= 10 ) union select 'catalog' as channel ,catalog.item ,catalog.return_ratio ,catalog.return_rank ,catalog.currency_rank from ( select item ,return_ratio ,currency_ratio ,rank() over (order by return_ratio) as return_rank ,rank() over (order by currency_ratio) as currency_rank from ( select cs.cs_item_sk as item ,(cast(sum(coalesce(cr.cr_return_quantity,0)) as decimal(15,4))/ cast(sum(coalesce(cs.cs_quantity,0)) as decimal(15,4) )) as return_ratio ,(cast(sum(coalesce(cr.cr_return_amount,0)) as decimal(15,4))/ cast(sum(coalesce(cs.cs_net_paid,0)) as decimal(15,4) )) as currency_ratio from catalog_sales cs left outer join catalog_returns cr on (cs.cs_order_number = cr.cr_order_number and cs.cs_item_sk = cr.cr_item_sk) ,date_dim where cr.cr_return_amount > 10000 and cs.cs_net_profit > 1 and cs.cs_net_paid > 0 and cs.cs_quantity > 0 and cs_sold_date_sk = d_date_sk and d_year = 2000 and d_moy = 12 group by cs.cs_item_sk ) in_cat ) catalog where ( catalog.return_rank <= 10 or catalog.currency_rank <=10 ) union select 'store' as channel ,store.item ,store.return_ratio ,store.return_rank ,store.currency_rank from ( select item ,return_ratio ,currency_ratio ,rank() over (order by return_ratio) as return_rank ,rank() over (order by currency_ratio) as currency_rank from ( select sts.ss_item_sk as item ,(cast(sum(coalesce(sr.sr_return_quantity,0)) as decimal(15,4))/cast(sum(coalesce(sts.ss_quantity,0)) as decimal(15,4) )) as return_ratio ,(cast(sum(coalesce(sr.sr_return_amt,0)) as decimal(15,4))/cast(sum(coalesce(sts.ss_net_paid,0)) as decimal(15,4) )) as currency_ratio from store_sales sts left outer join store_returns sr on (sts.ss_ticket_number = sr.sr_ticket_number and sts.ss_item_sk = sr.sr_item_sk) ,date_dim where sr.sr_return_amt > 10000 and sts.ss_net_profit > 1 and sts.ss_net_paid > 0 and sts.ss_quantity > 0 and ss_sold_date_sk = d_date_sk and d_year = 2000 and d_moy = 12 group by sts.ss_item_sk ) in_store ) store where ( store.return_rank <= 10 or store.currency_rank <= 10 ) ) order by 1,4,5,2 LIMIT 100; -- end query 48 in stream 0 using template query49.tpl -- start query 49 in stream 0 using template query9.tpl select case when (select count(*) from store_sales where ss_quantity between 1 and 20) > 144610 then (select avg(ss_ext_tax) from store_sales where ss_quantity between 1 and 20) else (select avg(ss_net_paid) from store_sales where ss_quantity between 1 and 20) end bucket1 , case when (select count(*) from store_sales where ss_quantity between 21 and 40) > 162498 then (select avg(ss_ext_tax) from store_sales where ss_quantity between 21 and 40) else (select avg(ss_net_paid) from store_sales where ss_quantity between 21 and 40) end bucket2, case when (select count(*) from store_sales where ss_quantity between 41 and 60) > 28387 then (select avg(ss_ext_tax) from store_sales where ss_quantity between 41 and 60) else (select avg(ss_net_paid) from store_sales where ss_quantity between 41 and 60) end bucket3, case when (select count(*) from store_sales where ss_quantity between 61 and 80) > 442573 then (select avg(ss_ext_tax) from store_sales where ss_quantity between 61 and 80) else (select avg(ss_net_paid) from store_sales where ss_quantity between 61 and 80) end bucket4, case when (select count(*) from store_sales where ss_quantity between 81 and 100) > 212532 then (select avg(ss_ext_tax) from store_sales where ss_quantity between 81 and 100) else (select avg(ss_net_paid) from store_sales where ss_quantity between 81 and 100) end bucket5 from reason where r_reason_sk = 1 ; -- end query 49 in stream 0 using template query9.tpl -- start query 50 in stream 0 using template query31.tpl with ss as (select ca_county,d_qoy, d_year,sum(ss_ext_sales_price) as store_sales from store_sales,date_dim,customer_address where ss_sold_date_sk = d_date_sk and ss_addr_sk=ca_address_sk group by ca_county,d_qoy, d_year), ws as (select ca_county,d_qoy, d_year,sum(ws_ext_sales_price) as web_sales from web_sales,date_dim,customer_address where ws_sold_date_sk = d_date_sk and ws_bill_addr_sk=ca_address_sk group by ca_county,d_qoy, d_year) select ss1.ca_county ,ss1.d_year ,ws2.web_sales/ws1.web_sales web_q1_q2_increase ,ss2.store_sales/ss1.store_sales store_q1_q2_increase ,ws3.web_sales/ws2.web_sales web_q2_q3_increase ,ss3.store_sales/ss2.store_sales store_q2_q3_increase from ss ss1 ,ss ss2 ,ss ss3 ,ws ws1 ,ws ws2 ,ws ws3 where ss1.d_qoy = 1 and ss1.d_year = 2000 and ss1.ca_county = ss2.ca_county and ss2.d_qoy = 2 and ss2.d_year = 2000 and ss2.ca_county = ss3.ca_county and ss3.d_qoy = 3 and ss3.d_year = 2000 and ss1.ca_county = ws1.ca_county and ws1.d_qoy = 1 and ws1.d_year = 2000 and ws1.ca_county = ws2.ca_county and ws2.d_qoy = 2 and ws2.d_year = 2000 and ws1.ca_county = ws3.ca_county and ws3.d_qoy = 3 and ws3.d_year =2000 and case when ws1.web_sales > 0 then ws2.web_sales/ws1.web_sales else null end > case when ss1.store_sales > 0 then ss2.store_sales/ss1.store_sales else null end and case when ws2.web_sales > 0 then ws3.web_sales/ws2.web_sales else null end > case when ss2.store_sales > 0 then ss3.store_sales/ss2.store_sales else null end order by web_q2_q3_increase; -- end query 50 in stream 0 using template query31.tpl -- start query 51 in stream 0 using template query11.tpl with year_total as ( select c_customer_id customer_id ,c_first_name customer_first_name ,c_last_name customer_last_name ,c_preferred_cust_flag customer_preferred_cust_flag ,c_birth_country customer_birth_country ,c_login customer_login ,c_email_address customer_email_address ,d_year dyear ,sum(ss_ext_list_price-ss_ext_discount_amt) year_total ,'s' sale_type from customer ,store_sales ,date_dim where c_customer_sk = ss_customer_sk and ss_sold_date_sk = d_date_sk group by c_customer_id ,c_first_name ,c_last_name ,c_preferred_cust_flag ,c_birth_country ,c_login ,c_email_address ,d_year union all select c_customer_id customer_id ,c_first_name customer_first_name ,c_last_name customer_last_name ,c_preferred_cust_flag customer_preferred_cust_flag ,c_birth_country customer_birth_country ,c_login customer_login ,c_email_address customer_email_address ,d_year dyear ,sum(ws_ext_list_price-ws_ext_discount_amt) year_total ,'w' sale_type from customer ,web_sales ,date_dim where c_customer_sk = ws_bill_customer_sk and ws_sold_date_sk = d_date_sk group by c_customer_id ,c_first_name ,c_last_name ,c_preferred_cust_flag ,c_birth_country ,c_login ,c_email_address ,d_year ) select t_s_secyear.customer_id ,t_s_secyear.customer_first_name ,t_s_secyear.customer_last_name ,t_s_secyear.customer_preferred_cust_flag from year_total t_s_firstyear ,year_total t_s_secyear ,year_total t_w_firstyear ,year_total t_w_secyear where t_s_secyear.customer_id = t_s_firstyear.customer_id and t_s_firstyear.customer_id = t_w_secyear.customer_id and t_s_firstyear.customer_id = t_w_firstyear.customer_id and t_s_firstyear.sale_type = 's' and t_w_firstyear.sale_type = 'w' and t_s_secyear.sale_type = 's' and t_w_secyear.sale_type = 'w' and t_s_firstyear.dyear = 1998 and t_s_secyear.dyear = 1998+1 and t_w_firstyear.dyear = 1998 and t_w_secyear.dyear = 1998+1 and t_s_firstyear.year_total > 0 and t_w_firstyear.year_total > 0 and case when t_w_firstyear.year_total > 0 then t_w_secyear.year_total / t_w_firstyear.year_total else 0.0 end > case when t_s_firstyear.year_total > 0 then t_s_secyear.year_total / t_s_firstyear.year_total else 0.0 end order by t_s_secyear.customer_id ,t_s_secyear.customer_first_name ,t_s_secyear.customer_last_name ,t_s_secyear.customer_preferred_cust_flag LIMIT 100; -- end query 51 in stream 0 using template query11.tpl -- start query 52 in stream 0 using template query93.tpl select ss_customer_sk ,sum(act_sales) sumsales from (select ss_item_sk ,ss_ticket_number ,ss_customer_sk ,case when sr_return_quantity is not null then (ss_quantity-sr_return_quantity)*ss_sales_price else (ss_quantity*ss_sales_price) end act_sales from store_sales left outer join store_returns on (sr_item_sk = ss_item_sk and sr_ticket_number = ss_ticket_number) ,reason where sr_reason_sk = r_reason_sk and r_reason_desc = 'reason 56') t group by ss_customer_sk order by sumsales, ss_customer_sk LIMIT 100; -- end query 52 in stream 0 using template query93.tpl -- start query 53 in stream 0 using template query29.tpl select i_item_id ,i_item_desc ,s_store_id ,s_store_name ,max(ss_quantity) as store_sales_quantity ,max(sr_return_quantity) as store_returns_quantity ,max(cs_quantity) as catalog_sales_quantity from store_sales ,store_returns ,catalog_sales ,date_dim d1 ,date_dim d2 ,date_dim d3 ,store ,item where d1.d_moy = 4 and d1.d_year = 2000 and d1.d_date_sk = ss_sold_date_sk and i_item_sk = ss_item_sk and s_store_sk = ss_store_sk and ss_customer_sk = sr_customer_sk and ss_item_sk = sr_item_sk and ss_ticket_number = sr_ticket_number and sr_returned_date_sk = d2.d_date_sk and d2.d_moy between 4 and 4 + 3 and d2.d_year = 2000 and sr_customer_sk = cs_bill_customer_sk and sr_item_sk = cs_item_sk and cs_sold_date_sk = d3.d_date_sk and d3.d_year in (2000,2000+1,2000+2) group by i_item_id ,i_item_desc ,s_store_id ,s_store_name order by i_item_id ,i_item_desc ,s_store_id ,s_store_name LIMIT 100; -- end query 53 in stream 0 using template query29.tpl -- start query 54 in stream 0 using template query38.tpl select count(*) from ( select distinct c_last_name, c_first_name, d_date from store_sales, date_dim, customer where store_sales.ss_sold_date_sk = date_dim.d_date_sk and store_sales.ss_customer_sk = customer.c_customer_sk and d_month_seq between 1212 and 1212 + 11 intersect select distinct c_last_name, c_first_name, d_date from catalog_sales, date_dim, customer where catalog_sales.cs_sold_date_sk = date_dim.d_date_sk and catalog_sales.cs_bill_customer_sk = customer.c_customer_sk and d_month_seq between 1212 and 1212 + 11 intersect select distinct c_last_name, c_first_name, d_date from web_sales, date_dim, customer where web_sales.ws_sold_date_sk = date_dim.d_date_sk and web_sales.ws_bill_customer_sk = customer.c_customer_sk and d_month_seq between 1212 and 1212 + 11 ) hot_cust LIMIT 100; -- end query 54 in stream 0 using template query38.tpl -- start query 55 in stream 0 using template query22.tpl select i_product_name ,i_brand ,i_class ,i_category ,avg(inv_quantity_on_hand) qoh from inventory ,date_dim ,item where inv_date_sk=d_date_sk and inv_item_sk=i_item_sk and d_month_seq between 1188 and 1188 + 11 group by rollup(i_product_name ,i_brand ,i_class ,i_category) order by qoh, i_product_name, i_brand, i_class, i_category LIMIT 100; -- end query 55 in stream 0 using template query22.tpl -- start query 56 in stream 0 using template query89.tpl select * from( select i_category, i_class, i_brand, s_store_name, s_company_name, d_moy, sum(ss_sales_price) sum_sales, avg(sum(ss_sales_price)) over (partition by i_category, i_brand, s_store_name, s_company_name) avg_monthly_sales from item, store_sales, date_dim, store where ss_item_sk = i_item_sk and ss_sold_date_sk = d_date_sk and ss_store_sk = s_store_sk and d_year in (2001) and ((i_category in ('Electronics','Books','Home') and i_class in ('scanners','parenting','wallpaper') ) or (i_category in ('Shoes','Sports','Women') and i_class in ('kids','archery','dresses') )) group by i_category, i_class, i_brand, s_store_name, s_company_name, d_moy) tmp1 where case when (avg_monthly_sales <> 0) then (abs(sum_sales - avg_monthly_sales) / avg_monthly_sales) else null end > 0.1 order by sum_sales - avg_monthly_sales, s_store_name LIMIT 100; -- end query 56 in stream 0 using template query89.tpl -- start query 57 in stream 0 using template query15.tpl select ca_zip ,sum(cs_sales_price) from catalog_sales ,customer ,customer_address ,date_dim where cs_bill_customer_sk = c_customer_sk and c_current_addr_sk = ca_address_sk and ( substr(ca_zip,1,5) in ('85669', '86197','88274','83405','86475', '85392', '85460', '80348', '81792') or ca_state in ('CA','WA','GA') or cs_sales_price > 500) and cs_sold_date_sk = d_date_sk and d_qoy = 2 and d_year = 2002 group by ca_zip order by ca_zip LIMIT 100; -- end query 57 in stream 0 using template query15.tpl -- start query 58 in stream 0 using template query6.tpl select a.ca_state state, count(*) cnt from customer_address a ,customer c ,store_sales s ,date_dim d ,item i where a.ca_address_sk = c.c_current_addr_sk and c.c_customer_sk = s.ss_customer_sk and s.ss_sold_date_sk = d.d_date_sk and s.ss_item_sk = i.i_item_sk and d.d_month_seq = (select distinct (d_month_seq) from date_dim where d_year = 1998 and d_moy = 6 ) and i.i_current_price > 1.2 * (select avg(j.i_current_price) from item j where j.i_category = i.i_category) group by a.ca_state having count(*) >= 10 order by cnt, a.ca_state LIMIT 100; -- end query 58 in stream 0 using template query6.tpl -- start query 59 in stream 0 using template query52.tpl select dt.d_year ,item.i_brand_id brand_id ,item.i_brand brand ,sum(ss_ext_sales_price) ext_price from date_dim dt ,store_sales ,item where dt.d_date_sk = store_sales.ss_sold_date_sk and store_sales.ss_item_sk = item.i_item_sk and item.i_manager_id = 1 and dt.d_moy=12 and dt.d_year=2002 group by dt.d_year ,item.i_brand ,item.i_brand_id order by dt.d_year ,ext_price desc ,brand_id LIMIT 100 ; -- end query 59 in stream 0 using template query52.tpl -- start query 60 in stream 0 using template query50.tpl select s_store_name ,s_company_id ,s_street_number ,s_street_name ,s_street_type ,s_suite_number ,s_city ,s_county ,s_state ,s_zip ,sum(case when (sr_returned_date_sk - ss_sold_date_sk <= 30 ) then 1 else 0 end) as `30 days` ,sum(case when (sr_returned_date_sk - ss_sold_date_sk > 30) and (sr_returned_date_sk - ss_sold_date_sk <= 60) then 1 else 0 end ) as `31-60 days` ,sum(case when (sr_returned_date_sk - ss_sold_date_sk > 60) and (sr_returned_date_sk - ss_sold_date_sk <= 90) then 1 else 0 end) as `61-90 days` ,sum(case when (sr_returned_date_sk - ss_sold_date_sk > 90) and (sr_returned_date_sk - ss_sold_date_sk <= 120) then 1 else 0 end) as `91-120 days` ,sum(case when (sr_returned_date_sk - ss_sold_date_sk > 120) then 1 else 0 end) as `>120 days` from store_sales ,store_returns ,store ,date_dim d1 ,date_dim d2 where d2.d_year = 2002 and d2.d_moy = 8 and ss_ticket_number = sr_ticket_number and ss_item_sk = sr_item_sk and ss_sold_date_sk = d1.d_date_sk and sr_returned_date_sk = d2.d_date_sk and ss_customer_sk = sr_customer_sk and ss_store_sk = s_store_sk group by s_store_name ,s_company_id ,s_street_number ,s_street_name ,s_street_type ,s_suite_number ,s_city ,s_county ,s_state ,s_zip order by s_store_name ,s_company_id ,s_street_number ,s_street_name ,s_street_type ,s_suite_number ,s_city ,s_county ,s_state ,s_zip LIMIT 100; -- end query 60 in stream 0 using template query50.tpl -- start query 61 in stream 0 using template query42.tpl select dt.d_year ,item.i_category_id ,item.i_category ,sum(ss_ext_sales_price) from date_dim dt ,store_sales ,item where dt.d_date_sk = store_sales.ss_sold_date_sk and store_sales.ss_item_sk = item.i_item_sk and item.i_manager_id = 1 and dt.d_moy=11 and dt.d_year=1999 group by dt.d_year ,item.i_category_id ,item.i_category order by sum(ss_ext_sales_price) desc,dt.d_year ,item.i_category_id ,item.i_category LIMIT 100 ; -- end query 61 in stream 0 using template query42.tpl -- start query 62 in stream 0 using template query41.tpl select distinct(i_product_name) from item i1 where i_manufact_id between 794 and 794+40 and (select count(*) as item_cnt from item where (i_manufact = i1.i_manufact and ((i_category = 'Women' and (i_color = 'pink' or i_color = 'yellow') and (i_units = 'Lb' or i_units = 'Pallet') and (i_size = 'small' or i_size = 'petite') ) or (i_category = 'Women' and (i_color = 'deep' or i_color = 'goldenrod') and (i_units = 'Bundle' or i_units = 'Oz') and (i_size = 'extra large' or i_size = 'economy') ) or (i_category = 'Men' and (i_color = 'peru' or i_color = 'cream') and (i_units = 'Case' or i_units = 'Ounce') and (i_size = 'medium' or i_size = 'N/A') ) or (i_category = 'Men' and (i_color = 'purple' or i_color = 'floral') and (i_units = 'Each' or i_units = 'Cup') and (i_size = 'small' or i_size = 'petite') ))) or (i_manufact = i1.i_manufact and ((i_category = 'Women' and (i_color = 'blue' or i_color = 'seashell') and (i_units = 'Pound' or i_units = 'Carton') and (i_size = 'small' or i_size = 'petite') ) or (i_category = 'Women' and (i_color = 'slate' or i_color = 'saddle') and (i_units = 'Gram' or i_units = 'Tsp') and (i_size = 'extra large' or i_size = 'economy') ) or (i_category = 'Men' and (i_color = 'midnight' or i_color = 'chiffon') and (i_units = 'Box' or i_units = 'Ton') and (i_size = 'medium' or i_size = 'N/A') ) or (i_category = 'Men' and (i_color = 'orchid' or i_color = 'magenta') and (i_units = 'Unknown' or i_units = 'Tbl') and (i_size = 'small' or i_size = 'petite') )))) > 0 order by i_product_name LIMIT 100; -- end query 62 in stream 0 using template query41.tpl -- start query 63 in stream 0 using template query8.tpl select s_store_name ,sum(ss_net_profit) from store_sales ,date_dim ,store, (select ca_zip from ( SELECT substr(ca_zip,1,5) ca_zip FROM customer_address WHERE substr(ca_zip,1,5) IN ( '43758','76357','20728','59309','19777','27690', '23681','52275','64367','24674','79465', '52936','53936','91889','89248','70394', '66020','56289','45541','29900','99055', '47395','16654','26748','74456','31039', '77674','87076','92273','31667','20150', '84426','75885','61588','57973','29487', '95008','65615','24339','84923','38463', '13811','44227','18570','40389','14584', '33007','61590','47363','57853','43499', '90755','47141','14392','33991','77031', '22854','20127','10624','15730','75295', '98460','17059','26953','82996','17095', '53227','34618','86978','33613','12541', '63977','53929','55459','11516','85350', '99888','23506','10569','66837','50031', '28282','83901','98554','54828','14616', '12743','42473','95507','30542','12883', '95097','61307','32530','37753','53116', '10989','87430','22114','68848','21246', '68327','28446','85870','11697','30541', '22933','70727','17570','55311','73355', '16347','61573','81229','95480','92091', '52603','51232','62666','12173','31993', '98202','78325','46798','63259','34167', '50435','56182','29390','51732','88435', '10366','46637','69283','18218','33324', '24139','16122','53142','16832','98386', '41451','85109','32534','83953','76537', '60857','59939','22271','38788','26296', '59937','14272','98651','38185','16322', '13735','56321','81398','36035','36512', '96290','40596','22748','77965','28512', '15540','20574','72340','81870','31905', '18121','26282','30345','38703','74274', '71129','23244','68810','10106','55461', '25528','71474','37071','21552','81846', '64930','13233','11694','17829','43790', '60379','11482','22714','40977','73320', '13928','78952','92802','66663','95765', '86101','19813','90867','81258','93891', '32755','21548','36452','50931','95773', '57046','14736','30562','44667','80519', '99886','97296','38505','29732','38693', '83898','88032','64442','25944','39303', '70781','92448','64252','89641','88070', '38159','27654','72120','41689','37122', '63776','90416','28479','14787','18038', '39783','50062','28010','13042','86777', '32380','80664','33558','43641','14627', '68858','57733','53458','73016','76141', '42375','12248','38778','50092','80825', '58934','12145','78407','57009','52782', '72140','35635','63926','35282','29292', '30149','33576','95945','48303','56310', '32214','69726','48249','91163','57311', '12361','20491','13551','61620','59648', '44466','53607','18410','99090','37973', '17986','80713','95948','35103','51799', '54707','52269','86117','44909','15530', '28999','80844','62823','46487','15144', '51445','81050','34943','45141','28541', '12414','56922','50548','16422','16780', '53104','60629','24405','61768','48257', '92852','27390','24411','17776','81487', '34848','45773','64188','24209','55276', '11379','33956','46173','67361','32337', '82112','73196','38461','43987','17980', '65414','12247','42107','15326','73018', '59993','85526','50231','60176','23889', '88012','27859','44921','50915','21742', '21272','64763','78761','62002','18502', '42208','49675','69413','46013','67034', '52739','94050','76249','25105','67299', '77588','50637','14333','39372','98030', '79792','12014','56236','61057','51347', '87879','71564','48478','33078','23325', '25526','52855','27570','78396','18695', '24397','76087','35195','97232','29136', '15812','18408','40746','78749') intersect select ca_zip from (SELECT substr(ca_zip,1,5) ca_zip,count(*) cnt FROM customer_address, customer WHERE ca_address_sk = c_current_addr_sk and c_preferred_cust_flag='Y' group by ca_zip having count(*) > 10)A1)A2) V1 where ss_store_sk = s_store_sk and ss_sold_date_sk = d_date_sk and d_qoy = 1 and d_year = 2000 and (substr(s_zip,1,2) = substr(V1.ca_zip,1,2)) group by s_store_name order by s_store_name LIMIT 100; -- end query 63 in stream 0 using template query8.tpl -- start query 64 in stream 0 using template query12.tpl select i_item_id ,i_item_desc ,i_category ,i_class ,i_current_price ,sum(ws_ext_sales_price) as itemrevenue ,sum(ws_ext_sales_price)*100/sum(sum(ws_ext_sales_price)) over (partition by i_class) as revenueratio from web_sales ,item ,date_dim where ws_item_sk = i_item_sk and i_category in ('Women', 'Children', 'Books') and ws_sold_date_sk = d_date_sk and d_date between cast('2001-02-28' as date) and (cast('2001-02-28' as date) + interval 30 days) group by i_item_id ,i_item_desc ,i_category ,i_class ,i_current_price order by i_category ,i_class ,i_item_id ,i_item_desc ,revenueratio LIMIT 100; -- end query 64 in stream 0 using template query12.tpl -- start query 65 in stream 0 using template query20.tpl select i_item_id ,i_item_desc ,i_category ,i_class ,i_current_price ,sum(cs_ext_sales_price) as itemrevenue ,sum(cs_ext_sales_price)*100/sum(sum(cs_ext_sales_price)) over (partition by i_class) as revenueratio from catalog_sales ,item ,date_dim where cs_item_sk = i_item_sk and i_category in ('Men', 'Home', 'Music') and cs_sold_date_sk = d_date_sk and d_date between cast('1999-03-08' as date) and (cast('1999-03-08' as date) + interval 30 days) group by i_item_id ,i_item_desc ,i_category ,i_class ,i_current_price order by i_category ,i_class ,i_item_id ,i_item_desc ,revenueratio LIMIT 100; -- end query 65 in stream 0 using template query20.tpl -- start query 66 in stream 0 using template query88.tpl select * from (select count(*) h8_30_to_9 from store_sales, household_demographics , time_dim, store where ss_sold_time_sk = time_dim.t_time_sk and ss_hdemo_sk = household_demographics.hd_demo_sk and ss_store_sk = s_store_sk and time_dim.t_hour = 8 and time_dim.t_minute >= 30 and ((household_demographics.hd_dep_count = -1 and household_demographics.hd_vehicle_count<=-1+2) or (household_demographics.hd_dep_count = 0 and household_demographics.hd_vehicle_count<=0+2) or (household_demographics.hd_dep_count = 2 and household_demographics.hd_vehicle_count<=2+2)) and store.s_store_name = 'ese') s1, (select count(*) h9_to_9_30 from store_sales, household_demographics , time_dim, store where ss_sold_time_sk = time_dim.t_time_sk and ss_hdemo_sk = household_demographics.hd_demo_sk and ss_store_sk = s_store_sk and time_dim.t_hour = 9 and time_dim.t_minute < 30 and ((household_demographics.hd_dep_count = -1 and household_demographics.hd_vehicle_count<=-1+2) or (household_demographics.hd_dep_count = 0 and household_demographics.hd_vehicle_count<=0+2) or (household_demographics.hd_dep_count = 2 and household_demographics.hd_vehicle_count<=2+2)) and store.s_store_name = 'ese') s2, (select count(*) h9_30_to_10 from store_sales, household_demographics , time_dim, store where ss_sold_time_sk = time_dim.t_time_sk and ss_hdemo_sk = household_demographics.hd_demo_sk and ss_store_sk = s_store_sk and time_dim.t_hour = 9 and time_dim.t_minute >= 30 and ((household_demographics.hd_dep_count = -1 and household_demographics.hd_vehicle_count<=-1+2) or (household_demographics.hd_dep_count = 0 and household_demographics.hd_vehicle_count<=0+2) or (household_demographics.hd_dep_count = 2 and household_demographics.hd_vehicle_count<=2+2)) and store.s_store_name = 'ese') s3, (select count(*) h10_to_10_30 from store_sales, household_demographics , time_dim, store where ss_sold_time_sk = time_dim.t_time_sk and ss_hdemo_sk = household_demographics.hd_demo_sk and ss_store_sk = s_store_sk and time_dim.t_hour = 10 and time_dim.t_minute < 30 and ((household_demographics.hd_dep_count = -1 and household_demographics.hd_vehicle_count<=-1+2) or (household_demographics.hd_dep_count = 0 and household_demographics.hd_vehicle_count<=0+2) or (household_demographics.hd_dep_count = 2 and household_demographics.hd_vehicle_count<=2+2)) and store.s_store_name = 'ese') s4, (select count(*) h10_30_to_11 from store_sales, household_demographics , time_dim, store where ss_sold_time_sk = time_dim.t_time_sk and ss_hdemo_sk = household_demographics.hd_demo_sk and ss_store_sk = s_store_sk and time_dim.t_hour = 10 and time_dim.t_minute >= 30 and ((household_demographics.hd_dep_count = -1 and household_demographics.hd_vehicle_count<=-1+2) or (household_demographics.hd_dep_count = 0 and household_demographics.hd_vehicle_count<=0+2) or (household_demographics.hd_dep_count = 2 and household_demographics.hd_vehicle_count<=2+2)) and store.s_store_name = 'ese') s5, (select count(*) h11_to_11_30 from store_sales, household_demographics , time_dim, store where ss_sold_time_sk = time_dim.t_time_sk and ss_hdemo_sk = household_demographics.hd_demo_sk and ss_store_sk = s_store_sk and time_dim.t_hour = 11 and time_dim.t_minute < 30 and ((household_demographics.hd_dep_count = -1 and household_demographics.hd_vehicle_count<=-1+2) or (household_demographics.hd_dep_count = 0 and household_demographics.hd_vehicle_count<=0+2) or (household_demographics.hd_dep_count = 2 and household_demographics.hd_vehicle_count<=2+2)) and store.s_store_name = 'ese') s6, (select count(*) h11_30_to_12 from store_sales, household_demographics , time_dim, store where ss_sold_time_sk = time_dim.t_time_sk and ss_hdemo_sk = household_demographics.hd_demo_sk and ss_store_sk = s_store_sk and time_dim.t_hour = 11 and time_dim.t_minute >= 30 and ((household_demographics.hd_dep_count = -1 and household_demographics.hd_vehicle_count<=-1+2) or (household_demographics.hd_dep_count = 0 and household_demographics.hd_vehicle_count<=0+2) or (household_demographics.hd_dep_count = 2 and household_demographics.hd_vehicle_count<=2+2)) and store.s_store_name = 'ese') s7, (select count(*) h12_to_12_30 from store_sales, household_demographics , time_dim, store where ss_sold_time_sk = time_dim.t_time_sk and ss_hdemo_sk = household_demographics.hd_demo_sk and ss_store_sk = s_store_sk and time_dim.t_hour = 12 and time_dim.t_minute < 30 and ((household_demographics.hd_dep_count = -1 and household_demographics.hd_vehicle_count<=-1+2) or (household_demographics.hd_dep_count = 0 and household_demographics.hd_vehicle_count<=0+2) or (household_demographics.hd_dep_count = 2 and household_demographics.hd_vehicle_count<=2+2)) and store.s_store_name = 'ese') s8 ; -- end query 66 in stream 0 using template query88.tpl -- start query 67 in stream 0 using template query82.tpl select i_item_id ,i_item_desc ,i_current_price from item, inventory, date_dim, store_sales where i_current_price between 9 and 9+30 and inv_item_sk = i_item_sk and d_date_sk=inv_date_sk and d_date between cast('2001-06-07' as date) and (cast('2001-06-07' as date) + interval 60 days) and i_manufact_id in (797,412,331,589) and inv_quantity_on_hand between 100 and 500 and ss_item_sk = i_item_sk group by i_item_id,i_item_desc,i_current_price order by i_item_id LIMIT 100; -- end query 67 in stream 0 using template query82.tpl -- start query 68 in stream 0 using template query23.tpl with frequent_ss_items as (select substr(i_item_desc,1,30) itemdesc,i_item_sk item_sk,d_date solddate,count(*) cnt from store_sales ,date_dim ,item where ss_sold_date_sk = d_date_sk and ss_item_sk = i_item_sk and d_year in (2000,2000+1,2000+2,2000+3) group by substr(i_item_desc,1,30),i_item_sk,d_date having count(*) >4), max_store_sales as (select max(csales) tpcds_cmax from (select c_customer_sk,sum(ss_quantity*ss_sales_price) csales from store_sales ,customer ,date_dim where ss_customer_sk = c_customer_sk and ss_sold_date_sk = d_date_sk and d_year in (2000,2000+1,2000+2,2000+3) group by c_customer_sk)), best_ss_customer as (select c_customer_sk,sum(ss_quantity*ss_sales_price) ssales from store_sales ,customer where ss_customer_sk = c_customer_sk group by c_customer_sk having sum(ss_quantity*ss_sales_price) > (95/100.0) * (select * from max_store_sales)) select sum(sales) from (select cs_quantity*cs_list_price sales from catalog_sales ,date_dim where d_year = 2000 and d_moy = 7 and cs_sold_date_sk = d_date_sk and cs_item_sk in (select item_sk from frequent_ss_items) and cs_bill_customer_sk in (select c_customer_sk from best_ss_customer) union all select ws_quantity*ws_list_price sales from web_sales ,date_dim where d_year = 2000 and d_moy = 7 and ws_sold_date_sk = d_date_sk and ws_item_sk in (select item_sk from frequent_ss_items) and ws_bill_customer_sk in (select c_customer_sk from best_ss_customer)) LIMIT 100; with frequent_ss_items as (select substr(i_item_desc,1,30) itemdesc,i_item_sk item_sk,d_date solddate,count(*) cnt from store_sales ,date_dim ,item where ss_sold_date_sk = d_date_sk and ss_item_sk = i_item_sk and d_year in (2000,2000 + 1,2000 + 2,2000 + 3) group by substr(i_item_desc,1,30),i_item_sk,d_date having count(*) >4), max_store_sales as (select max(csales) tpcds_cmax from (select c_customer_sk,sum(ss_quantity*ss_sales_price) csales from store_sales ,customer ,date_dim where ss_customer_sk = c_customer_sk and ss_sold_date_sk = d_date_sk and d_year in (2000,2000+1,2000+2,2000+3) group by c_customer_sk)), best_ss_customer as (select c_customer_sk,sum(ss_quantity*ss_sales_price) ssales from store_sales ,customer where ss_customer_sk = c_customer_sk group by c_customer_sk having sum(ss_quantity*ss_sales_price) > (95/100.0) * (select * from max_store_sales)) select c_last_name,c_first_name,sales from (select c_last_name,c_first_name,sum(cs_quantity*cs_list_price) sales from catalog_sales ,customer ,date_dim where d_year = 2000 and d_moy = 7 and cs_sold_date_sk = d_date_sk and cs_item_sk in (select item_sk from frequent_ss_items) and cs_bill_customer_sk in (select c_customer_sk from best_ss_customer) and cs_bill_customer_sk = c_customer_sk group by c_last_name,c_first_name union all select c_last_name,c_first_name,sum(ws_quantity*ws_list_price) sales from web_sales ,customer ,date_dim where d_year = 2000 and d_moy = 7 and ws_sold_date_sk = d_date_sk and ws_item_sk in (select item_sk from frequent_ss_items) and ws_bill_customer_sk in (select c_customer_sk from best_ss_customer) and ws_bill_customer_sk = c_customer_sk group by c_last_name,c_first_name) order by c_last_name,c_first_name,sales LIMIT 100; -- end query 68 in stream 0 using template query23.tpl -- start query 69 in stream 0 using template query14.tpl with cross_items as (select i_item_sk ss_item_sk from item, (select iss.i_brand_id brand_id ,iss.i_class_id class_id ,iss.i_category_id category_id from store_sales ,item iss ,date_dim d1 where ss_item_sk = iss.i_item_sk and ss_sold_date_sk = d1.d_date_sk and d1.d_year between 1999 AND 1999 + 2 intersect select ics.i_brand_id ,ics.i_class_id ,ics.i_category_id from catalog_sales ,item ics ,date_dim d2 where cs_item_sk = ics.i_item_sk and cs_sold_date_sk = d2.d_date_sk and d2.d_year between 1999 AND 1999 + 2 intersect select iws.i_brand_id ,iws.i_class_id ,iws.i_category_id from web_sales ,item iws ,date_dim d3 where ws_item_sk = iws.i_item_sk and ws_sold_date_sk = d3.d_date_sk and d3.d_year between 1999 AND 1999 + 2) where i_brand_id = brand_id and i_class_id = class_id and i_category_id = category_id ), avg_sales as (select avg(quantity*list_price) average_sales from (select ss_quantity quantity ,ss_list_price list_price from store_sales ,date_dim where ss_sold_date_sk = d_date_sk and d_year between 1999 and 1999 + 2 union all select cs_quantity quantity ,cs_list_price list_price from catalog_sales ,date_dim where cs_sold_date_sk = d_date_sk and d_year between 1999 and 1999 + 2 union all select ws_quantity quantity ,ws_list_price list_price from web_sales ,date_dim where ws_sold_date_sk = d_date_sk and d_year between 1999 and 1999 + 2) x) select channel, i_brand_id,i_class_id,i_category_id,sum(sales), sum(number_sales) from( select 'store' channel, i_brand_id,i_class_id ,i_category_id,sum(ss_quantity*ss_list_price) sales , count(*) number_sales from store_sales ,item ,date_dim where ss_item_sk in (select ss_item_sk from cross_items) and ss_item_sk = i_item_sk and ss_sold_date_sk = d_date_sk and d_year = 1999+2 and d_moy = 11 group by i_brand_id,i_class_id,i_category_id having sum(ss_quantity*ss_list_price) > (select average_sales from avg_sales) union all select 'catalog' channel, i_brand_id,i_class_id,i_category_id, sum(cs_quantity*cs_list_price) sales, count(*) number_sales from catalog_sales ,item ,date_dim where cs_item_sk in (select ss_item_sk from cross_items) and cs_item_sk = i_item_sk and cs_sold_date_sk = d_date_sk and d_year = 1999+2 and d_moy = 11 group by i_brand_id,i_class_id,i_category_id having sum(cs_quantity*cs_list_price) > (select average_sales from avg_sales) union all select 'web' channel, i_brand_id,i_class_id,i_category_id, sum(ws_quantity*ws_list_price) sales , count(*) number_sales from web_sales ,item ,date_dim where ws_item_sk in (select ss_item_sk from cross_items) and ws_item_sk = i_item_sk and ws_sold_date_sk = d_date_sk and d_year = 1999+2 and d_moy = 11 group by i_brand_id,i_class_id,i_category_id having sum(ws_quantity*ws_list_price) > (select average_sales from avg_sales) ) y group by rollup (channel, i_brand_id,i_class_id,i_category_id) order by channel,i_brand_id,i_class_id,i_category_id LIMIT 100; with cross_items as (select i_item_sk ss_item_sk from item, (select iss.i_brand_id brand_id ,iss.i_class_id class_id ,iss.i_category_id category_id from store_sales ,item iss ,date_dim d1 where ss_item_sk = iss.i_item_sk and ss_sold_date_sk = d1.d_date_sk and d1.d_year between 1999 AND 1999 + 2 intersect select ics.i_brand_id ,ics.i_class_id ,ics.i_category_id from catalog_sales ,item ics ,date_dim d2 where cs_item_sk = ics.i_item_sk and cs_sold_date_sk = d2.d_date_sk and d2.d_year between 1999 AND 1999 + 2 intersect select iws.i_brand_id ,iws.i_class_id ,iws.i_category_id from web_sales ,item iws ,date_dim d3 where ws_item_sk = iws.i_item_sk and ws_sold_date_sk = d3.d_date_sk and d3.d_year between 1999 AND 1999 + 2) x where i_brand_id = brand_id and i_class_id = class_id and i_category_id = category_id ), avg_sales as (select avg(quantity*list_price) average_sales from (select ss_quantity quantity ,ss_list_price list_price from store_sales ,date_dim where ss_sold_date_sk = d_date_sk and d_year between 1999 and 1999 + 2 union all select cs_quantity quantity ,cs_list_price list_price from catalog_sales ,date_dim where cs_sold_date_sk = d_date_sk and d_year between 1999 and 1999 + 2 union all select ws_quantity quantity ,ws_list_price list_price from web_sales ,date_dim where ws_sold_date_sk = d_date_sk and d_year between 1999 and 1999 + 2) x) select this_year.channel ty_channel ,this_year.i_brand_id ty_brand ,this_year.i_class_id ty_class ,this_year.i_category_id ty_category ,this_year.sales ty_sales ,this_year.number_sales ty_number_sales ,last_year.channel ly_channel ,last_year.i_brand_id ly_brand ,last_year.i_class_id ly_class ,last_year.i_category_id ly_category ,last_year.sales ly_sales ,last_year.number_sales ly_number_sales from (select 'store' channel, i_brand_id,i_class_id,i_category_id ,sum(ss_quantity*ss_list_price) sales, count(*) number_sales from store_sales ,item ,date_dim where ss_item_sk in (select ss_item_sk from cross_items) and ss_item_sk = i_item_sk and ss_sold_date_sk = d_date_sk and d_week_seq = (select d_week_seq from date_dim where d_year = 1999 + 1 and d_moy = 12 and d_dom = 28) group by i_brand_id,i_class_id,i_category_id having sum(ss_quantity*ss_list_price) > (select average_sales from avg_sales)) this_year, (select 'store' channel, i_brand_id,i_class_id ,i_category_id, sum(ss_quantity*ss_list_price) sales, count(*) number_sales from store_sales ,item ,date_dim where ss_item_sk in (select ss_item_sk from cross_items) and ss_item_sk = i_item_sk and ss_sold_date_sk = d_date_sk and d_week_seq = (select d_week_seq from date_dim where d_year = 1999 and d_moy = 12 and d_dom = 28) group by i_brand_id,i_class_id,i_category_id having sum(ss_quantity*ss_list_price) > (select average_sales from avg_sales)) last_year where this_year.i_brand_id= last_year.i_brand_id and this_year.i_class_id = last_year.i_class_id and this_year.i_category_id = last_year.i_category_id order by this_year.channel, this_year.i_brand_id, this_year.i_class_id, this_year.i_category_id LIMIT 100; -- end query 69 in stream 0 using template query14.tpl -- start query 70 in stream 0 using template query57.tpl with v1 as( select i_category, i_brand, cc_name, d_year, d_moy, sum(cs_sales_price) sum_sales, avg(sum(cs_sales_price)) over (partition by i_category, i_brand, cc_name, d_year) avg_monthly_sales, rank() over (partition by i_category, i_brand, cc_name order by d_year, d_moy) rn from item, catalog_sales, date_dim, call_center where cs_item_sk = i_item_sk and cs_sold_date_sk = d_date_sk and cc_call_center_sk= cs_call_center_sk and ( d_year = 1999 or ( d_year = 1999-1 and d_moy =12) or ( d_year = 1999+1 and d_moy =1) ) group by i_category, i_brand, cc_name , d_year, d_moy), v2 as( select v1.i_category, v1.i_brand ,v1.d_year, v1.d_moy ,v1.avg_monthly_sales ,v1.sum_sales, v1_lag.sum_sales psum, v1_lead.sum_sales nsum from v1, v1 v1_lag, v1 v1_lead where v1.i_category = v1_lag.i_category and v1.i_category = v1_lead.i_category and v1.i_brand = v1_lag.i_brand and v1.i_brand = v1_lead.i_brand and v1. cc_name = v1_lag. cc_name and v1. cc_name = v1_lead. cc_name and v1.rn = v1_lag.rn + 1 and v1.rn = v1_lead.rn - 1) select * from v2 where d_year = 1999 and avg_monthly_sales > 0 and case when avg_monthly_sales > 0 then abs(sum_sales - avg_monthly_sales) / avg_monthly_sales else null end > 0.1 order by sum_sales - avg_monthly_sales, nsum LIMIT 100; -- end query 70 in stream 0 using template query57.tpl -- start query 71 in stream 0 using template query65.tpl select s_store_name, i_item_desc, sc.revenue, i_current_price, i_wholesale_cost, i_brand from store, item, (select ss_store_sk, avg(revenue) as ave from (select ss_store_sk, ss_item_sk, sum(ss_sales_price) as revenue from store_sales, date_dim where ss_sold_date_sk = d_date_sk and d_month_seq between 1212 and 1212+11 group by ss_store_sk, ss_item_sk) sa group by ss_store_sk) sb, (select ss_store_sk, ss_item_sk, sum(ss_sales_price) as revenue from store_sales, date_dim where ss_sold_date_sk = d_date_sk and d_month_seq between 1212 and 1212+11 group by ss_store_sk, ss_item_sk) sc where sb.ss_store_sk = sc.ss_store_sk and sc.revenue <= 0.1 * sb.ave and s_store_sk = sc.ss_store_sk and i_item_sk = sc.ss_item_sk order by s_store_name, i_item_desc LIMIT 100; -- end query 71 in stream 0 using template query65.tpl -- start query 72 in stream 0 using template query71.tpl select i_brand_id brand_id, i_brand brand,t_hour,t_minute, sum(ext_price) ext_price from item, (select ws_ext_sales_price as ext_price, ws_sold_date_sk as sold_date_sk, ws_item_sk as sold_item_sk, ws_sold_time_sk as time_sk from web_sales,date_dim where d_date_sk = ws_sold_date_sk and d_moy=12 and d_year=2002 union all select cs_ext_sales_price as ext_price, cs_sold_date_sk as sold_date_sk, cs_item_sk as sold_item_sk, cs_sold_time_sk as time_sk from catalog_sales,date_dim where d_date_sk = cs_sold_date_sk and d_moy=12 and d_year=2002 union all select ss_ext_sales_price as ext_price, ss_sold_date_sk as sold_date_sk, ss_item_sk as sold_item_sk, ss_sold_time_sk as time_sk from store_sales,date_dim where d_date_sk = ss_sold_date_sk and d_moy=12 and d_year=2002 ) tmp,time_dim where sold_item_sk = i_item_sk and i_manager_id=1 and time_sk = t_time_sk and (t_meal_time = 'breakfast' or t_meal_time = 'dinner') group by i_brand, i_brand_id,t_hour,t_minute order by ext_price desc, i_brand_id ; -- end query 72 in stream 0 using template query71.tpl -- start query 73 in stream 0 using template query34.tpl select c_last_name ,c_first_name ,c_salutation ,c_preferred_cust_flag ,ss_ticket_number ,cnt from (select ss_ticket_number ,ss_customer_sk ,count(*) cnt from store_sales,date_dim,store,household_demographics where store_sales.ss_sold_date_sk = date_dim.d_date_sk and store_sales.ss_store_sk = store.s_store_sk and store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk and (date_dim.d_dom between 1 and 3 or date_dim.d_dom between 25 and 28) and (household_demographics.hd_buy_potential = '1001-5000' or household_demographics.hd_buy_potential = '0-500') and household_demographics.hd_vehicle_count > 0 and (case when household_demographics.hd_vehicle_count > 0 then household_demographics.hd_dep_count/ household_demographics.hd_vehicle_count else null end) > 1.2 and date_dim.d_year in (2000,2000+1,2000+2) and store.s_county in ('Williamson County','Walker County','Ziebach County','Walker County', 'Ziebach County','Ziebach County','Ziebach County','Ziebach County') group by ss_ticket_number,ss_customer_sk) dn,customer where ss_customer_sk = c_customer_sk and cnt between 15 and 20 order by c_last_name,c_first_name,c_salutation,c_preferred_cust_flag desc, ss_ticket_number; -- end query 73 in stream 0 using template query34.tpl -- start query 74 in stream 0 using template query48.tpl select sum (ss_quantity) from store_sales, store, customer_demographics, customer_address, date_dim where s_store_sk = ss_store_sk and ss_sold_date_sk = d_date_sk and d_year = 1998 and ( ( cd_demo_sk = ss_cdemo_sk and cd_marital_status = 'S' and cd_education_status = 'Secondary' and ss_sales_price between 100.00 and 150.00 ) or ( cd_demo_sk = ss_cdemo_sk and cd_marital_status = 'M' and cd_education_status = 'Primary' and ss_sales_price between 50.00 and 100.00 ) or ( cd_demo_sk = ss_cdemo_sk and cd_marital_status = 'W' and cd_education_status = '2 yr Degree' and ss_sales_price between 150.00 and 200.00 ) ) and ( ( ss_addr_sk = ca_address_sk and ca_country = 'United States' and ca_state in ('ND', 'KY', 'TX') and ss_net_profit between 0 and 2000 ) or (ss_addr_sk = ca_address_sk and ca_country = 'United States' and ca_state in ('WI', 'AR', 'GA') and ss_net_profit between 150 and 3000 ) or (ss_addr_sk = ca_address_sk and ca_country = 'United States' and ca_state in ('NC', 'SD', 'IL') and ss_net_profit between 50 and 25000 ) ) ; -- end query 74 in stream 0 using template query48.tpl -- start query 75 in stream 0 using template query30.tpl with customer_total_return as (select wr_returning_customer_sk as ctr_customer_sk ,ca_state as ctr_state, sum(wr_return_amt) as ctr_total_return from web_returns ,date_dim ,customer_address where wr_returned_date_sk = d_date_sk and d_year =2001 and wr_returning_addr_sk = ca_address_sk group by wr_returning_customer_sk ,ca_state) select c_customer_id,c_salutation,c_first_name,c_last_name,c_preferred_cust_flag ,c_birth_day,c_birth_month,c_birth_year,c_birth_country,c_login,c_email_address ,c_last_review_date_sk,ctr_total_return from customer_total_return ctr1 ,customer_address ,customer where ctr1.ctr_total_return > (select avg(ctr_total_return)*1.2 from customer_total_return ctr2 where ctr1.ctr_state = ctr2.ctr_state) and ca_address_sk = c_current_addr_sk and ca_state = 'MO' and ctr1.ctr_customer_sk = c_customer_sk order by c_customer_id,c_salutation,c_first_name,c_last_name,c_preferred_cust_flag ,c_birth_day,c_birth_month,c_birth_year,c_birth_country,c_login,c_email_address ,c_last_review_date_sk,ctr_total_return LIMIT 100; -- end query 75 in stream 0 using template query30.tpl -- start query 76 in stream 0 using template query74.tpl with year_total as ( select c_customer_id customer_id ,c_first_name customer_first_name ,c_last_name customer_last_name ,d_year as year ,sum(ss_net_paid) year_total ,'s' sale_type from customer ,store_sales ,date_dim where c_customer_sk = ss_customer_sk and ss_sold_date_sk = d_date_sk and d_year in (1998,1998+1) group by c_customer_id ,c_first_name ,c_last_name ,d_year union all select c_customer_id customer_id ,c_first_name customer_first_name ,c_last_name customer_last_name ,d_year as year ,sum(ws_net_paid) year_total ,'w' sale_type from customer ,web_sales ,date_dim where c_customer_sk = ws_bill_customer_sk and ws_sold_date_sk = d_date_sk and d_year in (1998,1998+1) group by c_customer_id ,c_first_name ,c_last_name ,d_year ) select t_s_secyear.customer_id, t_s_secyear.customer_first_name, t_s_secyear.customer_last_name from year_total t_s_firstyear ,year_total t_s_secyear ,year_total t_w_firstyear ,year_total t_w_secyear where t_s_secyear.customer_id = t_s_firstyear.customer_id and t_s_firstyear.customer_id = t_w_secyear.customer_id and t_s_firstyear.customer_id = t_w_firstyear.customer_id and t_s_firstyear.sale_type = 's' and t_w_firstyear.sale_type = 'w' and t_s_secyear.sale_type = 's' and t_w_secyear.sale_type = 'w' and t_s_firstyear.year = 1998 and t_s_secyear.year = 1998+1 and t_w_firstyear.year = 1998 and t_w_secyear.year = 1998+1 and t_s_firstyear.year_total > 0 and t_w_firstyear.year_total > 0 and case when t_w_firstyear.year_total > 0 then t_w_secyear.year_total / t_w_firstyear.year_total else null end > case when t_s_firstyear.year_total > 0 then t_s_secyear.year_total / t_s_firstyear.year_total else null end order by 2,1,3 LIMIT 100; -- end query 76 in stream 0 using template query74.tpl -- start query 77 in stream 0 using template query87.tpl select count(*) from ((select distinct c_last_name, c_first_name, d_date from store_sales, date_dim, customer where store_sales.ss_sold_date_sk = date_dim.d_date_sk and store_sales.ss_customer_sk = customer.c_customer_sk and d_month_seq between 1212 and 1212+11) except (select distinct c_last_name, c_first_name, d_date from catalog_sales, date_dim, customer where catalog_sales.cs_sold_date_sk = date_dim.d_date_sk and catalog_sales.cs_bill_customer_sk = customer.c_customer_sk and d_month_seq between 1212 and 1212+11) except (select distinct c_last_name, c_first_name, d_date from web_sales, date_dim, customer where web_sales.ws_sold_date_sk = date_dim.d_date_sk and web_sales.ws_bill_customer_sk = customer.c_customer_sk and d_month_seq between 1212 and 1212+11) ) cool_cust ; -- end query 77 in stream 0 using template query87.tpl -- start query 78 in stream 0 using template query77.tpl with ss as (select s_store_sk, sum(ss_ext_sales_price) as sales, sum(ss_net_profit) as profit from store_sales, date_dim, store where ss_sold_date_sk = d_date_sk and d_date between cast('2002-08-18' as date) and (cast('2002-08-18' as date) + interval 30 days) and ss_store_sk = s_store_sk group by s_store_sk) , sr as (select s_store_sk, sum(sr_return_amt) as returns, sum(sr_net_loss) as profit_loss from store_returns, date_dim, store where sr_returned_date_sk = d_date_sk and d_date between cast('2002-08-18' as date) and (cast('2002-08-18' as date) + interval 30 days) and sr_store_sk = s_store_sk group by s_store_sk), cs as (select cs_call_center_sk, sum(cs_ext_sales_price) as sales, sum(cs_net_profit) as profit from catalog_sales, date_dim where cs_sold_date_sk = d_date_sk and d_date between cast('2002-08-18' as date) and (cast('2002-08-18' as date) + interval 30 days) group by cs_call_center_sk ), cr as (select cr_call_center_sk, sum(cr_return_amount) as returns, sum(cr_net_loss) as profit_loss from catalog_returns, date_dim where cr_returned_date_sk = d_date_sk and d_date between cast('2002-08-18' as date) and (cast('2002-08-18' as date) + interval 30 days) group by cr_call_center_sk ), ws as ( select wp_web_page_sk, sum(ws_ext_sales_price) as sales, sum(ws_net_profit) as profit from web_sales, date_dim, web_page where ws_sold_date_sk = d_date_sk and d_date between cast('2002-08-18' as date) and (cast('2002-08-18' as date) + interval 30 days) and ws_web_page_sk = wp_web_page_sk group by wp_web_page_sk), wr as (select wp_web_page_sk, sum(wr_return_amt) as returns, sum(wr_net_loss) as profit_loss from web_returns, date_dim, web_page where wr_returned_date_sk = d_date_sk and d_date between cast('2002-08-18' as date) and (cast('2002-08-18' as date) + interval 30 days) and wr_web_page_sk = wp_web_page_sk group by wp_web_page_sk) select channel , id , sum(sales) as sales , sum(returns) as returns , sum(profit) as profit from (select 'store channel' as channel , ss.s_store_sk as id , sales , coalesce(returns, 0) as returns , (profit - coalesce(profit_loss,0)) as profit from ss left join sr on ss.s_store_sk = sr.s_store_sk union all select 'catalog channel' as channel , cs_call_center_sk as id , sales , returns , (profit - profit_loss) as profit from cs , cr union all select 'web channel' as channel , ws.wp_web_page_sk as id , sales , coalesce(returns, 0) returns , (profit - coalesce(profit_loss,0)) as profit from ws left join wr on ws.wp_web_page_sk = wr.wp_web_page_sk ) x group by rollup (channel, id) order by channel ,id LIMIT 100; -- end query 78 in stream 0 using template query77.tpl -- start query 79 in stream 0 using template query73.tpl select c_last_name ,c_first_name ,c_salutation ,c_preferred_cust_flag ,ss_ticket_number ,cnt from (select ss_ticket_number ,ss_customer_sk ,count(*) cnt from store_sales,date_dim,store,household_demographics where store_sales.ss_sold_date_sk = date_dim.d_date_sk and store_sales.ss_store_sk = store.s_store_sk and store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk and date_dim.d_dom between 1 and 2 and (household_demographics.hd_buy_potential = '1001-5000' or household_demographics.hd_buy_potential = 'Unknown') and household_demographics.hd_vehicle_count > 0 and case when household_demographics.hd_vehicle_count > 0 then household_demographics.hd_dep_count/ household_demographics.hd_vehicle_count else null end > 1 and date_dim.d_year in (1999,1999+1,1999+2) and store.s_county in ('Walker County','Williamson County','Ziebach County','Walker County') group by ss_ticket_number,ss_customer_sk) dj,customer where ss_customer_sk = c_customer_sk and cnt between 1 and 5 order by cnt desc, c_last_name asc; -- end query 79 in stream 0 using template query73.tpl -- start query 80 in stream 0 using template query84.tpl select c_customer_id as customer_id , coalesce(c_last_name,'') || ', ' || coalesce(c_first_name,'') as customername from customer ,customer_address ,customer_demographics ,household_demographics ,income_band ,store_returns where ca_city = 'Fairfield' and c_current_addr_sk = ca_address_sk and ib_lower_bound >= 58125 and ib_upper_bound <= 58125 + 50000 and ib_income_band_sk = hd_income_band_sk and cd_demo_sk = c_current_cdemo_sk and hd_demo_sk = c_current_hdemo_sk and sr_cdemo_sk = cd_demo_sk order by c_customer_id LIMIT 100; -- end query 80 in stream 0 using template query84.tpl -- start query 81 in stream 0 using template query54.tpl with my_customers as ( select distinct c_customer_sk , c_current_addr_sk from ( select cs_sold_date_sk sold_date_sk, cs_bill_customer_sk customer_sk, cs_item_sk item_sk from catalog_sales union all select ws_sold_date_sk sold_date_sk, ws_bill_customer_sk customer_sk, ws_item_sk item_sk from web_sales ) cs_or_ws_sales, item, date_dim, customer where sold_date_sk = d_date_sk and item_sk = i_item_sk and i_category = 'Children' and i_class = 'toddlers' and c_customer_sk = cs_or_ws_sales.customer_sk and d_moy = 4 and d_year = 1999 ) , my_revenue as ( select c_customer_sk, sum(ss_ext_sales_price) as revenue from my_customers, store_sales, customer_address, store, date_dim where c_current_addr_sk = ca_address_sk and ca_county = s_county and ca_state = s_state and ss_sold_date_sk = d_date_sk and c_customer_sk = ss_customer_sk and d_month_seq between (select distinct d_month_seq+1 from date_dim where d_year = 1999 and d_moy = 4) and (select distinct d_month_seq+3 from date_dim where d_year = 1999 and d_moy = 4) group by c_customer_sk ) , segments as (select cast((revenue/50) as int) as segment from my_revenue ) select segment, count(*) as num_customers, segment*50 as segment_base from segments group by segment order by segment, num_customers LIMIT 100; -- end query 81 in stream 0 using template query54.tpl -- start query 82 in stream 0 using template query55.tpl select i_brand_id brand_id, i_brand brand, sum(ss_ext_sales_price) ext_price from date_dim, store_sales, item where d_date_sk = ss_sold_date_sk and ss_item_sk = i_item_sk and i_manager_id=76 and d_moy=12 and d_year=1999 group by i_brand, i_brand_id order by ext_price desc, i_brand_id LIMIT 100 ; -- end query 82 in stream 0 using template query55.tpl -- start query 83 in stream 0 using template query56.tpl with ss as ( select i_item_id,sum(ss_ext_sales_price) total_sales from store_sales, date_dim, customer_address, item where i_item_id in (select i_item_id from item where i_color in ('blush','hot','orange')) and ss_item_sk = i_item_sk and ss_sold_date_sk = d_date_sk and d_year = 2000 and d_moy = 5 and ss_addr_sk = ca_address_sk and ca_gmt_offset = -5 group by i_item_id), cs as ( select i_item_id,sum(cs_ext_sales_price) total_sales from catalog_sales, date_dim, customer_address, item where i_item_id in (select i_item_id from item where i_color in ('blush','hot','orange')) and cs_item_sk = i_item_sk and cs_sold_date_sk = d_date_sk and d_year = 2000 and d_moy = 5 and cs_bill_addr_sk = ca_address_sk and ca_gmt_offset = -5 group by i_item_id), ws as ( select i_item_id,sum(ws_ext_sales_price) total_sales from web_sales, date_dim, customer_address, item where i_item_id in (select i_item_id from item where i_color in ('blush','hot','orange')) and ws_item_sk = i_item_sk and ws_sold_date_sk = d_date_sk and d_year = 2000 and d_moy = 5 and ws_bill_addr_sk = ca_address_sk and ca_gmt_offset = -5 group by i_item_id) select i_item_id ,sum(total_sales) total_sales from (select * from ss union all select * from cs union all select * from ws) tmp1 group by i_item_id order by total_sales, i_item_id LIMIT 100; -- end query 83 in stream 0 using template query56.tpl -- start query 84 in stream 0 using template query2.tpl with wscs as (select sold_date_sk ,sales_price from (select ws_sold_date_sk sold_date_sk ,ws_ext_sales_price sales_price from web_sales union all select cs_sold_date_sk sold_date_sk ,cs_ext_sales_price sales_price from catalog_sales)), wswscs as (select d_week_seq, sum(case when (d_day_name='Sunday') then sales_price else null end) sun_sales, sum(case when (d_day_name='Monday') then sales_price else null end) mon_sales, sum(case when (d_day_name='Tuesday') then sales_price else null end) tue_sales, sum(case when (d_day_name='Wednesday') then sales_price else null end) wed_sales, sum(case when (d_day_name='Thursday') then sales_price else null end) thu_sales, sum(case when (d_day_name='Friday') then sales_price else null end) fri_sales, sum(case when (d_day_name='Saturday') then sales_price else null end) sat_sales from wscs ,date_dim where d_date_sk = sold_date_sk group by d_week_seq) select d_week_seq1 ,round(sun_sales1/sun_sales2,2) ,round(mon_sales1/mon_sales2,2) ,round(tue_sales1/tue_sales2,2) ,round(wed_sales1/wed_sales2,2) ,round(thu_sales1/thu_sales2,2) ,round(fri_sales1/fri_sales2,2) ,round(sat_sales1/sat_sales2,2) from (select wswscs.d_week_seq d_week_seq1 ,sun_sales sun_sales1 ,mon_sales mon_sales1 ,tue_sales tue_sales1 ,wed_sales wed_sales1 ,thu_sales thu_sales1 ,fri_sales fri_sales1 ,sat_sales sat_sales1 from wswscs,date_dim where date_dim.d_week_seq = wswscs.d_week_seq and d_year = 1998) y, (select wswscs.d_week_seq d_week_seq2 ,sun_sales sun_sales2 ,mon_sales mon_sales2 ,tue_sales tue_sales2 ,wed_sales wed_sales2 ,thu_sales thu_sales2 ,fri_sales fri_sales2 ,sat_sales sat_sales2 from wswscs ,date_dim where date_dim.d_week_seq = wswscs.d_week_seq and d_year = 1998+1) z where d_week_seq1=d_week_seq2-53 order by d_week_seq1; -- end query 84 in stream 0 using template query2.tpl -- start query 85 in stream 0 using template query26.tpl select i_item_id, avg(cs_quantity) agg1, avg(cs_list_price) agg2, avg(cs_coupon_amt) agg3, avg(cs_sales_price) agg4 from catalog_sales, customer_demographics, date_dim, item, promotion where cs_sold_date_sk = d_date_sk and cs_item_sk = i_item_sk and cs_bill_cdemo_sk = cd_demo_sk and cs_promo_sk = p_promo_sk and cd_gender = 'M' and cd_marital_status = 'S' and cd_education_status = '4 yr Degree' and (p_channel_email = 'N' or p_channel_event = 'N') and d_year = 1999 group by i_item_id order by i_item_id LIMIT 100; -- end query 85 in stream 0 using template query26.tpl -- start query 86 in stream 0 using template query40.tpl select w_state ,i_item_id ,sum(case when (cast(d_date as date) < cast ('1998-03-13' as date)) then cs_sales_price - coalesce(cr_refunded_cash,0) else 0 end) as sales_before ,sum(case when (cast(d_date as date) >= cast ('1998-03-13' as date)) then cs_sales_price - coalesce(cr_refunded_cash,0) else 0 end) as sales_after from catalog_sales left outer join catalog_returns on (cs_order_number = cr_order_number and cs_item_sk = cr_item_sk) ,warehouse ,item ,date_dim where i_current_price between 0.99 and 1.49 and i_item_sk = cs_item_sk and cs_warehouse_sk = w_warehouse_sk and cs_sold_date_sk = d_date_sk and d_date between (cast ('1998-03-13' as date) - interval 30 days) and (cast ('1998-03-13' as date) + interval 30 days) group by w_state,i_item_id order by w_state,i_item_id LIMIT 100; -- end query 86 in stream 0 using template query40.tpl -- start query 87 in stream 0 using template query72.tpl select i_item_desc ,w_warehouse_name ,d1.d_week_seq ,sum(case when p_promo_sk is null then 1 else 0 end) no_promo ,sum(case when p_promo_sk is not null then 1 else 0 end) promo ,count(*) total_cnt from catalog_sales join inventory on (cs_item_sk = inv_item_sk) join warehouse on (w_warehouse_sk=inv_warehouse_sk) join item on (i_item_sk = cs_item_sk) join customer_demographics on (cs_bill_cdemo_sk = cd_demo_sk) join household_demographics on (cs_bill_hdemo_sk = hd_demo_sk) join date_dim d1 on (cs_sold_date_sk = d1.d_date_sk) join date_dim d2 on (inv_date_sk = d2.d_date_sk) join date_dim d3 on (cs_ship_date_sk = d3.d_date_sk) left outer join promotion on (cs_promo_sk=p_promo_sk) left outer join catalog_returns on (cr_item_sk = cs_item_sk and cr_order_number = cs_order_number) where d1.d_week_seq = d2.d_week_seq and inv_quantity_on_hand < cs_quantity and d3.d_date > d1.d_date + 5 and hd_buy_potential = '501-1000' and d1.d_year = 2002 and cd_marital_status = 'M' group by i_item_desc,w_warehouse_name,d1.d_week_seq order by total_cnt desc, i_item_desc, w_warehouse_name, d_week_seq LIMIT 100; -- end query 87 in stream 0 using template query72.tpl -- start query 88 in stream 0 using template query53.tpl select * from (select i_manufact_id, sum(ss_sales_price) sum_sales, avg(sum(ss_sales_price)) over (partition by i_manufact_id) avg_quarterly_sales from item, store_sales, date_dim, store where ss_item_sk = i_item_sk and ss_sold_date_sk = d_date_sk and ss_store_sk = s_store_sk and d_month_seq in (1202,1202+1,1202+2,1202+3,1202+4,1202+5,1202+6,1202+7,1202+8,1202+9,1202+10,1202+11) and ((i_category in ('Books','Children','Electronics') and i_class in ('personal','portable','reference','self-help') and i_brand in ('scholaramalgamalg #14','scholaramalgamalg #7', 'exportiunivamalg #9','scholaramalgamalg #9')) or(i_category in ('Women','Music','Men') and i_class in ('accessories','classical','fragrances','pants') and i_brand in ('amalgimporto #1','edu packscholar #1','exportiimporto #1', 'importoamalg #1'))) group by i_manufact_id, d_qoy ) tmp1 where case when avg_quarterly_sales > 0 then abs (sum_sales - avg_quarterly_sales)/ avg_quarterly_sales else null end > 0.1 order by avg_quarterly_sales, sum_sales, i_manufact_id LIMIT 100; -- end query 88 in stream 0 using template query53.tpl -- start query 89 in stream 0 using template query79.tpl select c_last_name,c_first_name,substr(s_city,1,30),ss_ticket_number,amt,profit from (select ss_ticket_number ,ss_customer_sk ,store.s_city ,sum(ss_coupon_amt) amt ,sum(ss_net_profit) profit from store_sales,date_dim,store,household_demographics where store_sales.ss_sold_date_sk = date_dim.d_date_sk and store_sales.ss_store_sk = store.s_store_sk and store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk and (household_demographics.hd_dep_count = 9 or household_demographics.hd_vehicle_count > -1) and date_dim.d_dow = 1 and date_dim.d_year in (2000,2000+1,2000+2) and store.s_number_employees between 200 and 295 group by ss_ticket_number,ss_customer_sk,ss_addr_sk,store.s_city) ms,customer where ss_customer_sk = c_customer_sk order by c_last_name,c_first_name,substr(s_city,1,30), profit LIMIT 100; -- end query 89 in stream 0 using template query79.tpl -- start query 90 in stream 0 using template query18.tpl select i_item_id, ca_country, ca_state, ca_county, avg( cast(cs_quantity as decimal(12,2))) agg1, avg( cast(cs_list_price as decimal(12,2))) agg2, avg( cast(cs_coupon_amt as decimal(12,2))) agg3, avg( cast(cs_sales_price as decimal(12,2))) agg4, avg( cast(cs_net_profit as decimal(12,2))) agg5, avg( cast(c_birth_year as decimal(12,2))) agg6, avg( cast(cd1.cd_dep_count as decimal(12,2))) agg7 from catalog_sales, customer_demographics cd1, customer_demographics cd2, customer, customer_address, date_dim, item where cs_sold_date_sk = d_date_sk and cs_item_sk = i_item_sk and cs_bill_cdemo_sk = cd1.cd_demo_sk and cs_bill_customer_sk = c_customer_sk and cd1.cd_gender = 'F' and cd1.cd_education_status = '4 yr Degree' and c_current_cdemo_sk = cd2.cd_demo_sk and c_current_addr_sk = ca_address_sk and c_birth_month in (4,2,12,10,11,3) and d_year = 2001 and ca_state in ('AR','GA','CO' ,'MS','ND','KS','KY') group by rollup (i_item_id, ca_country, ca_state, ca_county) order by ca_country, ca_state, ca_county, i_item_id LIMIT 100; -- end query 90 in stream 0 using template query18.tpl -- start query 91 in stream 0 using template query13.tpl select avg(ss_quantity) ,avg(ss_ext_sales_price) ,avg(ss_ext_wholesale_cost) ,sum(ss_ext_wholesale_cost) from store_sales ,store ,customer_demographics ,household_demographics ,customer_address ,date_dim where s_store_sk = ss_store_sk and ss_sold_date_sk = d_date_sk and d_year = 2001 and((ss_hdemo_sk=hd_demo_sk and cd_demo_sk = ss_cdemo_sk and cd_marital_status = 'D' and cd_education_status = 'Advanced Degree' and ss_sales_price between 100.00 and 150.00 and hd_dep_count = 3 )or (ss_hdemo_sk=hd_demo_sk and cd_demo_sk = ss_cdemo_sk and cd_marital_status = 'U' and cd_education_status = '2 yr Degree' and ss_sales_price between 50.00 and 100.00 and hd_dep_count = 1 ) or (ss_hdemo_sk=hd_demo_sk and cd_demo_sk = ss_cdemo_sk and cd_marital_status = 'W' and cd_education_status = '4 yr Degree' and ss_sales_price between 150.00 and 200.00 and hd_dep_count = 1 )) and((ss_addr_sk = ca_address_sk and ca_country = 'United States' and ca_state in ('TX', 'OH', 'OK') and ss_net_profit between 100 and 200 ) or (ss_addr_sk = ca_address_sk and ca_country = 'United States' and ca_state in ('MS', 'NY', 'GA') and ss_net_profit between 150 and 300 ) or (ss_addr_sk = ca_address_sk and ca_country = 'United States' and ca_state in ('TN', 'IN', 'AL') and ss_net_profit between 50 and 250 )) ; -- end query 91 in stream 0 using template query13.tpl -- start query 92 in stream 0 using template query24.tpl with ssales as (select c_last_name ,c_first_name ,s_store_name ,ca_state ,s_state ,i_color ,i_current_price ,i_manager_id ,i_units ,i_size ,sum(ss_net_profit) netpaid from store_sales ,store_returns ,store ,item ,customer ,customer_address where ss_ticket_number = sr_ticket_number and ss_item_sk = sr_item_sk and ss_customer_sk = c_customer_sk and ss_item_sk = i_item_sk and ss_store_sk = s_store_sk and c_current_addr_sk = ca_address_sk and c_birth_country <> upper(ca_country) and s_zip = ca_zip and s_market_id=10 group by c_last_name ,c_first_name ,s_store_name ,ca_state ,s_state ,i_color ,i_current_price ,i_manager_id ,i_units ,i_size) select c_last_name ,c_first_name ,s_store_name ,sum(netpaid) paid from ssales where i_color = 'firebrick' group by c_last_name ,c_first_name ,s_store_name having sum(netpaid) > (select 0.05*avg(netpaid) from ssales) order by c_last_name ,c_first_name ,s_store_name ; with ssales as (select c_last_name ,c_first_name ,s_store_name ,ca_state ,s_state ,i_color ,i_current_price ,i_manager_id ,i_units ,i_size ,sum(ss_net_profit) netpaid from store_sales ,store_returns ,store ,item ,customer ,customer_address where ss_ticket_number = sr_ticket_number and ss_item_sk = sr_item_sk and ss_customer_sk = c_customer_sk and ss_item_sk = i_item_sk and ss_store_sk = s_store_sk and c_current_addr_sk = ca_address_sk and c_birth_country <> upper(ca_country) and s_zip = ca_zip and s_market_id = 10 group by c_last_name ,c_first_name ,s_store_name ,ca_state ,s_state ,i_color ,i_current_price ,i_manager_id ,i_units ,i_size) select c_last_name ,c_first_name ,s_store_name ,sum(netpaid) paid from ssales where i_color = 'sienna' group by c_last_name ,c_first_name ,s_store_name having sum(netpaid) > (select 0.05*avg(netpaid) from ssales) order by c_last_name ,c_first_name ,s_store_name ; -- end query 92 in stream 0 using template query24.tpl -- start query 93 in stream 0 using template query4.tpl with year_total as ( select c_customer_id customer_id ,c_first_name customer_first_name ,c_last_name customer_last_name ,c_preferred_cust_flag customer_preferred_cust_flag ,c_birth_country customer_birth_country ,c_login customer_login ,c_email_address customer_email_address ,d_year dyear ,sum(((ss_ext_list_price-ss_ext_wholesale_cost-ss_ext_discount_amt)+ss_ext_sales_price)/2) year_total ,'s' sale_type from customer ,store_sales ,date_dim where c_customer_sk = ss_customer_sk and ss_sold_date_sk = d_date_sk group by c_customer_id ,c_first_name ,c_last_name ,c_preferred_cust_flag ,c_birth_country ,c_login ,c_email_address ,d_year union all select c_customer_id customer_id ,c_first_name customer_first_name ,c_last_name customer_last_name ,c_preferred_cust_flag customer_preferred_cust_flag ,c_birth_country customer_birth_country ,c_login customer_login ,c_email_address customer_email_address ,d_year dyear ,sum((((cs_ext_list_price-cs_ext_wholesale_cost-cs_ext_discount_amt)+cs_ext_sales_price)/2) ) year_total ,'c' sale_type from customer ,catalog_sales ,date_dim where c_customer_sk = cs_bill_customer_sk and cs_sold_date_sk = d_date_sk group by c_customer_id ,c_first_name ,c_last_name ,c_preferred_cust_flag ,c_birth_country ,c_login ,c_email_address ,d_year union all select c_customer_id customer_id ,c_first_name customer_first_name ,c_last_name customer_last_name ,c_preferred_cust_flag customer_preferred_cust_flag ,c_birth_country customer_birth_country ,c_login customer_login ,c_email_address customer_email_address ,d_year dyear ,sum((((ws_ext_list_price-ws_ext_wholesale_cost-ws_ext_discount_amt)+ws_ext_sales_price)/2) ) year_total ,'w' sale_type from customer ,web_sales ,date_dim where c_customer_sk = ws_bill_customer_sk and ws_sold_date_sk = d_date_sk group by c_customer_id ,c_first_name ,c_last_name ,c_preferred_cust_flag ,c_birth_country ,c_login ,c_email_address ,d_year ) select t_s_secyear.customer_id ,t_s_secyear.customer_first_name ,t_s_secyear.customer_last_name ,t_s_secyear.customer_preferred_cust_flag from year_total t_s_firstyear ,year_total t_s_secyear ,year_total t_c_firstyear ,year_total t_c_secyear ,year_total t_w_firstyear ,year_total t_w_secyear where t_s_secyear.customer_id = t_s_firstyear.customer_id and t_s_firstyear.customer_id = t_c_secyear.customer_id and t_s_firstyear.customer_id = t_c_firstyear.customer_id and t_s_firstyear.customer_id = t_w_firstyear.customer_id and t_s_firstyear.customer_id = t_w_secyear.customer_id and t_s_firstyear.sale_type = 's' and t_c_firstyear.sale_type = 'c' and t_w_firstyear.sale_type = 'w' and t_s_secyear.sale_type = 's' and t_c_secyear.sale_type = 'c' and t_w_secyear.sale_type = 'w' and t_s_firstyear.dyear = 1999 and t_s_secyear.dyear = 1999+1 and t_c_firstyear.dyear = 1999 and t_c_secyear.dyear = 1999+1 and t_w_firstyear.dyear = 1999 and t_w_secyear.dyear = 1999+1 and t_s_firstyear.year_total > 0 and t_c_firstyear.year_total > 0 and t_w_firstyear.year_total > 0 and case when t_c_firstyear.year_total > 0 then t_c_secyear.year_total / t_c_firstyear.year_total else null end > case when t_s_firstyear.year_total > 0 then t_s_secyear.year_total / t_s_firstyear.year_total else null end and case when t_c_firstyear.year_total > 0 then t_c_secyear.year_total / t_c_firstyear.year_total else null end > case when t_w_firstyear.year_total > 0 then t_w_secyear.year_total / t_w_firstyear.year_total else null end order by t_s_secyear.customer_id ,t_s_secyear.customer_first_name ,t_s_secyear.customer_last_name ,t_s_secyear.customer_preferred_cust_flag LIMIT 100; -- end query 93 in stream 0 using template query4.tpl -- start query 94 in stream 0 using template query99.tpl select substr(w_warehouse_name,1,20) ,sm_type ,cc_name ,sum(case when (cs_ship_date_sk - cs_sold_date_sk <= 30 ) then 1 else 0 end) as `30 days` ,sum(case when (cs_ship_date_sk - cs_sold_date_sk > 30) and (cs_ship_date_sk - cs_sold_date_sk <= 60) then 1 else 0 end ) as `31-60 days` ,sum(case when (cs_ship_date_sk - cs_sold_date_sk > 60) and (cs_ship_date_sk - cs_sold_date_sk <= 90) then 1 else 0 end) as `61-90 days` ,sum(case when (cs_ship_date_sk - cs_sold_date_sk > 90) and (cs_ship_date_sk - cs_sold_date_sk <= 120) then 1 else 0 end) as `91-120 days` ,sum(case when (cs_ship_date_sk - cs_sold_date_sk > 120) then 1 else 0 end) as `>120 days` from catalog_sales ,warehouse ,ship_mode ,call_center ,date_dim where d_month_seq between 1222 and 1222 + 11 and cs_ship_date_sk = d_date_sk and cs_warehouse_sk = w_warehouse_sk and cs_ship_mode_sk = sm_ship_mode_sk and cs_call_center_sk = cc_call_center_sk group by substr(w_warehouse_name,1,20) ,sm_type ,cc_name order by substr(w_warehouse_name,1,20) ,sm_type ,cc_name LIMIT 100; -- end query 94 in stream 0 using template query99.tpl -- start query 95 in stream 0 using template query68.tpl select c_last_name ,c_first_name ,ca_city ,bought_city ,ss_ticket_number ,extended_price ,extended_tax ,list_price from (select ss_ticket_number ,ss_customer_sk ,ca_city bought_city ,sum(ss_ext_sales_price) extended_price ,sum(ss_ext_list_price) list_price ,sum(ss_ext_tax) extended_tax from store_sales ,date_dim ,store ,household_demographics ,customer_address where store_sales.ss_sold_date_sk = date_dim.d_date_sk and store_sales.ss_store_sk = store.s_store_sk and store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk and store_sales.ss_addr_sk = customer_address.ca_address_sk and date_dim.d_dom between 1 and 2 and (household_demographics.hd_dep_count = 6 or household_demographics.hd_vehicle_count= 1) and date_dim.d_year in (1998,1998+1,1998+2) and store.s_city in ('Midway','Pleasant Hill') group by ss_ticket_number ,ss_customer_sk ,ss_addr_sk,ca_city) dn ,customer ,customer_address current_addr where ss_customer_sk = c_customer_sk and customer.c_current_addr_sk = current_addr.ca_address_sk and current_addr.ca_city <> bought_city order by c_last_name ,ss_ticket_number LIMIT 100; -- end query 95 in stream 0 using template query68.tpl -- start query 96 in stream 0 using template query83.tpl with sr_items as (select i_item_id item_id, sum(sr_return_quantity) sr_item_qty from store_returns, item, date_dim where sr_item_sk = i_item_sk and d_date in (select d_date from date_dim where d_week_seq in (select d_week_seq from date_dim where d_date in ('1998-05-29','1998-08-19','1998-11-10'))) and sr_returned_date_sk = d_date_sk group by i_item_id), cr_items as (select i_item_id item_id, sum(cr_return_quantity) cr_item_qty from catalog_returns, item, date_dim where cr_item_sk = i_item_sk and d_date in (select d_date from date_dim where d_week_seq in (select d_week_seq from date_dim where d_date in ('1998-05-29','1998-08-19','1998-11-10'))) and cr_returned_date_sk = d_date_sk group by i_item_id), wr_items as (select i_item_id item_id, sum(wr_return_quantity) wr_item_qty from web_returns, item, date_dim where wr_item_sk = i_item_sk and d_date in (select d_date from date_dim where d_week_seq in (select d_week_seq from date_dim where d_date in ('1998-05-29','1998-08-19','1998-11-10'))) and wr_returned_date_sk = d_date_sk group by i_item_id) select sr_items.item_id ,sr_item_qty ,sr_item_qty/(sr_item_qty+cr_item_qty+wr_item_qty)/3.0 * 100 sr_dev ,cr_item_qty ,cr_item_qty/(sr_item_qty+cr_item_qty+wr_item_qty)/3.0 * 100 cr_dev ,wr_item_qty ,wr_item_qty/(sr_item_qty+cr_item_qty+wr_item_qty)/3.0 * 100 wr_dev ,(sr_item_qty+cr_item_qty+wr_item_qty)/3.0 average from sr_items ,cr_items ,wr_items where sr_items.item_id=cr_items.item_id and sr_items.item_id=wr_items.item_id order by sr_items.item_id ,sr_item_qty LIMIT 100; -- end query 96 in stream 0 using template query83.tpl -- start query 97 in stream 0 using template query61.tpl select promotions,total,cast(promotions as decimal(15,4))/cast(total as decimal(15,4))*100 from (select sum(ss_ext_sales_price) promotions from store_sales ,store ,promotion ,date_dim ,customer ,customer_address ,item where ss_sold_date_sk = d_date_sk and ss_store_sk = s_store_sk and ss_promo_sk = p_promo_sk and ss_customer_sk= c_customer_sk and ca_address_sk = c_current_addr_sk and ss_item_sk = i_item_sk and ca_gmt_offset = -6 and i_category = 'Sports' and (p_channel_dmail = 'Y' or p_channel_email = 'Y' or p_channel_tv = 'Y') and s_gmt_offset = -6 and d_year = 1998 and d_moy = 12) promotional_sales, (select sum(ss_ext_sales_price) total from store_sales ,store ,date_dim ,customer ,customer_address ,item where ss_sold_date_sk = d_date_sk and ss_store_sk = s_store_sk and ss_customer_sk= c_customer_sk and ca_address_sk = c_current_addr_sk and ss_item_sk = i_item_sk and ca_gmt_offset = -6 and i_category = 'Sports' and s_gmt_offset = -6 and d_year = 1998 and d_moy = 12) all_sales order by promotions, total LIMIT 100; -- end query 97 in stream 0 using template query61.tpl -- start query 98 in stream 0 using template query5.tpl with ssr as (select s_store_id, sum(sales_price) as sales, sum(profit) as profit, sum(return_amt) as returns, sum(net_loss) as profit_loss from ( select ss_store_sk as store_sk, ss_sold_date_sk as date_sk, ss_ext_sales_price as sales_price, ss_net_profit as profit, cast(0 as decimal(7,2)) as return_amt, cast(0 as decimal(7,2)) as net_loss from store_sales union all select sr_store_sk as store_sk, sr_returned_date_sk as date_sk, cast(0 as decimal(7,2)) as sales_price, cast(0 as decimal(7,2)) as profit, sr_return_amt as return_amt, sr_net_loss as net_loss from store_returns ) salesreturns, date_dim, store where date_sk = d_date_sk and d_date between cast('1998-08-21' as date) and (cast('1998-08-21' as date) + interval 14 days) and store_sk = s_store_sk group by s_store_id) , csr as (select cp_catalog_page_id, sum(sales_price) as sales, sum(profit) as profit, sum(return_amt) as returns, sum(net_loss) as profit_loss from ( select cs_catalog_page_sk as page_sk, cs_sold_date_sk as date_sk, cs_ext_sales_price as sales_price, cs_net_profit as profit, cast(0 as decimal(7,2)) as return_amt, cast(0 as decimal(7,2)) as net_loss from catalog_sales union all select cr_catalog_page_sk as page_sk, cr_returned_date_sk as date_sk, cast(0 as decimal(7,2)) as sales_price, cast(0 as decimal(7,2)) as profit, cr_return_amount as return_amt, cr_net_loss as net_loss from catalog_returns ) salesreturns, date_dim, catalog_page where date_sk = d_date_sk and d_date between cast('1998-08-21' as date) and (cast('1998-08-21' as date) + interval 14 days) and page_sk = cp_catalog_page_sk group by cp_catalog_page_id) , wsr as (select web_site_id, sum(sales_price) as sales, sum(profit) as profit, sum(return_amt) as returns, sum(net_loss) as profit_loss from ( select ws_web_site_sk as wsr_web_site_sk, ws_sold_date_sk as date_sk, ws_ext_sales_price as sales_price, ws_net_profit as profit, cast(0 as decimal(7,2)) as return_amt, cast(0 as decimal(7,2)) as net_loss from web_sales union all select ws_web_site_sk as wsr_web_site_sk, wr_returned_date_sk as date_sk, cast(0 as decimal(7,2)) as sales_price, cast(0 as decimal(7,2)) as profit, wr_return_amt as return_amt, wr_net_loss as net_loss from web_returns left outer join web_sales on ( wr_item_sk = ws_item_sk and wr_order_number = ws_order_number) ) salesreturns, date_dim, web_site where date_sk = d_date_sk and d_date between cast('1998-08-21' as date) and (cast('1998-08-21' as date) + interval 14 days) and wsr_web_site_sk = web_site_sk group by web_site_id) select channel , id , sum(sales) as sales , sum(returns) as returns , sum(profit) as profit from (select 'store channel' as channel , 'store' || s_store_id as id , sales , returns , (profit - profit_loss) as profit from ssr union all select 'catalog channel' as channel , 'catalog_page' || cp_catalog_page_id as id , sales , returns , (profit - profit_loss) as profit from csr union all select 'web channel' as channel , 'web_site' || web_site_id as id , sales , returns , (profit - profit_loss) as profit from wsr ) x group by rollup (channel, id) order by channel ,id LIMIT 100; -- end query 98 in stream 0 using template query5.tpl -- start query 99 in stream 0 using template query76.tpl select channel, col_name, d_year, d_qoy, i_category, COUNT(*) sales_cnt, SUM(ext_sales_price) sales_amt FROM ( SELECT 'store' as channel, 'ss_addr_sk' col_name, d_year, d_qoy, i_category, ss_ext_sales_price ext_sales_price FROM store_sales, item, date_dim WHERE ss_addr_sk IS NULL AND ss_sold_date_sk=d_date_sk AND ss_item_sk=i_item_sk UNION ALL SELECT 'web' as channel, 'ws_web_page_sk' col_name, d_year, d_qoy, i_category, ws_ext_sales_price ext_sales_price FROM web_sales, item, date_dim WHERE ws_web_page_sk IS NULL AND ws_sold_date_sk=d_date_sk AND ws_item_sk=i_item_sk UNION ALL SELECT 'catalog' as channel, 'cs_ship_mode_sk' col_name, d_year, d_qoy, i_category, cs_ext_sales_price ext_sales_price FROM catalog_sales, item, date_dim WHERE cs_ship_mode_sk IS NULL AND cs_sold_date_sk=d_date_sk AND cs_item_sk=i_item_sk) foo GROUP BY channel, col_name, d_year, d_qoy, i_category ORDER BY channel, col_name, d_year, d_qoy, i_category LIMIT 100; -- end query 99 in stream 0 using template query76.tpl ================================================ FILE: examples/spark-connect-gpu/client/notebook/README.md ================================================ # Demo Notebook Overview The `spark-connect-gpu-etl-ml.ipynb` notebook demonstrates: ## ETL Pipeline - **Data ingestion** from CSV with custom schema - **Complex transformations** including date parsing and delinquency calculations - **String-to-numeric encoding** for categorical features - **Data joins and aggregations** with mortgage performance data ## Machine Learning Workflow - **Feature engineering** with FeatureHasher and VectorAssembler - **Logistic Regression** training for multi-class prediction - **Model evaluation** with performance metrics - **GPU vs CPU timing comparisons** ## Key Code Examples **Connecting to Spark with GPU acceleration:** ```python from pyspark.sql import SparkSession spark = ( SparkSession.builder .remote('sc://spark-connect-server') .appName('GPU-Accelerated-ETL-ML-Demo') .getOrCreate() ) ``` In the actual demo code we find it handier to use the `SPARK_REMOTE` environment variable instead of having it in the code so it is easy to run it in a Spark Classic way as well. **Machine Learning with GPU acceleration:** ```python from pyspark.ml import Pipeline from pyspark.ml.classification import LogisticRegression from pyspark.ml.feature import VectorAssembler, FeatureHasher spark.conf.set('spark.connect.ml.backend.classes', 'com.nvidia.rapids.ml.Plugin') # Feature preparation hasher = FeatureHasher(inputCols=categorical_cols, outputCol='hashed_categorical') assembler = VectorAssembler().setInputCols(numerical_cols + ['hashed_categorical']).setOutputCol('features') # Model training logistic = LogisticRegression().setFeaturesCol('features').setLabelCol('delinquency_12') pipeline = Pipeline().setStages([hasher, assembler, logistic]) model = pipeline.fit(training_data) ``` ## Results The demo at the Data+AI Summit'25 used the following mortgage quarters ```bash $ du -h * 503M 2023Q1.csv 412M 2023Q2.csv 162M 2023Q3.csv 1.1G 2023Q4.csv ``` and was tested on a machine with a 6GiB RTX A3000 Laptop GPU ```bash $ nvidia-smi +-----------------------------------------------------------------------------------------+ | NVIDIA-SMI 560.35.05 Driver Version: 560.35.05 CUDA Version: 12.6 | |-----------------------------------------+------------------------+----------------------+ | GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |=========================================+========================+======================| | 0 NVIDIA RTX A3000 Laptop GPU Off | 00000000:01:00.0 Off | N/A | | N/A 56C P8 13W / 60W | 1353MiB / 6144MiB | 1% Default | | | | N/A | +-----------------------------------------+------------------------+----------------------+ ``` and a 2x8-core CPU ![GPU Acceleration Results](example-acceleration-chart.png) ================================================ FILE: examples/spark-connect-gpu/client/notebook/spark-connect-gpu-etl-ml.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# GPU-Accelerated Spark Connect - SQL/DF ETL and MLlib on Mortgage Dataset (Spark 4.0+)\n", "\n", "Based on the Data and AI Summit 2025 session: [GPU Accelerated Spark Connect](https://www.databricks.com/dataaisummit/session/gpu-accelerated-spark-connect)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Import packages" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from pyspark.ml import Pipeline\n", "from pyspark.ml.classification import LogisticRegression\n", "from pyspark.ml.evaluation import MulticlassClassificationEvaluator\n", "from pyspark.ml.feature import VectorAssembler, FeatureHasher\n", "from pyspark.sql import SparkSession\n", "from pyspark.sql.functions import *\n", "from pyspark.sql.types import IntegerType\n", "from pyspark.sql.window import Window\n", "import csv\n", "import os\n", "import pandas as pd\n", "import time" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Connect to Spark via Spark Connect\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Create GPU-accelerated Spark session using Spark Connect 4.0+\n", "spark = (\n", " SparkSession.builder\n", " .appName('GPU-Accelerated Spark Connect - SQL/ETL and MLlib') \n", " .getOrCreate()\n", ")\n", "print(f'Spark Connect session id: {spark.session_id}')\n", "print(f'Spark version: {spark.version}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Local and Global Storage Access " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# This can be a local storage location accessible to the thin Spark Connect app\n", "# such as a IPython kernel\n", "local_data_dir = 'work'\n", "\n", "# This would normally be a global storage location such as Cloud Object Storage\n", "# This notebook requires a writable directory on the host. It is mounted into containers\n", "# requiring access to it as /data from the host \n", "# This directory should contain directory `mortgage.input.csv` with files from the Mortgage dataset.\n", "# We also store here data useful across the container life cycle such as metrics from the previous runs\n", "# and Spark event logs. \n", "global_data_dir = '/data'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Normalize references to the same bank " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "with open(f'{local_data_dir}/name_mapping.csv', 'r') as name_mapping_file:\n", " nm_reader = csv.reader(name_mapping_file,)\n", " name_mapping = [r for r in nm_reader]\n", "name_mapping_df = spark.createDataFrame(name_mapping, ['from_seller_name', 'to_seller_name'])\n", "\n", "(\n", " name_mapping_df\n", " .where(col('to_seller_name') == 'Wells Fargo' )\n", " .show(truncate=False)\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# String columns\n", "cate_col_names = [\n", " 'orig_channel',\n", " 'first_home_buyer',\n", " 'loan_purpose',\n", " 'property_type',\n", " 'occupancy_status',\n", " 'property_state',\n", " 'product_type',\n", " 'relocation_mortgage_indicator',\n", " 'seller_name',\n", " 'mod_flag'\n", "]\n", "# Numeric columns\n", "label_col_name = 'delinquency_12'\n", "numeric_col_names = [\n", " 'orig_interest_rate',\n", " 'orig_upb',\n", " 'orig_loan_term',\n", " 'orig_ltv',\n", " 'orig_cltv',\n", " 'num_borrowers',\n", " 'dti',\n", " 'borrower_credit_score',\n", " 'num_units',\n", " 'zip',\n", " 'mortgage_insurance_percent',\n", " 'current_loan_delinquency_status',\n", " 'current_actual_upb',\n", " 'interest_rate',\n", " 'loan_age',\n", " 'msa',\n", " 'non_interest_bearing_upb',\n", " label_col_name\n", "]\n", "all_col_names = cate_col_names + numeric_col_names" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define ETL Process" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Functions to read raw columns" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def read_raw_csv(spark, path):\n", " def _get_quarter_from_csv_file_name():\n", " return substring_index(substring_index(input_file_name(), '.', 1), '/', -1)\n", "\n", " with open(f'{local_data_dir}/csv_raw_schema.ddl', 'r') as f:\n", " _csv_raw_schema_str = f.read()\n", " \n", " return (\n", " spark.read\n", " .format('csv') \n", " .option('nullValue', '') \n", " .option('header', False) \n", " .option('delimiter', '|') \n", " .schema(_csv_raw_schema_str) \n", " .load(path) \n", " .withColumn('quarter', _get_quarter_from_csv_file_name())\n", " )\n", "\n", "def extract_perf_columns(rawDf):\n", " perfDf = rawDf.select(\n", " col('loan_id'),\n", " date_format(to_date(col('monthly_reporting_period'),'MMyyyy'), 'MM/dd/yyyy').alias('monthly_reporting_period'),\n", " upper(col('servicer')).alias('servicer'),\n", " col('interest_rate'),\n", " col('current_actual_upb'),\n", " col('loan_age'),\n", " col('remaining_months_to_legal_maturity'),\n", " col('adj_remaining_months_to_maturity'),\n", " date_format(to_date(col('maturity_date'),'MMyyyy'), 'MM/yyyy').alias('maturity_date'),\n", " col('msa'),\n", " col('current_loan_delinquency_status'),\n", " col('mod_flag'),\n", " col('zero_balance_code'),\n", " date_format(to_date(col('zero_balance_effective_date'),'MMyyyy'), 'MM/yyyy').alias('zero_balance_effective_date'),\n", " date_format(to_date(col('last_paid_installment_date'),'MMyyyy'), 'MM/dd/yyyy').alias('last_paid_installment_date'),\n", " date_format(to_date(col('foreclosed_after'),'MMyyyy'), 'MM/dd/yyyy').alias('foreclosed_after'),\n", " date_format(to_date(col('disposition_date'),'MMyyyy'), 'MM/dd/yyyy').alias('disposition_date'),\n", " col('foreclosure_costs'),\n", " col('prop_preservation_and_repair_costs'),\n", " col('asset_recovery_costs'),\n", " col('misc_holding_expenses'),\n", " col('holding_taxes'),\n", " col('net_sale_proceeds'),\n", " col('credit_enhancement_proceeds'),\n", " col('repurchase_make_whole_proceeds'),\n", " col('other_foreclosure_proceeds'),\n", " col('non_interest_bearing_upb'),\n", " col('principal_forgiveness_upb'),\n", " col('repurchase_make_whole_proceeds_flag'),\n", " col('foreclosure_principal_write_off_amount'),\n", " col('servicing_activity_indicator'),\n", " col('quarter')\n", " )\n", " return perfDf.select('*').filter('current_actual_upb != 0.0')\n", "\n", "def extract_acq_columns(rawDf):\n", " acqDf = rawDf.select(\n", " col('loan_id'),\n", " col('orig_channel'),\n", " upper(col('seller_name')).alias('seller_name'),\n", " col('orig_interest_rate'),\n", " col('orig_upb'),\n", " col('orig_loan_term'),\n", " date_format(to_date(col('orig_date'),'MMyyyy'), 'MM/yyyy').alias('orig_date'),\n", " date_format(to_date(col('first_pay_date'),'MMyyyy'), 'MM/yyyy').alias('first_pay_date'),\n", " col('orig_ltv'),\n", " col('orig_cltv'),\n", " col('num_borrowers'),\n", " col('dti'),\n", " col('borrower_credit_score'),\n", " col('first_home_buyer'),\n", " col('loan_purpose'),\n", " col('property_type'),\n", " col('num_units'),\n", " col('occupancy_status'),\n", " col('property_state'),\n", " col('zip'),\n", " col('mortgage_insurance_percent'),\n", " col('product_type'),\n", " col('coborrow_credit_score'),\n", " col('mortgage_insurance_type'),\n", " col('relocation_mortgage_indicator'),\n", " dense_rank().over(Window.partitionBy('loan_id').orderBy(to_date(col('monthly_reporting_period'),'MMyyyy'))).alias('rank'),\n", " col('quarter')\n", " )\n", "\n", " return acqDf.select('*').filter(col('rank')==1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define function to parse date in Performance data " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def _parse_dates(perf):\n", " return (\n", " perf.withColumn('monthly_reporting_period', to_date(col('monthly_reporting_period'), 'MM/dd/yyyy')) \n", " .withColumn('monthly_reporting_period_month', month(col('monthly_reporting_period'))) \n", " .withColumn('monthly_reporting_period_year', year(col('monthly_reporting_period'))) \n", " .withColumn('monthly_reporting_period_day', dayofmonth(col('monthly_reporting_period'))) \n", " .withColumn('last_paid_installment_date', to_date(col('last_paid_installment_date'), 'MM/dd/yyyy')) \n", " .withColumn('foreclosed_after', to_date(col('foreclosed_after'), 'MM/dd/yyyy')) \n", " .withColumn('disposition_date', to_date(col('disposition_date'), 'MM/dd/yyyy')) \n", " .withColumn('maturity_date', to_date(col('maturity_date'), 'MM/yyyy')) \n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define function to create deliquency data frame from Performance data. \n", "\n", "The computed `delinquency_12` column denotes whether a loan will become delinquent by 3, 6, or 9 months, \n", "or not delinquent, within the next 12 month period. \n", "\n", "It will be the target label for ML multi-class prediction." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def _create_perf_deliquency(spark, perf):\n", " aggDF = (\n", " perf\n", " .select(\n", " col('quarter'),\n", " col('loan_id'),\n", " col('current_loan_delinquency_status'),\n", " when(col('current_loan_delinquency_status') >= 1, col('monthly_reporting_period')).alias('delinquency_30'),\n", " when(col('current_loan_delinquency_status') >= 3, col('monthly_reporting_period')).alias('delinquency_90'),\n", " when(col('current_loan_delinquency_status') >= 6, col('monthly_reporting_period')).alias('delinquency_180')\n", " ).groupBy('quarter', 'loan_id')\n", " .agg(\n", " max('current_loan_delinquency_status').alias('delinquency_12'),\n", " min('delinquency_30').alias('delinquency_30'),\n", " min('delinquency_90').alias('delinquency_90'),\n", " min('delinquency_180').alias('delinquency_180')\n", " ).select(\n", " col('quarter'),\n", " col('loan_id'),\n", " (col('delinquency_12') >= 1).alias('ever_30'),\n", " (col('delinquency_12') >= 3).alias('ever_90'),\n", " (col('delinquency_12') >= 6).alias('ever_180'),\n", " col('delinquency_30'),\n", " col('delinquency_90'),\n", " col('delinquency_180')\n", " )\n", " )\n", " #aggDF.printSchema()\n", " joinedDf = (\n", " perf\n", " .withColumnRenamed('monthly_reporting_period', 'timestamp')\n", " .withColumnRenamed('monthly_reporting_period_month', 'timestamp_month') \n", " .withColumnRenamed('monthly_reporting_period_year', 'timestamp_year') \n", " .withColumnRenamed('current_loan_delinquency_status', 'delinquency_12') \n", " .withColumnRenamed('current_actual_upb', 'upb_12') \n", " .select('quarter', 'loan_id', 'timestamp', 'delinquency_12', 'upb_12', 'timestamp_month', 'timestamp_year') \n", " .join(aggDF, ['loan_id', 'quarter'], 'left_outer')\n", " )\n", " # calculate the 12 month delinquency and upb values\n", " months = 12\n", " monthArray = [lit(x) for x in range(0, 12)]\n", " \n", " testDf = ( \n", " joinedDf\n", " .withColumn('month_y', explode(array(monthArray)))\n", " .select(\n", " col('quarter'),\n", " floor(((col('timestamp_year') * 12 + col('timestamp_month')) - 24000) / months).alias('josh_mody'),\n", " floor(((col('timestamp_year') * 12 + col('timestamp_month')) - 24000 - col('month_y')) / months).alias('josh_mody_n'),\n", " col('ever_30'),\n", " col('ever_90'),\n", " col('ever_180'),\n", " col('delinquency_30'),\n", " col('delinquency_90'),\n", " col('delinquency_180'),\n", " col('loan_id'),\n", " col('month_y'),\n", " col('delinquency_12'),\n", " col('upb_12')\n", " ).groupBy('quarter', 'loan_id', 'josh_mody_n', 'ever_30', 'ever_90', 'ever_180', 'delinquency_30', 'delinquency_90', 'delinquency_180', 'month_y')\n", " .agg(max('delinquency_12').alias('delinquency_12'), min('upb_12').alias('upb_12')) \n", " .withColumn('timestamp_year', floor((lit(24000) + (col('josh_mody_n') * lit(months)) + (col('month_y') - 1)) / lit(12))) \n", " .selectExpr('*', f'pmod(24000 + (josh_mody_n * {months}) + month_y, 12) as timestamp_month_tmp') \n", " .withColumn('timestamp_month', when(col('timestamp_month_tmp') == lit(0), lit(12)).otherwise(col('timestamp_month_tmp'))) \n", " .withColumn('delinquency_12', ((col('delinquency_12') > 9).cast('int') + (col('delinquency_12') > 6).cast('int') + (col('delinquency_12') > 3).cast('int') + (col('upb_12') == 0).cast('int')).alias('delinquency_12')) \n", " .drop('timestamp_month_tmp', 'josh_mody_n', 'month_y')\n", " )\n", "\n", " return (\n", " perf\n", " .withColumnRenamed('monthly_reporting_period_month', 'timestamp_month')\n", " .withColumnRenamed('monthly_reporting_period_year', 'timestamp_year')\n", " .join(testDf, ['quarter', 'loan_id', 'timestamp_year', 'timestamp_month'], 'left')\n", " .drop('timestamp_year', 'timestamp_month')\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define function to create acquisition data frame from Acquisition data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def _create_acquisition(spark, acq):\n", " return (\n", " acq.join(name_mapping_df, col('seller_name') == col('from_seller_name'), 'left')\n", " .drop('from_seller_name') \n", " .withColumn('old_name', col('seller_name')) \n", " .withColumn('seller_name', coalesce(col('to_seller_name'), col('seller_name'))) \n", " .drop('to_seller_name') \n", " .withColumn('orig_date', to_date(col('orig_date'), 'MM/yyyy')) \n", " .withColumn('first_pay_date', to_date(col('first_pay_date'), 'MM/yyyy')) \n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define Casting Process\n", "\n", "\n", "This part is casting String column to Numeric one. \n", "Example:\n", "```\n", "col_1\n", " \"a\"\n", " \"b\"\n", " \"c\"\n", " \"a\"\n", "# After String ====> Numeric\n", "col_1\n", " 0\n", " 1\n", " 2\n", " 0\n", "``` \n", "\n", "### Define function to get column dictionary\n", "\n", "Example\n", "\n", "```\n", "col1 = [row(data=\"a\",id=0), row(data=\"b\",id=1)]\n", "```" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def _gen_dictionary(etl_df, col_names):\n", " cnt_table = (\n", " etl_df.select(posexplode(array([col(i) for i in col_names])))\n", " .withColumnRenamed('pos', 'column_id')\n", " .withColumnRenamed('col', 'data')\n", " .filter('data is not null')\n", " .groupBy('column_id', 'data')\n", " .count()\n", " )\n", " windowed = Window.partitionBy('column_id').orderBy(desc('count'))\n", " return cnt_table.withColumn('id', row_number().over(windowed)).drop('count')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define function to convert string columns to numeric\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def _cast_string_columns_to_numeric(spark, input_df):\n", " cached_dict_df = _gen_dictionary(input_df, cate_col_names) \n", " # .cache()\n", " # Uncomment above line to cache the dictionary dataframe. You need to spark.catalog.clearCache()\n", " # when running the notebook multiple times switching between CPU and GPU.\n", " \n", " output_df = input_df\n", " # Generate the final table with all columns being numeric.\n", " for col_pos, col_name in enumerate(cate_col_names):\n", " col_dict_df = (\n", " cached_dict_df.filter(col('column_id') == col_pos)\n", " .drop('column_id')\n", " .withColumnRenamed('data', col_name)\n", " )\n", " output_df = (\n", " output_df.join(broadcast(col_dict_df), col_name, 'left')\n", " .drop(col_name)\n", " .withColumnRenamed('id', col_name)\n", " )\n", " return output_df " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define Main Function\n", "\n", "In this function:\n", "1. Parse date in Performance data by calling _parse_dates (parsed_perf)\n", "2. Create deliqency dataframe(perf_deliqency) form Performance data by calling _create_perf_deliquency\n", "3. Create cleaned acquisition dataframe(cleaned_acq) from Acquisition data by calling _create_acquisition\n", "4. Join deliqency dataframe(perf_deliqency) and cleaned acquisition dataframe(cleaned_acq), get clean_df\n", "5. Cast String column to Numeric in clean_df by calling _cast_string_columns_to_numeric, get casted_clean_df\n", "6. Return casted_clean_df as final result" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def run_mortgage(spark, perf, acq):\n", " parsed_perf = _parse_dates(perf)\n", " perf_deliqency = _create_perf_deliquency(spark, parsed_perf)\n", " cleaned_acq = _create_acquisition(spark, acq)\n", " clean_df = perf_deliqency.join(cleaned_acq, ['loan_id', 'quarter'], 'inner').drop('quarter')\n", " casted_clean_df = (\n", " _cast_string_columns_to_numeric(spark, clean_df)\n", " .select(all_col_names)\n", " .withColumn(label_col_name, when(col(label_col_name) > 0, col(label_col_name)).otherwise(0))\n", " .fillna(float(0))\n", " )\n", " return casted_clean_df" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Knobs for running the pipelines" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Should raw csv input be used or input persisted to Parquet \n", "read_from_csv = False\n", "# if not read_from_csv, include conversion to Parquet in this run?\n", "convert_csv_to_parquet = True" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Execute SQL and ML on GPU ?\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": "accelerate_on_gpu = True" }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### ETL on GPU?" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "spark.conf.set('spark.rapids.sql.enabled', accelerate_on_gpu) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### ML on GPU?" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if accelerate_on_gpu:\n", " spark.conf.set('spark.connect.ml.backend.classes', 'com.nvidia.rapids.ml.Plugin')\n", "else:\n", " spark.conf.unset('spark.connect.ml.backend.classes')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Run ETL Pipeline\n", "\n", "#### Read Raw Data and Run ETL Process, Save the Result\n", "\n", "##### Convert CSV to Parquet" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if read_from_csv:\n", " mortgage_csv = read_raw_csv(spark, f'{global_data_dir}/mortgage.input.csv')\n", "elif convert_csv_to_parquet:\n", " read_raw_csv(spark, f'{global_data_dir}/mortgage.input.csv')\\\n", " .write.parquet(f'{global_data_dir}/mortgage_input.pq', mode='overwrite')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "##### ETL from Parquet or raw CSV Data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "mortgage = mortgage_csv if read_from_csv else spark.read.parquet(f'{global_data_dir}/mortgage_input.pq')\n", "acq = extract_acq_columns(mortgage)\n", "perf = extract_perf_columns(mortgage)\n", "# run main function to process data\n", "preprocessed = run_mortgage(spark, perf, acq)\n", "# save processed data\n", "\n", "start = time.time()\n", "preprocessed.write.parquet(f'{global_data_dir}/mortgage_preprocessed.pq' , mode='overwrite')\n", "end = time.time()\n", "\n", "etl_dur = end - start\n", "print(f'ETL takes {etl_dur}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Modeling Pipeline\n", "\n", "#### The ML modeling phase of the example uses the `spark.ml` Pipeline API to carry out the following steps on a random subsample of the ETL output:\n", " - use `spark.ml FeatureHasher` to map the int type columns in the ETL output to a 2^15 dimensional sparse feature vector with a non-zero entry in each location corresponding to hash value of each input column value + column name.\n", " - use `spark.ml VectorAssembler` to combine the output of `FeatureHasher` with the original float type columns into a single `VectorUDT` type feature vector\n", " - train a model using `LogisticRegression` to predict the multi-class (4 class values) label \"delinquency_12\"." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "etlDf = spark.read.parquet(f'{global_data_dir}/mortgage_preprocessed.pq')\n", "etlDf = etlDf.sample(fraction=0.1, seed=1234)\n", "etlDf.describe().filter(col('summary') == 'mean').show(vertical=True, truncate=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "etlDf = etlDf.withColumn('loc',(etlDf.msa*1000+etlDf.zip).cast('int')).drop('zip' ,'msa')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "label_col_name = 'delinquency_12'\n", "schema = etlDf.schema\n", "raw_features = [ x for x in schema.fields if x.name != label_col_name ]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "categorical_cols = [f.name for f in raw_features if f.dataType == IntegerType()]\n", "numerical_cols = [f.name for f in raw_features if f.name not in categorical_cols]\n", "hasher = FeatureHasher(inputCols=categorical_cols, outputCol='hashed_categorical', \n", " categoricalCols=categorical_cols, numFeatures=(1 << 15))\n", "va = VectorAssembler().setInputCols(numerical_cols + [hasher.getOutputCol()]).setOutputCol('features')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "logistic = ( \n", " LogisticRegression()\n", " .setMaxIter(200)\n", " .setRegParam(0.00002)\n", " .setElasticNetParam(0.1)\n", " .setTol(1.0e-12)\n", " .setFeaturesCol('features')\n", " .setLabelCol(label_col_name)\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "[df_train, df_test] = etlDf.randomSplit([0.8, 0.2], seed=1234)\n", "pipeline = Pipeline().setStages([hasher, va, logistic])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "start = time.time()\n", "# gpu lr, gpu etl, gpu transform, 200 iters, double precision, elasticnet=0.1, featurehasher, 0.1 sample, multiclass, float64\n", "pipeline_model = pipeline.fit(df_train)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "predictions = pipeline_model.transform(df_test)\n", "predictions.sample(0.1).show(1, vertical=True, truncate=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "evaluator = MulticlassClassificationEvaluator().setMetricName('logLoss').setLabelCol(label_col_name)\n", "eval_res = evaluator.evaluate(predictions)\n", "end = time.time()\n", "print(f'Evaluation result: {eval_res}')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ml_dur = end - start\n", "print(f'ML takes {ml_dur}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Save current run times " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Preserve across container restarts\n", "\n", "local_gpu_times_file = f'gpu_times.csv'\n", "local_cpu_times_file = f'cpu_times.csv'\n", "\n", "run_times = pd.Series({'etl' : etl_dur, 'ml' : ml_dur})\n", "run_times.to_csv(local_gpu_times_file if accelerate_on_gpu else local_cpu_times_file, index=True, header=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Visualize acceleration" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if os.path.exists(local_cpu_times_file) and os.path.exists(local_gpu_times_file):\n", " cpu_times = pd.read_csv(local_cpu_times_file, header=None, index_col=0)\n", " gpu_times = pd.read_csv(local_gpu_times_file, header=None, index_col=0)\n", " gpu_speedup = cpu_times / gpu_times\n", " gpu_speedup.plot(kind='bar', \n", " title='GPU Acceleration Factor (> 1.0 is good)', \n", " color='#76B900', \n", " legend=False)\n", " cpu_times = cpu_times[1].rename('cpu')\n", " gpu_times = gpu_times[1].rename('gpu')\n", " times = pd.DataFrame([cpu_times, gpu_times]).transpose()\n", " times.plot(kind='bar', \n", " title = 'ETL and ML elapsed times for CPU and GPU (lower is better)', \n", " color=['blue', '#76B900'])" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.6" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: examples/spark-connect-gpu/client/notebook/work/csv_raw_schema.ddl ================================================ reference_pool_id STRING, loan_id BIGINT, monthly_reporting_period STRING, orig_channel STRING, seller_name STRING, servicer STRING, master_servicer STRING, orig_interest_rate DOUBLE, interest_rate DOUBLE, orig_upb DOUBLE, upb_at_issuance STRING, current_actual_upb DOUBLE, orig_loan_term INT, orig_date STRING, first_pay_date STRING, loan_age DOUBLE, remaining_months_to_legal_maturity DOUBLE, adj_remaining_months_to_maturity DOUBLE, maturity_date STRING, orig_ltv DOUBLE, orig_cltv DOUBLE, num_borrowers DOUBLE, dti DOUBLE, borrower_credit_score DOUBLE, coborrow_credit_score DOUBLE, first_home_buyer STRING, loan_purpose STRING, property_type STRING, num_units INT, occupancy_status STRING, property_state STRING, msa DOUBLE, zip INT, mortgage_insurance_percent DOUBLE, product_type STRING, prepayment_penalty_indicator STRING, interest_only_loan_indicator STRING, interest_only_first_principal_and_interest_payment_date STRING, months_to_amortization STRING, current_loan_delinquency_status INT, loan_payment_history STRING, mod_flag STRING, mortgage_insurance_cancellation_indicator STRING, zero_balance_code STRING, zero_balance_effective_date STRING, upb_at_the_time_of_removal STRING, repurchase_date STRING, scheduled_principal_current STRING, total_principal_current STRING, unscheduled_principal_current STRING, last_paid_installment_date STRING, foreclosed_after STRING, disposition_date STRING, foreclosure_costs DOUBLE, prop_preservation_and_repair_costs DOUBLE, asset_recovery_costs DOUBLE, misc_holding_expenses DOUBLE, holding_taxes DOUBLE, net_sale_proceeds DOUBLE, credit_enhancement_proceeds DOUBLE, repurchase_make_whole_proceeds STRING, other_foreclosure_proceeds DOUBLE, non_interest_bearing_upb DOUBLE, principal_forgiveness_upb STRING, original_list_start_date STRING, original_list_price STRING, current_list_start_date STRING, current_list_price STRING, borrower_credit_score_at_issuance STRING, `co-borrower_credit_score_at_issuance` STRING, borrower_credit_score_current STRING, `co-Borrower_credit_score_current` STRING, mortgage_insurance_type DOUBLE, servicing_activity_indicator STRING, current_period_modification_loss_amount STRING, cumulative_modification_loss_amount STRING, current_period_credit_event_net_gain_or_loss STRING, cumulative_credit_event_net_gain_or_loss STRING, homeready_program_indicator STRING, foreclosure_principal_write_off_amount STRING, relocation_mortgage_indicator STRING, zero_balance_code_change_date STRING, loan_holdback_indicator STRING, loan_holdback_effective_date STRING, delinquent_accrued_interest STRING, property_valuation_method STRING, high_balance_loan_indicator STRING, `arm_initial_fixed-rate_period_lt_5_yr_indicator` STRING, arm_product_type STRING, `initial_fixed-rate_period` STRING, interest_rate_adjustment_frequency STRING, next_interest_rate_adjustment_date STRING, next_payment_change_date STRING, index STRING, arm_cap_structure STRING, initial_interest_rate_cap_up_percent STRING, periodic_interest_rate_cap_up_percent STRING, lifetime_interest_rate_cap_up_percent STRING, mortgage_margin STRING, arm_balloon_indicator STRING, arm_plan_number STRING, borrower_assistance_plan STRING, hltv_refinance_option_indicator STRING, deal_name STRING, repurchase_make_whole_proceeds_flag STRING, alternative_delinquency_resolution STRING, alternative_delinquency_resolution_count STRING, total_deferral_amount STRING ================================================ FILE: examples/spark-connect-gpu/client/notebook/work/name_mapping.csv ================================================ "WITMER FUNDING, LLC",Witmer WELLS FARGO CREDIT RISK TRANSFER SECURITIES TRUST 2015,Wells Fargo "WELLS FARGO BANK, NA",Wells Fargo "WELLS FARGO BANK, N.A.",Wells Fargo "WELLS FARGO BANK, NA",Wells Fargo USAA FEDERAL SAVINGS BANK,USAA "UNITED SHORE FINANCIAL SERVICES, LLC D\/B\/A UNITED WHOLESALE MORTGAGE",United Seq(e U.S. BANK N.A.,US Bank SUNTRUST MORTGAGE INC.,Suntrust STONEGATE MORTGAGE CORPORATION,Stonegate Mortgage "STEARNS LENDING, LLC",Stearns Lending "STEARNS LENDING, INC.",Stearns Lending "SIERRA PACIFIC MORTGAGE COMPANY, INC.",Sierra Pacific Mortgage REGIONS BANK,Regions RBC MORTGAGE COMPANY,RBC QUICKEN LOANS INC.,Quicken Loans "PULTE MORTGAGE, L.L.C.",Pulte Mortgage "PROVIDENT FUNDING ASSOCIATES, L.P.",Provident Funding "PROSPECT MORTGAGE, LLC",Prospect Mortgage "PRINCIPAL RESIDENTIAL MORTGAGE CAPITAL RESOURCES, LLC",Principal Residential "PNC BANK, N.A.",PNC PMT CREDIT RISK TRANSFER TRUST 2015-2,PennyMac PHH MORTGAGE CORPORATION,PHH Mortgage PENNYMAC CORP.,PennyMac "PACIFIC UNION FINANCIAL, LLC",Other OTHER,Other "NYCB MORTGAGE COMPANY, LLC",NYCB NEW YORK COMMUNITY BANK,NYCB NETBANK FUNDING SERVICES,Netbank "NATIONSTAR MORTGAGE, LLC",Nationstar Mortgage "METLIFE BANK, NA",Metlife "LOANDEPOT.COM, LLC",LoanDepot.com "J.P. MORGAN MADISON AVENUE SECURITIES TRUST, SERIES 2015-1",JP Morgan Chase "J.P. MORGAN MADISON AVENUE SECURITIES TRUST, SERIES 2014-1",JP Morgan Chase "JPMORGAN CHASE BANK, NATIONAL ASSOCIATION",JP Morgan Chase "JPMORGAN CHASE BANK, NA",JP Morgan Chase "JP MORGAN CHASE BANK, NA",JP Morgan Chase "IRWIN MORTGAGE, CORPORATION",Irwin Mortgage IMPAC MORTGAGE CORP.,Impac Mortgage "HSBC BANK USA, NATIONAL ASSOCIATION",HSBC "HOMEWARD RESIDENTIAL, INC.",Homeward Mortgage HOMESTREET BANK,Other "HOMEBRIDGE FINANCIAL SERVICES, INC.",HomeBridge "HARWOOD STREET FUNDING I, LLC",Harwood Mortgage GUILD MORTGAGE COMPANY,Guild Mortgage "GMAC MORTGAGE, LLC (USAA FEDERAL SAVINGS BANK)",GMAC "GMAC MORTGAGE, LLC",GMAC GMAC (USAA),GMAC FREMONT BANK,Fremont Bank FREEDOM MORTGAGE CORP.,Freedom Mortgage FRANKLIN AMERICAN MORTGAGE COMPANY,Franklin America FLEET NATIONAL BANK,Fleet National FLAGSTAR CAPITAL MARKETS CORPORATION,Flagstar Bank "FLAGSTAR BANK, FSB",Flagstar Bank FIRST TENNESSEE BANK NATIONAL ASSOCIATION,Other FIFTH THIRD BANK,Fifth Third Bank FEDERAL HOME LOAN BANK OF CHICAGO,Fedral Home of Chicago "FDIC, RECEIVER, INDYMAC FEDERAL BANK FSB",FDIC "DOWNEY SAVINGS AND LOAN ASSOCIATION, F.A.",Downey Mortgage DITECH FINANCIAL LLC,Ditech "CITIMORTGAGE, INC.",Citi CHICAGO MORTGAGE SOLUTIONS DBA INTERFIRST MORTGAGE COMPANY,Chicago Mortgage CHICAGO MORTGAGE SOLUTIONS DBA INTERBANK MORTGAGE COMPANY,Chicago Mortgage "CHASE HOME FINANCE, LLC",JP Morgan Chase CHASE HOME FINANCE FRANKLIN AMERICAN MORTGAGE COMPANY,JP Morgan Chase CHASE HOME FINANCE (CIE 1),JP Morgan Chase CHASE HOME FINANCE,JP Morgan Chase "CASHCALL, INC.",CashCall "CAPITAL ONE, NATIONAL ASSOCIATION",Capital One "CALIBER HOME LOANS, INC.",Caliber Funding BISHOPS GATE RESIDENTIAL MORTGAGE TRUST,Bishops Gate Mortgage "BANK OF AMERICA, N.A.",Bank of America AMTRUST BANK,AmTrust AMERISAVE MORTGAGE CORPORATION,Amerisave "AMERIHOME MORTGAGE COMPANY, LLC",AmeriHome Mortgage ALLY BANK,Ally Bank ACADEMY MORTGAGE CORPORATION,Academy Mortgage NO CASH-OUT REFINANCE,OTHER REFINANCE REFINANCE - NOT SPECIFIED,OTHER REFINANCE Other REFINANCE,OTHER REFINANCE ================================================ FILE: examples/spark-connect-gpu/client/python/batch-job.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "initial_id", "metadata": { "collapsed": true }, "outputs": [], "source": "%run batch-job.py" } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: examples/spark-connect-gpu/client/python/batch-job.py ================================================ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from pyspark.sql import SparkSession from pyspark.sql.functions import * spark = (SparkSession .builder .getOrCreate() ) df = ( spark.range(2 ** 12) .withColumn("mod10", col("id") % lit(10)) .groupBy("mod10").agg(count("*")) .orderBy("mod10") ) # workaround to get a plan with GpuOverrides applied by disabling adaptive execution def explain(dataframe): spark.conf.set("spark.sql.adaptive.enabled", False) dataframe.explain(mode="extended") spark.conf.set("spark.sql.adaptive.enabled", True) ## Disable GPU accelerating print("--------------- CPU running by disabling spark.rapids.sql.enabled ---------------") spark.conf.set("spark.rapids.sql.enabled", False) explain(df) df.show() ## Enable GPU accelerating spark.conf.set("spark.rapids.sql.enabled", True) print("--------------- GPU running by enabling spark.rapids.sql.enabled ---------------") explain(df) df.show() spark.stop() ================================================ FILE: examples/spark-connect-gpu/client/requirements.txt ================================================ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. jupyterlab matplotlib # https://spark.apache.org/docs/latest/api/python/getting_started/install.html#python-spark-connect-client # ... pure Python library pyspark-client==4.0.0 ================================================ FILE: examples/spark-connect-gpu/client/scala/.gitignore ================================================ target .idea ================================================ FILE: examples/spark-connect-gpu/client/scala/pom.xml ================================================ 4.0.0 com.example spark-connect-demo 1.0-SNAPSHOT 4.0.0 2.13.16 2.13 org.apache.spark spark-connect-client-jvm_${scala.binary.version} ${spark.connect.version} org.scala-lang scala-library ${scala.version} net.alchim31.maven scala-maven-plugin 4.9.6 -XDignore.symbol.file true scala-compile-first process-resources add-source compile org.apache.maven.plugins maven-assembly-plugin 3.7.1 jar-with-dependencies assembly package single ================================================ FILE: examples/spark-connect-gpu/client/scala/run.sh ================================================ #! /bin/bash # work for jdk 17 java \ --add-exports=java.base/sun.nio.ch=ALL-UNNAMED \ --add-opens=java.base/java.nio=ALL-UNNAMED \ --add-opens=java.base/java.lang.invoke=ALL-UNNAMED \ --add-opens=java.base/java.util=ALL-UNNAMED \ --add-opens=java.base/sun.security.action=ALL-UNNAMED \ -cp spark-connect-demo-1.0-SNAPSHOT-jar-with-dependencies.jar connect ================================================ FILE: examples/spark-connect-gpu/client/scala/scala-run.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "initial_id", "metadata": { "collapsed": true }, "outputs": [], "source": [ "%%bash\n", "java \\\n", " --add-exports=java.base/sun.nio.ch=ALL-UNNAMED \\\n", " --add-opens=java.base/java.nio=ALL-UNNAMED \\\n", " --add-opens=java.base/java.lang.invoke=ALL-UNNAMED \\\n", " --add-opens=java.base/java.util=ALL-UNNAMED \\\n", " --add-opens=java.base/sun.security.action=ALL-UNNAMED \\\n", " -cp spark-connect-demo-1.0-SNAPSHOT-jar-with-dependencies.jar connect" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: examples/spark-connect-gpu/client/scala/src/main/scala/connect.scala ================================================ // Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. // Licensed to the Apache Software Foundation (ASF) under one or more // contributor license agreements. See the NOTICE file distributed with // this work for additional information regarding copyright ownership. // The ASF licenses this file to You under the Apache License, Version 2.0 // (the "License"); you may not use this file except in compliance with // the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.functions._ object connect extends Serializable { def explain(df: DataFrame): Unit = { // workaround to get a plan with GpuOverrides applied by disabling adaptive execution df.sparkSession.conf.set("spark.sql.adaptive.enabled", "false") df.explain(mode = "extended") df.sparkSession.conf.set("spark.sql.adaptive.enabled", "true") } def main(args: Array[String]): Unit = { val spark = SparkSession.builder().getOrCreate() val df = spark.range(1L << 12) .withColumn("mod10", col("id") % lit(10)) .groupBy("mod10").agg(count("*")) .orderBy("mod10") // Disable GPU accelerating println("--------------- CPU running by disabling spark.rapids.sql.enabled ---------------") spark.conf.set("spark.rapids.sql.enabled", "false") explain(df) df.show() // Enable GPU accelerating spark.conf.set("spark.rapids.sql.enabled", "true") println("--------------- GPU running by enabling spark.rapids.sql.enabled ---------------") explain(df) df.show() spark.stop() } } ================================================ FILE: examples/spark-connect-gpu/server/README.md ================================================ # GPU-Accelerated Spark Connect Server This project demonstrates how to set up a GPU-accelerated Spark server using Apache Spark 4.0 with Spark Connect, featuring the RAPIDS Accelerator. ## 🚀 Key Features - **Apache Spark 4.0** with cutting-edge Spark Connect capabilities - **GPU acceleration** via RAPIDS Accelerator - **MLlib over Spark Connect** - new in Spark 4.0 - **Zero-code-change acceleration** - existing Spark applications automatically benefit - **Jupyter Lab integration** for interactive development - **Docker Compose** setup for easy deployment with clear distinction what dependencies are required by what service and where GPUs are really used ## 🏗️ Architecture The setup consists of four Docker services: ### Apache Spark Standalone Cluster 1. **Spark Master** (`spark-master`) - Cluster coordination and job scheduling. This container does not have GPU capability 2. **Spark Worker** (`spark-worker`) - GPU-enabled worker node for task execution. This is the only service requiring and having access to the host GPUs ### Middle Tier 3. **Spark Connect Server** (`spark-connect-server`) - gRPC interface with the RAPIDS integration ### Proxy Service 4. nginx configured as provide access to various Apache Spark WebUI using the Docker network ### Frontend Web Browser 5. WebUI for the Spark Connect Server and the Spark Standalone Cluster To reduce the complexity of the demo, no services for global storage is included. The demo relies on the **DATA_DIR** location mounted from the host in place of a storage service. This location is also used for convenience to preserve metrics and Spark event logs beyond the container life cycle for analysis or debugging. When the **DATA_DIR** is accessed in a way that would normally require a global access we indicate this by using the `global_` prefix for the variable storing the complete path. Otherwise, we use variables starting with `local_`. ## 📋 Prerequisites ### Required - [Docker](https://docs.docker.com/engine/install/) and [Docker Compose](https://docs.docker.com/compose/install/linux) - At least 8GB of available RAM - Available ports: 2080, 8080, 8081, 8888, 7077, 4040, 15002 ### For GPU Acceleration - NVIDIA GPU with CUDA compute capability supported by RAPIDS - [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) - Docker Compose version should be `2.30.x` or newer to avoid an NVIDIA Container Toolkit related bug. [Update](https://docs.docker.com/compose/install/linux) if necessary - CUDA 12.x drivers ## 🚀 Quick Start 1. **Clone and navigate to the project:** ```bash cd examples/spark-connect-gpu/server ``` 2. **Set up data directory (if needed):** ```bash export DATA_DIR=$(pwd)/data mkdir -p $DATA_DIR/mortgage.input.csv $DATA_DIR/spark-events $DATA_DIR/nds chmod 1777 $DATA_DIR $DATA_DIR/spark-events ``` Download a few quarters worth of the [Mortgage Dataset](https://capitalmarkets.fanniemae.com/credit-risk-transfer/single-family-credit-risk-transfer/fannie-mae-single-family-loan-performance-data) to the `$DATA_DIR/mortgage.input.csv` location. More details can refer to [How to download the Mortgage dataset](https://github.com/NVIDIA/spark-rapids-examples/blob/main/docs/get-started/xgboost-examples/dataset/mortgage.md) To run NDS (see [NDS v2.0 Automation](https://github.com/NVIDIA/spark-rapids-benchmarks/tree/dev/nds#nds-v20-automation)), generate the dataset and place it in "$DATA_DIR/nds". For more details, refer to [NDS Data Generation](https://github.com/NVIDIA/spark-rapids-benchmarks/tree/dev/nds#data-generation). 3. **Start all services:** ```bash $ docker compose up -d ``` (`docker compose` can be used in place of `docker-compose` here and throughout) 4. **Access the Web UI interfaces:** ***Option 1 (default)*** All containers' webUI are available using localhost URI's by default - **Spark Master UI**: http://localhost:8080 - Cluster coordination and resource management - **Spark Worker UI**: http://localhost:8081 - GPU-enabled worker node status and tasks - **Spark Driver UI**: http://localhost:4040 - Application monitoring and SQL queries ***Option 2*** if you launch docker compose in the environment with `SPARK_PUBLIC_DNS=container-hostname`, all containers' web UI but Jupyter Lab is available using the corresponding container host names such as spark-master - **Spark Master UI**: http://spark-master:8080 - Cluster coordination and resource management - **Spark Worker UI**: http://spark-worker:8081 - GPU-enabled worker node status and tasks - **Spark Driver UI**: http://spark-connect-server:4040 - Application monitoring and SQL queries Docker DNS names require configuring your browser an http proxy on the Docker network exposed at http://localhost:2080. Here are examples of launching Google Chrome with a temporary user profile without making persistent changes on the browser ***Linux*** ```bash $ google-chrome --user-data-dir="/tmp/chrome-proxy-profile" --proxy-server="http=http://localhost:2080" ``` ***macOS*** ```bash $ open -n -a "Google Chrome" --args --user-data-dir="/tmp/chrome-proxy-profile" --proxy-server="http=http://localhost:2080" ``` ***Launching containers on a remote machine*** Your local machine might not have a GPU, and it is common in this case to use a remote machine/cluster with GPUs residing in a remote Cloud or on-prem environment If you followed the default Option 1 make sure to create local port forwards for every webUI port ```bash ssh -L 8888:localhost:8888 -L 8080:localhost:8080 -L 8081:localhost:8081 -L 4040:localhost:4040 ``` if you used Option 2 it is sufficient to forward ports only for the HTTP proxy and the Notebook app: ```bash ssh -L 2080:localhost:2080 -L 8888:localhost:8888 ``` ## 🐳 Service Details ### Spark Master - **Image**: `apache/spark:4.0.0` - **Ports**: 8080 (Web UI), 7077 (Master) - **Role**: Cluster coordination and resource management ### Spark Worker (the only GPU node role) - **Image**: Custom build based on `apache/spark:4.0.0` - **GPU**: NVIDIA GPU support via Docker Compose deploy configuration - **Ports**: 8081 (Web UI) - **Features**: GPU resource discovery and task execution ### Spark Connect Server - **Image**: Custom build based on `apache/spark:4.0.0` with Spark RAPIDS ETL and ML Plugins - **RAPIDS Version**: 26.02.0 for CUDA 12 - **Ports**: 15002 (gRPC), 4040 (Driver UI) - **Configuration**: Optimized for GPU acceleration with memory management ## 📊 Performance Monitoring You can use tools like nvtop, nvitop, btop or jupyterlab_nvdashboard running on the GPU host(s) ## 🧹 Cleanup Stop and remove all services: ```bash docker-compose down -v ``` Remove built images: ```bash docker-compose down --rmi all -v ``` ### Logs Logs for the spark driver/connect server, standalone master, standalone worker, and jupyter server can be viewed using the respective commands: ```bash docker logs spark-connect-server docker logs spark-master docker logs spark-worker ``` Spark executor logs can be accessed via the Spark UI as usual. ## 📖 Additional Resources - [Apache Spark 4.0 Documentation](https://spark.apache.org/docs/latest/) - [Spark Connect Guide](https://spark.apache.org/docs/latest/spark-connect-overview.html) - [NVIDIA RAPIDS Accelerator](https://nvidia.github.io/spark-rapids/) - [Data and AI Summit Session](https://www.databricks.com/dataaisummit/session/gpu-accelerated-spark-connect) ================================================ FILE: examples/spark-connect-gpu/server/docker-compose.yaml ================================================ # Copyright (c) 2025-2026, NVIDIA CORPORATION. All rights reserved. # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # YAML anchors for shared configurations x-spark-common: &spark-common networks: - spark-network volumes: - ${DATA_DIR:-${PWD}/data}:/data x-spark-common-env: &spark-common-env SPARK_PUBLIC_DNS: "${SPARK_PUBLIC_DNS:-localhost}" SPARK_NO_DAEMONIZE: "1" services: # Spark Master Node spark-master: <<: *spark-common image: spark-master-image build: context: ./spark-master dockerfile: Dockerfile container_name: spark-master hostname: spark-master environment: <<: *spark-common-env ports: - "8080:8080" # Spark Master Web UI - "7077:7077" # Spark Master Port command: /opt/spark/sbin/start-master.sh # Spark Worker Node (GPU-enabled) spark-worker: <<: *spark-common image: spark-worker-image build: context: ./spark-worker dockerfile: Dockerfile container_name: spark-worker hostname: spark-worker environment: <<: *spark-common-env ports: - "8081:8081" # Spark Worker WebUI depends_on: - spark-master command: /opt/spark/sbin/start-worker.sh spark://spark-master:7077 deploy: resources: reservations: devices: - driver: nvidia capabilities: [gpu] # Spark Connect Server spark-connect-server: <<: *spark-common image: spark-connect-server-image build: context: ./spark-connect-server dockerfile: Dockerfile args: - CUDA_VERSION=${CUDA_VERSION:-12} - RAPIDS_VERSION=${RAPIDS_VERSION:-26.02.0} - REPO_URL=${REPO_URL:-https://repo1.maven.org/maven2} container_name: spark-connect-server hostname: spark-connect-server environment: <<: *spark-common-env ports: - "4040:4040" # Spark Driver WebUI - "15002:15002" # Spark Connect grpc depends_on: - spark-master - spark-worker command: > /opt/spark/sbin/start-connect-server.sh --driver-memory=24G --conf spark.executor.memory=28G --conf spark.executor.cores=8 proxy-service: build: context: ./proxy-service dockerfile: Dockerfile container_name: proxy-service ports: - "2080:2080" networks: - spark-network depends_on: - spark-master - spark-worker - spark-connect-server restart: unless-stopped networks: spark-network: driver: bridge ================================================ FILE: examples/spark-connect-gpu/server/proxy-service/Dockerfile ================================================ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. FROM nginx:latest # Copy the nginx configuration file into the container COPY nginx.conf /etc/nginx/nginx.conf ================================================ FILE: examples/spark-connect-gpu/server/proxy-service/nginx.conf ================================================ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. events { worker_connections 1024; } http { resolver 127.0.0.11; server { listen 2080; location / { proxy_set_header Host $http_host; proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; proxy_pass $scheme://$http_host; proxy_read_timeout 90; } } } ================================================ FILE: examples/spark-connect-gpu/server/spark-connect-server/Dockerfile ================================================ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. FROM apache/spark:4.0.0 ARG CUDA_VERSION ARG RAPIDS_VERSION ARG REPO_URL USER root COPY requirements.txt /tmp/requirements.txt RUN pip3 install -r /tmp/requirements.txt RUN mkdir -p /opt/spark-rapids/jars RUN chown -R spark:spark /opt/spark-rapids USER spark ENV CUDA_VERSION=${CUDA_VERSION} ENV RAPIDS_VERSION=${RAPIDS_VERSION} ENV REPO_URL=${REPO_URL} RUN wget -q ${REPO_URL}/com/nvidia/rapids-4-spark_2.13/${RAPIDS_VERSION}/rapids-4-spark_2.13-${RAPIDS_VERSION}-cuda${CUDA_VERSION}.jar -O /opt/spark-rapids/jars/rapids-4-spark-sql.jar COPY spark-defaults.conf /opt/spark/conf/spark-defaults.conf COPY spark-env.sh /opt/spark/conf/spark-env.sh ================================================ FILE: examples/spark-connect-gpu/server/spark-connect-server/requirements.txt ================================================ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. packaging==23.1 pandas==2.2.3 psutil pyarrow scikit-learn>=1.2.1 spark-rapids-ml==25.8.0 ================================================ FILE: examples/spark-connect-gpu/server/spark-connect-server/spark-defaults.conf ================================================ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. spark.driver.maxResultSize=3g spark.eventLog.compress=true spark.eventLog.dir=/data/spark-events spark.eventLog.enabled=true spark.eventLog.compress=true spark.executor.resource.gpu.amount=1 spark.executor.resource.gpu.discoveryScript=/opt/spark/examples/src/main/scripts/getGpusResources.sh spark.jars=/opt/spark-rapids/jars/rapids-4-spark-sql.jar,/usr/local/lib/python3.10/dist-packages/spark_rapids_ml/jars/com.nvidia.rapids.ml-25.08.0.jar spark.local.dir=/opt/spark/work spark.locality.wait=0 spark.master=spark://spark-master:7077 spark.plugins=com.nvidia.spark.SQLPlugin spark.rapids.memory.gpu.allocFraction=0.45 spark.rapids.memory.gpu.maxAllocFraction=0.45 spark.rapids.memory.gpu.minAllocFraction=0.0 spark.rapids.ml.float32_inputs=false spark.rapids.ml.python.transform.enabled=false spark.rapids.ml.verbose=6 spark.rapids.sql.batchSizeBytes=512m spark.rapids.sql.concurrentGpuTasks=4 spark.rapids.sql.debug.logTransformations=true spark.rapids.sql.explain=ALL spark.shuffle.manager=com.nvidia.spark.rapids.spark400.RapidsShuffleManager spark.sql.ansi.enabled=false spark.sql.files.maxPartitionBytes=512m spark.sql.session.timeZone=UTC spark.task.resource.gpu.amount=0.0625 ================================================ FILE: examples/spark-connect-gpu/server/spark-connect-server/spark-env.sh ================================================ #!/bin/bash # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. if [[ "$SPARK_PUBLIC_DNS" == "container-hostname" ]]; then export SPARK_PUBLIC_DNS=$(hostname) elif [[ "$SPARK_PUBLIC_DNS" != "" ]]; then # handles default localhost or any other custom value export SPARK_PUBLIC_DNS fi ================================================ FILE: examples/spark-connect-gpu/server/spark-master/Dockerfile ================================================ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. FROM apache/spark:4.0.0 COPY spark-env.sh /opt/spark/conf/spark-env.sh ================================================ FILE: examples/spark-connect-gpu/server/spark-master/spark-env.sh ================================================ #!/bin/bash # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. if [[ "$SPARK_PUBLIC_DNS" == "container-hostname" ]]; then export SPARK_PUBLIC_DNS=$(hostname) elif [[ "$SPARK_PUBLIC_DNS" != "" ]]; then # handles default localhost or any other custom value export SPARK_PUBLIC_DNS fi ================================================ FILE: examples/spark-connect-gpu/server/spark-worker/Dockerfile ================================================ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. FROM apache/spark:4.0.0 USER root COPY requirements.txt /tmp/requirements.txt RUN pip3 install --extra-index-url=https://pypi.nvidia.com -r /tmp/requirements.txt # TODO hack to avoid configuring cupy compiler path RUN mkdir -p /home/spark RUN chown -R spark:spark /home/spark RUN usermod -d /home/spark spark USER spark COPY spark-env.sh /opt/spark/conf/spark-env.sh ================================================ FILE: examples/spark-connect-gpu/server/spark-worker/requirements.txt ================================================ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. cuml-cu12~=25.8.0 numpy~=1.0 spark-rapids-ml==25.8.0 ================================================ FILE: examples/spark-connect-gpu/server/spark-worker/spark-env.sh ================================================ #!/bin/bash # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. if [[ "$SPARK_PUBLIC_DNS" == "container-hostname" ]]; then export SPARK_PUBLIC_DNS=$(hostname) elif [[ "$SPARK_PUBLIC_DNS" != "" ]]; then # handles default localhost or any other custom value export SPARK_PUBLIC_DNS fi GPU_COUNT_MAX=$(nvidia-smi -L | wc -l) export SPARK_WORKER_OPTS=" -Dspark.worker.resource.gpu.amount=${GPU_COUNT_MAX} -Dspark.worker.resource.gpu.discoveryScript=/opt/spark/examples/src/main/scripts/getGpusResources.sh " # workaround for wheels installation not setting the correct LD_LIBRARY_PATH # https://github.com/rapidsai/cuml/issues/5300#issuecomment-2084646729 LD_LIBRARY_PATH=$(find /usr/local/lib/python3.10/dist-packages/nvidia -name lib -type d | xargs printf '%s:'):$LD_LIBRARY_PATH export LD_LIBRARY_PATH ================================================ FILE: scripts/README.md ================================================ ### Encoding Tool This tool is to convert the values from categorical type to numerical type in certain columns. Currently we supoort `mean encoding` and `one-hot encoding`. ### Main Procedure 1. User should firstly use our tool to profile the raw data source to get a "dictinary"(We call this dictionary `model`) that maps categorical values to certain numerical values. We call this method `train`. Each column will have its own `model` 2. User will use the `model` they got from step 1 to replace those categorical values with numerical values. ### Usage 1. `cd encoding/python` 2. `zip -r sample.zip com` to get a python encoding tool library 3. submit the encoding job to your Spark host You can find full use cases in `encoding-sample/run.sh` ### Application Parameters - mainClass: - `com.nvidia.spark.encoding.criteo.one_hot_cpu_main`: one-hot encoding - `com.nvidia.spark.encoding.criteo.target_cpu_main`: target(mean) encoding - mode: - `train`: use raw data to get encoding model - `transform`: use encoding moddel to convert raw data - format: - `csv`: only csv is supported - columns: - the target columns user wants to convert, e.g. `_34,_35` means user wants to get dictionary for both `_34` and `_35` columns - modelPaths: - for `train` mode, it points to the path where user wants to save the encoding model - for `transform` mode, it points to the model that the encoding conversion needs. - it is 1-1 mapped to `columns`. If user wants to encode 2 columns, he must provide 2 `modelPaths`. e.g. `model_34,model_35` - inputPaths: - raw data user wants to get encoding model from, or to convert - outputPaths: - only used in `transform` mode. - overwrite: - whether overwrite the exsiting model or output data - numRows: - optinal. show some rows in command line when encoding is finished. - labelColumn: - required in `target encoding`. Set the label column of raw data. ### Optimization 1. Due to default behaviors from some Spark methods, Some value may contain useless precison which causes the large size of `model`.e.g. 0.000000 and 1.000000 are identical to 0 and 1 in value perspective, but the csv model file that contains those values costs more disk space. We provide `truncate-model.py` in `encoding-sample` to remove the extra useless precisions. 2. We provide a repartition kit `repartition.py` to reparitition your output data. The usage can also be found in `encoding-sample/run.sh` ================================================ FILE: scripts/building/python_build.sh ================================================ #!/bin/bash # Copyright (c) 2024-2025, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Follow these steps to package the Python zip file cd ../../examples/XGBoost-Examples cd agaricus/python ; zip -r ../../samples.zip com ; cd ../.. cd mortgage/python ; zip -r ../../samples.zip com ; cd ../.. cd taxi/python ; zip -r ../../samples.zip com ; cd ../.. cd utility/python ; zip -r ../../samples.zip com ; cd ../.. ================================================ FILE: scripts/csp-startup-scripts/README.md ================================================ # Startup Scripts for CSPs with Spark Rapids With the exception of Dataproc, CSP offerings like EMR have specific set of steps that are required to enable the Spark Rapids Plugin in their environment. The set of scripts here automate parts of that process, for EMR currently. The exact usage can be found in our docs [here](https://docs.nvidia.com/spark-rapids/user-guide/latest/getting-started/aws-emr.html) ================================================ FILE: scripts/csp-startup-scripts/emr/cgroup-bootstrap-action-emr6.sh ================================================ #!/bin/bash # # Copyright (c) 2024-2026, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # set -ex sudo chmod a+rwx -R /sys/fs/cgroup/cpu,cpuacct sudo chmod a+rwx -R /sys/fs/cgroup/devices ================================================ FILE: scripts/csp-startup-scripts/emr/cgroup-bootstrap-action-emr7.sh ================================================ #!/bin/bash # # Copyright (c) 2024-2026, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # set -ex sudo mkdir -p /spark-rapids-cgroup/devices sudo mount -t cgroup -o devices cgroupv1-devices /spark-rapids-cgroup/devices sudo chmod a+rwx -R /spark-rapids-cgroup ================================================ FILE: scripts/csp-startup-scripts/emr/config-emr6.json ================================================ [ { "Classification":"spark", "Properties":{ "enableSparkRapids":"true" } }, { "Classification":"yarn-site", "Properties":{ "yarn.nodemanager.resource-plugins":"yarn.io/gpu", "yarn.resource-types":"yarn.io/gpu", "yarn.nodemanager.resource-plugins.gpu.allowed-gpu-devices":"auto", "yarn.nodemanager.resource-plugins.gpu.path-to-discovery-executables":"/usr/bin", "yarn.nodemanager.linux-container-executor.cgroups.mount":"true", "yarn.nodemanager.linux-container-executor.cgroups.mount-path":"/sys/fs/cgroup", "yarn.nodemanager.linux-container-executor.cgroups.hierarchy":"yarn", "yarn.nodemanager.container-executor.class":"org.apache.hadoop.yarn.server.nodemanager.LinuxContainerExecutor" } }, { "Classification":"container-executor", "Properties":{ }, "Configurations":[ { "Classification":"gpu", "Properties":{ "module.enabled":"true" } }, { "Classification":"cgroups", "Properties":{ "root":"/sys/fs/cgroup", "yarn-hierarchy":"yarn" } } ] }, { "Classification":"spark-defaults", "Properties":{ "spark.plugins":"com.nvidia.spark.SQLPlugin", "spark.executor.resource.gpu.discoveryScript":"/usr/lib/spark/scripts/gpu/getGpusResources.sh", "spark.submit.pyFiles":"/usr/lib/spark/jars/xgboost4j-spark_3.0-1.4.2-0.3.0.jar", "spark.executor.extraLibraryPath":"/usr/local/cuda/targets/x86_64-linux/lib:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda/compat/lib:/usr/local/cuda/lib:/usr/local/cuda/lib64:/usr/lib/hadoop/lib/native:/usr/lib/hadoop-lzo/lib/native:/docker/usr/lib/hadoop/lib/native:/docker/usr/lib/hadoop-lzo/lib/native", "spark.rapids.sql.concurrentGpuTasks":"2", "spark.executor.resource.gpu.amount":"1", "spark.executor.cores":"${executor_cores}", "spark.task.cpus":"1", "spark.task.resource.gpu.amount":"${task_gpu_amount}", "spark.rapids.memory.pinnedPool.size":"2G", "spark.executor.memoryOverhead":"2G", "spark.sql.files.maxPartitionBytes":"256m", "spark.sql.adaptive.enabled":"false" } }, { "Classification":"capacity-scheduler", "Properties":{ "yarn.scheduler.capacity.resource-calculator":"org.apache.hadoop.yarn.util.resource.DominantResourceCalculator" } } ] ================================================ FILE: scripts/csp-startup-scripts/emr/config-emr7.json ================================================ [ { "Classification": "spark", "Properties": { "enableSparkRapids": "true" } }, { "Classification": "yarn-site", "Properties": { "yarn.nodemanager.resource-plugins": "yarn.io/gpu", "yarn.resource-types": "yarn.io/gpu", "yarn.nodemanager.resource-plugins.gpu.allowed-gpu-devices": "auto", "yarn.nodemanager.resource-plugins.gpu.path-to-discovery-executables": "/usr/bin", "yarn.nodemanager.linux-container-executor.cgroups.mount": "true", "yarn.nodemanager.linux-container-executor.cgroups.mount-path": "/spark-rapids-cgroup", "yarn.nodemanager.linux-container-executor.cgroups.hierarchy": "yarn", "yarn.nodemanager.container-executor.class": "org.apache.hadoop.yarn.server.nodemanager.LinuxContainerExecutor" } }, { "Classification": "container-executor", "Properties": {}, "Configurations": [ { "Classification": "gpu", "Properties": { "module.enabled": "true" } }, { "Classification": "cgroups", "Properties": { "root": "/spark-rapids-cgroup", "yarn-hierarchy": "yarn" } } ] }, { "Classification": "spark-defaults", "Properties": { "spark.plugins": "com.nvidia.spark.SQLPlugin", "spark.executor.resource.gpu.discoveryScript": "/usr/lib/spark/scripts/gpu/getGpusResources.sh", "spark.submit.pyFiles": "/usr/lib/spark/jars/xgboost4j-spark_3.0-1.4.2-0.3.0.jar", "spark.executor.extraLibraryPath": "/usr/local/cuda/targets/x86_64-linux/lib:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda/compat/lib:/usr/local/cuda/lib:/usr/local/cuda/lib64:/usr/lib/hadoop/lib/native:/usr/lib/hadoop-lzo/lib/native:/docker/usr/lib/hadoop/lib/native:/docker/usr/lib/hadoop-lzo/lib/native", "spark.rapids.sql.concurrentGpuTasks": "2", "spark.executor.resource.gpu.amount": "1", "spark.executor.cores": "${executor_cores}", "spark.task.cpus": "1", "spark.task.resource.gpu.amount": "${task_gpu_amount}", "spark.rapids.memory.pinnedPool.size": "2G", "spark.executor.memoryOverhead": "2G", "spark.sql.files.maxPartitionBytes": "256m", "spark.sql.adaptive.enabled": "false" } }, { "Classification": "capacity-scheduler", "Properties": { "yarn.scheduler.capacity.resource-calculator": "org.apache.hadoop.yarn.util.resource.DominantResourceCalculator" } } ] ================================================ FILE: scripts/csp-startup-scripts/emr/emr-spark-plugin-startup.py ================================================ # # Copyright (c) 2024-2026, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import argparse import json import os import subprocess import tempfile import boto3 from botocore.exceptions import NoCredentialsError, PartialCredentialsError def upload_file_to_s3(file_name, bucket_name, object_name=None): s3 = boto3.client('s3') # If no object name is specified, use the file name if object_name is None: object_name = file_name try: s3.upload_file(file_name, bucket_name, object_name) print(f"File '{file_name}' uploaded successfully to bucket '{bucket_name}' as '{object_name}'") return True except FileNotFoundError: print(f"Error: The file {file_name} was not found.") except NoCredentialsError: print("Error: AWS credentials not found.") except PartialCredentialsError: print("Error: Incomplete AWS credentials.") except Exception as e: print(f"An error occurred: {e}") return False g4dn_instance_map = { "g4dn.xlarge": 4, "g4dn.2xlarge": 8, "g4dn.4xlarge": 16, "g4dn.12xlarge": 48, "g4dn.16xlarge": 64 } _SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) def create_emr_cluster(release_label, key_name, service_role, subnet_id, az, instance_profile, worker_instance, s3_bucket_name): try: conf_json_fn = None bootstrap_fn = None if "emr-7" in release_label: conf_json_fn="config-emr7.json" bootstrap_fn="cgroup-bootstrap-action-emr7.sh" else: conf_json_fn="config-emr6.json" bootstrap_fn="cgroup-bootstrap-action-emr6.sh" # Replace the fields in the json exec_cores = g4dn_instance_map.get(worker_instance) if exec_cores is None: print(f"Error: Unsupported worker instance type '{worker_instance}'. " f"Supported types: {list(g4dn_instance_map.keys())}") return conf_json_path = os.path.join(_SCRIPT_DIR, conf_json_fn) bootstrap_path = os.path.join(_SCRIPT_DIR, bootstrap_fn) print("Config Json" + conf_json_fn) with open(conf_json_path, 'r') as file: data = json.load(file) json_string = json.dumps(data) # Replace the placeholder with the actual variable json_string = json_string.replace("${task_gpu_amount}", str(1/exec_cores)) json_string = json_string.replace("${executor_cores}", str(exec_cores)) updated_data = json.loads(json_string) print(json.dumps(updated_data, indent=4)) if not upload_file_to_s3(bootstrap_path, s3_bucket_name, bootstrap_fn): return with tempfile.NamedTemporaryFile(mode="w", suffix=".json") as config_file: json.dump(updated_data, config_file) config_file.flush() command = [ "aws", "emr", "create-cluster", "--release-label", release_label, "--applications", "Name=Hadoop", "Name=Spark", "Name=Livy", "Name=JupyterEnterpriseGateway", "--service-role", service_role, "--ec2-attributes", f"KeyName={key_name},SubnetId={subnet_id},AvailabilityZone={az},InstanceProfile={instance_profile}", "--instance-groups", "InstanceGroupType=MASTER,InstanceCount=1,InstanceType=m4.4xlarge", f"InstanceGroupType=CORE,InstanceCount=1,InstanceType={worker_instance}", "--configurations", f"file://{config_file.name}", "--bootstrap-actions", f"Name='Setup cgroups bootstrap',Path=s3://{s3_bucket_name}/{bootstrap_fn}" ] result = subprocess.run(command, check=True, text=True, capture_output=True) print("Cluster created successfully!") print(result.stdout) except subprocess.CalledProcessError as e: print("Error creating EMR cluster:", e.stderr) parser = argparse.ArgumentParser(description="A script that takes command-line arguments.") # Define arguments parser.add_argument("-r", "--release_label", type=str, default="emr-7.1.0", help="EMR Release Label, emr-7.1.0 for example") parser.add_argument("-k", "--key_name", type=str, required=True, help="Access Key Name") parser.add_argument("-s", "--service_role", type=str, required=True, help="AWS EMR service Role") parser.add_argument("-n", "--subnet", type=str, required=True, help="Subnet ID") parser.add_argument("-z", "--availability_zone", type=str, default="us-west-2b", help="Availability Zone") parser.add_argument("-i", "--instance_profile", type=str, required=True, help="Instance Profile") parser.add_argument("-w", "--worker_instance", type=str, default="g4dn.2xlarge", help="Worker Instance g4dn.xxxx") parser.add_argument("-b", "--s3_bucket_name", type=str, required=True, help="S3 Bucket Name to store the bootstrap and config info") args = parser.parse_args() release_label = args.release_label key_name = args.key_name service_role = args.service_role subnet_id = args.subnet az = args.availability_zone instance_profile = args.instance_profile worker_instance = args.worker_instance s3_bucket_name = args.s3_bucket_name create_emr_cluster(release_label, key_name, service_role, subnet_id, az, instance_profile, worker_instance, s3_bucket_name) ================================================ FILE: scripts/encoding/python/.gitignore ================================================ .idea ================================================ FILE: scripts/encoding/python/com/__init__.py ================================================ # Copyright (c) 2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: scripts/encoding/python/com/nvidia/__init__.py ================================================ # Copyright (c) 2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: scripts/encoding/python/com/nvidia/spark/__init__.py ================================================ # Copyright (c) 2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: scripts/encoding/python/com/nvidia/spark/encoding/__init__.py ================================================ # Copyright (c) 2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: scripts/encoding/python/com/nvidia/spark/encoding/criteo/__init__.py ================================================ # Copyright (c) 2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: scripts/encoding/python/com/nvidia/spark/encoding/criteo/common.py ================================================ # # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # def customize_reader(reader): (reader .option('sep', '\t')) def customize_writer(writer): (writer .option('sep', '\t') .option('nullValue', None)) ================================================ FILE: scripts/encoding/python/com/nvidia/spark/encoding/criteo/one_hot_cpu_main.py ================================================ # # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from com.nvidia.spark.encoding.criteo.common import * from com.nvidia.spark.encoding.utility.utils import * from pyspark.ml.feature import StringIndexer, StringIndexerModel from pyspark.sql import SparkSession from pyspark.sql.functions import col def index(df, column): column_index = column + '_index' return (StringIndexer(inputCol=column, outputCol=column_index) .setHandleInvalid('keep') .fit(df)) def expand(indexer, df, column): column_index = column + '_index' df = (indexer .transform(df) .withColumn(column_index, col(column_index).cast('int'))) for i in range(0, len(indexer.labels)): df = df.withColumn(column + '_' + str(i), (col(column_index) == i).cast('int')) return df.drop(column, column_index) def main(args): spark = (SparkSession .builder .appName(args.mainClass) .getOrCreate()) if args.mode == 'train': df = load_data(spark, args.inputPaths, args, customize_reader).cache() for column, path in zip(args.columns, args.modelPaths): indexer = index(df, column) save_model(indexer, path, args) if args.mode == 'transform': indexers = list(zip(args.columns, load_models(StringIndexerModel, args.modelPaths))) for input_path, output_path in zip(args.inputPaths, args.outputPaths): df = load_data(spark, input_path, args, customize_reader) for column, indexer in indexers: df = expand(indexer, df, column) args.numRows and df.show(args.numRows) save_data(df, output_path, args, customize_writer) spark.stop() ================================================ FILE: scripts/encoding/python/com/nvidia/spark/encoding/criteo/target_cpu_main.py ================================================ # # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from com.nvidia.spark.encoding.criteo.common import * from com.nvidia.spark.encoding.utility.utils import * from pyspark.sql import SparkSession from pyspark.sql.functions import udf from pyspark.sql import functions as F from pyspark.sql.types import FloatType, DoubleType import time def get_dict_df(train_df, target_col, label_col): ''' get one dict dataframe for one column ''' col_target_df = train_df.groupBy(target_col).agg(F.mean(label_col)) return col_target_df def encode_df(original_df, dict_df, col_name): dict_df_rename = dict_df.withColumnRenamed('_c0', 'hash').withColumnRenamed('_c1', col_name+'_mean') df_mean = (original_df.join(dict_df_rename, original_df[col_name] == dict_df_rename['hash'], how='left').drop('hash').drop(col_name) .na.fill(-1, [col_name + '_mean'])) return df_mean def main(args): spark = (SparkSession .builder .appName(args.mainClass) .getOrCreate()) if args.mode == 'train': for col_name, model_path in zip(args.columns, args.modelPaths): df = load_data(spark, args.inputPaths, args, customize_reader).cache() dict_df = get_dict_df(df, col_name, args.labelColumn) dict_df.repartition(1).write.csv(model_path) if args.mode == 'transform': dict_dfs = [ load_dict_df(spark, path).withColumn('_c1', F.col('_c1').cast(DoubleType())).cache() for path in args.modelPaths ] for input_path, output_path in zip(args.inputPaths, args.outputPaths): df = load_data(spark, input_path, args, customize_reader) for col_name, dict_df in zip(args.columns, dict_dfs): df = encode_df(df, dict_df, col_name) save_data(df, output_path, args, customize_writer) ================================================ FILE: scripts/encoding/python/com/nvidia/spark/encoding/main.py ================================================ # # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from com.nvidia.spark.encoding.utility.args import parse_arguments from importlib import import_module def main(): args = parse_arguments() getattr(import_module(args.mainClass), 'main')(args) ================================================ FILE: scripts/encoding/python/com/nvidia/spark/encoding/utility/__init__.py ================================================ # Copyright (c) 2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: scripts/encoding/python/com/nvidia/spark/encoding/utility/args.py ================================================ # # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import sys from argparse import ArgumentParser from distutils.util import strtobool def _to_bool(literal): return bool(strtobool(literal)) def _to_str_list(literal): return [x for x in literal.split(',') if x] _examples = [ 'com.nvidia.spark.encoding.criteo.one_hot_cpu_main', 'com.nvidia.spark.encoding.criteo.target_cpu_main' ] def _validate_args(args): usage = '' if args.mode == 'transform' and not args.outputPaths: usage += ' --outputPaths required for transform.\n' # for production: # validates that --columns and --inputPaths exists # validates that --inputPath and --outputPath matches for transform if (args.mainClass == 'com.nvidia.spark.encoding.criteo.target_cpu_main' and args.mode == 'train' and not args.labelColumn): usage += ' --labelColumn required for target encoding. \n' if usage: print('-' * 80) print('Usage:\n' + usage) sys.exit(1) def parse_arguments(): parser = ArgumentParser() # application arguments parser.add_argument('--mainClass', required=True, choices=_examples) parser.add_argument('--mode', choices=['train', 'transform'], required=True) parser.add_argument('--format', choices=['csv'], default='csv') parser.add_argument('--columns', type=_to_str_list, required=True) parser.add_argument('--modelPaths', type=_to_str_list, required=True) parser.add_argument('--inputPaths', type=_to_str_list, required=True) parser.add_argument('--outputPaths', type=_to_str_list) # for transform, required parser.add_argument('--overwrite', type=_to_bool, default=False) parser.add_argument('--numRows', type=int) # for transform, optional parser.add_argument('--labelColumn', help='name of the label column') # for target encoding, required parsed = parser.parse_args() _validate_args(parsed) return parsed ================================================ FILE: scripts/encoding/python/com/nvidia/spark/encoding/utility/utils.py ================================================ # # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import pickle def load_data(spark, paths, args, customize=None): reader = (spark .read .format(args.format)) customize and customize(reader) return reader.load(paths) def save_data(data_frame, path, args, customize=None): writer = (data_frame .write .format(args.format)) args.overwrite and writer.mode('overwrite') customize and customize(writer) writer.save(path) def load_model(model_class, path): return model_class.load(path) def load_models(model_class, paths): return [load_model(model_class, path) for path in paths] def save_model(model, path, args): writer = model.write().overwrite() if args.overwrite else model writer.save(path) def save_dict(mean_dict, target_path): ''' target_path: full path of the target location to save the dict ''' with open(target_path+'.pkl', 'wb') as f: pickle.dump(mean_dict, f, pickle.HIGHEST_PROTOCOL) def load_dict(dict_path): ''' dict_path: full path of target dict with '.pkl' tail. ''' with open(dict_path, 'rb') as f: return pickle.load(f) def load_dict_df(spark, dict_df_path): return spark.read.option("header","false").csv(dict_df_path) ================================================ FILE: scripts/encoding/python/main.py ================================================ # # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from com.nvidia.spark.encoding.main import main main() ================================================ FILE: scripts/encoding-sample/repartition.py ================================================ # Copyright (c) 2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Note: Plase modify the data source options for your case. import sys from pyspark.sql import SparkSession (SparkSession .builder .getOrCreate() .read .option('sep', '\t') .csv(sys.argv[1]) .repartition(int(sys.argv[3])) .write .option('sep', '\t') .option('nullValue', None) .csv(sys.argv[2])) ================================================ FILE: scripts/encoding-sample/run.sh ================================================ #!/bin/bash # Copyright (c) 2024-2025, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # clear rm -f encoding.zip main.py rm -f raw-*.csv rm -rf model target-* onehot-* final-* # prepare data head -n 500 ../../datasets/clicklog.csv > raw-1.csv head -n 750 ../../datasets/clicklog.csv | tail -n 250 > raw-2.csv tail -n 250 ../../datasets/clicklog.csv > raw-3.csv # assemble python libs pushd ../encoding/python/ zip -r ../../encoding-sample/encoding.zip ai/ cp main.py ../../encoding-sample/ popd # train target models/dicts spark-submit --py-files encoding.zip main.py \ --mainClass=com.nvidia.spark.encoding.criteo.target_cpu_main --mode=train \ --format=csv --inputPaths=raw-1.csv,raw-2.csv \ --labelColumn=_c0 --columns=_c34,_c35 --modelPaths=model/c34.dict,model/c35.dict spark-submit truncate-model.py model/c34.dict model/c34_truncated.dict spark-submit truncate-model.py model/c35.dict model/c35_truncated.dict # train onehot models/indexers spark-submit --py-files encoding.zip main.py \ --mainClass=com.nvidia.spark.encoding.criteo.one_hot_cpu_main --mode=train \ --format=csv --inputPaths=raw-1.csv,raw-2.csv \ --columns=_c19,_c26 --modelPaths=model/_c19,model/_c26 # target encoding spark-submit --py-files encoding.zip main.py \ --mainClass=com.nvidia.spark.encoding.criteo.target_cpu_main --mode=transform \ --columns=_c34,_c35 --modelPaths=model/c34_truncated.dict,model/c35_truncated.dict \ --format=csv --inputPaths=raw-1.csv,raw-2.csv,raw-3.csv --outputPaths=target-1,target-2,target-3 # onehot encoding # NOTE: If the column index changed after target encoding, you should change the metadata of all # models accordingly. E.g., change "outputCol":"_c26_index","inputCol":"_c26" to # "outputCol":"_c25_index","inputCol":"_c25" for file model/_c26/metadata/part-00000. # This is verified on Spark 2.x. spark-submit --py-files encoding.zip main.py \ --mainClass=com.nvidia.spark.encoding.criteo.one_hot_cpu_main --mode=transform \ --columns=_c19,_c26 --modelPaths=model/_c19,model/_c26 \ --format=csv --inputPaths=target-1,target-2,target-3 --outputPaths=onehot-1,onehot-2,onehot-3 # NOTE: As an example, not all categorical columns are encoded here. # But please encode all categorical columns in production environment. # repartition spark-submit repartition.py onehot-1 final-1 5 spark-submit repartition.py onehot-2 final-2 5 spark-submit repartition.py onehot-3 final-3 5 # known issues: # - Issue: "org.apache.spark.shuffle.FetchFailedException: Too large frame: ...": # Solution: Add "--conf spark.maxRemoteBlockSizeFetchToMem=1G" ================================================ FILE: scripts/encoding-sample/truncate-model.py ================================================ # Copyright (c) 2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import sys from pyspark.sql import SparkSession from pyspark.sql.functions import * (SparkSession .builder .getOrCreate() .read .csv(sys.argv[1]) .withColumn('_c1', format_string('%.6f', col('_c1').cast('float'))) .withColumn('_c1', when(col('_c1') == '0.000000', lit('0.0')).otherwise(col('_c1'))) .withColumn('_c1', when(col('_c1') == '1.000000', lit('1.0')).otherwise(col('_c1'))) .repartition(1) .write .option('nullValue', None) .csv(sys.argv[2])) ================================================ FILE: tools/databricks/README.md ================================================ # Databricks Qualification/Profiling Quick Start Notebooks The RAPIDS Accelerator for Apache Spark includes two key tools for understanding the benefits of GPU acceleration as well as analyzing GPU Spark jobs. For customers on Databricks, the quick start notebooks offer a simple interface for running the tools given a set of Spark event logs from CPU (qualification) or GPU (profiling) application runs. To use a demo notebook, you can import the notebook in the Databricks Notebook UI via File->Import Notebook. Once the demo notebook is imported, you can select run to activate the notebook to an available compute cluster. Once the notebook is activated, you can enter in the log path location in the text widget at the top of the notebook. After that, select *Run all* to execute the tools for the specific logs in the log path. ## Limitations 1. Currently local, S3 or DBFS event log paths are supported. 1. S3 path is only supported on Databricks AWS using [instance profiles](https://docs.databricks.com/en/connect/storage/tutorial-s3-instance-profile.html). 1. Eventlog path must follow the formats `/dbfs/path/to/eventlog` or `dbfs:/path/to/eventlog` for logs stored in DBFS. 1. Use wildcards for nested lookup of eventlogs. - For example: `/dbfs/path/to/clusterlogs/*/*` 1. Multiple event logs must be comma-separated. - For example: `/dbfs/path/to/eventlog1,/dbfs/path/to/eventlog2` **Latest Tools Version Supported** 26.02.0 ================================================ FILE: tools/databricks/[RAPIDS Accelerator for Apache Spark] Profiling Tool Notebook Template.ipynb ================================================ { "cells": [ { "metadata": {}, "cell_type": "raw", "source": [ "{\n", " \"cells\": [\n", " {\n", " \"cell_type\": \"markdown\",\n", " \"metadata\": {\n", " \"application/vnd.databricks.v1+cell\": {\n", " \"cellMetadata\": {\n", " \"byteLimit\": 2048000,\n", " \"rowLimit\": 10000\n", " },\n", " \"inputWidgets\": {},\n", " \"nuid\": \"df33c614-2ecc-47a0-8600-bc891681997f\",\n", " \"showTitle\": false,\n", " \"title\": \"\"\n", " }\n", " },\n", " \"source\": [\n", " \"## Welcome to the Profiling Tool for the RAPIDS Accelerator for Apache Spark\\n\",\n", " \"\\n\",\n", " \"To run the profiling tool, enter the log path that represents the DBFS location of your Spark GPU event logs. Then, select \\\"Run all\\\" to execute the notebook. Once the notebook completes, various output tables will appear below. For more options on running the profiling tool, please refer to the [Profiling Tool User Guide](https://docs.nvidia.com/spark-rapids/user-guide/latest/profiling/quickstart.html#running-the-tool).\\n\",\n", " \"\\n\",\n", " \"### Note\\n\",\n", " \"- Currently, local, S3 or DBFS event log paths are supported.\\n\",\n", " \"- S3 path is only supported on Databricks AWS using [instance profiles](https://docs.databricks.com/en/connect/storage/tutorial-s3-instance-profile.html).\\n\",\n", " \"- Eventlog path must follow the formats `/dbfs/path/to/eventlog` or `dbfs:/path/to/eventlog` for logs stored in DBFS.\\n\",\n", " \"- Use wildcards for nested lookup of eventlogs. \\n\",\n", " \" - For example: `/dbfs/path/to/clusterlogs/*/*`\\n\",\n", " \"- Multiple event logs must be comma-separated. \\n\",\n", " \" - For example: `/dbfs/path/to/eventlog1,/dbfs/path/to/eventlog2`\\n\",\n", " \"\\n\",\n", " \"### Per-Job Profile\\n\",\n", " \"\\n\",\n", " \"The profiler output includes information about the application, data sources, executors, SQL stages, Spark properties, and key application metrics at the job and stage levels.\"\n", " ]\n", " },\n", " {\n", " \"cell_type\": \"code\",\n", " \"execution_count\": 0,\n", " \"metadata\": {\n", " \"application/vnd.databricks.v1+cell\": {\n", " \"cellMetadata\": {\n", " \"byteLimit\": 2048000,\n", " \"rowLimit\": 10000\n", " },\n", " \"inputWidgets\": {},\n", " \"nuid\": \"5e9f5796-46ed-49ac-9d08-c8b98a87c39d\",\n", " \"showTitle\": true,\n", " \"title\": \"Set Tools Version\"\n", " },\n", " \"jupyter\": {\n", " \"source_hidden\": true\n", " }\n", " },\n", " \"outputs\": [],\n", " \"source\": [\n", " \"DEFAULT_TOOLS_VER = \\\"24.12.4\\\"\\n\",\n", " \"TOOLS_VER_ARG = dbutils.widgets.get(\\\"Tools Version\\\")\\n\",\n", " \"TOOLS_VER = TOOLS_VER_ARG if TOOLS_VER_ARG else DEFAULT_TOOLS_VER\\n\",\n", " \"print(f\\\"Using Tools Version: {TOOLS_VER}\\\")\"\n", " ]\n", " },\n", " {\n", " \"cell_type\": \"code\",\n", " \"execution_count\": 0,\n", " \"metadata\": {\n", " \"application/vnd.databricks.v1+cell\": {\n", " \"cellMetadata\": {\n", " \"byteLimit\": 2048000,\n", " \"rowLimit\": 10000\n", " },\n", " \"inputWidgets\": {},\n", " \"nuid\": \"313ee58b-61b3-4010-9d60-d21eceea796c\",\n", " \"showTitle\": true,\n", " \"title\": \"Install Package\"\n", " }\n", " },\n", " \"outputs\": [],\n", " \"source\": [\n", " \"%pip install spark-rapids-user-tools==$TOOLS_VER > /dev/null\"\n", " ]\n", " },\n", " {\n", " \"cell_type\": \"code\",\n", " \"execution_count\": 0,\n", " \"metadata\": {\n", " \"application/vnd.databricks.v1+cell\": {\n", " \"cellMetadata\": {\n", " \"byteLimit\": 2048000,\n", " \"rowLimit\": 10000\n", " },\n", " \"inputWidgets\": {},\n", " \"nuid\": \"34492d18-1130-45be-b9f7-e6931d3fa66b\",\n", " \"showTitle\": true,\n", " \"title\": \"Environment Setup\"\n", " },\n", " \"jupyter\": {\n", " \"source_hidden\": true\n", " }\n", " },\n", " \"outputs\": [],\n", " \"source\": [\n", " \"import os\\n\",\n", " \"import pandas as pd\\n\",\n", " \"\\n\",\n", " \"\\n\",\n", " \"def convert_dbfs_path(path):\\n\",\n", " \" return path.replace(\\\"dbfs:/\\\", \\\"/dbfs/\\\")\\n\",\n", " \" \\n\",\n", " \"# Detect cloud provider from cluster usage tags\\n\",\n", " \"valid_csps = [\\\"aws\\\", \\\"azure\\\"]\\n\",\n", " \"CSP=spark.conf.get(\\\"spark.databricks.clusterUsageTags.cloudProvider\\\", \\\"\\\").lower()\\n\",\n", " \"if CSP not in valid_csps:\\n\",\n", " \" print(f\\\"ERROR: Cannot detect cloud provider from cluster usage tags. Using '{valid_csps[0]}' as default. \\\")\\n\",\n", " \" CSP = valid_csps[0]\\n\",\n", " \"else:\\n\",\n", " \" print(f\\\"Detected Cloud Provider from Spark Configs: '{CSP}'\\\")\\n\",\n", " \"\\n\",\n", " \"# Initialize variables from widgets\\n\",\n", " \"dbutils.widgets.text(\\\"Eventlog Path\\\", \\\"/dbfs/user1/profiling_logs\\\")\\n\",\n", " \"EVENTLOG_PATH=dbutils.widgets.get(\\\"Eventlog Path\\\")\\n\",\n", " \"EVENTLOG_PATH=convert_dbfs_path(EVENTLOG_PATH)\\n\",\n", " \"\\n\",\n", " \"dbutils.widgets.text(\\\"Output Path\\\", \\\"/tmp\\\")\\n\",\n", " \"OUTPUT_PATH=dbutils.widgets.get(\\\"Output Path\\\")\\n\",\n", " \"\\n\",\n", " \"# Setup environment variables\\n\",\n", " \"os.environ[\\\"CSP\\\"] = CSP\\n\",\n", " \"os.environ[\\\"EVENTLOG_PATH\\\"] = EVENTLOG_PATH\\n\",\n", " \"os.environ[\\\"OUTPUT_PATH\\\"] = OUTPUT_PATH\\n\",\n", " \"\\n\",\n", " \"# Setup console output file\\n\",\n", " \"CONSOLE_OUTPUT_PATH = os.path.join(OUTPUT_PATH, 'console_output.log')\\n\",\n", " \"CONSOLE_ERROR_PATH = os.path.join(OUTPUT_PATH, 'console_error.log')\\n\",\n", " \"os.environ['CONSOLE_OUTPUT_PATH'] = CONSOLE_OUTPUT_PATH\\n\",\n", " \"os.environ['CONSOLE_ERROR_PATH'] = CONSOLE_ERROR_PATH\\n\",\n", " \"print(f'Console output will be stored at {CONSOLE_OUTPUT_PATH} and errors will be stored at {CONSOLE_ERROR_PATH}')\"\n", " ]\n", " },\n", " {\n", " \"cell_type\": \"code\",\n", " \"execution_count\": 0,\n", " \"metadata\": {\n", " \"application/vnd.databricks.v1+cell\": {\n", " \"cellMetadata\": {\n", " \"byteLimit\": 2048000,\n", " \"rowLimit\": 10000\n", " },\n", " \"inputWidgets\": {},\n", " \"nuid\": \"693b5ee0-7500-43f3-b3e2-717fd5468aa8\",\n", " \"showTitle\": true,\n", " \"title\": \"Run Profiling Tool\"\n", " }\n", " },\n", " \"outputs\": [],\n", " \"source\": [\n", " \"%sh\\n\",\n", " \"spark_rapids profiling --platform databricks-$CSP --eventlogs \\\"$EVENTLOG_PATH\\\" -o \\\"$OUTPUT_PATH\\\" --verbose > \\\"$CONSOLE_OUTPUT_PATH\\\" 2> \\\"$CONSOLE_ERROR_PATH\\\"\"\n", " ]\n", " },\n", " {\n", " \"cell_type\": \"markdown\",\n", " \"metadata\": {\n", " \"application/vnd.databricks.v1+cell\": {\n", " \"cellMetadata\": {\n", " \"byteLimit\": 2048000,\n", " \"rowLimit\": 10000\n", " },\n", " \"inputWidgets\": {},\n", " \"nuid\": \"f83af6c8-5a79-4a46-965b-38a4cb621877\",\n", " \"showTitle\": false,\n", " \"title\": \"\"\n", " }\n", " },\n", " \"source\": [\n", " \"## Console Output\\n\",\n", " \"Console output shows the recommended configurations for each app\\n\"\n", " ]\n", " },\n", " {\n", " \"cell_type\": \"code\",\n", " \"execution_count\": 0,\n", " \"metadata\": {\n", " \"application/vnd.databricks.v1+cell\": {\n", " \"cellMetadata\": {\n", " \"byteLimit\": 2048000,\n", " \"rowLimit\": 10000\n", " },\n", " \"inputWidgets\": {},\n", " \"nuid\": \"c61527b7-a21a-492c-bab8-77f83dc5cabf\",\n", " \"showTitle\": true,\n", " \"title\": \"Show Console Output\"\n", " },\n", " \"jupyter\": {\n", " \"source_hidden\": true\n", " }\n", " },\n", " \"outputs\": [],\n", " \"source\": [\n", " \"%sh\\n\",\n", " \"cat $CONSOLE_OUTPUT_PATH\"\n", " ]\n", " },\n", " {\n", " \"cell_type\": \"code\",\n", " \"execution_count\": 0,\n", " \"metadata\": {\n", " \"application/vnd.databricks.v1+cell\": {\n", " \"cellMetadata\": {\n", " \"byteLimit\": 2048000,\n", " \"rowLimit\": 10000\n", " },\n", " \"inputWidgets\": {},\n", " \"nuid\": \"f3c68b28-fc62-40ae-8528-799f3fc7507e\",\n", " \"showTitle\": true,\n", " \"title\": \"Show Logs\"\n", " },\n", " \"jupyter\": {\n", " \"source_hidden\": true\n", " }\n", " },\n", " \"outputs\": [],\n", " \"source\": [\n", " \"%sh\\n\",\n", " \"cat $CONSOLE_ERROR_PATH\"\n", " ]\n", " },\n", " {\n", " \"cell_type\": \"code\",\n", " \"execution_count\": 0,\n", " \"metadata\": {\n", " \"application/vnd.databricks.v1+cell\": {\n", " \"cellMetadata\": {\n", " \"byteLimit\": 2048000,\n", " \"rowLimit\": 10000\n", " },\n", " \"inputWidgets\": {},\n", " \"nuid\": \"05f96ca1-1b08-494c-a12b-7e6cc3dcc546\",\n", " \"showTitle\": true,\n", " \"title\": \"Parse Output\"\n", " },\n", " \"jupyter\": {\n", " \"source_hidden\": true\n", " }\n", " },\n", " \"outputs\": [],\n", " \"source\": [\n", " \"import re\\n\",\n", " \"import shutil\\n\",\n", " \"import os\\n\",\n", " \"\\n\",\n", " \"def extract_file_info(console_output_path, output_base_path):\\n\",\n", " \" try:\\n\",\n", " \" with open(console_output_path, 'r') as file:\\n\",\n", " \" stdout_text = file.read()\\n\",\n", " \" \\n\",\n", " \" # Extract log file location\\n\",\n", " \" location_match = re.search(r\\\"Location: (.+)\\\", stdout_text)\\n\",\n", " \" if not location_match:\\n\",\n", " \" raise ValueError(\\\"Log file location not found in the provided text.\\\")\\n\",\n", " \" \\n\",\n", " \" log_file_location = location_match.group(1)\\n\",\n", " \" \\n\",\n", " \" # Extract profiling output folder\\n\",\n", " \" prof_match = re.search(r\\\"prof_[^/]+(?=\\\\.log)\\\", log_file_location)\\n\",\n", " \" if not prof_match:\\n\",\n", " \" raise ValueError(\\\"Output folder not found in the log file location.\\\")\\n\",\n", " \" \\n\",\n", " \" output_folder_name = prof_match.group(0)\\n\",\n", " \" output_folder = os.path.join(output_base_path, output_folder_name)\\n\",\n", " \" return output_folder, log_file_location\\n\",\n", " \" \\n\",\n", " \" except Exception as e:\\n\",\n", " \" raise RuntimeError(f\\\"Cannot parse console output. Reason: {e}\\\")\\n\",\n", " \"\\n\",\n", " \"def copy_logs(destination_folder, *log_files):\\n\",\n", " \" try:\\n\",\n", " \" log_folder = os.path.join(destination_folder, \\\"logs\\\")\\n\",\n", " \" os.makedirs(log_folder, exist_ok=True)\\n\",\n", " \" \\n\",\n", " \" for log_file in log_files:\\n\",\n", " \" if os.path.exists(log_file):\\n\",\n", " \" shutil.copy2(log_file, log_folder)\\n\",\n", " \" else:\\n\",\n", " \" print(f\\\"Log file not found: {log_file}\\\")\\n\",\n", " \" except Exception as e:\\n\",\n", " \" raise RuntimeError(f\\\"Cannot copy logs to output. Reason: {e}\\\")\\n\",\n", " \"\\n\",\n", " \"try:\\n\",\n", " \" output_folder, log_file_location = extract_file_info(CONSOLE_OUTPUT_PATH, OUTPUT_PATH)\\n\",\n", " \" print(f\\\"Output folder detected {output_folder}\\\")\\n\",\n", " \" copy_logs(output_folder, log_file_location, CONSOLE_OUTPUT_PATH, CONSOLE_ERROR_PATH)\\n\",\n", " \" print(f\\\"Logs successfully copied to {output_folder}\\\")\\n\",\n", " \"except Exception as e:\\n\",\n", " \" print(e)\\n\"\n", " ]\n", " },\n", " {\n", " \"cell_type\": \"code\",\n", " \"execution_count\": 0,\n", " \"metadata\": {\n", " \"application/vnd.databricks.v1+cell\": {\n", " \"cellMetadata\": {\n", " \"byteLimit\": 2048000,\n", " \"rowLimit\": 10000\n", " },\n", " \"inputWidgets\": {},\n", " \"nuid\": \"8c65adcd-a933-482e-a50b-d40fa8f50e16\",\n", " \"showTitle\": true,\n", " \"title\": \"Download Output\"\n", " },\n", " \"jupyter\": {\n", " \"source_hidden\": true\n", " }\n", " },\n", " \"outputs\": [],\n", " \"source\": [\n", " \"import shutil\\n\",\n", " \"import os\\n\",\n", " \"import re\\n\",\n", " \"\\n\",\n", " \"current_working_directory = os.getcwd()\\n\",\n", " \"\\n\",\n", " \"def create_destination_folders(folder_name):\\n\",\n", " \" os.makedirs(folder_name, exist_ok=True)\\n\",\n", " \" base_download_folder_path = os.path.join(\\\"/dbfs/FileStore/\\\", folder_name)\\n\",\n", " \" os.makedirs(base_download_folder_path, exist_ok=True) \\n\",\n", " \" return base_download_folder_path\\n\",\n", " \"\\n\",\n", " \"def create_download_link(source_folder, destination_folder_name):\\n\",\n", " \" folder_to_compress = os.path.basename(source_folder)\\n\",\n", " \" zip_file_name = folder_to_compress + '.zip'\\n\",\n", " \" local_zip_file_path = os.path.join(current_working_directory, destination_folder_name, zip_file_name)\\n\",\n", " \" download_folder_path = os.path.join(destination_folder_name, zip_file_name)\\n\",\n", " \" try:\\n\",\n", " \" base_download_folder_path = create_destination_folders(destination_folder_name)\\n\",\n", " \" shutil.make_archive(folder_to_compress, 'zip', source_folder)\\n\",\n", " \" shutil.copy2(zip_file_name, base_download_folder_path)\\n\",\n", " \" if os.path.exists(local_zip_file_path):\\n\",\n", " \" os.remove(local_zip_file_path)\\n\",\n", " \" shutil.move(zip_file_name, local_zip_file_path)\\n\",\n", " \" \\n\",\n", " \" download_button_html = f\\\"\\\"\\\"\\n\",\n", " \" \\n\",\n", " \" \\n\",\n", " \"

\\n\",\n", " \" Zipped output file created at {local_zip_file_path}\\n\",\n", " \"
\\n\",\n", " \"
\\n\",\n", " \" Download Output\\n\",\n", " \"
\\n\",\n", " \" \\\"\\\"\\\"\\n\",\n", " \" displayHTML(download_button_html)\\n\",\n", " \" except Exception as e:\\n\",\n", " \" error_message_html = f\\\"\\\"\\\"\\n\",\n", " \"
\\n\",\n", " \" Error: Cannot create download link for {source_folder}. Reason: {e}\\n\",\n", " \"
\\n\",\n", " \" \\\"\\\"\\\"\\n\",\n", " \" displayHTML(error_message_html)\\n\",\n", " \"\\n\",\n", " \"destination_folder_name = \\\"Tools_Output\\\"\\n\",\n", " \"create_download_link(output_folder, destination_folder_name)\\n\"\n", " ]\n", " },\n", " {\n", " \"cell_type\": \"markdown\",\n", " \"metadata\": {\n", " \"application/vnd.databricks.v1+cell\": {\n", " \"cellMetadata\": {\n", " \"byteLimit\": 2048000,\n", " \"rowLimit\": 10000\n", " },\n", " \"inputWidgets\": {},\n", " \"nuid\": \"bbe50fde-0bd6-4281-95fd-6a1ec6f17ab2\",\n", " \"showTitle\": false,\n", " \"title\": \"\"\n", " }\n", " },\n", " \"source\": [\n", " \"\\n\",\n", " \"%md\\n\",\n", " \"\\n\",\n", " \"## GPU Job Tuning Recommendations\\n\",\n", " \"This has general suggestions for tuning your applications to run optimally on GPUs.\\n\"\n", " ]\n", " },\n", " {\n", " \"cell_type\": \"code\",\n", " \"execution_count\": 0,\n", " \"metadata\": {\n", " \"application/vnd.databricks.v1+cell\": {\n", " \"cellMetadata\": {\n", " \"byteLimit\": 2048000,\n", " \"rowLimit\": 10000\n", " },\n", " \"inputWidgets\": {},\n", " \"nuid\": \"b8bca4a6-16d8-4b60-ba7b-9aff64bdcaa1\",\n", " \"showTitle\": true,\n", " \"title\": \"Show Recommendations\"\n", " },\n", " \"jupyter\": {\n", " \"source_hidden\": true\n", " }\n", " },\n", " \"outputs\": [],\n", " \"source\": [\n", " \"jar_output_folder = os.path.join(output_folder, \\\"rapids_4_spark_profile\\\")\\n\",\n", " \"app_df = pd.DataFrame(columns=['appId', 'appName'])\\n\",\n", " \"\\n\",\n", " \"for x in os.scandir(jar_output_folder):\\n\",\n", " \" if x.is_dir():\\n\",\n", " \" csv_path = os.path.join(x.path, \\\"application_information.csv\\\")\\n\",\n", " \" if os.path.exists(csv_path):\\n\",\n", " \" tmp_df = pd.read_csv(csv_path)\\n\",\n", " \" app_df = pd.concat([app_df, tmp_df[['appId', 'appName']]])\\n\",\n", " \"\\n\",\n", " \"\\n\",\n", " \"app_list = app_df[\\\"appId\\\"].tolist()\\n\",\n", " \"app_recommendations = pd.DataFrame(columns=['app', 'recommendations'])\\n\",\n", " \"\\n\",\n", " \"for app in app_list:\\n\",\n", " \" app_file = open(os.path.join(jar_output_folder, app, \\\"profile.log\\\"))\\n\",\n", " \" recommendations_start = 0\\n\",\n", " \" recommendations_str = \\\"\\\"\\n\",\n", " \" for line in app_file:\\n\",\n", " \" if recommendations_start == 1:\\n\",\n", " \" recommendations_str = recommendations_str + line\\n\",\n", " \" if \\\"### D. Recommended Configuration ###\\\" in line:\\n\",\n", " \" recommendations_start = 1\\n\",\n", " \" app_recommendations = pd.concat([app_recommendations, pd.DataFrame({'app': [app], 'recommendations': [recommendations_str]})], ignore_index=True)\\n\",\n", " \"display(app_recommendations)\"\n", " ]\n", " }\n", " ],\n", " \"metadata\": {\n", " \"application/vnd.databricks.v1+notebook\": {\n", " \"dashboards\": [\n", " {\n", " \"elements\": [],\n", " \"globalVars\": {},\n", " \"guid\": \"\",\n", " \"layoutOption\": {\n", " \"grid\": true,\n", " \"stack\": true\n", " },\n", " \"nuid\": \"91c1bfb2-695a-4e5c-8a25-848a433108dc\",\n", " \"origId\": 2173122769183713,\n", " \"title\": \"Executive View\",\n", " \"version\": \"DashboardViewV1\",\n", " \"width\": 1600\n", " },\n", " {\n", " \"elements\": [],\n", " \"globalVars\": {},\n", " \"guid\": \"\",\n", " \"layoutOption\": {\n", " \"grid\": true,\n", " \"stack\": true\n", " },\n", " \"nuid\": \"62243296-4562-4f06-90ac-d7a609f19c16\",\n", " \"origId\": 2173122769183714,\n", " \"title\": \"App View\",\n", " \"version\": \"DashboardViewV1\",\n", " \"width\": 1920\n", " }\n", " ],\n", " \"environmentMetadata\": null,\n", " \"language\": \"python\",\n", " \"notebookMetadata\": {\n", " \"mostRecentlyExecutedCommandWithImplicitDF\": {\n", " \"commandId\": 2173122769183692,\n", " \"dataframes\": [\n", " \"_sqldf\"\n", " ]\n", " },\n", " \"pythonIndentUnit\": 2,\n", " \"widgetLayout\": [\n", " {\n", " \"breakBefore\": false,\n", " \"name\": \"Eventlog Path\",\n", " \"width\": 778\n", " },\n", " {\n", " \"breakBefore\": false,\n", " \"name\": \"Output Path\",\n", " \"width\": 302\n", " }\n", " ]\n", " },\n", " \"notebookName\": \"[RAPIDS Accelerator for Apache Spark] Profiling Tool Notebook Template\",\n", " \"widgets\": {\n", " \"Eventlog Path\": {\n", " \"currentValue\": \"/dbfs/user1/profiling_logs\",\n", " \"nuid\": \"1272501d-5ad9-42be-ab62-35768b2fc384\",\n", " \"typedWidgetInfo\": null,\n", " \"widgetInfo\": {\n", " \"defaultValue\": \"/dbfs/user1/profiling_logs\",\n", " \"label\": \"\",\n", " \"name\": \"Eventlog Path\",\n", " \"options\": {\n", " \"autoCreated\": false,\n", " \"validationRegex\": null,\n", " \"widgetType\": \"text\"\n", " },\n", " \"widgetType\": \"text\"\n", " }\n", " },\n", " \"Output Path\": {\n", " \"currentValue\": \"/tmp\",\n", " \"nuid\": \"ab7e082c-1ef9-4912-8fd7-51bf985eb9c1\",\n", " \"typedWidgetInfo\": null,\n", " \"widgetInfo\": {\n", " \"defaultValue\": \"/tmp\",\n", " \"label\": null,\n", " \"name\": \"Output Path\",\n", " \"options\": {\n", " \"autoCreated\": null,\n", " \"validationRegex\": null,\n", " \"widgetType\": \"text\"\n", " },\n", " \"widgetType\": \"text\"\n", " }\n", " }\n", " }\n", " },\n", " \"language_info\": {\n", " \"name\": \"python\"\n", " }\n", " },\n", " \"nbformat\": 4,\n", " \"nbformat_minor\": 0\n", "}\n" ], "id": "4e6b53d2fff1910e" } ], "metadata": {}, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: tools/databricks/[RAPIDS Accelerator for Apache Spark] Qualification Tool Notebook Template.ipynb ================================================ { "cells": [ { "metadata": {}, "cell_type": "raw", "source": [ "{\n", " \"cells\": [\n", " {\n", " \"cell_type\": \"markdown\",\n", " \"metadata\": {\n", " \"application/vnd.databricks.v1+cell\": {\n", " \"cellMetadata\": {\n", " \"byteLimit\": 2048000,\n", " \"rowLimit\": 10000\n", " },\n", " \"inputWidgets\": {},\n", " \"nuid\": \"df33c614-2ecc-47a0-8600-bc891681997f\",\n", " \"showTitle\": false,\n", " \"title\": \"\"\n", " }\n", " },\n", " \"source\": [\n", " \"## Welcome to the Qualification Tool for the RAPIDS Accelerator for Apache Spark\\n\",\n", " \"\\n\",\n", " \"To run the qualification tool, enter the log path that represents the DBFS location of your Spark GPU event logs. Then, select \\\"Run all\\\" to execute the notebook. Once the notebook completes, various output tables will appear below. For more options on running the profiling tool, please refer to the [Qualification Tool User Guide](https://docs.nvidia.com/spark-rapids/user-guide/latest/qualification/quickstart.html#running-the-tool).\\n\",\n", " \"\\n\",\n", " \"### Note\\n\",\n", " \"- Currently, local, S3 or DBFS event log paths are supported.\\n\",\n", " \"- S3 path is only supported on Databricks AWS using [instance profiles](https://docs.databricks.com/en/connect/storage/tutorial-s3-instance-profile.html).\\n\",\n", " \"- Eventlog path must follow the formats `/dbfs/path/to/eventlog` or `dbfs:/path/to/eventlog` for logs stored in DBFS.\\n\",\n", " \"- Use wildcards for nested lookup of eventlogs. \\n\",\n", " \" - For example: `/dbfs/path/to/clusterlogs/*/*`\\n\",\n", " \"- Multiple event logs must be comma-separated. \\n\",\n", " \" - For example: `/dbfs/path/to/eventlog1,/dbfs/path/to/eventlog2`\\n\"\n", " ]\n", " },\n", " {\n", " \"cell_type\": \"code\",\n", " \"execution_count\": 0,\n", " \"metadata\": {\n", " \"application/vnd.databricks.v1+cell\": {\n", " \"cellMetadata\": {\n", " \"byteLimit\": 2048000,\n", " \"rowLimit\": 10000\n", " },\n", " \"inputWidgets\": {},\n", " \"nuid\": \"5e9f5796-46ed-49ac-9d08-c8b98a87c39d\",\n", " \"showTitle\": true,\n", " \"title\": \"Set Tools Version\"\n", " },\n", " \"jupyter\": {\n", " \"source_hidden\": true\n", " }\n", " },\n", " \"outputs\": [],\n", " \"source\": [\n", " \"DEFAULT_TOOLS_VER = \\\"24.12.4\\\"\\n\",\n", " \"TOOLS_VER_ARG = dbutils.widgets.get(\\\"Tools Version\\\")\\n\",\n", " \"TOOLS_VER = TOOLS_VER_ARG if TOOLS_VER_ARG else DEFAULT_TOOLS_VER\\n\",\n", " \"print(f\\\"Using Tools Version: {TOOLS_VER}\\\")\"\n", " ]\n", " },\n", " {\n", " \"cell_type\": \"code\",\n", " \"execution_count\": 0,\n", " \"metadata\": {\n", " \"application/vnd.databricks.v1+cell\": {\n", " \"cellMetadata\": {\n", " \"byteLimit\": 2048000,\n", " \"rowLimit\": 10000\n", " },\n", " \"inputWidgets\": {},\n", " \"nuid\": \"313ee58b-61b3-4010-9d60-d21eceea796c\",\n", " \"showTitle\": true,\n", " \"title\": \"Install Package\"\n", " }\n", " },\n", " \"outputs\": [],\n", " \"source\": [\n", " \"%pip install spark-rapids-user-tools==$TOOLS_VER > /dev/null\"\n", " ]\n", " },\n", " {\n", " \"cell_type\": \"code\",\n", " \"execution_count\": 0,\n", " \"metadata\": {\n", " \"application/vnd.databricks.v1+cell\": {\n", " \"cellMetadata\": {\n", " \"byteLimit\": 2048000,\n", " \"rowLimit\": 10000\n", " },\n", " \"inputWidgets\": {},\n", " \"nuid\": \"acf401a3-12d3-4236-a6c5-8fe8990b153a\",\n", " \"showTitle\": true,\n", " \"title\": \"Environment Setup\"\n", " },\n", " \"jupyter\": {\n", " \"source_hidden\": true\n", " }\n", " },\n", " \"outputs\": [],\n", " \"source\": [\n", " \"import os\\n\",\n", " \"import pandas as pd\\n\",\n", " \"\\n\",\n", " \"\\n\",\n", " \"def convert_dbfs_path(path):\\n\",\n", " \" return path.replace(\\\"dbfs:/\\\", \\\"/dbfs/\\\")\\n\",\n", " \" \\n\",\n", " \"# Detect cloud provider from cluster usage tags\\n\",\n", " \"valid_csps = [\\\"aws\\\", \\\"azure\\\"]\\n\",\n", " \"CSP=spark.conf.get(\\\"spark.databricks.clusterUsageTags.cloudProvider\\\", \\\"\\\").lower()\\n\",\n", " \"if CSP not in valid_csps:\\n\",\n", " \" print(f\\\"ERROR: Cannot detect cloud provider from cluster usage tags. Using '{valid_csps[0]}' as default. \\\")\\n\",\n", " \" CSP = valid_csps[0]\\n\",\n", " \"else:\\n\",\n", " \" print(f\\\"Detected Cloud Provider from Spark Configs: '{CSP}'\\\")\\n\",\n", " \"\\n\",\n", " \"# Initialize variables from widgets\\n\",\n", " \"dbutils.widgets.text(\\\"Eventlog Path\\\", \\\"/dbfs/user1/qualification_logs\\\")\\n\",\n", " \"EVENTLOG_PATH=dbutils.widgets.get(\\\"Eventlog Path\\\")\\n\",\n", " \"EVENTLOG_PATH=convert_dbfs_path(EVENTLOG_PATH)\\n\",\n", " \"\\n\",\n", " \"dbutils.widgets.text(\\\"Output Path\\\", \\\"/tmp\\\")\\n\",\n", " \"OUTPUT_PATH=dbutils.widgets.get(\\\"Output Path\\\")\\n\",\n", " \"\\n\",\n", " \" \\n\",\n", " \"# Setup environment variables\\n\",\n", " \"os.environ[\\\"CSP\\\"] = CSP\\n\",\n", " \"os.environ[\\\"EVENTLOG_PATH\\\"] = EVENTLOG_PATH\\n\",\n", " \"os.environ[\\\"OUTPUT_PATH\\\"] = OUTPUT_PATH\\n\",\n", " \"\\n\",\n", " \"# Setup console output file\\n\",\n", " \"CONSOLE_OUTPUT_PATH = os.path.join(OUTPUT_PATH, 'console_output.log')\\n\",\n", " \"CONSOLE_ERROR_PATH = os.path.join(OUTPUT_PATH, 'console_error.log')\\n\",\n", " \"os.environ['CONSOLE_OUTPUT_PATH'] = CONSOLE_OUTPUT_PATH\\n\",\n", " \"os.environ['CONSOLE_ERROR_PATH'] = CONSOLE_ERROR_PATH\\n\",\n", " \"print(f'Console output will be stored at {CONSOLE_OUTPUT_PATH} and errors will be stored at {CONSOLE_ERROR_PATH}')\"\n", " ]\n", " },\n", " {\n", " \"cell_type\": \"code\",\n", " \"execution_count\": 0,\n", " \"metadata\": {\n", " \"application/vnd.databricks.v1+cell\": {\n", " \"cellMetadata\": {\n", " \"byteLimit\": 2048000,\n", " \"rowLimit\": 10000\n", " },\n", " \"inputWidgets\": {},\n", " \"nuid\": \"693b5ee0-7500-43f3-b3e2-717fd5468aa8\",\n", " \"showTitle\": true,\n", " \"title\": \"Run Qualification Tool\"\n", " }\n", " },\n", " \"outputs\": [],\n", " \"source\": [\n", " \"%sh\\n\",\n", " \"spark_rapids qualification --platform databricks-$CSP --eventlogs \\\"$EVENTLOG_PATH\\\" -o \\\"$OUTPUT_PATH\\\" --verbose > \\\"$CONSOLE_OUTPUT_PATH\\\" 2> \\\"$CONSOLE_ERROR_PATH\\\"\"\n", " ]\n", " },\n", " {\n", " \"cell_type\": \"markdown\",\n", " \"metadata\": {\n", " \"application/vnd.databricks.v1+cell\": {\n", " \"cellMetadata\": {\n", " \"byteLimit\": 2048000,\n", " \"rowLimit\": 10000\n", " },\n", " \"inputWidgets\": {},\n", " \"nuid\": \"f83af6c8-5a79-4a46-965b-38a4cb621877\",\n", " \"showTitle\": false,\n", " \"title\": \"\"\n", " }\n", " },\n", " \"source\": [\n", " \"## Console Output\\n\",\n", " \"Console output shows the top candidates and their estimated GPU speedup.\\n\"\n", " ]\n", " },\n", " {\n", " \"cell_type\": \"code\",\n", " \"execution_count\": 0,\n", " \"metadata\": {\n", " \"application/vnd.databricks.v1+cell\": {\n", " \"cellMetadata\": {\n", " \"byteLimit\": 2048000,\n", " \"rowLimit\": 10000\n", " },\n", " \"inputWidgets\": {},\n", " \"nuid\": \"c61527b7-a21a-492c-bab8-77f83dc5cabf\",\n", " \"showTitle\": true,\n", " \"title\": \"Show Console Output\"\n", " },\n", " \"jupyter\": {\n", " \"source_hidden\": true\n", " }\n", " },\n", " \"outputs\": [],\n", " \"source\": [\n", " \"%sh\\n\",\n", " \"cat $CONSOLE_OUTPUT_PATH\"\n", " ]\n", " },\n", " {\n", " \"cell_type\": \"code\",\n", " \"execution_count\": 0,\n", " \"metadata\": {\n", " \"application/vnd.databricks.v1+cell\": {\n", " \"cellMetadata\": {\n", " \"byteLimit\": 2048000,\n", " \"rowLimit\": 10000\n", " },\n", " \"inputWidgets\": {},\n", " \"nuid\": \"f3c68b28-fc62-40ae-8528-799f3fc7507e\",\n", " \"showTitle\": true,\n", " \"title\": \"Show Logs\"\n", " },\n", " \"jupyter\": {\n", " \"source_hidden\": true\n", " }\n", " },\n", " \"outputs\": [],\n", " \"source\": [\n", " \"%sh\\n\",\n", " \"cat $CONSOLE_ERROR_PATH\"\n", " ]\n", " },\n", " {\n", " \"cell_type\": \"code\",\n", " \"execution_count\": 0,\n", " \"metadata\": {\n", " \"application/vnd.databricks.v1+cell\": {\n", " \"cellMetadata\": {\n", " \"byteLimit\": 2048000,\n", " \"rowLimit\": 10000\n", " },\n", " \"inputWidgets\": {},\n", " \"nuid\": \"05f96ca1-1b08-494c-a12b-7e6cc3dcc546\",\n", " \"showTitle\": true,\n", " \"title\": \"Parse Output\"\n", " },\n", " \"jupyter\": {\n", " \"source_hidden\": true\n", " }\n", " },\n", " \"outputs\": [],\n", " \"source\": [\n", " \"import re\\n\",\n", " \"import shutil\\n\",\n", " \"import os\\n\",\n", " \"\\n\",\n", " \"def extract_file_info(console_output_path, output_base_path):\\n\",\n", " \" try:\\n\",\n", " \" with open(console_output_path, 'r') as file:\\n\",\n", " \" stdout_text = file.read()\\n\",\n", " \" \\n\",\n", " \" # Extract log file location\\n\",\n", " \" location_match = re.search(r\\\"Location: (.+)\\\", stdout_text)\\n\",\n", " \" if not location_match:\\n\",\n", " \" raise ValueError(\\\"Log file location not found in the provided text.\\\")\\n\",\n", " \" \\n\",\n", " \" log_file_location = location_match.group(1)\\n\",\n", " \" \\n\",\n", " \" # Extract qualification output folder\\n\",\n", " \" qual_match = re.search(r\\\"qual_[^/]+(?=\\\\.log)\\\", log_file_location)\\n\",\n", " \" if not qual_match:\\n\",\n", " \" raise ValueError(\\\"Output folder not found in the log file location.\\\")\\n\",\n", " \" \\n\",\n", " \" output_folder_name = qual_match.group(0)\\n\",\n", " \" output_folder = os.path.join(output_base_path, output_folder_name)\\n\",\n", " \" return output_folder, log_file_location\\n\",\n", " \" \\n\",\n", " \" except Exception as e:\\n\",\n", " \" raise RuntimeError(f\\\"Cannot parse console output. Reason: {e}\\\")\\n\",\n", " \"\\n\",\n", " \"def copy_logs(destination_folder, *log_files):\\n\",\n", " \" try:\\n\",\n", " \" log_folder = os.path.join(destination_folder, \\\"logs\\\")\\n\",\n", " \" os.makedirs(log_folder, exist_ok=True)\\n\",\n", " \" \\n\",\n", " \" for log_file in log_files:\\n\",\n", " \" if os.path.exists(log_file):\\n\",\n", " \" shutil.copy2(log_file, log_folder)\\n\",\n", " \" else:\\n\",\n", " \" print(f\\\"Log file not found: {log_file}\\\")\\n\",\n", " \" except Exception as e:\\n\",\n", " \" raise RuntimeError(f\\\"Cannot copy logs to output. Reason: {e}\\\")\\n\",\n", " \"\\n\",\n", " \"try:\\n\",\n", " \" output_folder, log_file_location = extract_file_info(CONSOLE_OUTPUT_PATH, OUTPUT_PATH)\\n\",\n", " \" jar_output_folder = os.path.join(output_folder, \\\"rapids_4_spark_qualification_output\\\")\\n\",\n", " \" print(f\\\"Output folder detected {output_folder}\\\")\\n\",\n", " \" copy_logs(output_folder, log_file_location, CONSOLE_OUTPUT_PATH, CONSOLE_ERROR_PATH)\\n\",\n", " \" print(f\\\"Logs successfully copied to {output_folder}\\\")\\n\",\n", " \"except Exception as e:\\n\",\n", " \" print(e)\\n\"\n", " ]\n", " },\n", " {\n", " \"cell_type\": \"code\",\n", " \"execution_count\": 0,\n", " \"metadata\": {\n", " \"application/vnd.databricks.v1+cell\": {\n", " \"cellMetadata\": {\n", " \"byteLimit\": 2048000,\n", " \"rowLimit\": 10000\n", " },\n", " \"inputWidgets\": {},\n", " \"nuid\": \"8c65adcd-a933-482e-a50b-d40fa8f50e16\",\n", " \"showTitle\": true,\n", " \"title\": \"Download Output\"\n", " },\n", " \"jupyter\": {\n", " \"source_hidden\": true\n", " }\n", " },\n", " \"outputs\": [],\n", " \"source\": [\n", " \"import shutil\\n\",\n", " \"import os\\n\",\n", " \"import re\\n\",\n", " \"\\n\",\n", " \"current_working_directory = os.getcwd()\\n\",\n", " \"\\n\",\n", " \"def create_destination_folders(folder_name):\\n\",\n", " \" os.makedirs(folder_name, exist_ok=True)\\n\",\n", " \" base_download_folder_path = os.path.join(\\\"/dbfs/FileStore/\\\", folder_name)\\n\",\n", " \" os.makedirs(base_download_folder_path, exist_ok=True) \\n\",\n", " \" return base_download_folder_path\\n\",\n", " \"\\n\",\n", " \"def create_download_link(source_folder, destination_folder_name):\\n\",\n", " \" folder_to_compress = os.path.basename(source_folder)\\n\",\n", " \" zip_file_name = folder_to_compress + '.zip'\\n\",\n", " \" local_zip_file_path = os.path.join(current_working_directory, destination_folder_name, zip_file_name)\\n\",\n", " \" download_folder_path = os.path.join(destination_folder_name, zip_file_name)\\n\",\n", " \" try:\\n\",\n", " \" base_download_folder_path = create_destination_folders(destination_folder_name)\\n\",\n", " \" shutil.make_archive(folder_to_compress, 'zip', source_folder)\\n\",\n", " \" shutil.copy2(zip_file_name, base_download_folder_path)\\n\",\n", " \" if os.path.exists(local_zip_file_path):\\n\",\n", " \" os.remove(local_zip_file_path)\\n\",\n", " \" shutil.move(zip_file_name, local_zip_file_path)\\n\",\n", " \" \\n\",\n", " \" download_button_html = f\\\"\\\"\\\"\\n\",\n", " \" \\n\",\n", " \" \\n\",\n", " \"
\\n\",\n", " \" Zipped output file created at {local_zip_file_path}\\n\",\n", " \"
\\n\",\n", " \"
\\n\",\n", " \" Download Output\\n\",\n", " \"
\\n\",\n", " \" \\\"\\\"\\\"\\n\",\n", " \" displayHTML(download_button_html)\\n\",\n", " \" except Exception as e:\\n\",\n", " \" error_message_html = f\\\"\\\"\\\"\\n\",\n", " \"
\\n\",\n", " \" Error: Cannot create download link for {source_folder}. Reason: {e}\\n\",\n", " \"
\\n\",\n", " \" \\\"\\\"\\\"\\n\",\n", " \" displayHTML(error_message_html)\\n\",\n", " \"\\n\",\n", " \"destination_folder_name = \\\"Tools_Output\\\"\\n\",\n", " \"create_download_link(output_folder, destination_folder_name)\\n\"\n", " ]\n", " },\n", " {\n", " \"cell_type\": \"markdown\",\n", " \"metadata\": {\n", " \"application/vnd.databricks.v1+cell\": {\n", " \"cellMetadata\": {\n", " \"byteLimit\": 2048000,\n", " \"rowLimit\": 10000\n", " },\n", " \"inputWidgets\": {},\n", " \"nuid\": \"bbe50fde-0bd6-4281-95fd-6a1ec6f17ab2\",\n", " \"showTitle\": false,\n", " \"title\": \"\"\n", " }\n", " },\n", " \"source\": [\n", " \"\\n\",\n", " \"## Summary Output\\n\",\n", " \"\\n\",\n", " \"The report provides a comprehensive overview of the entire application execution, estimated speedup, including unsupported operators and non-SQL operations. By default, the applications and queries are sorted in descending order based on the following fields:\\n\",\n", " \"\\n\",\n", " \"- Estimated GPU Speedup Category\\n\",\n", " \"- Estimated GPU Speedup\"\n", " ]\n", " },\n", " {\n", " \"cell_type\": \"code\",\n", " \"execution_count\": 0,\n", " \"metadata\": {\n", " \"application/vnd.databricks.v1+cell\": {\n", " \"cellMetadata\": {\n", " \"byteLimit\": 2048000,\n", " \"rowLimit\": 10000\n", " },\n", " \"inputWidgets\": {},\n", " \"nuid\": \"b8bca4a6-16d8-4b60-ba7b-9aff64bdcaa1\",\n", " \"showTitle\": true,\n", " \"title\": \"qualification_summary.csv\"\n", " },\n", " \"jupyter\": {\n", " \"source_hidden\": true\n", " }\n", " },\n", " \"outputs\": [],\n", " \"source\": [\n", " \"summary_output=pd.read_csv(os.path.join(output_folder, \\\"qualification_summary.csv\\\"))\\n\",\n", " \"summary_output=summary_output.drop(columns=[\\\"Unnamed: 0\\\"]).rename_axis('Index').reset_index()\\n\",\n", " \"display(summary_output)\"\n", " ]\n", " },\n", " {\n", " \"cell_type\": \"markdown\",\n", " \"metadata\": {\n", " \"application/vnd.databricks.v1+cell\": {\n", " \"cellMetadata\": {},\n", " \"inputWidgets\": {},\n", " \"nuid\": \"73b5e0b0-3a96-4cc6-8e6c-840e4b0d9d43\",\n", " \"showTitle\": false,\n", " \"title\": \"\"\n", " }\n", " },\n", " \"source\": [\n", " \"\\n\",\n", " \"## Application Status\\n\",\n", " \"\\n\",\n", " \"The report show the status of each eventlog file that was provided\\n\"\n", " ]\n", " },\n", " {\n", " \"cell_type\": \"code\",\n", " \"execution_count\": 0,\n", " \"metadata\": {\n", " \"application/vnd.databricks.v1+cell\": {\n", " \"cellMetadata\": {\n", " \"byteLimit\": 2048000,\n", " \"rowLimit\": 10000\n", " },\n", " \"inputWidgets\": {},\n", " \"nuid\": \"c9ffbfdb-dbb6-4736-b9cb-2ac457cc6714\",\n", " \"showTitle\": true,\n", " \"title\": \"rapids_4_spark_qualification_output_status.csv\"\n", " },\n", " \"jupyter\": {\n", " \"source_hidden\": true\n", " }\n", " },\n", " \"outputs\": [],\n", " \"source\": [\n", " \"status_output=pd.read_csv(os.path.join(jar_output_folder, \\\"rapids_4_spark_qualification_output_status.csv\\\"))\\n\",\n", " \"display(status_output)\"\n", " ]\n", " },\n", " {\n", " \"cell_type\": \"markdown\",\n", " \"metadata\": {\n", " \"application/vnd.databricks.v1+cell\": {\n", " \"cellMetadata\": {},\n", " \"inputWidgets\": {},\n", " \"nuid\": \"09945d39-f9c2-4f4a-8afd-4f309f24f8e0\",\n", " \"showTitle\": false,\n", " \"title\": \"\"\n", " }\n", " },\n", " \"source\": [\n", " \"\\n\",\n", " \"## Metadata for Migration\\n\",\n", " \"\\n\",\n", " \"The report show the metadata of each app as:\\n\",\n", " \"- Recommended GPU cluster\\n\",\n", " \"- File location of full cluster config recommendations\\n\",\n", " \"- File location of only Gpu specific config recommendations\\n\"\n", " ]\n", " },\n", " {\n", " \"cell_type\": \"code\",\n", " \"execution_count\": 0,\n", " \"metadata\": {\n", " \"application/vnd.databricks.v1+cell\": {\n", " \"cellMetadata\": {\n", " \"byteLimit\": 2048000,\n", " \"rowLimit\": 10000\n", " },\n", " \"inputWidgets\": {},\n", " \"nuid\": \"133cf1bd-33b6-4a62-9ae2-5505717092d1\",\n", " \"showTitle\": true,\n", " \"title\": \"app_metadata.json\"\n", " },\n", " \"jupyter\": {\n", " \"source_hidden\": true\n", " }\n", " },\n", " \"outputs\": [],\n", " \"source\": [\n", " \"import json\\n\",\n", " \"metadata_file = os.path.join(output_folder, \\\"app_metadata.json\\\")\\n\",\n", " \"def camel_to_title(name):\\n\",\n", " \" return re.sub('([a-z])([A-Z])', r'\\\\1 \\\\2', name).title()\\n\",\n", " \" \\n\",\n", " \"with open(metadata_file, 'r') as file:\\n\",\n", " \" json_data = json.load(file)\\n\",\n", " \"\\n\",\n", " \"df = pd.DataFrame(json_data)\\n\",\n", " \"df['recommendedGpuCluster'] = df['clusterInfo'].apply(lambda x: x['recommendedCluster'])\\n\",\n", " \"df['sourceCluster'] = df['clusterInfo'].apply(lambda x: x['sourceCluster'])\\n\",\n", " \"df.drop(columns=['clusterInfo'], inplace=True)\\n\",\n", " \"df = df[['appId', 'appName', 'estimatedGpuSpeedupCategory', 'recommendedGpuCluster', 'fullClusterConfigRecommendations', 'gpuConfigRecommendationBreakdown']]\\n\",\n", " \"df.columns = [camel_to_title(col) for col in df.columns]\\n\",\n", " \"display(df)\"\n", " ]\n", " },\n", " {\n", " \"cell_type\": \"markdown\",\n", " \"metadata\": {\n", " \"application/vnd.databricks.v1+cell\": {\n", " \"cellMetadata\": {\n", " \"byteLimit\": 2048000,\n", " \"rowLimit\": 10000\n", " },\n", " \"inputWidgets\": {},\n", " \"nuid\": \"6756159b-30ca-407a-ab6b-9c29ced01ea6\",\n", " \"showTitle\": false,\n", " \"title\": \"\"\n", " }\n", " },\n", " \"source\": [\n", " \"## Stages Output\\n\",\n", " \"\\n\",\n", " \"For each stage used in SQL operations, the Qualification tool generates the following information:\\n\",\n", " \"\\n\",\n", " \"1. App ID\\n\",\n", " \"2. Stage ID\\n\",\n", " \"3. Average Speedup Factor: The average estimated speed-up of all the operators in the given stage.\\n\",\n", " \"4. Stage Task Duration: The amount of time spent in tasks of SQL DataFrame operations for the given stage.\\n\",\n", " \"5. Unsupported Task Duration: The sum of task durations for the unsupported operators. For more details, see [Supported Operators](https://nvidia.github.io/spark-rapids/docs/supported_ops.html).\\n\",\n", " \"6. Stage Estimated: Indicates if the stage duration had to be estimated (True or False).\\n\"\n", " ]\n", " },\n", " {\n", " \"cell_type\": \"code\",\n", " \"execution_count\": 0,\n", " \"metadata\": {\n", " \"application/vnd.databricks.v1+cell\": {\n", " \"cellMetadata\": {\n", " \"byteLimit\": 2048000,\n", " \"rowLimit\": 10000\n", " },\n", " \"inputWidgets\": {},\n", " \"nuid\": \"cdde6177-db5f-434a-995b-776678a64a3a\",\n", " \"showTitle\": true,\n", " \"title\": \"rapids_4_spark_qualification_output_stages.csv\"\n", " },\n", " \"jupyter\": {\n", " \"source_hidden\": true\n", " }\n", " },\n", " \"outputs\": [],\n", " \"source\": [\n", " \"stages_output=pd.read_csv(os.path.join(jar_output_folder, \\\"rapids_4_spark_qualification_output_stages.csv\\\"))\\n\",\n", " \"display(stages_output)\"\n", " ]\n", " },\n", " {\n", " \"cell_type\": \"markdown\",\n", " \"metadata\": {\n", " \"application/vnd.databricks.v1+cell\": {\n", " \"cellMetadata\": {\n", " \"byteLimit\": 2048000,\n", " \"rowLimit\": 10000\n", " },\n", " \"inputWidgets\": {},\n", " \"nuid\": \"4d7ce219-ae75-4a0c-a78c-4e7f25b8cd6f\",\n", " \"showTitle\": false,\n", " \"title\": \"\"\n", " }\n", " },\n", " \"source\": [\n", " \"## Execs Output\\n\",\n", " \"\\n\",\n", " \"The Qualification tool generates a report of the “Exec” in the “SparkPlan” or “Executor Nodes” along with the estimated acceleration on the GPU. Please refer to the [Supported Operators guide](https://nvidia.github.io/spark-rapids/docs/supported_ops.html) for more details on limitations on UDFs and unsupported operators.\\n\",\n", " \"\\n\",\n", " \"1. App ID\\n\",\n", " \"2. SQL ID\\n\",\n", " \"3. Exec Name: Example: Filter, HashAggregate\\n\",\n", " \"4. Expression Name\\n\",\n", " \"5. Task Speedup Factor: The average acceleration of the operators based on the original CPU duration of the operator divided by the GPU duration. The tool uses historical queries and benchmarks to estimate a speed-up at an individual operator level to calculate how much a specific operator would accelerate on GPU.\\n\",\n", " \"6. Exec Duration: Wall-clock time measured from when the operator starts until it is completed.\\n\",\n", " \"7. SQL Node ID\\n\",\n", " \"8. Exec Is Supported: Indicates whether the Exec is supported by RAPIDS. Refer to the Supported Operators section for details.\\n\",\n", " \"9. Exec Stages: An array of stage IDs.\\n\",\n", " \"10. Exec Children\\n\",\n", " \"11. Exec Children Node IDs\\n\",\n", " \"12. Exec Should Remove: Indicates whether the Op is removed from the migrated plan.\\n\"\n", " ]\n", " },\n", " {\n", " \"cell_type\": \"code\",\n", " \"execution_count\": 0,\n", " \"metadata\": {\n", " \"application/vnd.databricks.v1+cell\": {\n", " \"cellMetadata\": {\n", " \"byteLimit\": 2048000,\n", " \"rowLimit\": 10000\n", " },\n", " \"inputWidgets\": {},\n", " \"nuid\": \"998b0c51-0cb6-408e-a01a-d1f5b1a61e1f\",\n", " \"showTitle\": true,\n", " \"title\": \"rapids_4_spark_qualification_output_execs.csv\"\n", " },\n", " \"jupyter\": {\n", " \"source_hidden\": true\n", " }\n", " },\n", " \"outputs\": [],\n", " \"source\": [\n", " \"execs_output=pd.read_csv(os.path.join(jar_output_folder, \\\"rapids_4_spark_qualification_output_execs.csv\\\"))\\n\",\n", " \"display(execs_output)\"\n", " ]\n", " }\n", " ],\n", " \"metadata\": {\n", " \"application/vnd.databricks.v1+notebook\": {\n", " \"dashboards\": [\n", " {\n", " \"elements\": [],\n", " \"globalVars\": {},\n", " \"guid\": \"\",\n", " \"layoutOption\": {\n", " \"grid\": true,\n", " \"stack\": true\n", " },\n", " \"nuid\": \"91c1bfb2-695a-4e5c-8a25-848a433108dc\",\n", " \"origId\": 2173122769183715,\n", " \"title\": \"Executive View\",\n", " \"version\": \"DashboardViewV1\",\n", " \"width\": 1600\n", " },\n", " {\n", " \"elements\": [],\n", " \"globalVars\": {},\n", " \"guid\": \"\",\n", " \"layoutOption\": {\n", " \"grid\": true,\n", " \"stack\": true\n", " },\n", " \"nuid\": \"62243296-4562-4f06-90ac-d7a609f19c16\",\n", " \"origId\": 2173122769183716,\n", " \"title\": \"App View\",\n", " \"version\": \"DashboardViewV1\",\n", " \"width\": 1920\n", " },\n", " {\n", " \"elements\": [],\n", " \"globalVars\": {},\n", " \"guid\": \"\",\n", " \"layoutOption\": {\n", " \"grid\": true,\n", " \"stack\": true\n", " },\n", " \"nuid\": \"854f9c75-5977-42aa-b3dd-c680b8331f19\",\n", " \"origId\": 2173122769183722,\n", " \"title\": \"Untitled\",\n", " \"version\": \"DashboardViewV1\",\n", " \"width\": 1024\n", " }\n", " ],\n", " \"environmentMetadata\": null,\n", " \"language\": \"python\",\n", " \"notebookMetadata\": {\n", " \"mostRecentlyExecutedCommandWithImplicitDF\": {\n", " \"commandId\": 2173122769183704,\n", " \"dataframes\": [\n", " \"_sqldf\"\n", " ]\n", " },\n", " \"pythonIndentUnit\": 2,\n", " \"widgetLayout\": [\n", " {\n", " \"breakBefore\": false,\n", " \"name\": \"Eventlog Path\",\n", " \"width\": 778\n", " },\n", " {\n", " \"breakBefore\": false,\n", " \"name\": \"Output Path\",\n", " \"width\": 302\n", " }\n", " ]\n", " },\n", " \"notebookName\": \"[RAPIDS Accelerator for Apache Spark] Qualification Tool Notebook Template\",\n", " \"widgets\": {\n", " \"Eventlog Path\": {\n", " \"currentValue\": \"/dbfs/user1/qualification_logs\",\n", " \"nuid\": \"1272501d-5ad9-42be-ab62-35768b2fc384\",\n", " \"typedWidgetInfo\": null,\n", " \"widgetInfo\": {\n", " \"defaultValue\": \"/dbfs/user1/qualification_logs\",\n", " \"label\": \"\",\n", " \"name\": \"Eventlog Path\",\n", " \"options\": {\n", " \"autoCreated\": false,\n", " \"validationRegex\": null,\n", " \"widgetType\": \"text\"\n", " },\n", " \"widgetType\": \"text\"\n", " }\n", " },\n", " \"Output Path\": {\n", " \"currentValue\": \"/tmp\",\n", " \"nuid\": \"ab7e082c-1ef9-4912-8fd7-51bf985eb9c1\",\n", " \"typedWidgetInfo\": null,\n", " \"widgetInfo\": {\n", " \"defaultValue\": \"/tmp\",\n", " \"label\": null,\n", " \"name\": \"Output Path\",\n", " \"options\": {\n", " \"autoCreated\": null,\n", " \"validationRegex\": null,\n", " \"widgetType\": \"text\"\n", " },\n", " \"widgetType\": \"text\"\n", " }\n", " }\n", " }\n", " },\n", " \"language_info\": {\n", " \"name\": \"python\"\n", " }\n", " },\n", " \"nbformat\": 4,\n", " \"nbformat_minor\": 0\n", "}\n" ], "id": "4ba18da2c217d2f1" } ], "metadata": {}, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: tools/emr/README.md ================================================ # EMR Qualification/Profiling Quick Start Notebooks The RAPIDS Accelerator for Apache Spark includes two key tools for understanding the benefits of GPU acceleration as well as analyzing GPU Spark jobs. For customers on EMR, the quick start notebooks offer a simple interface for running the tools given a set of Spark event logs from CPU (qualification) or GPU (profiling) application runs. ## Usage ### Pre-requisites: Setup EMR Studio and Workspace 1. Ensure that you have an **EMR cluster** running. 2. Set up **EMR Studio** and **Workspace** by following the instructions in the [AWS Documentation](https://docs.aws.amazon.com/emr/latest/ManagementGuide/emr-studio-create-studio.html): - Select **Custom Settings** while creating the Studio. - Choose the **VPC** and **Subnet** where the EMR cluster is running. 3. Attach the Workspace to the running EMR cluster. For more details, refer to the [AWS Documentation](https://docs.aws.amazon.com/emr/latest/ManagementGuide/emr-studio-create-use-clusters.html). ### Running the Notebook 1. Import the notebook into the EMR Workspace by dragging and dropping the notebook file. 2. In the **User Input** section of the notebook, enter the path to event log files. 3. Click the **fast-forward** icon labeled *Restart the kernel, then re-run the whole notebook* to process the logs at the specified path. ## Limitations 1. Currently, local and S3 event log paths are supported. 1. Eventlog path must follow the formats `/local/path/to/eventlog` for local logs or `s3://my-bucket/path/to/eventlog` for logs stored in S3. 1. The specified path can also be a directory. In such cases, the tool will recursively search for event logs within the directory. - For example: `/path/to/clusterlogs` 1. To specify multiple event logs, separate the paths with commas. - For example: `s3://my-bucket/path/to/eventlog1,s3://my-bucket/path/to/eventlog2` **Latest Tools Version Supported** 24.08.2 ================================================ FILE: tools/emr/[RAPIDS Accelerator for Apache Spark] Profiling Tool Notebook Template.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "df33c614-2ecc-47a0-8600-bc891681997f", "showTitle": false, "title": "" } }, "source": [ "## Profiling Tool for the RAPIDS Accelerator for Apache Spark\n", "\n", "To run the profiling tool, enter the log path that represents the location of your Spark GPU event logs. Then, select \"Run all\" to execute the notebook. Once the notebook completes, various output tables will appear below. For more options on running the profiling tool, please refer to the [Profiling Tool User Guide](https://docs.nvidia.com/spark-rapids/user-guide/latest/profiling/quickstart.html#running-the-tool).\n", "\n", "### Note\n", "- Currently, local and S3 event log paths are supported.\n", "- Eventlog path must follow the formats `/local/path/to/eventlog` for local logs or `s3://my-bucket/path/to/eventlog` for logs stored in S3.\n", "- The specified path can also be a directory. In such cases, the tool will recursively search for event logs within the directory.\n", " - For example: `/path/to/clusterlogs`\n", "- To specify multiple event logs, separate the paths with commas.\n", " - For example: `s3://my-bucket/path/to/eventlog1,s3://my-bucket/path/to/eventlog2`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## User Input" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "# Path to the event log in S3 (or local path)\n", "EVENTLOG_PATH = \"s3://my-bucket/path/to/eventlog\" # or \"/local/path/to/eventlog\"\n", "\n", "# S3 path with write access where the output will be copied. \n", "S3_OUTPUT_PATH = \"s3://my-bucket/path/to/output\"" ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "## Setup Environment" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "jupyter": { "source_hidden": true }, "tags": [] }, "outputs": [], "source": [ "from IPython.display import display, Markdown\n", "\n", "TOOLS_VER = \"24.08.2\"\n", "display(Markdown(f\"**Using Spark RAPIDS Tools Version:** {TOOLS_VER}\"))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "%pip install spark-rapids-user-tools==$TOOLS_VER --user > /dev/null 2>&1" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "acf401a3-12d3-4236-a6c5-8fe8990b153a", "showTitle": true, "title": "Environment Setup" }, "jupyter": { "source_hidden": true }, "tags": [] }, "outputs": [], "source": [ "import os\n", "import pandas as pd\n", "\n", "# Update PATH to include local binaries\n", "os.environ['PATH'] += os.pathsep + os.path.expanduser(\"~/.local/bin\")\n", "\n", "OUTPUT_PATH = \"/tmp\"\n", "DEST_FOLDER_NAME = \"prof-tool-result\"\n", "\n", "# Set environment variables\n", "os.environ[\"EVENTLOG_PATH\"] = EVENTLOG_PATH \n", "os.environ[\"OUTPUT_PATH\"] = OUTPUT_PATH\n", "\n", "CONSOLE_OUTPUT_PATH = os.path.join(OUTPUT_PATH, 'console_output.log')\n", "CONSOLE_ERROR_PATH = os.path.join(OUTPUT_PATH, 'console_error.log')\n", "\n", "os.environ['CONSOLE_OUTPUT_PATH'] = CONSOLE_OUTPUT_PATH\n", "os.environ['CONSOLE_ERROR_PATH'] = CONSOLE_ERROR_PATH\n", "\n", "print(f'Console output will be stored at {CONSOLE_OUTPUT_PATH} and errors will be stored at {CONSOLE_ERROR_PATH}')\n" ] }, { "cell_type": "markdown", "metadata": { "execution": { "iopub.execute_input": "2024-10-24T21:27:00.924906Z", "iopub.status.busy": "2024-10-24T21:27:00.924587Z", "iopub.status.idle": "2024-10-24T21:27:00.928129Z", "shell.execute_reply": "2024-10-24T21:27:00.927454Z", "shell.execute_reply.started": "2024-10-24T21:27:00.924879Z" } }, "source": [ "## Run Profiling Tool" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "693b5ee0-7500-43f3-b3e2-717fd5468aa8", "showTitle": true, "title": "Run Profiling Tool" }, "tags": [] }, "outputs": [], "source": [ "%%sh\n", "spark_rapids profiling --platform emr --eventlogs \"$EVENTLOG_PATH\" -o \"$OUTPUT_PATH\" --verbose > \"$CONSOLE_OUTPUT_PATH\" 2> \"$CONSOLE_ERROR_PATH\"" ] }, { "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "f83af6c8-5a79-4a46-965b-38a4cb621877", "showTitle": false, "title": "" } }, "source": [ "## Console Output\n", "Console output shows the top candidates and their estimated GPU speedup.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "c61527b7-a21a-492c-bab8-77f83dc5cabf", "showTitle": true, "title": "Show Console Output" }, "scrolled": true, "tags": [] }, "outputs": [], "source": [ "%%sh\n", "cat $CONSOLE_OUTPUT_PATH" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "f3c68b28-fc62-40ae-8528-799f3fc7507e", "showTitle": true, "title": "Show Logs" }, "scrolled": true, "tags": [] }, "outputs": [], "source": [ "%%sh\n", "cat $CONSOLE_ERROR_PATH" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "05f96ca1-1b08-494c-a12b-7e6cc3dcc546", "showTitle": true, "title": "Parse Output" }, "jupyter": { "source_hidden": true }, "tags": [] }, "outputs": [], "source": [ "import re\n", "import shutil\n", "import os\n", "\n", "\n", "def extract_file_info(console_output_path, output_base_path):\n", " try:\n", " with open(console_output_path, 'r') as file:\n", " stdout_text = file.read()\n", "\n", " # Extract log file location\n", " location_match = re.search(r\"Location: (.+)\", stdout_text)\n", " if not location_match:\n", " raise ValueError(\n", " \"Log file location not found in the provided text.\")\n", "\n", " log_file_location = location_match.group(1)\n", "\n", " # Extract profiling output folder\n", " qual_match = re.search(r\"prof_[^/]+(?=\\.log)\", log_file_location)\n", " if not qual_match:\n", " raise ValueError(\n", " \"Output folder not found in the log file location.\")\n", "\n", " output_folder_name = qual_match.group(0)\n", " output_folder = os.path.join(output_base_path, output_folder_name)\n", " return output_folder, log_file_location\n", "\n", " except Exception as e:\n", " raise RuntimeError(f\"Cannot parse console output. Reason: {e}\")\n", "\n", "\n", "def copy_logs(destination_folder, *log_files):\n", " try:\n", " log_folder = os.path.join(destination_folder, \"logs\")\n", " os.makedirs(log_folder, exist_ok=True)\n", "\n", " for log_file in log_files:\n", " if os.path.exists(log_file):\n", " shutil.copy2(log_file, log_folder)\n", " else:\n", " print(f\"Log file not found: {log_file}\")\n", " except Exception as e:\n", " raise RuntimeError(f\"Cannot copy logs to output. Reason: {e}\")\n", "\n", "\n", "try:\n", " output_folder, log_file_location = extract_file_info(\n", " CONSOLE_OUTPUT_PATH, OUTPUT_PATH)\n", " jar_output_folder = os.path.join(output_folder,\n", " \"rapids_4_spark_profile\")\n", " print(f\"Output folder detected {output_folder}\")\n", " copy_logs(output_folder, log_file_location, CONSOLE_OUTPUT_PATH,\n", " CONSOLE_ERROR_PATH)\n", " print(f\"Logs successfully copied to {output_folder}\")\n", "except Exception as e:\n", " print(e)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Download Output" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "8c65adcd-a933-482e-a50b-d40fa8f50e16", "showTitle": true, "title": "Download Output" }, "jupyter": { "source_hidden": true }, "tags": [] }, "outputs": [], "source": [ "import shutil\n", "import os\n", "import subprocess\n", "from IPython.display import HTML, display\n", "from urllib.parse import urlparse\n", "\n", "def display_error_message(error_message, exception):\n", " error_message_html = f\"\"\"\n", "
\n", " Error: {error_message}.\n", "
\n", " Exception: {exception}\n", "
\n", " \"\"\"\n", " display(HTML(error_message_html))\n", "\n", "def copy_file_to_s3(local_file: str, bucket: str, destination_folder_name: str):\n", " try:\n", " file_name = os.path.basename(local_file)\n", " s3_path = f\"s3://{bucket}/{destination_folder_name}/{file_name}\"\n", " subprocess.run([\"aws\", \"s3\", \"cp\", local_file, s3_path], check=True, capture_output=True, text=True)\n", " return construct_download_url(file_name, bucket, destination_folder_name)\n", " except subprocess.CalledProcessError as e:\n", " raise Exception(f\"Error copying file to S3: {e.stderr}\") from e\n", "\n", "def get_default_aws_region():\n", " try:\n", " return subprocess.check_output(\n", " \"aws configure list | grep region | awk '{print $2}'\",\n", " shell=True,\n", " text=True\n", " ).strip()\n", " except subprocess.CalledProcessError:\n", " return \"Error: Unable to retrieve the region.\"\n", "\n", "def construct_download_url(file_name: str, bucket_name: str, destination_folder_name: str):\n", " region = get_default_aws_region()\n", " return f\"https://{region}.console.aws.amazon.com/s3/object/{bucket_name}?region={region}&prefix={destination_folder_name}/{file_name}\"\n", "\n", "def create_download_link(source_folder, bucket_name, destination_folder_name):\n", " folder_to_compress = os.path.join(\"/tmp\", os.path.basename(source_folder))\n", " local_zip_file_path = shutil.make_archive(folder_to_compress, 'zip', source_folder)\n", " download_url = copy_file_to_s3(local_zip_file_path, bucket_name, destination_folder_name)\n", "\n", " download_button_html = f\"\"\"\n", " \n", "\n", "
\n", " Zipped output file created at {download_url}\n", "
\n", "
\n", " Download Output\n", "
\n", " \"\"\"\n", " display(HTML(download_button_html))\n", "\n", "try:\n", " current_working_directory = os.getcwd()\n", " parsed_s3_output_path = urlparse(S3_OUTPUT_PATH)\n", " bucket_name = parsed_s3_output_path.netloc\n", " destination_path = os.path.join(parsed_s3_output_path.path.strip(\"/\"), DEST_FOLDER_NAME.strip(\"/\"))\n", " create_download_link(output_folder, bucket_name, destination_path)\n", " \n", "except Exception as e:\n", " error_msg = f\"Failed to create download link for {output_folder}\"\n", " display_error_message(error_msg, e)" ] }, { "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": {}, "inputWidgets": {}, "nuid": "73b5e0b0-3a96-4cc6-8e6c-840e4b0d9d43", "showTitle": false, "title": "" } }, "source": [ "\n", "## Application Status\n", "\n", "The report show the status of each eventlog file that was provided\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "c9ffbfdb-dbb6-4736-b9cb-2ac457cc6714", "showTitle": true, "title": "profiling_status.csv" }, "jupyter": { "source_hidden": true }, "tags": [] }, "outputs": [], "source": [ "try:\n", " status_output = pd.read_csv(\n", " os.path.join(jar_output_folder, \"profiling_status.csv\"))\n", "\n", " # Set options to display the full content of the DataFrame\n", " pd.set_option('display.max_rows', None) # Show all rows\n", " pd.set_option('display.max_columns', None) # Show all columns\n", " pd.set_option('display.width', None) # Adjust column width to fit the display\n", " pd.set_option('display.max_colwidth', None) # Display full content of each column\n", "\n", " display(status_output)\n", "except Exception as e:\n", " error_msg = \"Unable to show Application Status\"\n", " display_error_message(error_msg, e) \n", " \n", " " ] }, { "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "6756159b-30ca-407a-ab6b-9c29ced01ea6", "showTitle": false, "title": "" } }, "source": [ "## GPU Job Tuning Recommendations\n", "This has general suggestions for tuning your applications to run optimally on GPUs.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "cdde6177-db5f-434a-995b-776678a64a3a", "showTitle": true, "title": "application_information.csv" }, "jupyter": { "source_hidden": true }, "scrolled": true, "tags": [] }, "outputs": [], "source": [ "try:\n", " jar_output_folder = os.path.join(output_folder, \"rapids_4_spark_profile\")\n", " app_df = pd.DataFrame(columns=['appId', 'appName'])\n", "\n", " for x in os.scandir(jar_output_folder):\n", " if x.is_dir():\n", " csv_path = os.path.join(x.path, \"application_information.csv\")\n", " if os.path.exists(csv_path):\n", " tmp_df = pd.read_csv(csv_path)\n", " app_df = pd.concat([app_df, tmp_df[['appId', 'appName']]])\n", "\n", "\n", " app_list = app_df[\"appId\"].tolist()\n", " app_recommendations = pd.DataFrame(columns=['App', 'Recommended Configuration'])\n", "\n", " for app in app_list:\n", " app_file = open(os.path.join(jar_output_folder, app, \"profile.log\"))\n", " recommendations_start = 0\n", " recommendations_str = \"\"\n", " for line in app_file:\n", " if recommendations_start == 1:\n", " recommendations_str = recommendations_str + line\n", " if \"### D. Recommended Configuration ###\" in line:\n", " recommendations_start = 1\n", " app_recommendations = pd.concat([app_recommendations, pd.DataFrame({'App': [app], 'Recommended Configuration': [recommendations_str]})], ignore_index=True)\n", " html = app_recommendations.to_html().replace(\"\\\\n\", \"
\")\n", " style = \"\"\n", " display(HTML(html + style))\n", "except Exception as e:\n", " error_msg = \"Unable to show stage output\"\n", " display_error_message(error_msg, e) " ] } ], "metadata": { "application/vnd.databricks.v1+notebook": { "dashboards": [ { "elements": [], "globalVars": {}, "guid": "", "layoutOption": { "grid": true, "stack": true }, "nuid": "91c1bfb2-695a-4e5c-8a25-848a433108dc", "origId": 2173122769183715, "title": "Executive View", "version": "DashboardViewV1", "width": 1600 }, { "elements": [], "globalVars": {}, "guid": "", "layoutOption": { "grid": true, "stack": true }, "nuid": "62243296-4562-4f06-90ac-d7a609f19c16", "origId": 2173122769183716, "title": "App View", "version": "DashboardViewV1", "width": 1920 }, { "elements": [], "globalVars": {}, "guid": "", "layoutOption": { "grid": true, "stack": true }, "nuid": "854f9c75-5977-42aa-b3dd-c680b8331f19", "origId": 2173122769183722, "title": "Untitled", "version": "DashboardViewV1", "width": 1024 } ], "environmentMetadata": null, "language": "python", "notebookMetadata": { "mostRecentlyExecutedCommandWithImplicitDF": { "commandId": 2173122769183704, "dataframes": [ "_sqldf" ] }, "pythonIndentUnit": 2, "widgetLayout": [ { "breakBefore": false, "name": "Eventlog Path", "width": 778 }, { "breakBefore": false, "name": "Output Path", "width": 302 } ] }, "notebookName": "[RAPIDS Accelerator for Apache Spark] Profiling Tool Notebook Template" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" }, "widgets": { "application/vnd.jupyter.widget-state+json": { "state": {}, "version_major": 2, "version_minor": 0 } } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: tools/emr/[RAPIDS Accelerator for Apache Spark] Qualification Tool Notebook Template.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "df33c614-2ecc-47a0-8600-bc891681997f", "showTitle": false, "title": "" } }, "source": [ "## Qualification Tool for the RAPIDS Accelerator for Apache Spark\n", "\n", "To run the qualification tool, enter the log path that represents the location of your Spark CPU event logs. Then, select \"Run all\" to execute the notebook. Once the notebook completes, various output tables will appear below. For more options on running the qualification tool, please refer to the [Qualification Tool User Guide](https://docs.nvidia.com/spark-rapids/user-guide/latest/qualification/quickstart.html#running-the-tool).\n", "\n", "### Note\n", "- Currently, local and S3 event log paths are supported.\n", "- Eventlog path must follow the formats `/local/path/to/eventlog` for local logs or `s3://my-bucket/path/to/eventlog` for logs stored in S3.\n", "- The specified path can also be a directory. In such cases, the tool will recursively search for event logs within the directory.\n", " - For example: `/path/to/clusterlogs`\n", "- To specify multiple event logs, separate the paths with commas.\n", " - For example: `s3://my-bucket/path/to/eventlog1,s3://my-bucket/path/to/eventlog2`\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## User Input" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "# Path to the event log in S3 (or local path)\n", "EVENTLOG_PATH = \"s3://my-bucket/path/to/eventlog\" # or \"/local/path/to/eventlog\"\n", "\n", "# S3 path with write access where the output will be copied. \n", "S3_OUTPUT_PATH = \"s3://my-bucket/path/to/output\"" ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "## Setup Environment" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "jupyter": { "source_hidden": true }, "tags": [] }, "outputs": [], "source": [ "from IPython.display import display, Markdown\n", "\n", "TOOLS_VER = \"24.08.2\"\n", "display(Markdown(f\"**Using Spark RAPIDS Tools Version:** {TOOLS_VER}\"))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "%pip install spark-rapids-user-tools==$TOOLS_VER --user > /dev/null 2>&1" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "acf401a3-12d3-4236-a6c5-8fe8990b153a", "showTitle": true, "title": "Environment Setup" }, "jupyter": { "source_hidden": true }, "tags": [] }, "outputs": [], "source": [ "import os\n", "import pandas as pd\n", "\n", "# Update PATH to include local binaries\n", "os.environ['PATH'] += os.pathsep + os.path.expanduser(\"~/.local/bin\")\n", "\n", "OUTPUT_PATH = \"/tmp\"\n", "DEST_FOLDER_NAME = \"qual-tool-result\"\n", "\n", "# Set environment variables\n", "os.environ[\"EVENTLOG_PATH\"] = EVENTLOG_PATH \n", "os.environ[\"OUTPUT_PATH\"] = OUTPUT_PATH\n", "\n", "CONSOLE_OUTPUT_PATH = os.path.join(OUTPUT_PATH, 'console_output.log')\n", "CONSOLE_ERROR_PATH = os.path.join(OUTPUT_PATH, 'console_error.log')\n", "\n", "os.environ['CONSOLE_OUTPUT_PATH'] = CONSOLE_OUTPUT_PATH\n", "os.environ['CONSOLE_ERROR_PATH'] = CONSOLE_ERROR_PATH\n", "\n", "print(f'Console output will be stored at {CONSOLE_OUTPUT_PATH} and errors will be stored at {CONSOLE_ERROR_PATH}')\n" ] }, { "cell_type": "markdown", "metadata": { "execution": { "iopub.execute_input": "2024-10-24T21:27:00.924906Z", "iopub.status.busy": "2024-10-24T21:27:00.924587Z", "iopub.status.idle": "2024-10-24T21:27:00.928129Z", "shell.execute_reply": "2024-10-24T21:27:00.927454Z", "shell.execute_reply.started": "2024-10-24T21:27:00.924879Z" } }, "source": [ "## Run Qualification Tool" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "693b5ee0-7500-43f3-b3e2-717fd5468aa8", "showTitle": true, "title": "Run Qualification Tool" }, "tags": [] }, "outputs": [], "source": [ "%%sh\n", "spark_rapids qualification --platform emr --eventlogs \"$EVENTLOG_PATH\" -o \"$OUTPUT_PATH\" --verbose > \"$CONSOLE_OUTPUT_PATH\" 2> \"$CONSOLE_ERROR_PATH\"" ] }, { "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "f83af6c8-5a79-4a46-965b-38a4cb621877", "showTitle": false, "title": "" } }, "source": [ "## Console Output\n", "Console output shows the top candidates and their estimated GPU speedup.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "c61527b7-a21a-492c-bab8-77f83dc5cabf", "showTitle": true, "title": "Show Console Output" }, "scrolled": true, "tags": [] }, "outputs": [], "source": [ "%%sh\n", "cat $CONSOLE_OUTPUT_PATH" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "f3c68b28-fc62-40ae-8528-799f3fc7507e", "showTitle": true, "title": "Show Logs" }, "scrolled": true, "tags": [] }, "outputs": [], "source": [ "%%sh\n", "cat $CONSOLE_ERROR_PATH" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "05f96ca1-1b08-494c-a12b-7e6cc3dcc546", "showTitle": true, "title": "Parse Output" }, "jupyter": { "source_hidden": true }, "tags": [] }, "outputs": [], "source": [ "import re\n", "import shutil\n", "import os\n", "\n", "\n", "def extract_file_info(console_output_path, output_base_path):\n", " try:\n", " with open(console_output_path, 'r') as file:\n", " stdout_text = file.read()\n", "\n", " # Extract log file location\n", " location_match = re.search(r\"Location: (.+)\", stdout_text)\n", " if not location_match:\n", " raise ValueError(\n", " \"Log file location not found in the provided text.\")\n", "\n", " log_file_location = location_match.group(1)\n", "\n", " # Extract qualification output folder\n", " qual_match = re.search(r\"qual_[^/]+(?=\\.log)\", log_file_location)\n", " if not qual_match:\n", " raise ValueError(\n", " \"Output folder not found in the log file location.\")\n", "\n", " output_folder_name = qual_match.group(0)\n", " output_folder = os.path.join(output_base_path, output_folder_name)\n", " return output_folder, log_file_location\n", "\n", " except Exception as e:\n", " raise RuntimeError(f\"Cannot parse console output. Reason: {e}\")\n", "\n", "\n", "def copy_logs(destination_folder, *log_files):\n", " try:\n", " log_folder = os.path.join(destination_folder, \"logs\")\n", " os.makedirs(log_folder, exist_ok=True)\n", "\n", " for log_file in log_files:\n", " if os.path.exists(log_file):\n", " shutil.copy2(log_file, log_folder)\n", " else:\n", " print(f\"Log file not found: {log_file}\")\n", " except Exception as e:\n", " raise RuntimeError(f\"Cannot copy logs to output. Reason: {e}\")\n", "\n", "\n", "try:\n", " output_folder, log_file_location = extract_file_info(\n", " CONSOLE_OUTPUT_PATH, OUTPUT_PATH)\n", " jar_output_folder = os.path.join(output_folder,\n", " \"rapids_4_spark_qualification_output\")\n", " print(f\"Output folder detected {output_folder}\")\n", " copy_logs(output_folder, log_file_location, CONSOLE_OUTPUT_PATH,\n", " CONSOLE_ERROR_PATH)\n", " print(f\"Logs successfully copied to {output_folder}\")\n", "except Exception as e:\n", " print(e)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Download Output" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "8c65adcd-a933-482e-a50b-d40fa8f50e16", "showTitle": true, "title": "Download Output" }, "jupyter": { "source_hidden": true }, "tags": [] }, "outputs": [], "source": [ "import shutil\n", "import os\n", "import subprocess\n", "from IPython.display import HTML, display\n", "from urllib.parse import urlparse\n", "\n", "def display_error_message(error_message, exception):\n", " error_message_html = f\"\"\"\n", "
\n", " Error: {error_message}.\n", "
\n", " Exception: {exception}\n", "
\n", " \"\"\"\n", " display(HTML(error_message_html))\n", "\n", "def copy_file_to_s3(local_file: str, bucket: str, destination_folder_name: str):\n", " try:\n", " file_name = os.path.basename(local_file)\n", " s3_path = f\"s3://{bucket}/{destination_folder_name}/{file_name}\"\n", " subprocess.run([\"aws\", \"s3\", \"cp\", local_file, s3_path], check=True, capture_output=True, text=True)\n", " return construct_download_url(file_name, bucket, destination_folder_name)\n", " except subprocess.CalledProcessError as e:\n", " raise Exception(f\"Error copying file to S3: {e.stderr}\") from e\n", "\n", "def get_default_aws_region():\n", " try:\n", " return subprocess.check_output(\n", " \"aws configure list | grep region | awk '{print $2}'\",\n", " shell=True,\n", " text=True\n", " ).strip()\n", " except subprocess.CalledProcessError:\n", " return \"Error: Unable to retrieve the region.\"\n", "\n", "def construct_download_url(file_name: str, bucket_name: str, destination_folder_name: str):\n", " region = get_default_aws_region()\n", " return f\"https://{region}.console.aws.amazon.com/s3/object/{bucket_name}?region={region}&prefix={destination_folder_name}/{file_name}\"\n", "\n", "def create_download_link(source_folder, bucket_name, destination_folder_name):\n", " folder_to_compress = os.path.join(\"/tmp\", os.path.basename(source_folder))\n", " local_zip_file_path = shutil.make_archive(folder_to_compress, 'zip', source_folder)\n", " download_url = copy_file_to_s3(local_zip_file_path, bucket_name, destination_folder_name)\n", "\n", " download_button_html = f\"\"\"\n", " \n", "\n", "
\n", " Zipped output file created at {download_url}\n", "
\n", "
\n", " Download Output\n", "
\n", " \"\"\"\n", " display(HTML(download_button_html))\n", "\n", "try:\n", " current_working_directory = os.getcwd()\n", " parsed_s3_output_path = urlparse(S3_OUTPUT_PATH)\n", " bucket_name = parsed_s3_output_path.netloc\n", " destination_path = os.path.join(parsed_s3_output_path.path.strip(\"/\"), DEST_FOLDER_NAME.strip(\"/\"))\n", " create_download_link(output_folder, bucket_name, destination_path)\n", " \n", "except Exception as e:\n", " error_msg = f\"Failed to create download link for {output_folder}\"\n", " display_error_message(error_msg, e)" ] }, { "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "bbe50fde-0bd6-4281-95fd-6a1ec6f17ab2", "showTitle": false, "title": "" } }, "source": [ "\n", "## Summary\n", "\n", "The report provides a comprehensive overview of the entire application execution, estimated speedup, including unsupported operators and non-SQL operations. By default, the applications and queries are sorted in descending order based on the following fields:\n", "\n", "- Estimated GPU Speedup Category\n", "- Estimated GPU Speedup" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "b8bca4a6-16d8-4b60-ba7b-9aff64bdcaa1", "showTitle": true, "title": "qualification_summary.csv" }, "jupyter": { "source_hidden": true }, "tags": [] }, "outputs": [], "source": [ "def millis_to_human_readable(millis):\n", " seconds = int(millis) / 1000\n", " if seconds < 60:\n", " return f\"{seconds:.2f} sec\"\n", " else:\n", " minutes = seconds / 60\n", " if minutes < 60:\n", " return f\"{minutes:.2f} min\"\n", " else:\n", " hours = minutes / 60\n", " return f\"{hours:.2f} hr\"\n", "\n", "try: \n", " # Read qualification summary \n", " summary_output = pd.read_csv(os.path.join(output_folder, \"qualification_summary.csv\"))\n", " summary_output = summary_output.drop(columns=[\"Unnamed: 0\"]).rename_axis('Index').reset_index()\n", " summary_output['Estimated GPU Duration'] = summary_output['Estimated GPU Duration'].apply(millis_to_human_readable)\n", " summary_output['App Duration'] = summary_output['App Duration'].apply(millis_to_human_readable)\n", " \n", " summary_output = summary_output[[\n", " 'App Name', 'App ID', 'Estimated GPU Speedup Category', 'Estimated GPU Speedup', \n", " 'Estimated GPU Duration', 'App Duration'\n", " ]]\n", " \n", " # Read cluster information\n", " cluster_df = pd.read_json(os.path.join(output_folder, \"app_metadata.json\"))\n", " cluster_df['Recommended GPU Cluster'] = cluster_df['clusterInfo'].apply(\n", " lambda x: f\"{x['recommendedCluster']['numWorkerNodes']} x {x['recommendedCluster']['workerNodeType']}\"\n", " )\n", " cluster_df['App ID'] = cluster_df['appId']\n", " cluster_df = cluster_df[['App ID', 'Recommended GPU Cluster']]\n", " \n", " # Merge the results\n", " results = pd.merge(summary_output, cluster_df, on='App ID', how='left')\n", " display(results)\n", "except Exception as e:\n", " error_msg = \"Unable to show summary\"\n", " display_error_message(error_msg, e)" ] }, { "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": {}, "inputWidgets": {}, "nuid": "73b5e0b0-3a96-4cc6-8e6c-840e4b0d9d43", "showTitle": false, "title": "" } }, "source": [ "\n", "## Application Status\n", "\n", "The report show the status of each eventlog file that was provided\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "c9ffbfdb-dbb6-4736-b9cb-2ac457cc6714", "showTitle": true, "title": "rapids_4_spark_qualification_output_status.csv" }, "jupyter": { "source_hidden": true }, "tags": [] }, "outputs": [], "source": [ "try:\n", " status_output = pd.read_csv(\n", " os.path.join(jar_output_folder,\n", " \"rapids_4_spark_qualification_output_status.csv\"))\n", "\n", " # Set options to display the full content of the DataFrame\n", " pd.set_option('display.max_rows', None) # Show all rows\n", " pd.set_option('display.max_columns', None) # Show all columns\n", " pd.set_option('display.width', None) # Adjust column width to fit the display\n", " pd.set_option('display.max_colwidth', None) # Display full content of each column\n", "\n", " display(status_output)\n", "except Exception as e:\n", " error_msg = \"Unable to show Application Status\"\n", " display_error_message(error_msg, e) \n", " \n", " " ] }, { "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "6756159b-30ca-407a-ab6b-9c29ced01ea6", "showTitle": false, "title": "" } }, "source": [ "## Stages Output\n", "\n", "For each stage used in SQL operations, the Qualification tool generates the following information:\n", "\n", "1. App ID\n", "2. Stage ID\n", "3. Average Speedup Factor: The average estimated speed-up of all the operators in the given stage.\n", "4. Stage Task Duration: The amount of time spent in tasks of SQL DataFrame operations for the given stage.\n", "5. Unsupported Task Duration: The sum of task durations for the unsupported operators. For more details, see [Supported Operators](https://nvidia.github.io/spark-rapids/docs/supported_ops.html).\n", "6. Stage Estimated: Indicates if the stage duration had to be estimated (True or False).\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "cdde6177-db5f-434a-995b-776678a64a3a", "showTitle": true, "title": "rapids_4_spark_qualification_output_stages.csv" }, "jupyter": { "source_hidden": true }, "scrolled": true, "tags": [] }, "outputs": [], "source": [ "try:\n", " stages_output = pd.read_csv(\n", " os.path.join(jar_output_folder,\n", " \"rapids_4_spark_qualification_output_stages.csv\"))\n", " display(stages_output)\n", "except Exception as e:\n", " error_msg = \"Unable to show stage output\"\n", " display_error_message(error_msg, e) " ] }, { "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "4d7ce219-ae75-4a0c-a78c-4e7f25b8cd6f", "showTitle": false, "title": "" } }, "source": [ "## Execs Output\n", "\n", "The Qualification tool generates a report of the “Exec” in the “SparkPlan” or “Executor Nodes” along with the estimated acceleration on the GPU. Please refer to the [Supported Operators guide](https://nvidia.github.io/spark-rapids/docs/supported_ops.html) for more details on limitations on UDFs and unsupported operators.\n", "\n", "1. App ID\n", "2. SQL ID\n", "3. Exec Name: Example: Filter, HashAggregate\n", "4. Expression Name\n", "5. Task Speedup Factor: The average acceleration of the operators based on the original CPU duration of the operator divided by the GPU duration. The tool uses historical queries and benchmarks to estimate a speed-up at an individual operator level to calculate how much a specific operator would accelerate on GPU.\n", "6. Exec Duration: Wall-clock time measured from when the operator starts until it is completed.\n", "7. SQL Node ID\n", "8. Exec Is Supported: Indicates whether the Exec is supported by RAPIDS. Refer to the Supported Operators section for details.\n", "9. Exec Stages: An array of stage IDs.\n", "10. Exec Children\n", "11. Exec Children Node IDs\n", "12. Exec Should Remove: Indicates whether the Op is removed from the migrated plan.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "998b0c51-0cb6-408e-a01a-d1f5b1a61e1f", "showTitle": true, "title": "rapids_4_spark_qualification_output_execs.csv" }, "jupyter": { "source_hidden": true }, "scrolled": true, "tags": [] }, "outputs": [], "source": [ "try:\n", " execs_output = pd.read_csv(\n", " os.path.join(jar_output_folder,\n", " \"rapids_4_spark_qualification_output_execs.csv\"))\n", " display(execs_output)\n", "except Exception as e:\n", " error_msg = \"Unable to show Execs output\"\n", " display_error_message(error_msg, e) " ] } ], "metadata": { "application/vnd.databricks.v1+notebook": { "dashboards": [ { "elements": [], "globalVars": {}, "guid": "", "layoutOption": { "grid": true, "stack": true }, "nuid": "91c1bfb2-695a-4e5c-8a25-848a433108dc", "origId": 2173122769183715, "title": "Executive View", "version": "DashboardViewV1", "width": 1600 }, { "elements": [], "globalVars": {}, "guid": "", "layoutOption": { "grid": true, "stack": true }, "nuid": "62243296-4562-4f06-90ac-d7a609f19c16", "origId": 2173122769183716, "title": "App View", "version": "DashboardViewV1", "width": 1920 }, { "elements": [], "globalVars": {}, "guid": "", "layoutOption": { "grid": true, "stack": true }, "nuid": "854f9c75-5977-42aa-b3dd-c680b8331f19", "origId": 2173122769183722, "title": "Untitled", "version": "DashboardViewV1", "width": 1024 } ], "environmentMetadata": null, "language": "python", "notebookMetadata": { "mostRecentlyExecutedCommandWithImplicitDF": { "commandId": 2173122769183704, "dataframes": [ "_sqldf" ] }, "pythonIndentUnit": 2, "widgetLayout": [ { "breakBefore": false, "name": "Eventlog Path", "width": 778 }, { "breakBefore": false, "name": "Output Path", "width": 302 } ] }, "notebookName": "[RAPIDS Accelerator for Apache Spark] Qualification Tool Notebook Template" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" }, "widgets": { "application/vnd.jupyter.widget-state+json": { "state": {}, "version_major": 2, "version_minor": 0 } } }, "nbformat": 4, "nbformat_minor": 4 }