[
  {
    "path": ".github/ISSUE_TEMPLATE/bug_report.md",
    "content": "---\nname: Bug report\nabout: Create a report to help us improve\ntitle: ''\nlabels: ''\nassignees: GaryShen2008\n\n---\n\n**Describe the bug**\nA clear and concise description of what the bug is.\n\n**Steps/Code to reproduce bug**\nPlease provide a list of steps or a code sample to reproduce the issue.\nAvoid posting private or sensitive data.\n\n**Expected behavior**\nA clear and concise description of what you expected to happen.\n\n**Environment details (please complete the following information)**\n - Environment location: [Standalone, YARN, Kubernetes, Cloud(specify cloud provider)]\n - Spark configuration settings related to the issue"
  },
  {
    "path": ".github/workflows/add-to-project.yml",
    "content": "# Copyright (c) 2024-2025, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nname: Add new issues and pull requests to project\n\non:\n  issues:\n    types:\n      - opened\n  pull_request_target:\n    types:\n      - opened\n\njobs:\n  Add-to-project:\n    if: github.repository_owner == 'NVIDIA' # avoid adding issues from forks\n    runs-on: ubuntu-latest\n    steps:\n      - name: add-to-project\n        uses: NVIDIA/spark-rapids-common/add-to-project@main\n        with:\n          token: ${{ secrets.PROJECT_TOKEN }}\n"
  },
  {
    "path": ".github/workflows/license-header-check.yml",
    "content": "# Copyright (c) 2024-2025, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# A workflow to check copyright/license header\nname: license header check\n\non:\n  pull_request:\n    types: [opened, synchronize, reopened]\n\njobs:\n  license-header-check:\n    runs-on: ubuntu-latest\n    if: \"!contains(github.event.pull_request.title, '[bot]')\"\n    steps:\n      - name: Get checkout depth\n        run: |\n          echo \"PR_FETCH_DEPTH=$(( ${{ github.event.pull_request.commits }} + 10 ))\" >> $GITHUB_ENV\n\n      - name: Checkout code\n        uses: NVIDIA/spark-rapids-common/checkout@main\n        with:\n          fetch-depth: ${{ env.PR_FETCH_DEPTH }}\n\n      - name: license-header-check\n        uses: NVIDIA/spark-rapids-common/license-header-check@main\n        with:\n          included_file_patterns: |\n            *.sh,\n            *.java,\n            *.py,\n            *.pbtxt,\n            *Dockerfile*,\n            *Jenkinsfile*,\n            *.yml,\n            *.yaml,\n            *.cpp,\n            *.hpp,\n            *.txt,\n            *.cu,\n            *.scala,\n            *.ini,\n            *.xml\n"
  },
  {
    "path": ".github/workflows/markdown-links-check/markdown-links-check-config.json",
    "content": "{\n  \"ignorePatterns\": [\n    {\n      \"pattern\": \"/docs\"\n    },\n    {\n      \"pattern\": \"/datasets\"\n    },\n    {\n      \"pattern\": \"/dockerfile\"\n    },\n    {\n      \"pattern\": \"/examples\"\n    },\n    {\n      \"pattern\": \"^http://localhost\"\n    },\n    {\n      \"pattern\": \"^http://spark-master\"\n    },\n    {\n      \"pattern\": \"^http://spark-worker\"\n    },\n    {\n      \"pattern\": \"^http://spark-connect-server\"\n    }\n  ],\n  \"timeout\": \"15s\",\n  \"retryOn429\": true,\n  \"retryCount\":30,\n  \"aliveStatusCodes\": [200, 403]\n} \n"
  },
  {
    "path": ".github/workflows/markdown-links-check.yml",
    "content": "# Copyright (c) 2022-2025, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# A workflow to check if PR got broken hyperlinks\nname: Check Markdown links\n\non:\n  pull_request:\n    types: [opened, synchronize, reopened]\n\njobs:\n  markdown-link-check:\n    runs-on: ubuntu-latest\n    steps:\n    - name: work around permission issue\n      run: git config --global --add safe.directory /github/workspace\n    - name: checkout code\n      uses: NVIDIA/spark-rapids-common/checkout@main\n    - name: markdown link check\n      uses: NVIDIA/spark-rapids-common/markdown-link-check@main\n      with:\n        max-depth: -1\n        use-verbose-mode: 'yes'\n        config-file: '.github/workflows/markdown-links-check/markdown-links-check-config.json'\n        base-branch: 'main'"
  },
  {
    "path": ".github/workflows/shell-check.yml",
    "content": "# Copyright (c) 2025, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# A workflow to check shell script syntax\nname: shell check\n\non:\n  pull_request:\n    types: [opened, synchronize, reopened]\n\njobs:\n  shell-check:\n    runs-on: ubuntu-latest\n    if: \"!contains(github.event.pull_request.title, '[bot]')\"\n    steps:\n      - name: Checkout code\n        uses: NVIDIA/spark-rapids-common/checkout@main\n\n      - name: Run ShellCheck\n        uses: NVIDIA/spark-rapids-common/shell-check@main\n        with:\n          excluded_codes:\n            SC2164,\n            SC2076,\n            SC2054\n\n          # codes explanation:\n          # SC2164: Use 'cd ... || exit' or 'cd ... || return' in case cd fails.\n          # SC2076: Remove quotes from right-hand side of =~ to match as a regex rather than literally.\n          # SC2054: Use spaces, not commas, to separate array elements.\n"
  },
  {
    "path": ".github/workflows/signoff-check.yml",
    "content": "# Copyright (c) 2021-2024, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# A workflow to check if PR got sign-off\nname: signoff check\n\non:\n  pull_request_target:\n    types: [opened, synchronize, reopened]\n\njobs:\n  signoff-check:\n    runs-on: ubuntu-latest\n    steps:\n      - name: signoff\n        uses: NVIDIA/spark-rapids-common/signoff-check@main\n        with:\n          owner: ${{ github.repository_owner }}\n          repo: spark-rapids-examples\n          pull_number: ${{ github.event.number }}\n          token: ${{ secrets.GITHUB_TOKEN }}\n"
  },
  {
    "path": ".gitignore",
    "content": "*#*#\n*.#*\n*.iml\n*.ipr\n*.iws\n*.pyc\n*.pyo\n*.swp\n*~\n.DS_Store\n.cache\n.classpath\n.ensime\n.ensime_cache/\n.ensime_lucene\n.generated-mima*\n.idea/\n.idea_modules/\n.project\n.pydevproject\n.scala_dependencies\n.settings\nhs_err*.log\ntarget\n"
  },
  {
    "path": "CONTRIBUTING.md",
    "content": "# Contributing to Spark Examples\n\n### Sign your work\n\nWe require that all contributors sign-off on their commits. \n\nThis certifies that the contribution is your original work, or you have rights to submit it under the same license, or a compatible license.\n\nAny contribution which contains commits that are not signed off will not be accepted.\n\nTo sign off on a commit use the `--signoff` (or `-s`) option when committing your changes:\n\n```shell\ngit commit -s -m \"Add cool feature.\"\n```\n\nThis will append the following to your commit message:\n\n```\nSigned-off-by: Your Name <your@email.com>\n```\n\nThe sign-off is a simple line at the end of the explanation for the patch. \n\nYour signature certifies that you wrote the patch or otherwise have the right to pass it on as an open-source patch. \n\nUse your real name, no pseudonyms or anonymous contributions.  \n\nIf you set your `user.name` and `user.email` git configs, you can sign your commit automatically with `git commit -s`.\n\n\nThe signoff means you certify the below (from [developercertificate.org](https://developercertificate.org)):\n\n```\nDeveloper Certificate of Origin\nVersion 1.1\n\nCopyright (C) 2004, 2006 The Linux Foundation and its contributors.\n1 Letterman Drive\nSuite D4700\nSan Francisco, CA, 94129\n\nEveryone is permitted to copy and distribute verbatim copies of this\nlicense document, but changing it is not allowed.\n\n\nDeveloper's Certificate of Origin 1.1\n\nBy making a contribution to this project, I certify that:\n\n(a) The contribution was created in whole or in part by me and I\n    have the right to submit it under the open source license\n    indicated in the file; or\n\n(b) The contribution is based upon previous work that, to the best\n    of my knowledge, is covered under an appropriate open source\n    license and I have the right under that license to submit that\n    work with modifications, whether created in whole or in part\n    by me, under the same open source license (unless I am\n    permitted to submit under a different license), as indicated\n    in the file; or\n\n(c) The contribution was provided directly to me by some other\n    person who certified (a), (b) or (c) and I have not modified\n    it.\n\n(d) I understand and agree that this project and the contribution\n    are public and that a record of the contribution (including all\n    personal information I submit with it, including my sign-off) is\n    maintained indefinitely and may be redistributed consistent with\n    this project or the open source license(s) involved.\n```\n\nNote: This section `Sign your work` is derived from [https://github.com/NVIDIA/spark-rapids](https://github.com/NVIDIA/spark-rapids)\n"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"{}\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright 2018 NVIDIA Corporation\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "# spark-rapids-examples\n\nThis is the [RAPIDS Accelerator for Apache Spark](https://nvidia.github.io/spark-rapids/) examples repo.\nRAPIDS Accelerator for Apache Spark accelerates Spark applications with no code changes.\nYou can download the latest version of RAPIDS Accelerator [here](https://nvidia.github.io/spark-rapids/docs/download.html).\nThis repo contains examples and applications that showcases the performance and benefits of using \nRAPIDS Accelerator in data processing and machine learning pipelines. \nThere are broadly five categories of examples in this repo: \n1. [SQL/Dataframe](./examples/SQL+DF-Examples) \n2. [Spark XGBoost](./examples/XGBoost-Examples) \n3. [Machine Learning/Deep Learning](./examples/ML+DL-Examples) \n4. [RAPIDS UDF](./examples/UDF-Examples)\n5. [Databricks Tools demo notebooks](./tools/databricks)\n\nFor more information on each of the examples please look into respective categories.\n\nHere is the list of notebooks in this repo:\n\n|   | Category  | Notebook Name | Description\n| ------------- | ------------- | ------------- | -------------\n| 1 | SQL/DF | Microbenchmark | Spark SQL operations such as expand, hash aggregate, windowing, and cross joins with up to 20x performance benefits\n| 2 | SQL/DF | Customer Churn | Data federation for modeling customer Churn with a sample telco customer data\n| 3 | XGBoost | Agaricus (Scala) | Uses XGBoost classifier function to create model that can accurately differentiate between edible and poisonous mushrooms with the [agaricus dataset](https://archive.ics.uci.edu/ml/datasets/mushroom)\n| 4 | XGBoost | Mortgage (Scala) | End-to-end ETL + XGBoost example to predict mortgage default with [Fannie Mae Single-Family Loan Performance Data](https://capitalmarkets.fanniemae.com/credit-risk-transfer/single-family-credit-risk-transfer/fannie-mae-single-family-loan-performance-data)\n| 5 | XGBoost | Taxi (Scala) | End-to-end ETL + XGBoost example to predict taxi trip fare amount with [NYC taxi trips data set](https://www1.nyc.gov/site/tlc/about/tlc-trip-record-data.page)\n| 6 | ML/DL | PCA | [Spark-Rapids-ML](https://github.com/NVIDIA/spark-rapids-ml) based PCA example to train and transform with a synthetic dataset\n| 7 | ML/DL | DL Inference | Several notebooks demonstrating distributed model inference on Spark using the `predict_batch_udf` across various frameworks: PyTorch, HuggingFace, vLLM, and TensorFlow\n| 8 | SQL/DF + MLlib | GPU-Accelerated Spark Connect | End-to-end SQL/DF + MLlib acceleration to predict mortgage default with [Fannie Mae Single-Family Loan Performance Data](https://capitalmarkets.fanniemae.com/credit-risk-transfer/single-family-credit-risk-transfer/fannie-mae-single-family-loan-performance-data) using the lightweight Spark Connect integration for Apache Spark 4.0+\n| 9 | SQL/DF | [TPC-DS](https://www.tpc.org/tpcds/) Scale Factor 10 | Comparison of Spark SQL CPU vs GPU. Easy to run locally and on Google Colab\n\nHere is the list of Apache Spark applications (Scala and PySpark) that \ncan be built for running on GPU with RAPIDS Accelerator in this repo:\n\n|   | Category  | Notebook Name | Description\n| ------------- | ------------- | ------------- | -------------\n| 1 | XGBoost | Agaricus (Scala) | Uses XGBoost classifier function to create model that can accurately differentiate between edible and poisonous mushrooms with the [agaricus dataset](https://archive.ics.uci.edu/ml/datasets/mushroom)\n| 2 | XGBoost | Mortgage (Scala) | End-to-end ETL + XGBoost example to predict mortgage default with [Fannie Mae Single-Family Loan Performance Data](https://capitalmarkets.fanniemae.com/credit-risk-transfer/single-family-credit-risk-transfer/fannie-mae-single-family-loan-performance-data)\n| 3 | XGBoost | Taxi (Scala) | End-to-end ETL + XGBoost example to predict taxi trip fare amount with [NYC taxi trips data set](https://www1.nyc.gov/site/tlc/about/tlc-trip-record-data.page)\n| 4 | ML/DL | PCA | [Spark-Rapids-ML](https://github.com/NVIDIA/spark-rapids-ml) based PCA example to train and transform with a synthetic dataset\n| 5 | UDF | URL Decode | Decodes URL-encoded strings using the [Java APIs of RAPIDS cudf](https://docs.rapids.ai/api/cudf-java/legacy/)\n| 6 | UDF | URL Encode | URL-encodes strings using the [Java APIs of RAPIDS cudf](https://docs.rapids.ai/api/cudf-java/legacy/)\n| 7 | UDF | [CosineSimilarity](./examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/java/com/nvidia/spark/rapids/udf/java/CosineSimilarity.java) | Computes the cosine similarity between two float vectors using [native code](./examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/cpp/src)\n| 8 | UDF | [StringWordCount](./examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/java/com/nvidia/spark/rapids/udf/hive/StringWordCount.java)  | Implements a Hive simple UDF using [native code](./examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/cpp/src) to count words in strings\n"
  },
  {
    "path": "dockerfile/Dockerfile",
    "content": "# Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.\n# Licensed to the Apache Software Foundation (ASF) under one or more\n# contributor license agreements.  See the NOTICE file distributed with\n# this work for additional information regarding copyright ownership.\n# The ASF licenses this file to You under the Apache License, Version 2.0\n# (the \"License\"); you may not use this file except in compliance with\n# the License.  You may obtain a copy of the License at\n#\n#    http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n\nFROM nvidia/cuda:11.8.0-devel-ubuntu18.04\nARG spark_uid=185\n\n# Install java dependencies \nRUN apt-get update && apt-get install -y --no-install-recommends openjdk-8-jdk openjdk-8-jre\nENV JAVA_HOME /usr/lib/jvm/java-1.8.0-openjdk-amd64\nENV PATH $PATH:/usr/lib/jvm/java-1.8.0-openjdk-amd64/jre/bin:/usr/lib/jvm/java-1.8.0-openjdk-amd64/bin\n\n# Before building the docker image, first build and make a Spark distribution following\n# the instructions in http://spark.apache.org/docs/latest/building-spark.html.\n# If this docker file is being used in the context of building your images from a Spark\n# distribution, the docker build command should be invoked from the top level directory\n# of the Spark distribution. E.g.:\n# docker build -t spark:latest -f kubernetes/dockerfiles/spark/Dockerfile .\n\nRUN set -ex && \\\n    ln -s /lib /lib64 && \\\n    mkdir -p /opt/spark && \\\n    mkdir -p /opt/spark/examples && \\\n    mkdir -p /opt/spark/work-dir && \\\n    touch /opt/spark/RELEASE && \\\n    rm /bin/sh && \\\n    ln -sv /bin/bash /bin/sh && \\\n    echo \"auth required pam_wheel.so use_uid\" >> /etc/pam.d/su && \\\n    chgrp root /etc/passwd && chmod ug+rw /etc/passwd\n\nENV DEBIAN_FRONTEND noninteractive\nRUN apt-get update && apt-get install -y --no-install-recommends apt-utils \\\n && apt-get install -y --no-install-recommends python libgomp1 \\\n && rm -rf /var/lib/apt/lists/*\n\nCOPY jars /opt/spark/jars\nCOPY bin /opt/spark/bin\nCOPY sbin /opt/spark/sbin\nCOPY kubernetes/dockerfiles/spark/entrypoint.sh /opt/\nCOPY examples /opt/spark/examples\nCOPY kubernetes/tests /opt/spark/tests\nCOPY data /opt/spark/data\n\nENV SPARK_HOME /opt/spark\n\nWORKDIR /opt/spark/work-dir\nRUN chmod g+w /opt/spark/work-dir\n\nENV TINI_VERSION v0.18.0\nADD https://github.com/krallin/tini/releases/download/${TINI_VERSION}/tini /sbin/tini\nRUN chmod +rx /sbin/tini\n\nENTRYPOINT [ \"/opt/entrypoint.sh\" ]\n\n# Specify the User that the actual main process will run as\nUSER ${spark_uid}\n\n"
  },
  {
    "path": "dockerfile/gpu_executor_template.yaml",
    "content": "# Copyright (c) 2024, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\napiVersion: v1\nkind: Pod\nspec:\n  containers:\n    - name: executor\n      resources:\n        limits:\n          nvidia.com/gpu: 1\n\n"
  },
  {
    "path": "docs/get-started/xgboost-examples/building-sample-apps/python.md",
    "content": "# Build XGBoost Python Examples\n\n## Build\n\nFollow these steps to package the Python zip file:\n\n``` bash\ngit clone https://github.com/NVIDIA/spark-rapids-examples.git\ncd spark-rapids-examples/scripts/building\nsh python_build.sh\n```\n\n\n## Files Required by PySpark\n\nTwo files are required by PySpark:\n\n+ *samples.zip*\n  \n  the package including all example code. \n  Executing the above build commands generates the samples.zip file in 'spark-rapids-examples/examples/XGBoost-Examples' folder\n\n+ *main.py*\n  \n  entrypoint for PySpark, you can find it in 'spark-rapids-examples/examples/XGBoost-Examples' folder\n"
  },
  {
    "path": "docs/get-started/xgboost-examples/building-sample-apps/scala.md",
    "content": "# Build XGBoost Scala Examples\n\nThe examples rely on [XGBoost](https://github.com/dmlc/xgboost).\n\n## Build\n\nFollow these steps to build the Scala jars:\n\n``` bash\ngit clone https://github.com/NVIDIA/spark-rapids-examples.git\ncd spark-rapids-examples/examples/XGBoost-Examples\nmvn package\n```\n\n## The generated Jars\n\nLet's assume LATEST_VERSION is **0.2.3**. The build process will generate two jars as belows,\n\n+ *aggregator/target/sample_xgboost_apps-${LATEST_VERSION}.jar*\n  \n  only classes for the examples are included, so it should be submitted to spark together with other dependent jars\n\n+ *aggregator/target/sample_xgboost_apps-${LATEST_VERSION}-jar-with-dependencies.jar*\n  \n  both classes for the examples and the classes from dependent jars are included except cudf and rapids.\n\n"
  },
  {
    "path": "docs/get-started/xgboost-examples/csp/aws/ec2.md",
    "content": "# Get Started with XGBoost4J-Spark 3.0 on AWS EC2\n\nThis is a getting started guide to Spark 3.2+ on AWS EC2. At the end of this guide, the reader will be able to run a sample Apache Spark application that runs on NVIDIA GPUs on AWS EC2.\n\nFor more details of AWS EC2 and get started, please check the [AWS document](https://aws.amazon.com/ec2/getting-started/).\n\n## Configure and Launch AWS EC2\n\nGo to AWS Management Console select a region, e.g. Oregon, and click EC2 service.\n\n### Step 1:  Launch New Instance\n\nClick \"Launch instance\" at the EC2 Management Console, and select \"Launch instance\".\n\n![Step 1:  Launch New Instance](pics/ec2_step1.png)\n\n### Step 2:  Configure Instance\n\n#### Step 2.1: Choose an Amazon Machine Image(AMI)\n\nSearch for \"deep learning base ami\", choose \"Deep Learning Base AMI (Ubuntu 18.04)\". Click \"Select\".\n\n![Step 2.1: Choose an Amazon Machine Image(AMI)](pics/ec2_step2-1.png)\n\n#### Step 2.2: Choose an Instance Type\n\nChoose type \"p3.2xlarge\". Click \"Next: Configure Instance Details\" at right buttom.\n\n![Step 2.1: Choose an Instance Type](pics/ec2_step2-2.png)\n\n#### Step 2.3: Configure Instance Detials\n\nDo not need to change anything here, make sure \"Number of instances\" is 1. Click \"Next: Add Storage\" at right buttom.\n\n![Step 2.3: Configure Instance Detials](pics/ec2_step2-3.png)\n\n#### Step 2.4: Add Storage\n\nChange the root disk size based on your needed, also you can add ebs volume by clicking \"Add New Volume\". In this sample, we use default 50G. Click \"Next: Add Tag\" at right buttom.\n\nFor more details of AWS EBS please check the [AWS document](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/AmazonEBS.html).\n\n![Step 2.4: Add Storage](pics/ec2_step2-4.png)\n\n#### Step 2.5: Add Tags\n\nYou can add tag here or skip. In this sample, we will skip it. Click \"Next: Configure Security Group\" at right buttom.\n\n#### Step 2.6: Configure Security Group\n\nFor convenience, in this sample, we open all ports. You can add your own rules.\n\nCreate a new security group and select type as \"All traffic\". Click \"Review and Launch\" at right buttom.\n\nFor more details of security group, please check the [AWS document](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/ec2-security-groups.html).\n\n![Step 2.6: Configure Security Group](pics/ec2_step2-6.png)\n\n#### Step 2.7: Review Instance Launch\n\nReview your configuration. Click \"Launch\" at right buttom. Choose the key-pair you have and launch instances.\n\nReturn \"instances | EC2 Managemnt Console\", you can find your instance running. (It may take a few minutes to initialize)\n\n![Step 2.7: Review Instance Launch](pics/ec2_step2-7.png)\n\n## Launch EC2 and Configure Spark 3.2+\n\n### Step 1:  Launch EC2\n\nCopy \"Public DNS (IPv4)\" of your instance \nUse ssh with your private key to launch the EC2 machine as user \"ubuntu\"\n\n``` bash\nssh -i \"key.pem\" ubuntu@xxxx.region.compute.amazonaws.com\n```\n\n### Step 2: Download Spark package\n\nDownload spark package and set environment variable.\n\n``` bash\n# download the spark\nwget https://dlcdn.apache.org/spark/spark-3.2.1/spark-3.2.1-bin-hadoop3.2.tgz\ntar zxf spark-3.2.1-bin-hadoop3.2.tgz\nexport SPARK_HOME=/your/spark/spark-3.2.1-bin-hadoop3.2\n```\n\n### Step 3: Download jars for S3A (optional)\n\nIf your dataset is on S3, you should download below jar files to enable the accessing of S3. In this sample, we will use data on S3.\nThe jars should under $SPARK_HOME/jars\n\n``` bash\ncd $SPARK_HOME/jars\nwget https://github.com/JodaOrg/joda-time/releases/download/v2.10.5/joda-time-2.10.5.jar\nwget https://repo1.maven.org/maven2/org/apache/hadoop/hadoop-aws/3.2.0/hadoop-aws-3.2.0.jar\nwget https://repo1.maven.org/maven2/com/amazonaws/aws-java-sdk/1.11.687/aws-java-sdk-1.11.687.jar\nwget https://repo1.maven.org/maven2/com/amazonaws/aws-java-sdk-core/1.11.687/aws-java-sdk-core-1.11.687.jar\nwget https://repo1.maven.org/maven2/com/amazonaws/aws-java-sdk-dynamodb/1.11.687/aws-java-sdk-dynamodb-1.11.687.jar\nwget https://repo1.maven.org/maven2/com/amazonaws/aws-java-sdk-s3/1.11.687/aws-java-sdk-s3-1.11.687.jar\n```\n\n### Step 4: Start Spark Standalone\n\n#### Step 4.1: Edit spark-default.conf\n\ncd $SPARK_HOME/conf and edit spark-defaults.conf\n\nBy default, thers is only spark-defaults.conf.template in $SPARK_HOME/conf, you could edit it and rename to spark-defaults.conf\nYou can find getGpusResources.sh in $SPARK_HOME/examples/src/main/scripts/getGpusResources.sh\n\n``` bash\nspark.worker.resource.gpu.amount 1\nspark.worker.resource.gpu.discoveryScript /path/to/getGpusResources.sh\n```\n\nThe gpu.amount should be <= the number of GPUs the worker has.\n\n#### Step 4.2: Start Spark Standalone\n\nStart Spark. Default master-spark-URL is spark://$HOSTNAME:7077 . \n\n``` bash\n$SPARK_HOME/sbin/start-master.sh\n$SPARK_HOME/sbin/start-slave.sh <master-spark-URL>\n```\n\n## Launch XGBoost-Spark examples on Spark 3.2+\n\n### Step 1: Download Jars\n\nMake sure you have prepared the necessary packages and dataset by following this [guide](/docs/get-started/xgboost-examples/prepare-package-data/preparation-scala.md)\n\nCopy rapids jars to `$SPARK_HOME/jars`\n\n``` bash\ncp $RAPIDS_JAR $SPARK_HOME/jars/\n```\n\n### Step 2: Create sample running script\n\nCreate running run.sh script with below content, make sure change the paths in it to your own. Also your aws key/secret.\n\n``` bash\n#!/bin/bash\nexport SPARK_HOME=/your/path/to/spark-3.2.1-bin-hadoop3.2\n\nexport PATH=$SPARK_HOME/bin:$SPARK_HOME/sbin:$PATH\n\nexport TOTAL_CORES=8\nexport NUM_EXECUTORS=1\nexport NUM_EXECUTOR_CORES=$((${TOTAL_CORES}/${NUM_EXECUTORS}))\n\nexport S3A_CREDS_USR=your_aws_key\n\nexport S3A_CREDS_PSW=your_aws_secret\n\nspark-submit --master spark://$HOSTNAME:7077 \\\n        --deploy-mode client \\\n        --driver-memory 10G \\\n        --executor-memory 22G \\\n        --conf spark.hadoop.fs.s3a.impl=org.apache.hadoop.fs.s3a.S3AFileSystem \\\n        --conf spark.hadoop.fs.s3a.access.key=$S3A_CREDS_USR \\\n        --conf spark.hadoop.fs.s3a.secret.key=$S3A_CREDS_PSW \\\n        --conf spark.executor.memoryOverhead=28G \\\n        --conf spark.cores.max=$TOTAL_CORES \\\n        --conf spark.executor.cores=$NUM_EXECUTOR_CORES \\\n        --conf spark.task.cpus=$NUM_EXECUTOR_CORES \\\n        --conf spark.sql.files.maxPartitionBytes=4294967296 \\\n        --conf spark.yarn.maxAppAttempts=1 \\\n        --conf spark.plugins=com.nvidia.spark.SQLPlugin \\\n        --conf spark.rapids.memory.gpu.pooling.enabled=false \\\n        --conf spark.executor.resource.gpu.amount=1 \\\n        --conf spark.task.resource.gpu.amount=1 \\\n        --class com.nvidia.spark.examples.mortgage.GPUMain \\\n        ${SAMPLE_JAR} \\\n        -num_workers=${NUM_EXECUTORS} \\\n        -format=csv \\\n        -dataPath=\"train::your-train-data-path\" \\\n        -dataPath=\"trans::your-eval-data-path\" \\\n        -numRound=100 -max_depth=8 -nthread=$NUM_EXECUTOR_CORES -showFeatures=0 \\\n        -tree_method=gpu_hist\n```\n\n### Step 3: Submit Sample job\n\nRun run.sh\n\n``` bash\n./run.sh\n```\n\nAfter running successfully, the job will print an accuracy benchmark for model prediction.  \n"
  },
  {
    "path": "docs/get-started/xgboost-examples/csp/databricks/databricks.md",
    "content": "Get Started with XGBoost4J-Spark on Databricks\n======================================================\n\nThis is a getting started guide to XGBoost4J-Spark on Databricks. At the end of this guide, the reader will be able to run a sample Apache Spark application that runs on NVIDIA GPUs on Databricks.\n\nPrerequisites\n-------------\n\n    * Apache Spark 3.x running in Databricks Runtime 10.4 ML or 11.3 ML with GPU\n    * AWS: 10.4 LTS ML (GPU, Scala 2.12, Spark 3.2.1) or 11.3 LTS ML (GPU, Scala 2.12, Spark 3.3.0)\n    * Azure: 10.4 LTS ML (GPU, Scala 2.12, Spark 3.2.1) or 11.3 LTS ML (GPU, Scala 2.12, Spark 3.3.0)\n\nThe number of GPUs per node dictates the number of Spark executors that can run in that node. Each executor should only be allowed to run 1 task at any given time.\n   \nStart A Databricks Cluster\n--------------------------\nBefore creating the cluster, we will need to create an [initialization script](https://docs.databricks.com/clusters/init-scripts.html) for the \ncluster to install the RAPIDS jars. Databricks recommends storing all cluster-scoped init scripts using workspace files. \nEach user has a Home directory configured under the /Users directory in the workspace. \nNavigate to your home directory in the UI and select **Create** > **File** from the menu, \ncreate an `init.sh` scripts with contents:   \n   ```bash\n   #!/bin/bash\n   sudo wget -O /databricks/jars/rapids-4-spark_2.12-26.02.0.jar https://repo1.maven.org/maven2/com/nvidia/rapids-4-spark_2.12/26.02.0/rapids-4-spark_2.12-26.02.0.jar\n   ```\n1. Select the Databricks Runtime Version from one of the supported runtimes specified in the\n   Prerequisites section.\n2. Choose the number of workers that matches the number of GPUs you want to use.\n3. Select a worker type. On AWS, use nodes with 1 GPU each such as `p3.2xlarge` or `g4dn.xlarge`.\n   For Azure, choose GPU nodes such as Standard_NC6s_v3. For GCP, choose N1 or A2 instance types with GPUs. \n4. Select the driver type. Generally this can be set to be the same as the worker.\n5. Click the “Edit” button, then navigate down to the “Advanced Options” section. Select the “Init Scripts” tab in \n   the advanced options section, and paste the workspace path to the initialization script:`/Users/user@domain/init.sh`, then click “Add”.\n   ![Init Script](../../../../img/databricks/initscript.png)\n6. Now select the “Spark” tab, and paste the following config options into the Spark Config section.\n   Change the config values based on the workers you choose. See Apache Spark\n   [configuration](https://spark.apache.org/docs/latest/configuration.html) and RAPIDS Accelerator\n   for Apache Spark [descriptions](https://nvidia.github.io/spark-rapids/docs/configs.html) for each config.\n\n    The\n    [`spark.task.resource.gpu.amount`](https://spark.apache.org/docs/latest/configuration.html#scheduling)\n    configuration is defaulted to 1 by Databricks. That means that only 1 task can run on an\n    executor with 1 GPU, which is limiting, especially on the reads and writes from Parquet. Set\n    this to 1/(number of cores per executor) which will allow multiple tasks to run in parallel just\n    like the CPU side. Having the value smaller is fine as well.\n    Note: Please remove the `spark.task.resource.gpu.amount` config for a single-node Databricks \n    cluster because Spark local mode does not support GPU scheduling.\n   \n    ```bash\n    spark.plugins com.nvidia.spark.SQLPlugin\n    spark.task.resource.gpu.amount 0.1\n    spark.rapids.memory.pinnedPool.size 2G\n    spark.rapids.sql.concurrentGpuTasks 2\n    ```\n\n    ![Spark Config](../../../../img/databricks/sparkconfig.png)\n\n    If running Pandas UDFs with GPU support from the plugin, at least three additional options\n    as below are required. The `spark.python.daemon.module` option is to choose the right daemon module\n    of python for Databricks. On Databricks, the python runtime requires different parameters than the\n    Spark one, so a dedicated python demon module `rapids.daemon_databricks` is created and should\n    be specified here. Set the config\n    [`spark.rapids.sql.python.gpu.enabled`](https://nvidia.github.io/spark-rapids/docs/configs.html#sql.python.gpu.enabled) to `true` to\n    enable GPU support for python. Add the path of the plugin jar (supposing it is placed under\n    `/databricks/jars/`) to the `spark.executorEnv.PYTHONPATH` option. For more details please go to\n    [GPU Scheduling For Pandas UDF](https://nvidia.github.io/spark-rapids/docs/additional-functionality/rapids-udfs.html#gpu-support-for-pandas-udf)\n\n    ```bash\n    spark.rapids.sql.python.gpu.enabled true\n    spark.python.daemon.module rapids.daemon_databricks\n    spark.executorEnv.PYTHONPATH /databricks/jars/rapids-4-spark_2.12-26.02.0.jar:/databricks/spark/python\n    ```\n   Note that since python memory pool require installing the cudf library, so you need to install cudf library in \n   each worker nodes `pip install cudf-cu11 --extra-index-url=https://pypi.nvidia.com` or disable python memory pool\n   `spark.rapids.python.memory.gpu.pooling.enabled=false`.\n   \n7. Click `Create Cluster`, it is now enabled for GPU-accelerated Spark.\n\nInstall the xgboost4j_spark jar in the cluster\n---------------------------\n\n1. See [Libraries](https://docs.databricks.com/user-guide/libraries.html) for how to install jars from DBFS\n2. Go to \"Libraries\" tab under your cluster and install dbfs:/FileStore/jars/${XGBOOST4J_SPARK_JAR} in your cluster by selecting the \"DBFS\" option for installing jars\n\nThese steps will ensure you are able to import xgboost libraries in python notebooks.\n\nImport the GPU Mortgage Example Notebook\n---------------------------\n\n1. See [Managing Notebooks](https://docs.databricks.com/user-guide/notebooks/notebook-manage.html) on how to import a notebook.\n2. Import the example notebook: [XGBoost4j-Spark mortgage notebook](../../../../../examples/XGBoost-Examples/mortgage/notebooks/scala/mortgage-gpu.ipynb)\n3. Inside the mortgage example notebook, update the data paths from \n\"/data/datasets/mortgage-small/train\" to \"dbfs:/FileStore/tables/mortgage/csv/train/mortgage_train_merged.csv\"\n\"/data/datasets/mortgage-small/eval\" to \"dbfs:/FileStore/tables/mortgage/csv/test/mortgage_eval_merged.csv\"\n\nThe example notebook comes with the following configuration, you can adjust this according to your setup.\nSee supported configuration options here: [xgboost parameters](../../../../../examples/XGBoost-Examples/app-parameters/supported_xgboost_parameters_python.md)\n\n``` bash\nparams = { \n    'eta': 0.1,\n    'gamma': 0.1,\n    'missing': 0.0,\n    'treeMethod': 'gpu_hist',\n    'maxDepth': 10, \n    'maxLeaves': 256,\n    'growPolicy': 'depthwise',\n    'minChildWeight': 30.0,\n    'lambda_': 1.0,\n    'scalePosWeight': 2.0,\n    'subsample': 1.0,\n    'nthread': 1,\n    'numRound': 100,\n    'numWorkers': 1,\n}\n```\n\n4. Run all the cells in the notebook.\n\n5. View the results\nIn the cell 5 (Training), 7 (Transforming) and 8 (Accuracy of Evaluation) you will see the output.\n\n```\n--------------\n==> Benchmark: \nTraining takes 6.48 seconds\n--------------\n\n--------------\n==> Benchmark: Transformation takes 3.2 seconds\n\n--------------\n\n------Accuracy of Evaluation------\nAccuracy is 0.9980699597729774\n\n```\n\nLimitations\n-------------\n\n1. When selecting GPU nodes, Databricks UI requires the driver node to be a GPU node. However you \n   can use Databricks API to create a cluster with CPU driver node.\n   Outside of Databricks the plugin can operate with the driver as a CPU node and workers as GPU nodes.\n\n2. Cannot spin off multiple executors on a multi-GPU node. \n\n   Even though it is possible to set `spark.executor.resource.gpu.amount=1` in the in Spark \n   Configuration tab, Databricks overrides this to `spark.executor.resource.gpu.amount=N` \n   (where N is the number of GPUs per node). This will result in failed executors when starting the\n   cluster.\n\n3. Parquet rebase mode is set to \"LEGACY\" by default.\n\n   The following Spark configurations are set to `LEGACY` by default on Databricks:\n   \n   ```\n   spark.sql.legacy.parquet.datetimeRebaseModeInWrite\n   spark.sql.legacy.parquet.int96RebaseModeInWrite\n   ```\n   \n   These settings will cause a CPU fallback for Parquet writes involving dates and timestamps.\n   If you do not need `LEGACY` write semantics, set these configs to `EXCEPTION` which is\n   the default value in Apache Spark 3.0 and higher.\n\n4. Databricks makes changes to the runtime without notification.\n\n    Databricks makes changes to existing runtimes, applying patches, without notification.\n    [Issue-3098](https://github.com/NVIDIA/spark-rapids/issues/3098) is one example of this.  We run\n    regular integration tests on the Databricks environment to catch these issues and fix them once\n    detected."
  },
  {
    "path": "docs/get-started/xgboost-examples/csp/databricks/init.sh",
    "content": "#!/bin/bash\n# Copyright (c) 2025-2026, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n\nsudo rm -f /databricks/jars/spark--maven-trees--ml--10.x--xgboost-gpu--ml.dmlc--xgboost4j-gpu_2.12--ml.dmlc__xgboost4j-gpu_2.12__1.5.2.jar\nsudo rm -f /databricks/jars/spark--maven-trees--ml--10.x--xgboost-gpu--ml.dmlc--xgboost4j-spark-gpu_2.12--ml.dmlc__xgboost4j-spark-gpu_2.12__1.5.2.jar\n\nsudo wget -O /databricks/jars/rapids-4-spark_2.12-26.02.0.jar https://repo1.maven.org/maven2/com/nvidia/rapids-4-spark_2.12/26.02.0/rapids-4-spark_2.12-26.02.0.jar\nsudo wget -O /databricks/jars/xgboost4j-gpu_2.12-1.7.1.jar https://repo1.maven.org/maven2/ml/dmlc/xgboost4j-gpu_2.12/1.7.1/xgboost4j-gpu_2.12-1.7.1.jar\nsudo wget -O /databricks/jars/xgboost4j-spark-gpu_2.12-1.7.1.jar https://repo1.maven.org/maven2/ml/dmlc/xgboost4j-spark-gpu_2.12/1.7.1/xgboost4j-spark-gpu_2.12-1.7.1.jar\nls -ltr\n\nmkdir -p /dbfs/FileStore/tables/\ncd /dbfs/FileStore/tables/\n# Note that this is just a dummy dataset for quickly hands on, please refer the instructions to download the full dataset:\n# https://github.com/NVIDIA/spark-rapids-examples/blob/main/docs/get-started/xgboost-examples/dataset/mortgage.md\nwget -O mortgage.zip https://rapidsai-data.s3.us-east-2.amazonaws.com/spark/mortgage.zip\nls\nunzip -o mortgage.zip\npwd\nls -ltr mortgage/csv/*"
  },
  {
    "path": "docs/get-started/xgboost-examples/csp/dataproc/gcp.md",
    "content": "# Getting started pyspark+xgboost with RAPIDS Accelerator on GCP Dataproc\n [Google Cloud Dataproc](https://cloud.google.com/dataproc) is Google Cloud's fully managed Apache\n Spark and Hadoop service. Please make sure to install gcloud CLI by following \n this [guide](https://cloud.google.com/sdk/docs/install) before getting started.\n \n## Create a Dataproc Cluster using T4's\n* One 16-core master node and 2 32-core worker nodes\n* Two NVIDIA T4 for each worker node\n\n```bash\n    export REGION=[Your Preferred GCP Region]\n    export GCS_BUCKET=[Your GCS Bucket]\n    export CLUSTER_NAME=[Your Cluster Name]\n    export NUM_GPUS=2\n    export NUM_WORKERS=2\n\ngcloud dataproc clusters create $CLUSTER_NAME  \\\n    --region=$REGION \\\n    --image-version=2.0-ubuntu18 \\\n    --master-machine-type=n2-standard-16 \\\n    --num-workers=$NUM_WORKERS \\\n    --worker-accelerator=type=nvidia-tesla-t4,count=$NUM_GPUS \\\n    --worker-machine-type=n1-highmem-32\\\n    --num-worker-local-ssds=4 \\\n    --initialization-actions=gs://goog-dataproc-initialization-actions-${REGION}/spark-rapids/spark-rapids.sh \\\n    --optional-components=JUPYTER,ZEPPELIN \\\n    --metadata=rapids-runtime=SPARK \\\n    --bucket=$GCS_BUCKET \\\n    --enable-component-gateway \\\n    --subnet=default\n```\n\nExplanation of parameters:\n* NUM_GPUS = number of GPUs to attach to each worker node in the cluster\n* NUM_WORKERS = number of Spark worker nodes in the cluster\n\nThis takes around 10-15 minutes to complete.  You can navigate to the Dataproc clusters tab in the\nGoogle Cloud Console to see the progress.\n\n![Dataproc Cluster](../../../../img/GCP/dataproc-cluster.png)\n\nIf you'd like to further accelerate init time to 4-5 minutes, create a custom Dataproc image using\n[this](#build-custom-dataproc-image-to-accelerate-cluster-init-time) guide.\n\n\n## Get Application Files, Jar and Dataset\n\nBash into the master node and make sure you have prepared the necessary packages and dataset by following this [guide](../../prepare-package-data/preparation-python.md).\n\nNote: Since there is no maven CLI in master node, so we need to manually install.\n``` bash\ngcloud compute ssh your-name@your-cluster-m --zone your-zone\nsudo apt-get install maven -y\n```\n\nThen create a directory in HDFS, and run below commands,\n\n``` bash\n[xgboost4j_spark_python]$ hadoop fs -mkdir /tmp/xgboost4j_spark_python\n[xgboost4j_spark_python]$ hadoop fs -copyFromLocal ${SPARK_XGBOOST_DIR}/mortgage/* /tmp/xgboost4j_spark_python\n```\n\n## Preparing libraries\nPlease make sure to install the XGBoost, cudf-cu11, numpy libraries on all nodes before running XGBoost application.\n``` bash\npip install xgboost\npip install cudf-cu11 --extra-index-url=https://pypi.nvidia.com\npip install numpy\npip install scikit-learn\n```\nYou can also create an isolated python environment by using [Virtualenv](https://virtualenv.pypa.io/en/latest/),\nand then directly pass/unpack the archive file and enable the environment on executors\nby leveraging the --archives option or spark.archives configuration.\n``` bash\n# create an isolated python environment and install libraries\npython -m venv pyspark_venv\nsource pyspark_venv/bin/activate\npip install xgboost\npip install cudf-cu11 --extra-index-url=https://pypi.nvidia.com\npip install numpy\npip install scikit-learn\npip install venv-pack\nvenv-pack -o pyspark_venv.tar.gz\n\n# enable archive python environment on executors\nexport PYSPARK_DRIVER_PYTHON=python # Do not set in cluster modes.\nexport PYSPARK_PYTHON=./environment/bin/python\nspark-submit --archives pyspark_venv.tar.gz#environment app.py\n```\n## Run jupyter notebooks on Dataproc \n\nBash into the master node and start up the notebook.\n```\njupyter notebook --ip=0.0.0.0 --port=8124 --no-browser\n```\n\nIf you want to remote access the notebook from local, please reserve an external static IP address first:\n1. Access the IP addresses page through the navigation menu: `VPC network` -> `IP addresses`\n![dataproc img2](../../../../img/GCP/dataproc-img2.png)\n2. Click the `RESERVE EXTERNAL STATIC ADDRESS` button\n![dataproc img3](../../../../img/GCP/dataproc-img3.png)\n3. Attached the static address to the master node of your cluster\n![dataproc img4](../../../../img/GCP/dataproc-img4.png)\n4. Then you can access and run the notebook from the browser in local using the reserved address.  \n![dataproc img5](../../../../img/GCP/dataproc-img5.png)\n\nThen you can run the [notebook](../../../../../examples/XGBoost-Examples/mortgage/notebooks/python/mortgage-gpu.ipynb) and get the benchmark results.\n![dataproc img6](../../../../img/GCP/dataproc-img6.png)\n\n## Build custom dataproc image to accelerate cluster init time\nIn order to accelerate cluster init time to 3-4 minutes, we need to build a custom Dataproc image\nthat already has NVIDIA drivers and CUDA toolkit installed, with RAPIDS deployed. The custom image\ncould also be used in an air gap environment. In this section, we will be using [these instructions\nfrom GCP](https://cloud.google.com/dataproc/docs/guides/dataproc-images) to create a custom image.\n\nCurrently, we can directly download the [spark-rapids.sh](https://github.com/GoogleCloudDataproc/initialization-actions/tree/master/spark-rapids)\nscript to create the Dataproc image:\n\nGoogle provides a `generate_custom_image.py` script that:\n- Launches a temporary Compute Engine VM instance with the specified Dataproc base image.\n- Then runs the customization script inside the VM instance to install custom packages and/or\nupdate configurations.\n- After the customization script finishes, it shuts down the VM instance and creates a Dataproc\n  custom image from the disk of the VM instance.\n- The temporary VM is deleted after the custom image is created.\n- The custom image is saved and can be used to create Dataproc clusters.\n\nDownload `spark-rapids.sh` in this repo.  The script uses\nGoogle's `generate_custom_image.py` script.  This step may take 20-25 minutes to complete.\n\n```bash\ngit clone https://github.com/GoogleCloudDataproc/custom-images\ncd custom-images\n\nexport CUSTOMIZATION_SCRIPT=/path/to/spark-rapids.sh\nexport ZONE=[Your Preferred GCP Zone]\nexport GCS_BUCKET=[Your GCS Bucket]\nexport IMAGE_NAME=sample-20-ubuntu18-gpu-t4\nexport DATAPROC_VERSION=2.0-ubuntu18\nexport GPU_NAME=nvidia-tesla-t4\nexport GPU_COUNT=1\n\npython generate_custom_image.py \\\n    --image-name $IMAGE_NAME \\\n    --dataproc-version $DATAPROC_VERSION \\\n    --customization-script $CUSTOMIZATION_SCRIPT \\\n    --no-smoke-test \\\n    --zone $ZONE \\\n    --gcs-bucket $GCS_BUCKET \\\n    --machine-type n1-standard-4 \\\n    --accelerator type=$GPU_NAME,count=$GPU_COUNT \\\n    --disk-size 200 \\\n    --subnet default \n```\n\nSee [here](https://cloud.google.com/dataproc/docs/guides/dataproc-images#running_the_code) for more\ndetails on `generate_custom_image.py` script arguments and\n[here](https://cloud.google.com/dataproc/docs/concepts/versioning/dataproc-versions) for dataproc\nversion description.\n\nThe image `sample-20-ubuntu18-gpu-t4` is now ready and can be viewed in the GCP console under\n`Compute Engine > Storage > Images`. The next step is to launch the cluster using this new image\nand new initialization actions (that do not install NVIDIA drivers since we are already past that\nstep).\n\nMove this to your own bucket. Let's launch the cluster:\n\n```bash \nexport REGION=[Your Preferred GCP Region]\nexport GCS_BUCKET=[Your GCS Bucket]\nexport CLUSTER_NAME=[Your Cluster Name]\nexport NUM_GPUS=1\nexport NUM_WORKERS=2\n\ngcloud dataproc clusters create $CLUSTER_NAME  \\\n    --region=$REGION \\\n    --image=sample-20-ubuntu18-gpu-t4 \\\n    --master-machine-type=n1-standard-4 \\\n    --num-workers=$NUM_WORKERS \\\n    --worker-accelerator=type=nvidia-tesla-t4,count=$NUM_GPUS \\\n    --worker-machine-type=n1-standard-4 \\\n    --num-worker-local-ssds=1 \\\n    --optional-components=JUPYTER,ZEPPELIN \\\n    --metadata=rapids-runtime=SPARK \\\n    --bucket=$GCS_BUCKET \\\n    --enable-component-gateway \\\n    --subnet=default \n```\n\nThe new cluster should be up and running within 3-4 minutes!\n\n"
  },
  {
    "path": "docs/get-started/xgboost-examples/dataset/mortgage.md",
    "content": "# How to download the Mortgage dataset\n\n\n\n## Steps to download the data\n\n1. Go to the [Fannie Mae](https://capitalmarkets.fanniemae.com/credit-risk-transfer/single-family-credit-risk-transfer/fannie-mae-single-family-loan-performance-data) website\n2. Click on [Single-Family Loan Performance Data](https://datadynamics.fanniemae.com/data-dynamics/?&_ga=2.181456292.2043790680.1657122341-289272350.1655822609#/reportMenu;category=HP)\n    * Register as a new user if you are using the website for the first time\n    * Use the credentials to login\n3. Select [HP](https://datadynamics.fanniemae.com/data-dynamics/#/reportMenu;category=HP)\n4. Click on  **Download Data** and choose *Single-Family Loan Performance Data*\n5. You will find a tabular list of 'Acquisition and Performance' files sorted based on year and quarter. Click on the file to download `Eg: 2017Q1.zip`\n6. Unzip the downlad file to extract the csv file `Eg: 2017Q1.csv`\n7. Copy only the csv files to a new folder for the ETL to read\n\n## Notes\n1. Refer to the [Loan Performance Data Tutorial](https://capitalmarkets.fanniemae.com/media/9066/display) for more details. \n2. Note that *Single-Family Loan Performance Data* has 2 componenets. However, the Mortgage ETL requires only the first one (primary dataset)\n    * Primary Dataset:  Acquisition and Performance Files\n    * HARP Dataset\n3. Use the [Resources](https://datadynamics.fanniemae.com/data-dynamics/#/resources/HP) section to know more about the dataset"
  },
  {
    "path": "docs/get-started/xgboost-examples/notebook/python-notebook.md",
    "content": "Get Started with pyspark+XGBoost with Jupyter Notebook\n===================================================================\n\nThis is a getting started guide to XGBoost4J-Spark using an [Jupyter notebook](https://jupyter.org/). \nAt the end of this guide, you will be able to run a sample notebook that runs on NVIDIA GPUs.\n\nBefore you begin, please ensure that you have setup a Spark Cluster(Standalone or YARN).\nYou should change `--master` config according to your cluster architecture. For example, set `--master yarn` for spark on YARN.\n\nIt is assumed that the `SPARK_MASTER` and `SPARK_HOME` environment variables are defined and point to the Spark Master URL (e.g. `spark://localhost:7077`),\nand the home directory for Apache Spark respectively.\n\n1. Make sure you have [Jupyter notebook installed](https://jupyter.org/install.html).\n\n   If you install it with conda, please make sure your Python version is consistent.\n\n2. Prepare packages and dataset.\n\n    Make sure you have prepared the necessary packages and dataset by following this [guide](../prepare-package-data/preparation-python.md)\n\n3. Launch the notebook:\n\n   Note: For ETL jobs, Set `spark.task.resource.gpu.amount` to `1/spark.executor.cores`.\n\n    For ETL:\n\n    ``` bash\n    PYSPARK_DRIVER_PYTHON=jupyter       \\\n    PYSPARK_DRIVER_PYTHON_OPTS=notebook \\\n    pyspark                             \\\n    --master ${SPARK_MASTER}            \\\n    --jars ${RAPIDS_JAR}\\\n    --py-files ${SAMPLE_ZIP}      \\\n    --conf spark.plugins=com.nvidia.spark.SQLPlugin \\\n    --conf spark.executor.resource.gpu.amount=1 \\\n    --conf spark.executor.cores=10 \\\n    --conf spark.task.resource.gpu.amount=0.1 \\\n    --conf spark.sql.cache.serializer=com.nvidia.spark.ParquetCachedBatchSerializer \\\n    --conf spark.executor.resource.gpu.discoveryScript=./getGpusResources.sh \\\n    --files $SPARK_HOME/examples/src/main/scripts/getGpusResources.sh\n    ```\n\n    For XGBoost:\n\n    ``` bash\n    PYSPARK_DRIVER_PYTHON=jupyter       \\\n    PYSPARK_DRIVER_PYTHON_OPTS=notebook \\\n    pyspark                             \\\n    --master ${SPARK_MASTER}            \\\n    --jars ${RAPIDS_JAR}\\\n    --py-files ${SAMPLE_ZIP}      \\\n    --conf spark.plugins=com.nvidia.spark.SQLPlugin \\\n    --conf spark.rapids.memory.gpu.pool=NONE \\\n    --conf spark.executor.resource.gpu.amount=1 \\\n    --conf spark.executor.cores=10 \\\n    --conf spark.task.resource.gpu.amount=1 \\\n    --conf spark.sql.execution.arrow.maxRecordsPerBatch=200000 \\\n    --conf spark.executor.resource.gpu.discoveryScript=./getGpusResources.sh \\\n    --files $SPARK_HOME/examples/src/main/scripts/getGpusResources.sh\n    ```\n\n4. Launch ETL Part \n\n- Mortgage ETL Notebook: [Python](../../../../examples/XGBoost-Examples/mortgage/notebooks/python/MortgageETL.ipynb)\n- Taxi ETL Notebook: [Python](../../../../examples/XGBoost-Examples/taxi/notebooks/python/taxi-ETL.ipynb)\n- Note: Agaricus does not have ETL part.\n"
  },
  {
    "path": "docs/get-started/xgboost-examples/notebook/spylon.md",
    "content": "Get Started with XGBoost4J-Spark with Spylon Kernel Jupyter Notebook\n===================================================================\n\nThis is a getting started guide to XGBoost4J-Spark using a [Spylon Kernel](https://pypi.org/project/spylon-kernel/) Jupyter notebook. \nAt the end of this guide, the reader will be able to run a sample notebook that runs on NVIDIA GPUs.\n\nBefore you begin, please ensure that you have setup \na [Spark Standalone Cluster](/docs/get-started/xgboost-examples/on-prem-cluster/standalone-scala.md).\n\nIt is assumed that the `SPARK_MASTER` and `SPARK_HOME` environment variables are defined and point to the Spark Master URL, \nand the home directory for Apache Spark respectively.\n\n1. Install Jupyter Notebook with spylon-kernel.\n   ``` bash\n   # Install notebook and spylon-kernel (Scala kernel for Jupyter Notebook), https://pypi.org/project/spylon-kernel/\n   # You can use spylon-kernel as Scala kernel for Jupyter Notebook. Do this when you want to work with Spark in Scala with a bit of Python code mixed in.\n   RUN pip3 install jupyter notebook spylon-kernel\n   RUN python -m spylon_kernel install\n   # Latest version breaks nbconvert: https://github.com/ipython/ipykernel/issues/422\n   RUN pip3 install ipykernel==5.1.1\n   ```\n2. Start Jupyter Notebook. \n<!-- markdown-link-check-disable -->\nYou can debug from webUI http://your_ip:your_port with your password.\n<!-- markdown-link-check-enable -->    \n    ``` bash\n    export JUPYTER_CONFIG_FILE=~/.jupyter/jupyter_notebook_config.py\n    \n    rm -rf `dirname $JUPYTER_CONFIG_FILE` && mkdir -p `dirname $JUPYTER_CONFIG_FILE` && echo \"\"\"\n    c.NotebookApp.ip='*'\n    c.NotebookApp.password = your_hashed_password\n    c.NotebookApp.password = your_password \n    c.NotebookApp.open_browser = False\n    c.NotebookApp.port = your_port\n    \"\"\" > $JUPYTER_CONFIG_FILE\n \n    jupyter notebook --allow-root --notebook-dir=$WORKSPACE --config=$JUPYTER_CONFIG_FILE &\n    ```\n3. Prepare packages and dataset.\n\n    Make sure you have prepared the necessary packages and dataset by following this [guide](/docs/get-started/xgboost-examples/prepare-package-data/preparation-scala.md)\n\n4. Run scala notebook (e.g. [mortgage-gpu.ipynb](../../../../examples/XGBoost-Examples/mortgage/notebooks/scala/mortgage-gpu.ipynb))\n\n    ``` bash\n    # Suppose your Scala file is $WORKSPACE/mortgage-gpu.ipynb\n \n    jupyter nbconvert --to notebook --stdout --execute $WORKSPACE/mortgage-gpu.ipynb\n     \n    # -------you will see output looks like ----------------\n    # { \n    #   \"cells\": [\n    #   {\n    #    \"cell_type\": \"code\",\n    #    \"execution_count\": 1,\n    #    \"id\": \"5ca1ae16\",\n    #    \"metadata\": {\n    #     ........\n    #     ........\n    #     ........\n    #   \"language_info\": {\n    #    \"codemirror_mode\": \"text/x-scala\",\n    #    \"file_extension\": \".scala\",\n    #    \"help_links\": [\n    #     {\n    #      \"text\": \"MetaKernel Magics\",\n    #      \"url\": \"https://metakernel.readthedocs.io/en/latest/source/README.html\"\n    #     }\n    #    ],\n    #    \"mimetype\": \"text/x-scala\",\n    #    \"name\": \"scala\",\n    #    \"pygments_lexer\": \"scala\",\n    #    \"version\": \"0.4.1\"\n    #   }\n    #  },\n    #  \"nbformat\": 4,\n    #  \"nbformat_minor\": 5\n    # }\n    ```\n    You can also run python notebook with Spylon Kernel\n    ``` bash\n    # restart Jupyter Notebook\n  \n    export PYSPARK_DRIVER_PYTHON=jupyter\n    export PYSPARK_DRIVER_PYTHON_OPTS=\"notebook --allow-root --notebook-dir=$WORKSPACE --config=$JUPYTER_CONFIG_FILE\"\n    pyspark &\n     \n    # Suppose your python file is $WORKSPACE/mortgage-gpu.ipynb\n    jupyter nbconvert --to notebook--stdout --execute $WORKSPACE/mortgage-gpu.ipynb\n    ```\n   \n5. Launch ETL Part \n- Mortgage ETL Notebook: [Scala](../../../../examples/XGBoost-Examples/mortgage/notebooks/scala/mortgage-ETL.ipynb) or\n  [Python](../../../../examples/XGBoost-Examples/mortgage/notebooks/python/MortgageETL.ipynb)\n- Taxi ETL Notebook: [Scala](../../../../examples/XGBoost-Examples/taxi/notebooks/scala/taxi-ETL.ipynb) or\n  [Python](../../../../examples/XGBoost-Examples/taxi/notebooks/python/taxi-ETL.ipynb)\n- Note: Agaricus does not have ETL part.\n   \n6. Launch XGBoost Part\n- Mortgage XGBoost Notebook: [Scala](../../../../examples/XGBoost-Examples/mortgage/notebooks/scala/mortgage-gpu.ipynb) \n- Taxi XGBoost Notebook: [Scala](../../../../examples/XGBoost-Examples/taxi/notebooks/scala/taxi-gpu.ipynb)\n- Agaricus XGBoost Notebook: [Scala](../../../../examples/XGBoost-Examples/agaricus/notebooks/scala/agaricus-gpu.ipynb) "
  },
  {
    "path": "docs/get-started/xgboost-examples/notebook/toree.md",
    "content": "Get Started with XGBoost4J-Spark with Apache Toree Jupyter Notebook\n===================================================================\n\nThis is a getting started guide to XGBoost4J-Spark using an [Apache Toree](https://toree.apache.org/) Jupyter notebook. \nAt the end of this guide, you will be able to run a sample notebook that runs on NVIDIA GPUs.\n\nBefore you begin, please ensure that you have setup a Spark Cluster(Standalone or YARN).\nYou should change `--master` config according to your cluster architecture. For example, set `--master yarn` for spark on YARN.\n\nIt is assumed that the `SPARK_MASTER` and `SPARK_HOME` environment variables are defined and point to the Spark Master URL (e.g. `spark://localhost:7077`),\nand the home directory for Apache Spark respectively.\n\n1. Make sure you have jupyter notebook and [sbt](https://www.scala-sbt.org/1.x/docs/Installing-sbt-on-Linux.html) installed first.\n2. Build the 'toree' locally to support scala 2.12, and install it.\n\n    ``` bash\n    # Download toree\n    wget https://github.com/apache/incubator-toree/archive/refs/tags/v0.5.0-incubating-rc4.tar.gz\n    tar -xvzf v0.5.0-incubating-rc4.tar.gz\n    # Build the Toree pip package.\n    cd incubator-toree-0.5.0-incubating-rc4\n    make pip-release\n    # Install Toree\n    pip install dist/toree-pip/toree-0.5.0.tar.gz\n    ```\n3. Prepare packages and dataset.\n\n    Make sure you have prepared the necessary packages and dataset by following this [guide](/docs/get-started/xgboost-examples/prepare-package-data/preparation-scala.md)\n\n4. Install a new kernel with gpu enabled and launch the notebook\n\n    Note: For ETL jobs, Set `spark.task.resource.gpu.amount` to `1/spark.executor.cores`.\n\n    For ETL:\n    ``` bash\n    jupyter toree install                                \\\n    --spark_home=${SPARK_HOME}                             \\\n    --user                                          \\\n    --toree_opts='--nosparkcontext'                         \\\n    --kernel_name=\"ETL-Spark\"                         \\\n    --spark_opts='--master ${SPARK_MASTER} \\\n      --jars ${RAPIDS_JAR},${SAMPLE_JAR}       \\\n      --conf spark.plugins=com.nvidia.spark.SQLPlugin  \\\n      --conf spark.executor.extraClassPath=${RAPIDS_JAR} \\\n      --conf spark.executor.cores=10 \\\n      --conf spark.task.resource.gpu.amount=0.1 \\\n      --conf spark.executor.resource.gpu.discoveryScript=./getGpusResources.sh \\\n      --files $SPARK_HOME/examples/src/main/scripts/getGpusResources.sh'\n    ```\n\n    For XGBoost:\n     ``` bash\n    jupyter toree install                                \\\n    --spark_home=${SPARK_HOME}                             \\\n    --user                                          \\\n    --toree_opts='--nosparkcontext'                         \\\n    --kernel_name=\"XGBoost-Spark\"                         \\\n    --spark_opts='--master ${SPARK_MASTER} \\\n      --jars ${RAPIDS_JAR},${SAMPLE_JAR}       \\\n      --conf spark.plugins=com.nvidia.spark.SQLPlugin  \\\n      --conf spark.executor.extraClassPath=${RAPIDS_JAR} \\\n      --conf spark.rapids.memory.gpu.pool=NONE \\\n      --conf spark.executor.resource.gpu.amount=1 \\\n      --conf spark.executor.cores=10 \\\n      --conf spark.task.resource.gpu.amount=1 \\\n      --conf spark.executor.resource.gpu.discoveryScript=./getGpusResources.sh \\\n      --files $SPARK_HOME/examples/src/main/scripts/getGpusResources.sh'\n    ```\n\n    Launch the notebook:\n\n    ``` bash\n    jupyter notebook\n    ```\n\n4. Launch ETL Part \n- Mortgage ETL Notebook: [Scala](../../../../examples/XGBoost-Examples/mortgage/notebooks/scala/mortgage-ETL.ipynb)\n- Taxi ETL Notebook: [Scala](../../../../examples/XGBoost-Examples/taxi/notebooks/scala/taxi-ETL.ipynb)\n- Note: Agaricus does not have ETL part.\n   \n5. Launch XGBoost Part\n- Mortgage XGBoost Notebook: [Scala](../../../../examples/XGBoost-Examples/mortgage/notebooks/scala/mortgage-gpu.ipynb)\n- Taxi XGBoost Notebook: [Scala](../../../../examples/XGBoost-Examples/taxi/notebooks/scala/taxi-gpu.ipynb)\n- Agaricus XGBoost Notebook: [Scala](../../../../examples/XGBoost-Examples/agaricus/notebooks/scala/agaricus-gpu.ipynb)"
  },
  {
    "path": "docs/get-started/xgboost-examples/on-prem-cluster/kubernetes-scala.md",
    "content": "Get Started with XGBoost4J-Spark on Kubernetes\n==============================================\nThis is a getting started guide to deploy XGBoost4J-Spark package on a Kubernetes cluster. At the end of this guide,\nthe reader will be able to run a sample Apache Spark XGBoost application on NVIDIA GPU Kubernetes cluster.\n\nPrerequisites\n-------------\n\n* Apache Spark 3.2.0+ (e.g.: Spark 3.2.0)\n* Hardware Requirements\n  * NVIDIA Pascal™ GPU architecture or better\n  * Multi-node clusters with homogenous GPU configuration\n* Software Requirements\n  * Ubuntu 20.04, 22.04/CentOS7, Rocky Linux 8\n  * CUDA 11.0+\n  * NVIDIA driver compatible with your CUDA\n  * NCCL 2.7.8+\n* [Kubernetes cluster with NVIDIA GPUs](https://docs.nvidia.com/datacenter/cloud-native/kubernetes/install-k8s.html)\n  * See official [Spark on Kubernetes](https://spark.apache.org/docs/latest/running-on-kubernetes.html#prerequisites) \n    instructions for detailed spark-specific cluster requirements\n* kubectl installed and configured in the job submission environment\n  * Required for managing jobs and retrieving logs\n\nBuild a GPU Spark Docker Image\n------------------------------\n\nBuild a GPU Docker image with Spark resources in it, this Docker image must be accessible by each node in the Kubernetes cluster.\n\n1. Locate your Spark installations. If you don't have one, you can [download](https://spark.apache.org/downloads.html) from Apache and unzip it.\n2. `export SPARK_HOME=<path to spark>`\n3. [Download the Dockerfile](/dockerfile/Dockerfile) into `${SPARK_HOME}`. (Here CUDA 11.0 is used as an example in the Dockerfile,\n   you may need to update it for other CUDA versions.)\n4. __(OPTIONAL)__ install any additional library jars into the `${SPARK_HOME}/jars` directory.\n    * Most public cloud file systems are not natively supported -- pulling data and jar files from S3, GCS, etc. require installing additional libraries.\n5. Build and push the docker image.\n\n``` bash\nexport SPARK_HOME=<path to spark>\nexport SPARK_DOCKER_IMAGE=<gpu spark docker image repo and name>\nexport SPARK_DOCKER_TAG=<spark docker image tag>\n\npushd ${SPARK_HOME}\nwget https://github.com/NVIDIA/spark-rapids-examples/raw/branch-25.08/dockerfile/Dockerfile\n\n# Optionally install additional jars into ${SPARK_HOME}/jars/\n\ndocker build . -t ${SPARK_DOCKER_IMAGE}:${SPARK_DOCKER_TAG}\ndocker push ${SPARK_DOCKER_IMAGE}:${SPARK_DOCKER_TAG}\npopd\n```\n\nGet Jars and Dataset\n-------------------------------\n\nMake sure you have prepared the necessary packages and dataset by following this [guide](/docs/get-started/xgboost-examples/prepare-package-data/preparation-scala.md).\n\nMake sure that data and jars are accessible by each node of the Kubernetes cluster \nvia [Kubernetes volumes](https://spark.apache.org/docs/latest/running-on-kubernetes.html#using-kubernetes-volumes), \non cluster filesystems like HDFS, or in [object stores like S3 and GCS](https://spark.apache.org/docs/2.3.0/cloud-integration.html). \nNote that using [application dependencies](https://spark.apache.org/docs/latest/running-on-kubernetes.html#dependency-management) from \nthe submission client’s local file system is currently not yet supported.\n\n#### Note: \n1. Mortgage and Taxi jobs have ETLs to generate the processed data. \n2. For convenience, a subset of [Taxi](/datasets/) dataset is made available in this repo that can be readily used for launching XGBoost job. Use [ETL](#etl) to generate larger datasets for trainig and testing. \n3. Agaricus does not have an ETL process, it is combined with XGBoost as there is just a filter operation.\n\nSave Kubernetes Template Resources\n----------------------------------\n\nWhen using Spark on Kubernetes the driver and executor pods can be launched with pod templates. In the XGBoost4J-Spark use case,\nthese template yaml files are used to allocate and isolate specific GPUs to each pod. The following is a barebones template file to allocate 1 GPU per pod.\n\n```\napiVersion: v1\nkind: Pod\nspec:\n  containers:\n    - name: gpu-example\n      resources:\n        limits:\n          nvidia.com/gpu: 1\n```\n\nThis 1 GPU template file should be sufficient for all XGBoost jobs because each executor should only run 1 task on a single GPU.\nSave this yaml file to the local environment of the machine you are submitting jobs from, \nyou will need to provide a path to it as an argument in your spark-submit command. \nWithout the template file a pod will see every GPU on the cluster node it is allocated on and can attempt\nto execute using a GPU which is already in use -- causing undefined behavior and errors.\n\n<span id=\"etl\">Launch Mortgage or Taxi ETL Part</span>\n---------------------------\nUse the ETL app to process raw Mortgage data. You can either use this ETLed data to split into training and evaluation data or run the ETL on different subsets of the dataset to produce training and evaluation datasets. \n\nNote: For ETL jobs, Set `spark.task.resource.gpu.amount` to `1/spark.executor.cores`.\n\nRun spark-submit\n\n``` bash\n${SPARK_HOME}/bin/spark-submit \\\n   --conf spark.plugins=com.nvidia.spark.SQLPlugin \\\n   --conf spark.executor.resource.gpu.amount=1 \\\n   --conf spark.executor.cores=10 \\\n   --conf spark.task.resource.gpu.amount=0.1 \\\n   --conf spark.rapids.sql.incompatibleDateFormats.enabled=true \\\n   --conf spark.rapids.sql.csv.read.double.enabled=true \\\n   --conf spark.executor.resource.gpu.discoveryScript=./getGpusResources.sh \\\n   --conf spark.sql.cache.serializer=com.nvidia.spark.ParquetCachedBatchSerializer \\\n   --files $SPARK_HOME/examples/src/main/scripts/getGpusResources.sh \\\n   --jars ${RAPIDS_JAR}                                           \\\n   --master <k8s://ip:port or k8s://URL>                                                                  \\\n   --deploy-mode ${SPARK_DEPLOY_MODE}                                             \\\n   --num-executors ${SPARK_NUM_EXECUTORS}                                         \\\n   --driver-memory ${SPARK_DRIVER_MEMORY}                                         \\\n   --executor-memory ${SPARK_EXECUTOR_MEMORY}                                     \\\n   --class com.nvidia.spark.examples.mortgage.ETLMain  \\\n   $SAMPLE_JAR \\\n   -format=csv \\\n   -dataPath=\"data::${SPARK_XGBOOST_DIR}/mortgage/input/\" \\\n   -dataPath=\"out::${SPARK_XGBOOST_DIR}/mortgage/output/train/\" \\\n   -dataPath=\"tmp::${SPARK_XGBOOST_DIR}/mortgage/output/tmp/\"\n\n# if generating eval data, change the data path to eval\n# -dataPath=\"data::${SPARK_XGBOOST_DIR}/mortgage/input/\"\n# -dataPath=\"out::${SPARK_XGBOOST_DIR}/mortgage/output/eval/\"\n# -dataPath=\"tmp::${SPARK_XGBOOST_DIR}/mortgage/output/tmp/\"\n# if running Taxi ETL benchmark, change the class and data path params to\n# -class com.nvidia.spark.examples.taxi.ETLMain  \n# -dataPath=\"raw::${SPARK_XGBOOST_DIR}/taxi/your-path\"\n# -dataPath=\"out::${SPARK_XGBOOST_DIR}/taxi/your-path\"\n```\n\nLaunch XGBoost Part on GPU\n---------------------------\n\nVariables required to run spark-submit command:\n\n``` bash\n# Variables dependent on how data was made accessible to each node\n# Make sure to include relevant spark-submit configuration arguments\n# location where data was saved\nexport DATA_PATH=<path to data directory> \n\n# Variables independent of how data was made accessible to each node\n# kubernetes master URL, used as the spark master for job submission\nexport SPARK_MASTER=<k8s://ip:port or k8s://URL>\n\n# local path to the template file saved in the previous step\nexport TEMPLATE_PATH=${HOME}/gpu_executor_template.yaml\n\n# spark docker image location\nexport SPARK_DOCKER_IMAGE=<spark docker image repo and name>\nexport SPARK_DOCKER_TAG=<spark docker image tag>\n\n# kubernetes service account to launch the job with\nexport K8S_ACCOUNT=<kubernetes service account name>\n\n# spark deploy mode, cluster mode recommended for spark on kubernetes\nexport SPARK_DEPLOY_MODE=cluster\n\n# run a single executor for this example to limit the number of spark tasks and\n# partitions to 1 as currently this number must match the number of input files\nexport SPARK_NUM_EXECUTORS=1\n\n# spark driver memory\nexport SPARK_DRIVER_MEMORY=4g\n\n# spark executor memory\nexport SPARK_EXECUTOR_MEMORY=8g\n\n# example class to use\nexport EXAMPLE_CLASS=com.nvidia.spark.examples.mortgage.Main\n# or change to com.nvidia.spark.examples.taxi.Main to run Taxi Xgboost benchmark\n# or change to com.nvidia.spark.examples.agaricus.Main to run Agaricus Xgboost benchmark\n\n# tree construction algorithm\nexport TREE_METHOD=gpu_hist\n```\n\nRun spark-submit:\n\n``` bash\n${SPARK_HOME}/bin/spark-submit                                                          \\\n  --conf spark.plugins=com.nvidia.spark.SQLPlugin \\\n  --conf spark.rapids.memory.gpu.pool=NONE \\\n  --conf spark.executor.resource.gpu.amount=1 \\\n  --conf spark.task.resource.gpu.amount=1 \\\n  --conf spark.executor.resource.gpu.discoveryScript=./getGpusResources.sh \\\n  --files $SPARK_HOME/examples/src/main/scripts/getGpusResources.sh \\\n  --jars ${RAPIDS_JAR}                           \\\n  --master ${SPARK_MASTER}                                                              \\\n  --deploy-mode ${SPARK_DEPLOY_MODE}                                                    \\\n  --class ${EXAMPLE_CLASS}                                                              \\\n  --conf spark.executor.instances=${SPARK_NUM_EXECUTORS}                                \\\n  --conf spark.kubernetes.authenticate.driver.serviceAccountName=${K8S_ACCOUNT}         \\\n  --conf spark.kubernetes.container.image=${SPARK_DOCKER_IMAGE}:${SPARK_DOCKER_TAG}     \\\n  --conf spark.kubernetes.driver.podTemplateFile=${TEMPLATE_PATH}                       \\\n  --conf spark.kubernetes.executor.podTemplateFile=${TEMPLATE_PATH}                     \\\n  --conf spark.kubernetes.authenticate.driver.serviceAccountName=spark                  \\\n  ${SAMPLE_JAR}                                                                        \\\n  -dataPath=train::${SPARK_XGBOOST_DIR}/mortgage/output/train/                   \\\n  -dataPath=trans::${SPARK_XGBOOST_DIR}/mortgage/output/eval/                    \\\n  -format=parquet                                                                \\\n  -numWorkers=${SPARK_NUM_EXECUTORS}                                                    \\\n  -treeMethod=${TREE_METHOD}                                                            \\\n  -numRound=100                                                                         \\\n  -maxDepth=8                   \n  \n   # Please make sure to change the class and data path while running Taxi or Agaricus benchmark                                                       \n                                                \n```\n\nRetrieve the logs using the driver's pod name that is printed to `stdout` by spark-submit \n```\nexport POD_NAME=<kubernetes pod name>\nkubectl logs -f ${POD_NAME}\n```\n\nIn the driver log, you should see timings* (in seconds), and the accuracy metric(take Mortgage as example):\n```\n--------------\n==> Benchmark: Elapsed time for [Mortgage GPU train csv stub Unknown Unknown Unknown]: 30.132s\n--------------\n\n--------------\n==> Benchmark: Elapsed time for [Mortgage GPU transform csv stub Unknown Unknown Unknown]: 22.352s\n--------------\n\n--------------\n==> Benchmark: Accuracy for [Mortgage GPU Accuracy csv stub Unknown Unknown Unknown]: 0.9869451418401349\n--------------\n```\n\n\\* Kubernetes logs may not be nicely formatted since `stdout` and `stderr` are not kept separately.\n\n\\* The timings in this Getting Started guide are only for illustrative purpose. \nPlease see our [release announcement](https://medium.com/rapids-ai/nvidia-gpus-and-apache-spark-one-step-closer-2d99e37ac8fd) for official benchmarks.\n"
  },
  {
    "path": "docs/get-started/xgboost-examples/on-prem-cluster/standalone-python.md",
    "content": "Get Started with XGBoost4J-Spark on an Apache Spark Standalone Cluster\n======================================================================\nThis is a getting started guide to XGBoost4J-Spark on an Apache Spark 3.2+ Standalone Cluster.\nAt the end of this guide, the user can run a sample Apache Spark Python application that runs on NVIDIA GPUs.\n\nPrerequisites\n-------------\n\n* Apache Spark 3.2.0+ (e.g.: Spark 3.2.0)\n* Hardware Requirements\n  * NVIDIA Pascal™ GPU architecture or better\n  * Multi-node clusters with homogenous GPU configuration\n* Software Requirements\n  * Ubuntu 20.04, 22.04/CentOS7, Rocky Linux 8\n  * CUDA 11.5+\n  * NVIDIA driver compatible with your CUDA\n  * NCCL 2.7.8+\n  * Python 3.8 or 3.9\n  * NumPy\n  * XGBoost 1.7.0+\n  * cudf-cu11  \n\nThe number of GPUs in each host dictates the number of Spark executors that can run there.\nAdditionally, cores per Spark executor and cores per Spark task must match, such that each executor can run 1 task at any given time.\n\nFor example, if each host has 4 GPUs, there should be 4 or fewer executors running on each host,\nand each executor should run at most 1 task (e.g.: a total of 4 tasks running on 4 GPUs).\n\nIn Spark Standalone mode, the default configuration is for an executor to take up all the cores assigned to each Spark Worker.\nIn this example, we will limit the number of cores to 1, to match our dataset.\nPlease see https://spark.apache.org/docs/latest/spark-standalone.html for more documentation regarding Standalone configuration.\n\nWe use `SPARK_HOME` environment variable to point to the Apache Spark cluster.\nAnd here are the steps to enable the GPU resources discovery for Spark 3.2+.\n\n1. Copy the spark config file from template\n\n    ``` bash\n    cd ${SPARK_HOME}/conf/\n    cp spark-defaults.conf.template spark-defaults.conf\n    ```\n\n2. Add the following configs to the file `spark-defaults.conf`.\n\n   The number in the first config should **NOT** be larger than the actual number of the GPUs on current host.\n   This example uses 1 as below for one GPU on the host.\n\n    ```bash\n    spark.worker.resource.gpu.amount 1\n    spark.worker.resource.gpu.discoveryScript ${SPARK_HOME}/examples/src/main/scripts/getGpusResources.sh\n    ```\n3. Install the XGBoost, cudf-cu11, numpy libraries on all nodes before running XGBoost application.\n\n``` bash\npip install xgboost\npip install cudf-cu11 --extra-index-url=https://pypi.nvidia.com\npip install numpy\npip install scikit-learn\n```\n\nGet Application Files, Jar and Dataset\n-------------------------------\n\nMake sure you have prepared the necessary packages and dataset by following this [guide](../prepare-package-data/preparation-python.md)\n\n\n#### Note: \n1. Mortgage and Taxi jobs have ETLs to generate the processed data.\n2. For convenience, a subset of [Taxi](/datasets/) dataset is made available in this repo that can be readily used for launching XGBoost job. Use [ETL](standalone-python.md#launch-mortgage-or-taxi-etl-part) to generate larger datasets for training and testing.\n3. Agaricus does not have an ETL process, it is combined with XGBoost as there is just a filter operation.\n\n\nLaunch a Standalone Spark Cluster\n---------------------------------\n\n1. Copy required jars to `$SPARK_HOME/jars` folder.\n\n    ``` bash\n    cp ${RAPIDS_JAR} $SPARK_HOME/jars/\n    ```\n\n2. Start the Spark Master process.\n\n    ``` bash\n    ${SPARK_HOME}/sbin/start-master.sh\n    ```\n\n    Note the hostname or ip address of the Master host, so that it can be given to each Worker process, in this example the Master and Worker will run on the same host.\n\n3. Start a spark slave process.\n\n    ``` bash\n    export SPARK_MASTER=spark://`hostname -f`:7077\n    export SPARK_CORES_PER_WORKER=1\n\n    ${SPARK_HOME}/sbin/start-slave.sh ${SPARK_MASTER} -c ${SPARK_CORES_PER_WORKER}\n    ```\n\n    Note that in this example the Master and Worker processes are both running on the same host. This is not a requirement, as long as all hosts that are used to run the Spark app have access to the dataset.\n\nLaunch Mortgage or Taxi ETL Part\n---------------------------\nUse the ETL app to process raw Mortgage data. You can either use this ETLed data to split into training and evaluation data or run the ETL on different subsets of the dataset to produce training and evaluation datasets.\n\nNote: For ETL jobs, Set `spark.task.resource.gpu.amount` to `1/spark.executor.cores`.\n### ETL on GPU\n``` bash\n${SPARK_HOME}/bin/spark-submit \\\n    --master spark://$HOSTNAME:7077 \\\n    --executor-memory 32G \\\n    --conf spark.executor.resource.gpu.amount=1 \\\n    --conf spark.executor.cores=10 \\\n    --conf spark.task.resource.gpu.amount=0.1 \\\n    --conf spark.plugins=com.nvidia.spark.SQLPlugin \\\n    --conf spark.rapids.sql.incompatibleDateFormats.enabled=true \\\n    --conf spark.rapids.sql.csv.read.double.enabled=true \\\n    --conf spark.sql.cache.serializer=com.nvidia.spark.ParquetCachedBatchSerializer \\\n    --py-files ${SAMPLE_ZIP} \\\n    main.py \\\n    --mainClass='com.nvidia.spark.examples.mortgage.etl_main' \\\n    --format=csv \\\n    --dataPath=\"data::${SPARK_XGBOOST_DIR}/mortgage/input/\" \\\n    --dataPath=\"out::${SPARK_XGBOOST_DIR}/mortgage/output/train/\" \\\n    --dataPath=\"tmp::${SPARK_XGBOOST_DIR}/mortgage/output/tmp/\"\n\n# if generating eval data, change the data path to eval\n# --dataPath=\"data::${SPARK_XGBOOST_DIR}/mortgage/input/\"\n# --dataPath=\"out::${SPARK_XGBOOST_DIR}/mortgage/output/eval/\"\n# --dataPath=\"tmp::${SPARK_XGBOOST_DIR}/mortgage/output/tmp/\"\n# if running Taxi ETL benchmark, change the class and data path params to\n# -class com.nvidia.spark.examples.taxi.ETLMain  \n# -dataPath=\"raw::${SPARK_XGBOOST_DIR}/taxi/your-path\"\n# -dataPath=\"out::${SPARK_XGBOOST_DIR}/taxi/your-path\"\n```\n### ETL on CPU\n```bash\n${SPARK_HOME}/bin/spark-submit \\\n    --master spark://$HOSTNAME:7077 \\\n    --executor-memory 32G \\\n    --conf spark.executor.instances=1 \\\n    --py-files ${SAMPLE_ZIP} \\\n    main.py \\\n    --mainClass='com.nvidia.spark.examples.mortgage.etl_main' \\\n    --format=csv \\\n    --dataPath=\"data::${SPARK_XGBOOST_DIR}/mortgage/input/\" \\\n    --dataPath=\"out::${SPARK_XGBOOST_DIR}/mortgage/output/train/\" \\\n    --dataPath=\"tmp::${SPARK_XGBOOST_DIR}/mortgage/output/tmp/\"\n\n# if generating eval data, change the data path to eval\n# --dataPath=\"data::${SPARK_XGBOOST_DIR}/mortgage/input/\"\n# --dataPath=\"out::${SPARK_XGBOOST_DIR}/mortgage/output/eval/\"\n# --dataPath=\"tmp::${SPARK_XGBOOST_DIR}/mortgage/output/tmp/\"\n# if running Taxi ETL benchmark, change the class and data path params to\n# -class com.nvidia.spark.examples.taxi.ETLMain  \n# -dataPath=\"raw::${SPARK_XGBOOST_DIR}/taxi/your-path\"\n# -dataPath=\"out::${SPARK_XGBOOST_DIR}/taxi/your-path\"\n```\n\nLaunch XGBoost Part on GPU\n---------------------------\n\nVariables required to run spark-submit command:\n\n``` bash\n# this is the same master host we defined while launching the cluster\nexport SPARK_MASTER=spark://`hostname -f`:7077\n\n# Currently the number of tasks and executors must match the number of input files.\n# For this example, we will set these such that we have 1 executor, with 1 core per executor\n\n## take up the the whole worker\nexport SPARK_CORES_PER_EXECUTOR=${SPARK_CORES_PER_WORKER}\n\n## run 1 executor\nexport SPARK_NUM_EXECUTORS=1\n\n## cores/executor * num_executors, which in this case is also 1, limits\n## the number of cores given to the application\nexport TOTAL_CORES=$((SPARK_CORES_PER_EXECUTOR * SPARK_NUM_EXECUTORS))\n\n# spark driver memory\nexport SPARK_DRIVER_MEMORY=4g\n\n# spark executor memory\nexport SPARK_EXECUTOR_MEMORY=8g\n\n# example class to use\nexport EXAMPLE_CLASS=com.nvidia.spark.examples.mortgage.main\n# or change to com.nvidia.spark.examples.taxi.main to run Taxi Xgboost benchmark\n# or change to com.nvidia.spark.examples.agaricus.main to run Agaricus Xgboost benchmark\n\n# tree construction algorithm\nexport TREE_METHOD=gpu_hist\n\n# if you enable archive python environment\nexport PYSPARK_DRIVER_PYTHON=python\nexport PYSPARK_PYTHON=./environment/bin/python\n```\n\nRun spark-submit:\n\n``` bash\n${SPARK_HOME}/bin/spark-submit                                                  \\\n --conf spark.plugins=com.nvidia.spark.SQLPlugin                       \\\n --conf spark.rapids.memory.gpu.pool=NONE                     \\\n --conf spark.executor.resource.gpu.amount=1                           \\\n --conf spark.task.resource.gpu.amount=1                              \\\n --master ${SPARK_MASTER}                                                       \\\n --driver-memory ${SPARK_DRIVER_MEMORY}                                         \\\n --executor-memory ${SPARK_EXECUTOR_MEMORY}                                     \\\n --conf spark.cores.max=${TOTAL_CORES}                                          \\\n --archives your_pyspark_venv.tar.gz#environment     #if you enabled archive python environment \\\n --jars ${RAPIDS_JAR}    \\\n --py-files ${SAMPLE_ZIP}                   \\\n ${MAIN_PY}                                                     \\\n --mainClass=${EXAMPLE_CLASS}                                                   \\\n --dataPath=train::${SPARK_XGBOOST_DIR}/mortgage/output/train/      \\\n --dataPath=trans::${SPARK_XGBOOST_DIR}/mortgage/output/eval/      \\\n --format=parquet                                 \\\n --numWorkers=${SPARK_NUM_EXECUTORS}                                            \\\n --treeMethod=${TREE_METHOD}                                                    \\\n --numRound=100                                                                 \\\n --maxDepth=8\n\n # Change the format to csv if your input file is CSV format.\n # Please make sure to change the class and data path while running Taxi or Agaricus benchmark  \n```\n\nIn the `stdout` log on driver side, you should see timings<sup>*</sup> (in seconds), and the accuracy metric:\n\n```\n----------------------------------------------------------------------------------------------------\nTraining takes 14.65 seconds\n\n----------------------------------------------------------------------------------------------------\nTransformation takes 12.21 seconds\n\n----------------------------------------------------------------------------------------------------\nAccuracy is 0.9873692247091792\n```\n\nLaunch XGBoost Part on CPU\n---------------------------\n\nIf you are running this example after running the GPU example above, please set these variables,\nto set both training and testing to run on the CPU exclusively:\n\n``` bash\n# this is the same master host we defined while launching the cluster\nexport SPARK_MASTER=spark://`hostname -f`:7077\n\n# Currently the number of tasks and executors must match the number of input files.\n# For this example, we will set these such that we have 1 executor, with 1 core per executor\n\n## take up the the whole worker\nexport SPARK_CORES_PER_EXECUTOR=${SPARK_CORES_PER_WORKER}\n\n## run 1 executor\nexport SPARK_NUM_EXECUTORS=1\n\n## cores/executor * num_executors, which in this case is also 1, limits\n## the number of cores given to the application\nexport TOTAL_CORES=$((SPARK_CORES_PER_EXECUTOR * SPARK_NUM_EXECUTORS))\n\n# spark driver memory\nexport SPARK_DRIVER_MEMORY=4g\n\n# spark executor memory\nexport SPARK_EXECUTOR_MEMORY=8g\n\n# example class to use\nexport EXAMPLE_CLASS=com.nvidia.spark.examples.mortgage.main\n# Please make sure to change the class while running Taxi or Agaricus benchmark    \n\n# tree construction algorithm\nexport TREE_METHOD=hist\n\n# if you enable archive python environment\nexport PYSPARK_DRIVER_PYTHON=python\nexport PYSPARK_PYTHON=./environment/bin/python\n```\n\nThis is the same command as for the GPU example, repeated for convenience:\n\n``` bash\n${SPARK_HOME}/bin/spark-submit                                                  \\\n --master ${SPARK_MASTER}                                                       \\\n --driver-memory ${SPARK_DRIVER_MEMORY}                                         \\\n --executor-memory ${SPARK_EXECUTOR_MEMORY}                                     \\\n --conf spark.cores.max=${TOTAL_CORES}                                          \\\n --archives your_pyspark_venv.tar.gz#environment     #if you enabled archive python environment \\\n --jars ${RAPIDS_JAR}     \\\n --py-files ${SAMPLE_ZIP}                       \\\n ${SPARK_PYTHON_ENTRYPOINT}                                                     \\\n --mainClass=${EXAMPLE_CLASS}                                                   \\\n --dataPath=train::${DATA_PATH}/mortgage/output/train/      \\\n --dataPath=trans::${DATA_PATH}/mortgage/output/eval/         \\\n --format=parquet                                                               \\\n --numWorkers=${SPARK_NUM_EXECUTORS}                                            \\\n --treeMethod=${TREE_METHOD}                                                    \\\n --numRound=100                                                                 \\\n --maxDepth=8\n\n # Change the format to csv if your input file is CSV format.\n # Please make sure to change the class and data path while running Taxi or Agaricus benchmark  \n \n```\n\nIn the `stdout` log on driver side, you should see timings<sup>*</sup> (in seconds), and the accuracy metric:\n\n```\n----------------------------------------------------------------------------------------------------\nTraining takes 225.7 seconds\n\n----------------------------------------------------------------------------------------------------\nTransformation takes 36.26 seconds\n\n----------------------------------------------------------------------------------------------------\nAccuracy is 0.9873709530950067\n```\n\n<sup>*</sup> The timings in this Getting Started guide are only illustrative.\nPlease see our [release announcement](https://medium.com/rapids-ai/nvidia-gpus-and-apache-spark-one-step-closer-2d99e37ac8fd) for official benchmarks.\n"
  },
  {
    "path": "docs/get-started/xgboost-examples/on-prem-cluster/standalone-scala.md",
    "content": "Get Started with XGBoost4J-Spark on an Apache Spark Standalone Cluster\n======================================================================\n\nThis is a getting-started guide to XGBoost on an Apache Spark 3.2+ Standalone Cluster. At the end of this guide,\nthe user can run a sample Apache Spark application that runs on NVIDIA GPUs.\n\nPrerequisites\n-------------\n\n* Apache Spark 3.2.0+ Standalone Cluster (e.g.: Spark 3.2.0)\n* Hardware Requirements\n  * NVIDIA Pascal™ GPU architecture or better\n  * Multi-node clusters with homogenous GPU configuration\n* Software Requirements\n  * Ubuntu 20.04, 22.04/CentOS7, Rocky Linux 8\n  * CUDA 11.0+\n  * NVIDIA driver compatible with your CUDA\n  * NCCL 2.7.8+\n  \nThe number of GPUs in each host dictates the number of Spark executors that can run there. Additionally,\ncores per Spark executor and cores per Spark task must match, such that each executor can run 1 task at any given time.\n\nFor example, if each host has 4 GPUs, there should be 4 or fewer executors running on each host,\nand each executor should run at most 1 task (e.g.: a total of 4 tasks running on 4 GPUs).\n\nIn Spark Standalone mode, the default configuration is for an executor to take up all the cores assigned to each Spark Worker.\nIn this example, we will limit the number of cores to 1, to match our dataset.\nPlease see https://spark.apache.org/docs/latest/spark-standalone.html for more documentation regarding Standalone configuration.\n\nWe use `SPARK_HOME` environment variable to point to the Apache Spark cluster.\nAnd here are steps to enable the GPU resources discovery for Spark 3.2+.\n\n1. Copy the spark configure file from template.\n\n    ``` bash\n    cd ${SPARK_HOME}/conf/\n    cp spark-defaults.conf.template spark-defaults.conf\n    ```\n\n2. Add the following configs to the file `spark-defaults.conf`.\n  \n    The number in first config should NOT be larger than the actual number of the GPUs on current host.\n   This example uses 1 as below for one GPU on the host.\n\n    ``` bash\n    spark.worker.resource.gpu.amount 1\n    spark.worker.resource.gpu.discoveryScript ${SPARK_HOME}/examples/src/main/scripts/getGpusResources.sh\n    ```\n\nGet Jars and Dataset\n-------------------------------\n\nMake sure you have prepared the necessary packages and dataset \nby following this [guide](/docs/get-started/xgboost-examples/prepare-package-data/preparation-scala.md)\n\n#### Note: \n1. Mortgage and Taxi jobs have ETLs to generate the processed data. \n2. For convenience, a subset of [Taxi](/datasets/) dataset is made available in this repo that can be readily used for launching XGBoost job. Use [ETL](#etl) to generate larger datasets for trainig and testing. \n3. Agaricus does not have an ETL process, it is combined with XGBoost as there is just a filter operation.\n\n\nLaunch a Standalone Spark Cluster\n---------------------------------\n\n1. Copy required jars to `$SPARK_HOME/jars` folder.\n\n    ``` bash\n    cp $RAPIDS_JAR $SPARK_HOME/jars/\n    ```\n\n2. Start the Spark Master process.\n\n    ``` bash\n    ${SPARK_HOME}/sbin/start-master.sh\n    ```\n\n    Note the hostname or ip address of the Master host, so that it can be given to each Worker process,\n    in this example the Master and Worker will run on the same host.\n\n3. Start a Spark slave process.\n\n    ``` bash\n    export SPARK_MASTER=spark://`hostname -f`:7077\n    export SPARK_CORES_PER_WORKER=1\n\n    ${SPARK_HOME}/sbin/start-slave.sh ${SPARK_MASTER} -c ${SPARK_CORES_PER_WORKER} \n    ```\n\n    Note that in this example the Master and Worker processes are both running on the same host. \n    This is not a requirement, as long as all hosts that are used to run the Spark app have access to the dataset.\n\n<span id=\"etl\">Launch Mortgage or Taxi ETL Part</span>\n---------------------------\n\nUse the ETL app to process raw Mortgage data. You can either use this ETLed data to split into training and evaluation data or run the ETL on different subsets of the dataset to produce training and evaluation datasets.\nRun spark-submit\n\nNote: For ETL jobs, Set `spark.task.resource.gpu.amount` to `1/spark.executor.cores`.\n\n### ETL on GPU \n``` bash\n${SPARK_HOME}/bin/spark-submit \\\n    --master spark://$HOSTNAME:7077 \\\n    --executor-memory 32G \\\n    --conf spark.executor.resource.gpu.amount=1 \\\n    --conf spark.executor.cores=10 \\\n    --conf spark.task.resource.gpu.amount=0.1 \\\n    --conf spark.plugins=com.nvidia.spark.SQLPlugin \\\n    --conf spark.rapids.sql.incompatibleDateFormats.enabled=true \\\n    --conf spark.rapids.sql.csv.read.double.enabled=true \\\n    --conf spark.sql.cache.serializer=com.nvidia.spark.ParquetCachedBatchSerializer \\\n    --class com.nvidia.spark.examples.mortgage.ETLMain  \\\n    $SAMPLE_JAR \\\n    -format=csv \\\n    -dataPath=\"data::${SPARK_XGBOOST_DIR}/mortgage/input/\" \\\n    -dataPath=\"out::${SPARK_XGBOOST_DIR}/mortgage/output/train/\" \\\n    -dataPath=\"tmp::${SPARK_XGBOOST_DIR}/mortgage/output/tmp/\"\n\n# if generating eval data, change the data path to eval \n# -dataPath=\"data::${SPARK_XGBOOST_DIR}/mortgage/input/\"\n# -dataPath=\"out::${SPARK_XGBOOST_DIR}/mortgage/output/eval/\"\n# -dataPath=\"tmp::${SPARK_XGBOOST_DIR}/mortgage/output/tmp/\"\n# if running Taxi ETL benchmark, change the class and data path params to\n# -class com.nvidia.spark.examples.taxi.ETLMain  \n# -dataPath=\"raw::${SPARK_XGBOOST_DIR}/taxi/your-path\"\n# -dataPath=\"out::${SPARK_XGBOOST_DIR}/taxi/your-path\"\n```\n\n### ETL on CPU\n\n```bash\n${SPARK_HOME}/bin/spark-submit \\\n--master spark://$HOSTNAME:7077 \\\n--executor-memory 32G \\\n--conf spark.executor.instances=1 \\\n--conf spark.sql.broadcastTimeout=700 \\\n--class com.nvidia.spark.examples.mortgage.ETLMain  \\\n$SAMPLE_JAR \\\n-format=csv \\\n-dataPath=\"data::${SPARK_XGBOOST_DIR}/mortgage/input/\" \\\n-dataPath=\"out::${SPARK_XGBOOST_DIR}/mortgage/output/train/\" \\\n-dataPath=\"tmp::${SPARK_XGBOOST_DIR}/mortgage/output/tmp/\"\n\n# if generating eval data, change the data path to eval \n# -dataPath=\"data::${SPARK_XGBOOST_DIR}/mortgage/input/\"\n# -dataPath=\"out::${SPARK_XGBOOST_DIR}/mortgage/output/eval/\"\n# if running Taxi ETL benchmark, change the class and data path params to\n# -class com.nvidia.spark.examples.taxi.ETLMain  \n# -dataPath=\"raw::${SPARK_XGBOOST_DIR}/taxi/your-path\"\n# -dataPath=\"out::${SPARK_XGBOOST_DIR}/taxi/your-path\"\n```\n\nLaunch XGBoost Part on GPU\n---------------------------\n\nVariables required to run spark-submit command:\n\n``` bash\n# this is the same master host we defined while launching the cluster\nexport SPARK_MASTER=spark://`hostname -f`:7077\n\n# Currently the number of tasks and executors must match the number of input files.\n# For this example, we will set these such that we have 1 executor, with 1 core per executor\n\n## take up the the whole worker\nexport SPARK_CORES_PER_EXECUTOR=${SPARK_CORES_PER_WORKER}\n\n## run 1 executor\nexport SPARK_NUM_EXECUTORS=1\n\n## cores/executor * num_executors, which in this case is also 1, limits\n## the number of cores given to the application\nexport TOTAL_CORES=$((SPARK_CORES_PER_EXECUTOR * SPARK_NUM_EXECUTORS))\n\n# spark driver memory\nexport SPARK_DRIVER_MEMORY=4g\n\n# spark executor memory\nexport SPARK_EXECUTOR_MEMORY=8g\n\n# example class to use\nexport EXAMPLE_CLASS=com.nvidia.spark.examples.mortgage.Main\n# or change to com.nvidia.spark.examples.taxi.Main to run Taxi Xgboost benchmark\n# or change to com.nvidia.spark.examples.agaricus.Main to run Agaricus Xgboost benchmark\n\n# tree construction algorithm\nexport TREE_METHOD=gpu_hist\n```\n\nRun spark-submit:\n\n``` bash\n${SPARK_HOME}/bin/spark-submit                                                  \\\n --conf spark.plugins=com.nvidia.spark.SQLPlugin                       \\\n --conf spark.rapids.memory.gpu.pool=NONE                     \\\n --conf spark.executor.resource.gpu.amount=1                           \\\n --conf spark.task.resource.gpu.amount=1                              \\\n --master ${SPARK_MASTER}                                                       \\\n --driver-memory ${SPARK_DRIVER_MEMORY}                                         \\\n --executor-memory ${SPARK_EXECUTOR_MEMORY}                                     \\\n --conf spark.cores.max=${TOTAL_CORES}                                          \\\n --class ${EXAMPLE_CLASS}                                                       \\\n ${SAMPLE_JAR}                                                                 \\\n -dataPath=train::${SPARK_XGBOOST_DIR}/mortgage/output/train/      \\\n -dataPath=trans::${SPARK_XGBOOST_DIR}/mortgage/output/eval/          \\\n -format=parquet                                                                    \\\n -numWorkers=${SPARK_NUM_EXECUTORS}                                             \\\n -treeMethod=${TREE_METHOD}                                                     \\\n -numRound=100                                                                  \\\n -maxDepth=8                      \n # Please make sure to change the class and data path while running Taxi or Agaricus benchmark                                              \n```\n\nIn `stdout` log on driver side, you should see timings<sup>*</sup> (in seconds), \nand the accuracy metric(take Mortgage as example):\n\n```\n--------------\n==> Benchmark: Elapsed time for [Mortgage GPU train csv stub Unknown Unknown Unknown]: 26.572s\n--------------\n\n--------------\n==> Benchmark: Elapsed time for [Mortgage GPU transform csv stub Unknown Unknown Unknown]: 10.323s\n--------------\n\n--------------\n==> Benchmark: Accuracy for [Mortgage GPU Accuracy csv stub Unknown Unknown Unknown]: 0.9869227318579323\n--------------\n```\n\nLaunch XGBoost Part on CPU\n---------------------------\n\nIf you are running this example after running the GPU example above, please set these variables, \nto set both training and testing to run on the CPU exclusively:\n\n``` bash\n# this is the same master host we defined while launching the cluster\nexport SPARK_MASTER=spark://`hostname -f`:7077\n\n# Currently the number of tasks and executors must match the number of input files.\n# For this example, we will set these such that we have 1 executor, with 1 core per executor\n\n## take up the the whole worker\nexport SPARK_CORES_PER_EXECUTOR=${SPARK_CORES_PER_WORKER}\n\n## run 1 executor\nexport SPARK_NUM_EXECUTORS=1\n\n## cores/executor * num_executors, which in this case is also 1, limits\n## the number of cores given to the application\nexport TOTAL_CORES=$((SPARK_CORES_PER_EXECUTOR * SPARK_NUM_EXECUTORS))\n\n# spark driver memory\nexport SPARK_DRIVER_MEMORY=4g\n\n# spark executor memory\nexport SPARK_EXECUTOR_MEMORY=8g\n\n# example class to use\nexport EXAMPLE_CLASS=com.nvidia.spark.examples.mortgage.Main\n# Please make sure to change the class while running Taxi or Agaricus benchmark     \n\n# tree construction algorithm\nexport TREE_METHOD=hist\n```\n\nThis is the same command as for the GPU example, repeated for convenience:\n\n```bash\n${SPARK_HOME}/bin/spark-submit                                                  \\\n --master ${SPARK_MASTER}                                                       \\\n --driver-memory ${SPARK_DRIVER_MEMORY}                                         \\\n --executor-memory ${SPARK_EXECUTOR_MEMORY}                                     \\\n --conf spark.cores.max=${TOTAL_CORES}                                          \\\n --class ${EXAMPLE_CLASS}                                                       \\\n ${SAMPLE_JAR}                                                                 \\\n -dataPath=train::${SPARK_XGBOOST_DIR}/mortgage/output/train/      \\\n -dataPath=trans::${SPARK_XGBOOST_DIR}/mortgage/output/eval/          \\\n -format=parquet                                                                    \\\n -numWorkers=${SPARK_NUM_EXECUTORS}                                             \\\n -treeMethod=${TREE_METHOD}                                                     \\\n -numRound=100                                                                  \\\n -maxDepth=8                  \n \n # Please make sure to change the class and data path while running Taxi or Agaricus benchmark                                                       \n```\n\nIn the `stdout` log on driver side, you should see timings<sup>*</sup> (in seconds), and the accuracy metric(take Mortgage as example):\n\n```\n--------------\n==> Benchmark: Elapsed time for [Mortgage CPU train csv stub Unknown Unknown Unknown]: 305.535s\n--------------\n\n--------------\n==> Benchmark: Elapsed time for [Mortgage CPU transform csv stub Unknown Unknown Unknown]: 52.867s\n--------------\n\n--------------\n==> Benchmark: Accuracy for [Mortgage CPU Accuracy csv stub Unknown Unknown Unknown]: 0.9872234894511343\n--------------\n```\n\n<sup>*</sup> The timings in this Getting Started guide are only for illustrative purpose. \nPlease see our [release announcement](https://medium.com/rapids-ai/nvidia-gpus-and-apache-spark-one-step-closer-2d99e37ac8fd) \nfor official benchmarks.\n"
  },
  {
    "path": "docs/get-started/xgboost-examples/on-prem-cluster/yarn-python.md",
    "content": "Get Started with XGBoost4J-Spark on Apache Hadoop YARN\n======================================================\nThis is a getting started guide to XGBoost4J-Spark on Apache Hadoop YARN supporting GPU scheduling.\nAt the end of this guide, the reader will be able to run a sample Apache Spark Python application that runs on NVIDIA GPUs.\n\nPrerequisites\n-------------\n\n* Apache Spark 3.2.0+ running on YARN supporting GPU scheduling. (e.g.: Spark 3.2.0, Hadoop-Yarn 3.3.0)\n* Hardware Requirements\n  * NVIDIA Pascal™ GPU architecture or better\n  * Multi-node clusters with homogenous GPU configuration\n* Software Requirements\n  * Ubuntu 20.04, 22.04/CentOS7, Rocky Linux 8\n  * CUDA 11.5+\n  * NVIDIA driver compatible with your CUDA\n  * NCCL 2.7.8+\n  * Python 3.8 or 3.9\n  * NumPy\n  * XGBoost 1.7.0+\n  * cudf-cu11  \n  \nThe number of GPUs per NodeManager dictates the number of Spark executors that can run in that NodeManager. \nAdditionally, cores per Spark executor and cores per Spark task must match, such that each executor can run 1 task at any given time.\n\nFor example: if each NodeManager has 4 GPUs, there should be 4 or fewer executors running on each NodeManager, \nand each executor should run 1 task (e.g.: A total of 4 tasks running on 4 GPUs). In order to achieve this, \nyou may need to adjust `spark.task.cpus` and `spark.executor.cores` to match (both set to 1 by default).\n\nAdditionally, we recommend adjusting `executor-memory` to divide host memory evenly amongst the number of GPUs in each NodeManager,\nsuch that Spark will schedule as many executors as there are GPUs in each NodeManager.\n\nWe use `SPARK_HOME` environment variable to point to the Apache Spark cluster. \nAnd as to how to enable GPU scheduling and isolation for Yarn,\nplease refer to [here](https://hadoop.apache.org/docs/r3.1.0/hadoop-yarn/hadoop-yarn-site/UsingGpus.html).\n\nPlease make sure to install the XGBoost, cudf-cu11, numpy libraries on all nodes before running XGBoost application.\n``` bash\npip install xgboost\npip install cudf-cu11 --extra-index-url=https://pypi.nvidia.com\npip install numpy\npip install scikit-learn\n```\nYou can also create an isolated python environment by using [Virtualenv](https://virtualenv.pypa.io/en/latest/),\nand then directly pass/unpack the archive file and enable the environment on executors\nby leveraging the --archives option or spark.archives configuration.\n``` bash\n# create an isolated python environment and install libraries\npython -m venv pyspark_venv\nsource pyspark_venv/bin/activate\npip install xgboost\npip install cudf-cu11 --extra-index-url=https://pypi.nvidia.com\npip install numpy\npip install scikit-learn\nvenv-pack -o pyspark_venv.tar.gz\n\n# enable archive python environment on executors\nexport PYSPARK_DRIVER_PYTHON=python # Do not set in cluster modes.\nexport PYSPARK_PYTHON=./environment/bin/python\nspark-submit --archives pyspark_venv.tar.gz#environment app.py\n```\n\nGet Application Files, Jar and Dataset\n-------------------------------\n\nMake sure you have prepared the necessary packages and dataset by following this [guide](../prepare-package-data/preparation-python.md)\n\nThen create a directory in HDFS, and run below commands,\n\n``` bash\n[xgboost4j_spark_python]$ hadoop fs -mkdir /tmp/xgboost4j_spark_python\n[xgboost4j_spark_python]$ hadoop fs -copyFromLocal ${SPARK_XGBOOST_DIR}/mortgage/* /tmp/xgboost4j_spark_python\n```\n\nLaunch Mortgage or Taxi ETL Part\n---------------------------\n\nUse the ETL app to process raw Mortgage data. You can either use this ETLed data to split into training and evaluation data or run the ETL on different subsets of the dataset to produce training and evaluation datasets.\n\nNote: For ETL jobs, Set `spark.task.resource.gpu.amount` to `1/spark.executor.cores`.\n\n``` bash\n# location where data was downloaded\nexport DATA_PATH=hdfs:/tmp/xgboost4j_spark_python/\n\n${SPARK_HOME}/bin/spark-submit \\\n    --master yarn \\\n    --deploy-mode cluster \\\n    --conf spark.executor.cores=10 \\\n    --conf spark.task.resource.gpu.amount=0.1 \\\n    --conf spark.rapids.sql.incompatibleDateFormats.enabled=true \\\n    --conf spark.rapids.sql.csv.read.double.enabled=true \\\n    --conf spark.sql.cache.serializer=com.nvidia.spark.ParquetCachedBatchSerializer \\\n    --jars ${RAPIDS_JAR}\\\n    ${MAIN_PY} \\\n    --mainClass='com.nvidia.spark.examples.mortgage.etl_main' \\\n    --format=csv \\\n    --dataPath=\"data::${DATA_PATH}/mortgage/data/mortgage/input/\" \\\n    --dataPath=\"out::${DATA_PATH}/mortgage/data/mortgage/output/train/\" \\\n    --dataPath=\"tmp::${SPARK_XGBOOST_DIR}/mortgage/output/tmp/\"\n\n# if generating eval data, change the data path to eval\n# --dataPath=\"data::${SPARK_XGBOOST_DIR}/mortgage/input/\"\n# --dataPath=\"out::${SPARK_XGBOOST_DIR}/mortgage/output/eval/\"\n# --dataPath=\"tmp::${SPARK_XGBOOST_DIR}/mortgage/output/tmp/\"\n# if running Taxi ETL benchmark, change the class and data path params to\n# -class com.nvidia.spark.examples.taxi.ETLMain  \n# -dataPath=\"raw::${SPARK_XGBOOST_DIR}/taxi/your-path\"\n# -dataPath=\"out::${SPARK_XGBOOST_DIR}/taxi/your-path\"\n```\n\nLaunch XGBoost Part on GPU\n---------------------------\n\nVariables required to run spark-submit command:\n\n``` bash\n# location where data was downloaded\nexport DATA_PATH=hdfs:/tmp/xgboost4j_spark_python\n\n# spark deploy mode (see Apache Spark documentation for more information)\nexport SPARK_DEPLOY_MODE=cluster\n\n# run a single executor for this example to limit the number of spark tasks and\n# partitions to 1 as currently this number must match the number of input files\nexport SPARK_NUM_EXECUTORS=1\n\n# spark driver memory\nexport SPARK_DRIVER_MEMORY=4g\n\n# spark executor memory\nexport SPARK_EXECUTOR_MEMORY=8g\n\n# python entrypoint\nexport SPARK_PYTHON_ENTRYPOINT=${LIBS_PATH}/main.py\n\n# example class to use\nexport EXAMPLE_CLASS=com.nvidia.spark.examples.mortgage.main\n# or change to com.nvidia.spark.examples.taxi.main to run Taxi Xgboost benchmark\n# or change to com.nvidia.spark.examples.agaricus.main to run Agaricus Xgboost benchmark\n\n# tree construction algorithm\nexport TREE_METHOD=gpu_hist\n\n# if you enable archive python environment\nexport PYSPARK_DRIVER_PYTHON=python\nexport PYSPARK_PYTHON=./environment/bin/python\n```\n\nRun spark-submit:\n\n``` bash\n${SPARK_HOME}/bin/spark-submit                                                  \\\n --conf spark.plugins=com.nvidia.spark.SQLPlugin                       \\\n --conf spark.rapids.memory.gpu.pool=NONE                     \\\n --conf spark.executor.resource.gpu.amount=1                           \\\n --conf spark.task.resource.gpu.amount=1                              \\\n --conf spark.executor.resource.gpu.discoveryScript=./getGpusResources.sh        \\\n --files ${SPARK_HOME}/examples/src/main/scripts/getGpusResources.sh            \\\n --master yarn                                                                  \\\n --deploy-mode ${SPARK_DEPLOY_MODE}                                             \\\n --archives your_pyspark_venv.tar.gz#environment     #if you enabled archive python environment \\\n --num-executors ${SPARK_NUM_EXECUTORS}                                         \\\n --driver-memory ${SPARK_DRIVER_MEMORY}                                         \\\n --executor-memory ${SPARK_EXECUTOR_MEMORY}                                     \\\n --jars ${RAPIDS_JAR}        \\\n --py-files ${SAMPLE_ZIP}                   \\\n ${MAIN_PY}                                                     \\\n --mainClass=${EXAMPLE_CLASS}                                                   \\\n --dataPath=train::${DATA_PATH}/mortgage/out/train/      \\\n --dataPath=trans::${DATA_PATH}/mortgage/out/eval/        \\\n --format=parquet                                                                   \\\n --numWorkers=${SPARK_NUM_EXECUTORS}                                            \\\n --treeMethod=${TREE_METHOD}                                                    \\\n --numRound=100                                                                 \\\n --maxDepth=8\n\n# Change the format to csv if your input file is CSV format.\n# Please make sure to change the class and data path while running Taxi or Agaricus benchmark  \n```\n\nIn the `stdout` driver log, you should see timings<sup>*</sup> (in seconds), and the accuracy metric:\n\n```\n----------------------------------------------------------------------------------------------------\nTraining takes 10.75 seconds\n\n----------------------------------------------------------------------------------------------------\nTransformation takes 4.38 seconds\n\n----------------------------------------------------------------------------------------------------\nAccuracy is 0.997544753891\n```\n\nLaunch XGBoost Part on CPU\n---------------------------\n\nIf you are running this example after running the GPU example above, please set these variables, to set both training and testing to run on the CPU exclusively:\n\n``` bash\n# location where data was downloaded\nexport DATA_PATH=hdfs:/tmp/xgboost4j_spark_python/\n\n# spark deploy mode (see Apache Spark documentation for more information)\nexport SPARK_DEPLOY_MODE=cluster\n\n# run a single executor for this example to limit the number of spark tasks and\n# partitions to 1 as currently this number must match the number of input files\nexport SPARK_NUM_EXECUTORS=1\n\n# spark driver memory\nexport SPARK_DRIVER_MEMORY=4g\n\n# spark executor memory\nexport SPARK_EXECUTOR_MEMORY=8g\n\n# example class to use\nexport EXAMPLE_CLASS=com.nvidia.spark.examples.mortgage.main\n# or change to com.nvidia.spark.examples.taxi.main to run Taxi Xgboost benchmark\n# or change to com.nvidia.spark.examples.agaricus.main to run Agaricus Xgboost benchmark\n\n# tree construction algorithm\nexport TREE_METHOD=hist\n\n# if you enable archive python environment\nexport PYSPARK_DRIVER_PYTHON=python\nexport PYSPARK_PYTHON=./environment/bin/python\n```\n\nThis is the same command as for the GPU example, repeated for convenience:\n\n``` bash\n${SPARK_HOME}/bin/spark-submit                                                  \\\n --master yarn                                                                  \\\n --archives your_pyspark_venv.tar.gz#environment     #if you enabled archive python environment \\\n --deploy-mode ${SPARK_DEPLOY_MODE}                                             \\\n --num-executors ${SPARK_NUM_EXECUTORS}                                         \\\n --driver-memory ${SPARK_DRIVER_MEMORY}                                         \\\n --executor-memory ${SPARK_EXECUTOR_MEMORY}                                     \\\n --jars ${RAPIDS_JAR}        \\\n --py-files ${SAMPLE_ZIP}                                  \\\n ${MAIN_PY}                                                     \\\n --mainClass=${EXAMPLE_CLASS}                                                   \\\n --dataPath=train::${DATA_PATH}/mortgage/output/train/       \\\n --dataPath=trans::${DATA_PATH}/mortgage/output/eval/         \\\n --format=parquet                                                               \\\n --numWorkers=${SPARK_NUM_EXECUTORS}                                            \\\n --treeMethod=${TREE_METHOD}                                                    \\\n --numRound=100                                                                 \\\n --maxDepth=8\n \n # Please make sure to change the class and data path while running Taxi or Agaricus benchmark  \n```\n\nIn the `stdout` driver log, you should see timings<sup>*</sup> (in seconds), and the accuracy metric:\n\n```\n----------------------------------------------------------------------------------------------------\nTraining takes 10.76 seconds\n\n----------------------------------------------------------------------------------------------------\nTransformation takes 1.25 seconds\n\n----------------------------------------------------------------------------------------------------\nAccuracy is 0.998526852335\n```\n\n<sup>*</sup> The timings in this Getting Started guide are only for illustrative purpose.\nPlease see our [release announcement](https://medium.com/rapids-ai/nvidia-gpus-and-apache-spark-one-step-closer-2d99e37ac8fd) for official benchmarks.\n"
  },
  {
    "path": "docs/get-started/xgboost-examples/on-prem-cluster/yarn-scala.md",
    "content": "Get Started with XGBoost4J-Spark on Apache Hadoop YARN\n======================================================\n\nThis is a getting started guide to XGBoost4J-Spark on Apache Hadoop YARN supporting GPU scheduling. \nAt the end of this guide, the reader will be able to run a sample Apache Spark application that runs on NVIDIA GPUs.\n\nPrerequisites\n-------------\n\n* Apache Spark 3.2.0+ running on YARN supporting GPU scheduling. (e.g.: Spark 3.2.0, Hadoop-Yarn 3.3.0)\n* Hardware Requirements\n  * NVIDIA Pascal™ GPU architecture or better\n  * Multi-node clusters with homogenous GPU configuration\n* Software Requirements\n  * Ubuntu 20.04, 22.04/CentOS7, Rocky Linux 8\n  * CUDA 11.0+\n  * NVIDIA driver compatible with your CUDA\n  * NCCL 2.7.8+\n\nThe number of GPUs per NodeManager dictates the number of Spark executors that can run in that NodeManager. \nAdditionally, cores per Spark executor and cores per Spark task must match, such that each executor can run 1 task at any given time.\n\nFor example: if each NodeManager has 4 GPUs, there should be 4 or fewer executors running on each NodeManager, \nand each executor should run 1 task (e.g.: A total of 4 tasks running on 4 GPUs). In order to achieve this, \nyou may need to adjust `spark.task.cpus` and `spark.executor.cores` to match (both set to 1 by default).\nAdditionally, we recommend adjusting `executor-memory` to divide host memory evenly amongst the number of GPUs in each NodeManager,\nsuch that Spark will schedule as many executors as there are GPUs in each NodeManager.\n\nWe use `SPARK_HOME` environment variable to point to the Apache Spark cluster.\nAnd as to how to enable GPU scheduling and isolation for Yarn, \nplease refer to [here](https://hadoop.apache.org/docs/r3.1.0/hadoop-yarn/hadoop-yarn-site/UsingGpus.html).\n\nGet Jars and Dataset\n-------------------------------\n\nMake sure you have prepared the necessary packages and dataset by following this [guide](/docs/get-started/xgboost-examples/prepare-package-data/preparation-scala.md)\n\n#### Note: \n1. Mortgage and Taxi jobs have ETLs to generate the processed data.\n2. For convenience, a subset of [Taxi](/datasets/) dataset is made available in this repo that can be readily used for launching XGBoost job. Use [ETL](#etl) to generate larger datasets for trainig and testing. \n3. Agaricus does not have an ETL process, it is combined with XGBoost as there is just a filter operation.\n\nCreate a directory in HDFS, and copy:\n\n``` bash\n[xgboost4j_spark]$ hadoop fs -mkdir /tmp/xgboost4j_spark\n[xgboost4j_spark]$ hadoop fs -copyFromLocal ${SPARK_XGBOOST_DIR}/mortgage/* /tmp/xgboost4j_spark\n```\n\n<span id=\"etl\">Launch Mortgage or Taxi ETL Part</span>\n---------------------------\n\nUse the ETL app to process raw Mortgage data. You can either use this ETLed data to split into training and evaluation data or run the ETL on different subsets of the dataset to produce training and evaluation datasets.\n\nNote: For ETL jobs, Set `spark.task.resource.gpu.amount` to `1/spark.executor.cores`.\n\n\nRun spark-submit\n\n``` bash\n${SPARK_HOME}/bin/spark-submit \\\n   --conf spark.plugins=com.nvidia.spark.SQLPlugin \\\n   --conf spark.executor.resource.gpu.amount=1 \\\n   --conf spark.executor.cores=10 \\\n   --conf spark.task.resource.gpu.amount=0.1 \\\n   --conf spark.rapids.sql.incompatibleDateFormats.enabled=true \\\n   --conf spark.rapids.sql.csv.read.double.enabled=true \\\n   --conf spark.executor.resource.gpu.discoveryScript=./getGpusResources.sh \\\n   --conf spark.sql.cache.serializer=com.nvidia.spark.ParquetCachedBatchSerializer \\\n   --files $SPARK_HOME/examples/src/main/scripts/getGpusResources.sh \\\n   --jars ${RAPIDS_JAR}                                           \\\n   --master yarn                                                                  \\\n   --deploy-mode ${SPARK_DEPLOY_MODE}                                             \\\n   --num-executors ${SPARK_NUM_EXECUTORS}                                         \\\n   --driver-memory ${SPARK_DRIVER_MEMORY}                                         \\\n   --executor-memory ${SPARK_EXECUTOR_MEMORY}                                     \\\n   --class com.nvidia.spark.examples.mortgage.ETLMain  \\\n   $SAMPLE_JAR \\\n   -format=csv \\\n   -dataPath=\"data::${SPARK_XGBOOST_DIR}/mortgage/input/\" \\\n   -dataPath=\"out::${SPARK_XGBOOST_DIR}/mortgage/output/train/\" \\\n   -dataPath=\"tmp::${SPARK_XGBOOST_DIR}/mortgage/output/tmp/\"\n\n# if generating eval data, change the data path to eval \n# -dataPath=\"data::${SPARK_XGBOOST_DIR}/mortgage/input/\"\n# -dataPath=\"out::${SPARK_XGBOOST_DIR}/mortgage/output/eval/\"\n# -dataPath=\"tmp::${SPARK_XGBOOST_DIR}/mortgage/output/tmp/\"\n# if running Taxi ETL benchmark, change the class and data path params to\n# -class com.nvidia.spark.examples.taxi.ETLMain  \n# -dataPath=\"raw::${SPARK_XGBOOST_DIR}/taxi/your-path\"\n# -dataPath=\"out::${SPARK_XGBOOST_DIR}/taxi/your-path\"\n```\n\nLaunch XGBoost Part on GPU\n---------------------------\n\nVariables required to run spark-submit command:\n\n``` bash\n# location where data was downloaded \nexport DATA_PATH=hdfs:/tmp/xgboost4j_spark/data\n\n# spark deploy mode (see Apache Spark documentation for more information) \nexport SPARK_DEPLOY_MODE=cluster\n\n# run a single executor for this example to limit the number of spark tasks and\n# partitions to 1 as currently this number must match the number of input files\nexport SPARK_NUM_EXECUTORS=1\n\n# spark driver memory\nexport SPARK_DRIVER_MEMORY=4g\n\n# spark executor memory\nexport SPARK_EXECUTOR_MEMORY=8g\n\n# example class to use\nexport EXAMPLE_CLASS=com.nvidia.spark.examples.mortgage.Main\n# or change to com.nvidia.spark.examples.taxi.Main to run Taxi Xgboost benchmark\n# or change to com.nvidia.spark.examples.agaricus.Main to run Agaricus Xgboost benchmark\n\n# tree construction algorithm\nexport TREE_METHOD=gpu_hist\n```\n\nRun spark-submit:\n\n``` bash\n${SPARK_HOME}/bin/spark-submit                                                  \\\n --conf spark.plugins=com.nvidia.spark.SQLPlugin \\\n --conf spark.rapids.memory.gpu.pool=NONE \\\n --conf spark.executor.resource.gpu.amount=1 \\\n --conf spark.task.resource.gpu.amount=1 \\\n --conf spark.executor.resource.gpu.discoveryScript=./getGpusResources.sh \\\n --files $SPARK_HOME/examples/src/main/scripts/getGpusResources.sh \\\n --jars ${RAPIDS_JAR}                                           \\\n --master yarn                                                                  \\\n --deploy-mode ${SPARK_DEPLOY_MODE}                                             \\\n --num-executors ${SPARK_NUM_EXECUTORS}                                         \\\n --driver-memory ${SPARK_DRIVER_MEMORY}                                         \\\n --executor-memory ${SPARK_EXECUTOR_MEMORY}                                     \\\n --class ${EXAMPLE_CLASS}                                                       \\\n ${SAMPLE_JAR}                                                                 \\\n -dataPath=train::${SPARK_XGBOOST_DIR}/mortgage/output/train/                   \\\n -dataPath=trans::${SPARK_XGBOOST_DIR}/mortgage/output/eval/                    \\\n -format=parquet                                                                \\\n -numWorkers=${SPARK_NUM_EXECUTORS}                                             \\\n -treeMethod=${TREE_METHOD}                                                     \\\n -numRound=100                                                                  \\\n -maxDepth=8                                                                    \n  # Please make sure to change the class and data path while running Taxi or Agaricus benchmark   \n```\n\nIn the `stdout` driver log, you should see timings<sup>*</sup> (in seconds), and the accuracy metric(take Mortgage as example):\n\n```\n--------------\n==> Benchmark: Elapsed time for [Mortgage GPU train csv stub Unknown Unknown Unknown]: 29.642s\n--------------\n\n--------------\n==> Benchmark: Elapsed time for [Mortgage GPU transform csv stub Unknown Unknown Unknown]: 21.272s\n--------------\n\n--------------\n==> Benchmark: Accuracy for [Mortgage GPU Accuracy csv stub Unknown Unknown Unknown]: 0.9874184013493451\n--------------\n```\n\nLaunch XGBoost Part on CPU\n---------------------------\n\nIf you are running this example after running the GPU example above, please set these variables, to set both training and testing to run on the CPU exclusively:\n\n``` bash\n# location where data was downloaded \nexport DATA_PATH=hdfs:/tmp/xgboost4j_spark/data\n\n# spark deploy mode (see Apache Spark documentation for more information) \nexport SPARK_DEPLOY_MODE=cluster\n\n# run a single executor for this example to limit the number of spark tasks and\n# partitions to 1 as currently this number must match the number of input files\nexport SPARK_NUM_EXECUTORS=1\n\n# spark driver memory\nexport SPARK_DRIVER_MEMORY=4g\n\n# spark executor memory\nexport SPARK_EXECUTOR_MEMORY=8g\n\n# example class to use\nexport EXAMPLE_CLASS=com.nvidia.spark.examples.mortgage.Main\n# Please make sure to change the class while running Taxi or Agaricus benchmark   \n\n# tree construction algorithm\nexport TREE_METHOD=hist\n```\n\nThis is the same command as for the GPU example, repeated for convenience:\n\n``` bash\n${SPARK_HOME}/bin/spark-submit                                                  \\\n --master yarn                                                                  \\\n --deploy-mode ${SPARK_DEPLOY_MODE}                                             \\\n --num-executors ${SPARK_NUM_EXECUTORS}                                         \\\n --driver-memory ${SPARK_DRIVER_MEMORY}                                         \\\n --executor-memory ${SPARK_EXECUTOR_MEMORY}                                     \\\n --class ${EXAMPLE_CLASS}                                                       \\\n ${SAMPLE_JAR}                                                                 \\\n -dataPath=train::${SPARK_XGBOOST_DIR}/mortgage/output/train/                   \\\n -dataPath=trans::${SPARK_XGBOOST_DIR}/mortgage/output/eval/                    \\\n -format=parquet                                                                \\\n -numWorkers=${SPARK_NUM_EXECUTORS}                                             \\\n -treeMethod=${TREE_METHOD}                                                     \\\n -numRound=100                                                                  \\\n -maxDepth=8                            \n   \n  # Please make sure to change the class and data path while running Taxi or Agaricus benchmark                                                       \n                                      \n```\n\nIn the `stdout` driver log, you should see timings<sup>*</sup> (in seconds), and the accuracy metric(take Mortgage as example):\n\n```\n--------------\n==> Benchmark: Elapsed time for [Mortgage CPU train csv stub Unknown Unknown Unknown]: 286.398s\n--------------\n\n--------------\n==> Benchmark: Elapsed time for [Mortgage CPU transform csv stub Unknown Unknown Unknown]: 49.836s\n--------------\n\n--------------\n==> Benchmark: Accuracy for [Mortgage CPU Accuracy csv stub Unknown Unknown Unknown]: 0.9873709530950067\n--------------\n```\n\n<sup>*</sup> The timings in this Getting Started guide are only for illustrative purpose.\nPlease see our [release announcement](https://medium.com/rapids-ai/nvidia-gpus-and-apache-spark-one-step-closer-2d99e37ac8fd) for official benchmarks.\n"
  },
  {
    "path": "docs/get-started/xgboost-examples/prepare-package-data/preparation-python.md",
    "content": "## Prepare packages and dataset for pyspark\n\nFor simplicity export the location to these jars. All examples assume the packages and dataset will be placed in the `/opt/xgboost` directory:\n\n### Download the jars\n\nDownload the RAPIDS Accelerator for Apache Spark plugin jar\n  * [RAPIDS Spark Package](https://repo1.maven.org/maven2/com/nvidia/rapids-4-spark_2.12/26.02.0/rapids-4-spark_2.12-26.02.0.jar)\n\n### Build XGBoost Python Examples\n\nFollowing this [guide](/docs/get-started/xgboost-examples/building-sample-apps/python.md), you can get *samples.zip* and *main.py* and copy them to `/opt/xgboost`\n\n### Download dataset\n\nYou need to copy the dataset to `/opt/xgboost`. Use the following links to download the data.\n1. [Mortgage dataset](/docs/get-started/xgboost-examples/dataset/mortgage.md)\n2. [Taxi dataset](https://www1.nyc.gov/site/tlc/about/tlc-trip-record-data.page)\n3. [Agaricus dataset](https://github.com/dmlc/xgboost/tree/master/demo/data)\n"
  },
  {
    "path": "docs/get-started/xgboost-examples/prepare-package-data/preparation-scala.md",
    "content": "## Prepare packages and dataset for scala\n\nFor simplicity export the location to these jars. All examples assume the packages and dataset will be placed in the `/opt/xgboost` directory:\n\n### Download the jars\n\n1. Download the RAPIDS Accelerator for Apache Spark plugin jar\n   * [RAPIDS Spark Package](https://repo1.maven.org/maven2/com/nvidia/rapids-4-spark_2.12/26.02.0/rapids-4-spark_2.12-26.02.0.jar)\n\n### Build XGBoost Scala Examples\n\nFollowing this [guide](/docs/get-started/xgboost-examples/building-sample-apps/scala.md), you can get *sample_xgboost_apps-0.2.3-jar-with-dependencies.jar* and copy it to `/opt/xgboost`\n\n### Download dataset\n\nYou need to copy the dataset to `/opt/xgboost`. Use the following links to download the data.\n1. [Mortgage dataset](/docs/get-started/xgboost-examples/dataset/mortgage.md)\n2. [Taxi dataset](https://www1.nyc.gov/site/tlc/about/tlc-trip-record-data.page)\n3. [Agaricus dataset](https://github.com/dmlc/xgboost/tree/master/demo/data)\n"
  },
  {
    "path": "docs/trouble-shooting/xgboost-examples-trouble-shooting.md",
    "content": "## XGBoost\n\n### 1. NCCL errors\n\nXGBoost supports distributed GPU training which depends on NCCL2 available at [this link](https://developer.nvidia.com/nccl). NCCL auto-detects which network interfaces to use for inter-node communication. If some interfaces are in state up, however are not able to communicate between nodes, NCCL may try to use them anyway and therefore fail during the init functions or **even hang**.\n\nTo track NCCL error, User needs to enable NCCL_DEBUG when submitting spark application by \n\n``` xml\n--conf spark.executorEnv.NCCL_DEBUG=INFO\n```\n\nSometimes, Node tries to connect to another node which selects an inappropriate interface, which may cause xgboost task hang. To fix this kind of issue, User needs to specify an appropriate interface for the node by NCCL_SOCKET_IFNAME\n\n``` xml\n--conf spark.executorEnv.NCCL_SOCKET_IFNAME=eth0\n```"
  },
  {
    "path": "examples/MIG-Support/README.md",
    "content": "# Multi-Instance GPU (MIG) support in Apache Hadoop YARN\n\nThere are multiple solutions for MIG scheduling on YARN that you can choose based on your environment and\ndeployment requirements:\n\n- [YARN 3.3.0+ MIG GPU Plugin](/examples/MIG-Support/device-plugins/gpu-mig) for adding a Java-based plugin for MIG\non top of the Pluggable Device Framework\n- [YARN 3.1.2 until YARN 3.3.0 MIG GPU Support](/examples/MIG-Support/resource-types/gpu-mig) for\npatching and rebuilding YARN code base to support MIG devices.\n- [YARN 3.1.2+ MIG GPU Support without modifying YARN / Device Plugin Code](/examples/MIG-Support/yarn-unpatched)\nrelying on installing nvidia CLI wrappers written in `bash`, but unlike the solutions above without\nany Java code changes.\n\n## Limitations and Caveats\n\nNote that are some common caveats for the solutions above.\n\n### Single MIG GPU per Container\n\nPlease see the [MIG Application Considerations](https://docs.nvidia.com/datacenter/tesla/mig-user-guide/#app-considerations)\nand [CUDA Device Enumeration](https://docs.nvidia.com/datacenter/tesla/mig-user-guide/index.html#cuda-visible-devices).\n\nIt is important to note that CUDA 11 only supports enumeration of a single MIG instance.\nIt is recommended that you configure YARN to only allow a single GPU be requested. See\nthe YARN config `yarn.resource-types.nvidia/miggpu.maximum-allocation` for the\n[Pluggable Device Framework](/examples/MIG-Support/device-plugins/gpu-mig) solution and\n`yarn.resource-types.yarn.io/gpu.maximum-allocation` for the remainder of MIG Support options above, respectively.\n\n### Metrics\nSome metrics are not and cannot be broken down by MIG device. For example, `utilization` is the\naggregate utilization of the parent GPU, and there is no attribution of `temperature` to a\nparticular MIG device.\n\n### GPU index / address as reported by Apache Spark in logs and UI\n\nWith YARN isolation using NVIDIA Container Runtime ensuring a single visible device\nper Docker container running a Spark Executor, each Executor will see a disjoint list comprising\na single device.\nTherefore, the user will end up observing index 0 being used by all executors. However, they refer\nto different GPU/MIG instances. You can verify this by running something like the following on a\nYARN worker node host OS:\n\n```bash\nfor cid in $(sudo docker ps -q); do sudo docker exec $cid bash -c \"printenv | grep VISIBLE; nvidia-smi -L\"; done\nNVIDIA_VISIBLE_DEVICES=3\nGPU 0: NVIDIA A30 (UUID: GPU-05aa99be-b706-0dc1-ab62-dd12f2227b7d)\n  MIG 1g.6gb      Device  0: (UUID: MIG-70dc024a-e8d7-587c-81dd-57ad493b1d91)\nNVIDIA_VISIBLE_DEVICES=1\nGPU 0: NVIDIA A30 (UUID: GPU-05aa99be-b706-0dc1-ab62-dd12f2227b7d)\n  MIG 1c.2g.12gb  Device  0: (UUID: MIG-54cc2421-6f2d-59e9-b074-20707aadd71e)\nNVIDIA_VISIBLE_DEVICES=2\nGPU 0: NVIDIA A30 (UUID: GPU-05aa99be-b706-0dc1-ab62-dd12f2227b7d)\n  MIG 1g.6gb      Device  0: (UUID: MIG-7e5552bf-d328-57a8-b091-0720d4530ffb)\nNVIDIA_VISIBLE_DEVICES=0\nGPU 0: NVIDIA A30 (UUID: GPU-05aa99be-b706-0dc1-ab62-dd12f2227b7d)\n  MIG 1c.2g.12gb  Device  0: (UUID: MIG-e6af58f0-9af8-594f-825e-74d23e1a68c1)\n```\n\n\n\n\n"
  },
  {
    "path": "examples/MIG-Support/device-plugins/gpu-mig/README.md",
    "content": "# NVIDIA GPU Plugin for YARN with MIG support for YARN 3.3.0+\n\nThis plugin adds support for GPUs with [MIG](https://docs.nvidia.com/datacenter/tesla/mig-user-guide/) on YARN. The built-in YARN GPU plugin does not support MIG enabled GPUs.\nThis plugin also works with GPUs without MIG or GPUs with MIG disabled but the limitation section still applies. It supports heterogenous environments where\nthere may be some MIG enabled GPUs and some without MIG. If you are not using MIG enabled GPUs, you should use the built-in YARN GPU plugin.\n\n## Compatibility\n\nIt works with Apache YARN 3.3.0+ versions that support the [Pluggable Device Framework](https://hadoop.apache.org/docs/current/hadoop-yarn/hadoop-yarn-site/PluggableDeviceFramework.html). This plugin requires YARN to be configured with Docker using the NVIDIA Container Toolkit (nvidia-docker2).\n\n## Limitations\n\nPlease see the [MIG Application Considerations](https://docs.nvidia.com/datacenter/tesla/mig-user-guide/#app-considerations)\nand [CUDA Device Enumeration](https://docs.nvidia.com/datacenter/tesla/mig-user-guide/index.html#cuda-visible-devices).\n\nIt is important to note that CUDA 11 only supports enumeration of a single MIG instance. This means that this plugin\nonly supports 1 GPU per container and the plugin will throw an exception by default if you request more.\nIt is recommended that you configure YARN to only allow a single GPU be requested. See the yarn config:\n```\n yarn.resource-types.nvidia/miggpu.maximum-allocation\n```\nSee [YARN Resource Configuration](https://hadoop.apache.org/docs/r3.3.1/hadoop-yarn/hadoop-yarn-site/ResourceModel.html) for more details.\nIf you do not configure the maximum allocation and someone requests multiple GPUs, the default behavior is to throw an exception. The user\nvisible exception is not very useful, as the real exception will be in the nodemanager logs. See the [Configuration](#configuration) section for options\nif it throws an exception.\n\n## Building From Source\n\n```\nmvn package \n```\n\nThis will create a jar `target/yarn-gpu-mig-plugin-1.0.0.jar`. This jar can be installed on your YARN cluster as a plugin.\n\n## Installation\n\nThese instructions assume YARN is already installed and configured with Docker enabled using the NVIDIA Container Toolkit (nvidia-docker2).\nEnable and configure your [GPUs with MIG](https://docs.nvidia.com/datacenter/tesla/mig-user-guide/index.html) on all of the nodes it applies to.\n\nInstall the jar into your Hadoop Cluster, see the [Test and Use Your Own Plugin](https://hadoop.apache.org/docs/current/hadoop-yarn/hadoop-yarn-site/DevelopYourOwnDevicePlugin.html)\nsection. This recommends installing it in something like `$HADOOP_COMMOND_HOME/share/hadoop/yarn`.\n\nConfigure the device plugin, see the YARN documentation on [Pluggable Device Framework](https://hadoop.apache.org/docs/current/hadoop-yarn/hadoop-yarn-site/PluggableDeviceFramework.html).\n\nAfter enabling the framework, enable the plugin in `yarn-site.xml`:\n\n```\n<property>\n  <name>yarn.nodemanager.pluggable-device-framework.device-classes</name>\n  <value>com.nvidia.spark.NvidiaGPUMigPluginForRuntimeV2</value>\n</property>\n\n```\n\nConfigure YARN to have the new resource type by modifying the `resource-types.xml` file to include:\n\n```\n<property>\n  <name>yarn.resource-types</name>\n  <value>nvidia/miggpu</value>\n</property>\n```\n\nRestart YARN to pick up any configuration changes.\n\n## Configuration\n\nTo change the behavior of throwing when the user allocates multiple GPUs, you can either set a config in the `yarn-site.xml` or set\nan environment variable when launching the Spark application. The environment variable will take precendence if both are set.\nIn either case, `true` means to throw if a user requests multiple GPUs (this is the default), `false`\nmeans it won't throw and if the container is allocated with multiple MIG devices from the same\nGPU, it is up to the application to know how to use them.\n\nConfig for `yarn-site.xml`:\n```\n<property>\n  <name>com.nvidia.spark.NvidiaGPUMigPluginForRuntimeV2.throwOnMultipleGPUs</name>\n  <value>true</value>\n</property>\n```\n\nEnvironment variable for Spark application:\n```\n--conf spark.executorEnv.NVIDIA_MIG_PLUGIN_THROW_ON_MULTIPLE_GPUS=true\n```\n\n## Using with Apache Spark on YARN\nSpark supports [scheduling GPUs and other custom resources on YARN](http://spark.apache.org/docs/latest/running-on-yarn.html#resource-allocation-and-configuration-overview). There are 2 options for using this plugin with Spark to allocate GPUs with MIG support: \n\n- Use Spark 3.2.1 or newer and remap the standard Spark `gpu` resource (i.e.: `spark.executor.resource.gpu.amount`) to be the new MIG GPU resource type using:\n```\n--conf spark.yarn.resourceGpuDeviceName=nvidia/miggpu\n```\nThis means users don't have to change their configs if they were already using the `gpu` resource type.\n\n- Spark applications specify the `nvidia/miggpu` resource type instead of the `gpu` resource type. For this the user has to change the resource\ntype to `nvidia/miggpu`, update the discovery script, and specify an extra YARN config(`spark.yarn.executor.resource.nvidia/miggpu.amount`).\nThe command would be something like below (update the amounts according to your setup):\n```\n --conf spark.executor.resource.nvidia/miggpu.amount=1 --conf spark.executor.resource.nvidia/miggpu.discoveryScript=./getMIGGPUs --conf spark.task.resource.nvidia/miggpu.amount=0.25 --files ./getMIGGpus --conf spark.yarn.executor.resource.nvidia/miggpu.amount=1\n```\nNote the getMIGGpus discovery script would is in the `scripts` directory in this repo. It just changes the resource name returned to match\n`nvidia/miggpu`.\n\n## Testing\nRun a Spark application using the [Rapids Accelerator for Apache Spark](https://nvidia.github.io/spark-rapids/) and request GPUs\nfrom YARN and verify they use the MIG enabled GPUs.\n"
  },
  {
    "path": "examples/MIG-Support/device-plugins/gpu-mig/pom.xml",
    "content": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<!--\n  Copyright (c) 2021, NVIDIA CORPORATION.\n\n  Licensed under the Apache License, Version 2.0 (the \"License\");\n  you may not use this file except in compliance with the License.\n  You may obtain a copy of the License at\n\n     http://www.apache.org/licenses/LICENSE-2.0\n\n  Unless required by applicable law or agreed to in writing, software\n  distributed under the License is distributed on an \"AS IS\" BASIS,\n  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n  See the License for the specific language governing permissions and\n  limitations under the License.\n-->\n<project xmlns=\"http://maven.apache.org/POM/4.0.0\"\n         xmlns:xsi=\"http://www.w3.org/2001/XMLSchema-instance\"\n         xsi:schemaLocation=\"http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd\">\n    <modelVersion>4.0.0</modelVersion>\n\n    <groupId>com.nvidia</groupId>\n    <artifactId>yarn-gpu-mig-plugin</artifactId>\n    <name>YARN Device Plugin that supports MIG</name>\n    <description>The root project of the YARN Device Plugin that supports MIG</description>\n    <version>1.0.0</version>\n    <packaging>jar</packaging>\n\n    <licenses>\n        <license>\n            <name>Apache License, Version 2.0</name>\n            <url>https://www.apache.org/licenses/LICENSE-2.0.txt</url>\n            <distribution>repo</distribution>\n        </license>\n    </licenses>\n\n    <properties>\n        <yarn.version>3.3.6</yarn.version>\n        <java.version>1.8</java.version>\n        <maven.compiler.version>3.8.1</maven.compiler.version>\n        <maven.jar.plugin.version>3.2.0</maven.jar.plugin.version>\n        <junit.version>4.13.1</junit.version>\n        <mockito.core.version>3.4.6</mockito.core.version>\n    </properties>\n    <dependencies>\n        <dependency>\n            <groupId>org.apache.hadoop</groupId>\n            <artifactId>hadoop-yarn-server-nodemanager</artifactId>\n            <version>${yarn.version}</version>\n            <scope>provided</scope>\n        </dependency>\n        <dependency>\n            <groupId>junit</groupId>\n            <artifactId>junit</artifactId>\n            <version>${junit.version}</version>\n            <scope>test</scope>\n        </dependency>\n        <dependency>\n            <groupId>org.mockito</groupId>\n            <artifactId>mockito-core</artifactId>\n            <version>${mockito.core.version}</version>\n            <scope>test</scope>\n        </dependency>\n    </dependencies>\n\n    <build>\n        <plugins>\n            <plugin>\n                <groupId>org.apache.maven.plugins</groupId>\n                <artifactId>maven-compiler-plugin</artifactId>\n                <version>${maven.compiler.version}</version>\n                <configuration>\n                    <source>${java.version}</source>\n                    <target>${java.version}</target>\n                </configuration>\n            </plugin>\n            <plugin>\n                <groupId>org.apache.maven.plugins</groupId>\n                <artifactId>maven-jar-plugin</artifactId>\n                <version>${maven.jar.plugin.version}</version>\n                <executions>\n                    <execution>\n                        <id>default-jar</id>\n                        <phase>package</phase>\n                        <goals>\n                            <goal>jar</goal>\n                        </goals>\n                    </execution>\n                </executions>\n            </plugin>\n        </plugins>\n    </build>\n</project>\n"
  },
  {
    "path": "examples/MIG-Support/device-plugins/gpu-mig/scripts/getMIGGPUs",
    "content": "#!/usr/bin/env bash\n\n# Copyright (c) 2021, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\n# This script is a basic example script to get resource information about NVIDIA MIG GPUs.\n# It works with the NVIDIA GPU Plugin for YARN with MIG support and is expected to be run\n# in a container where the nvidia-docker-v2 plugin has taken care of mapping the MIG\n# devices. This is the same as the Aapche Spark script, except the resource name is changed\n# to match the new plugin.\n#\n# It assumes the drivers are properly installed and the nvidia-smi command is available.\n# It is not guaranteed to work on all setups so please test and customize as needed\n# for your environment. It can be passed into SPARK via the config\n# spark.{driver/executor}.resource.gpu.discoveryScript to allow the driver or executor to discover\n# the GPUs it was allocated. It assumes you are running within an isolated container where the\n# GPUs are allocated exclusively to that driver or executor.\n# It outputs a JSON formatted string that is expected by the\n# spark.{driver/executor}.resource.gpu.discoveryScript config.\n#\n# Example output: {\"name\": \"nvidia/miggpu\", \"addresses\":[\"0\"]}\n\nADDRS=`nvidia-smi --query-gpu=index --format=csv,noheader | sed -e ':a' -e 'N' -e'$!ba' -e 's/\\n/\",\"/g'`\necho {\\\"name\\\": \\\"nvidia/miggpu\\\", \\\"addresses\\\":[\\\"$ADDRS\\\"]}\n"
  },
  {
    "path": "examples/MIG-Support/device-plugins/gpu-mig/src/main/java/com/nvidia/spark/NvidiaGPUMigPluginForRuntimeV2.java",
    "content": "/*\n * Copyright (c) 2021, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.nvidia.spark;\n\nimport java.util.regex.Matcher;\nimport java.util.regex.Pattern;\n\nimport org.apache.hadoop.conf.Configuration;\nimport org.apache.hadoop.util.Shell;\nimport org.apache.hadoop.yarn.exceptions.YarnException;\nimport org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.Device;\nimport org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePlugin;\nimport org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePluginScheduler;\nimport org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRegisterRequest;\nimport org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRuntimeSpec;\nimport org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.YarnRuntimeType;\nimport org.slf4j.Logger;\nimport org.slf4j.LoggerFactory;\n\nimport java.io.File;\nimport java.io.IOException;\nimport java.util.HashMap;\nimport java.util.Map;\nimport java.util.Set;\nimport java.util.TreeSet;\n\n/**\n * Nvidia GPU plugin supporting both Nvidia container runtime v2.\n * It supports discovering and allocating MIG devices. Currently, with CUDA 11,\n * only enumeration of a single MIG instance is supported. This means that\n * this plugin officially only supports 1 GPU per container and by default\n * will throw an exception if more are requested. The behavior of throwing\n * an exception is configurable by either setting the environment variable\n * {@code NVIDIA_MIG_PLUGIN_THROW_ON_MULTIPLE_GPUS} or by setting the YARN config\n * {@code com.nvidia.spark.NvidiaGPUMigPluginForRuntimeV2.throwOnMultipleGPUs}\n * to false.\n */\npublic class NvidiaGPUMigPluginForRuntimeV2 implements DevicePlugin,\n        DevicePluginScheduler {\n    public static final Logger LOG = LoggerFactory.getLogger(\n            NvidiaGPUMigPluginForRuntimeV2.class);\n\n    public static final String NV_RESOURCE_NAME = \"nvidia/miggpu\";\n\n    private NvidiaCommandExecutor shellExecutor = new NvidiaCommandExecutor();\n\n    private Map<String, String> environment = new HashMap<>();\n\n    // If this environment is set, use it directly\n    private static final String ENV_BINARY_PATH = \"NVIDIA_SMI_PATH\";\n\n    private static final String DEFAULT_BINARY_NAME = \"nvidia-smi\";\n\n    private static final String DEV_NAME_PREFIX = \"nvidia\";\n\n    private static final String THROW_MULTI_CONF =\n            \"com.nvidia.spark.NvidiaGPUMigPluginForRuntimeV2.throwOnMultipleGPUs\";\n\n    private static final String THROW_MULTI_ENV = \"NVIDIA_MIG_PLUGIN_THROW_ON_MULTIPLE_GPUS\";\n\n    private Boolean shouldThrowOnMultipleGPUFromConf =\n        new Configuration().getBoolean(THROW_MULTI_CONF, true);\n    private String shouldThrowOnMultipleGPUFromEnv = null;\n\n    private String pathOfGpuBinary = null;\n\n    // command should not run more than 10 sec.\n    private static final int MAX_EXEC_TIMEOUT_MS = 10 * 1000;\n\n    // When executable path not set, try to search default dirs\n    // By default search /usr/bin, /bin, and /usr/local/nvidia/bin (when\n    // launched by nvidia-docker.\n    private static final String[] DEFAULT_BINARY_SEARCH_DIRS = new String[]{\n            \"/usr/bin\", \"/bin\", \"/usr/local/nvidia/bin\"};\n\n    // device id -> mig id, populated during discovery and used when launching\n    // containers\n    private Map<Integer, String> migDevices = new HashMap<>();\n\n    private String migInfoOutput = null;\n\n\n    @Override\n    public DeviceRegisterRequest getRegisterRequestInfo() throws Exception {\n        return DeviceRegisterRequest.Builder.newInstance()\n                .setResourceName(NV_RESOURCE_NAME).build();\n    }\n\n    @Override\n    public Set<Device> getDevices() throws Exception {\n        shellExecutor.searchBinary();\n        TreeSet<Device> r = new TreeSet<>();\n        String output;\n        try {\n            output = shellExecutor.getDeviceInfo();\n            String[] lines = output.trim().split(\"\\n\");\n            int id = 0;\n            for (String oneLine : lines) {\n                String[] tokensEachLine = oneLine.split(\",\");\n                if (tokensEachLine.length != 3) {\n                    throw new Exception(\"Cannot parse the output to get the MIG enabled info. \"\n                            + \"output: \" + oneLine + \" expected index,pci.bus_id,mig.mode.current\");\n                }\n                String minorNumber = tokensEachLine[0].trim();\n                String busId = tokensEachLine[1].trim();\n                String migMode = tokensEachLine[2].trim();\n                String majorNumber = getMajorNumber(DEV_NAME_PREFIX\n                        + minorNumber);\n\n                if (majorNumber != null) {\n                    if (migMode.equalsIgnoreCase(\"enabled\")) {\n                        if (migInfoOutput == null) {\n                            // we get the mig info for all the GPUs on the host so only get it once\n                            migInfoOutput = shellExecutor.getDeviceMigInfo();\n                            if (migInfoOutput == null) {\n                                throw new Exception(\"MIG device enabled but no device info found\");\n                            }\n                        }\n                        String[] linesMig = migInfoOutput.trim().split(\"\\n\");\n                        Integer minorNumInt = Integer.parseInt(minorNumber);\n                        Integer migDevCount = 0;\n                        Integer numMigOutputLines = linesMig.length;\n                        for (int idmig = 0; idmig < numMigOutputLines; idmig++) {\n                            // first line should start with GPU\n                            // GPU 0: NVIDIA A30 (UUID: GPU-e7076666-0544-e103-4f65-a047fc18269e)\n                            // MIG 1g.6gb      Device  0: (UUID: MIG-de9876e2-eef7-5b5a-9701-db694ffe8a77)\n                            if (linesMig[idmig].startsWith(\"GPU \" + minorNumInt) && numMigOutputLines > (idmig + 1)) {\n                                // process any MIG devices, this expects all the lines to be MIG devices until\n                                // we find one that starts with GPU\n                                String nextLine = linesMig[++idmig].trim();\n                                String regex = \"MIG (.+)Device\\\\s+(\\\\d+):\\\\s+\\\\(UUID:(.*)\\\\)\";\n                                Pattern pattern = Pattern.compile(regex);\n                                while (nextLine.startsWith(\"MIG\")) {\n                                    Matcher matcher = pattern.matcher(nextLine);\n                                    while (matcher.find()) {\n                                        String devId = matcher.group(2);\n                                        migDevices.put(id, devId);\n                                        migDevCount++;\n                                        r.add(Device.Builder.newInstance()\n                                                .setId(id)\n                                                .setMajorNumber(Integer.parseInt(majorNumber))\n                                                .setMinorNumber(minorNumInt)\n                                                .setBusID(busId)\n                                                .setDevPath(\"/dev/\" + DEV_NAME_PREFIX + minorNumber)\n                                                .setHealthy(true)\n                                                .setStatus(devId)\n                                                .build());\n                                        id++;\n                                        if (++idmig < numMigOutputLines) {\n                                            nextLine = linesMig[idmig].trim();\n                                        } else {\n                                            nextLine = \"\";\n                                        }\n                                    }\n                                }\n                                idmig = numMigOutputLines;\n                            }\n                        }\n                        if (migDevCount < 1) {\n                            throw new IOException(\"Error finding MIG devices on GPU with \" +\n                                \"MIG enabled: \" + migInfoOutput);\n                        }\n                        LOG.info(\"Added GPU \" + majorNumber + \":\" + minorNumInt +\n                            \" with MIG Enabled, found \" + migDevCount + \" MIG devices\");\n                    } else {\n                        Integer majorNumInt = Integer.parseInt(majorNumber);\n                        Integer minorNumInt = Integer.parseInt(minorNumber);\n                        r.add(Device.Builder.newInstance()\n                                .setId(id)\n                                .setMajorNumber(majorNumInt)\n                                .setMinorNumber(minorNumInt)\n                                .setBusID(busId)\n                                .setDevPath(\"/dev/\" + DEV_NAME_PREFIX + minorNumber)\n                                .setHealthy(true)\n                                .build());\n                        LOG.info(\"Added GPU \" + majorNumInt + \":\" + minorNumInt);\n                        id++;\n                    }\n                }\n            }\n            return r;\n        } catch (IOException e) {\n            LOG.debug(\"Failed to get output from {}\", pathOfGpuBinary);\n            throw new YarnException(e);\n        }\n    }\n\n    private Boolean shouldThrowOnMultipleGPUs() {\n        // env setting takes highest priority if it is set\n        if (shouldThrowOnMultipleGPUFromEnv != null) {\n            return Boolean.parseBoolean(shouldThrowOnMultipleGPUFromEnv);\n        }\n        return shouldThrowOnMultipleGPUFromConf;\n    }\n\n    @Override\n    public DeviceRuntimeSpec onDevicesAllocated(Set<Device> allocatedDevices,\n                                                YarnRuntimeType yarnRuntime) throws Exception {\n        LOG.debug(\"Generating runtime spec for allocated devices: {}, {}\",\n                allocatedDevices, yarnRuntime.getName());\n        if (allocatedDevices.size() > 1 && shouldThrowOnMultipleGPUs()) {\n            throw new YarnException(\"Allocating more than 1 GPU per container is\" +\n                    \" not supported with use of MIG!\");\n        }\n        if (yarnRuntime == YarnRuntimeType.RUNTIME_DOCKER) {\n            String nvidiaRuntime = \"nvidia\";\n            String nvidiaVisibleDevices = \"NVIDIA_VISIBLE_DEVICES\";\n            StringBuffer gpuMinorNumbersSB = new StringBuffer();\n            for (Device device : allocatedDevices) {\n                Integer minorNum = device.getMinorNumber();\n                Integer id = device.getId();\n                if (migDevices.containsKey(id)) {\n                    gpuMinorNumbersSB.append(minorNum + \":\" + migDevices.get(id) + \",\");\n                } else {\n                    gpuMinorNumbersSB.append(minorNum + \",\");\n                }\n            }\n            String minorNumbers = gpuMinorNumbersSB.toString();\n            LOG.info(\"Nvidia Docker v2 assigned GPU: \" + minorNumbers);\n            String deviceStr = minorNumbers.substring(0, minorNumbers.length() - 1);\n            return DeviceRuntimeSpec.Builder.newInstance()\n                    .addEnv(nvidiaVisibleDevices, deviceStr)\n                    .setContainerRuntime(nvidiaRuntime)\n                    .build();\n        }\n        return null;\n    }\n\n    @Override\n    public void onDevicesReleased(Set<Device> releasedDevices) throws Exception {\n        // do nothing\n    }\n\n    // Get major number from device name.\n    private String getMajorNumber(String devName) {\n        String output = null;\n        // output \"major:minor\" in hex\n        try {\n            LOG.debug(\"Get major numbers from /dev/{}\", devName);\n            output = shellExecutor.getMajorMinorInfo(devName);\n            String[] strs = output.trim().split(\":\");\n            output = Integer.toString(Integer.parseInt(strs[0], 16));\n        } catch (IOException e) {\n            String msg =\n                    \"Failed to get major number from reading /dev/\" + devName;\n            LOG.warn(msg);\n        } catch (NumberFormatException e) {\n            LOG.error(\"Failed to parse device major number from stat output\");\n            output = null;\n        }\n        return output;\n    }\n\n    @Override\n    public Set<Device> allocateDevices(Set<Device> availableDevices, int count,\n                                       Map<String, String> envs) {\n        Set<Device> allocation = new TreeSet<>();\n        String envShouldThrow = envs.get(THROW_MULTI_ENV);\n        if (envShouldThrow != null) {\n            shouldThrowOnMultipleGPUFromEnv = envShouldThrow;\n        }\n        // Only officially support 1 GPU per container so don't worry about topology\n        // scheduling.\n        basicSchedule(allocation, count, availableDevices);\n        return allocation;\n    }\n\n    public void basicSchedule(Set<Device> allocation, int count,\n                              Set<Device> availableDevices) {\n        // Basic scheduling\n        // allocate all available\n        if (count == availableDevices.size()) {\n            allocation.addAll(availableDevices);\n            return;\n        }\n        int number = 0;\n        for (Device d : availableDevices) {\n            allocation.add(d);\n            number++;\n            if (number == count) {\n                break;\n            }\n        }\n    }\n\n    /**\n     * A shell wrapper class easy for test.\n     */\n    public class NvidiaCommandExecutor {\n\n        public String getDeviceInfo() throws IOException {\n            return Shell.execCommand(environment,\n                    new String[]{pathOfGpuBinary, \"--query-gpu=index,pci.bus_id,mig.mode.current\",\n                            \"--format=csv,noheader\"}, MAX_EXEC_TIMEOUT_MS);\n        }\n\n        public String getDeviceMigInfo() throws IOException {\n            return Shell.execCommand(environment,\n                    new String[]{pathOfGpuBinary, \"-L\"}, MAX_EXEC_TIMEOUT_MS);\n        }\n\n        public String getMajorMinorInfo(String devName) throws IOException {\n            // output \"major:minor\" in hex\n            Shell.ShellCommandExecutor shexec = new Shell.ShellCommandExecutor(\n                    new String[]{\"stat\", \"-c\", \"%t:%T\", \"/dev/\" + devName});\n            shexec.execute();\n            return shexec.getOutput();\n        }\n\n        public void searchBinary() throws Exception {\n            if (pathOfGpuBinary != null) {\n                LOG.info(\"Skip searching, the NVIDIA gpu binary is already set: \"\n                        + pathOfGpuBinary);\n                return;\n            }\n            // search env for the binary\n            String envBinaryPath = System.getenv(ENV_BINARY_PATH);\n            if (null != envBinaryPath) {\n                if (new File(envBinaryPath).exists()) {\n                    pathOfGpuBinary = envBinaryPath;\n                    LOG.info(\"Use NVIDIA gpu binary: \" + pathOfGpuBinary);\n                    return;\n                }\n            }\n            LOG.debug(\"Search binary..\");\n            // search if binary exists in default folders\n            File binaryFile;\n            boolean found = false;\n            for (String dir : DEFAULT_BINARY_SEARCH_DIRS) {\n                binaryFile = new File(dir, DEFAULT_BINARY_NAME);\n                if (binaryFile.exists()) {\n                    found = true;\n                    pathOfGpuBinary = binaryFile.getAbsolutePath();\n                    LOG.info(\"Found binary:\" + pathOfGpuBinary);\n                    break;\n                }\n            }\n            if (!found) {\n                LOG.error(\"No binary found from env variable: \"\n                        + ENV_BINARY_PATH + \" or path \"\n                        + DEFAULT_BINARY_SEARCH_DIRS.toString());\n                throw new Exception(\"No binary found for \"\n                        + NvidiaGPUMigPluginForRuntimeV2.class);\n            }\n        }\n    }\n\n    // visible for testing\n    public void setPathOfGpuBinary(String pOfGpuBinary) {\n        this.pathOfGpuBinary = pOfGpuBinary;\n    }\n\n    // visible for testing\n    public void setShellExecutor(NvidiaCommandExecutor shellExecutor) {\n        this.shellExecutor = shellExecutor;\n    }\n\n    // visible for testing\n    public void setMigDevices(Map<Integer, String> migDevices) {\n        this.migDevices = migDevices;\n    }\n\n    // visible for testing\n    public void setShouldThrowOnMultipleGPUFromConf(Boolean shouldThrow) {\n        this.shouldThrowOnMultipleGPUFromConf = shouldThrow;\n    }\n}\n"
  },
  {
    "path": "examples/MIG-Support/device-plugins/gpu-mig/src/test/java/com/nvidia/spark/TestNvidiaGPUMigPluginForRuntimeV2.java",
    "content": "/*\n * Copyright (c) 2021, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\npackage com.nvidia.spark;\n\nimport org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.Device;\nimport org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRuntimeSpec;\nimport org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.YarnRuntimeType;\nimport org.junit.Assert;\nimport org.junit.Test;\nimport org.slf4j.Logger;\nimport org.slf4j.LoggerFactory;\n\nimport java.util.HashMap;\nimport java.util.Map;\nimport java.util.Set;\nimport java.util.TreeSet;\n\nimport static org.mockito.Mockito.mock;\nimport static org.mockito.Mockito.when;\n\n/**\n * Test case for NvidiaGPUMigPluginForRuntimeV2 device plugin.\n */\npublic class TestNvidiaGPUMigPluginForRuntimeV2 {\n\n    private static final Logger LOG =\n            LoggerFactory.getLogger(TestNvidiaGPUMigPluginForRuntimeV2.class);\n\n    @Test\n    public void testGetNvidiaDevices() throws Exception {\n        NvidiaGPUMigPluginForRuntimeV2.NvidiaCommandExecutor mockShell =\n                mock(NvidiaGPUMigPluginForRuntimeV2.NvidiaCommandExecutor.class);\n        String deviceInfoShellOutput =\n                \"0, 00000000:04:00.0, [N/A]\\n\" +\n                \"1, 00000000:82:00.0, Enabled\";\n        String majorMinorNumber0 = \"c3:0\";\n        String majorMinorNumber1 = \"c3:1\";\n        String deviceMigInfoShellOutput =\n                \"GPU 0: NVIDIA A100 80GB PCIe (UUID: GPU-aa72194b-fdd4-24b0-f659-17c929f46267)\\n\" +\n                \"  MIG 1g.10gb     Device  0: (UUID: MIG-aa2c982c-48a9-5046-b7f8-aa4732879e02)\\n\" +\n                \"GPU 1: NVIDIA A100 80GB PCIe (UUID: GPU-aa7153bf-c0ba-00ef-cdce-f861c34172f6)\\n\" +\n                \"  MIG 1g.10gb     Device  0: (UUID: MIG-aa59d467-ba39-5d0a-a085-66af03246526)\\n\" +\n                \"  MIG 1g.10gb     Device  1: (UUID: MIG-aad5cb29-8e6f-510a-8352-8e18f483dc74)\" +\n        when(mockShell.getDeviceInfo()).thenReturn(deviceInfoShellOutput);\n        when(mockShell.getDeviceMigInfo()).thenReturn(deviceMigInfoShellOutput);\n        when(mockShell.getMajorMinorInfo(\"nvidia0\"))\n                .thenReturn(majorMinorNumber0);\n        when(mockShell.getMajorMinorInfo(\"nvidia1\"))\n                .thenReturn(majorMinorNumber1);\n        NvidiaGPUMigPluginForRuntimeV2 plugin = new NvidiaGPUMigPluginForRuntimeV2();\n        plugin.setShellExecutor(mockShell);\n        plugin.setPathOfGpuBinary(\"/fake/nvidia-smi\");\n\n        Set<Device> expectedDevices = new TreeSet<>();\n        expectedDevices.add(Device.Builder.newInstance()\n                .setId(0).setHealthy(true)\n                .setBusID(\"00000000:04:00.0\")\n                .setDevPath(\"/dev/nvidia0\")\n                .setMajorNumber(195)\n                .setStatus(\"0\")\n                .setMinorNumber(0).build());\n        expectedDevices.add(Device.Builder.newInstance()\n                .setId(1).setHealthy(true)\n                .setBusID(\"00000000:82:00.0\")\n                .setDevPath(\"/dev/nvidia1\")\n                .setMajorNumber(195)\n                .setStatus(\"0\")\n                .setMinorNumber(1).build());\n        expectedDevices.add(Device.Builder.newInstance()\n                .setId(2).setHealthy(true)\n                .setBusID(\"00000000:82:00.0\")\n                .setDevPath(\"/dev/nvidia1\")\n                .setMajorNumber(195)\n                .setStatus(\"1\")\n                .setMinorNumber(1).build());\n        Set<Device> devices = plugin.getDevices();\n        Assert.assertEquals(expectedDevices, devices);\n    }\n\n    @Test(expected = Exception.class)\n    public void testOnDeviceAllocatedMultiGPU() throws Exception {\n        NvidiaGPUMigPluginForRuntimeV2 plugin = new NvidiaGPUMigPluginForRuntimeV2();\n        Set<Device> allocatedDevices = new TreeSet<>();\n\n        DeviceRuntimeSpec spec = plugin.onDevicesAllocated(allocatedDevices,\n                YarnRuntimeType.RUNTIME_DEFAULT);\n        Assert.assertNull(spec);\n\n        // allocate one device\n        allocatedDevices.add(Device.Builder.newInstance()\n                .setId(0).setHealthy(true)\n                .setBusID(\"00000000:04:00.0\")\n                .setDevPath(\"/dev/nvidia0\")\n                .setMajorNumber(195)\n                .setMinorNumber(0).build());\n        spec = plugin.onDevicesAllocated(allocatedDevices,\n                YarnRuntimeType.RUNTIME_DOCKER);\n        Assert.assertEquals(\"nvidia\", spec.getContainerRuntime());\n        Assert.assertEquals(\"0\", spec.getEnvs().get(\"NVIDIA_VISIBLE_DEVICES\"));\n\n        // two device allowed\n        allocatedDevices.add(Device.Builder.newInstance()\n                .setId(0).setHealthy(true)\n                .setBusID(\"00000000:82:00.0\")\n                .setDevPath(\"/dev/nvidia1\")\n                .setMajorNumber(195)\n                .setMinorNumber(1).build());\n        spec = plugin.onDevicesAllocated(allocatedDevices,\n                YarnRuntimeType.RUNTIME_DOCKER);\n    }\n\n    @Test\n    public void testMultiGPUsEnvPrecedence() throws Exception {\n        NvidiaGPUMigPluginForRuntimeV2 plugin = new NvidiaGPUMigPluginForRuntimeV2();\n        Set<Device> allocatedDevices = new TreeSet<>();\n\n        DeviceRuntimeSpec spec = plugin.onDevicesAllocated(allocatedDevices,\n                YarnRuntimeType.RUNTIME_DEFAULT);\n        Assert.assertNull(spec);\n\n        // allocate one device\n        allocatedDevices.add(Device.Builder.newInstance()\n                .setId(0).setHealthy(true)\n                .setBusID(\"00000000:04:00.0\")\n                .setDevPath(\"/dev/nvidia0\")\n                .setMajorNumber(195)\n                .setMinorNumber(0).build());\n\n        // two device allowed\n        allocatedDevices.add(Device.Builder.newInstance()\n                .setId(0).setHealthy(true)\n                .setBusID(\"00000000:82:00.0\")\n                .setDevPath(\"/dev/nvidia1\")\n                .setMajorNumber(195)\n                .setMinorNumber(1).build());\n\n        // test that env variable takes presedence\n        plugin.setShouldThrowOnMultipleGPUFromConf(true);\n        Map<String, String> envs = new HashMap<>();\n        envs.put(\"NVIDIA_MIG_PLUGIN_THROW_ON_MULTIPLE_GPUS\", \"false\");\n        // note the allocated devices doesn't matter here, just the env passed in\n        plugin.allocateDevices(allocatedDevices, 2, envs);\n        spec = plugin.onDevicesAllocated(allocatedDevices,\n                YarnRuntimeType.RUNTIME_DOCKER);\n        Assert.assertEquals(\"nvidia\", spec.getContainerRuntime());\n        Assert.assertEquals(\"0,1\", spec.getEnvs().get(\"NVIDIA_VISIBLE_DEVICES\"));\n    }\n\n    @Test\n    public void testMultiGPUsConf() throws Exception {\n        NvidiaGPUMigPluginForRuntimeV2 plugin = new NvidiaGPUMigPluginForRuntimeV2();\n        Set<Device> allocatedDevices = new TreeSet<>();\n\n        DeviceRuntimeSpec spec = plugin.onDevicesAllocated(allocatedDevices,\n                YarnRuntimeType.RUNTIME_DEFAULT);\n        Assert.assertNull(spec);\n\n        // allocate one device\n        allocatedDevices.add(Device.Builder.newInstance()\n                .setId(0).setHealthy(true)\n                .setBusID(\"00000000:04:00.0\")\n                .setDevPath(\"/dev/nvidia0\")\n                .setMajorNumber(195)\n                .setMinorNumber(0).build());\n\n        // two device allowed\n        allocatedDevices.add(Device.Builder.newInstance()\n                .setId(0).setHealthy(true)\n                .setBusID(\"00000000:82:00.0\")\n                .setDevPath(\"/dev/nvidia1\")\n                .setMajorNumber(195)\n                .setMinorNumber(1).build());\n\n        // test that env variable takes presedence\n        plugin.setShouldThrowOnMultipleGPUFromConf(false);\n        spec = plugin.onDevicesAllocated(allocatedDevices,\n                YarnRuntimeType.RUNTIME_DOCKER);\n        Assert.assertEquals(\"nvidia\", spec.getContainerRuntime());\n        Assert.assertEquals(\"0,1\", spec.getEnvs().get(\"NVIDIA_VISIBLE_DEVICES\"));\n    }\n\n    @Test\n    public void testOnDeviceAllocatedMig() throws Exception {\n        NvidiaGPUMigPluginForRuntimeV2 plugin = new NvidiaGPUMigPluginForRuntimeV2();\n        Set<Device> allocatedDevices = new TreeSet<>();\n\n        DeviceRuntimeSpec spec = plugin.onDevicesAllocated(allocatedDevices,\n                YarnRuntimeType.RUNTIME_DEFAULT);\n        Assert.assertNull(spec);\n\n        Map<Integer, String> testMigDevices = new HashMap<>();\n        testMigDevices.put(0, \"0\");\n        plugin.setMigDevices(testMigDevices);\n\n        // allocate one device\n        allocatedDevices.add(Device.Builder.newInstance()\n                .setId(0).setHealthy(true)\n                .setBusID(\"00000000:04:00.0\")\n                .setDevPath(\"/dev/nvidia0\")\n                .setMajorNumber(195)\n                .setMinorNumber(0).build());\n        spec = plugin.onDevicesAllocated(allocatedDevices,\n                YarnRuntimeType.RUNTIME_DOCKER);\n        Assert.assertEquals(\"nvidia\", spec.getContainerRuntime());\n        Assert.assertEquals(\"0:0\", spec.getEnvs().get(\"NVIDIA_VISIBLE_DEVICES\"));\n    }\n\n    @Test\n    public void testOnDeviceAllocatedNoMig() throws Exception {\n        NvidiaGPUMigPluginForRuntimeV2 plugin = new NvidiaGPUMigPluginForRuntimeV2();\n        Set<Device> allocatedDevices = new TreeSet<>();\n\n        DeviceRuntimeSpec spec = plugin.onDevicesAllocated(allocatedDevices,\n                YarnRuntimeType.RUNTIME_DEFAULT);\n        Assert.assertNull(spec);\n\n        // allocate one device\n        allocatedDevices.add(Device.Builder.newInstance()\n                .setId(0).setHealthy(true)\n                .setBusID(\"00000000:04:00.0\")\n                .setDevPath(\"/dev/nvidia0\")\n                .setMajorNumber(195)\n                .setMinorNumber(0).build());\n        spec = plugin.onDevicesAllocated(allocatedDevices,\n                YarnRuntimeType.RUNTIME_DOCKER);\n        Assert.assertEquals(\"nvidia\", spec.getContainerRuntime());\n        Assert.assertEquals(\"0\", spec.getEnvs().get(\"NVIDIA_VISIBLE_DEVICES\"));\n    }\n}\n"
  },
  {
    "path": "examples/MIG-Support/resource-types/gpu-mig/README.md",
    "content": "# NVIDIA Support for GPU for YARN with MIG support for YARN 3.1.2 until YARN 3.3.0\n\nThis adds support for GPUs with [MIG](https://docs.nvidia.com/datacenter/tesla/mig-user-guide/) on YARN for versions prior to\nYARN 3.3.0 which don't support the pluggable device framework. Use the [GPU Plugin for YARN with MIG support](../../device-plugins/gpu-mig/README.md)\nfor YARN 3.3.0 and newer versions. The built-in YARN GPU plugin does not support MIG enabled GPUs. This patch\nworks with GPUs without MIG or GPUs with MIG disabled but the limitation section still applies. It supports heterogenous\nenvironments where there may be some MIG enabled GPUs and some without MIG. This requires patching YARN and rebuilding it.\n\n## Compatibility\n\nRequires YARN 3.1.2 or newer that supports GPU scheduling. See the [supported versions](#supported-versions) section below for specific versions supported.\nMIG support requires YARN to be configured with Docker and using the NVIDIA Container Toolkit (nvidia-docker2)\n\n## Limitations\n\nPlease see the [MIG Application Considerations](https://docs.nvidia.com/datacenter/tesla/mig-user-guide/#app-considerations)\nand [CUDA Device Enumeration](https://docs.nvidia.com/datacenter/tesla/mig-user-guide/index.html#cuda-visible-devices).\n\nIt is important to note that CUDA 11 only supports enumeration of a single MIG instance. This means that with this patch \nand MIG support enabled, it only supports 1 GPU per container and will throw an exception by default if you request more.\nIt is recommended that you configure YARN to only allow a single GPU be requested. See the yarn config:\n```\n yarn.resource-types.yarn.io/gpu.maximum-allocation\n```\nSee [YARN Resource Configuration](https://hadoop.apache.org/docs/r3.1.2/hadoop-yarn/hadoop-yarn-site/ResourceModel.html) for more details.\nIf you do not configure the maximum allocation and someone requests multiple GPUs, the default behavior is to throw an exception.\nSee the [Configuration](#configuration) section for options if it throws an exception.\n\n## Supported Versions\nThere are different patches available depending on the YARN version you are using:\n\n- YARN 3.1.2 use patch `yarn312MIG.patch`\n- YARN versions 3.1.3 to 3.1.5 (git hash cd7c34f9b4005d27886f73e58bef88e706fcccf9 since 3.1.5 was not released when this was tested) use `yarn313to315MIG.patch`\n- YARN 3.2.0, no patch is currently available, backport patch for YARN 3.2.1 or contact us.\n- YARN 3.2.1 and 3.2.3 use patch `yarn321to323MIG.patch`\n\n## Building\nApply the patch to your YARN version and build it like you would normally for your deployment.\n\nFor example:\n```\npatch -p1 < yarn312MIG.patch\nmvn clean package -Pdist -Dtar -DskipTests\n```\n\nRun unit tests:\n```\nmvn test -Pdist -Dtar -Dtest=TestGpuDiscoverer\nmvn test -Pdist -Dtar -Dtest=TestNvidiaDockerV2CommandPlugin\n```\n\n## Installation\n\nThese instructions assume YARN is already installed and configured with GPU Scheduling enabled using Docker and the NVIDIA Container Toolkit (nvidia-docker2).\nSee [Using GPU on YARN](https://hadoop.apache.org/docs/current/hadoop-yarn/hadoop-yarn-site/UsingGpus.html) if you need more information. \n\nEnable and configure your [GPUs with MIG](https://docs.nvidia.com/datacenter/tesla/mig-user-guide/index.html) on all of the nodes it applies to.\n\nInstall the new YARN version built with the patch on your YARN Cluster.\n\nEnable the MIG GPU support in the Hadoop configuration files:\n\n```\n<property>\n  <name>yarn.nodemanager.resource-plugins.gpu.use-mig-enabled</name>\n  <value>true</value>\n</property>\n\n```\n\nRestart YARN if needed to pick up any configuration changes.\n\n## Configuration\n\nThe default behavior of the GPU resource plugin on YARN is to use `auto` discovery mode of GPUs on each nodemanager.\nIt also allows you to manually allow certain gpu devices. This configuration was extended to support MIG devices.\n`yarn.nodemanager.resource-plugins.gpu.allowed-gpu-devices` configuration can be used to manually specify devices.\nGPU device is identified by their minor device number, index, and optionally MIG device index. A common approach to get\nminor device number of GPUs is using nvidia-smi -q and search Minor Number output and optionally MIG device indices.\nThe format is index:minor_number[:mig_index][,index:minor_number...]. An example of manual specification is\n0:0,1:1:0,1:1:1,2:2\" to allow YARN NodeManager to manage GPU devices with indices 0/1/2 and minor number 0/1/2\nwhere GPU indices 1 has 2 MIG enabled devices with indices 0/1.\n```\n<property>\n  <name>yarn.nodemanager.resource-plugins.gpu.allowed-gpu-devices</name>\n  <value>0:0,1:1:0,1:1:1,2:2</value>\n</property>\n```\n\nTo change the behavior of throwing when the user allocates multiple GPUs can be controlled by setting an environment variable\nwhen the Spark application is launched. Setting it to `true` means to throw if a user requests multiple GPUs (this is the default), `false`\nmeans it won't throw and if the container is allocated with multiple MIG devices from the same GPU, it is up to the\napplication to know how to use them.\n\nEnvironment variable for Spark application:\n```\n--conf spark.executorEnv.NVIDIA_MIG_PLUGIN_THROW_ON_MULTIPLE_GPUS=false\n```\n\n## Testing\nRun a Spark application using the [Rapids Accelerator for Apache Spark](https://nvidia.github.io/spark-rapids/) and request GPUs\nfrom YARN and verify they use the MIG enabled GPUs.\n"
  },
  {
    "path": "examples/MIG-Support/resource-types/gpu-mig/yarn312MIG.patch",
    "content": "diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-api/src/main/java/org/apache/hadoop/yarn/conf/YarnConfiguration.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-api/src/main/java/org/apache/hadoop/yarn/conf/YarnConfiguration.java\nindex 36fafefdbc4..e37d0a3a685 100644\n--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-api/src/main/java/org/apache/hadoop/yarn/conf/YarnConfiguration.java\n+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-api/src/main/java/org/apache/hadoop/yarn/conf/YarnConfiguration.java\n@@ -1574,6 +1574,10 @@ public static boolean isAclEnabled(Configuration conf) {\n   @Private\n   public static final String AUTOMATICALLY_DISCOVER_GPU_DEVICES = \"auto\";\n \n+  @Private\n+  public static final String USE_MIG_ENABLED_GPUS =\n+          NM_GPU_RESOURCE_PREFIX + \"use-mig-enabled\";\n+\n   /**\n    * This setting controls where to how to invoke GPU binaries\n    */\ndiff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/AssignedGpuDevice.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/AssignedGpuDevice.java\nindex 26fd9050742..e84b920dcee 100644\n--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/AssignedGpuDevice.java\n+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/AssignedGpuDevice.java\n@@ -34,6 +34,12 @@ public AssignedGpuDevice(int index, int minorNumber,\n     this.containerId = containerId.toString();\n   }\n \n+  public AssignedGpuDevice(int index, int minorNumber,\n+                           int migIndex, ContainerId containerId) {\n+    super(index, minorNumber, migIndex);\n+    this.containerId = containerId.toString();\n+  }\n+\n   public String getContainerId() {\n     return containerId;\n   }\n@@ -49,6 +55,7 @@ public boolean equals(Object obj) {\n     }\n     AssignedGpuDevice other = (AssignedGpuDevice) obj;\n     return index == other.index && minorNumber == other.minorNumber\n+        && migDeviceIndex == other.migDeviceIndex\n         && containerId.equals(other.containerId);\n   }\n \n@@ -68,12 +75,16 @@ public int compareTo(Object obj) {\n     if (0 != result) {\n       return result;\n     }\n-    return containerId.compareTo(other.containerId);\n+    result = containerId.compareTo(other.containerId);\n+    if (0 != result) {\n+      return result;\n+    }\n+    return Integer.compare(migDeviceIndex, other.migDeviceIndex);\n   }\n \n   @Override\n   public int hashCode() {\n     final int prime = 47;\n-    return prime * (prime * index + minorNumber) + containerId.hashCode();\n+    return prime * (prime * index + minorNumber + migDeviceIndex) + containerId.hashCode();\n   }\n }\ndiff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDevice.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDevice.java\nindex bce1d9fa480..3cb42d3c58f 100644\n--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDevice.java\n+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDevice.java\n@@ -26,6 +26,7 @@\n public class GpuDevice implements Serializable, Comparable {\n   protected int index;\n   protected int minorNumber;\n+  protected int migDeviceIndex = -1;\n   private static final long serialVersionUID = -6812314470754667710L;\n \n   public GpuDevice(int index, int minorNumber) {\n@@ -33,6 +34,12 @@ public GpuDevice(int index, int minorNumber) {\n     this.minorNumber = minorNumber;\n   }\n \n+  public GpuDevice(int index, int minorNumber, int migIndex) {\n+    this.index = index;\n+    this.minorNumber = minorNumber;\n+    this.migDeviceIndex = migIndex;\n+  }\n+\n   public int getIndex() {\n     return index;\n   }\n@@ -41,13 +48,17 @@ public int getMinorNumber() {\n     return minorNumber;\n   }\n \n+  public int getMIGIndex() {\n+    return migDeviceIndex;\n+  }\n+\n   @Override\n   public boolean equals(Object obj) {\n     if (obj == null || !(obj instanceof GpuDevice)) {\n       return false;\n     }\n     GpuDevice other = (GpuDevice) obj;\n-    return index == other.index && minorNumber == other.minorNumber;\n+    return index == other.index && minorNumber == other.minorNumber && migDeviceIndex == other.migDeviceIndex;\n   }\n \n   @Override\n@@ -62,17 +73,21 @@ public int compareTo(Object obj) {\n     if (0 != result) {\n       return result;\n     }\n-    return Integer.compare(minorNumber, other.minorNumber);\n+    result = Integer.compare(minorNumber, other.minorNumber);\n+    if (0 != result) {\n+      return result;\n+    }\n+    return Integer.compare(migDeviceIndex, other.migDeviceIndex);\n   }\n \n   @Override\n   public int hashCode() {\n     final int prime = 47;\n-    return prime * index + minorNumber;\n+    return prime * index + minorNumber + migDeviceIndex;\n   }\n \n   @Override\n   public String toString() {\n-    return \"(index=\" + index + \",minor_number=\" + minorNumber + \")\";\n+    return \"(index=\" + index + \",minor_number=\" + minorNumber + \",mig_index=\" + migDeviceIndex + \")\";\n   }\n }\ndiff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDiscoverer.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDiscoverer.java\nindex 6e3cf1315ce..55f7379d4cc 100644\n--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDiscoverer.java\n+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDiscoverer.java\n@@ -30,6 +30,7 @@\n import org.apache.hadoop.yarn.server.nodemanager.webapp.dao.gpu.GpuDeviceInformation;\n import org.apache.hadoop.yarn.server.nodemanager.webapp.dao.gpu.GpuDeviceInformationParser;\n import org.apache.hadoop.yarn.server.nodemanager.webapp.dao.gpu.PerGpuDeviceInformation;\n+import org.apache.hadoop.yarn.server.nodemanager.webapp.dao.gpu.PerGpuMigDevice;\n import org.slf4j.Logger;\n import org.slf4j.LoggerFactory;\n \n@@ -149,6 +150,10 @@ public synchronized GpuDeviceInformation getGpuDeviceInformation()\n         YarnConfiguration.NM_GPU_ALLOWED_DEVICES,\n         YarnConfiguration.AUTOMATICALLY_DISCOVER_GPU_DEVICES);\n \n+    Boolean useMIGEnabledGPUs = conf.getBoolean(\n+            YarnConfiguration.USE_MIG_ENABLED_GPUS, false);\n+    LOG.info(\"Use MIG enabled is: \" + useMIGEnabledGPUs);\n+\n     List<GpuDevice> gpuDevices = new ArrayList<>();\n \n     if (allowedDevicesStr.equals(\n@@ -171,21 +176,45 @@ public synchronized GpuDeviceInformation getGpuDeviceInformation()\n              i++) {\n           List<PerGpuDeviceInformation> gpuInfos =\n               lastDiscoveredGpuInformation.getGpus();\n-          gpuDevices.add(new GpuDevice(i, gpuInfos.get(i).getMinorNumber()));\n+          if (useMIGEnabledGPUs &&\n+              gpuInfos.get(i).getMIGMode().getCurrentMigMode().equalsIgnoreCase(\"enabled\")) {\n+            LOG.info(\"GPU id \" + i + \" has MIG mode enabled.\");\n+            for (PerGpuMigDevice dev: gpuInfos.get(i).getMIGDevices()) {\n+              gpuDevices.add(new GpuDevice(i, gpuInfos.get(i).getMinorNumber(), dev.getMigDeviceIndex()));\n+            }\n+          } else {\n+            gpuDevices.add(new GpuDevice(i, gpuInfos.get(i).getMinorNumber()));\n+          }\n         }\n+        LOG.info(\"Discovered GPU devices: \" + gpuDevices);\n       }\n     } else{\n       for (String s : allowedDevicesStr.split(\",\")) {\n         if (s.trim().length() > 0) {\n           String[] kv = s.trim().split(\":\");\n-          if (kv.length != 2) {\n-            throw new YarnException(\n-                \"Illegal format, it should be index:minor_number format, now it=\"\n-                    + s);\n+          if (useMIGEnabledGPUs) {\n+            if (kv.length != 2 && kv.length != 3) {\n+              throw new YarnException(\n+                      \"Illegal format, it should be index:minor_number or index:minor_number:mig_device_id\" +\n+                              \" format, now it=\" + s);\n+            }\n+            if (kv.length == 3) {\n+              // assumes this is MIG enabled device\n+              gpuDevices.add(\n+                      new GpuDevice(Integer.parseInt(kv[0]), Integer.parseInt(kv[1]), Integer.parseInt(kv[2])));\n+            } else {\n+              gpuDevices.add(\n+                      new GpuDevice(Integer.parseInt(kv[0]), Integer.parseInt(kv[1])));\n+            }\n+          } else {\n+            if (kv.length != 2) {\n+              throw new YarnException(\n+                      \"Illegal format, it should be index:minor_number format, now it=\"\n+                              + s);\n+            }\n+            gpuDevices.add(\n+                    new GpuDevice(Integer.parseInt(kv[0]), Integer.parseInt(kv[1])));\n           }\n-\n-          gpuDevices.add(\n-              new GpuDevice(Integer.parseInt(kv[0]), Integer.parseInt(kv[1])));\n         }\n       }\n       LOG.info(\"Allowed GPU devices:\" + gpuDevices);\ndiff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDockerCommandPluginFactory.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDockerCommandPluginFactory.java\nindex 051afd6c561..996cb58ac45 100644\n--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDockerCommandPluginFactory.java\n+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDockerCommandPluginFactory.java\n@@ -36,7 +36,7 @@ public static DockerCommandPlugin createGpuDockerCommandPlugin(\n     }\n     // nvidia-docker2\n     if (impl.equals(YarnConfiguration.NVIDIA_DOCKER_V2)) {\n-      return new NvidiaDockerV2CommandPlugin();\n+      return new NvidiaDockerV2CommandPlugin(conf);\n     }\n \n     throw new YarnException(\ndiff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/NvidiaDockerV2CommandPlugin.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/NvidiaDockerV2CommandPlugin.java\nindex ff25eb6ced6..c2cc0e5a2d1 100644\n--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/NvidiaDockerV2CommandPlugin.java\n+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/NvidiaDockerV2CommandPlugin.java\n@@ -21,7 +21,9 @@\n import com.google.common.annotations.VisibleForTesting;\n import org.apache.commons.logging.Log;\n import org.apache.commons.logging.LogFactory;\n+import org.apache.hadoop.conf.Configuration;\n import org.apache.hadoop.yarn.api.records.ResourceInformation;\n+import org.apache.hadoop.yarn.conf.YarnConfiguration;\n import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container;\n import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.ResourceMappings;\n import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.gpu.GpuResourceAllocator;\n@@ -45,8 +47,12 @@\n \n   private String nvidiaRuntime = \"nvidia\";\n   private String nvidiaVisibleDevices = \"NVIDIA_VISIBLE_DEVICES\";\n+  private String nvidiaMigThrowOnMultiGpus = \"NVIDIA_MIG_PLUGIN_THROW_ON_MULTIPLE_GPUS\";\n+  private Boolean isMigEnabled = false;\n \n-  public NvidiaDockerV2CommandPlugin() {}\n+  public NvidiaDockerV2CommandPlugin(Configuration conf) {\n+    isMigEnabled = conf.getBoolean(YarnConfiguration.USE_MIG_ENABLED_GPUS, false);\n+  }\n \n   private Set<GpuDevice> getAssignedGpus(Container container) {\n     ResourceMappings resourceMappings = container.getResourceMappings();\n@@ -84,10 +90,23 @@ public synchronized void updateDockerRunCommand(\n       return;\n     }\n     Map<String, String> environment = new HashMap<>();\n+    if (isMigEnabled && assignedResources.size() > 1) {\n+      Map<String, String> existingEnv = container.getLaunchContext().getEnvironment();\n+      Boolean shouldThrowOnMultipleGpus = Boolean.parseBoolean(\n+              existingEnv.getOrDefault(nvidiaMigThrowOnMultiGpus, \"true\"));\n+      if (shouldThrowOnMultipleGpus) {\n+        throw new ContainerExecutionException(\"Allocating more than 1 GPU per container is \" +\n+                \"not supported with use of MIG!\");\n+      }\n+    }\n     String gpuIndexList = \"\";\n     for (GpuDevice gpuDevice : assignedResources) {\n-      gpuIndexList = gpuIndexList + gpuDevice.getIndex() + \",\";\n-      LOG.info(\"nvidia docker2 assigned gpu index: \" + gpuDevice.getIndex());\n+      String deviceIndex = String.valueOf(gpuDevice.getIndex());\n+      if (gpuDevice.getMIGIndex() != -1) {\n+        deviceIndex = gpuDevice.getIndex() + \":\" + gpuDevice.getMIGIndex();\n+      }\n+      gpuIndexList = gpuIndexList + deviceIndex + \",\";\n+      LOG.info(\"nvidia docker2 assigned gpu index: \" + deviceIndex);\n     }\n     dockerRunCommand.addRuntime(nvidiaRuntime);\n     environment.put(nvidiaVisibleDevices,\ndiff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuDeviceInformation.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuDeviceInformation.java\nindex 25c2e3a1f1d..15cb7eac10a 100644\n--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuDeviceInformation.java\n+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuDeviceInformation.java\n@@ -22,8 +22,10 @@\n import org.apache.hadoop.classification.InterfaceStability;\n \n import javax.xml.bind.annotation.XmlElement;\n+import javax.xml.bind.annotation.XmlElementWrapper;\n import javax.xml.bind.annotation.XmlRootElement;\n import javax.xml.bind.annotation.adapters.XmlAdapter;\n+import java.util.List;\n \n /**\n  * Capture single GPU device information such as memory size, temperature,\n@@ -38,6 +40,8 @@\n   private String uuid = \"N/A\";\n   private int minorNumber = -1;\n \n+  private List<PerGpuMigDevice> migDevices;\n+  private PerGpuMigMode migMode;\n   private PerGpuUtilizations gpuUtilizations;\n   private PerGpuMemoryUsage gpuMemoryUsage;\n   private PerGpuTemperature temperature;\n@@ -108,6 +112,25 @@ public void setUuid(String uuid) {\n     this.uuid = uuid;\n   }\n \n+  @XmlElement(name = \"mig_mode\")\n+  public PerGpuMigMode getMIGMode() {\n+    return migMode;\n+  }\n+\n+  public void setMIGMode(PerGpuMigMode mode) {\n+    this.migMode = mode;\n+  }\n+\n+  @XmlElementWrapper( name = \"mig_devices\" )\n+  @XmlElement(name = \"mig_device\")\n+  public List<PerGpuMigDevice> getMIGDevices() {\n+    return migDevices;\n+  }\n+\n+  public void setMIGDevices(List<PerGpuMigDevice> devices) {\n+    this.migDevices = devices;\n+  }\n+\n   @XmlElement(name = \"product_name\")\n   public String getProductName() {\n     return productName;\ndiff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuMigDevice.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuMigDevice.java\nnew file mode 100644\nindex 00000000000..4ce7cec6e55\n--- /dev/null\n+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuMigDevice.java\n@@ -0,0 +1,48 @@\n+/**\n+ * Licensed to the Apache Software Foundation (ASF) under one\n+ * or more contributor license agreements.  See the NOTICE file\n+ * distributed with this work for additional information\n+ * regarding copyright ownership.  The ASF licenses this file\n+ * to you under the Apache License, Version 2.0 (the\n+ * \"License\"); you may not use this file except in compliance\n+ * with the License.  You may obtain a copy of the License at\n+ *\n+ *     http://www.apache.org/licenses/LICENSE-2.0\n+ *\n+ * Unless required by applicable law or agreed to in writing, software\n+ * distributed under the License is distributed on an \"AS IS\" BASIS,\n+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n+ * See the License for the specific language governing permissions and\n+ * limitations under the License.\n+ */\n+\n+package org.apache.hadoop.yarn.server.nodemanager.webapp.dao.gpu;\n+\n+import org.apache.hadoop.classification.InterfaceAudience;\n+import org.apache.hadoop.classification.InterfaceStability;\n+\n+import javax.xml.bind.annotation.XmlElement;\n+import javax.xml.bind.annotation.XmlRootElement;\n+\n+/**\n+ * GPU MIG Device Information\n+ */\n+@InterfaceAudience.Private\n+@InterfaceStability.Unstable\n+@XmlRootElement(name = \"mig_device\")\n+public class PerGpuMigDevice {\n+  private int index;\n+\n+  /**\n+   * MIG device index\n+   * @return MIG device index\n+   */\n+  @XmlElement(name = \"index\")\n+  public int getMigDeviceIndex() {\n+    return index;\n+  }\n+\n+  public void setMigDeviceIndex(int index) {\n+    this.index = index;\n+  }\n+}\ndiff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuMigMode.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuMigMode.java\nnew file mode 100644\nindex 00000000000..b706df2c3bb\n--- /dev/null\n+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuMigMode.java\n@@ -0,0 +1,48 @@\n+/**\n+ * Licensed to the Apache Software Foundation (ASF) under one\n+ * or more contributor license agreements.  See the NOTICE file\n+ * distributed with this work for additional information\n+ * regarding copyright ownership.  The ASF licenses this file\n+ * to you under the Apache License, Version 2.0 (the\n+ * \"License\"); you may not use this file except in compliance\n+ * with the License.  You may obtain a copy of the License at\n+ *\n+ *     http://www.apache.org/licenses/LICENSE-2.0\n+ *\n+ * Unless required by applicable law or agreed to in writing, software\n+ * distributed under the License is distributed on an \"AS IS\" BASIS,\n+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n+ * See the License for the specific language governing permissions and\n+ * limitations under the License.\n+ */\n+\n+package org.apache.hadoop.yarn.server.nodemanager.webapp.dao.gpu;\n+\n+import org.apache.hadoop.classification.InterfaceAudience;\n+import org.apache.hadoop.classification.InterfaceStability;\n+\n+import javax.xml.bind.annotation.XmlElement;\n+import javax.xml.bind.annotation.XmlRootElement;\n+\n+/**\n+ * GPU MIG Mode\n+ */\n+@InterfaceAudience.Private\n+@InterfaceStability.Unstable\n+@XmlRootElement(name = \"mig_mode\")\n+public class PerGpuMigMode {\n+  private String currentMigMode;\n+\n+  /**\n+   * Current MIG mode\n+   * @return MIG mode enabled or disabled\n+   */\n+  @XmlElement(name = \"current_mig\")\n+  public String getCurrentMigMode() {\n+    return currentMigMode;\n+  }\n+\n+  public void setCurrentMigMode(String migMode) {\n+    this.currentMigMode = migMode;\n+  }\n+}\ndiff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/TestGpuDiscoverer.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/TestGpuDiscoverer.java\nindex 4abb633a69a..404930d00c2 100644\n--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/TestGpuDiscoverer.java\n+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/TestGpuDiscoverer.java\n@@ -138,4 +138,47 @@ public void getNumberOfUsableGpusFromConfig() throws YarnException {\n     Assert.assertTrue(2 == usableGpuDevices.get(2).getMinorNumber());\n     Assert.assertTrue(4 == usableGpuDevices.get(3).getMinorNumber());\n   }\n+\n+  @Test\n+  public void getNumberOfUsableGpusFromConfigMIG() throws YarnException {\n+    Configuration conf = new Configuration(false);\n+    conf.set(YarnConfiguration.USE_MIG_ENABLED_GPUS, \"true\");\n+\n+    // Illegal format\n+    conf.set(YarnConfiguration.NM_GPU_ALLOWED_DEVICES, \"0:0,1:1:2,2:2:0,3\");\n+    GpuDiscoverer plugin = new GpuDiscoverer();\n+    try {\n+      plugin.initialize(conf);\n+      plugin.getGpusUsableByYarn();\n+      Assert.fail(\"Illegal format, should fail.\");\n+    } catch (YarnException e) {\n+      // Expected\n+    }\n+\n+    // Valid format\n+    conf.set(YarnConfiguration.NM_GPU_ALLOWED_DEVICES, \"0:0,1:1:0,1:1:2,2:2:0,3:4\");\n+    plugin = new GpuDiscoverer();\n+    plugin.initialize(conf);\n+\n+    List<GpuDevice> usableGpuDevices = plugin.getGpusUsableByYarn();\n+    Assert.assertEquals(5, usableGpuDevices.size());\n+\n+    Assert.assertTrue(0 == usableGpuDevices.get(0).getIndex());\n+    Assert.assertTrue(1 == usableGpuDevices.get(1).getIndex());\n+    Assert.assertTrue(1 == usableGpuDevices.get(2).getIndex());\n+    Assert.assertTrue(2 == usableGpuDevices.get(3).getIndex());\n+    Assert.assertTrue(3 == usableGpuDevices.get(4).getIndex());\n+\n+    Assert.assertTrue(0 == usableGpuDevices.get(0).getMinorNumber());\n+    Assert.assertTrue(1 == usableGpuDevices.get(1).getMinorNumber());\n+    Assert.assertTrue(1 == usableGpuDevices.get(2).getMinorNumber());\n+    Assert.assertTrue(2 == usableGpuDevices.get(3).getMinorNumber());\n+    Assert.assertTrue(4 == usableGpuDevices.get(4).getMinorNumber());\n+\n+    Assert.assertTrue(-1 == usableGpuDevices.get(0).getMIGIndex());\n+    Assert.assertTrue(0 == usableGpuDevices.get(1).getMIGIndex());\n+    Assert.assertTrue(2 == usableGpuDevices.get(2).getMIGIndex());\n+    Assert.assertTrue(0 == usableGpuDevices.get(3).getMIGIndex());\n+    Assert.assertTrue(-1 == usableGpuDevices.get(4).getMIGIndex());\n+  }\n }\ndiff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/TestNvidiaDockerV2CommandPlugin.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/TestNvidiaDockerV2CommandPlugin.java\nindex b0b523360ef..798a95cb009 100644\n--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/TestNvidiaDockerV2CommandPlugin.java\n+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/TestNvidiaDockerV2CommandPlugin.java\n@@ -20,10 +20,14 @@\n \n import com.google.common.collect.ImmutableList;\n import com.google.common.collect.Sets;\n+import org.apache.hadoop.conf.Configuration;\n+import org.apache.hadoop.yarn.conf.YarnConfiguration;\n+import org.apache.hadoop.yarn.api.records.ContainerLaunchContext;\n import org.apache.hadoop.yarn.api.records.ResourceInformation;\n import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container;\n import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.ResourceMappings;\n import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.docker.DockerRunCommand;\n+import org.apache.hadoop.yarn.server.nodemanager.containermanager.runtime.ContainerExecutionException;\n import org.junit.Assert;\n import org.junit.Test;\n \n@@ -69,7 +73,13 @@ private boolean commandlinesEquals(Map<String, List<String>> cli1,\n       extends NvidiaDockerV2CommandPlugin {\n     private boolean requestsGpu = false;\n \n-    MyNvidiaDockerV2CommandPlugin() {}\n+    MyNvidiaDockerV2CommandPlugin() {\n+      super(new Configuration());\n+    }\n+\n+    MyNvidiaDockerV2CommandPlugin(Configuration conf) {\n+      super(conf);\n+    }\n \n     public void setRequestsGpu(boolean r) {\n       requestsGpu = r;\n@@ -127,4 +137,118 @@ public void testPlugin() throws Exception {\n     // runtime should exist\n     Assert.assertTrue(newCommandLine.containsKey(\"runtime\"));\n   }\n-}\n\\ No newline at end of file\n+\n+  @Test\n+  public void testPluginMIG() throws Exception {\n+    DockerRunCommand runCommand = new DockerRunCommand(\"container_1\", \"user\",\n+        \"fakeimage\");\n+\n+    Map<String, List<String>> originalCommandline = copyCommandLine(\n+        runCommand.getDockerCommandWithArguments());\n+\n+    Configuration conf = new Configuration();\n+    conf.set(YarnConfiguration.USE_MIG_ENABLED_GPUS, \"true\");\n+    MyNvidiaDockerV2CommandPlugin\n+        commandPlugin = new MyNvidiaDockerV2CommandPlugin(conf);\n+\n+    Container nmContainer = mock(Container.class);\n+    ResourceMappings resourceMappings = new ResourceMappings();\n+    when(nmContainer.getResourceMappings()).thenReturn(resourceMappings);\n+\n+    // Assign GPU resource\n+    ResourceMappings.AssignedResources assigned =\n+        new ResourceMappings.AssignedResources();\n+    assigned.updateAssignedResources(\n+        ImmutableList.of(new GpuDevice(0, 0, 0)));\n+    resourceMappings.addAssignedResources(ResourceInformation.GPU_URI,\n+        assigned);\n+\n+    commandPlugin.setRequestsGpu(true);\n+    commandPlugin.updateDockerRunCommand(runCommand, nmContainer);\n+    Map<String, List<String>> newCommandLine =\n+        runCommand.getDockerCommandWithArguments();\n+\n+    // Command line will be updated\n+    Assert.assertFalse(commandlinesEquals(originalCommandline, newCommandLine));\n+    // NVIDIA_VISIBLE_DEVICES will be set\n+    Assert.assertTrue(\n+        runCommand.getEnv().get(\"NVIDIA_VISIBLE_DEVICES\").equals(\"0:0\"));\n+    // runtime should exist\n+    Assert.assertTrue(newCommandLine.containsKey(\"runtime\"));\n+  }\n+\n+  @Test(expected = ContainerExecutionException.class)\n+  public void testPluginMIGThrowsMulti() throws Exception {\n+    DockerRunCommand runCommand = new DockerRunCommand(\"container_1\", \"user\",\n+        \"fakeimage\");\n+\n+    Map<String, List<String>> originalCommandline = copyCommandLine(\n+        runCommand.getDockerCommandWithArguments());\n+\n+    Configuration conf = new Configuration();\n+    conf.set(YarnConfiguration.USE_MIG_ENABLED_GPUS, \"true\");\n+    MyNvidiaDockerV2CommandPlugin\n+        commandPlugin = new MyNvidiaDockerV2CommandPlugin(conf);\n+\n+    Container nmContainer = mock(Container.class);\n+    ResourceMappings resourceMappings = new ResourceMappings();\n+    Map<String, String> env = new HashMap<>();\n+    env.put(\"NVIDIA_MIG_PLUGIN_THROW_ON_MULTIPLE_GPUS\", \"true\");\n+    when(nmContainer.getResourceMappings()).thenReturn(resourceMappings);\n+    ContainerLaunchContext launchCtx = mock(ContainerLaunchContext.class);\n+    when(nmContainer.getLaunchContext()).thenReturn(launchCtx);\n+    when(launchCtx.getEnvironment()).thenReturn(env);\n+\n+    // Assign GPU resource\n+    ResourceMappings.AssignedResources assigned =\n+        new ResourceMappings.AssignedResources();\n+    assigned.updateAssignedResources(\n+        ImmutableList.of(new GpuDevice(0, 0, 0), new GpuDevice(1, 1, 2)));\n+    resourceMappings.addAssignedResources(ResourceInformation.GPU_URI,\n+        assigned);\n+\n+    commandPlugin.setRequestsGpu(true);\n+    commandPlugin.updateDockerRunCommand(runCommand, nmContainer);\n+  }\n+\n+  @Test\n+  public void testPluginMIGNoThrowsMulti() throws Exception {\n+    DockerRunCommand runCommand = new DockerRunCommand(\"container_1\", \"user\",\n+        \"fakeimage\");\n+\n+    Map<String, List<String>> originalCommandline = copyCommandLine(\n+        runCommand.getDockerCommandWithArguments());\n+\n+    Configuration conf = new Configuration();\n+    conf.set(YarnConfiguration.USE_MIG_ENABLED_GPUS, \"true\");\n+    MyNvidiaDockerV2CommandPlugin\n+        commandPlugin = new MyNvidiaDockerV2CommandPlugin(conf);\n+\n+    Container nmContainer = mock(Container.class);\n+    ResourceMappings resourceMappings = new ResourceMappings();\n+    Map<String, String> env = new HashMap<>();\n+    env.put(\"NVIDIA_MIG_PLUGIN_THROW_ON_MULTIPLE_GPUS\", \"false\");\n+    when(nmContainer.getResourceMappings()).thenReturn(resourceMappings);\n+    ContainerLaunchContext launchCtx = mock(ContainerLaunchContext.class);\n+    when(nmContainer.getLaunchContext()).thenReturn(launchCtx);\n+    when(launchCtx.getEnvironment()).thenReturn(env);\n+\n+    // Assign GPU resource\n+    ResourceMappings.AssignedResources assigned =\n+        new ResourceMappings.AssignedResources();\n+    assigned.updateAssignedResources(\n+        ImmutableList.of(new GpuDevice(0, 0, 0), new GpuDevice(1, 1, 2)));\n+    resourceMappings.addAssignedResources(ResourceInformation.GPU_URI,\n+        assigned);\n+\n+    commandPlugin.setRequestsGpu(true);\n+    commandPlugin.updateDockerRunCommand(runCommand, nmContainer);\n+    Map<String, List<String>> newCommandLine =\n+        runCommand.getDockerCommandWithArguments();\n+    // NVIDIA_VISIBLE_DEVICES will be set\n+    Assert.assertTrue(\n+        runCommand.getEnv().get(\"NVIDIA_VISIBLE_DEVICES\").equals(\"0:0,1:2\"));\n+    // runtime should exist\n+    Assert.assertTrue(newCommandLine.containsKey(\"runtime\"));\n+  }\n+}\n"
  },
  {
    "path": "examples/MIG-Support/resource-types/gpu-mig/yarn313to315MIG.patch",
    "content": "diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-api/src/main/java/org/apache/hadoop/yarn/conf/YarnConfiguration.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-api/src/main/java/org/apache/hadoop/yarn/conf/YarnConfiguration.java\nindex 737baee70bb..0e113036a80 100644\n--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-api/src/main/java/org/apache/hadoop/yarn/conf/YarnConfiguration.java\n+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-api/src/main/java/org/apache/hadoop/yarn/conf/YarnConfiguration.java\n@@ -1655,6 +1655,10 @@ public static boolean isAclEnabled(Configuration conf) {\n   @Private\n   public static final String AUTOMATICALLY_DISCOVER_GPU_DEVICES = \"auto\";\n \n+  @Private\n+  public static final String USE_MIG_ENABLED_GPUS =\n+          NM_GPU_RESOURCE_PREFIX + \"use-mig-enabled\";\n+\n   /**\n    * This setting controls where to how to invoke GPU binaries\n    */\ndiff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/AssignedGpuDevice.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/AssignedGpuDevice.java\nindex 26fd9050742..e84b920dcee 100644\n--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/AssignedGpuDevice.java\n+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/AssignedGpuDevice.java\n@@ -34,6 +34,12 @@ public AssignedGpuDevice(int index, int minorNumber,\n     this.containerId = containerId.toString();\n   }\n \n+  public AssignedGpuDevice(int index, int minorNumber,\n+                           int migIndex, ContainerId containerId) {\n+    super(index, minorNumber, migIndex);\n+    this.containerId = containerId.toString();\n+  }\n+\n   public String getContainerId() {\n     return containerId;\n   }\n@@ -49,6 +55,7 @@ public boolean equals(Object obj) {\n     }\n     AssignedGpuDevice other = (AssignedGpuDevice) obj;\n     return index == other.index && minorNumber == other.minorNumber\n+        && migDeviceIndex == other.migDeviceIndex\n         && containerId.equals(other.containerId);\n   }\n \n@@ -68,12 +75,16 @@ public int compareTo(Object obj) {\n     if (0 != result) {\n       return result;\n     }\n-    return containerId.compareTo(other.containerId);\n+    result = containerId.compareTo(other.containerId);\n+    if (0 != result) {\n+      return result;\n+    }\n+    return Integer.compare(migDeviceIndex, other.migDeviceIndex);\n   }\n \n   @Override\n   public int hashCode() {\n     final int prime = 47;\n-    return prime * (prime * index + minorNumber) + containerId.hashCode();\n+    return prime * (prime * index + minorNumber + migDeviceIndex) + containerId.hashCode();\n   }\n }\ndiff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDevice.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDevice.java\nindex bce1d9fa480..3cb42d3c58f 100644\n--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDevice.java\n+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDevice.java\n@@ -26,6 +26,7 @@\n public class GpuDevice implements Serializable, Comparable {\n   protected int index;\n   protected int minorNumber;\n+  protected int migDeviceIndex = -1;\n   private static final long serialVersionUID = -6812314470754667710L;\n \n   public GpuDevice(int index, int minorNumber) {\n@@ -33,6 +34,12 @@ public GpuDevice(int index, int minorNumber) {\n     this.minorNumber = minorNumber;\n   }\n \n+  public GpuDevice(int index, int minorNumber, int migIndex) {\n+    this.index = index;\n+    this.minorNumber = minorNumber;\n+    this.migDeviceIndex = migIndex;\n+  }\n+\n   public int getIndex() {\n     return index;\n   }\n@@ -41,13 +48,17 @@ public int getMinorNumber() {\n     return minorNumber;\n   }\n \n+  public int getMIGIndex() {\n+    return migDeviceIndex;\n+  }\n+\n   @Override\n   public boolean equals(Object obj) {\n     if (obj == null || !(obj instanceof GpuDevice)) {\n       return false;\n     }\n     GpuDevice other = (GpuDevice) obj;\n-    return index == other.index && minorNumber == other.minorNumber;\n+    return index == other.index && minorNumber == other.minorNumber && migDeviceIndex == other.migDeviceIndex;\n   }\n \n   @Override\n@@ -62,17 +73,21 @@ public int compareTo(Object obj) {\n     if (0 != result) {\n       return result;\n     }\n-    return Integer.compare(minorNumber, other.minorNumber);\n+    result = Integer.compare(minorNumber, other.minorNumber);\n+    if (0 != result) {\n+      return result;\n+    }\n+    return Integer.compare(migDeviceIndex, other.migDeviceIndex);\n   }\n \n   @Override\n   public int hashCode() {\n     final int prime = 47;\n-    return prime * index + minorNumber;\n+    return prime * index + minorNumber + migDeviceIndex;\n   }\n \n   @Override\n   public String toString() {\n-    return \"(index=\" + index + \",minor_number=\" + minorNumber + \")\";\n+    return \"(index=\" + index + \",minor_number=\" + minorNumber + \",mig_index=\" + migDeviceIndex + \")\";\n   }\n }\ndiff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDeviceSpecificationException.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDeviceSpecificationException.java\nindex 9d61b91a1f2..d775aab0226 100644\n--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDeviceSpecificationException.java\n+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDeviceSpecificationException.java\n@@ -26,6 +26,8 @@\n public final class GpuDeviceSpecificationException extends YarnException {\n   private static final String VALID_FORMAT_MESSAGE = \"The valid format \" +\n       \"should be: index:minor_number\";\n+  private static final String VALID_MIG_FORMAT_MESSAGE = VALID_FORMAT_MESSAGE +\n+      \"or with MIG enabled: index:minor_number:mig_index\";\n \n   private GpuDeviceSpecificationException(String message) {\n     super(message);\n@@ -57,12 +59,25 @@ public static GpuDeviceSpecificationException createWithWrongValueSpecified(\n     return new GpuDeviceSpecificationException(message);\n   }\n \n+  public static GpuDeviceSpecificationException createWithWrongValueSpecifiedMIG(\n+      String device, String configValue) {\n+    final String message = createIllegalFormatMessageMIG(device, configValue);\n+    return new GpuDeviceSpecificationException(message);\n+  }\n+\n   public static GpuDeviceSpecificationException createWithDuplicateValueSpecified(\n       String device, String configValue) {\n     final String message = createDuplicateFormatMessage(device, configValue);\n     return new GpuDeviceSpecificationException(message);\n   }\n \n+  private static String createIllegalFormatMessageMIG(String device,\n+      String configValue) {\n+    return String.format(\"Illegal format of individual GPU device: %s, \" +\n+            \"the whole config value was: '%s'! \" + VALID_MIG_FORMAT_MESSAGE,\n+        device, configValue);\n+  }\n+\n   private static String createIllegalFormatMessage(String device,\n       String configValue) {\n     return String.format(\"Illegal format of individual GPU device: %s, \" +\n@@ -79,4 +94,4 @@ private static String createDuplicateFormatMessage(String device,\n             \"! Current value of the configuration is: %s\",\n         device, configValue);\n   }\n-}\n\\ No newline at end of file\n+}\ndiff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDiscoverer.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDiscoverer.java\nindex ce767229e50..c74651b41df 100644\n--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDiscoverer.java\n+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDiscoverer.java\n@@ -31,6 +31,7 @@\n import org.apache.hadoop.yarn.server.nodemanager.webapp.dao.gpu.GpuDeviceInformation;\n import org.apache.hadoop.yarn.server.nodemanager.webapp.dao.gpu.GpuDeviceInformationParser;\n import org.apache.hadoop.yarn.server.nodemanager.webapp.dao.gpu.PerGpuDeviceInformation;\n+import org.apache.hadoop.yarn.server.nodemanager.webapp.dao.gpu.PerGpuMigDevice;\n import org.slf4j.Logger;\n import org.slf4j.LoggerFactory;\n \n@@ -69,6 +70,7 @@\n   private GpuDeviceInformation lastDiscoveredGpuInformation = null;\n \n   private List<GpuDevice> gpuDevicesFromUser;\n+  private Boolean useMIGEnabledGPUs = false;\n \n   private void validateConfOrThrowException() throws YarnException {\n     if (conf == null) {\n@@ -194,8 +196,17 @@ private boolean IsAutoDiscoveryEnabled() {\n       for (int i = 0; i < numberOfGpus; i++) {\n         List<PerGpuDeviceInformation> gpuInfos =\n             lastDiscoveredGpuInformation.getGpus();\n-        gpuDevices.add(new GpuDevice(i, gpuInfos.get(i).getMinorNumber()));\n+        if (useMIGEnabledGPUs &&\n+            gpuInfos.get(i).getMIGMode().getCurrentMigMode().equalsIgnoreCase(\"enabled\")) {\n+          LOG.info(\"GPU id \" + i + \" has MIG mode enabled.\");\n+          for (PerGpuMigDevice dev: gpuInfos.get(i).getMIGDevices()) {\n+            gpuDevices.add(new GpuDevice(i, gpuInfos.get(i).getMinorNumber(), dev.getMigDeviceIndex()));\n+          }\n+        } else {\n+          gpuDevices.add(new GpuDevice(i, gpuInfos.get(i).getMinorNumber()));\n+        }\n       }\n+      LOG.info(\"Discovered GPU devices: \" + gpuDevices);\n     }\n     return gpuDevices;\n   }\n@@ -218,18 +229,39 @@ private boolean IsAutoDiscoveryEnabled() {\n     for (String device : devices.split(\",\")) {\n       if (device.trim().length() > 0) {\n         String[] splitByColon = device.trim().split(\":\");\n-        if (splitByColon.length != 2) {\n-          throw GpuDeviceSpecificationException.\n-              createWithWrongValueSpecified(device, devices);\n-        }\n-\n-        GpuDevice gpuDevice = parseGpuDevice(device, splitByColon, devices);\n-        if (!gpuDevices.contains(gpuDevice)) {\n-          gpuDevices.add(gpuDevice);\n+        if (useMIGEnabledGPUs) {\n+          if (splitByColon.length != 2 && splitByColon.length != 3) {\n+            throw GpuDeviceSpecificationException.\n+                createWithWrongValueSpecifiedMIG(device, devices);\n+          }\n+\n+          GpuDevice gpuDevice;\n+          if (splitByColon.length == 3) {\n+            gpuDevice = parseGpuMIGDevice(device, splitByColon, devices);\n+          } else {\n+            gpuDevice = parseGpuDevice(device, splitByColon, devices);\n+          }\n+          if (!gpuDevices.contains(gpuDevice)) {\n+            gpuDevices.add(gpuDevice);\n+          } else {\n+            throw GpuDeviceSpecificationException\n+                .createWithDuplicateValueSpecified(device, devices);\n+          }\n         } else {\n-          throw GpuDeviceSpecificationException\n-              .createWithDuplicateValueSpecified(device, devices);\n+          if (splitByColon.length != 2) {\n+            throw GpuDeviceSpecificationException.\n+                createWithWrongValueSpecified(device, devices);\n+          }\n+\n+          GpuDevice gpuDevice = parseGpuDevice(device, splitByColon, devices);\n+          if (!gpuDevices.contains(gpuDevice)) {\n+            gpuDevices.add(gpuDevice);\n+          } else {\n+            throw GpuDeviceSpecificationException\n+                .createWithDuplicateValueSpecified(device, devices);\n+          }\n         }\n+\n       }\n     }\n     LOG.info(\"Allowed GPU devices:\" + gpuDevices);\n@@ -237,6 +269,19 @@ private boolean IsAutoDiscoveryEnabled() {\n     return gpuDevices;\n   }\n \n+  private GpuDevice parseGpuMIGDevice(String device, String[] splitByColon,\n+      String allowedDevicesStr) throws YarnException {\n+    try {\n+      int index = Integer.parseInt(splitByColon[0]);\n+      int minorNumber = Integer.parseInt(splitByColon[1]);\n+      int migIndex = Integer.parseInt(splitByColon[2]);\n+      return new GpuDevice(index, minorNumber, migIndex);\n+    } catch (NumberFormatException e) {\n+      throw GpuDeviceSpecificationException.\n+          createWithWrongValueSpecified(device, allowedDevicesStr, e);\n+    }\n+  }\n+\n   private GpuDevice parseGpuDevice(String device, String[] splitByColon,\n       String allowedDevicesStr) throws YarnException {\n     try {\n@@ -268,6 +313,9 @@ public synchronized void initialize(Configuration config)\n         LOG.warn(msg);\n       }\n     }\n+    useMIGEnabledGPUs = conf.getBoolean(YarnConfiguration.USE_MIG_ENABLED_GPUS, false);\n+    LOG.info(\"Use MIG enabled is: \" + useMIGEnabledGPUs);\n+\n   }\n \n   private void lookUpAutoDiscoveryBinary(Configuration config)\ndiff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDockerCommandPluginFactory.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDockerCommandPluginFactory.java\nindex 051afd6c561..996cb58ac45 100644\n--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDockerCommandPluginFactory.java\n+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDockerCommandPluginFactory.java\n@@ -36,7 +36,7 @@ public static DockerCommandPlugin createGpuDockerCommandPlugin(\n     }\n     // nvidia-docker2\n     if (impl.equals(YarnConfiguration.NVIDIA_DOCKER_V2)) {\n-      return new NvidiaDockerV2CommandPlugin();\n+      return new NvidiaDockerV2CommandPlugin(conf);\n     }\n \n     throw new YarnException(\ndiff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/NvidiaDockerV2CommandPlugin.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/NvidiaDockerV2CommandPlugin.java\nindex ff25eb6ced6..c2cc0e5a2d1 100644\n--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/NvidiaDockerV2CommandPlugin.java\n+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/NvidiaDockerV2CommandPlugin.java\n@@ -21,7 +21,9 @@\n import com.google.common.annotations.VisibleForTesting;\n import org.apache.commons.logging.Log;\n import org.apache.commons.logging.LogFactory;\n+import org.apache.hadoop.conf.Configuration;\n import org.apache.hadoop.yarn.api.records.ResourceInformation;\n+import org.apache.hadoop.yarn.conf.YarnConfiguration;\n import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container;\n import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.ResourceMappings;\n import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.gpu.GpuResourceAllocator;\n@@ -45,8 +47,12 @@\n \n   private String nvidiaRuntime = \"nvidia\";\n   private String nvidiaVisibleDevices = \"NVIDIA_VISIBLE_DEVICES\";\n+  private String nvidiaMigThrowOnMultiGpus = \"NVIDIA_MIG_PLUGIN_THROW_ON_MULTIPLE_GPUS\";\n+  private Boolean isMigEnabled = false;\n \n-  public NvidiaDockerV2CommandPlugin() {}\n+  public NvidiaDockerV2CommandPlugin(Configuration conf) {\n+    isMigEnabled = conf.getBoolean(YarnConfiguration.USE_MIG_ENABLED_GPUS, false);\n+  }\n \n   private Set<GpuDevice> getAssignedGpus(Container container) {\n     ResourceMappings resourceMappings = container.getResourceMappings();\n@@ -84,10 +90,23 @@ public synchronized void updateDockerRunCommand(\n       return;\n     }\n     Map<String, String> environment = new HashMap<>();\n+    if (isMigEnabled && assignedResources.size() > 1) {\n+      Map<String, String> existingEnv = container.getLaunchContext().getEnvironment();\n+      Boolean shouldThrowOnMultipleGpus = Boolean.parseBoolean(\n+              existingEnv.getOrDefault(nvidiaMigThrowOnMultiGpus, \"true\"));\n+      if (shouldThrowOnMultipleGpus) {\n+        throw new ContainerExecutionException(\"Allocating more than 1 GPU per container is \" +\n+                \"not supported with use of MIG!\");\n+      }\n+    }\n     String gpuIndexList = \"\";\n     for (GpuDevice gpuDevice : assignedResources) {\n-      gpuIndexList = gpuIndexList + gpuDevice.getIndex() + \",\";\n-      LOG.info(\"nvidia docker2 assigned gpu index: \" + gpuDevice.getIndex());\n+      String deviceIndex = String.valueOf(gpuDevice.getIndex());\n+      if (gpuDevice.getMIGIndex() != -1) {\n+        deviceIndex = gpuDevice.getIndex() + \":\" + gpuDevice.getMIGIndex();\n+      }\n+      gpuIndexList = gpuIndexList + deviceIndex + \",\";\n+      LOG.info(\"nvidia docker2 assigned gpu index: \" + deviceIndex);\n     }\n     dockerRunCommand.addRuntime(nvidiaRuntime);\n     environment.put(nvidiaVisibleDevices,\ndiff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuDeviceInformation.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuDeviceInformation.java\nindex 11ff2a4c49c..939ed46aac7 100644\n--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuDeviceInformation.java\n+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuDeviceInformation.java\n@@ -22,8 +22,10 @@\n import org.apache.hadoop.classification.InterfaceStability;\n \n import javax.xml.bind.annotation.XmlElement;\n+import javax.xml.bind.annotation.XmlElementWrapper;\n import javax.xml.bind.annotation.XmlRootElement;\n import javax.xml.bind.annotation.adapters.XmlAdapter;\n+import java.util.List;\n \n /**\n  * Capture single GPU device information such as memory size, temperature,\n@@ -37,6 +39,8 @@\n   private String uuid = \"N/A\";\n   private int minorNumber = -1;\n \n+  private List<PerGpuMigDevice> migDevices;\n+  private PerGpuMigMode migMode;\n   private PerGpuUtilizations gpuUtilizations;\n   private PerGpuMemoryUsage gpuMemoryUsage;\n   private PerGpuTemperature temperature;\n@@ -107,6 +111,25 @@ public void setUuid(String uuid) {\n     this.uuid = uuid;\n   }\n \n+  @XmlElement(name = \"mig_mode\")\n+  public PerGpuMigMode getMIGMode() {\n+    return migMode;\n+  }\n+\n+  public void setMIGMode(PerGpuMigMode mode) {\n+    this.migMode = mode;\n+  }\n+\n+  @XmlElementWrapper( name = \"mig_devices\" )\n+  @XmlElement(name = \"mig_device\")\n+  public List<PerGpuMigDevice> getMIGDevices() {\n+    return migDevices;\n+  }\n+\n+  public void setMIGDevices(List<PerGpuMigDevice> devices) {\n+    this.migDevices = devices;\n+  }\n+\n   @XmlElement(name = \"product_name\")\n   public String getProductName() {\n     return productName;\ndiff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuMigDevice.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuMigDevice.java\nnew file mode 100644\nindex 00000000000..4ce7cec6e55\n--- /dev/null\n+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuMigDevice.java\n@@ -0,0 +1,48 @@\n+/**\n+ * Licensed to the Apache Software Foundation (ASF) under one\n+ * or more contributor license agreements.  See the NOTICE file\n+ * distributed with this work for additional information\n+ * regarding copyright ownership.  The ASF licenses this file\n+ * to you under the Apache License, Version 2.0 (the\n+ * \"License\"); you may not use this file except in compliance\n+ * with the License.  You may obtain a copy of the License at\n+ *\n+ *     http://www.apache.org/licenses/LICENSE-2.0\n+ *\n+ * Unless required by applicable law or agreed to in writing, software\n+ * distributed under the License is distributed on an \"AS IS\" BASIS,\n+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n+ * See the License for the specific language governing permissions and\n+ * limitations under the License.\n+ */\n+\n+package org.apache.hadoop.yarn.server.nodemanager.webapp.dao.gpu;\n+\n+import org.apache.hadoop.classification.InterfaceAudience;\n+import org.apache.hadoop.classification.InterfaceStability;\n+\n+import javax.xml.bind.annotation.XmlElement;\n+import javax.xml.bind.annotation.XmlRootElement;\n+\n+/**\n+ * GPU MIG Device Information\n+ */\n+@InterfaceAudience.Private\n+@InterfaceStability.Unstable\n+@XmlRootElement(name = \"mig_device\")\n+public class PerGpuMigDevice {\n+  private int index;\n+\n+  /**\n+   * MIG device index\n+   * @return MIG device index\n+   */\n+  @XmlElement(name = \"index\")\n+  public int getMigDeviceIndex() {\n+    return index;\n+  }\n+\n+  public void setMigDeviceIndex(int index) {\n+    this.index = index;\n+  }\n+}\ndiff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuMigMode.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuMigMode.java\nnew file mode 100644\nindex 00000000000..b706df2c3bb\n--- /dev/null\n+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuMigMode.java\n@@ -0,0 +1,48 @@\n+/**\n+ * Licensed to the Apache Software Foundation (ASF) under one\n+ * or more contributor license agreements.  See the NOTICE file\n+ * distributed with this work for additional information\n+ * regarding copyright ownership.  The ASF licenses this file\n+ * to you under the Apache License, Version 2.0 (the\n+ * \"License\"); you may not use this file except in compliance\n+ * with the License.  You may obtain a copy of the License at\n+ *\n+ *     http://www.apache.org/licenses/LICENSE-2.0\n+ *\n+ * Unless required by applicable law or agreed to in writing, software\n+ * distributed under the License is distributed on an \"AS IS\" BASIS,\n+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n+ * See the License for the specific language governing permissions and\n+ * limitations under the License.\n+ */\n+\n+package org.apache.hadoop.yarn.server.nodemanager.webapp.dao.gpu;\n+\n+import org.apache.hadoop.classification.InterfaceAudience;\n+import org.apache.hadoop.classification.InterfaceStability;\n+\n+import javax.xml.bind.annotation.XmlElement;\n+import javax.xml.bind.annotation.XmlRootElement;\n+\n+/**\n+ * GPU MIG Mode\n+ */\n+@InterfaceAudience.Private\n+@InterfaceStability.Unstable\n+@XmlRootElement(name = \"mig_mode\")\n+public class PerGpuMigMode {\n+  private String currentMigMode;\n+\n+  /**\n+   * Current MIG mode\n+   * @return MIG mode enabled or disabled\n+   */\n+  @XmlElement(name = \"current_mig\")\n+  public String getCurrentMigMode() {\n+    return currentMigMode;\n+  }\n+\n+  public void setCurrentMigMode(String migMode) {\n+    this.currentMigMode = migMode;\n+  }\n+}\ndiff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/TestGpuDiscoverer.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/TestGpuDiscoverer.java\nindex f0f100c1f8b..02b213b6734 100644\n--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/TestGpuDiscoverer.java\n+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/TestGpuDiscoverer.java\n@@ -372,6 +372,37 @@ public void testGetNumberOfUsableGpusFromConfig() throws YarnException {\n     assertEquals(4, usableGpuDevices.get(3).getMinorNumber());\n   }\n \n+  @Test\n+  public void testGetNumberOfUsableGpusFromConfigMIG() throws YarnException {\n+    Configuration conf = createConfigWithAllowedDevices(\"0:0,1:1:0,1:1:3,2:2,3:4\");\n+    conf.set(YarnConfiguration.USE_MIG_ENABLED_GPUS, \"true\");\n+    GpuDiscoverer discoverer = new GpuDiscoverer();\n+    discoverer.initialize(conf);\n+\n+    List<GpuDevice> usableGpuDevices = discoverer.getGpusUsableByYarn();\n+    assertEquals(5, usableGpuDevices.size());\n+\n+    assertEquals(0, usableGpuDevices.get(0).getIndex());\n+    assertEquals(0, usableGpuDevices.get(0).getMinorNumber());\n+    assertEquals(-1, usableGpuDevices.get(0).getMIGIndex());\n+\n+    assertEquals(1, usableGpuDevices.get(1).getIndex());\n+    assertEquals(1, usableGpuDevices.get(1).getMinorNumber());\n+    assertEquals(0, usableGpuDevices.get(1).getMIGIndex());\n+\n+    assertEquals(1, usableGpuDevices.get(2).getIndex());\n+    assertEquals(1, usableGpuDevices.get(2).getMinorNumber());\n+    assertEquals(3, usableGpuDevices.get(2).getMIGIndex());\n+\n+    assertEquals(2, usableGpuDevices.get(3).getIndex());\n+    assertEquals(2, usableGpuDevices.get(3).getMinorNumber());\n+    assertEquals(-1, usableGpuDevices.get(3).getMIGIndex());\n+\n+    assertEquals(3, usableGpuDevices.get(4).getIndex());\n+    assertEquals(4, usableGpuDevices.get(4).getMinorNumber());\n+    assertEquals(-1, usableGpuDevices.get(4).getMIGIndex());\n+  }\n+\n   @Test\n   public void testGetNumberOfUsableGpusFromConfigDuplicateValues()\n       throws YarnException {\n@@ -512,4 +543,5 @@ public void testScriptNotCalled() throws YarnException {\n \n     verify(gpuSpy, never()).getGpuDeviceInformation();\n   }\n+\n }\ndiff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/TestNvidiaDockerV2CommandPlugin.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/TestNvidiaDockerV2CommandPlugin.java\nindex b0b523360ef..798a95cb009 100644\n--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/TestNvidiaDockerV2CommandPlugin.java\n+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/TestNvidiaDockerV2CommandPlugin.java\n@@ -20,10 +20,14 @@\n \n import com.google.common.collect.ImmutableList;\n import com.google.common.collect.Sets;\n+import org.apache.hadoop.conf.Configuration;\n+import org.apache.hadoop.yarn.conf.YarnConfiguration;\n+import org.apache.hadoop.yarn.api.records.ContainerLaunchContext;\n import org.apache.hadoop.yarn.api.records.ResourceInformation;\n import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container;\n import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.ResourceMappings;\n import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.docker.DockerRunCommand;\n+import org.apache.hadoop.yarn.server.nodemanager.containermanager.runtime.ContainerExecutionException;\n import org.junit.Assert;\n import org.junit.Test;\n \n@@ -69,7 +73,13 @@ private boolean commandlinesEquals(Map<String, List<String>> cli1,\n       extends NvidiaDockerV2CommandPlugin {\n     private boolean requestsGpu = false;\n \n-    MyNvidiaDockerV2CommandPlugin() {}\n+    MyNvidiaDockerV2CommandPlugin() {\n+      super(new Configuration());\n+    }\n+\n+    MyNvidiaDockerV2CommandPlugin(Configuration conf) {\n+      super(conf);\n+    }\n \n     public void setRequestsGpu(boolean r) {\n       requestsGpu = r;\n@@ -127,4 +137,118 @@ public void testPlugin() throws Exception {\n     // runtime should exist\n     Assert.assertTrue(newCommandLine.containsKey(\"runtime\"));\n   }\n-}\n\\ No newline at end of file\n+\n+  @Test\n+  public void testPluginMIG() throws Exception {\n+    DockerRunCommand runCommand = new DockerRunCommand(\"container_1\", \"user\",\n+        \"fakeimage\");\n+\n+    Map<String, List<String>> originalCommandline = copyCommandLine(\n+        runCommand.getDockerCommandWithArguments());\n+\n+    Configuration conf = new Configuration();\n+    conf.set(YarnConfiguration.USE_MIG_ENABLED_GPUS, \"true\");\n+    MyNvidiaDockerV2CommandPlugin\n+        commandPlugin = new MyNvidiaDockerV2CommandPlugin(conf);\n+\n+    Container nmContainer = mock(Container.class);\n+    ResourceMappings resourceMappings = new ResourceMappings();\n+    when(nmContainer.getResourceMappings()).thenReturn(resourceMappings);\n+\n+    // Assign GPU resource\n+    ResourceMappings.AssignedResources assigned =\n+        new ResourceMappings.AssignedResources();\n+    assigned.updateAssignedResources(\n+        ImmutableList.of(new GpuDevice(0, 0, 0)));\n+    resourceMappings.addAssignedResources(ResourceInformation.GPU_URI,\n+        assigned);\n+\n+    commandPlugin.setRequestsGpu(true);\n+    commandPlugin.updateDockerRunCommand(runCommand, nmContainer);\n+    Map<String, List<String>> newCommandLine =\n+        runCommand.getDockerCommandWithArguments();\n+\n+    // Command line will be updated\n+    Assert.assertFalse(commandlinesEquals(originalCommandline, newCommandLine));\n+    // NVIDIA_VISIBLE_DEVICES will be set\n+    Assert.assertTrue(\n+        runCommand.getEnv().get(\"NVIDIA_VISIBLE_DEVICES\").equals(\"0:0\"));\n+    // runtime should exist\n+    Assert.assertTrue(newCommandLine.containsKey(\"runtime\"));\n+  }\n+\n+  @Test(expected = ContainerExecutionException.class)\n+  public void testPluginMIGThrowsMulti() throws Exception {\n+    DockerRunCommand runCommand = new DockerRunCommand(\"container_1\", \"user\",\n+        \"fakeimage\");\n+\n+    Map<String, List<String>> originalCommandline = copyCommandLine(\n+        runCommand.getDockerCommandWithArguments());\n+\n+    Configuration conf = new Configuration();\n+    conf.set(YarnConfiguration.USE_MIG_ENABLED_GPUS, \"true\");\n+    MyNvidiaDockerV2CommandPlugin\n+        commandPlugin = new MyNvidiaDockerV2CommandPlugin(conf);\n+\n+    Container nmContainer = mock(Container.class);\n+    ResourceMappings resourceMappings = new ResourceMappings();\n+    Map<String, String> env = new HashMap<>();\n+    env.put(\"NVIDIA_MIG_PLUGIN_THROW_ON_MULTIPLE_GPUS\", \"true\");\n+    when(nmContainer.getResourceMappings()).thenReturn(resourceMappings);\n+    ContainerLaunchContext launchCtx = mock(ContainerLaunchContext.class);\n+    when(nmContainer.getLaunchContext()).thenReturn(launchCtx);\n+    when(launchCtx.getEnvironment()).thenReturn(env);\n+\n+    // Assign GPU resource\n+    ResourceMappings.AssignedResources assigned =\n+        new ResourceMappings.AssignedResources();\n+    assigned.updateAssignedResources(\n+        ImmutableList.of(new GpuDevice(0, 0, 0), new GpuDevice(1, 1, 2)));\n+    resourceMappings.addAssignedResources(ResourceInformation.GPU_URI,\n+        assigned);\n+\n+    commandPlugin.setRequestsGpu(true);\n+    commandPlugin.updateDockerRunCommand(runCommand, nmContainer);\n+  }\n+\n+  @Test\n+  public void testPluginMIGNoThrowsMulti() throws Exception {\n+    DockerRunCommand runCommand = new DockerRunCommand(\"container_1\", \"user\",\n+        \"fakeimage\");\n+\n+    Map<String, List<String>> originalCommandline = copyCommandLine(\n+        runCommand.getDockerCommandWithArguments());\n+\n+    Configuration conf = new Configuration();\n+    conf.set(YarnConfiguration.USE_MIG_ENABLED_GPUS, \"true\");\n+    MyNvidiaDockerV2CommandPlugin\n+        commandPlugin = new MyNvidiaDockerV2CommandPlugin(conf);\n+\n+    Container nmContainer = mock(Container.class);\n+    ResourceMappings resourceMappings = new ResourceMappings();\n+    Map<String, String> env = new HashMap<>();\n+    env.put(\"NVIDIA_MIG_PLUGIN_THROW_ON_MULTIPLE_GPUS\", \"false\");\n+    when(nmContainer.getResourceMappings()).thenReturn(resourceMappings);\n+    ContainerLaunchContext launchCtx = mock(ContainerLaunchContext.class);\n+    when(nmContainer.getLaunchContext()).thenReturn(launchCtx);\n+    when(launchCtx.getEnvironment()).thenReturn(env);\n+\n+    // Assign GPU resource\n+    ResourceMappings.AssignedResources assigned =\n+        new ResourceMappings.AssignedResources();\n+    assigned.updateAssignedResources(\n+        ImmutableList.of(new GpuDevice(0, 0, 0), new GpuDevice(1, 1, 2)));\n+    resourceMappings.addAssignedResources(ResourceInformation.GPU_URI,\n+        assigned);\n+\n+    commandPlugin.setRequestsGpu(true);\n+    commandPlugin.updateDockerRunCommand(runCommand, nmContainer);\n+    Map<String, List<String>> newCommandLine =\n+        runCommand.getDockerCommandWithArguments();\n+    // NVIDIA_VISIBLE_DEVICES will be set\n+    Assert.assertTrue(\n+        runCommand.getEnv().get(\"NVIDIA_VISIBLE_DEVICES\").equals(\"0:0,1:2\"));\n+    // runtime should exist\n+    Assert.assertTrue(newCommandLine.containsKey(\"runtime\"));\n+  }\n+}\n"
  },
  {
    "path": "examples/MIG-Support/resource-types/gpu-mig/yarn321to323MIG.patch",
    "content": "diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-api/src/main/java/org/apache/hadoop/yarn/conf/YarnConfiguration.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-api/src/main/java/org/apache/hadoop/yarn/conf/YarnConfiguration.java\nindex ad4d87daa1a..95259b1d956 100644\n--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-api/src/main/java/org/apache/hadoop/yarn/conf/YarnConfiguration.java\n+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-api/src/main/java/org/apache/hadoop/yarn/conf/YarnConfiguration.java\n@@ -1716,6 +1716,10 @@ public static boolean isAclEnabled(Configuration conf) {\n   @Private\n   public static final String AUTOMATICALLY_DISCOVER_GPU_DEVICES = \"auto\";\n \n+  @Private\n+  public static final String USE_MIG_ENABLED_GPUS =\n+          NM_GPU_RESOURCE_PREFIX + \"use-mig-enabled\";\n+\n   /**\n    * This setting controls where to how to invoke GPU binaries\n    */\ndiff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/AssignedGpuDevice.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/AssignedGpuDevice.java\nindex 26fd9050742..e84b920dcee 100644\n--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/AssignedGpuDevice.java\n+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/AssignedGpuDevice.java\n@@ -34,6 +34,12 @@ public AssignedGpuDevice(int index, int minorNumber,\n     this.containerId = containerId.toString();\n   }\n \n+  public AssignedGpuDevice(int index, int minorNumber,\n+                           int migIndex, ContainerId containerId) {\n+    super(index, minorNumber, migIndex);\n+    this.containerId = containerId.toString();\n+  }\n+\n   public String getContainerId() {\n     return containerId;\n   }\n@@ -49,6 +55,7 @@ public boolean equals(Object obj) {\n     }\n     AssignedGpuDevice other = (AssignedGpuDevice) obj;\n     return index == other.index && minorNumber == other.minorNumber\n+        && migDeviceIndex == other.migDeviceIndex\n         && containerId.equals(other.containerId);\n   }\n \n@@ -68,12 +75,16 @@ public int compareTo(Object obj) {\n     if (0 != result) {\n       return result;\n     }\n-    return containerId.compareTo(other.containerId);\n+    result = containerId.compareTo(other.containerId);\n+    if (0 != result) {\n+      return result;\n+    }\n+    return Integer.compare(migDeviceIndex, other.migDeviceIndex);\n   }\n \n   @Override\n   public int hashCode() {\n     final int prime = 47;\n-    return prime * (prime * index + minorNumber) + containerId.hashCode();\n+    return prime * (prime * index + minorNumber + migDeviceIndex) + containerId.hashCode();\n   }\n }\ndiff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDevice.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDevice.java\nindex bce1d9fa480..3cb42d3c58f 100644\n--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDevice.java\n+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDevice.java\n@@ -26,6 +26,7 @@\n public class GpuDevice implements Serializable, Comparable {\n   protected int index;\n   protected int minorNumber;\n+  protected int migDeviceIndex = -1;\n   private static final long serialVersionUID = -6812314470754667710L;\n \n   public GpuDevice(int index, int minorNumber) {\n@@ -33,6 +34,12 @@ public GpuDevice(int index, int minorNumber) {\n     this.minorNumber = minorNumber;\n   }\n \n+  public GpuDevice(int index, int minorNumber, int migIndex) {\n+    this.index = index;\n+    this.minorNumber = minorNumber;\n+    this.migDeviceIndex = migIndex;\n+  }\n+\n   public int getIndex() {\n     return index;\n   }\n@@ -41,13 +48,17 @@ public int getMinorNumber() {\n     return minorNumber;\n   }\n \n+  public int getMIGIndex() {\n+    return migDeviceIndex;\n+  }\n+\n   @Override\n   public boolean equals(Object obj) {\n     if (obj == null || !(obj instanceof GpuDevice)) {\n       return false;\n     }\n     GpuDevice other = (GpuDevice) obj;\n-    return index == other.index && minorNumber == other.minorNumber;\n+    return index == other.index && minorNumber == other.minorNumber && migDeviceIndex == other.migDeviceIndex;\n   }\n \n   @Override\n@@ -62,17 +73,21 @@ public int compareTo(Object obj) {\n     if (0 != result) {\n       return result;\n     }\n-    return Integer.compare(minorNumber, other.minorNumber);\n+    result = Integer.compare(minorNumber, other.minorNumber);\n+    if (0 != result) {\n+      return result;\n+    }\n+    return Integer.compare(migDeviceIndex, other.migDeviceIndex);\n   }\n \n   @Override\n   public int hashCode() {\n     final int prime = 47;\n-    return prime * index + minorNumber;\n+    return prime * index + minorNumber + migDeviceIndex;\n   }\n \n   @Override\n   public String toString() {\n-    return \"(index=\" + index + \",minor_number=\" + minorNumber + \")\";\n+    return \"(index=\" + index + \",minor_number=\" + minorNumber + \",mig_index=\" + migDeviceIndex + \")\";\n   }\n }\ndiff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDeviceSpecificationException.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDeviceSpecificationException.java\nindex 9d61b91a1f2..ffc2a4c19af 100644\n--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDeviceSpecificationException.java\n+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDeviceSpecificationException.java\n@@ -26,6 +26,8 @@\n public final class GpuDeviceSpecificationException extends YarnException {\n   private static final String VALID_FORMAT_MESSAGE = \"The valid format \" +\n       \"should be: index:minor_number\";\n+  private static final String VALID_MIG_FORMAT_MESSAGE = VALID_FORMAT_MESSAGE +\n+      \" or with MIG enabled: index:minor_number:mig_index\";\n \n   private GpuDeviceSpecificationException(String message) {\n     super(message);\n@@ -57,12 +59,31 @@ public static GpuDeviceSpecificationException createWithWrongValueSpecified(\n     return new GpuDeviceSpecificationException(message);\n   }\n \n+  public static GpuDeviceSpecificationException createWithWrongValueSpecifiedMIG(\n+      String device, String configValue, Exception cause) {\n+    final String message = createIllegalFormatMessageMIG(device, configValue);\n+    return new GpuDeviceSpecificationException(message, cause);\n+  }\n+\n+  public static GpuDeviceSpecificationException createWithWrongValueSpecifiedMIG(\n+      String device, String configValue) {\n+    final String message = createIllegalFormatMessageMIG(device, configValue);\n+    return new GpuDeviceSpecificationException(message);\n+  }\n+\n   public static GpuDeviceSpecificationException createWithDuplicateValueSpecified(\n       String device, String configValue) {\n     final String message = createDuplicateFormatMessage(device, configValue);\n     return new GpuDeviceSpecificationException(message);\n   }\n \n+  private static String createIllegalFormatMessageMIG(String device,\n+      String configValue) {\n+    return String.format(\"Illegal format of individual GPU device: %s, \" +\n+            \"the whole config value was: '%s'! \" + VALID_MIG_FORMAT_MESSAGE,\n+        device, configValue);\n+  }\n+\n   private static String createIllegalFormatMessage(String device,\n       String configValue) {\n     return String.format(\"Illegal format of individual GPU device: %s, \" +\n@@ -79,4 +100,4 @@ private static String createDuplicateFormatMessage(String device,\n             \"! Current value of the configuration is: %s\",\n         device, configValue);\n   }\n-}\n\\ No newline at end of file\n+}\ndiff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDiscoverer.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDiscoverer.java\nindex f710ff0bccd..1517e12599a 100644\n--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDiscoverer.java\n+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDiscoverer.java\n@@ -36,6 +36,7 @@\n import org.apache.hadoop.yarn.server.nodemanager.webapp.dao.gpu.GpuDeviceInformation;\n import org.apache.hadoop.yarn.server.nodemanager.webapp.dao.gpu.GpuDeviceInformationParser;\n import org.apache.hadoop.yarn.server.nodemanager.webapp.dao.gpu.PerGpuDeviceInformation;\n+import org.apache.hadoop.yarn.server.nodemanager.webapp.dao.gpu.PerGpuMigDevice;\n import org.slf4j.Logger;\n import org.slf4j.LoggerFactory;\n \n@@ -70,6 +71,7 @@\n   private GpuDeviceInformation lastDiscoveredGpuInformation = null;\n \n   private List<GpuDevice> gpuDevicesFromUser;\n+  private Boolean useMIGEnabledGPUs = false;\n \n   private void validateConfOrThrowException() throws YarnException {\n     if (conf == null) {\n@@ -188,8 +190,17 @@ private boolean isAutoDiscoveryEnabled() {\n       for (int i = 0; i < numberOfGpus; i++) {\n         List<PerGpuDeviceInformation> gpuInfos =\n             lastDiscoveredGpuInformation.getGpus();\n-        gpuDevices.add(new GpuDevice(i, gpuInfos.get(i).getMinorNumber()));\n+        if (useMIGEnabledGPUs &&\n+            gpuInfos.get(i).getMIGMode().getCurrentMigMode().equalsIgnoreCase(\"enabled\")) {\n+          LOG.info(\"GPU id \" + i + \" has MIG mode enabled.\");\n+          for (PerGpuMigDevice dev: gpuInfos.get(i).getMIGDevices()) {\n+            gpuDevices.add(new GpuDevice(i, gpuInfos.get(i).getMinorNumber(), dev.getMigDeviceIndex()));\n+          }\n+        } else {\n+          gpuDevices.add(new GpuDevice(i, gpuInfos.get(i).getMinorNumber()));\n+        }\n       }\n+      LOG.info(\"Discovered GPU devices: \" + gpuDevices);\n     }\n     return gpuDevices;\n   }\n@@ -212,28 +223,56 @@ private boolean isAutoDiscoveryEnabled() {\n     for (String device : devices.split(\",\")) {\n       if (device.trim().length() > 0) {\n         String[] splitByColon = device.trim().split(\":\");\n-        if (splitByColon.length != 2) {\n-          throwIfNecessary(GpuDeviceSpecificationException\n-              .createWithWrongValueSpecified(device, devices), conf);\n-          LOG.warn(\"Wrong GPU specification string {}, ignored\", device);\n-        }\n \n-        GpuDevice gpuDevice;\n-        try {\n-          gpuDevice = parseGpuDevice(splitByColon);\n-        } catch (NumberFormatException e) {\n-          throwIfNecessary(GpuDeviceSpecificationException\n-              .createWithWrongValueSpecified(device, devices, e), conf);\n-          LOG.warn(\"Cannot parse GPU device numbers: {}\", device);\n-          continue;\n-        }\n+        if (useMIGEnabledGPUs) {\n+          if (splitByColon.length != 2 && splitByColon.length != 3) {\n+            throwIfNecessary(GpuDeviceSpecificationException\n+                .createWithWrongValueSpecifiedMIG(device, devices), conf);\n+            LOG.warn(\"Wrong GPU specification string {}, ignored\", device);\n+          }\n+          GpuDevice gpuDevice;\n+          try {\n+            if (splitByColon.length == 3) {\n+              gpuDevice = parseGpuMIGDevice(splitByColon);\n+            } else {\n+              gpuDevice = parseGpuDevice(splitByColon);\n+            }\n+          } catch (NumberFormatException e) {\n+            throwIfNecessary(GpuDeviceSpecificationException\n+                .createWithWrongValueSpecifiedMIG(device, devices, e), conf);\n+            LOG.warn(\"Cannot parse GPU device numbers: {}\", device);\n+            continue;\n+          }\n+          if (!gpuDevices.contains(gpuDevice)) {\n+            gpuDevices.add(gpuDevice);\n+          } else {\n+            throw GpuDeviceSpecificationException\n+                .createWithDuplicateValueSpecified(device, devices);\n+          }\n \n-        if (!gpuDevices.contains(gpuDevice)) {\n-          gpuDevices.add(gpuDevice);\n         } else {\n-          throwIfNecessary(GpuDeviceSpecificationException\n-              .createWithDuplicateValueSpecified(device, devices), conf);\n-          LOG.warn(\"CPU device is duplicated: {}\", device);\n+          if (splitByColon.length != 2) {\n+            throwIfNecessary(GpuDeviceSpecificationException\n+                .createWithWrongValueSpecified(device, devices), conf);\n+            LOG.warn(\"Wrong GPU specification string {}, ignored\", device);\n+          }\n+          GpuDevice gpuDevice;\n+          try {\n+            gpuDevice = parseGpuDevice(splitByColon);\n+          } catch (NumberFormatException e) {\n+            throwIfNecessary(GpuDeviceSpecificationException\n+                .createWithWrongValueSpecified(device, devices, e), conf);\n+            LOG.warn(\"Cannot parse GPU device numbers: {}\", device);\n+            continue;\n+          }\n+\n+          if (!gpuDevices.contains(gpuDevice)) {\n+            gpuDevices.add(gpuDevice);\n+          } else {\n+            throwIfNecessary(GpuDeviceSpecificationException\n+                .createWithDuplicateValueSpecified(device, devices), conf);\n+            LOG.warn(\"CPU device is duplicated: {}\", device);\n+          }\n         }\n       }\n     }\n@@ -248,6 +287,12 @@ private GpuDevice parseGpuDevice(String[] splitByColon) {\n     return new GpuDevice(index, minorNumber);\n   }\n \n+  private GpuDevice parseGpuMIGDevice(String[] splitByColon) {\n+      int index = Integer.parseInt(splitByColon[0]);\n+      int minorNumber = Integer.parseInt(splitByColon[1]);\n+      int migIndex = Integer.parseInt(splitByColon[2]);\n+      return new GpuDevice(index, minorNumber, migIndex);\n+  }\n \n   public synchronized void initialize(Configuration config,\n       NvidiaBinaryHelper nvidiaHelper) throws YarnException {\n@@ -269,6 +314,9 @@ public synchronized void initialize(Configuration config,\n         LOG.warn(msg);\n       }\n     }\n+    useMIGEnabledGPUs = conf.getBoolean(YarnConfiguration.USE_MIG_ENABLED_GPUS, false);\n+    LOG.info(\"Use MIG enabled is: \" + useMIGEnabledGPUs);\n+\n   }\n \n   private void lookUpAutoDiscoveryBinary(Configuration config)\ndiff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDockerCommandPluginFactory.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDockerCommandPluginFactory.java\nindex 051afd6c561..996cb58ac45 100644\n--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDockerCommandPluginFactory.java\n+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/GpuDockerCommandPluginFactory.java\n@@ -36,7 +36,7 @@ public static DockerCommandPlugin createGpuDockerCommandPlugin(\n     }\n     // nvidia-docker2\n     if (impl.equals(YarnConfiguration.NVIDIA_DOCKER_V2)) {\n-      return new NvidiaDockerV2CommandPlugin();\n+      return new NvidiaDockerV2CommandPlugin(conf);\n     }\n \n     throw new YarnException(\ndiff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/NvidiaDockerV2CommandPlugin.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/NvidiaDockerV2CommandPlugin.java\nindex ff25eb6ced6..c2cc0e5a2d1 100644\n--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/NvidiaDockerV2CommandPlugin.java\n+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/NvidiaDockerV2CommandPlugin.java\n@@ -21,7 +21,9 @@\n import com.google.common.annotations.VisibleForTesting;\n import org.apache.commons.logging.Log;\n import org.apache.commons.logging.LogFactory;\n+import org.apache.hadoop.conf.Configuration;\n import org.apache.hadoop.yarn.api.records.ResourceInformation;\n+import org.apache.hadoop.yarn.conf.YarnConfiguration;\n import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container;\n import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.ResourceMappings;\n import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.gpu.GpuResourceAllocator;\n@@ -45,8 +47,12 @@\n \n   private String nvidiaRuntime = \"nvidia\";\n   private String nvidiaVisibleDevices = \"NVIDIA_VISIBLE_DEVICES\";\n+  private String nvidiaMigThrowOnMultiGpus = \"NVIDIA_MIG_PLUGIN_THROW_ON_MULTIPLE_GPUS\";\n+  private Boolean isMigEnabled = false;\n \n-  public NvidiaDockerV2CommandPlugin() {}\n+  public NvidiaDockerV2CommandPlugin(Configuration conf) {\n+    isMigEnabled = conf.getBoolean(YarnConfiguration.USE_MIG_ENABLED_GPUS, false);\n+  }\n \n   private Set<GpuDevice> getAssignedGpus(Container container) {\n     ResourceMappings resourceMappings = container.getResourceMappings();\n@@ -84,10 +90,23 @@ public synchronized void updateDockerRunCommand(\n       return;\n     }\n     Map<String, String> environment = new HashMap<>();\n+    if (isMigEnabled && assignedResources.size() > 1) {\n+      Map<String, String> existingEnv = container.getLaunchContext().getEnvironment();\n+      Boolean shouldThrowOnMultipleGpus = Boolean.parseBoolean(\n+              existingEnv.getOrDefault(nvidiaMigThrowOnMultiGpus, \"true\"));\n+      if (shouldThrowOnMultipleGpus) {\n+        throw new ContainerExecutionException(\"Allocating more than 1 GPU per container is \" +\n+                \"not supported with use of MIG!\");\n+      }\n+    }\n     String gpuIndexList = \"\";\n     for (GpuDevice gpuDevice : assignedResources) {\n-      gpuIndexList = gpuIndexList + gpuDevice.getIndex() + \",\";\n-      LOG.info(\"nvidia docker2 assigned gpu index: \" + gpuDevice.getIndex());\n+      String deviceIndex = String.valueOf(gpuDevice.getIndex());\n+      if (gpuDevice.getMIGIndex() != -1) {\n+        deviceIndex = gpuDevice.getIndex() + \":\" + gpuDevice.getMIGIndex();\n+      }\n+      gpuIndexList = gpuIndexList + deviceIndex + \",\";\n+      LOG.info(\"nvidia docker2 assigned gpu index: \" + deviceIndex);\n     }\n     dockerRunCommand.addRuntime(nvidiaRuntime);\n     environment.put(nvidiaVisibleDevices,\ndiff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuDeviceInformation.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuDeviceInformation.java\nindex 11ff2a4c49c..939ed46aac7 100644\n--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuDeviceInformation.java\n+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuDeviceInformation.java\n@@ -22,8 +22,10 @@\n import org.apache.hadoop.classification.InterfaceStability;\n \n import javax.xml.bind.annotation.XmlElement;\n+import javax.xml.bind.annotation.XmlElementWrapper;\n import javax.xml.bind.annotation.XmlRootElement;\n import javax.xml.bind.annotation.adapters.XmlAdapter;\n+import java.util.List;\n \n /**\n  * Capture single GPU device information such as memory size, temperature,\n@@ -37,6 +39,8 @@\n   private String uuid = \"N/A\";\n   private int minorNumber = -1;\n \n+  private List<PerGpuMigDevice> migDevices;\n+  private PerGpuMigMode migMode;\n   private PerGpuUtilizations gpuUtilizations;\n   private PerGpuMemoryUsage gpuMemoryUsage;\n   private PerGpuTemperature temperature;\n@@ -107,6 +111,25 @@ public void setUuid(String uuid) {\n     this.uuid = uuid;\n   }\n \n+  @XmlElement(name = \"mig_mode\")\n+  public PerGpuMigMode getMIGMode() {\n+    return migMode;\n+  }\n+\n+  public void setMIGMode(PerGpuMigMode mode) {\n+    this.migMode = mode;\n+  }\n+\n+  @XmlElementWrapper( name = \"mig_devices\" )\n+  @XmlElement(name = \"mig_device\")\n+  public List<PerGpuMigDevice> getMIGDevices() {\n+    return migDevices;\n+  }\n+\n+  public void setMIGDevices(List<PerGpuMigDevice> devices) {\n+    this.migDevices = devices;\n+  }\n+\n   @XmlElement(name = \"product_name\")\n   public String getProductName() {\n     return productName;\ndiff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuMigDevice.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuMigDevice.java\nnew file mode 100644\nindex 00000000000..4ce7cec6e55\n--- /dev/null\n+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuMigDevice.java\n@@ -0,0 +1,48 @@\n+/**\n+ * Licensed to the Apache Software Foundation (ASF) under one\n+ * or more contributor license agreements.  See the NOTICE file\n+ * distributed with this work for additional information\n+ * regarding copyright ownership.  The ASF licenses this file\n+ * to you under the Apache License, Version 2.0 (the\n+ * \"License\"); you may not use this file except in compliance\n+ * with the License.  You may obtain a copy of the License at\n+ *\n+ *     http://www.apache.org/licenses/LICENSE-2.0\n+ *\n+ * Unless required by applicable law or agreed to in writing, software\n+ * distributed under the License is distributed on an \"AS IS\" BASIS,\n+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n+ * See the License for the specific language governing permissions and\n+ * limitations under the License.\n+ */\n+\n+package org.apache.hadoop.yarn.server.nodemanager.webapp.dao.gpu;\n+\n+import org.apache.hadoop.classification.InterfaceAudience;\n+import org.apache.hadoop.classification.InterfaceStability;\n+\n+import javax.xml.bind.annotation.XmlElement;\n+import javax.xml.bind.annotation.XmlRootElement;\n+\n+/**\n+ * GPU MIG Device Information\n+ */\n+@InterfaceAudience.Private\n+@InterfaceStability.Unstable\n+@XmlRootElement(name = \"mig_device\")\n+public class PerGpuMigDevice {\n+  private int index;\n+\n+  /**\n+   * MIG device index\n+   * @return MIG device index\n+   */\n+  @XmlElement(name = \"index\")\n+  public int getMigDeviceIndex() {\n+    return index;\n+  }\n+\n+  public void setMigDeviceIndex(int index) {\n+    this.index = index;\n+  }\n+}\ndiff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuMigMode.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuMigMode.java\nnew file mode 100644\nindex 00000000000..b706df2c3bb\n--- /dev/null\n+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/webapp/dao/gpu/PerGpuMigMode.java\n@@ -0,0 +1,48 @@\n+/**\n+ * Licensed to the Apache Software Foundation (ASF) under one\n+ * or more contributor license agreements.  See the NOTICE file\n+ * distributed with this work for additional information\n+ * regarding copyright ownership.  The ASF licenses this file\n+ * to you under the Apache License, Version 2.0 (the\n+ * \"License\"); you may not use this file except in compliance\n+ * with the License.  You may obtain a copy of the License at\n+ *\n+ *     http://www.apache.org/licenses/LICENSE-2.0\n+ *\n+ * Unless required by applicable law or agreed to in writing, software\n+ * distributed under the License is distributed on an \"AS IS\" BASIS,\n+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n+ * See the License for the specific language governing permissions and\n+ * limitations under the License.\n+ */\n+\n+package org.apache.hadoop.yarn.server.nodemanager.webapp.dao.gpu;\n+\n+import org.apache.hadoop.classification.InterfaceAudience;\n+import org.apache.hadoop.classification.InterfaceStability;\n+\n+import javax.xml.bind.annotation.XmlElement;\n+import javax.xml.bind.annotation.XmlRootElement;\n+\n+/**\n+ * GPU MIG Mode\n+ */\n+@InterfaceAudience.Private\n+@InterfaceStability.Unstable\n+@XmlRootElement(name = \"mig_mode\")\n+public class PerGpuMigMode {\n+  private String currentMigMode;\n+\n+  /**\n+   * Current MIG mode\n+   * @return MIG mode enabled or disabled\n+   */\n+  @XmlElement(name = \"current_mig\")\n+  public String getCurrentMigMode() {\n+    return currentMigMode;\n+  }\n+\n+  public void setCurrentMigMode(String migMode) {\n+    this.currentMigMode = migMode;\n+  }\n+}\ndiff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/TestGpuDiscoverer.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/TestGpuDiscoverer.java\nindex 8261895b2a9..6c1f500009c 100644\n--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/TestGpuDiscoverer.java\n+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/TestGpuDiscoverer.java\n@@ -373,6 +373,37 @@ public void testGetNumberOfUsableGpusFromConfig() throws YarnException {\n     assertEquals(4, usableGpuDevices.get(3).getMinorNumber());\n   }\n \n+  @Test\n+  public void testGetNumberOfUsableGpusFromConfigMIG() throws YarnException {\n+    Configuration conf = createConfigWithAllowedDevices(\"0:0,1:1:0,1:1:3,2:2,3:4\");\n+    conf.set(YarnConfiguration.USE_MIG_ENABLED_GPUS, \"true\");\n+    GpuDiscoverer discoverer = new GpuDiscoverer();\n+    discoverer.initialize(conf, binaryHelper);\n+\n+    List<GpuDevice> usableGpuDevices = discoverer.getGpusUsableByYarn();\n+    assertEquals(5, usableGpuDevices.size());\n+\n+    assertEquals(0, usableGpuDevices.get(0).getIndex());\n+    assertEquals(0, usableGpuDevices.get(0).getMinorNumber());\n+    assertEquals(-1, usableGpuDevices.get(0).getMIGIndex());\n+\n+    assertEquals(1, usableGpuDevices.get(1).getIndex());\n+    assertEquals(1, usableGpuDevices.get(1).getMinorNumber());\n+    assertEquals(0, usableGpuDevices.get(1).getMIGIndex());\n+\n+    assertEquals(1, usableGpuDevices.get(2).getIndex());\n+    assertEquals(1, usableGpuDevices.get(2).getMinorNumber());\n+    assertEquals(3, usableGpuDevices.get(2).getMIGIndex());\n+\n+    assertEquals(2, usableGpuDevices.get(3).getIndex());\n+    assertEquals(2, usableGpuDevices.get(3).getMinorNumber());\n+    assertEquals(-1, usableGpuDevices.get(3).getMIGIndex());\n+\n+    assertEquals(3, usableGpuDevices.get(4).getIndex());\n+    assertEquals(4, usableGpuDevices.get(4).getMinorNumber());\n+    assertEquals(-1, usableGpuDevices.get(4).getMIGIndex());\n+  }\n+\n   @Test\n   public void testGetNumberOfUsableGpusFromConfigDuplicateValues()\n       throws YarnException {\n@@ -513,4 +544,4 @@ public void testScriptNotCalled() throws YarnException, IOException {\n \n     verify(gpuSpy, never()).getGpuDeviceInformation();\n   }\n-}\n\\ No newline at end of file\n+}\ndiff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/TestNvidiaDockerV2CommandPlugin.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/TestNvidiaDockerV2CommandPlugin.java\nindex b0b523360ef..798a95cb009 100644\n--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/TestNvidiaDockerV2CommandPlugin.java\n+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/TestNvidiaDockerV2CommandPlugin.java\n@@ -20,10 +20,14 @@\n \n import com.google.common.collect.ImmutableList;\n import com.google.common.collect.Sets;\n+import org.apache.hadoop.conf.Configuration;\n+import org.apache.hadoop.yarn.conf.YarnConfiguration;\n+import org.apache.hadoop.yarn.api.records.ContainerLaunchContext;\n import org.apache.hadoop.yarn.api.records.ResourceInformation;\n import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container;\n import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.ResourceMappings;\n import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.docker.DockerRunCommand;\n+import org.apache.hadoop.yarn.server.nodemanager.containermanager.runtime.ContainerExecutionException;\n import org.junit.Assert;\n import org.junit.Test;\n \n@@ -69,7 +73,13 @@ private boolean commandlinesEquals(Map<String, List<String>> cli1,\n       extends NvidiaDockerV2CommandPlugin {\n     private boolean requestsGpu = false;\n \n-    MyNvidiaDockerV2CommandPlugin() {}\n+    MyNvidiaDockerV2CommandPlugin() {\n+      super(new Configuration());\n+    }\n+\n+    MyNvidiaDockerV2CommandPlugin(Configuration conf) {\n+      super(conf);\n+    }\n \n     public void setRequestsGpu(boolean r) {\n       requestsGpu = r;\n@@ -127,4 +137,118 @@ public void testPlugin() throws Exception {\n     // runtime should exist\n     Assert.assertTrue(newCommandLine.containsKey(\"runtime\"));\n   }\n-}\n\\ No newline at end of file\n+\n+  @Test\n+  public void testPluginMIG() throws Exception {\n+    DockerRunCommand runCommand = new DockerRunCommand(\"container_1\", \"user\",\n+        \"fakeimage\");\n+\n+    Map<String, List<String>> originalCommandline = copyCommandLine(\n+        runCommand.getDockerCommandWithArguments());\n+\n+    Configuration conf = new Configuration();\n+    conf.set(YarnConfiguration.USE_MIG_ENABLED_GPUS, \"true\");\n+    MyNvidiaDockerV2CommandPlugin\n+        commandPlugin = new MyNvidiaDockerV2CommandPlugin(conf);\n+\n+    Container nmContainer = mock(Container.class);\n+    ResourceMappings resourceMappings = new ResourceMappings();\n+    when(nmContainer.getResourceMappings()).thenReturn(resourceMappings);\n+\n+    // Assign GPU resource\n+    ResourceMappings.AssignedResources assigned =\n+        new ResourceMappings.AssignedResources();\n+    assigned.updateAssignedResources(\n+        ImmutableList.of(new GpuDevice(0, 0, 0)));\n+    resourceMappings.addAssignedResources(ResourceInformation.GPU_URI,\n+        assigned);\n+\n+    commandPlugin.setRequestsGpu(true);\n+    commandPlugin.updateDockerRunCommand(runCommand, nmContainer);\n+    Map<String, List<String>> newCommandLine =\n+        runCommand.getDockerCommandWithArguments();\n+\n+    // Command line will be updated\n+    Assert.assertFalse(commandlinesEquals(originalCommandline, newCommandLine));\n+    // NVIDIA_VISIBLE_DEVICES will be set\n+    Assert.assertTrue(\n+        runCommand.getEnv().get(\"NVIDIA_VISIBLE_DEVICES\").equals(\"0:0\"));\n+    // runtime should exist\n+    Assert.assertTrue(newCommandLine.containsKey(\"runtime\"));\n+  }\n+\n+  @Test(expected = ContainerExecutionException.class)\n+  public void testPluginMIGThrowsMulti() throws Exception {\n+    DockerRunCommand runCommand = new DockerRunCommand(\"container_1\", \"user\",\n+        \"fakeimage\");\n+\n+    Map<String, List<String>> originalCommandline = copyCommandLine(\n+        runCommand.getDockerCommandWithArguments());\n+\n+    Configuration conf = new Configuration();\n+    conf.set(YarnConfiguration.USE_MIG_ENABLED_GPUS, \"true\");\n+    MyNvidiaDockerV2CommandPlugin\n+        commandPlugin = new MyNvidiaDockerV2CommandPlugin(conf);\n+\n+    Container nmContainer = mock(Container.class);\n+    ResourceMappings resourceMappings = new ResourceMappings();\n+    Map<String, String> env = new HashMap<>();\n+    env.put(\"NVIDIA_MIG_PLUGIN_THROW_ON_MULTIPLE_GPUS\", \"true\");\n+    when(nmContainer.getResourceMappings()).thenReturn(resourceMappings);\n+    ContainerLaunchContext launchCtx = mock(ContainerLaunchContext.class);\n+    when(nmContainer.getLaunchContext()).thenReturn(launchCtx);\n+    when(launchCtx.getEnvironment()).thenReturn(env);\n+\n+    // Assign GPU resource\n+    ResourceMappings.AssignedResources assigned =\n+        new ResourceMappings.AssignedResources();\n+    assigned.updateAssignedResources(\n+        ImmutableList.of(new GpuDevice(0, 0, 0), new GpuDevice(1, 1, 2)));\n+    resourceMappings.addAssignedResources(ResourceInformation.GPU_URI,\n+        assigned);\n+\n+    commandPlugin.setRequestsGpu(true);\n+    commandPlugin.updateDockerRunCommand(runCommand, nmContainer);\n+  }\n+\n+  @Test\n+  public void testPluginMIGNoThrowsMulti() throws Exception {\n+    DockerRunCommand runCommand = new DockerRunCommand(\"container_1\", \"user\",\n+        \"fakeimage\");\n+\n+    Map<String, List<String>> originalCommandline = copyCommandLine(\n+        runCommand.getDockerCommandWithArguments());\n+\n+    Configuration conf = new Configuration();\n+    conf.set(YarnConfiguration.USE_MIG_ENABLED_GPUS, \"true\");\n+    MyNvidiaDockerV2CommandPlugin\n+        commandPlugin = new MyNvidiaDockerV2CommandPlugin(conf);\n+\n+    Container nmContainer = mock(Container.class);\n+    ResourceMappings resourceMappings = new ResourceMappings();\n+    Map<String, String> env = new HashMap<>();\n+    env.put(\"NVIDIA_MIG_PLUGIN_THROW_ON_MULTIPLE_GPUS\", \"false\");\n+    when(nmContainer.getResourceMappings()).thenReturn(resourceMappings);\n+    ContainerLaunchContext launchCtx = mock(ContainerLaunchContext.class);\n+    when(nmContainer.getLaunchContext()).thenReturn(launchCtx);\n+    when(launchCtx.getEnvironment()).thenReturn(env);\n+\n+    // Assign GPU resource\n+    ResourceMappings.AssignedResources assigned =\n+        new ResourceMappings.AssignedResources();\n+    assigned.updateAssignedResources(\n+        ImmutableList.of(new GpuDevice(0, 0, 0), new GpuDevice(1, 1, 2)));\n+    resourceMappings.addAssignedResources(ResourceInformation.GPU_URI,\n+        assigned);\n+\n+    commandPlugin.setRequestsGpu(true);\n+    commandPlugin.updateDockerRunCommand(runCommand, nmContainer);\n+    Map<String, List<String>> newCommandLine =\n+        runCommand.getDockerCommandWithArguments();\n+    // NVIDIA_VISIBLE_DEVICES will be set\n+    Assert.assertTrue(\n+        runCommand.getEnv().get(\"NVIDIA_VISIBLE_DEVICES\").equals(\"0:0,1:2\"));\n+    // runtime should exist\n+    Assert.assertTrue(newCommandLine.containsKey(\"runtime\"));\n+  }\n+}\n"
  },
  {
    "path": "examples/MIG-Support/yarn-unpatched/README.md",
    "content": "# MIG Support for Spark on YARN using unmodified versions of Apache Hadoop 3.1.2+\n\nThis document describes a solution for utilizing MIG with YARN when upgrading to a recent 3.3+\nversion or patching older versions of Apache Hadoop is not feasible. Please refer to the corresponding\nalternatives for more information:\n- [Device Plugins README](../device-plugins/gpu-mig/README.md)\n- [YARN patch README](../resource-types/gpu-mig/README.md)\n\n## Introduction\n\nWe provide a set of scripts that wrap the original `nvidia-smi` from the NVIDIA GPU Driver and `nvidia-container-cli`\nincluded in [NVIDIA Container Toolkit](https://github.com/NVIDIA/nvidia-docker).\n\n`nvidia-smi` is a wrapper script that parses the XML output of `nvidia-smi -q -x` used by YARN\nto discover GPUs. It replaces MIG-enabled GPUs with the list of `<gpu>` elements corresponding to every\n`<mig_device>` element of the GPU with additional annotation to construct the MIG identifier for\n`nvidia-container-cli`. This reverse mapping is performed by  modified `nvidia` Docker runtime using\n`nvidia-container-cli-wrapper.sh`.\n\n## Requirements\n\nPlease see the [MIG Application Considerations](https://docs.nvidia.com/datacenter/tesla/mig-user-guide/#app-considerations)\nand [CUDA Device Enumeration](https://docs.nvidia.com/datacenter/tesla/mig-user-guide/index.html#cuda-visible-devices).\n\nSpecial note, that this method only works with drivers >= R470 (470.42.01+).\n\n## Installation\n\nThese instructions assume YARN is already installed and configured with GPU Scheduling enabled\nusing Docker and the NVIDIA Container Toolkit (nvidia-docker2).\nSee [Using GPU on YARN](https://hadoop.apache.org/docs/r3.1.2/hadoop-yarn/hadoop-yarn-site/UsingGpus.html) if\nyou need more information.\n\nEnable and configure your [GPUs with MIG](https://docs.nvidia.com/datacenter/tesla/mig-user-guide/index.html) on all of the nodes\nit applies to.\n\nDownload the contents of [scripts](./scripts/) to every YARN NodeManager (worker) machine\nto some location, for example: `/usr/local/yarn-mig-scripts`. Make sure that the scripts\nare executable by the docker daemon user (i.e., `root`), and YARN NM service user (typically `yarn`). Note that the scripts\nleave the original outputs untouched if the environment variable `MIG_AS_GPU_ENABLED` is not 1.\n\n### YARN Configuration\n#### Customizing yarn-env.sh\n\nIn `$YARN_CONF_DIR/yarn-env.sh`\n- Add `export MIG_AS_GPU_ENABLED=1` to enable replacing of MIG-enabled GPUs with a list\nof of MIG devices as if they are physical GPU.\n- Customize `REAL_NVIDIA_SMI_PATH` value if nvidia-smi is not at the default location\n`/usr/bin/nvidia-smi`.\n- Add `ENABLE_NON_MIG_GPUS=0` if you want to prevent discovery of physical GPUs that are not subdivided in MIGs.\nDefault is ENABLE_NON_MIG_GPUS=1 and physical GPUs in the MIG-Disabled state are listed along with MIG sub-devices on the node.\n\nModify the following config `$YARN_CONF_DIR/yarn-site.xml`:\n```xml\n<property>\n  <name>yarn.nodemanager.resource-plugins.gpu.path-to-discovery-executables</name>\n  <value>/usr/local/yarn-mig-scripts/</value>\n</property>\n```\n\nBy default, `yarn.nodemanager.resource-plugins.gpu.allowed-gpu-devices` is set to `auto` and\nand `/usr/local/yarn-mig-scripts/nvidia-smi` will be called by YARN to discover GPUs.\n\nIf you disable the default automatic GPU discovery, you can manually\nspecify the list of MIG instances to use by setting\n`yarn.nodemanager.resource-plugins.gpu.allowed-gpu-devices` to the list of\n0-based indices corresponding to the desired `<gpu>` elements in the output of\n\n```bash\nMIG_AS_GPU_ENABLED=1 /usr/local/yarn-mig-scripts/nvidia-smi -q -x\n```\n\nIn other words, if you want to allow MIG 1:2 and 2:0 and they are listed as 3rd and 5th `<gpu>`\nelements the value for `yarn.nodemanager.resource-plugins.gpu.allowed-gpu-devices` should be\n\"2,4\".\n\n### NVIDIA Docker Runtime Configuration\n\nModify section `[nvidia-container-cli]` in `/etc/nvidia-container-runtime/config.toml`:\n```toml\npath = \"/usr/local/yarn-mig-scripts/nvidia-container-cli-wrapper.sh\"\nenvironment = [ \"MIG_AS_GPU_ENABLED=1\",  \"REAL_NVIDIA_SMI_PATH=/if/non-default/path/nvidia-smi\" ]\n```\n\nNote, the values for `MIG_AS_GPU_ENABLED`, `REAL_NVIDIA_SMI_PATH`, `ENABLE_NON_MIG_GPUS` should be\nidentical to the ones specified in `yarn-env.sh`.\n\n"
  },
  {
    "path": "examples/MIG-Support/yarn-unpatched/scripts/mig2gpu.sh",
    "content": "#!/bin/bash\n\n# Copyright (c) 2022, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nset -e\n\n# This file contains the logic for parsing and manipulating the well-formed\n# pretty-printed XML output generated by nvidia-smi. It replaces the a each MIG-enabled gpu element with\n# with a list of gpu elements corresponding to its configured MIG devices.\n# If there is at least one MIG-enabled GPU, the output for non-MIG GPUs is suppressed by default. However,\n# this can be overridden using ENABLE_NON_MIG_GPUS=1.\n\n# XML fragments are viewed and manipuated using bash arrays of lines. Each elmenent of interest is tracked by\n# a start offset into the line array pointing to the the line with the opening tag and the end offset,\n# which is the line number past the closing tag.\n#\n# NOTE: this is not a real XML parser, but it is sufficient to handle XML without nested\n# tags mixed on the same line. When making changes try to avoid non-bash dependencies.\n\n# Include both MIG and non-MIG devices by default\n# Set ENABLE_NON_MIG_GPUS=0 to discover only GPU devices with the current MIG mode Disabled\nENABLE_NON_MIG_GPUS=${ENABLE_NON_MIG_GPUS:-1}\n\n# If setting YARN up to use Cgroups without official YARN support,\n# enabling this tells the script to use the NVIDIA capabilities access\n# device number for the minor number so that the YARN Cgroup code\n# denies access to MIG devices properly.\nENABLE_MIG_GPUS_FOR_CGROUPS=${ENABLE_MIG_GPUS_FOR_CGROUPS:-0}\n\n# For stored input test: NVIDIA_SMI_QX=./src/resources/tom-nvidia-smi-xq.xml\n# For live input test: NVIDIA_SMI_QX=/dev/stdin\nNVIDIA_SMI_QX=\"${NVIDIA_SMI_QX:-\"/dev/stdin\"}\"\n\nmig2gpu_inputLines=()\n\n# buffer global output here\nmig2gpu_out=()\n\nmig2gpu_migEnabled=0\n\nmig2gpu_driverVersion=\"INVALID_DRIVER_VERSION\"\n\n# buffer non-MIG GPU output here\nmig2gpu_nonMigGpu_out=()\nmig2gpu_migGpu_out=()\n\n# Slice of original XML defining the current GPU element\nmig2gpu_gpu_lineNumberStart=-1\nmig2gpu_gpu_lineNumberEnd=-1\n\n# Slice of original XML defining the current MIG element\nmig2gpu_mig_lineNumberStart=-1\nmig2gpu_mig_lineNumberEnd=-1\nmig2gpu_migIndex=-1\n\n# Parent GPU context for MIG\nmig2gpu_gpuIdx=-1\nmig2gpu_migGpuInstanceId=-1\nmig2gpu_migComputeInstanceUuid=-1\nmig2gpu_productName=\"INVALID_GPU_PRODUCT_NAME\"\nmig2gpu_gpuUuid=\"INVALID_GPU_UUID\"\nmig2gpu_gpuMinorNumber=\"INVALID_GPU_MINOR_NUMBER\"\nmig2gpu_gpu_utilization_lineNumberStart=-1\nmig2gpu_gpu_utilization_lineNumberEnd=-1\nmig2gpu_gpu_temperature_lineNumberStart=-1\nmig2gpu_gpu_temperature_lineNumberEnd=-1\n\n# The function to replace a MIG-enabled GPU with the \"fake\" GPU device elements\n# corresponding to MIG devices contained within the given GPU element\n#\n# The minimum GPU content YARN needs from GPU for parse to succeed:\n#\n# <nvidia_smi_log>\n#         <driver_version>495.29.05</driver_version>\n#         <gpu id=\"00000000:17:00.0\">\n#                 <product_name>Quadro RTX 6000</product_name>\n#                 <uuid>GPU-903720f4-f8d1-11e0-3b2f-4bd740b2f424</uuid>\n#                 <minor_number>0</minor_number>\n#                 <fb_memory_usage>\n#                         <used>673 MiB</used>\n#                         <free>&23547 MiB</free>\n#                 </fb_memory_usage>\n#                 <utilization>\n#                         <gpu_util>&23 %</gpu_util>\n#                 </utilization>\n#                 <temperature>\n#                         <gpu_temp>38 C</gpu_temp>\n#                         <gpu_temp_max_threshold>94 C</gpu_temp_max_threshold>\n#                         <gpu_temp_slow_threshold>91 C</gpu_temp_slow_threshold>\n#                 </temperature>\n#         </gpu>\n# </nvidia_smi_log>\n#\n# A MIG device looks like this:\n# <mig_device>\n#     <index>0</index>\n#     <gpu_instance_id>3</gpu_instance_id>\n#     <compute_instance_id>0</compute_instance_id>\n#     <device_attributes>\n#         <shared>\n#             <multiprocessor_count>14</multiprocessor_count>\n#             <copy_engine_count>1</copy_engine_count>\n#             <encoder_count>0</encoder_count>\n#             <decoder_count>1</decoder_count>\n#             <ofa_count>0</ofa_count>\n#             <jpg_count>0</jpg_count>\n#         </shared>\n#     </device_attributes>\n#     <ecc_error_count>\n#         <volatile_count>\n#             <sram_uncorrectable>0</sram_uncorrectable>\n#         </volatile_count>\n#     </ecc_error_count>\n#     <fb_memory_usage>\n#         <total>6016 MiB</total>\n#         <used>3 MiB</used>\n#         <free>6012 MiB</free>\n#     </fb_memory_usage>\n#     <bar1_memory_usage>\n#         <total>8191 MiB</total>\n#         <used>0 MiB</used>\n#         <free>8191 MiB</free>\n#     </bar1_memory_usage>\n# </mig_device>\n#\n# To satisfy the minimum parseable GPU element, we need to\n# 1) add a <product_name> element, parent's orginal text + MIG + index\n# 2) add a <uuid> element accoring to https://docs.nvidia.com/datacenter/tesla/mig-user-guide/#cuda-gi\n#    MIG-<parent gpu uuid>/<gpu instance id>/<compute instance id>\n# 3) add parent's <minor_number> 0 (don't care)\n# 4) use MIG's own <fb_memory_usage> element unchanged\n# 5) copy <utilization> element from parent\n# 6) copy <temperature> element from parent\n#\n# To enable bidirectional translation <mig_device> to/from fake <gpu>\n# 7) add a <_mig2gpu_device_id> element: \"<parent gpu index>:<mig index>\", e.g. 0:0\n\n\nfunction processParentGpuGlobals {\n    local lineNumber\n\n    # increment 0-based GPU iteration order index\n    mig2gpu_gpuIdx=$((mig2gpu_gpuIdx+1))\n\n    for ((lineNumber=mig2gpu_gpu_lineNumberStart; lineNumber<mig2gpu_gpu_lineNumberEnd; lineNumber++)); do\n        local line=\"${mig2gpu_inputLines[$lineNumber]}\"\n\n        case \"$line\" in\n\n            $'\\t'*'<current_mig>'*'</current_mig>')\n                if [[ \"$line\" =~ '<current_mig>Enabled</current_mig>' ]]; then\n                    mig2gpu_migEnabled=1\n                else\n                    mig2gpu_migEnabled=0\n                fi\n                ;;\n\n            $'\\t'*'<product_name>'*)\n                if [[ \"$line\" =~ $'\\t\\t<product_name>'(.*)'</product_name>' ]]; then\n                    mig2gpu_productName=\"${BASH_REMATCH[1]}\"\n                fi\n                ;;\n\n            $'\\t'*'<uuid>'*)\n                if [[ \"$line\" =~ $'\\t\\t<uuid>'(.*)'</uuid>' ]]; then\n                    mig2gpu_gpuUuid=\"${BASH_REMATCH[1]}\"\n                fi\n                ;;\n\n            $'\\t'*'<minor_number>'*)\n                mig2gpu_gpuMinorNumber=\"$line\"\n                ;;\n\n            $'\\t'*'<utilization>'*)\n                mig2gpu_gpu_utilization_lineNumberStart=\"$lineNumber\"\n                ;;\n\n            $'\\t'*'</utilization>'*)\n                mig2gpu_gpu_utilization_lineNumberEnd=$((lineNumber+1))\n                ;;\n\n            $'\\t'*'<temperature>'*)\n                mig2gpu_gpu_temperature_lineNumberStart=\"$lineNumber\"\n                ;;\n\n            $'\\t'*'</temperature>'*)\n                mig2gpu_gpu_temperature_lineNumberEnd=$((lineNumber+1))\n                ;;\n        esac\n    done\n}\n\n\nfunction addOriginalGpuIndexAsDeviceId {\n    local afterUuidLineStart=$((mig2gpu_gpu_lineNumberStart+3))\n    local afterUuidGpuLength=$((mig2gpu_gpu_lineNumberEnd-afterUuidLineStart))\n    mig2gpu_nonMigGpu_out+=( \"${mig2gpu_inputLines[@]:$mig2gpu_gpu_lineNumberStart:3}\" )\n    mig2gpu_nonMigGpu_out+=( $'\\t\\t'\"<_mig2gpu_device_id>$mig2gpu_gpuIdx</_mig2gpu_device_id>\")\n    mig2gpu_nonMigGpu_out+=( \"${mig2gpu_inputLines[@]:$afterUuidLineStart:$afterUuidGpuLength}\" )\n}\n\n\nfunction replaceParentGpuWithMigs {\n\n    for ((lineNumber=mig2gpu_gpu_lineNumberStart; lineNumber<mig2gpu_gpu_lineNumberEnd; lineNumber++)); do\n        local line=\"${mig2gpu_inputLines[$lineNumber]}\"\n\n        case \"$line\" in\n\n            $'\\t'*'<mig_device>'*)\n                mig2gpu_mig_lineNumberStart=$lineNumber\n                ;;\n\n            $'\\t'*'<index>'*)\n                if [[ \"$line\" =~ $'\\t'*'<index>'(.*)'</index>' ]]; then\n                    mig2gpu_migIndex=\"${BASH_REMATCH[1]}\"\n                fi\n                ;;\n\n            $'\\t'*'_instance_id>'*)\n                if [[ \"$line\" =~ $'\\t'*'<gpu_instance_id>'(.*)'</gpu_instance_id>' ]]; then\n                    mig2gpu_migGpuInstanceId=\"${BASH_REMATCH[1]}\"\n                elif [[ \"$line\" =~ $'\\t'*'<compute_instance_id>'(.*)'</compute_instance_id>' ]]; then\n                    mig2gpu_migComputeInstanceId=\"${BASH_REMATCH[1]}\"\n                fi\n                ;;\n\n            $'\\t'*'<fb_memory_usage>'*)\n                local fbMemoryUsage_lineNumberStart=$lineNumber\n                ;;\n\n            $'\\t'*'</fb_memory_usage>'*)\n                local fbMemoryUsage_lineNumberEnd=$((lineNumber+1))\n                local fbMemryUsageLength=$((fbMemoryUsage_lineNumberEnd-fbMemoryUsage_lineNumberStart))\n                local fbMemoryUsage=(\"${mig2gpu_inputLines[@]:$fbMemoryUsage_lineNumberStart:fbMemryUsageLength}\")\n                local migFbMemoryUsage=(\"${fbMemoryUsage[@]//$'\\t\\t\\t'/$'\\t\\t'}\")\n                ;;\n\n            $'\\t'*'</mig_device>'*)\n                mig2gpu_mig_lineNumberEnd=$((lineNumber+1))\n\n                # <gpu id=\"...\">\n                mig2gpu_migGpu_out+=(\"${mig2gpu_inputLines[$mig2gpu_gpu_lineNumberStart]}\")\n                mig2gpu_migGpu_out+=($'\\t\\t'\"<product_name>$mig2gpu_productName (MIG)</product_name>\")\n\n                # We don't really use it since driver-dependent\n                # but R450 & R460 form is more useful for debugging\n                # https://docs.nvidia.com/datacenter/tesla/mig-user-guide/#cuda-visible-devices\n                #\n                local migUuid=\"MIG-$mig2gpu_gpuUuid/$mig2gpu_migGpuInstanceId/$mig2gpu_migComputeInstanceId\"\n                mig2gpu_migGpu_out+=($'\\t\\t'\"<uuid>$migUuid</uuid>\")\n\n                # https://github.com/NVIDIA/nvidia-container-runtime#nvidia_visible_devices\n                # The scheme <GPU Device Index>:<MIG Device Index> is not annotated with any\n                # driver version caveats, so adding this for stability and simplicity\n                local migDeviceId=\"$mig2gpu_gpuIdx:$mig2gpu_migIndex\"\n                mig2gpu_migGpu_out+=($'\\t\\t'\"<_mig2gpu_device_id>$migDeviceId</_mig2gpu_device_id>\")\n\n                # if using this with CGROUP workaround we need the minor number to be from nvidia-caps access\n                if [[ \"$ENABLE_MIG_GPUS_FOR_CGROUPS\" == 1 ]]; then\n                    mig_minor_dev_num=`cat /proc/driver/nvidia-caps/mig-minors | grep gpu$mig2gpu_gpuIdx/gi$mig2gpu_migGpuInstanceId/access | cut -d ' ' -f 2`\n                    mig2gpu_migGpu_out+=($'\\t\\t'\"<minor_number>$mig_minor_dev_num</minor_number>\")\n                else\n                    mig2gpu_migGpu_out+=(\"$mig2gpu_gpuMinorNumber\")\n                fi\n                mig2gpu_migGpu_out+=(\"${migFbMemoryUsage[@]}\")\n\n                local gpuUtilizationLength=$((mig2gpu_gpu_utilization_lineNumberEnd - mig2gpu_gpu_utilization_lineNumberStart))\n                local gpuUtilization=(\"${mig2gpu_inputLines[@]:$mig2gpu_gpu_utilization_lineNumberStart:gpuUtilizationLength}\")\n                mig2gpu_migGpu_out+=(\"${gpuUtilization[@]}\")\n\n                local gpuTemperatureLength=$((mig2gpu_gpu_temperature_lineNumberEnd - mig2gpu_gpu_temperature_lineNumberStart))\n                mig2gpu_migGpu_out+=(\"${mig2gpu_inputLines[@]:$mig2gpu_gpu_temperature_lineNumberStart:$gpuTemperatureLength}\")\n\n                # </gpu>\n                mig2gpu_migGpu_out+=(\"${mig2gpu_inputLines[$((mig2gpu_gpu_lineNumberEnd-1))]}\")\n                ;;\n        esac\n    done\n}\n\n\nfunction processGpuElement {\n    processParentGpuGlobals\n\n    if [[ \"$mig2gpu_migEnabled\" != \"1\" ]]; then\n        addOriginalGpuIndexAsDeviceId\n    else\n        # scan gpu element lines twice because the mig section appears before\n        # the info needed from parent\n        replaceParentGpuWithMigs\n    fi\n}\n\n\nfunction mig2gpuMain {\n    local line\n    local lineNumber\n\n    # simplified regex-free parser relying on the fact\n    # that nvidia-smi output is pretty-printed with tabs\n    while IFS= read -r line; do\n        lineNumber=${#mig2gpu_inputLines[@]}\n        mig2gpu_inputLines+=(\"$line\")\n\n        case \"$line\" in\n\n            # document-level tags\n            '<'*)\n                mig2gpu_out+=(\"$line\")\n                ;;\n\n            $'\\t<gpu '*)\n                # start of a new GPU element\n                mig2gpu_gpu_lineNumberStart=\"$lineNumber\"\n                ;;\n\n            $'\\t</gpu'*)\n                # end of a GPU element\n                mig2gpu_gpu_lineNumberEnd=$((lineNumber+1))\n                processGpuElement\n                ;;\n\n            $'\\t<driver_version>'*)\n                mig2gpu_driverVersion=\"$line\"\n                ;;\n\n            *)\n                # ignore infeasible\n                ;;\n\n        esac\n    done < \"$NVIDIA_SMI_QX\"\n\n    for outLine in \"${mig2gpu_out[@]}\"; do\n        printf '%s\\n' \"$outLine\"\n        if [[ \"$outLine\" =~ '<nvidia_smi_log>' ]]; then\n            printf '%s\\n' \"$mig2gpu_driverVersion\"\n            printf '%s\\n' \"${mig2gpu_migGpu_out[@]}\"\n\n            # output non-MIG only if ENABLE_NON_MIG_GPUS is set\n            # https://docs.nvidia.com/datacenter/tesla/mig-user-guide/#cuda-visible-devices\n            # currently mixing MIG and non-MIG GPUs is not supported by the driver\n            # \"Note that these constraints may be relaxed in future NVIDIA driver releases for MIG\"\n            if [[ \"${#mig2gpu_migGpu_out[@]}\" == \"0\" || \"$ENABLE_NON_MIG_GPUS\" == \"1\" ]]; then\n                printf '%s\\n' \"${mig2gpu_nonMigGpu_out[@]}\"\n            fi\n        fi\n    done\n}\n\n\nmig2gpuMain\n"
  },
  {
    "path": "examples/MIG-Support/yarn-unpatched/scripts/nvidia-container-cli-wrapper.sh",
    "content": "#!/bin/bash\n\n# Copyright (c) 2022-2025, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# This script is executed by the `nvidia` Docker runtime on the host before creating the container.\n# It intercepts the device assigned by YARN, a 0-based index and converts it to a pair\n# GPU device index:MIG device index that is stored in _mig2gpu_device_is elememnt\n# by mig2gpu.sh in the nvidia-smi-wrapper.sh which limits all the processes withing the container\n# to the corresponding MIG Compute Instance https://github.com/NVIDIA/nvidia-container-runtime#nvidia_visible_devices.\n\n# customize in /etc/nvidia-container-runtime/config.toml\n# [nvidia-container-cli]\n# environment = [ \"VAR1=VAL1\", \"VAR2=VAL2\" ]\nREAL_NVIDIA_CONTAINER_CLI_PATH=${REAL_NVIDIA_CONTAINER_CLI_PATH:-\"/usr/bin/nvidia-container-cli\"}\nREAL_NVIDIA_SMI_PATH=${REAL_NVIDIA_SMI_PATH:-\"/usr/bin/nvidia-smi\"}\nMIG_AS_GPU_ENABLED=${MIG_AS_GPU_ENABLED:-\"0\"}\n\nTHIS_PATH=\"$(readlink -f $0)\"\nTHIS_DIR=\"$(dirname $THIS_PATH)\"\n\nif [[ \"$MIG_AS_GPU_ENABLED\" == \"1\" ]]; then\n    realArgs=()\n    for arg in \"$@\"; do\n        case \"$arg\" in\n\n            \"--device=\"*)\n                nvcli_migDeviceIds=()\n                # map CSV of indexes 0,3,10 to ,0,3,10,\n                # so we can do an easy \"contains\" test\n                # the device N is included if deviceArgWithLeadingTrailingComma\n                # matches =~ \",N,\"\n                deviceArgWithLeadingTrailingComma=\",${arg#*=},\"\n                current_gpu_idx=-1\n                while read -r line; do\n                    case \"$line\" in\n\n                        # found the device id constructed in mig2gpu.sh with the original nvidia-smi enumeration\n                        # gpu index, mig index\n                        *\"<_mig2gpu_device_id>\"*)\n                            current_gpu_idx=$((current_gpu_idx+1))\n                            if [[ \"$deviceArgWithLeadingTrailingComma\" =~ \",${current_gpu_idx},\" && \"$line\" =~ '<_mig2gpu_device_id>'(.*)'</_mig2gpu_device_id>' ]]; then\n                                nvcli_migDeviceIds+=(\"${BASH_REMATCH[1]}\")\n                            fi\n                            ;;\n\n                    esac\n                done < <(\"$REAL_NVIDIA_SMI_PATH\" -q -x | \"$THIS_DIR/mig2gpu.sh\")\n                # make sure the above redirect into the while read loop does not use the here-string (<<<) method because different\n                # versions of bash materialize newlines differently in the string. Older versions treat it as a single\n                # line and newer versions leave it as a multiline string. Here it needs to be a multiline.\n\n                if (( ${#nvcli_migDeviceIds[@]} )); then\n                    migDeviceIdsCsv=$(IFS=','; echo \"${nvcli_migDeviceIds[*]}\")\n                    realArgs+=(\"--device=$migDeviceIdsCsv\")\n                else\n                    realArgs+=(\"$arg\")\n                fi\n\n                ;;\n\n            *)\n                realArgs+=(\"$arg\")\n                ;;\n\n        esac\n    done\n    \"$REAL_NVIDIA_CONTAINER_CLI_PATH\" \"${realArgs[@]}\"\nelse\n    \"$REAL_NVIDIA_CONTAINER_CLI_PATH\" \"$@\"\nfi\n"
  },
  {
    "path": "examples/MIG-Support/yarn-unpatched/scripts/nvidia-smi",
    "content": "#!/bin/bash\n\n# Copyright (c) 2022, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# This script is designed as a drop-in replacement for YARN node manager's automatic\n# MIG-aware GPU discovery. YARN config\n# yarn.nodemanager.resource-plugins.gpu.path-to-discovery-executables\n# should point to this script on NM host, e.g\n# <property>\n#   <name>yarn.nodemanager.resource-plugins.gpu.path-to-discovery-executables</name>\n#   <value>/usr/local/yarn-mig-scripts/</value>\n# </property>\n\n# customize in yarn-env.sh\nREAL_NVIDIA_SMI_PATH=${REAL_NVIDIA_SMI_PATH:-\"/usr/bin/nvidia-smi\"}\nMIG_AS_GPU_ENABLED=${MIG_AS_GPU_ENABLED:-\"0\"}\n\nTHIS_PATH=\"$(readlink -f $0)\"\nTHIS_DIR=\"$(dirname $THIS_PATH)\"\n\nfor arg in \"$@\"; do\n    case \"$arg\" in\n\n        \"-q\"|\"--query\")\n            QUERY_ARG=1\n            ;;\n\n        \"-x\"|\"--xml-format\")\n            XML_FORMAT_ARG=1\n            ;;\n\n    esac\ndone\n\nif [[ \"$MIG_AS_GPU_ENABLED\" == \"1\" && \"$XML_FORMAT_ARG\" == \"1\" && \"$QUERY_ARG\" == \"1\" ]]; then\n    \"$REAL_NVIDIA_SMI_PATH\" \"$@\" | \"$THIS_DIR/mig2gpu.sh\"\nelse\n    \"$REAL_NVIDIA_SMI_PATH\" \"$@\"\nfi\n"
  },
  {
    "path": "examples/ML+DL-Examples/Optuna-Spark/README.md",
    "content": "<img src=\"http://developer.download.nvidia.com/notebooks/dlsw-notebooks/tensorrt_torchtrt_efficientnet/nvidia_logo.png\" width=\"110px\">\n\n# Distributed Hyperparameter Tuning\n\nThese examples demonstrate distributed hyperparameter tuning with [Optuna](https://optuna.readthedocs.io/en/stable/index.html) on Apache Spark, accelerated with [RAPIDS](https://rapids.ai/) on GPU. We showcase how to set up and tune XGBoost on GPU, with deployment on Spark Standalone or Databricks clusters. \n\n## Contents:\n- [Overview](#overview)\n  - [Examples](#examples)\n- [Running Optuna on Spark Standalone](#running-optuna-on-spark-standalone)\n  - [Setup Database for Optuna](#1-setup-database-for-optuna)\n  - [Setup Optuna Python Environment](#2-setup-optuna-python-environment)\n  - [Start Standalone Cluster and Run](#3-start-standalone-cluster-and-run)\n- [Running Optuna on Databricks](#running-optuna-on-databricks)\n  - [Upload Init Script and Notebook](#1-upload-init-script-and-notebook)\n  - [Create Cluster](#2-create-cluster)\n  - [Run Notebook](#3-run-notebook)\n- [Benchmarks](#benchmarks)\n- [How Does it Work?](#how-does-it-work)\n  - [Implementation Notes](#implementation-notes)\n\n---\n\n## Overview\n\nOptuna is a lightweight Python library for hyperparameter tuning, integrating state-of-the-art hyperparameter optimization algorithms.  \n\nAt a high level, we optimize hyperparameters in three steps:\n1. Wrap model training with an `objective` function that returns a loss metric.\n2. In each `trial`, suggest hyperparameters based on previous results.\n3. Create a `study` object, which executes the optimization and stores the trial results.\n\n**Local example**: tuning XGBoost with Optuna (from [Optuna docs](https://optuna.org/#code_examples)):\n```python\nimport xgboost as xgb\nimport optuna\n\n# 1. Define an objective function to be maximized.\ndef objective(trial):\n    ...\n\n    # 2. Suggest values of the hyperparameters using a trial object.\n    param = {\n        \"objective\": \"binary:logistic\",\n        \"booster\": trial.suggest_categorical(\"booster\", [\"gbtree\", \"gblinear\", \"dart\"]),\n        \"lambda\": trial.suggest_float(\"lambda\", 1e-8, 1.0, log=True),\n        \"alpha\": trial.suggest_float(\"alpha\", 1e-8, 1.0, log=True),\n        \"subsample\": trial.suggest_float(\"subsample\", 0.2, 1.0),\n        \"colsample_bytree\": trial.suggest_float(\"colsample_bytree\", 0.2, 1.0),\n    }\n\n    booster = xgb.train(param, dtrain)\n    ...\n    return accuracy\n\n# 3. Create a study object and optimize the objective function.\nstudy = optuna.create_study(direction='maximize')\nstudy.optimize(objective, n_trials=100)\n```\n\nTo run **distributed tuning** on Spark, we take the following steps:\n1. Each worker receives a copy of the same dataset. \n2. Each worker runs a subset of the trials in parallel.\n3. Workers write trial results and receive new hyperparameters using a shared database. \n\n### Examples\n\nWe provide **2 notebooks**, with differences in the backend/implementation. See [implementation notes](#implementation-notes) for more details.\n\n- `optuna-joblibspark.ipynb`: \n  - Uses the [Joblib Spark backend](https://github.com/joblib/joblib-spark) to distribute tasks on the Spark cluster.\n  - Implements *Worker-I/O*, where each worker reads the full dataset from a specified filepath (e.g., distributed file system).\n  - Builds on [this Databricks example](https://docs.databricks.com/en/machine-learning/automl-hyperparam-tuning/optuna.html). \n- `optuna-dataframe.ipynb`: \n  - Uses Spark dataframes to distribute tasks on the cluster. \n  - Implements *Spark-I/O*, where Spark reads the dataset from a specified filepath, then duplicates and repartitions it so that each worker task is mapped onto a copy of the dataset.\n  - Dataframe operations are accelerated on GPU with the [Spark-RAPIDS Accelerator](https://nvidia.github.io/spark-rapids/).\n\n## Running Optuna on Spark Standalone\n\n### 1. Setup Database for Optuna\n\nOptuna offers an RDBStorage option which allows for the persistence of experiments across different machines and processes, thereby enabling Optuna tasks to be distributed.\n\nThis section will walk you through setting up MySQL as the backend for RDBStorage in Optuna.\n\nWe highly recommend installing MySQL on the driver node. This setup eliminates concerns regarding MySQL connectivity between worker nodes and the driver, simplifying the management of database connections.  \n(For Databricks, the installation is handled by the init script).\n\n1. Install MySql:\n\n``` shell\nsudo apt install mysql-server\n```\n\n2. Configure MySQL bind address:\n\nin `/etc/mysql/mysql.conf.d/mysqld.cnf`\n\n``` shell\nbind-address    = YOUR_DRIVER_HOST_IP\nmysqlx-bind-address = YOUR_DRIVER_HOST_IP\n```\n\n3. Restart MySQL:\n\n``` shell\nsudo systemctl restart mysql.service\n```\n\n4. Setup user:\n\n```shell\nsudo mysql\n```\n\n``` mysql\nmysql> CREATE USER 'optuna_user'@'%' IDENTIFIED BY 'optuna_password';\nQuery OK, 0 rows affected (0.01 sec)\n\nmysql> GRANT ALL PRIVILEGES ON *.* TO 'optuna_user'@'%' WITH GRANT OPTION;\nQuery OK, 0 rows affected (0.01 sec)\n\nmysql> FLUSH PRIVILEGES;\nQuery OK, 0 rows affected (0.01 sec)\n\nmysql> EXIT;\nBye\n```\n\nCreate a database for Optuna:\n\n``` shell\nmysql -u optuna_user -p -e \"CREATE DATABASE IF NOT EXISTS optuna\"\n```\n\nTroubleshooting:  \n> If you encounter   \n`\"ERROR 2002 (HY000): Can't connect to local MySQL server through socket '/tmp/mysql.sock' (2)\"`,  \ntry the command:  \n`ln -s /var/run/mysqld/mysqld.sock /tmp/mysql.sock`\n\n### 2. Setup Optuna Python Environment\n\nInstall the MySQL client and create a conda environment with the required libraries.  \nWe use [RAPIDS](https://docs.rapids.ai/install/#get-rapids) for GPU-accelerated ETL. See the [docs](https://docs.rapids.ai/install/#get-rapids) for version selection.\n``` shell\nsudo apt install libmysqlclient-dev\n\nconda create -n rapids-26.02 -c rapidsai -c conda-forge -c nvidia  \\\n    cudf=26.02 cuml=26.02 python=3.10 'cuda-version>=12.0,<=12.5'\nconda activate optuna-spark\npip install mysqlclient\npip install optuna joblib joblibspark ipywidgets\n```\n\n### 3. Start Standalone Cluster and Run\n\nConfigure your standalone cluster settings. \nThis example just creates local cluster with a single GPU worker:\n```shell\nexport SPARK_HOME=/path/to/spark\nexport SPARK_WORKER_OPTS=\"-Dspark.worker.resource.gpu.amount=1  \\\n    -Dspark.worker.resource.gpu.discoveryScript=$SPARK_HOME/examples/src/main/scripts/getGpusResources.sh\"\nexport MASTER=spark://$(hostname):7077; export SPARK_WORKER_INSTANCES=1; export CORES_PER_WORKER=8\n\n${SPARK_HOME}/sbin/start-master.sh; ${SPARK_HOME}/sbin/start-worker.sh -c ${CORES_PER_WORKER} -m 16G ${MASTER}\n```\n\nYou can now run the notebook using the `optuna-spark` Python kernel!  \nThe notebook contains instructions to attach to the standalone cluster.\n\n\n## Running Optuna on Databricks\n\n### 1. Upload Init Script and Notebook\n\n- Make sure your [Databricks CLI]((https://docs.databricks.com/en/dev-tools/cli/tutorial.html)) is configured for your Databricks workspace.\n- Copy the desired notebook into your Databricks workspace. For example:\n    ```shell\n    databricks workspace import /Users/someone@example.com/optuna/optuna-joblibspark.ipynb --format JUPYTER --file optuna-joblibspark.ipynb\n    ```\n- Copy the init script ```databricks/init_optuna.sh```:\n    ```shell\n    databricks workspace import /Users/someone@example.com/optuna/init_optuna.sh --format AUTO --file databricks/init_optuna.sh\n    ```\n\n### 2. Create Cluster\n\n*For Databricks Azure*: Use the cluster startup script, which is configured to create a 4 node GPU cluster:\n```shell\nexport INIT_PATH=/Users/someone@example.com/optuna/init_optuna.sh\ncd databricks\nchmod +x start_cluster.sh\n./start_cluster.sh\n```\n\nOr, create a cluster via the web UI:\n- Go to `Compute > Create compute` and set the desired cluster settings.    \n- Under `Advanced Options > Init Scripts`, upload the init script from your workspace.\n- Under `Advanced Options > Spark > Environment variables`, set `LIBCUDF_CUFILE_POLICY=OFF`.\n- Make sure to use a GPU cluster and include task GPU resources.\n\nThe init script will install the required libraries on all nodes, including RAPIDS and the Spark-RAPIDS plugin for GPU-accelerated ETL. On the driver, it will setup the MySQL server backend. \n\n### 3. Run Notebook\n\nLocate the notebook in your workspace and click on `Connect` to attach it to the cluster. The notebook is ready to run!\n\n## Benchmarks\n\nThe graph below shows running times comparing distributed (8 GPUs) vs. single GPU hyperparameter tuning with 100 trials on synthetic regression datasets.  \n\n![Databricks benchmarking results](images/runtimes.png)\n\n## How does it work?\n\nThe Optuna tasks will be serialized into bytes and distributed to Spark workers to run. The Optuna task on the executor side that loads the Optuna study from RDBStorage, and then runs its set of trials.\n\nDuring tuning, the Optuna tasks send intermediate results back to RDBStorage to persist, and ask for the parameters from RDBStorage sampled by Optuna on the driver to run next.\n\n**Using JoblibSpark**: each Optuna task is a Spark application that has only 1 job, 1 stage, 1 task, and the Spark application will be submitted on the local threads. Here the parameter `n_jobs` configures the Spark backend to limit how many Spark applications are submitted at the same time.  \n\nThus Optuna with JoblibSpark uses Spark application level parallelism, rather than task-level parallelism. For larger datasets, ensure that a single XGBoost task can run on a single node without any CPU/GPU OOM.  \n\nApplication parallelism with JoblibSpark:  \n\n![Optuna on JoblibSpark](images/optuna.svg)\n\n### Implementation Notes\n\n###### Data I/O:\nSince each worker requires the full dataset to perform hyperparameter tuning, there are two strategies to get the data into worker memory:\n  - **Worker I/O**: *each worker reads the dataset* from the filepath once the task has begun. In practice, this requires the dataset to be written to a distributed file system accessible to all workers prior to tuning. The `optuna-joblibspark` notebook demonstrates this.\n  - **Spark I/O**: Spark reads the dataset and *creates a copy of the dataset for each worker*, then maps the tuning task onto each copy. In practice, this enables the code to be chained to other Dataframe operations (e.g. ETL stages) without the intermediate step of writing to DBFS, at the cost of some overhead during duplication. The `optuna-dataframe` notebook demonstrates this.\n    - To achieve this, we coalesce the input Dataframe to a single partition, and recursively self-union until we have the desired number of copies (number of workers). Thus each partition will contain a duplicate of the entire dataset, and the Optuna task can be mapped directly onto the partitions.\n\n\n###### Misc:\n- Please be aware that Optuna studies will continue where they left off from previous trials; delete and recreate the study if you would like to start anew.\n- Optuna in distributed mode is **non-deterministic** (see [this link](https://optuna.readthedocs.io/en/stable/faq.html#how-can-i-obtain-reproducible-optimization-results)), as trials are executed asynchronously by executors. Deterministic behavior can be achieved using Spark barriers to coordinate reads/writes to the database.\n- Reading data with GPU using cuDF requires disabling [GPUDirect Storage](https://docs.rapids.ai/api/cudf/nightly/user_guide/io/io/#magnum-io-gpudirect-storage-integration), i.e., setting the environment variable `LIBCUDF_CUFILE_POLICY=OFF`, to be compatible with the Databricks file system. Without GDS, cuDF will use a CPU bounce buffer when reading files, but all parsing and decoding will still be accelerated by the GPU. \n- Note that the storage doesn’t store the state of the instance of samplers and pruners. To resume a study with a sampler whose seed argument is specified, [the sampler can be pickled](https://optuna.readthedocs.io/en/stable/tutorial/20_recipes/001_rdb.html#resume-study) and returned to the driver alongside the results. \n"
  },
  {
    "path": "examples/ML+DL-Examples/Optuna-Spark/optuna-examples/databricks/init_optuna.sh",
    "content": "#!/bin/bash\n#\n# Copyright (c) 2025-2026, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n\nset -x\n\nsudo rm -r /var/lib/apt/lists/*\nsudo apt clean && sudo apt update --fix-missing -y\n\nif [[ $DB_IS_DRIVER = \"TRUE\" ]]; then\n    # setup database for optuna on driver\n\n    # install mysql server\n    sudo apt install -y mysql-server\n\n    if [[ ! -f \"/etc/mysql/mysql.conf.d/mysqld.cnf\" ]]; then\n        sudo apt remove --purge mysql\\*\n        sudo apt clean && sudo apt update --fix-missing -y\n        sudo apt install -y mysql-server\n    fi\n\n    if [[ ! -f \"/etc/mysql/mysql.conf.d/mysqld.cnf\" ]]; then\n        echo \"ERROR: MYSQL installation failed\"\n        exit 1\n    fi\n\n    # configure mysql\n    BIND_ADDRESS=$DB_DRIVER_IP\n    MYSQL_CONFIG_FILE=\"/etc/mysql/mysql.conf.d/mysqld.cnf\"\n    sudo sed -i \"s/^bind-address\\s*=.*/bind-address = $BIND_ADDRESS/\" \"$MYSQL_CONFIG_FILE\"\n    sudo sed -i \"s/^mysqlx-bind-address\\s*=.*/mysqlx-bind-address = $BIND_ADDRESS/\" \"$MYSQL_CONFIG_FILE\"\n    sudo systemctl restart mysql.service\n\n    # setup user\n    OPTUNA_USER=\"optuna_user\"\n    OPTUNA_PASSWORD=\"optuna_password\"\n    sudo mysql -u root -e \"\n        CREATE USER IF NOT EXISTS '$OPTUNA_USER'@'%' IDENTIFIED BY '$OPTUNA_PASSWORD';\n        GRANT ALL PRIVILEGES ON *.* TO '$OPTUNA_USER'@'%' WITH GRANT OPTION;\n        FLUSH PRIVILEGES;\"  \nfi\n\n\n# rapids import\nSPARK_RAPIDS_VERSION=26.02.0\ncurl -L https://repo1.maven.org/maven2/com/nvidia/rapids-4-spark_2.12/${SPARK_RAPIDS_VERSION}/rapids-4-spark_2.12-${SPARK_RAPIDS_VERSION}.jar -o \\\n    /databricks/jars/rapids-4-spark_2.12-${SPARK_RAPIDS_VERSION}.jar\n\n# setup cuda: install cudatoolkit 11.8 via runfile approach\nwget https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run\nsh cuda_11.8.0_520.61.05_linux.run --silent --toolkit\n# reset symlink and update library loading paths\nrm /usr/local/cuda\nln -s /usr/local/cuda-11.8 /usr/local/cuda\n\nsudo /databricks/python3/bin/pip3 install \\\n    --extra-index-url=https://pypi.nvidia.com \\\n    \"cudf-cu11==25.02.*\" \"cuml-cu11==25.02.*\"\n\n# setup python environment\nsudo apt clean && sudo apt update --fix-missing -y\nsudo apt install pkg-config\nsudo apt install -y libmysqlclient-dev\nsudo /databricks/python3/bin/pip3 install --upgrade pip\nsudo /databricks/python3/bin/pip3 install mysqlclient xgboost\nsudo /databricks/python3/bin/pip3 install optuna joblib joblibspark\n\nif [[ $DB_IS_DRIVER = \"TRUE\" ]]; then\n    # create optuna database and study\n    sudo mysql -u $OPTUNA_USER -p$OPTUNA_PASSWORD -e \"CREATE DATABASE IF NOT EXISTS optuna;\"\nfi\nset +x\n"
  },
  {
    "path": "examples/ML+DL-Examples/Optuna-Spark/optuna-examples/databricks/start_cluster.sh",
    "content": "#!/bin/bash\n#\n# Copyright (c) 2025-2026, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n\nif [[ -z ${INIT_PATH} ]]; then\n    echo \"Please export INIT_PATH per README.md\"\n    exit 1\nfi\n\njson_config=$(cat <<EOF\n{\n    \"cluster_name\": \"optuna-xgboost-gpu\",\n    \"spark_version\": \"13.3.x-gpu-ml-scala2.12\",\n    \"spark_conf\": {\n        \"spark.task.resource.gpu.amount\": \"1\",\n        \"spark.executorEnv.PYTHONPATH\": \"/databricks/jars/rapids-4-spark_2.12-26.02.0.jar:/databricks/spark/python:/databricks/python3\",\n        \"spark.executor.cores\": \"8\",\n        \"spark.rapids.memory.gpu.minAllocFraction\": \"0.0001\",\n        \"spark.plugins\": \"com.nvidia.spark.SQLPlugin\",\n        \"spark.locality.wait\": \"0s\",\n        \"spark.sql.cache.serializer\": \"com.nvidia.spark.ParquetCachedBatchSerializer\",\n        \"spark.rapids.memory.gpu.pooling.enabled\": \"false\",\n        \"spark.executor.resource.gpu.amount\": \"1\",\n        \"spark.rapids.sql.explain\": \"NONE\",\n        \"spark.sql.execution.sortBeforeRepartition\": \"false\",\n        \"spark.rapids.sql.python.gpu.enabled\": \"true\",\n        \"spark.rapids.memory.pinnedPool.size\": \"2G\",\n        \"spark.task.maxFailures\": \"1\",\n        \"spark.python.daemon.module\": \"rapids.daemon_databricks\",\n        \"spark.rapids.sql.batchSizeBytes\": \"512m\",\n        \"spark.sql.adaptive.enabled\": \"false\",\n        \"spark.rapids.sql.format.parquet.reader.type\": \"MULTITHREADED\",\n        \"spark.sql.execution.arrow.pyspark.enabled\": \"true\",\n        \"spark.rapids.sql.format.parquet.multiThreadedRead.maxNumFilesParallel\": \"20\",\n        \"spark.sql.files.maxPartitionBytes\": \"512m\",\n        \"spark.rapids.sql.multiThreadedRead.numThreads\": \"20\",\n        \"spark.rapids.sql.concurrentGpuTasks\": \"2\"\n    },\n    \"node_type_id\": \"Standard_NC8as_T4_v3\",\n    \"driver_node_type_id\": \"Standard_NC8as_T4_v3\",\n    \"spark_env_vars\": {\n        \"LIBCUDF_CUFILE_POLICY\": \"OFF\"\n    },\n    \"autotermination_minutes\": 60,\n    \"enable_elastic_disk\": true,\n    \"init_scripts\": [\n        {\n            \"workspace\": {\n                \"destination\": \"${INIT_PATH}\"\n            }\n        }\n    ],\n    \"runtime_engine\": \"STANDARD\",\n    \"num_workers\": 4\n}\nEOF\n)\n\ndatabricks clusters create --json \"$json_config\""
  },
  {
    "path": "examples/ML+DL-Examples/Optuna-Spark/optuna-examples/optuna-dataframe.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"#\\n\",\n    \"# Copyright (c) 2024, NVIDIA CORPORATION.\\n\",\n    \"#\\n\",\n    \"# Licensed under the Apache License, Version 2.0 (the \\\"License\\\");\\n\",\n    \"# you may not use this file except in compliance with the License.\\n\",\n    \"# You may obtain a copy of the License at\\n\",\n    \"#\\n\",\n    \"#     http://www.apache.org/licenses/LICENSE-2.0\\n\",\n    \"#\\n\",\n    \"# Unless required by applicable law or agreed to in writing, software\\n\",\n    \"# distributed under the License is distributed on an \\\"AS IS\\\" BASIS,\\n\",\n    \"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\\n\",\n    \"# See the License for the specific language governing permissions and\\n\",\n    \"# limitations under the License.\\n\",\n    \"#\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"<img src=\\\"http://developer.download.nvidia.com/notebooks/dlsw-notebooks/tensorrt_torchtrt_efficientnet/nvidia_logo.png\\\" width=\\\"90px\\\">\\n\",\n    \"\\n\",\n    \"# Distributed Hyperparameter Tuning: Optuna + Spark Dataframes\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"This demo demonstrates distributed hyperparameter tuning for XGBoost using Spark Dataframes.  \\n\",\n    \"We implement best practices to precompute data and maximize computations on the GPU.  \\n\",\n    \"\\n\",\n    \"Reference: https://forecastegy.com/posts/xgboost-hyperparameter-tuning-with-optuna/\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Note:\\n\",\n    \"Before running, please make sure you've followed the relevant [setup instructions](../README.md) for your environment (standalone or databricks).\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from typing import Iterable, List, Dict, Optional, Union, Sequence, Any\\n\",\n    \"import math\\n\",\n    \"import os\\n\",\n    \"import requests\\n\",\n    \"import pandas as pd\\n\",\n    \"import optuna\\n\",\n    \"from optuna.samplers import TPESampler\\n\",\n    \"import xgboost as xgb\\n\",\n    \"from pyspark.sql import SparkSession, DataFrame\\n\",\n    \"from pyspark import TaskContext, SparkConf\\n\",\n    \"from pyspark.sql.types import StructType, StructField, DoubleType, IntegerType, StringType, BooleanType\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Download the dataset\\n\",\n    \"\\n\",\n    \"We'll use the [red wine quality dataset](https://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-red.csv) to regress wine quality based on features such as acidity, sugar content, etc.  \\n\",\n    \"\\n\",\n    \"**Note**: This example uses a small dataset for demonstration purposes. The performance advantages of distributed training are best realized with large datasets and computational workloads.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"cwd = os.getcwd()\\n\",\n    \"os.mkdir(os.path.join(cwd, \\\"data\\\")) if not os.path.exists(os.path.join(cwd, \\\"data\\\")) else None\\n\",\n    \"filepath = os.path.join(cwd, \\\"data\\\", \\\"winequality-red.csv\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"File downloaded and saved to /home/rishic/Code/myforks/spark-rapids-examples/examples/ML+DL-Examples/Optuna-Spark/optuna-examples/data/winequality-red.csv\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"url = \\\"https://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-red.csv\\\"\\n\",\n    \"\\n\",\n    \"response = requests.get(url)\\n\",\n    \"if response.status_code == 200:\\n\",\n    \"    with open(filepath, \\\"wb\\\") as f:\\n\",\n    \"        f.write(response.content)\\n\",\n    \"    print(f\\\"File downloaded and saved to {filepath}\\\")\\n\",\n    \"else:\\n\",\n    \"    print(f\\\"Failed to download the file. Status code: {response.status_code}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Part 1. Running Optuna locally\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import cudf\\n\",\n    \"from cuml.metrics.regression import mean_squared_error\\n\",\n    \"from cuml.model_selection import train_test_split\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Prepare data\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/html\": [\n       \"<div>\\n\",\n       \"<style scoped>\\n\",\n       \"    .dataframe tbody tr th:only-of-type {\\n\",\n       \"        vertical-align: middle;\\n\",\n       \"    }\\n\",\n       \"\\n\",\n       \"    .dataframe tbody tr th {\\n\",\n       \"        vertical-align: top;\\n\",\n       \"    }\\n\",\n       \"\\n\",\n       \"    .dataframe thead th {\\n\",\n       \"        text-align: right;\\n\",\n       \"    }\\n\",\n       \"</style>\\n\",\n       \"<table border=\\\"1\\\" class=\\\"dataframe\\\">\\n\",\n       \"  <thead>\\n\",\n       \"    <tr style=\\\"text-align: right;\\\">\\n\",\n       \"      <th></th>\\n\",\n       \"      <th>fixed acidity</th>\\n\",\n       \"      <th>volatile acidity</th>\\n\",\n       \"      <th>citric acid</th>\\n\",\n       \"      <th>residual sugar</th>\\n\",\n       \"      <th>chlorides</th>\\n\",\n       \"      <th>free sulfur dioxide</th>\\n\",\n       \"      <th>total sulfur dioxide</th>\\n\",\n       \"      <th>density</th>\\n\",\n       \"      <th>pH</th>\\n\",\n       \"      <th>sulphates</th>\\n\",\n       \"      <th>alcohol</th>\\n\",\n       \"      <th>quality</th>\\n\",\n       \"    </tr>\\n\",\n       \"  </thead>\\n\",\n       \"  <tbody>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>0</th>\\n\",\n       \"      <td>7.4</td>\\n\",\n       \"      <td>0.70</td>\\n\",\n       \"      <td>0.00</td>\\n\",\n       \"      <td>1.9</td>\\n\",\n       \"      <td>0.076</td>\\n\",\n       \"      <td>11.0</td>\\n\",\n       \"      <td>34.0</td>\\n\",\n       \"      <td>0.9978</td>\\n\",\n       \"      <td>3.51</td>\\n\",\n       \"      <td>0.56</td>\\n\",\n       \"      <td>9.4</td>\\n\",\n       \"      <td>5</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>1</th>\\n\",\n       \"      <td>7.8</td>\\n\",\n       \"      <td>0.88</td>\\n\",\n       \"      <td>0.00</td>\\n\",\n       \"      <td>2.6</td>\\n\",\n       \"      <td>0.098</td>\\n\",\n       \"      <td>25.0</td>\\n\",\n       \"      <td>67.0</td>\\n\",\n       \"      <td>0.9968</td>\\n\",\n       \"      <td>3.20</td>\\n\",\n       \"      <td>0.68</td>\\n\",\n       \"      <td>9.8</td>\\n\",\n       \"      <td>5</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>2</th>\\n\",\n       \"      <td>7.8</td>\\n\",\n       \"      <td>0.76</td>\\n\",\n       \"      <td>0.04</td>\\n\",\n       \"      <td>2.3</td>\\n\",\n       \"      <td>0.092</td>\\n\",\n       \"      <td>15.0</td>\\n\",\n       \"      <td>54.0</td>\\n\",\n       \"      <td>0.9970</td>\\n\",\n       \"      <td>3.26</td>\\n\",\n       \"      <td>0.65</td>\\n\",\n       \"      <td>9.8</td>\\n\",\n       \"      <td>5</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>3</th>\\n\",\n       \"      <td>11.2</td>\\n\",\n       \"      <td>0.28</td>\\n\",\n       \"      <td>0.56</td>\\n\",\n       \"      <td>1.9</td>\\n\",\n       \"      <td>0.075</td>\\n\",\n       \"      <td>17.0</td>\\n\",\n       \"      <td>60.0</td>\\n\",\n       \"      <td>0.9980</td>\\n\",\n       \"      <td>3.16</td>\\n\",\n       \"      <td>0.58</td>\\n\",\n       \"      <td>9.8</td>\\n\",\n       \"      <td>6</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>4</th>\\n\",\n       \"      <td>7.4</td>\\n\",\n       \"      <td>0.70</td>\\n\",\n       \"      <td>0.00</td>\\n\",\n       \"      <td>1.9</td>\\n\",\n       \"      <td>0.076</td>\\n\",\n       \"      <td>11.0</td>\\n\",\n       \"      <td>34.0</td>\\n\",\n       \"      <td>0.9978</td>\\n\",\n       \"      <td>3.51</td>\\n\",\n       \"      <td>0.56</td>\\n\",\n       \"      <td>9.4</td>\\n\",\n       \"      <td>5</td>\\n\",\n       \"    </tr>\\n\",\n       \"  </tbody>\\n\",\n       \"</table>\\n\",\n       \"</div>\"\n      ],\n      \"text/plain\": [\n       \"   fixed acidity  volatile acidity  citric acid  residual sugar  chlorides  \\\\\\n\",\n       \"0            7.4              0.70         0.00             1.9      0.076   \\n\",\n       \"1            7.8              0.88         0.00             2.6      0.098   \\n\",\n       \"2            7.8              0.76         0.04             2.3      0.092   \\n\",\n       \"3           11.2              0.28         0.56             1.9      0.075   \\n\",\n       \"4            7.4              0.70         0.00             1.9      0.076   \\n\",\n       \"\\n\",\n       \"   free sulfur dioxide  total sulfur dioxide  density    pH  sulphates  \\\\\\n\",\n       \"0                 11.0                  34.0   0.9978  3.51       0.56   \\n\",\n       \"1                 25.0                  67.0   0.9968  3.20       0.68   \\n\",\n       \"2                 15.0                  54.0   0.9970  3.26       0.65   \\n\",\n       \"3                 17.0                  60.0   0.9980  3.16       0.58   \\n\",\n       \"4                 11.0                  34.0   0.9978  3.51       0.56   \\n\",\n       \"\\n\",\n       \"   alcohol  quality  \\n\",\n       \"0      9.4        5  \\n\",\n       \"1      9.8        5  \\n\",\n       \"2      9.8        5  \\n\",\n       \"3      9.8        6  \\n\",\n       \"4      9.4        5  \"\n      ]\n     },\n     \"execution_count\": 6,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"data = cudf.read_csv(filepath, delimiter=\\\";\\\")\\n\",\n    \"data.head()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Prepare the train/validation sets. Precompute the Quantile DMatrix, which is used by histogram-based tree methods to save memory.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"X = data.iloc[:, :-1].values\\n\",\n    \"y = data[\\\"quality\\\"].values\\n\",\n    \"X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)\\n\",\n    \"Xy_train_qdm = xgb.QuantileDMatrix(X_train, y_train)  # Precompute Quantile DMatrix to avoid repeated quantization every trial.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Objective function\\n\",\n    \"\\n\",\n    \"We define the objective and a hyperparameter search space to optimize via the `trial.suggest_` methods.  \\n\",\n    \"\\n\",\n    \"In each trial, new hyperparameters will be suggested based on previous results. See [optuna.trial.Trial](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.trial.Trial.html) API for a full list of functions.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def objective(trial):\\n\",\n    \"    params = {\\n\",\n    \"        \\\"objective\\\": \\\"reg:squarederror\\\",\\n\",\n    \"        \\\"verbosity\\\": 0,\\n\",\n    \"        \\\"learning_rate\\\": trial.suggest_float(\\\"learning_rate\\\", 1e-3, 0.1, log=True),\\n\",\n    \"        \\\"max_depth\\\": trial.suggest_int(\\\"max_depth\\\", 1, 10),\\n\",\n    \"        \\\"subsample\\\": trial.suggest_float(\\\"subsample\\\", 0.05, 1.0),\\n\",\n    \"        \\\"colsample_bytree\\\": trial.suggest_float(\\\"colsample_bytree\\\", 0.05, 1.0),\\n\",\n    \"        \\\"min_child_weight\\\": trial.suggest_int(\\\"min_child_weight\\\", 1, 20),\\n\",\n    \"        \\\"tree_method\\\": \\\"gpu_hist\\\",\\n\",\n    \"        \\\"device\\\": \\\"cuda\\\",\\n\",\n    \"    }\\n\",\n    \"\\n\",\n    \"    booster = xgb.train(params=params, dtrain=Xy_train_qdm, num_boost_round=trial.suggest_int(\\\"num_boost_round\\\", 100, 500))\\n\",\n    \"    predictions = booster.inplace_predict(X_val)\\n\",\n    \"    rmse = mean_squared_error(y_val, predictions, squared=False).get()\\n\",\n    \"    \\n\",\n    \"    return rmse   \"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Create the study and optimize. By default, the study results will be stored in memory.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 9,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[I 2024-12-11 23:47:48,356] A new study created in memory with name: optuna-xgboost-local\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[I 2024-12-11 23:47:48,724] Trial 0 finished with value: 0.6377619522504244 and parameters: {'learning_rate': 0.005611516415334507, 'max_depth': 10, 'subsample': 0.7453942447208348, 'colsample_bytree': 0.6187255599871848, 'min_child_weight': 4, 'num_boost_round': 162}. Best is trial 0 with value: 0.6377619522504244.\\n\",\n      \"[I 2024-12-11 23:47:49,676] Trial 1 finished with value: 0.6703788974319568 and parameters: {'learning_rate': 0.0013066739238053278, 'max_depth': 9, 'subsample': 0.6210592611560484, 'colsample_bytree': 0.7226689489062432, 'min_child_weight': 1, 'num_boost_round': 488}. Best is trial 0 with value: 0.6377619522504244.\\n\",\n      \"[I 2024-12-11 23:47:49,819] Trial 2 finished with value: 0.6181751362616256 and parameters: {'learning_rate': 0.04622589001020832, 'max_depth': 3, 'subsample': 0.2227337188467456, 'colsample_bytree': 0.22423428436076215, 'min_child_weight': 7, 'num_boost_round': 310}. Best is trial 2 with value: 0.6181751362616256.\\n\",\n      \"[I 2024-12-11 23:47:49,942] Trial 3 finished with value: 0.6698576232920956 and parameters: {'learning_rate': 0.007309539835912915, 'max_depth': 3, 'subsample': 0.6312602499862605, 'colsample_bytree': 0.18251916761943976, 'min_child_weight': 6, 'num_boost_round': 246}. Best is trial 2 with value: 0.6181751362616256.\\n\",\n      \"[I 2024-12-11 23:47:50,060] Trial 4 finished with value: 0.6704590546150145 and parameters: {'learning_rate': 0.008168455894760165, 'max_depth': 8, 'subsample': 0.23969009305044175, 'colsample_bytree': 0.538522716492931, 'min_child_weight': 12, 'num_boost_round': 118}. Best is trial 2 with value: 0.6181751362616256.\\n\",\n      \"[I 2024-12-11 23:47:50,214] Trial 5 finished with value: 0.6088806682631155 and parameters: {'learning_rate': 0.016409286730647923, 'max_depth': 2, 'subsample': 0.11179901333601554, 'colsample_bytree': 0.9514412603906666, 'min_child_weight': 20, 'num_boost_round': 424}. Best is trial 5 with value: 0.6088806682631155.\\n\",\n      \"[I 2024-12-11 23:47:50,289] Trial 6 finished with value: 0.7103495949713845 and parameters: {'learning_rate': 0.0040665633135147945, 'max_depth': 1, 'subsample': 0.700021375186549, 'colsample_bytree': 0.4681448690526212, 'min_child_weight': 3, 'num_boost_round': 298}. Best is trial 5 with value: 0.6088806682631155.\\n\",\n      \"[I 2024-12-11 23:47:50,693] Trial 7 finished with value: 0.7255199474722185 and parameters: {'learning_rate': 0.001171593739230706, 'max_depth': 10, 'subsample': 0.29584098252001606, 'colsample_bytree': 0.6793961701362828, 'min_child_weight': 7, 'num_boost_round': 308}. Best is trial 5 with value: 0.6088806682631155.\\n\",\n      \"[I 2024-12-11 23:47:50,858] Trial 8 finished with value: 0.6060010014477214 and parameters: {'learning_rate': 0.0123999678368461, 'max_depth': 2, 'subsample': 0.9711053963763306, 'colsample_bytree': 0.7863761821930588, 'min_child_weight': 19, 'num_boost_round': 458}. Best is trial 8 with value: 0.6060010014477214.\\n\",\n      \"[I 2024-12-11 23:47:51,199] Trial 9 finished with value: 0.6292433375858283 and parameters: {'learning_rate': 0.015696396388661146, 'max_depth': 10, 'subsample': 0.13406787694932354, 'colsample_bytree': 0.23618371929818793, 'min_child_weight': 1, 'num_boost_round': 230}. Best is trial 8 with value: 0.6060010014477214.\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"study = optuna.create_study(study_name=\\\"optuna-xgboost-local\\\", sampler=TPESampler(seed=42))\\n\",\n    \"study.optimize(objective, n_trials=10)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 10,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Best RMSE:  0.6060010014477214\\n\",\n      \"Best hyperparameters:  {'learning_rate': 0.0123999678368461, 'max_depth': 2, 'subsample': 0.9711053963763306, 'colsample_bytree': 0.7863761821930588, 'min_child_weight': 19, 'num_boost_round': 458}\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"trial = study.best_trial\\n\",\n    \"print(\\\"Best RMSE: \\\", trial.value)\\n\",\n    \"print(\\\"Best hyperparameters: \\\", trial.params)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Part 2. Distributed Optuna on Spark \"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### PySpark\\n\",\n    \"\\n\",\n    \"For standalone users, we need to create the Spark session with the Spark-Rapids plugin. For Databricks users, the Spark session will be preconfigured and this cell can be skipped.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 11,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Plugin file already exists. Skipping download.\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"24/12/11 23:47:51 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)\\n\",\n      \"24/12/11 23:47:51 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\\n\",\n      \"24/12/11 23:47:52 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\\n\",\n      \"Setting default log level to \\\"WARN\\\".\\n\",\n      \"To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\\n\",\n      \"24/12/11 23:47:52 WARN RapidsPluginUtils: RAPIDS Accelerator 25.02.1 using cudf 25.02.1, private revision bd4e99e18e20234ee0c54f95f4b0bfce18a6255e\\n\",\n      \"24/12/11 23:47:52 WARN RapidsPluginUtils: RAPIDS Accelerator is enabled, to disable GPU support set `spark.rapids.sql.enabled` to false.\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"def get_rapids_jar():\\n\",\n    \"    SPARK_RAPIDS_VERSION = \\\"26.02.0\\\"\\n\",\n    \"    rapids_jar = f\\\"rapids-4-spark_2.12-{SPARK_RAPIDS_VERSION}.jar\\\"\\n\",\n    \"    if not os.path.exists(rapids_jar):\\n\",\n    \"        print(\\\"Downloading Spark Rapids jar\\\")\\n\",\n    \"        url = f\\\"https://repo1.maven.org/maven2/com/nvidia/rapids-4-spark_2.12/{SPARK_RAPIDS_VERSION}/{rapids_jar}\\\"\\n\",\n    \"        response = requests.get(url)\\n\",\n    \"        if response.status_code == 200:\\n\",\n    \"            with open(rapids_jar, \\\"wb\\\") as f:\\n\",\n    \"                f.write(response.content)\\n\",\n    \"            print(f\\\"File '{rapids_jar}' downloaded and saved successfully.\\\")\\n\",\n    \"        else:\\n\",\n    \"            print(f\\\"Failed to download the plugin. Status code: {response.status_code}\\\")\\n\",\n    \"    else:\\n\",\n    \"        print(\\\"Plugin file already exists. Skipping download.\\\")\\n\",\n    \"    return rapids_jar\\n\",\n    \"\\n\",\n    \"def initialize_spark(rapids_jar: str):\\n\",\n    \"    import socket\\n\",\n    \"    hostname = socket.gethostname()\\n\",\n    \"    conda_env = os.environ.get(\\\"CONDA_PREFIX\\\")\\n\",\n    \"\\n\",\n    \"    conf = SparkConf()\\n\",\n    \"    conf.setMaster(f\\\"spark://{hostname}:7077\\\")  # Assuming master is on host and default port. \\n\",\n    \"    conf.set(\\\"spark.task.maxFailures\\\", \\\"1\\\")\\n\",\n    \"    conf.set(\\\"spark.task.resource.gpu.amount\\\", f\\\"{1/4}\\\")  # Setting to 1/4 for single-node demo. In practice, set to 1. \\n\",\n    \"    conf.set(\\\"spark.executor.resource.gpu.amount\\\", \\\"1\\\")\\n\",\n    \"    conf.set(\\\"spark.sql.execution.arrow.pyspark.enabled\\\", \\\"true\\\")\\n\",\n    \"    conf.set(\\\"spark.pyspark.python\\\", f\\\"{conda_env}/bin/python\\\")\\n\",\n    \"    conf.set(\\\"spark.pyspark.driver.python\\\", f\\\"{conda_env}/bin/python\\\")\\n\",\n    \"    conf.set(\\\"spark.jars\\\", rapids_jar)\\n\",\n    \"    conf.set(\\\"spark.executorEnv.PYTHONPATH\\\", rapids_jar)\\n\",\n    \"    conf.set(\\\"spark.rapids.memory.gpu.minAllocFraction\\\", \\\"0.0001\\\")\\n\",\n    \"    conf.set(\\\"spark.plugins\\\", \\\"com.nvidia.spark.SQLPlugin\\\")\\n\",\n    \"    conf.set(\\\"spark.locality.wait\\\", \\\"0s\\\")\\n\",\n    \"    conf.set(\\\"spark.sql.cache.serializer\\\", \\\"com.nvidia.spark.ParquetCachedBatchSerializer\\\")\\n\",\n    \"    conf.set(\\\"spark.rapids.memory.gpu.pooling.enabled\\\", \\\"false\\\")\\n\",\n    \"    conf.set(\\\"spark.sql.execution.sortBeforeRepartition\\\", \\\"false\\\")\\n\",\n    \"    conf.set(\\\"spark.rapids.sql.format.parquet.reader.type\\\", \\\"MULTITHREADED\\\")\\n\",\n    \"    conf.set(\\\"spark.rapids.sql.format.parquet.multiThreadedRead.maxNumFilesParallel\\\", \\\"20\\\")\\n\",\n    \"    conf.set(\\\"spark.rapids.sql.multiThreadedRead.numThreads\\\", \\\"20\\\")\\n\",\n    \"    conf.set(\\\"spark.rapids.sql.python.gpu.enabled\\\", \\\"true\\\")\\n\",\n    \"    conf.set(\\\"spark.rapids.memory.pinnedPool.size\\\", \\\"2G\\\")\\n\",\n    \"    conf.set(\\\"spark.python.daemon.module\\\", \\\"rapids.daemon\\\")\\n\",\n    \"    conf.set(\\\"spark.rapids.sql.batchSizeBytes\\\", \\\"512m\\\")\\n\",\n    \"    conf.set(\\\"spark.sql.adaptive.enabled\\\", \\\"false\\\")\\n\",\n    \"    conf.set(\\\"spark.sql.files.maxPartitionBytes\\\", \\\"512m\\\")\\n\",\n    \"    conf.set(\\\"spark.rapids.sql.concurrentGpuTasks\\\", \\\"2\\\")\\n\",\n    \"    conf.set(\\\"spark.rapids.sql.explain\\\", \\\"NONE\\\")\\n\",\n    \"    \\n\",\n    \"    spark = SparkSession.builder.appName(\\\"optuna-spark-xgboost\\\").config(conf=conf).getOrCreate()\\n\",\n    \"    return spark\\n\",\n    \"\\n\",\n    \"if 'spark' not in globals():\\n\",\n    \"    rapids_jar = get_rapids_jar()\\n\",\n    \"    spark = initialize_spark(rapids_jar)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Helper Class\\n\",\n    \"\\n\",\n    \"First we'll define a helper class. This will store the hyperparameters we want optimized in each trial, and easily convert that into a schema for the output dataframe.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 12,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"class OptunaParams:\\n\",\n    \"    def __init__(self):\\n\",\n    \"        self.hyperparameters = {}\\n\",\n    \"\\n\",\n    \"    def add_categorical_param(self, name: str, choices: Sequence[Union[None, bool, int, float, str]]):\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"        Adds a categorical hyperparameter to be tuned via Optuna's trial.suggest_categorical().\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"        self.hyperparameters[name] = { \\\"type\\\": \\\"categorical\\\", \\\"choices\\\": choices }\\n\",\n    \"    \\n\",\n    \"    def add_int_param(self, name: str, low: int, high: int, step: int = 1, log: bool = False):\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"        Adds an integer hyperparameter to be tuned via Optuna's trial.suggest_int().\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"        self.hyperparameters[name] = { \\\"type\\\": \\\"int\\\", \\\"low\\\": low, \\\"high\\\": high, \\\"step\\\": step, \\\"log\\\": log }\\n\",\n    \"    \\n\",\n    \"    def add_float_param(self, name: str, low: float, high: float, step: Optional[float] = None, log: bool = False):\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"        Adds a float hyperparameter to be tuned via Optuna's trial.suggest_float().\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"        self.hyperparameters[name] = { \\\"type\\\": \\\"float\\\", \\\"low\\\": low, \\\"high\\\": high, \\\"step\\\": step,\\\"log\\\": log }\\n\",\n    \"\\n\",\n    \"    def suggest_params(self, trial) -> Dict[str, Union[int, float, str, bool]]:\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"        Converts the hyperparameter space into a dictionary of suggested values in Optuna format,\\n\",\n    \"        to be called within the objective function.\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"        suggested_params = {}\\n\",\n    \"        for name, config in self.hyperparameters.items():\\n\",\n    \"            if config[\\\"type\\\"] == \\\"categorical\\\":\\n\",\n    \"                suggested_params[name] = trial.suggest_categorical(name, config[\\\"choices\\\"])\\n\",\n    \"            elif config[\\\"type\\\"] == \\\"int\\\":\\n\",\n    \"                suggested_params[name] = trial.suggest_int(\\n\",\n    \"                    name, config[\\\"low\\\"], config[\\\"high\\\"], step=config[\\\"step\\\"], log=config[\\\"log\\\"]\\n\",\n    \"                )\\n\",\n    \"            elif config[\\\"type\\\"] == \\\"float\\\":\\n\",\n    \"                suggested_params[name] = trial.suggest_float(\\n\",\n    \"                    name, config[\\\"low\\\"], config[\\\"high\\\"], step=config.get(\\\"step\\\", None), log=config[\\\"log\\\"]\\n\",\n    \"                )\\n\",\n    \"        return suggested_params\\n\",\n    \"\\n\",\n    \"    def to_schema(self) -> StructType:\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"        Converts the hyperparameter space into a Spark StructType output schema.\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"        fields = []\\n\",\n    \"        for name, config in self.hyperparameters.items():\\n\",\n    \"            if config[\\\"type\\\"] == \\\"float\\\":\\n\",\n    \"                fields.append(StructField(name, DoubleType(), False))\\n\",\n    \"            elif config[\\\"type\\\"] == \\\"int\\\":\\n\",\n    \"                fields.append(StructField(name, IntegerType(), False))\\n\",\n    \"            elif config[\\\"type\\\"] == \\\"categorical\\\":\\n\",\n    \"                if isinstance(config[\\\"choices\\\"][0], str):\\n\",\n    \"                    fields.append(StructField(name, StringType(), False))\\n\",\n    \"                elif isinstance(config[\\\"choices\\\"][0], bool):\\n\",\n    \"                    fields.append(StructField(name, BooleanType(), False))\\n\",\n    \"                elif isinstance(config[\\\"choices\\\"][0], (int, float)):\\n\",\n    \"                    fields.append(StructField(name, DoubleType(), False))\\n\",\n    \"                else:\\n\",\n    \"                    raise ValueError(f\\\"Unsupported categorical type for field {name}\\\")\\n\",\n    \"        \\n\",\n    \"        # Study will also return the best achieved loss:\\n\",\n    \"        fields.append(StructField(\\\"best_value\\\", DoubleType(), False)) \\n\",\n    \"        return StructType(fields)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Optuna Task\\n\",\n    \"\\n\",\n    \"This implementation demonstrates **Spark I/O**.\\n\",\n    \"\\n\",\n    \"This means that Spark will read the dataset and create a duplicate of the dataset for each worker (1 partition = 1 duplicate), then map the tuning task onto each partition.  \\n\",\n    \"In practice, this enables the code to be chained to other Dataframe operations (e.g. ETL stages) without the intermediate step of writing to DBFS, at the cost of some overhead during duplication.\\n\",\n    \"\\n\",\n    \"For the alternative implementation using **Worker I/O**, see the [JoblibSpark notebook](optuna-joblibspark.ipynb). \"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"In the task, each worker will:\\n\",\n    \"1. Concatenate the pandas partition batches to form the dataset\\n\",\n    \"2. Load the study from the MySQL storage backend\\n\",\n    \"3. Optimize over the objective for the assigned number of trials, sending results back to the database after each iteration\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 13,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def task_udf(pdf_iter: Iterable[pd.DataFrame],\\n\",\n    \"             xgb_params: Dict[str, Any],\\n\",\n    \"             optuna_params: OptunaParams,\\n\",\n    \"             trials_per_task: List[int],            \\n\",\n    \"             driver_ip: str,\\n\",\n    \"             study_name: str,\\n\",\n    \"             seed: int) -> Iterable[pd.DataFrame]:\\n\",\n    \"\\n\",\n    \"    import cudf\\n\",\n    \"    from cuml.metrics.regression import mean_squared_error\\n\",\n    \"    from cuml.model_selection import train_test_split\\n\",\n    \"    \\n\",\n    \"    tc = TaskContext.get()\\n\",\n    \"    assert \\\"gpu\\\" in tc.resources(), \\\"GPU resource not found.\\\"\\n\",\n    \"    num_trials = trials_per_task[tc.partitionId()]\\n\",\n    \"\\n\",\n    \"    df_list = []\\n\",\n    \"    for pdf in pdf_iter:\\n\",\n    \"        df_list.append(cudf.DataFrame.from_pandas(pdf))\\n\",\n    \"    \\n\",\n    \"    data = cudf.concat(df_list)\\n\",\n    \"    X = data.iloc[:, :-1].values\\n\",\n    \"    y = data[\\\"quality\\\"].values\\n\",\n    \"    X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)\\n\",\n    \"\\n\",\n    \"    tuning_max_bin = \\\"max_bin\\\" in optuna_params.hyperparameters\\n\",\n    \"    if not tuning_max_bin:\\n\",\n    \"        max_bin = xgb_params.get(\\\"max_bin\\\", 256)\\n\",\n    \"        # Precompute Quantile DMatrix to avoid repeated quantization every trial.\\n\",\n    \"        Xy_train_qdm = xgb.QuantileDMatrix(X_train, y_train, max_bin=max_bin)\\n\",\n    \"\\n\",\n    \"    def objective(trial):\\n\",\n    \"        tuning_params = optuna_params.suggest_params(trial)\\n\",\n    \"        xgb_params.update(tuning_params)\\n\",\n    \"\\n\",\n    \"        if tuning_max_bin:\\n\",\n    \"            # If tuning the max_bin param, we must recompute the QDM every trial, since the quantiles change.\\n\",\n    \"            if \\\"n_estimators\\\" not in xgb_params:\\n\",\n    \"                xgb_params[\\\"n_estimators\\\"] = 100  # Default value if not tuning.\\n\",\n    \"\\n\",\n    \"            model = xgb.XGBRegressor(**xgb_params)\\n\",\n    \"            model.fit(X_train, y_train)\\n\",\n    \"            booster = model.get_booster()\\n\",\n    \"        else:\\n\",\n    \"            # Train the model with xgb.train() API using the precomputed QDM.\\n\",\n    \"            num_boost_round = xgb_params.get(\\\"n_estimators\\\", 100)\\n\",\n    \"            booster = xgb.train(params=xgb_params, dtrain=Xy_train_qdm, num_boost_round=num_boost_round)\\n\",\n    \"        \\n\",\n    \"        predictions = booster.inplace_predict(X_val)\\n\",\n    \"        rmse = mean_squared_error(y_val, predictions, squared=False).get()\\n\",\n    \"        \\n\",\n    \"        return rmse\\n\",\n    \"\\n\",\n    \"    study = optuna.load_study(\\n\",\n    \"        study_name=study_name,\\n\",\n    \"        storage=f\\\"mysql://optuna_user:optuna_password@{driver_ip}/optuna\\\",\\n\",\n    \"        sampler=TPESampler(seed=seed),\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    print(f\\\"Running {num_trials} trials on partition {tc.partitionId()}.\\\")\\n\",\n    \"    study.optimize(objective, n_trials=num_trials)\\n\",\n    \"\\n\",\n    \"    result_dict = {f\\\"{key}\\\": [value] for key, value in study.best_params.items()}\\n\",\n    \"    result_dict['best_value'] = [study.best_value]\\n\",\n    \"    \\n\",\n    \"    yield pd.DataFrame(result_dict)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Setup and run the Optuna study\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Get the driver IP for the MySQL database.  \\n\",\n    \"- For standalone users, make sure you've followed the [database setup instructions](../README.md#setup-database-for-optuna). The database should be on 'localhost'. \\n\",\n    \"- For databricks users, the database should already be setup on the driver node by the init script.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 14,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# check if we're running on databricks\\n\",\n    \"on_databricks = os.environ.get(\\\"DATABRICKS_RUNTIME_VERSION\\\", False)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 15,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"MySQL database is hosted on localhost\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"if on_databricks:\\n\",\n    \"    driver_ip = spark.conf.get(\\\"spark.driver.host\\\")\\n\",\n    \"else:\\n\",\n    \"    driver_ip = \\\"localhost\\\"\\n\",\n    \"\\n\",\n    \"print(f\\\"MySQL database is hosted on {driver_ip}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Create a new study, referencing the MySQL database as the storage backend.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 16,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[I 2024-12-11 23:47:53,347] A new study created in RDB with name: optuna-xgboost-dataframe\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"<optuna.study.study.Study at 0x756423c12560>\"\n      ]\n     },\n     \"execution_count\": 16,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"study_name = \\\"optuna-xgboost-dataframe\\\"\\n\",\n    \"seed = 42\\n\",\n    \"\\n\",\n    \"try:\\n\",\n    \"    # Delete the study if it already exists\\n\",\n    \"    optuna.delete_study(\\n\",\n    \"        study_name=study_name, \\n\",\n    \"        storage=f\\\"mysql://optuna_user:optuna_password@{driver_ip}/optuna\\\"\\n\",\n    \"    )\\n\",\n    \"except:\\n\",\n    \"    pass\\n\",\n    \"\\n\",\n    \"optuna.create_study(\\n\",\n    \"    study_name=study_name,\\n\",\n    \"    storage=f\\\"mysql://optuna_user:optuna_password@{driver_ip}/optuna\\\",\\n\",\n    \"    sampler=TPESampler(seed=seed)\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Define the number of tasks, number of trials, and trials per task. \\n\",\n    \"\\n\",\n    \"**NOTE**: for standalone users running on a single worker, the 4 tasks will all be assigned to the same worker and will time-share the GPU for demonstration. In practice, you should set `spark.task.resource.gpu.amount=1` and set num_tasks to the number of workers in the cluster so that each task gets full access to the GPU.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 17,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def partition_trials(total_trials: int, total_tasks: int) -> List[int]:\\n\",\n    \"    base_size = total_trials // total_tasks\\n\",\n    \"    extra = total_trials % total_tasks\\n\",\n    \"    partitions = [base_size] * total_tasks\\n\",\n    \"    for i in range(extra):\\n\",\n    \"        partitions[i] += 1\\n\",\n    \"    \\n\",\n    \"    return partitions\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 18,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Trials per task: [25, 25, 25, 25]\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"num_tasks = 4\\n\",\n    \"num_trials = 100\\n\",\n    \"trials_per_task = partition_trials(num_trials, num_tasks)\\n\",\n    \"print(f\\\"Trials per task: {trials_per_task}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Define params\\n\",\n    \"Define the XGBoost model params and the hyperparams for Optuna to tune. \"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 19,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Keep these params consistent:\\n\",\n    \"xgb_params = {\\n\",\n    \"    \\\"objective\\\": \\\"reg:squarederror\\\",\\n\",\n    \"    \\\"verbosity\\\": 0,\\n\",\n    \"    \\\"tree_method\\\": \\\"gpu_hist\\\",\\n\",\n    \"    \\\"device\\\": \\\"cuda\\\",\\n\",\n    \"    \\\"seed\\\": seed,\\n\",\n    \"}\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 20,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Tune these params:\\n\",\n    \"hyperparams = OptunaParams()\\n\",\n    \"hyperparams.add_int_param(\\\"n_estimators\\\", low=100, high=500)\\n\",\n    \"hyperparams.add_float_param(\\\"learning_rate\\\", low=1e-3, high=0.1, log=True)\\n\",\n    \"hyperparams.add_int_param(\\\"max_depth\\\", low=1, high=10)\\n\",\n    \"hyperparams.add_float_param(\\\"subsample\\\", low=0.05, high=1.0)\\n\",\n    \"hyperparams.add_float_param(\\\"colsample_bytree\\\", low=0.05, high=1.0)\\n\",\n    \"hyperparams.add_int_param(\\\"min_child_weight\\\", low=1, high=20)\\n\",\n    \"\\n\",\n    \"out_schema = hyperparams.to_schema()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"We'll also define the following helper function, which will create duplicates of the dataframe held in separate partitions.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 21,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def coalesce_tree_union(df: DataFrame, num_duplicates: int):\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"    Coalesce the DataFrame to a single partition and recursively self-union to create duplicates.\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"    input_df = df.coalesce(1).cache()\\n\",\n    \"    current_df = input_df\\n\",\n    \"    \\n\",\n    \"    if num_duplicates <= 1:\\n\",\n    \"        return current_df\\n\",\n    \"\\n\",\n    \"    recursions = int(math.log(num_duplicates, 2))\\n\",\n    \"    remainder = num_duplicates - 2 ** recursions\\n\",\n    \"\\n\",\n    \"    for _ in range(recursions):\\n\",\n    \"        current_df = current_df.union(current_df)\\n\",\n    \"\\n\",\n    \"    for _ in range(remainder):\\n\",\n    \"        current_df = current_df.union(input_df)\\n\",\n    \"    \\n\",\n    \"    return current_df\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Load dataset\\n\",\n    \"\\n\",\n    \"Read the data from the local directory with Spark and then duplicate it to prepare to run the task.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 22,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"if on_databricks:\\n\",\n    \"    # once the dataset is in dbfs, databricks appends \\\"dbfs:\\\" to the filepath automatically\\n\",\n    \"    filepath = '/FileStore/optuna-data/winequality-red.csv'\\n\",\n    \"else:\\n\",\n    \"    cwd = os.getcwd()\\n\",\n    \"    filepath = os.path.join(cwd, \\\"data\\\", \\\"winequality-red.csv\\\")\\n\",\n    \"\\n\",\n    \"in_schema = StructType([\\n\",\n    \"    StructField(\\\"fixed acidity\\\", DoubleType(), True),\\n\",\n    \"    StructField(\\\"volatile acidity\\\", DoubleType(), True),\\n\",\n    \"    StructField(\\\"citric acid\\\", DoubleType(), True),\\n\",\n    \"    StructField(\\\"residual sugar\\\", DoubleType(), True),\\n\",\n    \"    StructField(\\\"chlorides\\\", DoubleType(), True),\\n\",\n    \"    StructField(\\\"free sulfur dioxide\\\", DoubleType(), True),\\n\",\n    \"    StructField(\\\"total sulfur dioxide\\\", DoubleType(), True),\\n\",\n    \"    StructField(\\\"density\\\", DoubleType(), True),\\n\",\n    \"    StructField(\\\"pH\\\", DoubleType(), True),\\n\",\n    \"    StructField(\\\"sulphates\\\", DoubleType(), True),\\n\",\n    \"    StructField(\\\"alcohol\\\", DoubleType(), True),\\n\",\n    \"    StructField(\\\"quality\\\", IntegerType(), True)\\n\",\n    \"])\\n\",\n    \"\\n\",\n    \"data_df = spark.read.csv(filepath, header=True, schema=in_schema, sep=\\\";\\\")\\n\",\n    \"data_df = coalesce_tree_union(data_df, num_duplicates=num_tasks)    \"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Run the study\\n\",\n    \"\\n\",\n    \"Map the Optuna task onto the dataframe and collect the results (it might take a few minutes).\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 23,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"result_df = data_df.mapInPandas(lambda pdf_iter: \\n\",\n    \"                                task_udf(pdf_iter,\\n\",\n    \"                                         xgb_params=xgb_params,\\n\",\n    \"                                         optuna_params=hyperparams,\\n\",\n    \"                                         trials_per_task=trials_per_task,\\n\",\n    \"                                         driver_ip=driver_ip,\\n\",\n    \"                                         study_name=study_name,\\n\",\n    \"                                         seed=seed),\\n\",\n    \"                                         schema=out_schema).toPandas()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 24,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Best parameters: {'n_estimators': 419.0, 'learning_rate': 0.015039610889407229, 'max_depth': 10.0, 'subsample': 0.6630214978050138, 'colsample_bytree': 0.8524338650689898, 'min_child_weight': 2.0}\\n\",\n      \"Best value: 0.533100375625104\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"results = result_df.iloc[0].to_dict()\\n\",\n    \"best_value = results.pop(\\\"best_value\\\")\\n\",\n    \"\\n\",\n    \"print(f\\\"Best parameters: {results}\\\")\\n\",\n    \"print(f\\\"Best value: {best_value}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"optuna-spark\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.10.15\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "examples/ML+DL-Examples/Optuna-Spark/optuna-examples/optuna-joblibspark.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"#\\n\",\n    \"# Copyright (c) 2024, NVIDIA CORPORATION.\\n\",\n    \"#\\n\",\n    \"# Licensed under the Apache License, Version 2.0 (the \\\"License\\\");\\n\",\n    \"# you may not use this file except in compliance with the License.\\n\",\n    \"# You may obtain a copy of the License at\\n\",\n    \"#\\n\",\n    \"#     http://www.apache.org/licenses/LICENSE-2.0\\n\",\n    \"#\\n\",\n    \"# Unless required by applicable law or agreed to in writing, software\\n\",\n    \"# distributed under the License is distributed on an \\\"AS IS\\\" BASIS,\\n\",\n    \"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\\n\",\n    \"# See the License for the specific language governing permissions and\\n\",\n    \"# limitations under the License.\\n\",\n    \"#\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"<img src=\\\"http://developer.download.nvidia.com/notebooks/dlsw-notebooks/tensorrt_torchtrt_efficientnet/nvidia_logo.png\\\" width=\\\"90px\\\">\\n\",\n    \"\\n\",\n    \"# Distributed Hyperparameter Tuning: Optuna + JoblibSpark\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"This demo demonstrates distributed hyperparameter tuning for XGBoost using the [JoblibSpark backend](https://github.com/joblib/joblib-spark), building on this [example from Databricks](https://docs.databricks.com/en/machine-learning/automl-hyperparam-tuning/optuna.html).  \\n\",\n    \"We implement best practices to precompute data and maximize computations on the GPU.  \\n\",\n    \"\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"Reference: https://forecastegy.com/posts/xgboost-hyperparameter-tuning-with-optuna/\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Note:\\n\",\n    \"Before running, please make sure you've followed the relevant [setup instructions](../README.md) for your environment (standalone or databricks).\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from typing import List\\n\",\n    \"import os\\n\",\n    \"import requests\\n\",\n    \"import joblib\\n\",\n    \"from joblibspark import register_spark\\n\",\n    \"import optuna\\n\",\n    \"from optuna.samplers import TPESampler\\n\",\n    \"import xgboost as xgb\\n\",\n    \"from pyspark.sql import SparkSession\\n\",\n    \"from pyspark import TaskContext, SparkConf\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Download the dataset\\n\",\n    \"\\n\",\n    \"We'll use the [red wine quality dataset](https://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-red.csv) to regress wine quality based on features such as acidity, sugar content, etc.  \\n\",\n    \"\\n\",\n    \"**Note**: This example uses a small dataset for demonstration purposes. The performance advantages of distributed training are best realized with large datasets and computational workloads.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"cwd = os.getcwd()\\n\",\n    \"os.mkdir(os.path.join(cwd, \\\"data\\\")) if not os.path.exists(os.path.join(cwd, \\\"data\\\")) else None\\n\",\n    \"filepath = os.path.join(cwd, \\\"data\\\", \\\"winequality-red.csv\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"File downloaded and saved to /home/rishic/Code/myforks/spark-rapids-examples/examples/ML+DL-Examples/Optuna-Spark/optuna-examples/data/winequality-red.csv\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"url = \\\"https://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-red.csv\\\"\\n\",\n    \"\\n\",\n    \"response = requests.get(url)\\n\",\n    \"if response.status_code == 200:\\n\",\n    \"    with open(filepath, \\\"wb\\\") as f:\\n\",\n    \"        f.write(response.content)\\n\",\n    \"    print(f\\\"File downloaded and saved to {filepath}\\\")\\n\",\n    \"else:\\n\",\n    \"    print(f\\\"Failed to download the file. Status code: {response.status_code}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Part 1. Running Optuna locally\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import cudf\\n\",\n    \"from cuml.metrics.regression import mean_squared_error\\n\",\n    \"from cuml.model_selection import train_test_split\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Prepare data\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/html\": [\n       \"<div>\\n\",\n       \"<style scoped>\\n\",\n       \"    .dataframe tbody tr th:only-of-type {\\n\",\n       \"        vertical-align: middle;\\n\",\n       \"    }\\n\",\n       \"\\n\",\n       \"    .dataframe tbody tr th {\\n\",\n       \"        vertical-align: top;\\n\",\n       \"    }\\n\",\n       \"\\n\",\n       \"    .dataframe thead th {\\n\",\n       \"        text-align: right;\\n\",\n       \"    }\\n\",\n       \"</style>\\n\",\n       \"<table border=\\\"1\\\" class=\\\"dataframe\\\">\\n\",\n       \"  <thead>\\n\",\n       \"    <tr style=\\\"text-align: right;\\\">\\n\",\n       \"      <th></th>\\n\",\n       \"      <th>fixed acidity</th>\\n\",\n       \"      <th>volatile acidity</th>\\n\",\n       \"      <th>citric acid</th>\\n\",\n       \"      <th>residual sugar</th>\\n\",\n       \"      <th>chlorides</th>\\n\",\n       \"      <th>free sulfur dioxide</th>\\n\",\n       \"      <th>total sulfur dioxide</th>\\n\",\n       \"      <th>density</th>\\n\",\n       \"      <th>pH</th>\\n\",\n       \"      <th>sulphates</th>\\n\",\n       \"      <th>alcohol</th>\\n\",\n       \"      <th>quality</th>\\n\",\n       \"    </tr>\\n\",\n       \"  </thead>\\n\",\n       \"  <tbody>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>0</th>\\n\",\n       \"      <td>7.4</td>\\n\",\n       \"      <td>0.70</td>\\n\",\n       \"      <td>0.00</td>\\n\",\n       \"      <td>1.9</td>\\n\",\n       \"      <td>0.076</td>\\n\",\n       \"      <td>11.0</td>\\n\",\n       \"      <td>34.0</td>\\n\",\n       \"      <td>0.9978</td>\\n\",\n       \"      <td>3.51</td>\\n\",\n       \"      <td>0.56</td>\\n\",\n       \"      <td>9.4</td>\\n\",\n       \"      <td>5</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>1</th>\\n\",\n       \"      <td>7.8</td>\\n\",\n       \"      <td>0.88</td>\\n\",\n       \"      <td>0.00</td>\\n\",\n       \"      <td>2.6</td>\\n\",\n       \"      <td>0.098</td>\\n\",\n       \"      <td>25.0</td>\\n\",\n       \"      <td>67.0</td>\\n\",\n       \"      <td>0.9968</td>\\n\",\n       \"      <td>3.20</td>\\n\",\n       \"      <td>0.68</td>\\n\",\n       \"      <td>9.8</td>\\n\",\n       \"      <td>5</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>2</th>\\n\",\n       \"      <td>7.8</td>\\n\",\n       \"      <td>0.76</td>\\n\",\n       \"      <td>0.04</td>\\n\",\n       \"      <td>2.3</td>\\n\",\n       \"      <td>0.092</td>\\n\",\n       \"      <td>15.0</td>\\n\",\n       \"      <td>54.0</td>\\n\",\n       \"      <td>0.9970</td>\\n\",\n       \"      <td>3.26</td>\\n\",\n       \"      <td>0.65</td>\\n\",\n       \"      <td>9.8</td>\\n\",\n       \"      <td>5</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>3</th>\\n\",\n       \"      <td>11.2</td>\\n\",\n       \"      <td>0.28</td>\\n\",\n       \"      <td>0.56</td>\\n\",\n       \"      <td>1.9</td>\\n\",\n       \"      <td>0.075</td>\\n\",\n       \"      <td>17.0</td>\\n\",\n       \"      <td>60.0</td>\\n\",\n       \"      <td>0.9980</td>\\n\",\n       \"      <td>3.16</td>\\n\",\n       \"      <td>0.58</td>\\n\",\n       \"      <td>9.8</td>\\n\",\n       \"      <td>6</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>4</th>\\n\",\n       \"      <td>7.4</td>\\n\",\n       \"      <td>0.70</td>\\n\",\n       \"      <td>0.00</td>\\n\",\n       \"      <td>1.9</td>\\n\",\n       \"      <td>0.076</td>\\n\",\n       \"      <td>11.0</td>\\n\",\n       \"      <td>34.0</td>\\n\",\n       \"      <td>0.9978</td>\\n\",\n       \"      <td>3.51</td>\\n\",\n       \"      <td>0.56</td>\\n\",\n       \"      <td>9.4</td>\\n\",\n       \"      <td>5</td>\\n\",\n       \"    </tr>\\n\",\n       \"  </tbody>\\n\",\n       \"</table>\\n\",\n       \"</div>\"\n      ],\n      \"text/plain\": [\n       \"   fixed acidity  volatile acidity  citric acid  residual sugar  chlorides  \\\\\\n\",\n       \"0            7.4              0.70         0.00             1.9      0.076   \\n\",\n       \"1            7.8              0.88         0.00             2.6      0.098   \\n\",\n       \"2            7.8              0.76         0.04             2.3      0.092   \\n\",\n       \"3           11.2              0.28         0.56             1.9      0.075   \\n\",\n       \"4            7.4              0.70         0.00             1.9      0.076   \\n\",\n       \"\\n\",\n       \"   free sulfur dioxide  total sulfur dioxide  density    pH  sulphates  \\\\\\n\",\n       \"0                 11.0                  34.0   0.9978  3.51       0.56   \\n\",\n       \"1                 25.0                  67.0   0.9968  3.20       0.68   \\n\",\n       \"2                 15.0                  54.0   0.9970  3.26       0.65   \\n\",\n       \"3                 17.0                  60.0   0.9980  3.16       0.58   \\n\",\n       \"4                 11.0                  34.0   0.9978  3.51       0.56   \\n\",\n       \"\\n\",\n       \"   alcohol  quality  \\n\",\n       \"0      9.4        5  \\n\",\n       \"1      9.8        5  \\n\",\n       \"2      9.8        5  \\n\",\n       \"3      9.8        6  \\n\",\n       \"4      9.4        5  \"\n      ]\n     },\n     \"execution_count\": 6,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"data = cudf.read_csv(filepath, delimiter=\\\";\\\")\\n\",\n    \"data.head()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Prepare the train/validation sets. Precompute the Quantile DMatrix, which is used by histogram-based tree methods to save memory.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"X = data.iloc[:, :-1].values\\n\",\n    \"y = data[\\\"quality\\\"].values\\n\",\n    \"X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)\\n\",\n    \"Xy_train_qdm = xgb.QuantileDMatrix(X_train, y_train)  # Precompute Quantile DMatrix to avoid repeated quantization every trial.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Objective function\\n\",\n    \"\\n\",\n    \"We define the objective and a hyperparameter search space to optimize via the `trial.suggest_` methods.  \\n\",\n    \"\\n\",\n    \"In each trial, new hyperparameters will be suggested based on previous results. See [optuna.trial.Trial](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.trial.Trial.html) API for a full list of functions.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 9,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def objective(trial):\\n\",\n    \"    params = {\\n\",\n    \"        \\\"objective\\\": \\\"reg:squarederror\\\",\\n\",\n    \"        \\\"verbosity\\\": 0,\\n\",\n    \"        \\\"learning_rate\\\": trial.suggest_float(\\\"learning_rate\\\", 1e-3, 0.1, log=True),\\n\",\n    \"        \\\"max_depth\\\": trial.suggest_int(\\\"max_depth\\\", 1, 10),\\n\",\n    \"        \\\"subsample\\\": trial.suggest_float(\\\"subsample\\\", 0.05, 1.0),\\n\",\n    \"        \\\"colsample_bytree\\\": trial.suggest_float(\\\"colsample_bytree\\\", 0.05, 1.0),\\n\",\n    \"        \\\"min_child_weight\\\": trial.suggest_int(\\\"min_child_weight\\\", 1, 20),\\n\",\n    \"        \\\"tree_method\\\": \\\"gpu_hist\\\",\\n\",\n    \"        \\\"device\\\": \\\"cuda\\\",\\n\",\n    \"    }\\n\",\n    \"\\n\",\n    \"    booster = xgb.train(params=params, dtrain=Xy_train_qdm, num_boost_round=trial.suggest_int(\\\"num_boost_round\\\", 100, 500))\\n\",\n    \"    predictions = booster.inplace_predict(X_val)\\n\",\n    \"    rmse = mean_squared_error(y_val, predictions, squared=False).get()\\n\",\n    \"    \\n\",\n    \"    return rmse   \"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Create the study and optimize. By default, the study results will be stored in memory.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 10,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[I 2024-12-11 23:42:09,341] A new study created in memory with name: optuna-xgboost-local\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[I 2024-12-11 23:42:09,715] Trial 0 finished with value: 0.6377619522504244 and parameters: {'learning_rate': 0.005611516415334507, 'max_depth': 10, 'subsample': 0.7453942447208348, 'colsample_bytree': 0.6187255599871848, 'min_child_weight': 4, 'num_boost_round': 162}. Best is trial 0 with value: 0.6377619522504244.\\n\",\n      \"[I 2024-12-11 23:42:10,666] Trial 1 finished with value: 0.6703788974319568 and parameters: {'learning_rate': 0.0013066739238053278, 'max_depth': 9, 'subsample': 0.6210592611560484, 'colsample_bytree': 0.7226689489062432, 'min_child_weight': 1, 'num_boost_round': 488}. Best is trial 0 with value: 0.6377619522504244.\\n\",\n      \"[I 2024-12-11 23:42:10,806] Trial 2 finished with value: 0.6181751362616256 and parameters: {'learning_rate': 0.04622589001020832, 'max_depth': 3, 'subsample': 0.2227337188467456, 'colsample_bytree': 0.22423428436076215, 'min_child_weight': 7, 'num_boost_round': 310}. Best is trial 2 with value: 0.6181751362616256.\\n\",\n      \"[I 2024-12-11 23:42:10,922] Trial 3 finished with value: 0.6698576232920956 and parameters: {'learning_rate': 0.007309539835912915, 'max_depth': 3, 'subsample': 0.6312602499862605, 'colsample_bytree': 0.18251916761943976, 'min_child_weight': 6, 'num_boost_round': 246}. Best is trial 2 with value: 0.6181751362616256.\\n\",\n      \"[I 2024-12-11 23:42:11,039] Trial 4 finished with value: 0.6704590546150145 and parameters: {'learning_rate': 0.008168455894760165, 'max_depth': 8, 'subsample': 0.23969009305044175, 'colsample_bytree': 0.538522716492931, 'min_child_weight': 12, 'num_boost_round': 118}. Best is trial 2 with value: 0.6181751362616256.\\n\",\n      \"[I 2024-12-11 23:42:11,191] Trial 5 finished with value: 0.6088806682631155 and parameters: {'learning_rate': 0.016409286730647923, 'max_depth': 2, 'subsample': 0.11179901333601554, 'colsample_bytree': 0.9514412603906666, 'min_child_weight': 20, 'num_boost_round': 424}. Best is trial 5 with value: 0.6088806682631155.\\n\",\n      \"[I 2024-12-11 23:42:11,266] Trial 6 finished with value: 0.7103495949713845 and parameters: {'learning_rate': 0.0040665633135147945, 'max_depth': 1, 'subsample': 0.700021375186549, 'colsample_bytree': 0.4681448690526212, 'min_child_weight': 3, 'num_boost_round': 298}. Best is trial 5 with value: 0.6088806682631155.\\n\",\n      \"[I 2024-12-11 23:42:11,666] Trial 7 finished with value: 0.7255199474722185 and parameters: {'learning_rate': 0.001171593739230706, 'max_depth': 10, 'subsample': 0.29584098252001606, 'colsample_bytree': 0.6793961701362828, 'min_child_weight': 7, 'num_boost_round': 308}. Best is trial 5 with value: 0.6088806682631155.\\n\",\n      \"[I 2024-12-11 23:42:11,829] Trial 8 finished with value: 0.6060010014477214 and parameters: {'learning_rate': 0.0123999678368461, 'max_depth': 2, 'subsample': 0.9711053963763306, 'colsample_bytree': 0.7863761821930588, 'min_child_weight': 19, 'num_boost_round': 458}. Best is trial 8 with value: 0.6060010014477214.\\n\",\n      \"[I 2024-12-11 23:42:12,168] Trial 9 finished with value: 0.6292433375858283 and parameters: {'learning_rate': 0.015696396388661146, 'max_depth': 10, 'subsample': 0.13406787694932354, 'colsample_bytree': 0.23618371929818793, 'min_child_weight': 1, 'num_boost_round': 230}. Best is trial 8 with value: 0.6060010014477214.\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"study = optuna.create_study(study_name=\\\"optuna-xgboost-local\\\", sampler=TPESampler(seed=42))\\n\",\n    \"study.optimize(objective, n_trials=10)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 11,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Best RMSE:  0.6060010014477214\\n\",\n      \"Best hyperparameters:  {'learning_rate': 0.0123999678368461, 'max_depth': 2, 'subsample': 0.9711053963763306, 'colsample_bytree': 0.7863761821930588, 'min_child_weight': 19, 'num_boost_round': 458}\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"trial = study.best_trial\\n\",\n    \"print(\\\"Best RMSE: \\\", trial.value)\\n\",\n    \"print(\\\"Best hyperparameters: \\\", trial.params)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Part 2. Distributed Optuna on Spark \"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### PySpark\\n\",\n    \"\\n\",\n    \"For standalone users, we need to create the Spark session. For Databricks users, the Spark session will be preconfigured.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 12,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"24/12/11 23:42:12 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)\\n\",\n      \"24/12/11 23:42:12 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\\n\",\n      \"Setting default log level to \\\"WARN\\\".\\n\",\n      \"To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\\n\",\n      \"24/12/11 23:42:13 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"def initialize_spark():\\n\",\n    \"    import socket\\n\",\n    \"    hostname = socket.gethostname()\\n\",\n    \"    conda_env = os.environ.get(\\\"CONDA_PREFIX\\\")\\n\",\n    \"\\n\",\n    \"    conf = SparkConf()\\n\",\n    \"    conf.setMaster(f\\\"spark://{hostname}:7077\\\")  # Assuming master is on host and default port. \\n\",\n    \"    conf.set(\\\"spark.task.maxFailures\\\", \\\"1\\\")\\n\",\n    \"    conf.set(\\\"spark.task.resource.gpu.amount\\\", \\\"1\\\")\\n\",\n    \"    conf.set(\\\"spark.executor.resource.gpu.amount\\\", \\\"1\\\")\\n\",\n    \"    conf.set(\\\"spark.pyspark.python\\\", f\\\"{conda_env}/bin/python\\\")\\n\",\n    \"    conf.set(\\\"spark.pyspark.driver.python\\\", f\\\"{conda_env}/bin/python\\\")\\n\",\n    \"    \\n\",\n    \"    spark = SparkSession.builder.appName(\\\"optuna-joblibspark-xgboost\\\").config(conf=conf).getOrCreate()\\n\",\n    \"    return spark\\n\",\n    \"\\n\",\n    \"if 'spark' not in globals():\\n\",\n    \"    spark = initialize_spark()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Optuna Task\\n\",\n    \"\\n\",\n    \"This implementation demonstrates **Worker I/O**. \\n\",\n    \"\\n\",\n    \"This means that each worker will read the full dataset from the filepath rather than passing the data in a dataframe.  \\n\",\n    \"In practice, this requires the dataset to be written to a distributed file system accessible to all workers prior to tuning. \\n\",\n    \"\\n\",\n    \"For the alternative implementation using **Spark I/O**, see the [Spark Dataframe notebook](optuna-dataframe.ipynb).\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"In the task, each worker will:\\n\",\n    \"1. Read the dataset from the filepath\\n\",\n    \"2. Load the study from the MySQL storage backend\\n\",\n    \"3. Optimize over the objective for the assigned number of trials, sending results back to the database after each iteration\\n\",\n    \"\\n\",\n    \"Here we use Optuna's [Define-and-Run](https://optuna.readthedocs.io/en/stable/tutorial/20_recipes/009_ask_and_tell.html#define-and-run) API, which allows us to predefine the hyperparameter space and pass it to the task.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 13,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def task(num_trials: int, xgb_params: dict, optuna_params: dict, driver_ip: str, study_name: str, seed: int, filepath: str):\\n\",\n    \"    import cudf\\n\",\n    \"    from cuml.metrics.regression import mean_squared_error\\n\",\n    \"    from cuml.model_selection import train_test_split\\n\",\n    \"\\n\",\n    \"    tc = TaskContext.get()\\n\",\n    \"    assert \\\"gpu\\\" in tc.resources(), \\\"GPU resource not found.\\\"\\n\",\n    \"\\n\",\n    \"    if filepath.startswith(\\\"/dbfs/\\\"):\\n\",\n    \"        # Check to ensure GPU direct storage is disabled for cuDF on databricks.\\n\",\n    \"        libcudf_policy = os.environ.get('LIBCUDF_CUFILE_POLICY')\\n\",\n    \"        if libcudf_policy != 'OFF':\\n\",\n    \"            raise RuntimeError(\\\"Set LIBCUDF_CUFILE_POLICY=OFF to read from DBFS with cuDF.\\\")\\n\",\n    \"    \\n\",\n    \"    data = cudf.read_csv(filepath, delimiter=\\\";\\\")\\n\",\n    \"    X = data.iloc[:, :-1].values\\n\",\n    \"    y = data[\\\"quality\\\"].values\\n\",\n    \"    X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=seed)\\n\",\n    \"\\n\",\n    \"    tuning_max_bin = \\\"max_bin\\\" in optuna_params\\n\",\n    \"    if not tuning_max_bin:\\n\",\n    \"        max_bin = xgb_params.get(\\\"max_bin\\\", 256)\\n\",\n    \"        # Precompute Quantile DMatrix to avoid repeated quantization every trial.\\n\",\n    \"        Xy_train_qdm = xgb.QuantileDMatrix(X_train, y_train, max_bin=max_bin)\\n\",\n    \"\\n\",\n    \"    study = optuna.load_study(\\n\",\n    \"        study_name=study_name,\\n\",\n    \"        storage=f\\\"mysql://optuna_user:optuna_password@{driver_ip}/optuna\\\",\\n\",\n    \"        sampler=TPESampler(seed=seed),\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    print(f\\\"Running {num_trials} trials on partition {tc.partitionId()}.\\\")\\n\",\n    \"\\n\",\n    \"    ### Objective ###\\n\",\n    \"    for _ in range(num_trials):\\n\",\n    \"        trial = study.ask(optuna_params)\\n\",\n    \"        xgb_params.update(trial.params)\\n\",\n    \"\\n\",\n    \"        if tuning_max_bin:\\n\",\n    \"            # If tuning the max_bin param, we must recompute the QDM every trial.\\n\",\n    \"            if \\\"n_estimators\\\" not in xgb_params:\\n\",\n    \"                xgb_params[\\\"n_estimators\\\"] = 100  # Default value if not tuning.\\n\",\n    \"\\n\",\n    \"            model = xgb.XGBRegressor(**xgb_params)\\n\",\n    \"            model.fit(X_train, y_train)\\n\",\n    \"            booster = model.get_booster()\\n\",\n    \"        else:\\n\",\n    \"            # Train the model with xgb.train() API using the precomputed QDM.\\n\",\n    \"            num_boost_round = xgb_params.get(\\\"n_estimators\\\", 100)\\n\",\n    \"            booster = xgb.train(params=xgb_params, dtrain=Xy_train_qdm, num_boost_round=num_boost_round)\\n\",\n    \"            \\n\",\n    \"        # Perform in-place predictions on GPU using the booster.\\n\",\n    \"        predictions = booster.inplace_predict(X_val)\\n\",\n    \"        rmse = mean_squared_error(y_val, predictions, squared=False).get()\\n\",\n    \"        \\n\",\n    \"        study.tell(trial, rmse)\\n\",\n    \"\\n\",\n    \"    return study.best_params, study.best_value\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 14,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# This will register the Spark Session with the Joblib Spark backend.\\n\",\n    \"register_spark()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Setup and run the Optuna study\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Get the driver IP for the MySQL database.  \\n\",\n    \"- For standalone users, make sure you've followed the [database setup instructions](../README.md#setup-database-for-optuna). The database should be on 'localhost'. \\n\",\n    \"- For databricks users, the database should already be setup on the driver node by the init script.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 15,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# check if we're running on databricks\\n\",\n    \"on_databricks = os.environ.get(\\\"DATABRICKS_RUNTIME_VERSION\\\", False)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 16,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"MySQL database is hosted on localhost\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"if on_databricks:\\n\",\n    \"    driver_ip = spark.conf.get(\\\"spark.driver.host\\\")\\n\",\n    \"else:\\n\",\n    \"    driver_ip = \\\"localhost\\\"\\n\",\n    \"\\n\",\n    \"print(f\\\"MySQL database is hosted on {driver_ip}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Create a new study, referencing the MySQL database as the storage backend.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 17,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[I 2024-12-11 23:42:13,928] A new study created in RDB with name: optuna-xgboost-joblibspark\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"<optuna.study.study.Study at 0x76ae75922b00>\"\n      ]\n     },\n     \"execution_count\": 17,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"study_name = \\\"optuna-xgboost-joblibspark\\\"\\n\",\n    \"seed = 42\\n\",\n    \"\\n\",\n    \"try:\\n\",\n    \"    # Delete the study if it already exists\\n\",\n    \"    optuna.delete_study(\\n\",\n    \"        study_name=study_name, \\n\",\n    \"        storage=f\\\"mysql://optuna_user:optuna_password@{driver_ip}/optuna\\\"\\n\",\n    \"    )\\n\",\n    \"except:\\n\",\n    \"    pass\\n\",\n    \"\\n\",\n    \"optuna.create_study(\\n\",\n    \"    study_name=study_name,\\n\",\n    \"    storage=f\\\"mysql://optuna_user:optuna_password@{driver_ip}/optuna\\\",\\n\",\n    \"    sampler=TPESampler(seed=seed)\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Define the number of tasks, number of trials, and trials per task. \\n\",\n    \"\\n\",\n    \"**NOTE**: for standalone users running on a single worker, the 4 tasks will all be assigned to the same worker and executed sequentially in this demonstration. This can easily be scaled up to run concurrently by adding more workers.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 18,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def partition_trials(total_trials: int, total_tasks: int) -> List[int]:\\n\",\n    \"    base_size = total_trials // total_tasks\\n\",\n    \"    extra = total_trials % total_tasks\\n\",\n    \"    partitions = [base_size] * total_tasks\\n\",\n    \"    for i in range(extra):\\n\",\n    \"        partitions[i] += 1\\n\",\n    \"    \\n\",\n    \"    return partitions\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 19,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Trials per task: [25, 25, 25, 25]\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"num_tasks = 4\\n\",\n    \"num_trials = 100\\n\",\n    \"trials_per_task = partition_trials(num_trials, num_tasks)\\n\",\n    \"print(f\\\"Trials per task: {trials_per_task}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Define params\\n\",\n    \"Define the XGBoost model params and the hyperparams for Optuna to tune. \"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Keep these params consistent:\\n\",\n    \"xgb_params = {\\n\",\n    \"    \\\"objective\\\": \\\"reg:squarederror\\\",\\n\",\n    \"    \\\"verbosity\\\": 0,\\n\",\n    \"    \\\"tree_method\\\": \\\"gpu_hist\\\",\\n\",\n    \"    \\\"device\\\": f\\\"cuda\\\",\\n\",\n    \"    \\\"seed\\\": seed,\\n\",\n    \"}\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Tune these params:\\n\",\n    \"optuna_params = {\\n\",\n    \"    \\\"n_estimators\\\": optuna.distributions.IntDistribution(100, 500),\\n\",\n    \"    \\\"learning_rate\\\": optuna.distributions.FloatDistribution(1e-3, 0.1, log=True),\\n\",\n    \"    \\\"max_depth\\\": optuna.distributions.IntDistribution(1, 10),\\n\",\n    \"    \\\"subsample\\\": optuna.distributions.FloatDistribution(0.05, 1.0),\\n\",\n    \"    \\\"colsample_bytree\\\": optuna.distributions.FloatDistribution(0.05, 1.0),\\n\",\n    \"    \\\"min_child_weight\\\": optuna.distributions.IntDistribution(1, 20),\\n\",\n    \"}\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"**For Databricks**: we must download the dataset to DBFS so that all workers can access it.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 22,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"if on_databricks:\\n\",\n    \"    dbutils.fs.mkdirs(\\\"/FileStore/optuna-data\\\")\\n\",\n    \"    filepath = \\\"/dbfs/FileStore/optuna-data/winequality-red.csv\\\"\\n\",\n    \"    url = \\\"https://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-red.csv\\\"\\n\",\n    \"\\n\",\n    \"    response = requests.get(url)\\n\",\n    \"    if response.status_code == 200:\\n\",\n    \"        with open(filepath, \\\"wb\\\") as f:\\n\",\n    \"            f.write(response.content)\\n\",\n    \"        print(f\\\"File downloaded and saved to {filepath}\\\")\\n\",\n    \"    else:\\n\",\n    \"        print(f\\\"Failed to download the file. Status code: {response.status_code}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Run the study\\n\",\n    \"\\n\",\n    \"Run parallel threads to execute the Optuna task and collect the reuslts (it might take a few minutes).\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 23,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"/home/rishic/anaconda3/envs/optuna-spark/lib/python3.10/site-packages/joblibspark/backend.py:115: UserWarning: Spark version does not support stage-level scheduling.\\n\",\n      \"  warnings.warn(\\\"Spark version does not support stage-level scheduling.\\\")\\n\",\n      \"/home/rishic/anaconda3/envs/optuna-spark/lib/python3.10/site-packages/joblibspark/backend.py:154: UserWarning: User-specified n_jobs (4) is greater than the max number of concurrent tasks (1) this cluster can run now.If dynamic allocation is enabled for the cluster, you might see more executors allocated.\\n\",\n      \"  warnings.warn(f\\\"User-specified n_jobs ({n_jobs}) is greater than the max number of \\\"\\n\",\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"with joblib.parallel_backend(\\\"spark\\\", n_jobs=num_tasks):\\n\",\n    \"    results = joblib.Parallel()(\\n\",\n    \"        joblib.delayed(task)(num_trials,\\n\",\n    \"                             xgb_params,\\n\",\n    \"                             optuna_params,\\n\",\n    \"                             driver_ip,\\n\",\n    \"                             study_name,\\n\",\n    \"                             seed,\\n\",\n    \"                             filepath) for num_trials in trials_per_task\\n\",\n    \"    )\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 24,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Best parameters: {'n_estimators': 463, 'learning_rate': 0.05206124631137337, 'max_depth': 9, 'subsample': 0.7434942725744815, 'colsample_bytree': 0.877391644494205, 'min_child_weight': 4}\\n\",\n      \"Best value: 0.5324732150787205\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"best_params = min(results, key=lambda x: x[1])[0]\\n\",\n    \"best_value = min(results, key=lambda x: x[1])[1]\\n\",\n    \"\\n\",\n    \"print(f\\\"Best parameters: {best_params}\\\")\\n\",\n    \"print(f\\\"Best value: {best_value}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"optuna-spark\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.10.15\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "examples/ML+DL-Examples/Spark-DL/dl_inference/README.md",
    "content": "# Deep Learning Inference on Spark\n\nExample notebooks demonstrating **distributed deep learning inference** using the [predict_batch_udf](https://developer.nvidia.com/blog/distributed-deep-learning-made-easy-with-spark-3-4/#distributed_inference) introduced in Spark 3.4.0.\nThese notebooks also demonstrate model serving integrations with [Triton Inference Server](https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/index.html) and [vLLM serve](https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html).\n\n## Contents:\n- [Overview](#overview)\n- [Running Locally](#running-locally)\n- [Running on Cloud](#running-on-cloud-environments)\n- [Inference Serving Integration](#inference-serving)\n\n## Overview\n\nThese notebooks demonstrate how models from external frameworks (Torch, Huggingface, Tensorflow, vLLM) trained on single-worker machines can be used for large-scale distributed inference on Spark clusters.  \nFor example, a basic model trained in TensorFlow and saved on disk as \"mnist_model\" can be used in Spark as follows:\n```\nimport numpy as np\nfrom pyspark.sql.functions import predict_batch_udf\nfrom pyspark.sql.types import ArrayType, FloatType\n\ndef predict_batch_fn():\n    import tensorflow as tf\n    model = tf.keras.models.load_model(\"/path/to/mnist_model\")\n    def predict(inputs: np.ndarray) -> np.ndarray:\n        return model.predict(inputs)\n    return predict\n\nmnist = predict_batch_udf(predict_batch_fn,\n                          return_type=ArrayType(FloatType()),\n                          batch_size=1024,\n                          input_tensor_shapes=[[784]])\n\ndf = spark.read.parquet(\"mnist_data\")\npredictions = df.withColumn(\"preds\", mnist(\"data\")).collect()\n```\n\nIn this simple case, the `predict_batch_fn` will use TensorFlow APIs to load the model and return a simple `predict` function.  The `predict_batch_udf` will handle the data conversion from Spark DataFrame columns into batched numpy inputs.\n\n\n#### Notebook List\n\nBelow is a full list of the notebooks and their links. All notebooks have been saved with sample outputs for quick browsing.  \n\n|   | Framework  | Notebook Name | Description | Link\n| ------------- | ------------- | ------------- | ------------- | ------------- \n| 1 | HuggingFace | DeepSeek-R1 | LLM batch inference using the DeepSeek-R1-Distill-Llama reasoning model to solve word problems. | [Link](huggingface/deepseek-r1_torch.ipynb)\n| 2 | HuggingFace | Qwen-2.5-7b | LLM batch inference using the Qwen-2.5-7b model for text summarization. | [Link](huggingface/qwen-2.5-7b_torch.ipynb)\n| 3 | HuggingFace | Gemma-7b | LLM batch inference using the Google Gemma-7b model for code comprehension tasks. | [Link](huggingface/gemma-7b_torch.ipynb)\n| 4 | HuggingFace | Sentence Transformers | Sentence embeddings using SentenceTransformers in Torch. | [Link](huggingface/sentence_transformers_torch.ipynb)\n| 5+6 | HuggingFace | Conditional Generation | Sentence translation using the T5 text-to-text transformer (Torch and Tensorflow). | [Torch Link](huggingface/conditional_generation_torch.ipynb), [TF Link](huggingface/conditional_generation_tf.ipynb)\n| 7+8 | HuggingFace | Pipelines | Sentiment analysis using Huggingface pipelines (Torch and Tensorflow). | [Torch Link](huggingface/pipelines_torch.ipynb), [TF Link](huggingface/pipelines_tf.ipynb)\n| 9 | vLLM | Qwen-2.5-14b-tensor-parallel | Tensor-parallel LLM batch inference using the Qwen-2.5-14b model to summarize unstructured text data into a structured schema, using vLLM serve. | [Link](vllm/qwen-2.5-14b-tensor-parallel_vllm.ipynb)\n| 10 | vLLM | Qwen-2.5-7b | LLM batch inference using the Qwen-2.5-7b model to summarize for text summarization, using vLLM serve. | [Link](vllm/qwen-2.5-7b_vllm.ipynb)\n| 11 | PyTorch | Image Classification | Training a model to predict clothing categories in FashionMNIST, and deploying with Torch-TensorRT accelerated inference. | [Link](pytorch/image_classification_torch.ipynb)\n| 12 | PyTorch | Housing Regression | Training and deploying a model to predict housing prices in the California Housing Dataset, and deploying with Torch-TensorRT accelerated inference. | [Link](pytorch/housing_regression_torch.ipynb)\n| 13 | Tensorflow | Image Classification | Training and deploying a model to predict hand-written digits in MNIST. | [Link](tensorflow/image_classification_tf.ipynb)\n| 14 | Tensorflow | Keras Preprocessing | Training and deploying a model with preprocessing layers to predict likelihood of pet adoption in the PetFinder mini dataset. | [Link](tensorflow/keras_preprocessing_tf.ipynb)\n| 15 | Tensorflow | Keras Resnet50 | Deploying ResNet-50 to perform flower recognition from flower images. | [Link](tensorflow/keras_resnet50_tf.ipynb)\n| 16 | Tensorflow | Text Classification | Training and deploying a model to perform sentiment analysis on the IMDB dataset. | [Link](tensorflow/text_classification_tf.ipynb)\n\n\n## Running Locally\n\nTo run the notebooks locally, please follow these instructions:\n\n#### Create environment\n\nEach notebook has a suffix `_torch`, `_tf`, or `_vllm` specifying the environment used.\n\n**For PyTorch:**\n```\nconda create -n spark-dl-torch -c conda-forge python=3.11\nconda activate spark-dl-torch\nconda install -c conda-forge libstdcxx-ng\npip install -r torch_requirements.txt\n```\n**For TensorFlow:**\n```\nconda create -n spark-dl-tf -c conda-forge python=3.11\nconda activate spark-dl-tf\nconda install -c conda-forge libstdcxx-ng\npip install -r tf_requirements.txt\n```\n**For vLLM:**\n```\nconda create -n spark-dl-vllm -c conda-forge python=3.11\nconda activate spark-dl-vllm\npip install -r vllm_requirements.txt\n```\n\n#### Start Cluster\n\nFor demonstration, these instructions just use a local Standalone cluster with a single executor, but they can be run on any distributed Spark cluster. If you haven't already, [install Spark](https://spark.apache.org/downloads.html) on your system. \n```shell\n# Replace with your Spark installation path\nexport SPARK_HOME=</path/to/spark>\n```\n\n```shell\n# Configure and start cluster\nexport MASTER=spark://$(hostname):7077\nexport SPARK_WORKER_INSTANCES=1\nexport CORES_PER_WORKER=8\nexport SPARK_WORKER_OPTS=\"-Dspark.worker.resource.gpu.amount=1 \\\n                          -Dspark.worker.resource.gpu.discoveryScript=$SPARK_HOME/examples/src/main/scripts/getGpusResources.sh\"\n${SPARK_HOME}/sbin/start-master.sh; ${SPARK_HOME}/sbin/start-worker.sh -c ${CORES_PER_WORKER} -m 16G ${MASTER}\n```\n\nThe notebooks are ready to run! Each notebook has a cell to connect to the standalone cluster and create a SparkSession.\n\n**Notes**: \n- Please create separate environments for different frameworks as specified above. This will avoid conflicts between the CUDA libraries bundled with their respective versions. \n- `requirements.txt` installs pyspark>=3.4.0. Make sure the installed PySpark version is compatible with your system's Spark installation.\n- The notebooks require an NVIDIA GPU on your system.  \n- The PyTorch notebooks include model compilation and accelerated inference with TensorRT. While not included in the notebooks, Tensorflow also supports [integration with TensorRT](https://docs.nvidia.com/deeplearning/frameworks/tf-trt-user-guide/index.html), but as of writing it is not supported in TF==2.17.0. \n- Note that some Huggingface models may be gated and will require a login, e.g.,:\n    ```python\n    from huggingface_hub import login\n    login()\n    ```\n\n## Running on Cloud Environments\n\nWe also provide instructions to run the notebooks on CSP Spark environments.  \nSee the instructions for [Databricks](databricks/README.md) and [GCP Dataproc](dataproc/README.md).\n\n## Inference Serving\n\n<img src=\"images/spark-server.png\" alt=\"drawing\" width=\"900\"/>\n\nThe notebooks demonstrate deploying models on an inference server as a sidecar process, as shown above. The process looks like this:\n- Prior to inference, launch a server process on each node.\n- Define a predict function, which creates a client that sends/receives inference requests to the local server.\n- Wrap the predict function in a predict_batch_udf to launch parallel inference requests using Spark.\n\nThis logically separates the CPU parallelism from the GPU parallelism for streamlined deployment. \nFor instance, say we want to run a 20GB model on a GPU with 25GB of memory.\n- With `predict_batch_udf` using an in-process framework, we must set `spark.task.resource.gpu.amount=1`, which limits parallelism to 1 task (i.e. model instance) per GPU for the entire application due to memory constraints. \n- Using an inference server, we can set `spark.task.resource.gpu.amount=(num_cores)` to leverage all the executor CPUs for Dataframe operations (reading/preprocessing/writing), while the server loads 1 instance of the model on the GPU for inference.\n\nSee [`server_utils.py`](server_utils.py) for more details on how we manage servers on the Spark cluster.\n\n### Triton Inference Server\n\nEach notebook has a section that demonstrates model serving with [Triton Inference Server](https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/index.html), an open-source serving platform for deep learning models, which includes many [major features](https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/index.html#triton-major-features) to streamline inference. To leverage Triton through Python, we use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles communication with the Triton server.  \n\nTriton allows you to define a Python function encapsulating the inference logic, including complex pipelines such as model ensembles or concurrent execution. For more information on how PyTriton works, see the [PyTriton docs](https://triton-inference-server.github.io/pytriton/latest/high_level_design/).\n\n### vLLM Server\n\nThe vLLM notebooks demonstrate serving with [vLLM serve](https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html), an OpenAI-compatible HTTP server to deploy vLLM models. If you do not need the custom inference logic provided by Triton, vLLM serve is a straightforward alternative to deploy a vLLM-compatible LLM."
  },
  {
    "path": "examples/ML+DL-Examples/Spark-DL/dl_inference/databricks/README.md",
    "content": "# Spark DL Inference on Databricks\n\n**Note**: fields in \\<brackets\\> require user inputs.  \nMake sure you are in [this](./) directory.\n\n## Setup\n\n1. Install the latest [databricks-cli](https://docs.databricks.com/en/dev-tools/cli/tutorial.html) and configure for your workspace.\n\n2. Specify the path to your Databricks workspace:\n    ```shell\n    export WS_PATH=</Users/someone@example.com>\n    ```\n\n    ```shell\n    export SPARK_DL_WS=${WS_PATH}/spark-dl\n    databricks workspace mkdirs ${SPARK_DL_WS}\n    ```\n3. Specify the local paths to the notebook you wish to run.\n    As an example for a PyTorch notebook:\n    ```shell\n    export NOTEBOOK_SRC=</path/to/notebook_torch.ipynb>\n    ```\n4. Specify the framework to torch, tf, or vllm, corresponding to the notebook you wish to run. Continuing with the PyTorch example:\n    ```shell\n    export FRAMEWORK=torch\n    ```\n    This will tell the init script which libraries to install on the cluster.\n\n5. Copy the notebook, the utils file, and the init script to the Databricks Workspace:\n    ```shell\n    databricks workspace import ${SPARK_DL_WS}/$(basename \"$NOTEBOOK_SRC\") --format JUPYTER --file $NOTEBOOK_SRC\n    databricks workspace import ${SPARK_DL_WS}/server_utils.py --format AUTO --file $(realpath ../server_utils.py)\n    databricks workspace import ${SPARK_DL_WS}/init_spark_dl.sh --format AUTO --file $(pwd)/setup/init_spark_dl.sh\n    ```\n\n6. Launch the cluster with the provided script with the argument `aws` or `azure` based on your provider. Modify the scripts if you do not have the specific instance types. By default the script will create a cluster with 2 A10 workers and 1 A10 driver. \n    ```shell\n    cd setup\n    chmod +x start_cluster.sh\n    ./start_cluster.sh aws  # or ./start_cluster.sh azure\n    ```\n    To create a cluster capable of tensor parallelism, include the argument `tp` to acquire multiple GPUs per node:\n    ```shell\n    ./start_cluster.sh aws tp  # or ./start_cluster.sh azure tp\n    ```\n    In this case, the Azure worker nodes will have 2 GPUs each and the AWS workers will have 4 GPUs each (since AWS does not have an instance type with 2 GPUs) to run the tensor parallel example.* \n\n7. Navigate to the notebook in your workspace and attach it to the cluster. The default cluster name is `spark-dl-inference-$FRAMEWORK`.  \n\n*Note that the RAPIDS Accelerator for Apache Spark is not compatible with this case, since [multiple GPUs per executor are not yet supported](https://docs.nvidia.com/spark-rapids/user-guide/latest/faq.html#why-are-multiple-gpus-per-executor-not-supported)."
  },
  {
    "path": "examples/ML+DL-Examples/Spark-DL/dl_inference/databricks/setup/init_spark_dl.sh",
    "content": "#!/bin/bash\n# Copyright (c) 2025, NVIDIA CORPORATION.\n\nset -euxo pipefail\n\n# install requirements\nsudo /databricks/python3/bin/pip3 install --upgrade pip\n\nif [[ \"${FRAMEWORK}\" == \"torch\" ]]; then\n    cat <<EOF > temp_requirements.txt\ndatasets==3.*\ntransformers\nnvidia-pytriton\ntorch<=2.5.1\ntorchvision --extra-index-url https://download.pytorch.org/whl/cu121\ntorch-tensorrt\ntensorrt --extra-index-url https://download.pytorch.org/whl/cu121\nsentence_transformers\nsentencepiece\nnvidia-modelopt[all] --extra-index-url https://pypi.nvidia.com\nEOF\nelif [[ \"${FRAMEWORK}\" == \"tf\" ]]; then\n    cat <<EOF > temp_requirements.txt\ndatasets==3.*\ntransformers\nnvidia-pytriton\nEOF\nelif [[ \"${FRAMEWORK}\" == \"vllm\" ]]; then\n    cat <<EOF > temp_requirements.txt\nvllm==0.8.2\nEOF\nelse\n    echo \"Please export FRAMEWORK as torch, tf, or vllm per README\"\n    exit 1\nfi\n\nsudo /databricks/python3/bin/pip3 install --upgrade --force-reinstall -r temp_requirements.txt\nrm temp_requirements.txt\n"
  },
  {
    "path": "examples/ML+DL-Examples/Spark-DL/dl_inference/databricks/setup/start_cluster.sh",
    "content": "#!/bin/bash\n# Copyright (c) 2025, NVIDIA CORPORATION.\n\nset -eo pipefail\n\nif [ $# -lt 1 ] || [ $# -gt 2 ]; then\n    echo \"Usage: $0 <aws|azure> [tp]\"\n    exit 1\nfi\n\nCLOUD_PROVIDER=$1\nTENSOR_PARALLEL=false\n\n# Check if the second argument is \"tp\" for tensor parallelism\nif [ $# -eq 2 ] && [ \"$2\" == \"tp\" ]; then\n    TENSOR_PARALLEL=true\nfi\n\nif [[ \"${FRAMEWORK}\" != \"vllm\" && \"${FRAMEWORK}\" != \"torch\" && \"${FRAMEWORK}\" != \"tf\" ]]; then\n    echo \"Error: Please export FRAMEWORK as torch, tf, or vllm per README\"\n    exit 1\nfi\n\n# Modify the node types below if your Databricks account does not have these specific instance types. \n# Modify EXECUTOR_CORES=(cores per node) and EXECUTOR_GPU_AMT=(GPUs per node) accordingly.\n# We recommend selecting instances with A10/L4+ GPUs for these examples.\nif [[ \"${CLOUD_PROVIDER}\" == \"aws\" ]]; then\n    DRIVER_NODE_TYPE=\"g5.2xlarge\"\n    \n    if [[ \"${TENSOR_PARALLEL}\" == \"true\" ]]; then\n        # For tensor-parallelism examples, we default to the g5.12xlarge with 4 A10 GPUs (AWS does not have 2-GPU instances). \n        NODE_TYPE=\"g5.12xlarge\"\n        EXECUTOR_CORES=48\n        EXECUTOR_GPU_AMT=4\n    else\n        NODE_TYPE=\"g5.4xlarge\"\n        EXECUTOR_CORES=16\n        EXECUTOR_GPU_AMT=1\n    fi\nelif [[ \"${CLOUD_PROVIDER}\" == \"azure\" ]]; then\n    DRIVER_NODE_TYPE=\"Standard_NV36ads_A10_v5\"\n    \n    if [[ \"${TENSOR_PARALLEL}\" == \"true\" ]]; then\n        # For tensor-parallelism examples, we default to the Standard_NV72ads_A10_v5 with 2 A10 GPUs.\n        NODE_TYPE=\"Standard_NV72ads_A10_v5\"\n        EXECUTOR_CORES=72\n        EXECUTOR_GPU_AMT=2\n    else\n        NODE_TYPE=\"Standard_NV36ads_A10_v5\"\n        EXECUTOR_CORES=36\n        EXECUTOR_GPU_AMT=1\n    fi\nelse\n    echo \"Error: Cloud provider must be either 'aws' or 'azure'\"\n    exit 1\nfi\n\nCLUSTER_SUFFIX=\"${FRAMEWORK}\"\nif [[ \"${TENSOR_PARALLEL}\" == \"true\" ]]; then\n    CLUSTER_SUFFIX=\"${FRAMEWORK}-tp\"\nfi\n\n# Task GPU amount = Executor GPU amount / Executor cores\nTASK_GPU_AMT=$(awk \"BEGIN {print ${EXECUTOR_GPU_AMT}/${EXECUTOR_CORES}}\")\n\njson_config=$(cat <<EOF\n{\n    \"cluster_name\": \"spark-dl-inference-${CLUSTER_SUFFIX}\",\n    \"spark_version\": \"15.4.x-gpu-ml-scala2.12\",\n    \"spark_conf\": {\n        \"spark.executor.resource.gpu.amount\": \"${EXECUTOR_GPU_AMT}\",\n        \"spark.python.worker.reuse\": \"true\",\n        \"spark.sql.execution.arrow.pyspark.enabled\": \"true\",\n        \"spark.task.resource.gpu.amount\": \"${TASK_GPU_AMT}\",\n        \"spark.executor.cores\": \"${EXECUTOR_CORES}\"\n    },\n    \"node_type_id\": \"${NODE_TYPE}\",\n    \"driver_node_type_id\": \"${DRIVER_NODE_TYPE}\",\n    \"spark_env_vars\": {\n        \"TF_GPU_ALLOCATOR\": \"cuda_malloc_async\",\n        \"FRAMEWORK\": \"${FRAMEWORK}\"\n    },\n    \"autotermination_minutes\": 60,\n    \"enable_elastic_disk\": true,\n    \"init_scripts\": [\n        {\n            \"workspace\": {\n                \"destination\": \"${SPARK_DL_WS}/init_spark_dl.sh\"\n            }\n        }\n    ],\n    \"runtime_engine\": \"STANDARD\",\n    \"num_workers\": \"2\"\n}\nEOF\n)\n\ndatabricks clusters create --json \"$json_config\"\n"
  },
  {
    "path": "examples/ML+DL-Examples/Spark-DL/dl_inference/dataproc/README.md",
    "content": "# Spark DL Inference on Dataproc\n\n## Setup\n\n**Note**: fields in \\<brackets\\> require user inputs.  \nMake sure you are in [this](./) directory.\n\n#### Setup GCloud CLI\n\n1. Install the latest [gcloud-cli](https://cloud.google.com/sdk/docs/install) and initialize with `gcloud init`.\n\n2. Configure the following settings:\n    ```shell\n    export PROJECT=<your_project>\n    export DATAPROC_REGION=<your_dataproc_region>\n    export COMPUTE_REGION=<your_compute_region>\n    export COMPUTE_ZONE=<your_compute_zone>\n\n    gcloud config set project ${PROJECT}\n    gcloud config set dataproc/region ${DATAPROC_REGION}\n    gcloud config set compute/region ${COMPUTE_REGION}\n    gcloud config set compute/zone ${COMPUTE_ZONE}\n    ```\n\n#### Copy files to GCS\n\n3. Create a GCS bucket if you don't already have one:\n    ```shell\n    export GCS_BUCKET=<your_gcs_bucket_name>\n\n    gcloud storage buckets create gs://${GCS_BUCKET} \n    ```\n\n4.  Specify the local path to the notebook(s) and copy to the GCS bucket.\n    As an example for a torch notebook:\n    ```shell\n    export SPARK_DL_HOME=${GCS_BUCKET}/spark-dl\n    \n    gcloud storage cp </path/to/notebook_name_torch.ipynb> gs://${SPARK_DL_HOME}/notebooks/\n    ```\n    Repeat this step for any notebooks you wish to run. All notebooks under `gs://${SPARK_DL_HOME}/notebooks/` will be copied to the master node during initialization.\n\n5. Copy the utils file to the GCS bucket.\n    ```shell\n    gcloud storage cp $(realpath ../server_utils.py) gs://${SPARK_DL_HOME}/\n    ```\n\n#### Start cluster and run\n\n5. Specify the framework to use (torch, tf, or vllm), which will determine what libraries to install on the cluster. For example:\n    ```shell\n    export FRAMEWORK=torch\n    ```\n    Run the cluster startup script. The script will also retrieve and use the [spark-rapids initialization script](https://github.com/GoogleCloudDataproc/initialization-actions/blob/master/spark-rapids/spark-rapids.sh) to setup GPU resources. The script will create 2 L4 worker nodes and 1 L4 driver node by default, named `${USER}-spark-dl-inference-${FRAMEWORK}`. \n    ```shell\n    cd setup\n    chmod +x start_cluster.sh\n    ./start_cluster.sh\n    ```\n    To create a cluster capable of tensor parallelism, include the argument `tp` to acquire multiple GPUs per node:\n    ```shell\n    ./start_cluster.sh tp\n    ```\n    In this case, the worker nodes will have 2 L4s each to run the tensor parallel example.*\n\n7. Browse to the Jupyter web UI:\n    - Go to `Dataproc` > `Clusters` > `(Cluster Name)` > `Web Interfaces` > `Jupyter/Lab`\n    \n    Or, get the link by running this command (under httpPorts > Jupyter/Lab):\n    ```shell\n    gcloud dataproc clusters describe ${CLUSTER_NAME} --region=${COMPUTE_REGION}\n    ```\n\n8. Open and run the notebook interactively with the **Python 3 kernel**.  \nThe notebooks can be found under `Local Disk/spark-dl-notebooks` on the master node (folder icon on the top left > Local Disk).\n\n*Note that the RAPIDS Accelerator for Apache Spark is not applicable in this case, since [multiple GPUs per executor are not yet supported](https://docs.nvidia.com/spark-rapids/user-guide/latest/faq.html#why-are-multiple-gpus-per-executor-not-supported)."
  },
  {
    "path": "examples/ML+DL-Examples/Spark-DL/dl_inference/dataproc/setup/init_spark_dl.sh",
    "content": "#!/bin/bash\n# Copyright (c) 2025, NVIDIA CORPORATION.\n\nset -euxo pipefail\n\nfunction get_metadata_attribute() {\n  local -r attribute_name=$1\n  local -r default_value=$2\n  /usr/share/google/get_metadata_value \"attributes/${attribute_name}\" || echo -n \"${default_value}\"\n}\n\nSPARK_DL_HOME=$(get_metadata_attribute spark-dl-home UNSET)\nif [[ ${SPARK_DL_HOME} == \"UNSET\" ]]; then\n    echo \"Please set --metadata spark-dl-home\"\n    exit 1\nfi\n\nGCS_BUCKET=$(get_metadata_attribute gcs-bucket UNSET)\nif [[ ${GCS_BUCKET} == \"UNSET\" ]]; then\n    echo \"Please set --metadata gcs-bucket\"\n    exit 1\nfi\n\nREQUIREMENTS=$(get_metadata_attribute requirements UNSET)\nif [[ ${REQUIREMENTS} == \"UNSET\" ]]; then\n    echo \"Please set --metadata requirements\"\n    exit 1\nfi\n\n# mount gcs bucket as fuse\nexport GCSFUSE_REPO=gcsfuse-`lsb_release -c -s`\necho \"deb https://packages.cloud.google.com/apt $GCSFUSE_REPO main\" | sudo tee /etc/apt/sources.list.d/gcsfuse.list\ncurl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -\nsudo apt-get update\nsudo apt-get install -y fuse gcsfuse\nsudo mkdir -p /mnt/gcs\ngcsfuse -o allow_other --implicit-dirs ${GCS_BUCKET} /mnt/gcs\nsudo chmod -R 777 /mnt/gcs\n\n# install requirements\npip install --upgrade pip\necho \"${REQUIREMENTS}\" > temp_requirements.txt\npip install --upgrade --force-reinstall -r temp_requirements.txt\nrm temp_requirements.txt\n\n# copy notebooks to master\nROLE=$(/usr/share/google/get_metadata_value attributes/dataproc-role)\nif [[ \"${ROLE}\" == 'Master' ]]; then\n    if gsutil -q stat gs://${SPARK_DL_HOME}/notebooks/**; then\n        mkdir spark-dl-notebooks\n        gcloud storage cp -r gs://${SPARK_DL_HOME}/notebooks/* spark-dl-notebooks\n        gcloud storage cp gs://${SPARK_DL_HOME}/server_utils.py .\n    else\n        echo \"Failed to retrieve notebooks from gs://${SPARK_DL_HOME}/notebooks/\"\n        exit 1\n    fi\nfi\n\nsudo chmod -R a+rw /home/\nsudo systemctl daemon-reload\n"
  },
  {
    "path": "examples/ML+DL-Examples/Spark-DL/dl_inference/dataproc/setup/start_cluster.sh",
    "content": "#!/bin/bash\n# Copyright (c) 2025, NVIDIA CORPORATION.\n\nset -eo pipefail\n\nTENSOR_PARALLEL=false\nif [[ $# -gt 0 && \"$1\" == \"tp\" ]]; then\n    TENSOR_PARALLEL=true\n    echo \"Tensor parallelism enabled - will use larger machine types with multiple GPUs\"\nfi\n\n# configure arguments\nif [[ -z ${GCS_BUCKET} ]]; then\n    echo \"Please export GCS_BUCKET per README.md\"\n    exit 1\nfi\n\nif [[ -z ${FRAMEWORK} ]]; then\n    echo \"Please export FRAMEWORK as 'torch', 'tf', or 'vllm'\"\n    exit 1\nfi\n\nif [[ -z ${COMPUTE_REGION} ]]; then\n    COMPUTE_REGION=$(gcloud config get-value compute/region)\n    if [[ -z ${COMPUTE_REGION} ]]; then\n        echo \"Please export COMPUTE_REGION per README.md or set it in gcloud config.\"\n        exit 1\n    fi\nfi\n\nSPARK_DL_HOME=${SPARK_DL_HOME:-${GCS_BUCKET}/spark-dl}\n\n# copy init script to gcs\ngcloud storage cp init_spark_dl.sh gs://${SPARK_DL_HOME}/init/\nINIT_PATH=gs://${SPARK_DL_HOME}/init/init_spark_dl.sh\n\n# retrieve and upload spark-rapids initialization script to gcs\ncurl -LO https://raw.githubusercontent.com/GoogleCloudDataproc/initialization-actions/master/spark-rapids/spark-rapids.sh\n# don't enable rapids plugin by default\nsed -i '/spark.plugins=com.nvidia.spark.SQLPlugin/d' spark-rapids.sh\ngcloud storage cp spark-rapids.sh gs://${SPARK_DL_HOME}/init/\n# rm spark-rapids.sh\n\nCOMMON_REQUIREMENTS=\"numpy\npandas\nmatplotlib\nportalocker\npyarrow\npydot\nscikit-learn\nhuggingface\ndatasets==3.*\ntransformers\nnvidia-pytriton\"\n\nTORCH_REQUIREMENTS=\"${COMMON_REQUIREMENTS}\ntorch<=2.5.1\ntorchvision --extra-index-url https://download.pytorch.org/whl/cu121\ntorch-tensorrt\ntensorrt --extra-index-url https://download.pytorch.org/whl/cu121\nsentence_transformers\nsentencepiece\nnvidia-modelopt[all] --extra-index-url https://pypi.nvidia.com\"\n\nTF_REQUIREMENTS=\"${COMMON_REQUIREMENTS}\ntensorflow[and-cuda]\ntf-keras\"\n\nVLLM_REQUIREMENTS=\"datasets==3.*\nvllm==0.8.2\"\n\ncluster_name=${USER}-spark-dl-inference-${FRAMEWORK}\nif [[ \"${TENSOR_PARALLEL}\" == \"true\" ]]; then\n    cluster_name=\"${cluster_name}-tp\"\nfi\n\nif [[ ${FRAMEWORK} == \"torch\" ]]; then\n    requirements=${TORCH_REQUIREMENTS}\n    echo \"=========================================================\"\n    echo \"Starting PyTorch cluster ${cluster_name}\"\n    echo \"=========================================================\"\nelif [[ ${FRAMEWORK} == \"tf\" ]]; then\n    requirements=${TF_REQUIREMENTS}\n    echo \"=========================================================\"\n    echo \"Starting Tensorflow cluster ${cluster_name}\"\n    echo \"=========================================================\"\nelif [[ ${FRAMEWORK} == \"vllm\" ]]; then\n    requirements=${VLLM_REQUIREMENTS}\n    echo \"=========================================================\"\n    echo \"Starting vLLM cluster ${cluster_name}\"\n    echo \"=========================================================\"\nelse\n    echo \"Please export FRAMEWORK as torch, tf, or vllm\"\n    exit 1\nfi\n\nif [[ \"${TENSOR_PARALLEL}\" == \"true\" ]]; then\n    WORKER_MACHINE_TYPE=\"g2-standard-24\"  # 2 L4 GPUs per node\nelse\n    WORKER_MACHINE_TYPE=\"g2-standard-8\"   # 1 L4 GPU per node\nfi\n\nif gcloud dataproc clusters list | grep -q \"${cluster_name}\"; then\n    echo \"Cluster ${cluster_name} already exists.\"\n    exit 0\nfi\n\nCLUSTER_PARAMS=(\n    --image-version=2.2-ubuntu\n    --region \"${COMPUTE_REGION}\"\n    --num-workers 2\n    --master-machine-type g2-standard-8\n    --worker-machine-type \"${WORKER_MACHINE_TYPE}\"\n    --initialization-actions gs://\"${SPARK_DL_HOME}\"/init/spark-rapids.sh,\"${INIT_PATH}\"\n    --metadata gpu-driver-provider=\"NVIDIA\"\n    --metadata gcs-bucket=\"${GCS_BUCKET}\"\n    --metadata spark-dl-home=\"${SPARK_DL_HOME}\"\n    --metadata requirements=\"${requirements}\"\n    --worker-local-ssd-interface=NVME\n    --optional-components=JUPYTER\n    --bucket \"${GCS_BUCKET}\"\n    --enable-component-gateway\n    --max-idle \"60m\"\n    --subnet=default\n    --no-shielded-secure-boot\n)\n\ngcloud dataproc clusters create ${cluster_name} \"${CLUSTER_PARAMS[@]}\"\n"
  },
  {
    "path": "examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/conditional_generation_tf.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"777fc40d\",\n   \"metadata\": {},\n   \"source\": [\n    \"<img src=\\\"http://developer.download.nvidia.com/notebooks/dlsw-notebooks/tensorrt_torchtrt_efficientnet/nvidia_logo.png\\\" width=\\\"90px\\\">\\n\",\n    \"\\n\",\n    \"# PySpark Huggingface Inferencing\\n\",\n    \"### Conditional generation with Tensorflow\\n\",\n    \"\\n\",\n    \"In this notebook, we demonstrate distributed inference with the T5 transformer to perform sentence translation.  \\n\",\n    \"From: https://huggingface.co/docs/transformers/model_doc/t5\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"05c79ac4-bf25-421e-b55e-020d6d9e15d5\",\n   \"metadata\": {},\n   \"source\": [\n    \"Note that cuFFT/cuDNN/cuBLAS registration errors are expected (as of `tf=2.17.0`) and will not affect behavior, as noted in [this issue.](https://github.com/tensorflow/tensorflow/issues/62075)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"id\": \"f6f0dbf3-712b-4c58-85eb-261ce15bb2be\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2025-02-04 13:53:50.831324: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\\n\",\n      \"2025-02-04 13:53:50.838528: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\\n\",\n      \"2025-02-04 13:53:50.846226: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\\n\",\n      \"2025-02-04 13:53:50.848585: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\\n\",\n      \"2025-02-04 13:53:50.854859: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\\n\",\n      \"To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\\n\",\n      \"2025-02-04 13:53:51.229622: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"from transformers import AutoTokenizer, TFT5ForConditionalGeneration\\n\",\n    \"\\n\",\n    \"# Manually enable Huggingface tokenizer parallelism to avoid disabling with PySpark parallelism.\\n\",\n    \"# See (https://github.com/huggingface/transformers/issues/5486) for more info. \\n\",\n    \"import os\\n\",\n    \"os.environ[\\\"TOKENIZERS_PARALLELISM\\\"] = \\\"true\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"id\": \"275890d7\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2.17.0\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\\n\",\n      \"I0000 00:00:1738706031.770264 3625306 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\\n\",\n      \"I0000 00:00:1738706031.793270 3625306 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\\n\",\n      \"I0000 00:00:1738706031.796251 3625306 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"import tensorflow as tf\\n\",\n    \"\\n\",\n    \"# Enable GPU memory growth\\n\",\n    \"gpus = tf.config.experimental.list_physical_devices('GPU')\\n\",\n    \"if gpus:\\n\",\n    \"    try:\\n\",\n    \"        for gpu in gpus:\\n\",\n    \"            tf.config.experimental.set_memory_growth(gpu, True)\\n\",\n    \"    except RuntimeError as e:\\n\",\n    \"        print(e)\\n\",\n    \"        \\n\",\n    \"print(tf.__version__)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"id\": \"2684fb41-9467-40c0-9d7e-a1cc867c5a3c\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"I0000 00:00:1738706032.132191 3625306 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\\n\",\n      \"I0000 00:00:1738706032.134996 3625306 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\\n\",\n      \"I0000 00:00:1738706032.137528 3625306 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\\n\",\n      \"I0000 00:00:1738706032.251302 3625306 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\\n\",\n      \"I0000 00:00:1738706032.252345 3625306 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\\n\",\n      \"I0000 00:00:1738706032.253281 3625306 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\\n\",\n      \"2025-02-04 13:53:52.254192: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 43462 MB memory:  -> device: 0, name: NVIDIA RTX A6000, pci bus id: 0000:01:00.0, compute capability: 8.6\\n\",\n      \"All PyTorch model weights were used when initializing TFT5ForConditionalGeneration.\\n\",\n      \"\\n\",\n      \"All the weights of TFT5ForConditionalGeneration were initialized from the PyTorch model.\\n\",\n      \"If your task is similar to the task the model of the checkpoint was trained on, you can already use TFT5ForConditionalGeneration for predictions without further training.\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"tokenizer = AutoTokenizer.from_pretrained(\\\"google-t5/t5-small\\\")\\n\",\n    \"model = TFT5ForConditionalGeneration.from_pretrained(\\\"google-t5/t5-small\\\")\\n\",\n    \"\\n\",\n    \"task_prefix = \\\"translate English to German: \\\"\\n\",\n    \"\\n\",\n    \"lines = [\\n\",\n    \"    \\\"The house is wonderful\\\",\\n\",\n    \"    \\\"Welcome to NYC\\\",\\n\",\n    \"    \\\"HuggingFace is a company\\\"\\n\",\n    \"]\\n\",\n    \"\\n\",\n    \"input_sequences = [task_prefix + l for l in lines]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"id\": \"6eb2dfdb-0ad3-4d0f-81a4-268d92c53759\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\\n\",\n      \"I0000 00:00:1738706033.555987 3625654 service.cc:146] XLA service 0x712d300025f0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:\\n\",\n      \"I0000 00:00:1738706033.556005 3625654 service.cc:154]   StreamExecutor device (0): NVIDIA RTX A6000, Compute Capability 8.6\\n\",\n      \"2025-02-04 13:53:53.558887: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.\\n\",\n      \"2025-02-04 13:53:53.569767: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:531] Loaded cuDNN version 8907\\n\",\n      \"I0000 00:00:1738706033.604327 3625654 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"inputs = tokenizer(input_sequences, \\n\",\n    \"                      padding=True,\\n\",\n    \"                      return_tensors=\\\"tf\\\")\\n\",\n    \"outputs = model.generate(input_ids=inputs[\\\"input_ids\\\"], attention_mask=inputs[\\\"attention_mask\\\"], max_length=128)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"id\": \"720158d4-e0e0-4904-b096-e5aede756afd\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"['Das Haus ist wunderbar',\\n\",\n       \" 'Willkommen in NYC',\\n\",\n       \" 'HuggingFace ist ein Unternehmen']\"\n      ]\n     },\n     \"execution_count\": 5,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"[tokenizer.decode(o, skip_special_tokens=True) for o in outputs]\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"546eabe0\",\n   \"metadata\": {},\n   \"source\": [\n    \"## PySpark\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"id\": \"68121304-f1df-466e-9347-c9d2b36a9b3a\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from pyspark.sql.types import *\\n\",\n    \"from pyspark import SparkConf\\n\",\n    \"from pyspark.sql import SparkSession\\n\",\n    \"from pyspark.sql.functions import pandas_udf, col, struct\\n\",\n    \"from pyspark.ml.functions import predict_batch_udf\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"id\": \"2f6db1f0-7d68-4af7-8bd6-c9fa45906c61\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import json\\n\",\n    \"import pandas as pd\\n\",\n    \"import datasets\\n\",\n    \"from datasets import load_dataset\\n\",\n    \"datasets.disable_progress_bars()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"0d636975\",\n   \"metadata\": {},\n   \"source\": [\n    \"Check the cluster environment to handle any platform-specific Spark configurations.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"id\": \"ca351245\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"on_databricks = os.environ.get(\\\"DATABRICKS_RUNTIME_VERSION\\\", False)\\n\",\n    \"on_dataproc = os.environ.get(\\\"DATAPROC_IMAGE_VERSION\\\", False)\\n\",\n    \"on_standalone = not (on_databricks or on_dataproc)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"d3199f8b\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Create Spark Session\\n\",\n    \"\\n\",\n    \"For local standalone clusters, we'll connect to the cluster and create the Spark Session.  \\n\",\n    \"For CSP environments, Spark will either be preconfigured (Databricks) or we'll need to create the Spark Session (Dataproc).\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 9,\n   \"id\": \"6279a849\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"25/02/04 13:53:54 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)\\n\",\n      \"25/02/04 13:53:54 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\\n\",\n      \"Setting default log level to \\\"WARN\\\".\\n\",\n      \"To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\\n\",\n      \"25/02/04 13:53:55 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"conf = SparkConf()\\n\",\n    \"\\n\",\n    \"if 'spark' not in globals():\\n\",\n    \"    if on_standalone:\\n\",\n    \"        import socket\\n\",\n    \"        \\n\",\n    \"        conda_env = os.environ.get(\\\"CONDA_PREFIX\\\")\\n\",\n    \"        hostname = socket.gethostname()\\n\",\n    \"        conf.setMaster(f\\\"spark://{hostname}:7077\\\")\\n\",\n    \"        conf.set(\\\"spark.pyspark.python\\\", f\\\"{conda_env}/bin/python\\\")\\n\",\n    \"        conf.set(\\\"spark.pyspark.driver.python\\\", f\\\"{conda_env}/bin/python\\\")\\n\",\n    \"    elif on_dataproc:\\n\",\n    \"        conf.set(\\\"spark.executorEnv.TF_GPU_ALLOCATOR\\\", \\\"cuda_malloc_async\\\")\\n\",\n    \"\\n\",\n    \"    conf.set(\\\"spark.executor.cores\\\", \\\"8\\\")\\n\",\n    \"    conf.set(\\\"spark.task.resource.gpu.amount\\\", \\\"0.125\\\")\\n\",\n    \"    conf.set(\\\"spark.executor.resource.gpu.amount\\\", \\\"1\\\")\\n\",\n    \"    conf.set(\\\"spark.sql.execution.arrow.pyspark.enabled\\\", \\\"true\\\")\\n\",\n    \"    conf.set(\\\"spark.python.worker.reuse\\\", \\\"true\\\")\\n\",\n    \"\\n\",\n    \"conf.set(\\\"spark.sql.execution.arrow.maxRecordsPerBatch\\\", \\\"1000\\\")\\n\",\n    \"spark = SparkSession.builder.appName(\\\"spark-dl-examples\\\").config(conf=conf).getOrCreate()\\n\",\n    \"sc = spark.sparkContext\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"7f311650\",\n   \"metadata\": {},\n   \"source\": [\n    \"Load the IMBD Movie Reviews dataset from Huggingface.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 10,\n   \"id\": \"b8453111-d068-49bb-ab91-8ae3d8bcdb7a\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"dataset = load_dataset(\\\"imdb\\\", split=\\\"test\\\")\\n\",\n    \"dataset = dataset.to_pandas().drop(columns=\\\"label\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"6fd5b472-47e8-4804-9907-772793fedb2b\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Create PySpark DataFrame\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 11,\n   \"id\": \"d24d9404-0269-476e-a9dd-1842667c915a\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"StructType([StructField('text', StringType(), True)])\"\n      ]\n     },\n     \"execution_count\": 11,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"df = spark.createDataFrame(dataset).repartition(8)\\n\",\n    \"df.schema\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 12,\n   \"id\": \"c76314b7\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"25000\"\n      ]\n     },\n     \"execution_count\": 12,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"df.count()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 13,\n   \"id\": \"4384c762-1f79-4f60-876c-94b1f552e8fb\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"25/02/04 13:54:01 WARN TaskSetManager: Stage 6 contains a task of very large size (4021 KiB). The maximum recommended task size is 1000 KiB.\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"[Row(text=\\\"Anyone remember the first CKY, CKY2K etc..? Back when it was about making crazy cool stuff, rather than watching Bam Margera act like a douchebag, spoiled 5 year old, super/rock-star wannabe.<br /><br />The show used to be awesome, however, Bam's fame and wealth has led him to believe, that we now enjoy him acting childish and idiotic, more than actual cool stuff, that used to be in ex. CKY2K.<br /><br />The acts are so repetitive, there's like nothing new, except annoying stupidity and rehearsed comments... The only things we see is Bam Margera, so busy showing us how much he doesn't care, how much money he got or whatsoever.<br /><br />I really got nothing much left to say except, give us back CKY2K, cause Bam suck..<br /><br />I enjoy watching Steve-o, Knoxville etc. a thousand times more.\\\")]\"\n      ]\n     },\n     \"execution_count\": 13,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"df.take(1)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"42ba3513-82dd-47e7-8193-eb4389458757\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Save the test dataset as parquet files\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 14,\n   \"id\": \"e7eec8ec-4126-4890-b957-025809fad67d\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"25/02/04 13:54:02 WARN TaskSetManager: Stage 9 contains a task of very large size (4021 KiB). The maximum recommended task size is 1000 KiB.\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"data_path = \\\"spark-dl-datasets/imdb_test\\\"\\n\",\n    \"if on_databricks:\\n\",\n    \"    dbutils.fs.mkdirs(\\\"/FileStore/spark-dl-datasets\\\")\\n\",\n    \"    data_path = \\\"dbfs:/FileStore/\\\" + data_path\\n\",\n    \"\\n\",\n    \"df.write.mode(\\\"overwrite\\\").parquet(data_path)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"078425e1\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Load and preprocess DataFrame\\n\",\n    \"\\n\",\n    \"Define our preprocess function. We'll take the first sentence from each sample as our input for translation.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 15,\n   \"id\": \"b9a0889a-35b4-493a-8197-1146fc7efd53\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def preprocess(text: pd.Series, prefix: str = \\\"\\\") -> pd.Series:\\n\",\n    \"    @pandas_udf(\\\"string\\\")\\n\",\n    \"    def _preprocess(text: pd.Series) -> pd.Series:\\n\",\n    \"        return pd.Series([prefix + s.split(\\\".\\\")[0] for s in text])\\n\",\n    \"    return _preprocess(text)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 16,\n   \"id\": \"c483e4d4-9ab1-416f-a766-694e17490fd3\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+----------------------------------------------------------------------------------------------------+\\n\",\n      \"|                                                                                                text|\\n\",\n      \"+----------------------------------------------------------------------------------------------------+\\n\",\n      \"|Doesn't anyone bother to check where this kind of sludge comes from before blathering on about it...|\\n\",\n      \"|There were two things I hated about WASTED : The directing and the script . I know I`m opening my...|\\n\",\n      \"|I'm rather surprised that anybody found this film touching or moving.<br /><br />The basic premis...|\\n\",\n      \"|Cultural Vandalism Is the new Hallmark production of Gulliver's Travels an act of cultural vandal...|\\n\",\n      \"|I was at Wrestlemania VI in Toronto as a 10 year old, and the event I saw then was pretty differe...|\\n\",\n      \"|This movie has been done before. It is basically a unoriginal combo of \\\"Napoleon Dynamite\\\" and \\\"S...|\\n\",\n      \"|[ as a new resolution for this year 2005, i decide to write a comment for each movie I saw in the...|\\n\",\n      \"|This movie is over hyped!! I am sad to say that I manage to watch the first 15 minutes of this mo...|\\n\",\n      \"|This show had a promising start as sort of the opposite of 'Oceans 11' but has developed into a s...|\\n\",\n      \"|MINOR PLOT SPOILERS AHEAD!!!<br /><br />How did such talented actors get involved in such mindles...|\\n\",\n      \"|There is not one character on this sitcom with any redeeming qualities. They are all self-centere...|\\n\",\n      \"|Tommy Lee Jones was the best Woodroe and no one can play Woodroe F. Call better than he. Not only...|\\n\",\n      \"|My wife rented this movie and then conveniently never got to see it. If I ever want to torture he...|\\n\",\n      \"|This is one of those star-filled over-the-top comedies that could a) be hysterical, or b) wish th...|\\n\",\n      \"|This excruciatingly boring and unfunny movie made me think that Chaplin was the real Hitler, as o...|\\n\",\n      \"|you will likely be sorely disappointed by this sequel that's not a sequel.AWIL is a classic.but t...|\\n\",\n      \"|If I was British, I would be embarrassed by this portrayal of incompetence. A protection agent of...|\\n\",\n      \"|One of those movies in which there are no big twists whatsoever and you can predict pretty much w...|\\n\",\n      \"|This show is like watching someone who is in training to someday host a show. There are some good...|\\n\",\n      \"|Sigh. I'm baffled when I see a short like this get attention and assignments and whatnot. I saw t...|\\n\",\n      \"+----------------------------------------------------------------------------------------------------+\\n\",\n      \"only showing top 20 rows\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"# Limit to N rows, since this can be slow\\n\",\n    \"df = spark.read.parquet(data_path).limit(256).repartition(8)\\n\",\n    \"df.show(truncate=100)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"a9f8e538\",\n   \"metadata\": {},\n   \"source\": [\n    \"Append a prefix to tell the model to translate English to French:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 17,\n   \"id\": \"831bc52c-a5c6-4c29-a6da-0566b5167773\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+----------------------------------------------------------------------------------------------------+\\n\",\n      \"|                                                                                               input|\\n\",\n      \"+----------------------------------------------------------------------------------------------------+\\n\",\n      \"|translate English to French: Doesn't anyone bother to check where this kind of sludge comes from ...|\\n\",\n      \"|translate English to French: There were two things I hated about WASTED : The directing and the s...|\\n\",\n      \"|   translate English to French: I'm rather surprised that anybody found this film touching or moving|\\n\",\n      \"|translate English to French: Cultural Vandalism Is the new Hallmark production of Gulliver's Trav...|\\n\",\n      \"|translate English to French: I was at Wrestlemania VI in Toronto as a 10 year old, and the event ...|\\n\",\n      \"|                                        translate English to French: This movie has been done before|\\n\",\n      \"|translate English to French: [ as a new resolution for this year 2005, i decide to write a commen...|\\n\",\n      \"|translate English to French: This movie is over hyped!! I am sad to say that I manage to watch th...|\\n\",\n      \"|translate English to French: This show had a promising start as sort of the opposite of 'Oceans 1...|\\n\",\n      \"|translate English to French: MINOR PLOT SPOILERS AHEAD!!!<br /><br />How did such talented actors...|\\n\",\n      \"| translate English to French: There is not one character on this sitcom with any redeeming qualities|\\n\",\n      \"|     translate English to French: Tommy Lee Jones was the best Woodroe and no one can play Woodroe F|\\n\",\n      \"|    translate English to French: My wife rented this movie and then conveniently never got to see it|\\n\",\n      \"|translate English to French: This is one of those star-filled over-the-top comedies that could a)...|\\n\",\n      \"|translate English to French: This excruciatingly boring and unfunny movie made me think that Chap...|\\n\",\n      \"|translate English to French: you will likely be sorely disappointed by this sequel that's not a s...|\\n\",\n      \"|translate English to French: If I was British, I would be embarrassed by this portrayal of incomp...|\\n\",\n      \"|translate English to French: One of those movies in which there are no big twists whatsoever and ...|\\n\",\n      \"|translate English to French: This show is like watching someone who is in training to someday hos...|\\n\",\n      \"|                                                                   translate English to French: Sigh|\\n\",\n      \"+----------------------------------------------------------------------------------------------------+\\n\",\n      \"only showing top 20 rows\\n\",\n      \"\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"input_df = df.select(preprocess(col(\\\"text\\\"), \\\"translate English to French: \\\").alias(\\\"input\\\")).cache()\\n\",\n    \"input_df.show(truncate=100)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"ec53a65c\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Inference using Spark DL API\\n\",\n    \"\\n\",\n    \"Distributed inference using the PySpark [predict_batch_udf](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.functions.predict_batch_udf.html#pyspark.ml.functions.predict_batch_udf):\\n\",\n    \"\\n\",\n    \"- predict_batch_fn uses Tensorflow APIs to load the model and return a predict function which operates on numpy arrays \\n\",\n    \"- predict_batch_udf will convert the Spark DataFrame columns into numpy input batches for the predict function\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 18,\n   \"id\": \"e7ae69d3-70c2-4765-928f-c96a7ba59829\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def predict_batch_fn():\\n\",\n    \"    import tensorflow as tf\\n\",\n    \"    import numpy as np\\n\",\n    \"    from transformers import TFT5ForConditionalGeneration, AutoTokenizer\\n\",\n    \"\\n\",\n    \"    # Enable GPU memory growth\\n\",\n    \"    print(\\\"initializing model\\\")\\n\",\n    \"    gpus = tf.config.experimental.list_physical_devices('GPU')\\n\",\n    \"    if gpus:\\n\",\n    \"        try:\\n\",\n    \"            for gpu in gpus:\\n\",\n    \"                tf.config.experimental.set_memory_growth(gpu, True)\\n\",\n    \"        except RuntimeError as e:\\n\",\n    \"            print(e)\\n\",\n    \"\\n\",\n    \"    model = TFT5ForConditionalGeneration.from_pretrained(\\\"google-t5/t5-small\\\")\\n\",\n    \"    tokenizer = AutoTokenizer.from_pretrained(\\\"google-t5/t5-small\\\")\\n\",\n    \"\\n\",\n    \"    def predict(inputs):\\n\",\n    \"        flattened = np.squeeze(inputs).tolist()\\n\",\n    \"        inputs = tokenizer(flattened, \\n\",\n    \"                            padding=True, \\n\",\n    \"                            return_tensors=\\\"tf\\\")\\n\",\n    \"        outputs = model.generate(input_ids=inputs[\\\"input_ids\\\"],\\n\",\n    \"                                 attention_mask=inputs[\\\"attention_mask\\\"],\\n\",\n    \"                                 max_length=128)\\n\",\n    \"        string_outputs = np.array([tokenizer.decode(o, skip_special_tokens=True) for o in outputs])\\n\",\n    \"        print(\\\"predict: {}\\\".format(len(flattened)))\\n\",\n    \"        return string_outputs\\n\",\n    \"    \\n\",\n    \"    return predict\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 19,\n   \"id\": \"36684f59-d947-43f8-a2e8-c7a423764e88\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"generate = predict_batch_udf(predict_batch_fn,\\n\",\n    \"                             return_type=StringType(),\\n\",\n    \"                             batch_size=32)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 20,\n   \"id\": \"6a01c855-8fa1-4765-a3a5-2c9dd872df10\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 24:====================================>                     (5 + 3) / 8]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 9.07 ms, sys: 8.83 ms, total: 17.9 ms\\n\",\n      \"Wall time: 19.3 s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"# first pass caches model/fn\\n\",\n    \"preds = input_df.withColumn(\\\"preds\\\", generate(struct(\\\"input\\\")))\\n\",\n    \"results = preds.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 21,\n   \"id\": \"d912d4b0-cd0b-44ea-859a-b23455cc2700\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 27:==================================================>       (7 + 1) / 8]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 7.51 ms, sys: 4.96 ms, total: 12.5 ms\\n\",\n      \"Wall time: 12.4 s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"preds = input_df.withColumn(\\\"preds\\\", generate(\\\"input\\\"))\\n\",\n    \"results = preds.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 22,\n   \"id\": \"5fe3d88b-30f7-468f-8db8-1f4118d0f26c\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 30:=====================>                                    (3 + 5) / 8]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 5.46 ms, sys: 5.98 ms, total: 11.4 ms\\n\",\n      \"Wall time: 11.4 s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"preds = input_df.withColumn(\\\"preds\\\", generate(col(\\\"input\\\")))\\n\",\n    \"results = preds.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 23,\n   \"id\": \"4ad9b365-4b9a-438e-8fdf-47da55cb1cf4\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 33:>                                                         (0 + 1) / 1]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+--------------------------------------------------+--------------------------------------------------+\\n\",\n      \"|                                             input|                                             preds|\\n\",\n      \"+--------------------------------------------------+--------------------------------------------------+\\n\",\n      \"|translate English to French: Doesn't anyone bot...|Ne s'ennuie-t-il pas de vérifier où viennent ce...|\\n\",\n      \"|translate English to French: There were two thi...|Il y avait deux choses que j'ai hâte de voir : ...|\\n\",\n      \"|translate English to French: I'm rather surpris...|Je suis plutôt surpris que quelqu'un ait trouvé...|\\n\",\n      \"|translate English to French: Cultural Vandalism...|Vandalisme culturel La nouvelle production Hall...|\\n\",\n      \"|translate English to French: I was at Wrestlema...|J'étais à Wrestlemania VI à Toronto en 10 ans, ...|\\n\",\n      \"|translate English to French: This movie has bee...|                       Ce film a été réalisé avant|\\n\",\n      \"|translate English to French: [ as a new resolut...|[ en tant que nouvelle résolution pour cette an...|\\n\",\n      \"|translate English to French: This movie is over...|Je suis triste de dire que je parviens à regard...|\\n\",\n      \"|translate English to French: This show had a pr...|Ce spectacle a eu un début prometteur en l'espè...|\\n\",\n      \"|translate English to French: MINOR PLOT SPOILER...|br />br /> Comment ces acteurs talentueux ont-i...|\\n\",\n      \"|translate English to French: There is not one c...|Il n'y a pas d'un personnage sur ce sitcom ayan...|\\n\",\n      \"|translate English to French: Tommy Lee Jones wa...|Tommy Lee Jones était le meilleur Woodroe et pe...|\\n\",\n      \"|translate English to French: My wife rented thi...|Ma femme a loué ce film et n'a jamais pu le voi...|\\n\",\n      \"|translate English to French: This is one of tho...|C’est l’une des comédies en étoiles à l’étoile ...|\\n\",\n      \"|translate English to French: This excruciatingl...|Ce film excruciant ennuyant et infaillible m’a ...|\\n\",\n      \"|translate English to French: you will likely be...|Vous serez probablement très déçu par cette séq...|\\n\",\n      \"|translate English to French: If I was British, ...|Si j'étais britannique, je seraitis embarrassé ...|\\n\",\n      \"|translate English to French: One of those movie...|Un des films dans lesquels il n'y a pas de gros...|\\n\",\n      \"|translate English to French: This show is like ...|Ce spectacle ressemble à l'observation d'une pe...|\\n\",\n      \"|                 translate English to French: Sigh|                                             Pesée|\\n\",\n      \"+--------------------------------------------------+--------------------------------------------------+\\n\",\n      \"only showing top 20 rows\\n\",\n      \"\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"preds.show(truncate=50)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 24,\n   \"id\": \"1eb0c83b-d91b-4f8c-a5e7-c35f55c88108\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"input_df2 = df.select(preprocess(col(\\\"text\\\"), \\\"translate English to German: \\\").alias(\\\"input\\\")).cache()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 25,\n   \"id\": \"6f6b70f9-188a-402b-9143-78a5788140e4\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 36:==================================================>       (7 + 1) / 8]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 9.1 ms, sys: 4.04 ms, total: 13.1 ms\\n\",\n      \"Wall time: 14.9 s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"# first pass caches model/fn\\n\",\n    \"preds = input_df2.withColumn(\\\"preds\\\", generate(struct(\\\"input\\\")))\\n\",\n    \"result = preds.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 26,\n   \"id\": \"031a6a5e-7999-4653-b394-19ed478d8c96\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 39:==================================================>       (7 + 1) / 8]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 6.62 ms, sys: 5.23 ms, total: 11.9 ms\\n\",\n      \"Wall time: 11.9 s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"preds = input_df2.withColumn(\\\"preds\\\", generate(\\\"input\\\"))\\n\",\n    \"result = preds.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 27,\n   \"id\": \"229b6515-82f6-4e9c-90f0-a9c3cfb26301\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 42:==============>                                           (2 + 6) / 8]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 8.67 ms, sys: 3.27 ms, total: 11.9 ms\\n\",\n      \"Wall time: 11.7 s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"preds = input_df2.withColumn(\\\"preds\\\", generate(col(\\\"input\\\")))\\n\",\n    \"result = preds.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 28,\n   \"id\": \"8be750ac-fa39-452e-bb4c-c2270bc2f70d\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 45:>                                                         (0 + 1) / 1]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+--------------------------------------------------+--------------------------------------------------+\\n\",\n      \"|                                             input|                                             preds|\\n\",\n      \"+--------------------------------------------------+--------------------------------------------------+\\n\",\n      \"|translate English to German: Doesn't anyone bot...|Warum hat man sich nicht angeschaut, woher der ...|\\n\",\n      \"|translate English to German: There were two thi...|Es gab zwei Dinge, die ich hat an WASTED gehass...|\\n\",\n      \"|translate English to German: I'm rather surpris...|Ich bin ziemlich überrascht, dass jemand diesen...|\\n\",\n      \"|translate English to German: Cultural Vandalism...|Kultureller Vandalismus Ist die neue Hallmark-P...|\\n\",\n      \"|translate English to German: I was at Wrestlema...|Ich war als 10 Jahre alt bei Wrestlemania VI in...|\\n\",\n      \"|translate English to German: This movie has bee...|             Dieser Film wurde bereits vorgenommen|\\n\",\n      \"|translate English to German: [ as a new resolut...|[ als neue Entschließung für dieses Jahr 2005, ...|\\n\",\n      \"|translate English to German: This movie is over...|Ich hoffe, dass ich die ersten 15 Minuten diese...|\\n\",\n      \"|translate English to German: This show had a pr...|Diese Show hatte einen vielversprechenden Start...|\\n\",\n      \"|translate English to German: MINOR PLOT SPOILER...|br />br />Wie haben sich so talentierte Schausp...|\\n\",\n      \"|translate English to German: There is not one c...|Es gibt keinen Charakter auf dieser Seite mit i...|\\n\",\n      \"|translate English to German: Tommy Lee Jones wa...|Tommy Lee Jones war der beste Woodroe und niema...|\\n\",\n      \"|translate English to German: My wife rented thi...|Meine Frau hat diesen Film vermietet und dann b...|\\n\",\n      \"|translate English to German: This is one of tho...|Dies ist eines der Sterne-gefüllten über-the-to...|\\n\",\n      \"|translate English to German: This excruciatingl...|Dieser schreckliche langweilige und unfunnelnde...|\\n\",\n      \"|translate English to German: you will likely be...|Sie werden wahrscheinlich ernsthaft enttäuscht ...|\\n\",\n      \"|translate English to German: If I was British, ...|Wenn ich Britisch wäre, wäre ich beschämt über ...|\\n\",\n      \"|translate English to German: One of those movie...|Einer der Filme, in denen es keine großen Drehu...|\\n\",\n      \"|translate English to German: This show is like ...|Diese Show ist wie ein jemanden, der in Ausbild...|\\n\",\n      \"|                 translate English to German: Sigh|                                            Segnen|\\n\",\n      \"+--------------------------------------------------+--------------------------------------------------+\\n\",\n      \"only showing top 20 rows\\n\",\n      \"\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"preds.show(truncate=50)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"f5803188\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Using Triton Inference Server\\n\",\n    \"In this section, we demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL.  \\n\",\n    \"We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server.  \\n\",\n    \"\\n\",\n    \"The process looks like this:\\n\",\n    \"- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.\\n\",\n    \"- Define a Triton inference function, which contains a client that binds to the local server on a given node and sends inference requests.\\n\",\n    \"- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.\\n\",\n    \"- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.\\n\",\n    \"\\n\",\n    \"<img src=\\\"../images/spark-server.png\\\" alt=\\\"drawing\\\" width=\\\"700\\\"/>\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 29,\n   \"id\": \"6d09f972\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from functools import partial\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"2964ffee\",\n   \"metadata\": {},\n   \"source\": [\n    \"Import the helper class from server_utils.py:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 30,\n   \"id\": \"f1083dc8\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"sc.addPyFile(\\\"server_utils.py\\\")\\n\",\n    \"\\n\",\n    \"from server_utils import TritonServerManager\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"066c8695\",\n   \"metadata\": {},\n   \"source\": [\n    \"Define the Triton Server function:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 31,\n   \"id\": \"afd00b7e-8150-4c95-a2e4-037e9c90f92a\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def triton_server(ports):\\n\",\n    \"    import time\\n\",\n    \"    import signal\\n\",\n    \"    import numpy as np\\n\",\n    \"    import tensorflow as tf\\n\",\n    \"    from transformers import TFT5ForConditionalGeneration, AutoTokenizer\\n\",\n    \"    from pytriton.decorators import batch\\n\",\n    \"    from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor\\n\",\n    \"    from pytriton.triton import Triton, TritonConfig\\n\",\n    \"    from pyspark import TaskContext\\n\",\n    \"\\n\",\n    \"    print(f\\\"SERVER: Initializing Conditional Generation model on worker {TaskContext.get().partitionId()}.\\\")\\n\",\n    \"\\n\",\n    \"    # Enable GPU memory growth\\n\",\n    \"    gpus = tf.config.experimental.list_physical_devices('GPU')\\n\",\n    \"    if gpus:\\n\",\n    \"        try:\\n\",\n    \"            for gpu in gpus:\\n\",\n    \"                tf.config.experimental.set_memory_growth(gpu, True)\\n\",\n    \"        except RuntimeError as e:\\n\",\n    \"            print(e)\\n\",\n    \"    \\n\",\n    \"    tokenizer = AutoTokenizer.from_pretrained(\\\"google-t5/t5-small\\\")\\n\",\n    \"    model = TFT5ForConditionalGeneration.from_pretrained(\\\"google-t5/t5-small\\\")\\n\",\n    \"\\n\",\n    \"    @batch\\n\",\n    \"    def _infer_fn(**inputs):\\n\",\n    \"        sentences = np.squeeze(inputs[\\\"text\\\"]).tolist()\\n\",\n    \"        print(f\\\"SERVER: Received batch of size {len(sentences)}\\\")\\n\",\n    \"        decoded_sentences = [s.decode(\\\"utf-8\\\") for s in sentences]\\n\",\n    \"        inputs = tokenizer(decoded_sentences,\\n\",\n    \"                            padding=True,\\n\",\n    \"                            return_tensors=\\\"tf\\\")\\n\",\n    \"        output_ids = model.generate(input_ids=inputs[\\\"input_ids\\\"],\\n\",\n    \"                                    attention_mask=inputs[\\\"attention_mask\\\"],\\n\",\n    \"                                    max_length=128)\\n\",\n    \"        outputs = np.array([[tokenizer.decode(o, skip_special_tokens=True)] for o in output_ids])\\n\",\n    \"        return {\\n\",\n    \"            \\\"translations\\\": outputs,\\n\",\n    \"        }\\n\",\n    \"\\n\",\n    \"    workspace_path = f\\\"/tmp/triton_{time.strftime('%m_%d_%M_%S')}\\\"\\n\",\n    \"    triton_conf = TritonConfig(http_port=ports[0], grpc_port=ports[1], metrics_port=ports[2])\\n\",\n    \"    with Triton(config=triton_conf, workspace=workspace_path) as triton:\\n\",\n    \"        triton.bind(\\n\",\n    \"            model_name=\\\"ConditionalGeneration\\\",\\n\",\n    \"            infer_func=_infer_fn,\\n\",\n    \"            inputs=[\\n\",\n    \"                Tensor(name=\\\"text\\\", dtype=object, shape=(-1,)),\\n\",\n    \"            ],\\n\",\n    \"            outputs=[\\n\",\n    \"                Tensor(name=\\\"translations\\\", dtype=object, shape=(-1,)),\\n\",\n    \"            ],\\n\",\n    \"            config=ModelConfig(\\n\",\n    \"                max_batch_size=64,\\n\",\n    \"                batcher=DynamicBatcher(max_queue_delay_microseconds=5000),  # 5ms\\n\",\n    \"            ),\\n\",\n    \"            strict=True,\\n\",\n    \"        )\\n\",\n    \"\\n\",\n    \"        def _stop_triton(signum, frame):\\n\",\n    \"            # The server manager sends SIGTERM to stop the server; this function ensures graceful cleanup.\\n\",\n    \"            print(\\\"SERVER: Received SIGTERM. Stopping Triton server.\\\")\\n\",\n    \"            triton.stop()\\n\",\n    \"\\n\",\n    \"        signal.signal(signal.SIGTERM, _stop_triton)\\n\",\n    \"\\n\",\n    \"        print(\\\"SERVER: Serving inference\\\")\\n\",\n    \"        triton.serve()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"527da1b0\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Start Triton servers\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"4142ebfc\",\n   \"metadata\": {},\n   \"source\": [\n    \"The `TritonServerManager` will handle the lifecycle of Triton server instances across the Spark cluster:\\n\",\n    \"- Find available ports for HTTP/gRPC/metrics\\n\",\n    \"- Deploy a server on each node via stage-level scheduling\\n\",\n    \"- Gracefully shutdown servers across nodes\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"3d522f30\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"model_name = \\\"ConditionalGeneration\\\"\\n\",\n    \"server_manager = TritonServerManager(model_name=model_name)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"7c18994c\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2025-02-07 11:03:44,809 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\\n\",\n      \"2025-02-07 11:03:44,810 - INFO - Starting 1 servers.\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'cb4ae00-lcedt': (2020631, [7000, 7001, 7002])}\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"# Returns {'hostname', (server_pid, [http_port, grpc_port, metrics_port])}\\n\",\n    \"server_manager.start_servers(triton_server)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"3f284eb3\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Define client function\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"237e56dd\",\n   \"metadata\": {},\n   \"source\": [\n    \"Get the hostname -> url mapping from the server manager:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"826db582\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"host_to_http_url = server_manager.host_to_http_url  # or server_manager.host_to_grpc_url\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"f3f58e7b\",\n   \"metadata\": {},\n   \"source\": [\n    \"Define the Triton inference function, which returns a predict function for batch inference through the server:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 36,\n   \"id\": \"aff88b3f\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def triton_fn(model_name, host_to_url):\\n\",\n    \"    import socket\\n\",\n    \"    import numpy as np\\n\",\n    \"    from pytriton.client import ModelClient\\n\",\n    \"\\n\",\n    \"    url = host_to_url[socket.gethostname()]\\n\",\n    \"    print(f\\\"Connecting to Triton model {model_name} at {url}.\\\")\\n\",\n    \"\\n\",\n    \"    def infer_batch(inputs):\\n\",\n    \"        with ModelClient(url, model_name, inference_timeout_s=240) as client:\\n\",\n    \"            flattened = np.squeeze(inputs).tolist() \\n\",\n    \"            # Encode batch\\n\",\n    \"            encoded_batch = [[text.encode(\\\"utf-8\\\")] for text in flattened]\\n\",\n    \"            encoded_batch_np = np.array(encoded_batch, dtype=np.bytes_)\\n\",\n    \"            # Run inference\\n\",\n    \"            result_data = client.infer_batch(encoded_batch_np)\\n\",\n    \"            result_data = np.squeeze(result_data[\\\"translations\\\"], -1)\\n\",\n    \"            return result_data\\n\",\n    \"        \\n\",\n    \"    return infer_batch\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 40,\n   \"id\": \"5d10c61c-6102-4d19-8dd6-0c7b5b65343e\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"generate = predict_batch_udf(partial(triton_fn, model_name=model_name, host_to_url=host_to_http_url),\\n\",\n    \"                             return_type=StringType(),\\n\",\n    \"                             input_tensor_shapes=[[1]],\\n\",\n    \"                             batch_size=32)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"a85e2ceb\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Load and preprocess DataFrame\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 37,\n   \"id\": \"2fa3664e\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def preprocess(text: pd.Series, prefix: str = \\\"\\\") -> pd.Series:\\n\",\n    \"    @pandas_udf(\\\"string\\\")\\n\",\n    \"    def _preprocess(text: pd.Series) -> pd.Series:\\n\",\n    \"        return pd.Series([prefix + s.split(\\\".\\\")[0] for s in text])\\n\",\n    \"    return _preprocess(text)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 38,\n   \"id\": \"5d6c54e7-534d-406f-b8e6-fd592efd0ab2\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"df = spark.read.parquet(data_path).limit(256).repartition(8)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 39,\n   \"id\": \"dc1bbbe3-4232-49e5-80f6-99976524b73b\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"25/02/04 13:55:37 WARN CacheManager: Asked to cache already cached data.\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"input_df = df.select(preprocess(col(\\\"text\\\"), \\\"translate English to French: \\\").alias(\\\"input\\\")).cache()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"e71f07d4\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Run Inference\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 41,\n   \"id\": \"2e0907da-a5d9-4c3b-9db4-ce5e70ca9bb4\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 51:==================================================>       (7 + 1) / 8]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 10.8 ms, sys: 8.12 ms, total: 18.9 ms\\n\",\n      \"Wall time: 30 s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"# first pass caches model/fn\\n\",\n    \"preds = input_df.withColumn(\\\"preds\\\", generate(struct(\\\"input\\\")))\\n\",\n    \"results = preds.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 42,\n   \"id\": \"9308bdd7-6f67-484d-8b51-dd1e1b2960ba\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 54:===========================================>              (6 + 2) / 8]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 7.23 ms, sys: 3.43 ms, total: 10.7 ms\\n\",\n      \"Wall time: 21.2 s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"preds = input_df.withColumn(\\\"preds\\\", generate(\\\"input\\\"))\\n\",\n    \"results = preds.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 43,\n   \"id\": \"38484ffd-370d-492b-8ca4-9eff9f242a9f\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 57:===========================================>              (6 + 2) / 8]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 2.81 ms, sys: 12.7 ms, total: 15.5 ms\\n\",\n      \"Wall time: 22.3 s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"preds = input_df.withColumn(\\\"preds\\\", generate(col(\\\"input\\\")))\\n\",\n    \"results = preds.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 44,\n   \"id\": \"ebcb6699-3ac2-4529-ab0f-fab0a5e792da\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 60:>                                                         (0 + 1) / 1]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+--------------------------------------------------+--------------------------------------------------+\\n\",\n      \"|                                             input|                                             preds|\\n\",\n      \"+--------------------------------------------------+--------------------------------------------------+\\n\",\n      \"|translate English to French: Doesn't anyone bot...|Ne s'ennuie-t-il pas de vérifier où viennent ce...|\\n\",\n      \"|translate English to French: There were two thi...|Il y avait deux choses que j'ai hâte de voir : ...|\\n\",\n      \"|translate English to French: I'm rather surpris...|Je suis plutôt surpris que quelqu'un ait trouvé...|\\n\",\n      \"|translate English to French: Cultural Vandalism...|Vandalisme culturel La nouvelle production Hall...|\\n\",\n      \"|translate English to French: I was at Wrestlema...|J'étais à Wrestlemania VI à Toronto en 10 ans, ...|\\n\",\n      \"|translate English to French: This movie has bee...|                       Ce film a été réalisé avant|\\n\",\n      \"|translate English to French: [ as a new resolut...|[ en tant que nouvelle résolution pour cette an...|\\n\",\n      \"|translate English to French: This movie is over...|Je suis triste de dire que je parviens à regard...|\\n\",\n      \"|translate English to French: This show had a pr...|Ce spectacle a eu un début prometteur en l'espè...|\\n\",\n      \"|translate English to French: MINOR PLOT SPOILER...|br />br /> Comment ces acteurs talentueux ont-i...|\\n\",\n      \"|translate English to French: There is not one c...|Il n'y a pas d'un personnage sur ce sitcom ayan...|\\n\",\n      \"|translate English to French: Tommy Lee Jones wa...|Tommy Lee Jones était le meilleur Woodroe et pe...|\\n\",\n      \"|translate English to French: My wife rented thi...|Ma femme a loué ce film et n'a jamais pu le voi...|\\n\",\n      \"|translate English to French: This is one of tho...|C’est l’une des comédies en étoiles à l’étoile ...|\\n\",\n      \"|translate English to French: This excruciatingl...|Ce film excruciant ennuyant et infaillible m’a ...|\\n\",\n      \"|translate English to French: you will likely be...|Vous serez probablement très déçu par cette séq...|\\n\",\n      \"|translate English to French: If I was British, ...|Si j'étais britannique, je seraitis embarrassé ...|\\n\",\n      \"|translate English to French: One of those movie...|Un des films dans lesquels il n'y a pas de gros...|\\n\",\n      \"|translate English to French: This show is like ...|Ce spectacle ressemble à l'observation d'une pe...|\\n\",\n      \"|                 translate English to French: Sigh|                                             Pesée|\\n\",\n      \"+--------------------------------------------------+--------------------------------------------------+\\n\",\n      \"only showing top 20 rows\\n\",\n      \"\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"preds.show(truncate=50)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"919e3113-64dd-482a-9233-6607b3f63c1e\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"source\": [\n    \"#### Shut down server on each executor\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 45,\n   \"id\": \"425d3b28-7705-45ba-8a18-ad34fc895219\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2025-02-04 13:56:54,506 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\\n\",\n      \"2025-02-04 13:56:59,695 - INFO - Sucessfully stopped 1 servers.                 \\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"[True]\"\n      ]\n     },\n     \"execution_count\": 45,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"server_manager.stop_servers()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 46,\n   \"id\": \"2dec80ca-7a7c-46a9-97c0-7afb1572f5b9\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"if not on_databricks: # on databricks, spark.stop() puts the cluster in a bad state\\n\",\n    \"    spark.stop()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"f43118ab-fc0a-4f64-a126-4302e615654a\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"spark-dl-tf\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.11.9\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/conditional_generation_torch.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"8f6659b4-88da-4207-8d32-2674da5383a0\",\n     \"showTitle\": false,\n     \"tableResultSettingsMap\": {},\n     \"title\": \"\"\n    }\n   },\n   \"source\": [\n    \"<img src=\\\"http://developer.download.nvidia.com/notebooks/dlsw-notebooks/tensorrt_torchtrt_efficientnet/nvidia_logo.png\\\" width=\\\"90px\\\">\\n\",\n    \"\\n\",\n    \"# PySpark DL Inference\\n\",\n    \"### Conditional generation with Huggingface\\n\",\n    \"\\n\",\n    \"In this notebook, we demonstrate distributed inference with the T5 transformer to perform sentence translation.  \\n\",\n    \"From: https://huggingface.co/docs/transformers/model_doc/t5\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from transformers import T5Tokenizer, T5ForConditionalGeneration\\n\",\n    \"\\n\",\n    \"# Manually enable Huggingface tokenizer parallelism to avoid disabling with PySpark parallelism.\\n\",\n    \"# See (https://github.com/huggingface/transformers/issues/5486) for more info. \\n\",\n    \"import os\\n\",\n    \"os.environ[\\\"TOKENIZERS_PARALLELISM\\\"] = \\\"true\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"tokenizer = T5Tokenizer.from_pretrained(\\\"google-t5/t5-small\\\")\\n\",\n    \"model = T5ForConditionalGeneration.from_pretrained(\\\"google-t5/t5-small\\\")\\n\",\n    \"\\n\",\n    \"task_prefix = \\\"translate English to German: \\\"\\n\",\n    \"\\n\",\n    \"lines = [\\n\",\n    \"    \\\"The house is wonderful\\\",\\n\",\n    \"    \\\"Welcome to NYC\\\",\\n\",\n    \"    \\\"HuggingFace is a company\\\"\\n\",\n    \"]\\n\",\n    \"\\n\",\n    \"input_sequences = [task_prefix + l for l in lines]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"inputs = tokenizer(input_sequences,\\n\",\n    \"                      padding=True, \\n\",\n    \"                      return_tensors=\\\"pt\\\")\\n\",\n    \"\\n\",\n    \"outputs = model.generate(input_ids=inputs[\\\"input_ids\\\"], attention_mask=inputs[\\\"attention_mask\\\"], max_length=128)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"['Das Haus ist wunderbar',\\n\",\n       \" 'Willkommen in NYC',\\n\",\n       \" 'HuggingFace ist ein Unternehmen']\"\n      ]\n     },\n     \"execution_count\": 4,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"[tokenizer.decode(o, skip_special_tokens=True) for o in outputs]\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## PySpark\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"1b8dae4a-3bfc-4430-b28a-7350db5efed4\",\n     \"showTitle\": false,\n     \"tableResultSettingsMap\": {},\n     \"title\": \"\"\n    }\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"from pyspark.sql.types import *\\n\",\n    \"from pyspark import SparkConf\\n\",\n    \"from pyspark.sql import SparkSession\\n\",\n    \"from pyspark.sql.functions import pandas_udf, col, struct\\n\",\n    \"from pyspark.ml.functions import predict_batch_udf\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"a93a1424-e483-4d37-a719-32fabee3f285\",\n     \"showTitle\": false,\n     \"tableResultSettingsMap\": {},\n     \"title\": \"\"\n    }\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"import json\\n\",\n    \"import pandas as pd\\n\",\n    \"import datasets\\n\",\n    \"from datasets import load_dataset\\n\",\n    \"datasets.disable_progress_bars()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Check the cluster environment to handle any platform-specific Spark configurations.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"on_databricks = os.environ.get(\\\"DATABRICKS_RUNTIME_VERSION\\\", False)\\n\",\n    \"on_dataproc = os.environ.get(\\\"DATAPROC_IMAGE_VERSION\\\", False)\\n\",\n    \"on_standalone = not (on_databricks or on_dataproc)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Create Spark Session\\n\",\n    \"\\n\",\n    \"For local standalone clusters, we'll connect to the cluster and create the Spark Session.  \\n\",\n    \"For CSP environments, Spark will either be preconfigured (Databricks) or we'll need to create the Spark Session (Dataproc).\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"25/02/04 13:34:55 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)\\n\",\n      \"25/02/04 13:34:55 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\\n\",\n      \"Setting default log level to \\\"WARN\\\".\\n\",\n      \"To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\\n\",\n      \"25/02/04 13:34:55 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"conf = SparkConf()\\n\",\n    \"\\n\",\n    \"if 'spark' not in globals():\\n\",\n    \"    if on_standalone:\\n\",\n    \"        import socket\\n\",\n    \"        conda_env = os.environ.get(\\\"CONDA_PREFIX\\\")\\n\",\n    \"        hostname = socket.gethostname()\\n\",\n    \"        conf.setMaster(f\\\"spark://{hostname}:7077\\\")\\n\",\n    \"        conf.set(\\\"spark.pyspark.python\\\", f\\\"{conda_env}/bin/python\\\")\\n\",\n    \"        conf.set(\\\"spark.pyspark.driver.python\\\", f\\\"{conda_env}/bin/python\\\")\\n\",\n    \"\\n\",\n    \"    conf.set(\\\"spark.executor.cores\\\", \\\"8\\\")\\n\",\n    \"    conf.set(\\\"spark.task.resource.gpu.amount\\\", \\\"0.125\\\")\\n\",\n    \"    conf.set(\\\"spark.executor.resource.gpu.amount\\\", \\\"1\\\")\\n\",\n    \"    conf.set(\\\"spark.sql.execution.arrow.pyspark.enabled\\\", \\\"true\\\")\\n\",\n    \"    conf.set(\\\"spark.python.worker.reuse\\\", \\\"true\\\")\\n\",\n    \"\\n\",\n    \"conf.set(\\\"spark.sql.execution.arrow.maxRecordsPerBatch\\\", \\\"1000\\\")\\n\",\n    \"spark = SparkSession.builder.appName(\\\"spark-dl-examples\\\").config(conf=conf).getOrCreate()\\n\",\n    \"sc = spark.sparkContext\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"f08c37a5-fb0c-45f6-8630-d2af67831641\",\n     \"showTitle\": false,\n     \"tableResultSettingsMap\": {},\n     \"title\": \"\"\n    }\n   },\n   \"source\": [\n    \"Load the IMBD Movie Reviews dataset from Huggingface.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 9,\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"f0ec30c9-365a-43c5-9c53-3497400ee548\",\n     \"showTitle\": false,\n     \"tableResultSettingsMap\": {},\n     \"title\": \"\"\n    }\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"dataset = load_dataset(\\\"imdb\\\", split=\\\"test\\\")\\n\",\n    \"dataset = dataset.to_pandas().drop(columns=\\\"label\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"1e4269da-d2b3-46a5-9309-38a1ba825a47\",\n     \"showTitle\": false,\n     \"tableResultSettingsMap\": {},\n     \"title\": \"\"\n    }\n   },\n   \"source\": [\n    \"#### Create PySpark DataFrame\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 10,\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"30dab34d-8e4b-4f30-b7c2-3dff49da018b\",\n     \"showTitle\": false,\n     \"tableResultSettingsMap\": {},\n     \"title\": \"\"\n    }\n   },\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"StructType([StructField('text', StringType(), True)])\"\n      ]\n     },\n     \"execution_count\": 10,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"df = spark.createDataFrame(dataset).repartition(8)\\n\",\n    \"df.schema\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 11,\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"55c33cc0-5dfb-449c-ae79-80972fb04405\",\n     \"showTitle\": false,\n     \"tableResultSettingsMap\": {},\n     \"title\": \"\"\n    }\n   },\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"25000\"\n      ]\n     },\n     \"execution_count\": 11,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"df.count()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 12,\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"efd6d6d9-1c2c-4131-8df4-a3ef75c3fc57\",\n     \"showTitle\": false,\n     \"tableResultSettingsMap\": {},\n     \"title\": \"\"\n    }\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"25/02/04 13:35:02 WARN TaskSetManager: Stage 6 contains a task of very large size (4021 KiB). The maximum recommended task size is 1000 KiB.\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"[Row(text=\\\"Anyone remember the first CKY, CKY2K etc..? Back when it was about making crazy cool stuff, rather than watching Bam Margera act like a douchebag, spoiled 5 year old, super/rock-star wannabe.<br /><br />The show used to be awesome, however, Bam's fame and wealth has led him to believe, that we now enjoy him acting childish and idiotic, more than actual cool stuff, that used to be in ex. CKY2K.<br /><br />The acts are so repetitive, there's like nothing new, except annoying stupidity and rehearsed comments... The only things we see is Bam Margera, so busy showing us how much he doesn't care, how much money he got or whatsoever.<br /><br />I really got nothing much left to say except, give us back CKY2K, cause Bam suck..<br /><br />I enjoy watching Steve-o, Knoxville etc. a thousand times more.\\\")]\"\n      ]\n     },\n     \"execution_count\": 12,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"df.take(1)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 13,\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"65a5b258-1634-441e-8b36-29777e54592d\",\n     \"showTitle\": false,\n     \"tableResultSettingsMap\": {},\n     \"title\": \"\"\n    }\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"25/02/04 13:35:02 WARN TaskSetManager: Stage 9 contains a task of very large size (4021 KiB). The maximum recommended task size is 1000 KiB.\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"data_path = \\\"spark-dl-datasets/imdb_test\\\"\\n\",\n    \"if on_databricks:\\n\",\n    \"    dbutils.fs.mkdirs(\\\"/FileStore/spark-dl-datasets\\\")\\n\",\n    \"    data_path = \\\"dbfs:/FileStore/\\\" + data_path\\n\",\n    \"\\n\",\n    \"df.write.mode(\\\"overwrite\\\").parquet(data_path)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"89b909f4-5732-428b-ad61-9a6c5cf94df2\",\n     \"showTitle\": false,\n     \"tableResultSettingsMap\": {},\n     \"title\": \"\"\n    }\n   },\n   \"source\": [\n    \"#### Load and preprocess DataFrame\\n\",\n    \"\\n\",\n    \"Define our preprocess function. We'll take the first sentence from each sample as our input for translation.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 14,\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"eb7e53d6-bbd0-48d2-a3be-36847275e2a9\",\n     \"showTitle\": false,\n     \"tableResultSettingsMap\": {},\n     \"title\": \"\"\n    }\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"def preprocess(text: pd.Series, prefix: str = \\\"\\\") -> pd.Series:\\n\",\n    \"    @pandas_udf(\\\"string\\\")\\n\",\n    \"    def _preprocess(text: pd.Series) -> pd.Series:\\n\",\n    \"        return pd.Series([prefix + s.split(\\\".\\\")[0] for s in text])\\n\",\n    \"    return _preprocess(text)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 15,\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"97eee1a4-9dc4-43b0-9578-6d7f8ff338bd\",\n     \"showTitle\": false,\n     \"tableResultSettingsMap\": {},\n     \"title\": \"\"\n    }\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+----------------------------------------------------------------------------------------------------+\\n\",\n      \"|                                                                                                text|\\n\",\n      \"+----------------------------------------------------------------------------------------------------+\\n\",\n      \"|The only reason I'm even giving this movie a 4 is because it was made in to an episode of Mystery...|\\n\",\n      \"|Awkward disaster mishmash has a team of scavengers coming across the overturned S.S. Poseidon, ho...|\\n\",\n      \"|Here is a fantastic concept for a film - a series of meteors crash into a small town and the resu...|\\n\",\n      \"|I walked out of the cinema having suffered this film after 30 mins. I left two friends pinned in ...|\\n\",\n      \"|A wildly uneven film where the major problem is the uneasy mix of comedy and thriller. To me, the...|\\n\",\n      \"|Leonard Rossiter and Frances de la Tour carry this film, not without a struggle, as the script wa...|\\n\",\n      \"|A good cast... A good idea but turns out it is flawed as hypnosis is not allowed as evidence in c...|\\n\",\n      \"|Yet again, I appear to be the only person on planet Earth who is capable of criticizing Japanese ...|\\n\",\n      \"|As a serious horror fan, I get that certain marketing ploys are used to sell movies, especially t...|\\n\",\n      \"|Upon writing this review I have difficulty trying to think of what to write about. Nothing much h...|\\n\",\n      \"|Simply awful. I'm including a spoiler warning here only because of including a coupla jokes from ...|\\n\",\n      \"|I am a fan of Ed Harris' work and I really had high expectations about this film. Having so good ...|\\n\",\n      \"|Well...I like Patricia Kaas. She is a beautiful lady and an extremely gifted and versatile singer...|\\n\",\n      \"|This is a new approach to comedy. It isn't funny.<br /><br />The joke is that this, in and of its...|\\n\",\n      \"|It's been mentioned by others the inane dialogue in this series and I agree.<br /><br />If Mom an...|\\n\",\n      \"|One of the most boring movies I've ever had to sit through, it's completely formulaic. Just a coo...|\\n\",\n      \"|This movie was playing on Lifetime Movie Network last month and I decided to check it out. I watc...|\\n\",\n      \"|1983's \\\"Frightmare\\\" is an odd little film. The director seems to be trying to combine the atmosph...|\\n\",\n      \"|'Felony' is a B-movie. No doubt about it.<br /><br />Of course, if you take a look at the cast li...|\\n\",\n      \"|This movie defines the word \\\"confused\\\". All the actors stay true to the script. More's the pity, ...|\\n\",\n      \"+----------------------------------------------------------------------------------------------------+\\n\",\n      \"only showing top 20 rows\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"# Limit to N rows, since this can be slow\\n\",\n    \"df = spark.read.parquet(data_path).limit(512).repartition(8)\\n\",\n    \"df.show(truncate=100)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Append a prefix to tell the model to translate English to French:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 16,\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"fa14304d-b409-4d07-99ef-9da7c7c76158\",\n     \"showTitle\": false,\n     \"tableResultSettingsMap\": {},\n     \"title\": \"\"\n    }\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+----------------------------------------------------------------------------------------------------+\\n\",\n      \"|                                                                                               input|\\n\",\n      \"+----------------------------------------------------------------------------------------------------+\\n\",\n      \"|translate English to French: The only reason I'm even giving this movie a 4 is because it was mad...|\\n\",\n      \"|translate English to French: Awkward disaster mishmash has a team of scavengers coming across the...|\\n\",\n      \"|translate English to French: Here is a fantastic concept for a film - a series of meteors crash i...|\\n\",\n      \"|     translate English to French: I walked out of the cinema having suffered this film after 30 mins|\\n\",\n      \"|translate English to French: A wildly uneven film where the major problem is the uneasy mix of co...|\\n\",\n      \"|translate English to French: Leonard Rossiter and Frances de la Tour carry this film, not without...|\\n\",\n      \"|                                                            translate English to French: A good cast|\\n\",\n      \"|translate English to French: Yet again, I appear to be the only person on planet Earth who is cap...|\\n\",\n      \"|translate English to French: As a serious horror fan, I get that certain marketing ploys are used...|\\n\",\n      \"|translate English to French: Upon writing this review I have difficulty trying to think of what t...|\\n\",\n      \"|                                                           translate English to French: Simply awful|\\n\",\n      \"|translate English to French: I am a fan of Ed Harris' work and I really had high expectations abo...|\\n\",\n      \"|                                                                   translate English to French: Well|\\n\",\n      \"|                                       translate English to French: This is a new approach to comedy|\\n\",\n      \"|translate English to French: It's been mentioned by others the inane dialogue in this series and ...|\\n\",\n      \"|translate English to French: One of the most boring movies I've ever had to sit through, it's com...|\\n\",\n      \"|translate English to French: This movie was playing on Lifetime Movie Network last month and I de...|\\n\",\n      \"|                              translate English to French: 1983's \\\"Frightmare\\\" is an odd little film|\\n\",\n      \"|                                                  translate English to French: 'Felony' is a B-movie|\\n\",\n      \"|                                 translate English to French: This movie defines the word \\\"confused\\\"|\\n\",\n      \"+----------------------------------------------------------------------------------------------------+\\n\",\n      \"only showing top 20 rows\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"input_df = df.select(preprocess(col(\\\"text\\\"), \\\"translate English to French: \\\").alias(\\\"input\\\")).cache()\\n\",\n    \"input_df.show(truncate=100)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"bc9cbdd2-1ca6-48e4-a549-792b3726525b\",\n     \"showTitle\": false,\n     \"tableResultSettingsMap\": {},\n     \"title\": \"\"\n    }\n   },\n   \"source\": [\n    \"## Inference using Spark DL API\\n\",\n    \"\\n\",\n    \"Distributed inference using the PySpark [predict_batch_udf](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.functions.predict_batch_udf.html#pyspark.ml.functions.predict_batch_udf):\\n\",\n    \"\\n\",\n    \"- predict_batch_fn uses PyTorch APIs to load the model and return a predict function which operates on numpy arrays \\n\",\n    \"- predict_batch_udf will convert the Spark DataFrame columns into numpy input batches for the predict function\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 17,\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"adb81177-442d-42ab-b86d-d8792201b4c8\",\n     \"showTitle\": false,\n     \"tableResultSettingsMap\": {},\n     \"title\": \"\"\n    }\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"def predict_batch_fn():\\n\",\n    \"    import numpy as np\\n\",\n    \"    import torch\\n\",\n    \"    from transformers import T5ForConditionalGeneration, T5Tokenizer\\n\",\n    \"    from pyspark import TaskContext\\n\",\n    \"\\n\",\n    \"    device = torch.device(\\\"cuda\\\" if torch.cuda.is_available() else \\\"cpu\\\")\\n\",\n    \"    print(f\\\"Initializing model on worker {TaskContext.get().partitionId()}, device {device}.\\\")\\n\",\n    \"    model = T5ForConditionalGeneration.from_pretrained(\\\"t5-small\\\").to(device)\\n\",\n    \"    tokenizer = T5Tokenizer.from_pretrained(\\\"t5-small\\\")\\n\",\n    \"\\n\",\n    \"    def predict(inputs):\\n\",\n    \"        flattened = np.squeeze(inputs).tolist()\\n\",\n    \"        inputs = tokenizer(flattened, \\n\",\n    \"                           padding=True,\\n\",\n    \"                           return_tensors=\\\"pt\\\").to(device)\\n\",\n    \"        outputs = model.generate(input_ids=inputs[\\\"input_ids\\\"],\\n\",\n    \"                                 attention_mask=inputs[\\\"attention_mask\\\"],\\n\",\n    \"                                 max_length=128)\\n\",\n    \"        string_outputs = np.array([tokenizer.decode(o, skip_special_tokens=True) for o in outputs])\\n\",\n    \"        print(\\\"predict: {}\\\".format(len(flattened)))\\n\",\n    \"        return string_outputs\\n\",\n    \"    \\n\",\n    \"    return predict\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 18,\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"20aab3a1-2284-4c07-9ce1-a20cf54d88f3\",\n     \"showTitle\": false,\n     \"tableResultSettingsMap\": {},\n     \"title\": \"\"\n    }\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"generate = predict_batch_udf(predict_batch_fn,\\n\",\n    \"                             return_type=StringType(),\\n\",\n    \"                             batch_size=32)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 19,\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"a8d6f48e-09e7-4fc7-9d2f-1b68bc2976a7\",\n     \"showTitle\": false,\n     \"tableResultSettingsMap\": {},\n     \"title\": \"\"\n    }\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 24:=============================>                            (4 + 4) / 8]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 10.2 ms, sys: 5.05 ms, total: 15.2 ms\\n\",\n      \"Wall time: 7.41 s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"# first pass caches model/fn\\n\",\n    \"preds = input_df.withColumn(\\\"preds\\\", generate(struct(\\\"input\\\")))\\n\",\n    \"results = preds.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 20,\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"abe2271d-0077-48f6-98b1-93524dd86447\",\n     \"showTitle\": false,\n     \"tableResultSettingsMap\": {},\n     \"title\": \"\"\n    }\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 27:=============================>                            (4 + 4) / 8]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 3.93 ms, sys: 1.98 ms, total: 5.91 ms\\n\",\n      \"Wall time: 4.08 s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"preds = input_df.withColumn(\\\"preds\\\", generate(\\\"input\\\"))\\n\",\n    \"results = preds.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 21,\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"77623711-a742-4262-8839-16fc3ddd1af7\",\n     \"showTitle\": false,\n     \"tableResultSettingsMap\": {},\n     \"title\": \"\"\n    }\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 30:==============>                                           (2 + 6) / 8]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 3.85 ms, sys: 1.75 ms, total: 5.6 ms\\n\",\n      \"Wall time: 4.08 s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"preds = input_df.withColumn(\\\"preds\\\", generate(col(\\\"input\\\")))\\n\",\n    \"results = preds.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 22,\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"f339c654-52fd-4992-b054-188dfb260e5d\",\n     \"showTitle\": false,\n     \"tableResultSettingsMap\": {},\n     \"title\": \"\"\n    }\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+--------------------------------------------------+--------------------------------------------------+\\n\",\n      \"|                                             input|                                             preds|\\n\",\n      \"+--------------------------------------------------+--------------------------------------------------+\\n\",\n      \"|translate English to French: The only reason I'...|La seule raison pour laquelle je donne même ce ...|\\n\",\n      \"|translate English to French: Awkward disaster m...|La mishmash d’Awkward a eu une équipe de scaven...|\\n\",\n      \"|translate English to French: Here is a fantasti...|Voici un concept fantastique pour un film : une...|\\n\",\n      \"|translate English to French: I walked out of th...|Je me suis rendu du cinéma après avoir subi ce ...|\\n\",\n      \"|translate English to French: A wildly uneven fi...|Un film extrêmement inégal où le problème majeu...|\\n\",\n      \"|translate English to French: Leonard Rossiter a...|Leonard Rossiter et Frances de la Tour mettent ...|\\n\",\n      \"|          translate English to French: A good cast|                                  Une bonne étoile|\\n\",\n      \"|translate English to French: Yet again, I appea...|Encore une fois, je semble être la seule person...|\\n\",\n      \"|translate English to French: As a serious horro...|En tant que grand fan d'horreur, je peux obteni...|\\n\",\n      \"|translate English to French: Upon writing this ...|la suite de cette étude, j'ai de la difficulté ...|\\n\",\n      \"|         translate English to French: Simply awful|                          Tout simplement terrible|\\n\",\n      \"|translate English to French: I am a fan of Ed H...|Je suis un fan de l'oeuvre d'Ed Harris et j'ai ...|\\n\",\n      \"|                 translate English to French: Well|                                           Eh bien|\\n\",\n      \"|translate English to French: This is a new appr...|  Il s’agit d’une nouvelle approche de la comédie.|\\n\",\n      \"|translate English to French: It's been mentione...|Il a été mentionné par d'autres le dialogue ina...|\\n\",\n      \"|translate English to French: One of the most bo...|Un des films les plus ennuyeux que je n'ai jama...|\\n\",\n      \"|translate English to French: This movie was pla...|Ce film jouait sur Lifetime Movie Network le mo...|\\n\",\n      \"|translate English to French: 1983's \\\"Frightmare...|Le film \\\"Frightmare\\\" de 1983 est un petit film ...|\\n\",\n      \"|translate English to French: 'Felony' is a B-movie|                       'Felony' est un mouvement B|\\n\",\n      \"|translate English to French: This movie defines...|                   Ce film définit le mot «confus»|\\n\",\n      \"+--------------------------------------------------+--------------------------------------------------+\\n\",\n      \"only showing top 20 rows\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"preds.show(truncate=50)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Let's try English to German:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 23,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"input_df2 = df.select(preprocess(col(\\\"text\\\"), \\\"translate English to German: \\\").alias(\\\"input\\\")).cache()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 24,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 36:==================================================>       (7 + 1) / 8]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 6.02 ms, sys: 705 μs, total: 6.73 ms\\n\",\n      \"Wall time: 4.24 s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"# first pass caches model/fn\\n\",\n    \"preds = input_df2.withColumn(\\\"preds\\\", generate(struct(\\\"input\\\")))\\n\",\n    \"result = preds.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 25,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 39:==============>                                           (2 + 6) / 8]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 6.12 ms, sys: 319 μs, total: 6.43 ms\\n\",\n      \"Wall time: 3.88 s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"preds = input_df2.withColumn(\\\"preds\\\", generate(\\\"input\\\"))\\n\",\n    \"result = preds.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 26,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 42:==============>                                           (2 + 6) / 8]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 7.03 ms, sys: 16 μs, total: 7.05 ms\\n\",\n      \"Wall time: 3.9 s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"preds = input_df2.withColumn(\\\"preds\\\", generate(col(\\\"input\\\")))\\n\",\n    \"result = preds.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 27,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+--------------------------------------------------+--------------------------------------------------+\\n\",\n      \"|                                             input|                                             preds|\\n\",\n      \"+--------------------------------------------------+--------------------------------------------------+\\n\",\n      \"|translate English to German: The only reason I'...|Der einzige Grund, warum ich sogar diesen Film ...|\\n\",\n      \"|translate English to German: Awkward disaster m...|Awkward-Katastrophenmischmash hat ein Team von ...|\\n\",\n      \"|translate English to German: Here is a fantasti...|Hier ist ein fantastisches Konzept für einen Fi...|\\n\",\n      \"|translate English to German: I walked out of th...|Ich ging aus dem Kino, nachdem ich diesen Film ...|\\n\",\n      \"|translate English to German: A wildly uneven fi...|Ein völlig ungleicher Film, in dem das Hauptpro...|\\n\",\n      \"|translate English to German: Leonard Rossiter a...|Leonard Rossiter und Frances de la Tour tragen ...|\\n\",\n      \"|          translate English to German: A good cast|                                     Gutes Casting|\\n\",\n      \"|translate English to German: Yet again, I appea...|Ich scheine wieder einmal die einzige Person au...|\\n\",\n      \"|translate English to German: As a serious horro...|Als ernsthafter Horrorfan erhalte ich, dass bes...|\\n\",\n      \"|translate English to German: Upon writing this ...|Ich habe Schwierigkeiten, mich an die Regeln zu...|\\n\",\n      \"|         translate English to German: Simply awful|                               Einfach schrecklich|\\n\",\n      \"|translate English to German: I am a fan of Ed H...|Ich bin ein Fan von Ed Harris' Arbeit und hatte...|\\n\",\n      \"|                 translate English to German: Well|                                               Nun|\\n\",\n      \"|translate English to German: This is a new appr...|          Das ist ein neuer Ansatz für die Komödie|\\n\",\n      \"|translate English to German: It's been mentione...|Es wurde von anderen erwähnt, die unangenehme D...|\\n\",\n      \"|translate English to German: One of the most bo...|Einer der langwierigen Filme, die ich jemals du...|\\n\",\n      \"|translate English to German: This movie was pla...|Dieser Film spielte im letzten Monat auf Lifeti...|\\n\",\n      \"|translate English to German: 1983's \\\"Frightmare...|       1983 ist \\\"Frightmare\\\" ein merkwürdiger Film|\\n\",\n      \"|translate English to German: 'Felony' is a B-movie|                           'Felony' ist ein B-Film|\\n\",\n      \"|translate English to German: This movie defines...|         Dieser Film definiert das Wort \\\"verwirrt\\\"|\\n\",\n      \"+--------------------------------------------------+--------------------------------------------------+\\n\",\n      \"only showing top 20 rows\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"preds.show(truncate=50)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"a79a6f3a-cc34-46a4-aadd-16870423fffa\",\n     \"showTitle\": false,\n     \"tableResultSettingsMap\": {},\n     \"title\": \"\"\n    }\n   },\n   \"source\": [\n    \"## Using Triton Inference Server\\n\",\n    \"In this section, we demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL.  \\n\",\n    \"We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server.  \\n\",\n    \"\\n\",\n    \"The process looks like this:\\n\",\n    \"- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.\\n\",\n    \"- Define a Triton inference function, which contains a client that binds to the local server on a given node and sends inference requests.\\n\",\n    \"- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.\\n\",\n    \"- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.\\n\",\n    \"\\n\",\n    \"<img src=\\\"../images/spark-server.png\\\" alt=\\\"drawing\\\" width=\\\"700\\\"/>\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 28,\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"1e73757e-a451-4835-98e0-257ccf7a9025\",\n     \"showTitle\": false,\n     \"tableResultSettingsMap\": {},\n     \"title\": \"\"\n    }\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"from functools import partial\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Import the helper class from server_utils.py:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 29,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"sc.addPyFile(\\\"server_utils.py\\\")\\n\",\n    \"\\n\",\n    \"from server_utils import TritonServerManager\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Define the Triton Server function:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 30,\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"71b1cb49-3d8f-4eeb-937a-c0c334bd2947\",\n     \"showTitle\": false,\n     \"tableResultSettingsMap\": {},\n     \"title\": \"\"\n    }\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"def triton_server(ports):\\n\",\n    \"    import time\\n\",\n    \"    import signal\\n\",\n    \"    import numpy as np\\n\",\n    \"    import torch\\n\",\n    \"    from transformers import T5Tokenizer, T5ForConditionalGeneration\\n\",\n    \"    from pytriton.decorators import batch\\n\",\n    \"    from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor\\n\",\n    \"    from pytriton.triton import Triton, TritonConfig\\n\",\n    \"    from pyspark import TaskContext\\n\",\n    \"\\n\",\n    \"    print(f\\\"SERVER: Initializing Conditional Generation model on worker {TaskContext.get().partitionId()}.\\\")\\n\",\n    \"    tokenizer = T5Tokenizer.from_pretrained(\\\"t5-small\\\")\\n\",\n    \"    model = T5ForConditionalGeneration.from_pretrained(\\\"t5-small\\\")\\n\",\n    \"    \\n\",\n    \"    DEVICE = \\\"cuda\\\" if torch.cuda.is_available() else \\\"cpu\\\"\\n\",\n    \"    print(f\\\"SERVER: Using {DEVICE} device.\\\")\\n\",\n    \"    model = model.to(DEVICE)\\n\",\n    \"\\n\",\n    \"    @batch\\n\",\n    \"    def _infer_fn(**inputs):\\n\",\n    \"        sentences = np.squeeze(inputs[\\\"text\\\"]).tolist()\\n\",\n    \"        print(f\\\"SERVER: Received batch of size {len(sentences)}\\\")\\n\",\n    \"        decoded_sentences = [s.decode(\\\"utf-8\\\") for s in sentences]\\n\",\n    \"        inputs = tokenizer(decoded_sentences,\\n\",\n    \"                        padding=True,\\n\",\n    \"                        return_tensors=\\\"pt\\\").to(DEVICE)\\n\",\n    \"        output_ids = model.generate(input_ids=inputs[\\\"input_ids\\\"],\\n\",\n    \"                                    attention_mask=inputs[\\\"attention_mask\\\"],\\n\",\n    \"                                    max_length=128)\\n\",\n    \"        outputs = np.array([[tokenizer.decode(o, skip_special_tokens=True)] for o in output_ids])\\n\",\n    \"        return {\\n\",\n    \"            \\\"translations\\\": outputs,\\n\",\n    \"        }\\n\",\n    \"\\n\",\n    \"    workspace_path = f\\\"/tmp/triton_{time.strftime('%m_%d_%M_%S')}\\\"\\n\",\n    \"    triton_conf = TritonConfig(http_port=ports[0], grpc_port=ports[1], metrics_port=ports[2])\\n\",\n    \"    with Triton(config=triton_conf, workspace=workspace_path) as triton:\\n\",\n    \"        triton.bind(\\n\",\n    \"            model_name=\\\"ConditionalGeneration\\\",\\n\",\n    \"            infer_func=_infer_fn,\\n\",\n    \"            inputs=[\\n\",\n    \"                Tensor(name=\\\"text\\\", dtype=object, shape=(-1,)),\\n\",\n    \"            ],\\n\",\n    \"            outputs=[\\n\",\n    \"                Tensor(name=\\\"translations\\\", dtype=object, shape=(-1,)),\\n\",\n    \"            ],\\n\",\n    \"            config=ModelConfig(\\n\",\n    \"                max_batch_size=64,\\n\",\n    \"                batcher=DynamicBatcher(max_queue_delay_microseconds=5000),  # 5ms\\n\",\n    \"            ),\\n\",\n    \"            strict=True,\\n\",\n    \"        )\\n\",\n    \"\\n\",\n    \"        def _stop_triton(signum, frame):\\n\",\n    \"            # The server manager sends SIGTERM to stop the server; this function ensures graceful cleanup.\\n\",\n    \"            print(\\\"SERVER: Received SIGTERM. Stopping Triton server.\\\")\\n\",\n    \"            triton.stop()\\n\",\n    \"\\n\",\n    \"        signal.signal(signal.SIGTERM, _stop_triton)\\n\",\n    \"\\n\",\n    \"        print(\\\"SERVER: Serving inference\\\")\\n\",\n    \"        triton.serve()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"1bf14846-15a3-4bc8-b0c5-ce71680d3550\",\n     \"showTitle\": false,\n     \"tableResultSettingsMap\": {},\n     \"title\": \"\"\n    }\n   },\n   \"source\": [\n    \"#### Start Triton servers\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"The `TritonServerManager` will handle the lifecycle of Triton server instances across the Spark cluster:\\n\",\n    \"- Find available ports for HTTP/gRPC/metrics\\n\",\n    \"- Deploy a server on each node via stage-level scheduling\\n\",\n    \"- Gracefully shutdown servers across nodes\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 32,\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"5bf1fafc-d9c9-4fd7-901d-da97cf4ff496\",\n     \"showTitle\": false,\n     \"tableResultSettingsMap\": {},\n     \"title\": \"\"\n    }\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"model_name = \\\"ConditionalGeneration\\\"\\n\",\n    \"server_manager = TritonServerManager(model_name=model_name)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2025-02-07 11:03:44,809 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\\n\",\n      \"2025-02-07 11:03:44,810 - INFO - Starting 1 servers.\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'cb4ae00-lcedt': (2020631, [7000, 7001, 7002])}\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"# Returns {'hostname', (server_pid, [http_port, grpc_port, metrics_port])}\\n\",\n    \"server_manager.start_servers(triton_server)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Define client function\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Get the hostname -> url mapping from the server manager:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"host_to_http_url = server_manager.host_to_http_url  # or server_manager.host_to_grpc_url\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Define the Triton inference function, which returns a predict function for batch inference through the server:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 35,\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"e203eb19-166d-4177-aa87-fd31b7e3c90e\",\n     \"showTitle\": false,\n     \"tableResultSettingsMap\": {},\n     \"title\": \"\"\n    }\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"def triton_fn(model_name, host_to_url):\\n\",\n    \"    import socket\\n\",\n    \"    import numpy as np\\n\",\n    \"    from pytriton.client import ModelClient\\n\",\n    \"\\n\",\n    \"    url = host_to_url[socket.gethostname()]\\n\",\n    \"    print(f\\\"Connecting to Triton model {model_name} at {url}.\\\")\\n\",\n    \"\\n\",\n    \"    def infer_batch(inputs):\\n\",\n    \"        with ModelClient(url, model_name, inference_timeout_s=240) as client:\\n\",\n    \"            flattened = np.squeeze(inputs).tolist() \\n\",\n    \"            # Encode batch\\n\",\n    \"            encoded_batch = [[text.encode(\\\"utf-8\\\")] for text in flattened]\\n\",\n    \"            encoded_batch_np = np.array(encoded_batch, dtype=np.bytes_)\\n\",\n    \"            # Run inference\\n\",\n    \"            result_data = client.infer_batch(encoded_batch_np)\\n\",\n    \"            result_data = np.squeeze(result_data[\\\"translations\\\"], -1)\\n\",\n    \"            return result_data\\n\",\n    \"        \\n\",\n    \"    return infer_batch\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 39,\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"be692f4a-cf86-4cf4-9530-7c62e479cacd\",\n     \"showTitle\": false,\n     \"tableResultSettingsMap\": {},\n     \"title\": \"\"\n    }\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"generate = predict_batch_udf(partial(triton_fn, model_name=model_name, host_to_url=host_to_http_url),\\n\",\n    \"                             return_type=StringType(),\\n\",\n    \"                             input_tensor_shapes=[[1]],\\n\",\n    \"                             batch_size=32)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"1b6b2a05-aea4-4e4d-a87d-0a6bd5ab554c\",\n     \"showTitle\": false,\n     \"tableResultSettingsMap\": {},\n     \"title\": \"\"\n    }\n   },\n   \"source\": [\n    \"#### Load and preprocess DataFrame\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 36,\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"a5e83230-5178-4fec-bba2-0e69be40e68c\",\n     \"showTitle\": false,\n     \"tableResultSettingsMap\": {},\n     \"title\": \"\"\n    }\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"def preprocess(text: pd.Series, prefix: str = \\\"\\\") -> pd.Series:\\n\",\n    \"    @pandas_udf(\\\"string\\\")\\n\",\n    \"    def _preprocess(text: pd.Series) -> pd.Series:\\n\",\n    \"        return pd.Series([prefix + s.split(\\\".\\\")[0] for s in text])\\n\",\n    \"    return _preprocess(text)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 37,\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"aad299b0-34bb-4edb-b1e4-cd0c82bb7455\",\n     \"showTitle\": false,\n     \"tableResultSettingsMap\": {},\n     \"title\": \"\"\n    }\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"df = spark.read.parquet(data_path).limit(512).repartition(8)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 38,\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"7934a6fc-57bc-4104-a52c-076351e77cbe\",\n     \"showTitle\": false,\n     \"tableResultSettingsMap\": {},\n     \"title\": \"\"\n    }\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"25/02/04 13:35:39 WARN CacheManager: Asked to cache already cached data.\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"input_df = df.select(preprocess(col(\\\"text\\\"), \\\"translate English to French: \\\").alias(\\\"input\\\")).cache()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Run Inference\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 40,\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"0f6229ef-01c8-43c9-a259-c5df6a18d689\",\n     \"showTitle\": false,\n     \"tableResultSettingsMap\": {},\n     \"title\": \"\"\n    }\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 51:====================================>                     (5 + 3) / 8]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 5.09 ms, sys: 4.41 ms, total: 9.5 ms\\n\",\n      \"Wall time: 4.96 s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"# first pass caches model/fn\\n\",\n    \"preds = input_df.withColumn(\\\"preds\\\", generate(struct(\\\"input\\\")))\\n\",\n    \"results = preds.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 41,\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"5a543b4c-8b29-4f61-9773-2639bbc7f728\",\n     \"showTitle\": false,\n     \"tableResultSettingsMap\": {},\n     \"title\": \"\"\n    }\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 54:===========================================>              (6 + 2) / 8]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 5.4 ms, sys: 1.12 ms, total: 6.52 ms\\n\",\n      \"Wall time: 4.41 s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"preds = input_df.withColumn(\\\"preds\\\", generate(\\\"input\\\"))\\n\",\n    \"results = preds.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 42,\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"4c0cfc4e-ef0a-435e-9fdf-72b72b6def93\",\n     \"showTitle\": false,\n     \"tableResultSettingsMap\": {},\n     \"title\": \"\"\n    }\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 57:===========================================>              (6 + 2) / 8]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 4.59 ms, sys: 1.79 ms, total: 6.38 ms\\n\",\n      \"Wall time: 4.55 s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"preds = input_df.withColumn(\\\"preds\\\", generate(col(\\\"input\\\")))\\n\",\n    \"results = preds.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 43,\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"2d756e2e-8b60-43cb-b5f9-e27de11be24d\",\n     \"showTitle\": false,\n     \"tableResultSettingsMap\": {},\n     \"title\": \"\"\n    }\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+--------------------------------------------------+--------------------------------------------------+\\n\",\n      \"|                                             input|                                             preds|\\n\",\n      \"+--------------------------------------------------+--------------------------------------------------+\\n\",\n      \"|translate English to French: The only reason I'...|La seule raison pour laquelle je donne même ce ...|\\n\",\n      \"|translate English to French: Awkward disaster m...|La mishmash d’Awkward a eu une équipe de scaven...|\\n\",\n      \"|translate English to French: Here is a fantasti...|Voici un concept fantastique pour un film : une...|\\n\",\n      \"|translate English to French: I walked out of th...|Je me suis rendu du cinéma après avoir subi ce ...|\\n\",\n      \"|translate English to French: A wildly uneven fi...|Un film extrêmement inégal où le problème majeu...|\\n\",\n      \"|translate English to French: Leonard Rossiter a...|Leonard Rossiter et Frances de la Tour mettent ...|\\n\",\n      \"|          translate English to French: A good cast|                                  Une bonne étoile|\\n\",\n      \"|translate English to French: Yet again, I appea...|Encore une fois, je semble être la seule person...|\\n\",\n      \"|translate English to French: As a serious horro...|En tant que grand fan d'horreur, je peux obteni...|\\n\",\n      \"|translate English to French: Upon writing this ...|la suite de cette étude, j'ai de la difficulté ...|\\n\",\n      \"|         translate English to French: Simply awful|                          Tout simplement terrible|\\n\",\n      \"|translate English to French: I am a fan of Ed H...|Je suis un fan de l'oeuvre d'Ed Harris et j'ai ...|\\n\",\n      \"|                 translate English to French: Well|                                           Eh bien|\\n\",\n      \"|translate English to French: This is a new appr...|  Il s’agit d’une nouvelle approche de la comédie.|\\n\",\n      \"|translate English to French: It's been mentione...|Il a été mentionné par d'autres le dialogue ina...|\\n\",\n      \"|translate English to French: One of the most bo...|Un des films les plus ennuyeux que je n'ai jama...|\\n\",\n      \"|translate English to French: This movie was pla...|Ce film jouait sur Lifetime Movie Network le mo...|\\n\",\n      \"|translate English to French: 1983's \\\"Frightmare...|Le film \\\"Frightmare\\\" de 1983 est un petit film ...|\\n\",\n      \"|translate English to French: 'Felony' is a B-movie|                       'Felony' est un mouvement B|\\n\",\n      \"|translate English to French: This movie defines...|                   Ce film définit le mot «confus»|\\n\",\n      \"+--------------------------------------------------+--------------------------------------------------+\\n\",\n      \"only showing top 20 rows\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"preds.show(truncate=50)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"86ae68d4-57da-41d9-91b4-625ef9465d60\",\n     \"showTitle\": false,\n     \"tableResultSettingsMap\": {},\n     \"title\": \"\"\n    }\n   },\n   \"source\": [\n    \"#### Shut down servers on each executor\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 44,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2025-02-04 13:35:53,794 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\\n\",\n      \"2025-02-04 13:35:58,983 - INFO - Sucessfully stopped 1 servers.                 \\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"[True]\"\n      ]\n     },\n     \"execution_count\": 44,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"server_manager.stop_servers()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 45,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"if not on_databricks:  # on databricks, spark.stop() puts the cluster in a bad state\\n\",\n    \"    spark.stop()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"008c3e50-d321-4431-a9ab-919b35d1b042\",\n     \"showTitle\": false,\n     \"tableResultSettingsMap\": {},\n     \"title\": \"\"\n    }\n   },\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"application/vnd.databricks.v1+notebook\": {\n   \"dashboards\": [],\n   \"environmentMetadata\": null,\n   \"language\": \"python\",\n   \"notebookMetadata\": {\n    \"mostRecentlyExecutedCommandWithImplicitDF\": {\n     \"commandId\": 421988607303514,\n     \"dataframes\": [\n      \"_sqldf\"\n     ]\n    },\n    \"pythonIndentUnit\": 4\n   },\n   \"notebookName\": \"spark-triton-db.ipynb\",\n   \"widgets\": {}\n  },\n  \"kernelspec\": {\n   \"display_name\": \"spark-dl-torch\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.11.11\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 4\n}\n"
  },
  {
    "path": "examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/deepseek-r1_torch.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"<img src=\\\"http://developer.download.nvidia.com/notebooks/dlsw-notebooks/tensorrt_torchtrt_efficientnet/nvidia_logo.png\\\" width=\\\"90px\\\">\\n\",\n    \"\\n\",\n    \"# PySpark LLM Inference: DeepSeek-R1 Reasoning Q/A\\n\",\n    \"\\n\",\n    \"In this notebook, we demonstrate distributed batch inference with [DeepSeek-R1](https://github.com/deepseek-ai/DeepSeek-R1), using open weights on Huggingface.\\n\",\n    \"\\n\",\n    \"We use [DeepSeek-R1-Distill-Llama-8B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Llama-8B) as demonstration. DeepSeek's distilled models are based on open-source LLMs (such as Llama/Qwen), and are fine-tuned using samples generated by DeepSeek-R1. We'll show how to use the model to reason through word problems.\\n\",\n    \"\\n\",\n    \"**Note:** Running this model on GPU with 16-bit precision requires **~18GB** of GPU RAM. Make sure your instances have sufficient GPU capacity.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Manually enable Huggingface tokenizer parallelism to avoid disabling with PySpark parallelism.\\n\",\n    \"# See (https://github.com/huggingface/transformers/issues/5486) for more info. \\n\",\n    \"import os\\n\",\n    \"os.environ[\\\"TOKENIZERS_PARALLELISM\\\"] = \\\"true\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Check the cluster environment to handle any platform-specific configurations.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"on_databricks = os.environ.get(\\\"DATABRICKS_RUNTIME_VERSION\\\", False)\\n\",\n    \"on_dataproc = os.environ.get(\\\"DATAPROC_IMAGE_VERSION\\\", False)\\n\",\n    \"on_standalone = not (on_databricks or on_dataproc)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# For cloud environments, load the model to the distributed file system.\\n\",\n    \"if on_databricks:\\n\",\n    \"    models_dir = \\\"/dbfs/FileStore/spark-dl-models\\\"\\n\",\n    \"    dbutils.fs.mkdirs(\\\"/FileStore/spark-dl-models\\\")\\n\",\n    \"    model_path = f\\\"{models_dir}/deepseek-r1-distill-llama-8b\\\"\\n\",\n    \"elif on_dataproc:\\n\",\n    \"    models_dir = \\\"/mnt/gcs/spark-dl-models\\\"\\n\",\n    \"    os.mkdir(models_dir) if not os.path.exists(models_dir) else None\\n\",\n    \"    model_path = f\\\"{models_dir}/deepseek-r1-distill-llama-8b\\\"\\n\",\n    \"else:\\n\",\n    \"    model_path = os.path.abspath(\\\"deepseek-r1-distill-llama-8b\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Download the model from huggingface hub.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from huggingface_hub import snapshot_download\\n\",\n    \"\\n\",\n    \"model_path = snapshot_download(\\n\",\n    \"    repo_id=\\\"deepseek-ai/DeepSeek-R1-Distill-Llama-8B\\\",\\n\",\n    \"    local_dir=model_path\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Warmup: Running locally\\n\",\n    \"\\n\",\n    \"**Note:** If the driver node does not have sufficient GPU capacity, proceed to the PySpark section.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"0ab193983c774a948e375407d7df1f83\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Device set to use cuda\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"import torch\\n\",\n    \"from transformers import pipeline\\n\",\n    \"\\n\",\n    \"pipe = pipeline(\\\"text-generation\\\", model=model_path, torch_dtype=torch.bfloat16, device=\\\"cuda\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 10,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\n\",\n      \" How many r's are there in strawberry? Let's count them.\\n\",\n      \"\\n\",\n      \"First, I'll write down the word: S T R A W B E R R Y.\\n\",\n      \"\\n\",\n      \"Now, I'll go through each letter one by one.\\n\",\n      \"\\n\",\n      \"1. S - no R.\\n\",\n      \"2. T - no R.\\n\",\n      \"3. R - that's one R.\\n\",\n      \"4. A - no R.\\n\",\n      \"5. W - no R.\\n\",\n      \"6. B - no R.\\n\",\n      \"7. E - no R.\\n\",\n      \"8. R - that's two R's.\\n\",\n      \"9. R - that's three R's.\\n\",\n      \"10. Y - no R.\\n\",\n      \"\\n\",\n      \"So, in total, there are three R's in the word strawberry.\\n\",\n      \"</think>\\n\",\n      \"\\n\",\n      \"To determine how many **r's** are in the word **strawberry**, let's follow these steps:\\n\",\n      \"\\n\",\n      \"1. **Write down the word:**\\n\",\n      \"   \\n\",\n      \"   S T R A W B E R R Y\\n\",\n      \"\\n\",\n      \"2. **Identify and count each occurrence of the letter R:**\\n\",\n      \"   \\n\",\n      \"   - **1.** S - no R\\n\",\n      \"   - **2.** T - no R\\n\",\n      \"   - **3.** R - **1 R**\\n\",\n      \"   - **4.** A - no R\\n\",\n      \"   - **5.** W - no R\\n\",\n      \"   - **6.** B - no R\\n\",\n      \"   - **7.** E - no R\\n\",\n      \"   - **8.** R - **2 R's**\\n\",\n      \"   - **9.** R - **3 R's**\\n\",\n      \"   - **10.** Y - no R\\n\",\n      \"\\n\",\n      \"3. **Total count of R's:**\\n\",\n      \"   \\n\",\n      \"   There are **3 R's** in the word **strawberry**.\\n\",\n      \"\\n\",\n      \"\\\\boxed{3}\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"res = pipe([\\\"How many r's are there in strawberry?\\\"], max_new_tokens=512, temperature=0.1)\\n\",\n    \"print(\\\"\\\\n\\\", res[0][0]['generated_text'])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\n\",\n      \" Which number is bigger: 9.9 or 9.11? Let's see.\\n\",\n      \"\\n\",\n      \"First, I need to compare the whole number parts of both numbers. Both 9.9 and 9.11 have the same whole number part, which is 9.\\n\",\n      \"\\n\",\n      \"Since the whole numbers are equal, I'll compare the decimal parts. For 9.9, the decimal part is 0.9, and for 9.11, the decimal part is 0.11.\\n\",\n      \"\\n\",\n      \"To make it easier, I can express 0.9 as 0.90. Now, comparing 0.90 and 0.11, it's clear that 0.90 is greater than 0.11.\\n\",\n      \"\\n\",\n      \"Therefore, 9.9 is bigger than 9.11.\\n\",\n      \"</think>\\n\",\n      \"\\n\",\n      \"To determine which number is larger between **9.9** and **9.11**, let's compare them step by step.\\n\",\n      \"\\n\",\n      \"1. **Compare the Whole Numbers:**\\n\",\n      \"   - Both numbers have the same whole number part: **9**.\\n\",\n      \"   \\n\",\n      \"2. **Compare the Decimal Parts:**\\n\",\n      \"   - **9.9** can be written as **9.90**.\\n\",\n      \"   - **9.11** remains **9.11**.\\n\",\n      \"   \\n\",\n      \"3. **Analyze the Decimal Comparison:**\\n\",\n      \"   - Compare the tenths place:\\n\",\n      \"     - **9.90** has **9** in the tenths place.\\n\",\n      \"     - **9.11** has **1** in the tenths place.\\n\",\n      \"   - Since **9 > 1**, **9.90** is greater than **9.11**.\\n\",\n      \"\\n\",\n      \"4. **Conclusion:**\\n\",\n      \"   - Therefore, **9.9** is larger than **9.11**.\\n\",\n      \"\\n\",\n      \"\\\\boxed{9.9}\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"res = pipe([\\\"Which number is bigger: 9.9 or 9.11?\\\"], max_new_tokens=512, temperature=0.1)\\n\",\n    \"print(\\\"\\\\n\\\", res[0][0]['generated_text'])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import torch\\n\",\n    \"\\n\",\n    \"# Unload the model from GPU memory.\\n\",\n    \"del pipe\\n\",\n    \"torch.cuda.empty_cache()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## PySpark\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from pyspark.sql.types import *\\n\",\n    \"from pyspark import SparkConf\\n\",\n    \"from pyspark.sql import SparkSession\\n\",\n    \"from pyspark.sql.functions import pandas_udf, col, struct, length\\n\",\n    \"from pyspark.ml.functions import predict_batch_udf\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import os\\n\",\n    \"import pandas as pd\\n\",\n    \"import datasets\\n\",\n    \"from datasets import load_dataset\\n\",\n    \"datasets.disable_progress_bars()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Create Spark Session\\n\",\n    \"\\n\",\n    \"For local standalone clusters, we'll connect to the cluster and create the Spark Session.  \\n\",\n    \"For CSP environments, Spark will either be preconfigured (Databricks) or we'll need to create the Spark Session (Dataproc).\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"25/02/10 09:40:01 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)\\n\",\n      \"25/02/10 09:40:01 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\\n\",\n      \"Setting default log level to \\\"WARN\\\".\\n\",\n      \"To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\\n\",\n      \"25/02/10 09:40:02 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"conf = SparkConf()\\n\",\n    \"\\n\",\n    \"if 'spark' not in globals():\\n\",\n    \"    if on_standalone:\\n\",\n    \"        import socket\\n\",\n    \"        conda_env = os.environ.get(\\\"CONDA_PREFIX\\\")\\n\",\n    \"        hostname = socket.gethostname()\\n\",\n    \"        conf.setMaster(f\\\"spark://{hostname}:7077\\\")\\n\",\n    \"        conf.set(\\\"spark.pyspark.python\\\", f\\\"{conda_env}/bin/python\\\")\\n\",\n    \"        conf.set(\\\"spark.pyspark.driver.python\\\", f\\\"{conda_env}/bin/python\\\")\\n\",\n    \"\\n\",\n    \"    conf.set(\\\"spark.executor.cores\\\", \\\"8\\\")\\n\",\n    \"    conf.set(\\\"spark.task.maxFailures\\\", \\\"1\\\")\\n\",\n    \"    conf.set(\\\"spark.task.resource.gpu.amount\\\", \\\"0.125\\\")\\n\",\n    \"    conf.set(\\\"spark.executor.resource.gpu.amount\\\", \\\"1\\\")\\n\",\n    \"    conf.set(\\\"spark.sql.execution.arrow.pyspark.enabled\\\", \\\"true\\\")\\n\",\n    \"    conf.set(\\\"spark.python.worker.reuse\\\", \\\"true\\\")\\n\",\n    \"\\n\",\n    \"spark = SparkSession.builder.appName(\\\"spark-dl-examples\\\").config(conf=conf).getOrCreate()\\n\",\n    \"sc = spark.sparkContext\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Load DataFrame\\n\",\n    \"\\n\",\n    \"Load the first 500 samples of the [Orca Math Word Problems dataset](https://huggingface.co/datasets/microsoft/orca-math-word-problems-200k) from Huggingface and store in a Spark Dataframe.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"dataset = load_dataset(\\\"microsoft/orca-math-word-problems-200k\\\", split=\\\"train\\\", streaming=True)\\n\",\n    \"dataset = pd.Series([sample[\\\"question\\\"] for sample in dataset.take(500)])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+----------------------------------------------------------------------------------------------------+\\n\",\n      \"|                                                                                            question|\\n\",\n      \"+----------------------------------------------------------------------------------------------------+\\n\",\n      \"|Jungkook is the 5th place. Find the number of people who crossed the finish line faster than Jung...|\\n\",\n      \"|A number divided by 10 is 6. Yoongi got the result by subtracting 15 from a certain number. What ...|\\n\",\n      \"|Dongju selects a piece of paper with a number written on it, and wants to make a three-digit numb...|\\n\",\n      \"|You wanted to subtract 46 from a number, but you accidentally subtract 59 and get 43. How much do...|\\n\",\n      \"|The length of one span of Jinseo is about 12 centimeters (cm). When Jinseo measured the length of...|\\n\",\n      \"+----------------------------------------------------------------------------------------------------+\\n\",\n      \"only showing top 5 rows\\n\",\n      \"\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"df = spark.createDataFrame(dataset, schema=StringType()).withColumnRenamed(\\\"value\\\", \\\"question\\\")\\n\",\n    \"df.show(5, truncate=100)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 9,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"data_path = \\\"spark-dl-datasets/orca_math\\\"\\n\",\n    \"if on_databricks:\\n\",\n    \"    dbutils.fs.mkdirs(\\\"/FileStore/spark-dl-datasets\\\")\\n\",\n    \"    data_path = \\\"dbfs:/FileStore/\\\" + data_path\\n\",\n    \"\\n\",\n    \"df.write.mode(\\\"overwrite\\\").json(data_path)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Triton Inference Server\\n\",\n    \"We'll demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL.  \\n\",\n    \"We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server.  \\n\",\n    \"\\n\",\n    \"The process looks like this:\\n\",\n    \"- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.\\n\",\n    \"- Define a Triton inference function, which contains a client that binds to the local server on a given node and sends inference requests.\\n\",\n    \"- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.\\n\",\n    \"- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.\\n\",\n    \"\\n\",\n    \"<img src=\\\"../images/spark-server.png\\\" alt=\\\"drawing\\\" width=\\\"700\\\"/>\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 10,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from functools import partial\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Import the helper class from server_utils.py:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 11,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"sc.addPyFile(\\\"server_utils.py\\\")\\n\",\n    \"\\n\",\n    \"from server_utils import TritonServerManager\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Define the Triton Server function:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 12,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def triton_server(ports, model_path):\\n\",\n    \"    import time\\n\",\n    \"    import signal\\n\",\n    \"    import numpy as np\\n\",\n    \"    import torch\\n\",\n    \"    from transformers import pipeline\\n\",\n    \"    from pytriton.decorators import batch\\n\",\n    \"    from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor\\n\",\n    \"    from pytriton.triton import Triton, TritonConfig\\n\",\n    \"    from pyspark import TaskContext\\n\",\n    \"\\n\",\n    \"    print(f\\\"SERVER: Initializing model on worker {TaskContext.get().partitionId()}.\\\")\\n\",\n    \"    device = torch.device(\\\"cuda\\\" if torch.cuda.is_available() else \\\"cpu\\\")\\n\",\n    \"    pipe = pipeline(\\\"text-generation\\\", model=model_path, torch_dtype=torch.bfloat16, device=device)\\n\",\n    \"    print(f\\\"SERVER: Using {device} device.\\\")\\n\",\n    \"\\n\",\n    \"    @batch\\n\",\n    \"    def _infer_fn(**inputs):\\n\",\n    \"        prompts = np.squeeze(inputs[\\\"prompts\\\"]).tolist()\\n\",\n    \"        decoded_prompts = [p.decode(\\\"utf-8\\\") for p in prompts]\\n\",\n    \"        # limit responses to 256 tokens, since reasoning tasks can take a while\\n\",\n    \"        responses = pipe(decoded_prompts, max_new_tokens=256, temperature=0.2, return_full_text=False)\\n\",\n    \"        return {\\n\",\n    \"            \\\"responses\\\": np.array([r[0]['generated_text'] for r in responses]).reshape(-1, 1)\\n\",\n    \"        }\\n\",\n    \"\\n\",\n    \"    workspace_path = f\\\"/tmp/triton_{time.strftime('%m_%d_%M_%S')}\\\"\\n\",\n    \"    triton_conf = TritonConfig(http_port=ports[0], grpc_port=ports[1], metrics_port=ports[2])\\n\",\n    \"    with Triton(config=triton_conf, workspace=workspace_path) as triton:\\n\",\n    \"        triton.bind(\\n\",\n    \"            model_name=\\\"deepseek-r1\\\",\\n\",\n    \"            infer_func=_infer_fn,\\n\",\n    \"            inputs=[\\n\",\n    \"                Tensor(name=\\\"prompts\\\", dtype=object, shape=(-1,)),\\n\",\n    \"            ],\\n\",\n    \"            outputs=[\\n\",\n    \"                Tensor(name=\\\"responses\\\", dtype=object, shape=(-1,)),\\n\",\n    \"            ],\\n\",\n    \"            config=ModelConfig(\\n\",\n    \"                max_batch_size=16,\\n\",\n    \"                batcher=DynamicBatcher(max_queue_delay_microseconds=5000),  # 5ms\\n\",\n    \"            ),\\n\",\n    \"            strict=True,\\n\",\n    \"        )\\n\",\n    \"\\n\",\n    \"        def _stop_triton(signum, frame):\\n\",\n    \"            # The server manager sends SIGTERM to stop the server; this function ensures graceful cleanup.\\n\",\n    \"            print(\\\"SERVER: Received SIGTERM. Stopping Triton server.\\\")\\n\",\n    \"            triton.stop()\\n\",\n    \"\\n\",\n    \"        signal.signal(signal.SIGTERM, _stop_triton)\\n\",\n    \"\\n\",\n    \"        print(\\\"SERVER: Serving inference\\\")\\n\",\n    \"        triton.serve()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Start Triton servers\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"The `TritonServerManager` will handle the lifecycle of Triton server instances across the Spark cluster:\\n\",\n    \"- Find available ports for HTTP/gRPC/metrics\\n\",\n    \"- Deploy a server on each node via stage-level scheduling\\n\",\n    \"- Gracefully shutdown servers across nodes\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 14,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"model_name = \\\"deepseek-r1\\\"\\n\",\n    \"server_manager = TritonServerManager(model_name=model_name, model_path=model_path)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 15,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2025-02-10 09:40:17,442 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\\n\",\n      \"2025-02-10 09:40:17,442 - INFO - Starting 1 servers.\\n\",\n      \"                                                                                \\r\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'cb4ae00-lcedt': (272659, [7000, 7001, 7002])}\"\n      ]\n     },\n     \"execution_count\": 15,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"# Returns {'hostname', (server_pid, [http_port, grpc_port, metrics_port])}\\n\",\n    \"server_manager.start_servers(triton_server, wait_retries=24)  # allow up to 2 minutes for model loading\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Define client function\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Get the hostname -> url mapping from the server manager:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 16,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"host_to_grpc_url = server_manager.host_to_grpc_url  # or server_manager.host_to_http_url\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Define the Triton inference function, which returns a predict function for batch inference through the server:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 17,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def triton_fn(model_name, host_to_url):\\n\",\n    \"    import socket\\n\",\n    \"    import numpy as np\\n\",\n    \"    from pytriton.client import ModelClient\\n\",\n    \"\\n\",\n    \"    url = host_to_url[socket.gethostname()]\\n\",\n    \"    print(f\\\"Connecting to Triton model {model_name} at {url}.\\\")\\n\",\n    \"\\n\",\n    \"    def infer_batch(inputs):\\n\",\n    \"        with ModelClient(url, model_name, inference_timeout_s=500) as client:\\n\",\n    \"            flattened = np.squeeze(inputs).tolist()\\n\",\n    \"            # Encode batch\\n\",\n    \"            encoded_batch = [[text.encode(\\\"utf-8\\\")] for text in flattened]\\n\",\n    \"            encoded_batch_np = np.array(encoded_batch, dtype=np.bytes_)\\n\",\n    \"            # Run inference\\n\",\n    \"            result_data = client.infer_batch(encoded_batch_np)\\n\",\n    \"            result_data = np.squeeze(result_data[\\\"responses\\\"], -1)\\n\",\n    \"            return result_data\\n\",\n    \"        \\n\",\n    \"    return infer_batch\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 18,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"generate = predict_batch_udf(partial(triton_fn, model_name=model_name, host_to_url=host_to_grpc_url),\\n\",\n    \"                             return_type=StringType(),\\n\",\n    \"                             input_tensor_shapes=[[1]],\\n\",\n    \"                             batch_size=2)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Load and preprocess DataFrame\\n\",\n    \"\\n\",\n    \"We'll select a few of the shorter questions for demonstration, since reasoning tasks can take a while.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 19,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"df = spark.read.json(data_path)\\n\",\n    \"df = df.filter(length(col(\\\"question\\\")) <= 100).limit(16).repartition(8).cache()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Run Inference\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 20,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 6:==============>                                            (2 + 6) / 8]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 18.6 ms, sys: 8.31 ms, total: 26.9 ms\\n\",\n      \"Wall time: 1min 46s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"# first pass caches model/fn\\n\",\n    \"preds = df.withColumn(\\\"response\\\", generate(col(\\\"question\\\")))\\n\",\n    \"results = preds.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 30,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 23:==================================================>       (7 + 1) / 8]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 9.55 ms, sys: 4.51 ms, total: 14.1 ms\\n\",\n      \"Wall time: 1min 45s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"preds = df.withColumn(\\\"response\\\", generate(\\\"question\\\"))\\n\",\n    \"results = preds.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Sample output:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 21,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Q: There are 9 dogs and 23 cats. How many more cats are there than dogs? \\n\",\n      \"\\n\",\n      \"A:  Let me think. So, I have 23 cats and 9 dogs. To find out how many more cats there are than dogs, I need to subtract the number of dogs from the number of cats. That would be 23 minus 9. Let me do the subtraction: 23 minus 9 is 14. So, there are 14 more cats than dogs.\\n\",\n      \"\\n\",\n      \"Wait, let me double-check that. If I have 9 dogs and 23 cats, subtracting the number of dogs from the number of cats should give me the difference. So, 23 minus 9 is indeed 14. Yeah, that seems right. I don't think I made a mistake there. So, the answer is 14 more cats than dogs.\\n\",\n      \"\\n\",\n      \"**Final Answer**\\n\",\n      \"The number of cats exceeds the number of dogs by \\\\boxed{14}.\\n\",\n      \"\\\\boxed{14}\\n\",\n      \"</think>\\n\",\n      \"\\n\",\n      \"To determine how many more cats there are than dogs, we subtract the number of dogs from the number of cats. \\n\",\n      \"\\n\",\n      \"Given:\\n\",\n      \"- Number of cats = 23\\n\",\n      \"- Number of dogs = 9\\n\",\n      \"\\n\",\n      \"The calculation is:\\n\",\n      \"\\\\[ 23 - 9 = 14 \\\\]\\n\",\n      \"\\n\",\n      \"Thus, there are 14 more cats than dogs.\\n\",\n      \"\\n\",\n      \"\\\\[\\n\",\n      \"\\\\boxed{14}\\n\",\n      \"\\\\] \\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"print(f\\\"Q: {results[2].question} \\\\n\\\")\\n\",\n    \"print(f\\\"A: {results[2].response} \\\\n\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Shut down server on each executor\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 30,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2025-02-10 09:43:36,499 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2025-02-10 09:43:41,701 - INFO - Sucessfully stopped 1 servers.                 \\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"[True]\"\n      ]\n     },\n     \"execution_count\": 30,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"server_manager.stop_servers()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 31,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"if not on_databricks: # on databricks, spark.stop() puts the cluster in a bad state\\n\",\n    \"    spark.stop()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"spark-dl-torch\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.11.11\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/gemma-7b_torch.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"<img src=\\\"http://developer.download.nvidia.com/notebooks/dlsw-notebooks/tensorrt_torchtrt_efficientnet/nvidia_logo.png\\\" width=\\\"90px\\\">\\n\",\n    \"\\n\",\n    \"# PySpark LLM Inference: Gemma-7b Code Comprehension\\n\",\n    \"\\n\",\n    \"In this notebook, we demonstrate distributed inference with the Google [Gemma-7b-instruct](https://huggingface.co/google/gemma-7b-it) LLM, using open-weights on Huggingface.\\n\",\n    \"\\n\",\n    \"The Gemma-7b-instruct is an instruction-fine-tuned version of the Gemma-7b base model. We'll show how to use the model to perform code comprehension tasks.\\n\",\n    \"\\n\",\n    \"**Note:** Running this model on GPU with 16-bit precision requires **~18 GB** of GPU RAM. Make sure your instances have sufficient GPU capacity.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Manually enable Huggingface tokenizer parallelism to avoid disabling with PySpark parallelism.\\n\",\n    \"# See (https://github.com/huggingface/transformers/issues/5486) for more info. \\n\",\n    \"import os\\n\",\n    \"os.environ[\\\"TOKENIZERS_PARALLELISM\\\"] = \\\"true\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Check the cluster environment to handle any platform-specific configurations.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 13,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"on_databricks = os.environ.get(\\\"DATABRICKS_RUNTIME_VERSION\\\", False)\\n\",\n    \"on_dataproc = os.environ.get(\\\"DATAPROC_IMAGE_VERSION\\\", False)\\n\",\n    \"on_standalone = not (on_databricks or on_dataproc)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 14,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# For cloud environments, load the model to the distributed file system.\\n\",\n    \"if on_databricks:\\n\",\n    \"    models_dir = \\\"/dbfs/FileStore/spark-dl-models\\\"\\n\",\n    \"    dbutils.fs.mkdirs(\\\"/FileStore/spark-dl-models\\\")\\n\",\n    \"    model_path = f\\\"{models_dir}/gemma-7b-it\\\"\\n\",\n    \"elif on_dataproc:\\n\",\n    \"    models_dir = \\\"/mnt/gcs/spark-dl-models\\\"\\n\",\n    \"    os.mkdir(models_dir) if not os.path.exists(models_dir) else None\\n\",\n    \"    model_path = f\\\"{models_dir}/gemma-7b-it\\\"\\n\",\n    \"else:\\n\",\n    \"    model_path = os.path.abspath(\\\"gemma-7b-it\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"First visit the [Gemma Huggingface repository](https://huggingface.co/google/gemma-7b-it) to accept the terms to access the model, then login via huggingface_hub.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from huggingface_hub import login\\n\",\n    \"\\n\",\n    \"login()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Once you have access, you can download the model:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from huggingface_hub import snapshot_download\\n\",\n    \"\\n\",\n    \"model_path = snapshot_download(\\n\",\n    \"    repo_id=\\\"google/gemma-7b-it\\\",\\n\",\n    \"    local_dir=model_path,\\n\",\n    \"    ignore_patterns=\\\"*.gguf\\\"\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Warmup: Running locally\\n\",\n    \"\\n\",\n    \"**Note**: If the driver node does not have sufficient GPU capacity, proceed to the PySpark section.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"58494ca5858c40e39f924ad330a65885\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"import torch\\n\",\n    \"from transformers import AutoTokenizer, AutoModelForCausalLM\\n\",\n    \"\\n\",\n    \"tokenizer = AutoTokenizer.from_pretrained(model_path)\\n\",\n    \"model = AutoModelForCausalLM.from_pretrained(model_path,\\n\",\n    \"                                             device_map=\\\"auto\\\",\\n\",\n    \"                                             torch_dtype=torch.bfloat16)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 10,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"<bos>Write me a poem about Apache Spark.\\n\",\n      \"\\n\",\n      \"In the realm of big data, a spark ignites,\\n\",\n      \"A framework born to conquer the night.\\n\",\n      \"Apache Spark, a lightning-fast tool,\\n\",\n      \"For processing data, swift and cool.\\n\",\n      \"\\n\",\n      \"With its resilient distributed architecture,\\n\",\n      \"It slices through terabytes with grace.\\n\",\n      \"No longer bound by memory's plight,\\n\",\n      \"Spark empowers us to analyze with might.\\n\",\n      \"\\n\",\n      \"From Python to Scala, it's a versatile spark,\\n\",\n      \"Unveiling insights hidden in the dark.\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"input_text = \\\"Write me a poem about Apache Spark.\\\"\\n\",\n    \"inputs = tokenizer(input_text, return_tensors=\\\"pt\\\").to(\\\"cuda\\\")\\n\",\n    \"\\n\",\n    \"outputs = model.generate(**inputs, max_new_tokens=100, temperature=0.1, do_sample=True)\\n\",\n    \"print(tokenizer.decode(outputs[0]))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import torch\\n\",\n    \"\\n\",\n    \"# Unload the model from GPU memory.\\n\",\n    \"del model\\n\",\n    \"torch.cuda.empty_cache()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## PySpark\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from pyspark.sql.types import *\\n\",\n    \"from pyspark import SparkConf\\n\",\n    \"from pyspark.sql import SparkSession\\n\",\n    \"from pyspark.sql.functions import *\\n\",\n    \"from pyspark.ml.functions import predict_batch_udf\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import os\\n\",\n    \"import pandas as pd\\n\",\n    \"import datasets\\n\",\n    \"from datasets import load_dataset\\n\",\n    \"datasets.disable_progress_bars()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Create Spark Session\\n\",\n    \"\\n\",\n    \"For local standalone clusters, we'll connect to the cluster and create the Spark Session.  \\n\",\n    \"For CSP environments, Spark will either be preconfigured (Databricks) or we'll need to create the Spark Session (Dataproc).\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"25/02/10 09:44:33 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)\\n\",\n      \"25/02/10 09:44:33 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\\n\",\n      \"Setting default log level to \\\"WARN\\\".\\n\",\n      \"To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\\n\",\n      \"25/02/10 09:44:33 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"conf = SparkConf()\\n\",\n    \"\\n\",\n    \"if 'spark' not in globals():\\n\",\n    \"    if on_standalone:\\n\",\n    \"        import socket\\n\",\n    \"        conda_env = os.environ.get(\\\"CONDA_PREFIX\\\")\\n\",\n    \"        hostname = socket.gethostname()\\n\",\n    \"        conf.setMaster(f\\\"spark://{hostname}:7077\\\")\\n\",\n    \"        conf.set(\\\"spark.pyspark.python\\\", f\\\"{conda_env}/bin/python\\\")\\n\",\n    \"        conf.set(\\\"spark.pyspark.driver.python\\\", f\\\"{conda_env}/bin/python\\\")\\n\",\n    \"\\n\",\n    \"    conf.set(\\\"spark.executor.cores\\\", \\\"8\\\")\\n\",\n    \"    conf.set(\\\"spark.task.maxFailures\\\", \\\"1\\\")\\n\",\n    \"    conf.set(\\\"spark.task.resource.gpu.amount\\\", \\\"0.125\\\")\\n\",\n    \"    conf.set(\\\"spark.executor.resource.gpu.amount\\\", \\\"1\\\")\\n\",\n    \"    conf.set(\\\"spark.sql.execution.arrow.pyspark.enabled\\\", \\\"true\\\")\\n\",\n    \"    conf.set(\\\"spark.python.worker.reuse\\\", \\\"true\\\")\\n\",\n    \"\\n\",\n    \"spark = SparkSession.builder.appName(\\\"spark-dl-examples\\\").config(conf=conf).getOrCreate()\\n\",\n    \"sc = spark.sparkContext\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Load DataFrame\\n\",\n    \"\\n\",\n    \"Load the first 500 samples of the [Code Comprehension dataset](https://huggingface.co/datasets/imbue/code-comprehension) from Huggingface and store in a Spark Dataframe.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"dataset = load_dataset(\\\"imbue/code-comprehension\\\", split=\\\"train\\\", streaming=True)\\n\",\n    \"dataset = pd.Series([sample[\\\"question\\\"] for sample in dataset.take(500)])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"df = spark.createDataFrame(dataset, schema=StringType()).withColumnRenamed(\\\"value\\\", \\\"prompt\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 10,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+----------------------------------------------------------------------------------------------------+\\n\",\n      \"|                                                                                              prompt|\\n\",\n      \"+----------------------------------------------------------------------------------------------------+\\n\",\n      \"|If we execute the code below, what will `result` be equal to?\\\\n\\\\n```python\\\\nN = 'quz'\\\\nN += 'bar'...|\\n\",\n      \"|```python\\\\nresult = 9 - 9 - 1 - 7 - 9 - 1 + 9 - 2 + 6 - 4 - 8 - 1\\\\n```\\\\n\\\\nOut of these options, w...|\\n\",\n      \"|```python\\\\nx = 'bas'\\\\nD = 'bar'.swapcase()\\\\nx = len(x)\\\\nx = str(x)\\\\nnu = 'bar'.isnumeric()\\\\nx += ...|\\n\",\n      \"|If we execute the code below, what will `result` be equal to?\\\\n\\\\n```python\\\\n\\\\nl = 'likewise'\\\\nmat...|\\n\",\n      \"|```python\\\\nresult = 'mazda' + 'isolated' + 'mistakes' + 'grew' + 'raid' + 'junk' + 'jamaica' + 'c...|\\n\",\n      \"+----------------------------------------------------------------------------------------------------+\\n\",\n      \"only showing top 5 rows\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"df.show(5, truncate=100)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 11,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"If we execute the code below, what will `result` be equal to?\\n\",\n      \"\\n\",\n      \"```python\\n\",\n      \"N = 'quz'\\n\",\n      \"N += 'bar'\\n\",\n      \"N = N.swapcase()\\n\",\n      \"N = len(N)\\n\",\n      \"mu = 'bar'.strip()\\n\",\n      \"N = str(N)\\n\",\n      \"Q = N.isalpha()\\n\",\n      \"if N == 'bawr':\\n\",\n      \"    N = 'BAWR'.lower()\\n\",\n      \"N = N + N\\n\",\n      \"N = '-'.join([N, N, N, 'foo'])\\n\",\n      \"if mu == N:\\n\",\n      \"    N = 'bar'.upper()\\n\",\n      \"gamma = 'BAZ'.lower()\\n\",\n      \"\\n\",\n      \"result = N\\n\",\n      \"```\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"print(df.take(1)[0].prompt)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 12,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"data_path = \\\"spark-dl-datasets/code_comprehension\\\"\\n\",\n    \"if on_databricks:\\n\",\n    \"    dbutils.fs.mkdirs(\\\"/FileStore/spark-dl-datasets\\\")\\n\",\n    \"    data_path = \\\"dbfs:/FileStore/\\\" + data_path\\n\",\n    \"\\n\",\n    \"df.write.mode(\\\"overwrite\\\").json(data_path)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Using Triton Inference Server\\n\",\n    \"In this section, we demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL.  \\n\",\n    \"We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server.  \\n\",\n    \"\\n\",\n    \"The process looks like this:\\n\",\n    \"- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.\\n\",\n    \"- Define a Triton inference function, which contains a client that binds to the local server on a given node and sends inference requests.\\n\",\n    \"- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.\\n\",\n    \"- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.\\n\",\n    \"\\n\",\n    \"<img src=\\\"../images/spark-server.png\\\" alt=\\\"drawing\\\" width=\\\"700\\\"/>\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 15,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from functools import partial\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Import the helper class from server_utils.py:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 16,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"sc.addPyFile(\\\"server_utils.py\\\")\\n\",\n    \"\\n\",\n    \"from server_utils import TritonServerManager\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Define the Triton Server function:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 37,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def triton_server(ports, model_path):\\n\",\n    \"    import time\\n\",\n    \"    import signal\\n\",\n    \"    import numpy as np\\n\",\n    \"    import torch\\n\",\n    \"    from transformers import AutoTokenizer, AutoModelForCausalLM\\n\",\n    \"    from pytriton.decorators import batch\\n\",\n    \"    from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor\\n\",\n    \"    from pytriton.triton import Triton, TritonConfig\\n\",\n    \"    from pyspark import TaskContext\\n\",\n    \"\\n\",\n    \"    print(f\\\"SERVER: Initializing model on worker {TaskContext.get().partitionId()}.\\\")\\n\",\n    \"    device = torch.device(\\\"cuda\\\" if torch.cuda.is_available() else \\\"cpu\\\")\\n\",\n    \"    tokenizer = AutoTokenizer.from_pretrained(model_path)\\n\",\n    \"    model = AutoModelForCausalLM.from_pretrained(model_path, device_map=\\\"auto\\\", torch_dtype=torch.bfloat16)\\n\",\n    \"    print(f\\\"SERVER: Using {device} device.\\\")\\n\",\n    \"\\n\",\n    \"    @batch\\n\",\n    \"    def _infer_fn(**inputs):\\n\",\n    \"        prompts = np.squeeze(inputs[\\\"prompts\\\"]).tolist()\\n\",\n    \"        print(f\\\"SERVER: Received batch of size {len(prompts)}\\\")\\n\",\n    \"        decoded_prompts = [p.decode(\\\"utf-8\\\") for p in prompts]\\n\",\n    \"        tokenized_inputs = tokenizer(decoded_prompts, padding=True, return_tensors=\\\"pt\\\").to(device)\\n\",\n    \"        outputs = model.generate(**tokenized_inputs, max_new_tokens=256, temperature=0.1, do_sample=True)\\n\",\n    \"        # Decode only the model output (excluding the input prompt) and remove special tokens.\\n\",\n    \"        responses = np.array(tokenizer.batch_decode(outputs[:, tokenized_inputs.input_ids.shape[1]:], skip_special_tokens = True))\\n\",\n    \"        return {\\n\",\n    \"            \\\"responses\\\": responses.reshape(-1, 1),\\n\",\n    \"        }\\n\",\n    \"\\n\",\n    \"    workspace_path = f\\\"/tmp/triton_{time.strftime('%m_%d_%M_%S')}\\\"\\n\",\n    \"    triton_conf = TritonConfig(http_port=ports[0], grpc_port=ports[1], metrics_port=ports[2])\\n\",\n    \"    with Triton(config=triton_conf, workspace=workspace_path) as triton:\\n\",\n    \"        triton.bind(\\n\",\n    \"            model_name=\\\"gemma-7b\\\",\\n\",\n    \"            infer_func=_infer_fn,\\n\",\n    \"            inputs=[\\n\",\n    \"                Tensor(name=\\\"prompts\\\", dtype=object, shape=(-1,)),\\n\",\n    \"            ],\\n\",\n    \"            outputs=[\\n\",\n    \"                Tensor(name=\\\"responses\\\", dtype=object, shape=(-1,)),\\n\",\n    \"            ],\\n\",\n    \"            config=ModelConfig(\\n\",\n    \"                max_batch_size=16,\\n\",\n    \"                batcher=DynamicBatcher(max_queue_delay_microseconds=5000),  # 5ms\\n\",\n    \"            ),\\n\",\n    \"            strict=True,\\n\",\n    \"        )\\n\",\n    \"\\n\",\n    \"        def _stop_triton(signum, frame):\\n\",\n    \"            # The server manager sends SIGTERM to stop the server; this function ensures graceful cleanup.\\n\",\n    \"            print(\\\"SERVER: Received SIGTERM. Stopping Triton server.\\\")\\n\",\n    \"            triton.stop()\\n\",\n    \"\\n\",\n    \"        signal.signal(signal.SIGTERM, _stop_triton)\\n\",\n    \"\\n\",\n    \"        print(\\\"SERVER: Serving inference\\\")\\n\",\n    \"        triton.serve()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Start Triton servers\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"The `TritonServerManager` will handle the lifecycle of Triton server instances across the Spark cluster:\\n\",\n    \"- Find available ports for HTTP/gRPC/metrics\\n\",\n    \"- Deploy a server on each node via stage-level scheduling\\n\",\n    \"- Gracefully shutdown servers across nodes\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 39,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"model_name = \\\"gemma-7b\\\"\\n\",\n    \"server_manager = TritonServerManager(model_name=model_name, model_path=model_path)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2025-02-10 09:06:38,803 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\\n\",\n      \"2025-02-10 09:06:38,805 - INFO - Starting 1 servers.\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'cb4ae00-lcedt': (252119, [7000, 7001, 7002])}\"\n      ]\n     },\n     \"execution_count\": 40,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"# Returns {'hostname', (server_pid, [http_port, grpc_port, metrics_port])}\\n\",\n    \"server_manager.start_servers(triton_server, wait_retries=24)  # allow up to 2 minutes for model loading\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Define client function\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Get the hostname -> url mapping from the server manager:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 41,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"host_to_grpc_url = server_manager.host_to_grpc_url  # or server_manager.host_to_http_url\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Define the Triton inference function, which returns a predict function for batch inference through the server:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 42,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def triton_fn(model_name, host_to_url):\\n\",\n    \"    import socket\\n\",\n    \"    import numpy as np\\n\",\n    \"    from pytriton.client import ModelClient\\n\",\n    \"\\n\",\n    \"    url = host_to_url[socket.gethostname()]\\n\",\n    \"    print(f\\\"Connecting to Triton model {model_name} at {url}.\\\")\\n\",\n    \"\\n\",\n    \"    def infer_batch(inputs):\\n\",\n    \"        with ModelClient(url, model_name, inference_timeout_s=500) as client:\\n\",\n    \"            flattened = np.squeeze(inputs).tolist()\\n\",\n    \"            # Encode batch\\n\",\n    \"            encoded_batch = [[text.encode(\\\"utf-8\\\")] for text in flattened]\\n\",\n    \"            encoded_batch_np = np.array(encoded_batch, dtype=np.bytes_)\\n\",\n    \"            # Run inference\\n\",\n    \"            result_data = client.infer_batch(encoded_batch_np)\\n\",\n    \"            result_data = np.squeeze(result_data[\\\"responses\\\"], -1)\\n\",\n    \"            return result_data\\n\",\n    \"        \\n\",\n    \"    return infer_batch\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 43,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"generate = predict_batch_udf(partial(triton_fn, model_name=model_name, host_to_url=host_to_grpc_url),\\n\",\n    \"                             return_type=StringType(),\\n\",\n    \"                             input_tensor_shapes=[[1]],\\n\",\n    \"                             batch_size=4)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Load and preprocess DataFrame\\n\",\n    \"\\n\",\n    \"We'll parallelize over a small set of questions for demonstration.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 44,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"df = spark.read.json(data_path).limit(32).repartition(8)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Run Inference\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 45,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 30:====================================>                     (5 + 3) / 8]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 5.6 ms, sys: 3.51 ms, total: 9.11 ms\\n\",\n      \"Wall time: 28.1 s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"# first pass caches model/fn\\n\",\n    \"preds = df.withColumn(\\\"response\\\", generate(col(\\\"prompt\\\")))\\n\",\n    \"results = preds.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 51,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 42:=============================>                            (4 + 4) / 8]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 8.12 ms, sys: 3.13 ms, total: 11.2 ms\\n\",\n      \"Wall time: 23.1 s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"preds = df.withColumn(\\\"response\\\", generate(\\\"prompt\\\"))\\n\",\n    \"results = preds.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Sample output:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 54,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Q: ```python\\n\",\n      \"result = ['mirrors', 'limousines', 'meaningful', 'cats', UNKNOWN, 'striking', 'wings', 'injured', 'wishlist', 'granny'].index('oracle')\\n\",\n      \"print(result)\\n\",\n      \"```\\n\",\n      \"\\n\",\n      \"The code above has one or more parts replaced with the word UNKNOWN. Knowing that running the code prints `4` to the console, what should go in place of UNKNOWN? \\n\",\n      \"\\n\",\n      \"A: \\n\",\n      \"\\n\",\n      \"The answer is `oracle`.\\n\",\n      \"\\n\",\n      \"The code is searching for the index of the word `oracle` in the list `result`, and the index is returned as `4`. \\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"print(f\\\"Q: {results[2].prompt} \\\\n\\\")\\n\",\n    \"print(f\\\"A: {results[2].response} \\\\n\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Shut down server on each executor\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 55,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2025-02-10 09:11:11,880 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2025-02-10 09:11:17,105 - INFO - Sucessfully stopped 1 servers.                 \\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"[True]\"\n      ]\n     },\n     \"execution_count\": 55,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"server_manager.stop_servers()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 56,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"if not on_databricks: # on databricks, spark.stop() puts the cluster in a bad state\\n\",\n    \"    spark.stop()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"spark-dl-torch\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.11.11\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/pipelines_tf.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"9e9fe848\",\n   \"metadata\": {},\n   \"source\": [\n    \"<img src=\\\"http://developer.download.nvidia.com/notebooks/dlsw-notebooks/tensorrt_torchtrt_efficientnet/nvidia_logo.png\\\" width=\\\"90px\\\">\\n\",\n    \"\\n\",\n    \"# PySpark Huggingface Inferencing\\n\",\n    \"### Sentiment Analysis using Pipelines with Tensorflow\\n\",\n    \"\\n\",\n    \"In this notebook, we demonstrate distributed inference with Huggingface Pipelines to perform sentiment analysis.  \\n\",\n    \"From: https://huggingface.co/docs/transformers/quicktour#pipeline-usage\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"1799fd4f\",\n   \"metadata\": {},\n   \"source\": [\n    \"Note that cuFFT/cuDNN/cuBLAS registration errors are expected (as of `tf=2.17.0`) and will not affect behavior, as noted in [this issue.](https://github.com/tensorflow/tensorflow/issues/62075)  \"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"id\": \"0dd0f77b-ee1b-4477-a038-d25a4f1da0ea\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2025-02-04 13:57:08.242673: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\\n\",\n      \"2025-02-04 13:57:08.249833: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\\n\",\n      \"2025-02-04 13:57:08.257735: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\\n\",\n      \"2025-02-04 13:57:08.259994: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\\n\",\n      \"2025-02-04 13:57:08.266655: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\\n\",\n      \"To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\\n\",\n      \"2025-02-04 13:57:08.649929: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"import tensorflow as tf\\n\",\n    \"from transformers import pipeline\\n\",\n    \"\\n\",\n    \"# Manually enable Huggingface tokenizer parallelism to avoid disabling with PySpark parallelism.\\n\",\n    \"# See (https://github.com/huggingface/transformers/issues/5486) for more info. \\n\",\n    \"import os\\n\",\n    \"os.environ[\\\"TOKENIZERS_PARALLELISM\\\"] = \\\"true\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"id\": \"d80fc3f8\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\\n\",\n      \"I0000 00:00:1738706229.309141 3668091 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\\n\",\n      \"I0000 00:00:1738706229.333555 3668091 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\\n\",\n      \"I0000 00:00:1738706229.336487 3668091 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"device = 0 if tf.config.list_physical_devices('GPU') else -1\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"id\": \"e60a2877\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2.17.0\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"# Enable GPU memory growth\\n\",\n    \"gpus = tf.config.experimental.list_physical_devices('GPU')\\n\",\n    \"if gpus:\\n\",\n    \"    try:\\n\",\n    \"        for gpu in gpus:\\n\",\n    \"            tf.config.experimental.set_memory_growth(gpu, True)\\n\",\n    \"    except RuntimeError as e:\\n\",\n    \"        print(e)\\n\",\n    \"\\n\",\n    \"print(tf.__version__)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"id\": \"553b28d2-a5d1-4d07-8a49-8f82b808e738\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"No model was supplied, defaulted to distilbert/distilbert-base-uncased-finetuned-sst-2-english and revision 714eb0f (https://huggingface.co/distilbert/distilbert-base-uncased-finetuned-sst-2-english).\\n\",\n      \"Using a pipeline without specifying a model name and revision in production is not recommended.\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"I0000 00:00:1738706229.617170 3668091 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\\n\",\n      \"I0000 00:00:1738706229.620218 3668091 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\\n\",\n      \"I0000 00:00:1738706229.622781 3668091 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\\n\",\n      \"I0000 00:00:1738706229.732012 3668091 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\\n\",\n      \"I0000 00:00:1738706229.733045 3668091 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\\n\",\n      \"I0000 00:00:1738706229.733965 3668091 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\\n\",\n      \"2025-02-04 13:57:09.734873: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 43096 MB memory:  -> device: 0, name: NVIDIA RTX A6000, pci bus id: 0000:01:00.0, compute capability: 8.6\\n\",\n      \"All PyTorch model weights were used when initializing TFDistilBertForSequenceClassification.\\n\",\n      \"\\n\",\n      \"All the weights of TFDistilBertForSequenceClassification were initialized from the PyTorch model.\\n\",\n      \"If your task is similar to the task the model of the checkpoint was trained on, you can already use TFDistilBertForSequenceClassification for predictions without further training.\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"classifier = pipeline(\\\"sentiment-analysis\\\", device=device)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"3b91fe91-b725-4564-ae93-56e3fb51e47c\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"[{'label': 'POSITIVE', 'score': 0.9997794032096863}]\"\n      ]\n     },\n     \"execution_count\": 5,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"classifier((\\\"We are very happy to show you the 🤗 Transformers library.\\\"))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"0be39eb3-462c-42ff-b8f4-09f4e4fe3a3c\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"label: POSITIVE, with score: 0.9998\\n\",\n      \"label: NEGATIVE, with score: 0.5282\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"results = classifier([\\\"We are very happy to show you the 🤗 Transformers library.\\\", \\\"We hope you don't hate it.\\\"])\\n\",\n    \"for result in results:\\n\",\n    \"    print(f\\\"label: {result['label']}, with score: {round(result['score'], 4)}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"e29ee6d8\",\n   \"metadata\": {},\n   \"source\": [\n    \"Let's try a different model and tokenizer in the pipeline.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"id\": \"cd9d3349\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"model_name = \\\"nlptown/bert-base-multilingual-uncased-sentiment\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"id\": \"99e21b58\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"All PyTorch model weights were used when initializing TFBertForSequenceClassification.\\n\",\n      \"\\n\",\n      \"All the weights of TFBertForSequenceClassification were initialized from the PyTorch model.\\n\",\n      \"If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertForSequenceClassification for predictions without further training.\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"from transformers import AutoTokenizer, TFAutoModelForSequenceClassification\\n\",\n    \"\\n\",\n    \"model = TFAutoModelForSequenceClassification.from_pretrained(model_name)\\n\",\n    \"tokenizer = AutoTokenizer.from_pretrained(model_name)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"31079133\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"[{'label': '5 stars', 'score': 0.7272477746009827}]\"\n      ]\n     },\n     \"execution_count\": 9,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"classifier = pipeline(\\\"sentiment-analysis\\\", model=model, tokenizer=tokenizer, device=device)\\n\",\n    \"classifier(\\\"Nous sommes très heureux de vous présenter la bibliothèque 🤗 Transformers.\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"e6357234\",\n   \"metadata\": {},\n   \"source\": [\n    \"## PySpark\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 10,\n   \"id\": \"69dd6a1a-f450-47f0-9dbf-ad250585a011\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from pyspark.sql.functions import col, struct, pandas_udf\\n\",\n    \"from pyspark.ml.functions import predict_batch_udf\\n\",\n    \"from pyspark.sql.types import *\\n\",\n    \"from pyspark.sql import SparkSession\\n\",\n    \"from pyspark import SparkConf\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 11,\n   \"id\": \"287b1e96\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import os\\n\",\n    \"import json\\n\",\n    \"import pandas as pd\\n\",\n    \"import datasets\\n\",\n    \"from datasets import load_dataset\\n\",\n    \"datasets.disable_progress_bars()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"50e124cd\",\n   \"metadata\": {},\n   \"source\": [\n    \"Check the cluster environment to handle any platform-specific Spark configurations.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 12,\n   \"id\": \"36001f55\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"on_databricks = os.environ.get(\\\"DATABRICKS_RUNTIME_VERSION\\\", False)\\n\",\n    \"on_dataproc = os.environ.get(\\\"DATAPROC_IMAGE_VERSION\\\", False)\\n\",\n    \"on_standalone = not (on_databricks or on_dataproc)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"48c7271a\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Create Spark Session\\n\",\n    \"\\n\",\n    \"For local standalone clusters, we'll connect to the cluster and create the Spark Session.  \\n\",\n    \"For CSP environments, Spark will either be preconfigured (Databricks) or we'll need to create the Spark Session (Dataproc).\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 13,\n   \"id\": \"6e0e0dd7\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"25/02/04 13:57:12 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)\\n\",\n      \"25/02/04 13:57:12 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\\n\",\n      \"Setting default log level to \\\"WARN\\\".\\n\",\n      \"To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\\n\",\n      \"25/02/04 13:57:12 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"conf = SparkConf()\\n\",\n    \"\\n\",\n    \"if 'spark' not in globals():\\n\",\n    \"    if on_standalone:\\n\",\n    \"        import socket\\n\",\n    \"        \\n\",\n    \"        conda_env = os.environ.get(\\\"CONDA_PREFIX\\\")\\n\",\n    \"        hostname = socket.gethostname()\\n\",\n    \"        conf.setMaster(f\\\"spark://{hostname}:7077\\\")\\n\",\n    \"        conf.set(\\\"spark.pyspark.python\\\", f\\\"{conda_env}/bin/python\\\")\\n\",\n    \"        conf.set(\\\"spark.pyspark.driver.python\\\", f\\\"{conda_env}/bin/python\\\")\\n\",\n    \"    elif on_dataproc:\\n\",\n    \"        conf.set(\\\"spark.executorEnv.TF_GPU_ALLOCATOR\\\", \\\"cuda_malloc_async\\\")\\n\",\n    \"\\n\",\n    \"    conf.set(\\\"spark.executor.cores\\\", \\\"8\\\")\\n\",\n    \"    conf.set(\\\"spark.task.resource.gpu.amount\\\", \\\"0.125\\\")\\n\",\n    \"    conf.set(\\\"spark.executor.resource.gpu.amount\\\", \\\"1\\\")\\n\",\n    \"    conf.set(\\\"spark.sql.execution.arrow.pyspark.enabled\\\", \\\"true\\\")\\n\",\n    \"    conf.set(\\\"spark.python.worker.reuse\\\", \\\"true\\\")\\n\",\n    \"\\n\",\n    \"conf.set(\\\"spark.sql.execution.arrow.maxRecordsPerBatch\\\", \\\"1000\\\")\\n\",\n    \"spark = SparkSession.builder.appName(\\\"spark-dl-examples\\\").config(conf=conf).getOrCreate()\\n\",\n    \"sc = spark.sparkContext\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 14,\n   \"id\": \"42d70208\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"dataset = load_dataset(\\\"imdb\\\", split=\\\"test\\\")\\n\",\n    \"dataset = dataset.to_pandas().drop(columns=\\\"label\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"95ded4b2\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Create PySpark DataFrame\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 15,\n   \"id\": \"ac24f3c2\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"StructType([StructField('text', StringType(), True)])\"\n      ]\n     },\n     \"execution_count\": 15,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"df = spark.createDataFrame(dataset).repartition(8)\\n\",\n    \"df.schema\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 16,\n   \"id\": \"1db4db3a\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"25000\"\n      ]\n     },\n     \"execution_count\": 16,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"df.count()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 17,\n   \"id\": \"517fe2e9\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"25/02/04 13:57:20 WARN TaskSetManager: Stage 6 contains a task of very large size (4021 KiB). The maximum recommended task size is 1000 KiB.\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"[Row(text=\\\"Anyone remember the first CKY, CKY2K etc..? Back when it was about making crazy cool stuff, rather than watching Bam Margera act like a douchebag, spoiled 5 year old, super/rock-star wannabe.<br /><br />The show used to be awesome, however, Bam's fame and wealth has led him to believe, that we now enjoy him acting childish and idiotic, more than actual cool stuff, that used to be in ex. CKY2K.<br /><br />The acts are so repetitive, there's like nothing new, except annoying stupidity and rehearsed comments... The only things we see is Bam Margera, so busy showing us how much he doesn't care, how much money he got or whatsoever.<br /><br />I really got nothing much left to say except, give us back CKY2K, cause Bam suck..<br /><br />I enjoy watching Steve-o, Knoxville etc. a thousand times more.\\\")]\"\n      ]\n     },\n     \"execution_count\": 17,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"df.take(1)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 18,\n   \"id\": \"e176d28b\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"25/02/04 13:57:20 WARN TaskSetManager: Stage 9 contains a task of very large size (4021 KiB). The maximum recommended task size is 1000 KiB.\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"data_path = \\\"spark-dl-datasets/imdb_test\\\"\\n\",\n    \"if on_databricks:\\n\",\n    \"    dbutils.fs.mkdirs(\\\"/FileStore/spark-dl-datasets\\\")\\n\",\n    \"    data_path = \\\"dbfs:/FileStore/\\\" + data_path\\n\",\n    \"\\n\",\n    \"df.write.mode(\\\"overwrite\\\").parquet(data_path)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"395e0374\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Load and preprocess DataFrame\\n\",\n    \"\\n\",\n    \"Define our preprocess function. We'll take the first sentence from each sample as our input for sentiment analysis.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 19,\n   \"id\": \"9665b7b6-d7e9-4bd4-b29d-7a449ac5b574\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"@pandas_udf(\\\"string\\\")\\n\",\n    \"def preprocess(text: pd.Series) -> pd.Series:\\n\",\n    \"    return pd.Series([s.split(\\\".\\\")[0] for s in text])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 20,\n   \"id\": \"26693020\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+----------------------------------------------------------------------------------------------------+\\n\",\n      \"|                                                                                               input|\\n\",\n      \"+----------------------------------------------------------------------------------------------------+\\n\",\n      \"|Doesn't anyone bother to check where this kind of sludge comes from before blathering on about it...|\\n\",\n      \"|                          There were two things I hated about WASTED : The directing and the script |\\n\",\n      \"|                                I'm rather surprised that anybody found this film touching or moving|\\n\",\n      \"|Cultural Vandalism Is the new Hallmark production of Gulliver's Travels an act of cultural vandal...|\\n\",\n      \"|I was at Wrestlemania VI in Toronto as a 10 year old, and the event I saw then was pretty differe...|\\n\",\n      \"|                                                                     This movie has been done before|\\n\",\n      \"|[ as a new resolution for this year 2005, i decide to write a comment for each movie I saw in the...|\\n\",\n      \"|This movie is over hyped!! I am sad to say that I manage to watch the first 15 minutes of this mo...|\\n\",\n      \"|This show had a promising start as sort of the opposite of 'Oceans 11' but has developed into a s...|\\n\",\n      \"|MINOR PLOT SPOILERS AHEAD!!!<br /><br />How did such talented actors get involved in such mindles...|\\n\",\n      \"|                              There is not one character on this sitcom with any redeeming qualities|\\n\",\n      \"|                                  Tommy Lee Jones was the best Woodroe and no one can play Woodroe F|\\n\",\n      \"|                                 My wife rented this movie and then conveniently never got to see it|\\n\",\n      \"|This is one of those star-filled over-the-top comedies that could a) be hysterical, or b) wish th...|\\n\",\n      \"|This excruciatingly boring and unfunny movie made me think that Chaplin was the real Hitler, as o...|\\n\",\n      \"|                           you will likely be sorely disappointed by this sequel that's not a sequel|\\n\",\n      \"|                          If I was British, I would be embarrassed by this portrayal of incompetence|\\n\",\n      \"|One of those movies in which there are no big twists whatsoever and you can predict pretty much w...|\\n\",\n      \"|                        This show is like watching someone who is in training to someday host a show|\\n\",\n      \"|                                                                                                Sigh|\\n\",\n      \"+----------------------------------------------------------------------------------------------------+\\n\",\n      \"only showing top 20 rows\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"# Limit to N rows, since this can be slow\\n\",\n    \"df = spark.read.parquet(data_path).limit(256).repartition(8)\\n\",\n    \"df = df.select(preprocess(col(\\\"text\\\")).alias(\\\"input\\\")).cache()\\n\",\n    \"df.show(truncate=100)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"76dc525c\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Inference using Spark DL API\\n\",\n    \"\\n\",\n    \"Distributed inference using the PySpark [predict_batch_udf](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.functions.predict_batch_udf.html#pyspark.ml.functions.predict_batch_udf):\\n\",\n    \"\\n\",\n    \"- predict_batch_fn uses Tensorflow APIs to load the model and return a predict function which operates on numpy arrays \\n\",\n    \"- predict_batch_udf will convert the Spark DataFrame columns into numpy input batches for the predict function\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 21,\n   \"id\": \"0da9d25c-5ebe-4503-bb19-154fcc047cbf\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def predict_batch_fn():\\n\",\n    \"    import tensorflow as tf\\n\",\n    \"    from transformers import pipeline\\n\",\n    \"\\n\",\n    \"    # Enable GPU memory growth\\n\",\n    \"    gpus = tf.config.experimental.list_physical_devices('GPU')\\n\",\n    \"    if gpus:\\n\",\n    \"        try:\\n\",\n    \"            for gpu in gpus:\\n\",\n    \"                tf.config.experimental.set_memory_growth(gpu, True)\\n\",\n    \"        except RuntimeError as e:\\n\",\n    \"            print(e)\\n\",\n    \"    \\n\",\n    \"    device = 0 if tf.config.list_physical_devices('GPU') else -1\\n\",\n    \"    pipe = pipeline(\\\"sentiment-analysis\\\", device=device)\\n\",\n    \"    def predict(inputs):\\n\",\n    \"        return pipe(inputs.tolist())\\n\",\n    \"    return predict\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 22,\n   \"id\": \"78afef29-ee30-4267-9fb6-be2dcb86cbba\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"classify = predict_batch_udf(predict_batch_fn,\\n\",\n    \"                             return_type=StructType([\\n\",\n    \"                                 StructField(\\\"label\\\", StringType(), True),\\n\",\n    \"                                 StructField(\\\"score\\\", FloatType(), True)\\n\",\n    \"                             ]),\\n\",\n    \"                             batch_size=32)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 23,\n   \"id\": \"a5bc327e-89cf-4731-82e6-e66cb93deef1\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 18:=======>                                                  (1 + 7) / 8]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 8.06 ms, sys: 2.92 ms, total: 11 ms\\n\",\n      \"Wall time: 4.86 s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"# first pass caches model/fn\\n\",\n    \"# note: expanding the \\\"struct\\\" return_type to top-level columns\\n\",\n    \"preds = df.withColumn(\\\"preds\\\", classify(struct(\\\"input\\\"))).select(\\\"input\\\", \\\"preds.*\\\")\\n\",\n    \"results = preds.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 24,\n   \"id\": \"ac642895-cfd6-47ee-9b21-02e7835424e4\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 21:>                                                         (0 + 8) / 8]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 4.5 ms, sys: 1.43 ms, total: 5.93 ms\\n\",\n      \"Wall time: 1.19 s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"preds = df.withColumn(\\\"preds\\\", classify(\\\"input\\\")).select(\\\"input\\\", \\\"preds.*\\\")\\n\",\n    \"results = preds.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 25,\n   \"id\": \"76a44d80-d5db-405f-989c-7246379cfb95\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 24:>                                                         (0 + 8) / 8]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 5.9 ms, sys: 605 μs, total: 6.5 ms\\n\",\n      \"Wall time: 1.37 s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"preds = df.withColumn(\\\"preds\\\", classify(col(\\\"input\\\"))).select(\\\"input\\\", \\\"preds.*\\\")\\n\",\n    \"results = preds.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 26,\n   \"id\": \"c01761b3-c766-46b0-ae0b-fcf968ffb3a1\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 27:>                                                         (0 + 1) / 1]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+--------------------------------------------------------------------------------+--------+----------+\\n\",\n      \"|                                                                           input|   label|     score|\\n\",\n      \"+--------------------------------------------------------------------------------+--------+----------+\\n\",\n      \"|Doesn't anyone bother to check where this kind of sludge comes from before bl...|NEGATIVE| 0.9984061|\\n\",\n      \"|      There were two things I hated about WASTED : The directing and the script |NEGATIVE| 0.9979007|\\n\",\n      \"|            I'm rather surprised that anybody found this film touching or moving|POSITIVE|0.83874947|\\n\",\n      \"|Cultural Vandalism Is the new Hallmark production of Gulliver's Travels an ac...|NEGATIVE|0.99727434|\\n\",\n      \"|I was at Wrestlemania VI in Toronto as a 10 year old, and the event I saw the...|POSITIVE|  0.982114|\\n\",\n      \"|                                                 This movie has been done before|NEGATIVE|0.94210696|\\n\",\n      \"|[ as a new resolution for this year 2005, i decide to write a comment for eac...|NEGATIVE| 0.9967818|\\n\",\n      \"|This movie is over hyped!! I am sad to say that I manage to watch the first 1...|NEGATIVE| 0.9985843|\\n\",\n      \"|This show had a promising start as sort of the opposite of 'Oceans 11' but ha...|NEGATIVE|0.99926835|\\n\",\n      \"|MINOR PLOT SPOILERS AHEAD!!!<br /><br />How did such talented actors get invo...|NEGATIVE|0.99956733|\\n\",\n      \"|          There is not one character on this sitcom with any redeeming qualities|NEGATIVE| 0.9985662|\\n\",\n      \"|              Tommy Lee Jones was the best Woodroe and no one can play Woodroe F|POSITIVE|  0.994562|\\n\",\n      \"|             My wife rented this movie and then conveniently never got to see it|NEGATIVE|0.99841607|\\n\",\n      \"|This is one of those star-filled over-the-top comedies that could a) be hyste...|NEGATIVE| 0.9953243|\\n\",\n      \"|This excruciatingly boring and unfunny movie made me think that Chaplin was t...|NEGATIVE| 0.9997607|\\n\",\n      \"|       you will likely be sorely disappointed by this sequel that's not a sequel|NEGATIVE| 0.9997198|\\n\",\n      \"|      If I was British, I would be embarrassed by this portrayal of incompetence|NEGATIVE| 0.9965172|\\n\",\n      \"|One of those movies in which there are no big twists whatsoever and you can p...|NEGATIVE| 0.9986059|\\n\",\n      \"|    This show is like watching someone who is in training to someday host a show|NEGATIVE|0.97015846|\\n\",\n      \"|                                                                            Sigh|NEGATIVE| 0.9923151|\\n\",\n      \"+--------------------------------------------------------------------------------+--------+----------+\\n\",\n      \"only showing top 20 rows\\n\",\n      \"\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"preds.show(truncate=80)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"fc8127d9\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Using Triton Inference Server\\n\",\n    \"In this section, we demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL.  \\n\",\n    \"We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server.  \\n\",\n    \"\\n\",\n    \"The process looks like this:\\n\",\n    \"- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.\\n\",\n    \"- Define a Triton inference function, which contains a client that binds to the local server on a given node and sends inference requests.\\n\",\n    \"- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.\\n\",\n    \"- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.\\n\",\n    \"\\n\",\n    \"<img src=\\\"../images/spark-server.png\\\" alt=\\\"drawing\\\" width=\\\"700\\\"/>\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 27,\n   \"id\": \"4d4be844-4b8c-47df-bd09-0c280c7ff16b\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from functools import partial\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"4f15dfcb\",\n   \"metadata\": {},\n   \"source\": [\n    \"Import the helper class from server_utils.py:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 28,\n   \"id\": \"bfa7ec9d\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"sc.addPyFile(\\\"server_utils.py\\\")\\n\",\n    \"\\n\",\n    \"from server_utils import TritonServerManager\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"1bf04546\",\n   \"metadata\": {},\n   \"source\": [\n    \"Define the Triton Server function:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 29,\n   \"id\": \"7e53df9f-43cb-4c38-b8ac-dc2cbad99815\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def triton_server(ports):\\n\",\n    \"    import time\\n\",\n    \"    import signal\\n\",\n    \"    import numpy as np\\n\",\n    \"    import tensorflow as tf\\n\",\n    \"    from transformers import pipeline\\n\",\n    \"    from pytriton.decorators import batch\\n\",\n    \"    from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor\\n\",\n    \"    from pytriton.triton import Triton, TritonConfig\\n\",\n    \"    from pyspark import TaskContext\\n\",\n    \"\\n\",\n    \"    print(f\\\"SERVER: Initializing pipeline on worker {TaskContext.get().partitionId()}.\\\")\\n\",\n    \"    # Enable GPU memory growth\\n\",\n    \"    gpus = tf.config.experimental.list_physical_devices('GPU')\\n\",\n    \"    if gpus:\\n\",\n    \"        try:\\n\",\n    \"            for gpu in gpus:\\n\",\n    \"                tf.config.experimental.set_memory_growth(gpu, True)\\n\",\n    \"        except RuntimeError as e:\\n\",\n    \"            print(e)\\n\",\n    \"    \\n\",\n    \"    device = 0 if tf.config.list_physical_devices('GPU') else -1\\n\",\n    \"    \\n\",\n    \"    pipe = pipeline(\\\"sentiment-analysis\\\", device=device)\\n\",\n    \"    print(f\\\"SERVER: Using {device} device.\\\")\\n\",\n    \"\\n\",\n    \"    @batch\\n\",\n    \"    def _infer_fn(**inputs):\\n\",\n    \"        sentences = np.squeeze(inputs[\\\"text\\\"]).tolist()\\n\",\n    \"        print(f\\\"SERVER: Received batch of size {len(sentences)}\\\")\\n\",\n    \"        decoded_sentences = [s.decode(\\\"utf-8\\\") for s in sentences]\\n\",\n    \"        return {\\n\",\n    \"            \\\"outputs\\\": np.array([[json.dumps(o)] for o in pipe(decoded_sentences)])\\n\",\n    \"        }\\n\",\n    \"\\n\",\n    \"    workspace_path = f\\\"/tmp/triton_{time.strftime('%m_%d_%M_%S')}\\\"\\n\",\n    \"    triton_conf = TritonConfig(http_port=ports[0], grpc_port=ports[1], metrics_port=ports[2])\\n\",\n    \"    with Triton(config=triton_conf, workspace=workspace_path) as triton:\\n\",\n    \"        triton.bind(\\n\",\n    \"            model_name=\\\"SentimentAnalysis\\\",\\n\",\n    \"            infer_func=_infer_fn,\\n\",\n    \"            inputs=[\\n\",\n    \"                Tensor(name=\\\"text\\\", dtype=object, shape=(-1,)),\\n\",\n    \"            ],\\n\",\n    \"            outputs=[\\n\",\n    \"                Tensor(name=\\\"outputs\\\", dtype=object, shape=(-1,)),\\n\",\n    \"            ],\\n\",\n    \"            config=ModelConfig(\\n\",\n    \"                max_batch_size=64,\\n\",\n    \"                batcher=DynamicBatcher(max_queue_delay_microseconds=5000),  # 5ms\\n\",\n    \"            ),\\n\",\n    \"            strict=True,\\n\",\n    \"        )\\n\",\n    \"\\n\",\n    \"        def _stop_triton(signum, frame):\\n\",\n    \"            print(\\\"SERVER: Received SIGTERM. Stopping Triton server.\\\")\\n\",\n    \"            triton.stop()\\n\",\n    \"\\n\",\n    \"        signal.signal(signal.SIGTERM, _stop_triton)\\n\",\n    \"\\n\",\n    \"        print(\\\"SERVER: Serving inference\\\")\\n\",\n    \"        triton.serve()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"19d9028d\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Start Triton servers\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"5354c597\",\n   \"metadata\": {},\n   \"source\": [\n    \"The `TritonServerManager` will handle the lifecycle of Triton server instances across the Spark cluster:\\n\",\n    \"- Find available ports for HTTP/gRPC/metrics\\n\",\n    \"- Deploy a server on each node via stage-level scheduling\\n\",\n    \"- Gracefully shutdown servers across nodes\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"156de815\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"model_name = \\\"SentimentAnalysis\\\"\\n\",\n    \"server_manager = TritonServerManager(model_name=model_name)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"d003a862\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2025-02-07 11:03:44,809 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\\n\",\n      \"2025-02-07 11:03:44,810 - INFO - Starting 1 servers.\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'cb4ae00-lcedt': (2020631, [7000, 7001, 7002])}\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"# Returns {'hostname', (server_pid, [http_port, grpc_port, metrics_port])}\\n\",\n    \"server_manager.start_servers(triton_server)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"e4c4017c\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Define client function\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"405edc49\",\n   \"metadata\": {},\n   \"source\": [\n    \"Get the hostname -> url mapping from the server manager:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"19768ddb\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"host_to_http_url = server_manager.host_to_http_url  # or server_manager.host_to_grpc_url\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"eb5dbb89\",\n   \"metadata\": {},\n   \"source\": [\n    \"Define the Triton inference function, which returns a predict function for batch inference through the server:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 34,\n   \"id\": \"431b864c\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def triton_fn(model_name, host_to_url):\\n\",\n    \"    import socket\\n\",\n    \"    import numpy as np\\n\",\n    \"    from pytriton.client import ModelClient\\n\",\n    \"\\n\",\n    \"    url = host_to_url[socket.gethostname()]\\n\",\n    \"    print(f\\\"Connecting to Triton model {model_name} at {url}.\\\")\\n\",\n    \"\\n\",\n    \"    def infer_batch(inputs):\\n\",\n    \"        with ModelClient(url, model_name, inference_timeout_s=240) as client:\\n\",\n    \"            flattened = np.squeeze(inputs).tolist()\\n\",\n    \"            # Encode batch\\n\",\n    \"            encoded_batch = [[text.encode(\\\"utf-8\\\")] for text in flattened]\\n\",\n    \"            encoded_batch_np = np.array(encoded_batch, dtype=np.bytes_)\\n\",\n    \"            # Run inference\\n\",\n    \"            result_data = client.infer_batch(encoded_batch_np)\\n\",\n    \"            result_data = np.squeeze(result_data[\\\"outputs\\\"], -1)\\n\",\n    \"            return [json.loads(o) for o in result_data]\\n\",\n    \"        \\n\",\n    \"    return infer_batch\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 37,\n   \"id\": \"3930cfcd-3284-4c6a-a9b5-36b8053fe899\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"classify = predict_batch_udf(partial(triton_fn, model_name=model_name, host_to_url=host_to_http_url),\\n\",\n    \"                             return_type=StructType([\\n\",\n    \"                                 StructField(\\\"label\\\", StringType(), True),\\n\",\n    \"                                 StructField(\\\"score\\\", FloatType(), True)\\n\",\n    \"                             ]),\\n\",\n    \"                             input_tensor_shapes=[[1]],\\n\",\n    \"                             batch_size=32)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"5a8ec7be\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Load and preprocess DataFrame\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 35,\n   \"id\": \"d53fb283-bf9e-4571-8c68-b75a41f1f067\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"@pandas_udf(\\\"string\\\")\\n\",\n    \"def preprocess(text: pd.Series) -> pd.Series:\\n\",\n    \"    return pd.Series([s.split(\\\".\\\")[0] for s in text])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 36,\n   \"id\": \"29b0cc0d-c480-4e4a-bd41-207dc314cba5\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"25/02/04 13:57:36 WARN CacheManager: Asked to cache already cached data.\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"df = spark.read.parquet(data_path).limit(256).repartition(8)\\n\",\n    \"df = df.select(preprocess(col(\\\"text\\\")).alias(\\\"input\\\")).cache()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"da39990f\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Run Inference\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 38,\n   \"id\": \"8eecbf23-4e9e-4d4c-8645-98209b25db2c\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 33:===========================================>              (6 + 2) / 8]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 13.1 ms, sys: 8.29 ms, total: 21.4 ms\\n\",\n      \"Wall time: 7.54 s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"# first pass caches model/fn\\n\",\n    \"# note: expanding the \\\"struct\\\" return_type to top-level columns\\n\",\n    \"preds = df.withColumn(\\\"preds\\\", classify(struct(\\\"input\\\"))).select(\\\"input\\\", \\\"preds.*\\\")\\n\",\n    \"results = preds.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 39,\n   \"id\": \"566ba28c-0ca4-4479-a24a-c8a362228b89\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 36:===========================================>              (6 + 2) / 8]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 7.54 ms, sys: 3.13 ms, total: 10.7 ms\\n\",\n      \"Wall time: 7.02 s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"preds = df.withColumn(\\\"preds\\\", classify(\\\"input\\\")).select(\\\"input\\\", \\\"preds.*\\\")\\n\",\n    \"results = preds.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 40,\n   \"id\": \"44c7e776-08da-484a-ba07-9d6add1a0f15\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 39:===========================================>              (6 + 2) / 8]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 6.26 ms, sys: 3 ms, total: 9.26 ms\\n\",\n      \"Wall time: 7.03 s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"preds = df.withColumn(\\\"preds\\\", classify(col(\\\"input\\\"))).select(\\\"input\\\", \\\"preds.*\\\")\\n\",\n    \"results = preds.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 41,\n   \"id\": \"f61d79f8-661e-4d9e-a3aa-c0754b854603\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+--------------------------------------------------------------------------------+--------+----------+\\n\",\n      \"|                                                                           input|   label|     score|\\n\",\n      \"+--------------------------------------------------------------------------------+--------+----------+\\n\",\n      \"|Doesn't anyone bother to check where this kind of sludge comes from before bl...|NEGATIVE| 0.9984061|\\n\",\n      \"|      There were two things I hated about WASTED : The directing and the script |NEGATIVE| 0.9979007|\\n\",\n      \"|            I'm rather surprised that anybody found this film touching or moving|POSITIVE|0.83874947|\\n\",\n      \"|Cultural Vandalism Is the new Hallmark production of Gulliver's Travels an ac...|NEGATIVE|0.99727434|\\n\",\n      \"|I was at Wrestlemania VI in Toronto as a 10 year old, and the event I saw the...|POSITIVE|  0.982114|\\n\",\n      \"|                                                 This movie has been done before|NEGATIVE|0.94210696|\\n\",\n      \"|[ as a new resolution for this year 2005, i decide to write a comment for eac...|NEGATIVE| 0.9967818|\\n\",\n      \"|This movie is over hyped!! I am sad to say that I manage to watch the first 1...|NEGATIVE| 0.9985843|\\n\",\n      \"|This show had a promising start as sort of the opposite of 'Oceans 11' but ha...|NEGATIVE|0.99926835|\\n\",\n      \"|MINOR PLOT SPOILERS AHEAD!!!<br /><br />How did such talented actors get invo...|NEGATIVE|0.99956733|\\n\",\n      \"|          There is not one character on this sitcom with any redeeming qualities|NEGATIVE| 0.9985662|\\n\",\n      \"|              Tommy Lee Jones was the best Woodroe and no one can play Woodroe F|POSITIVE|  0.994562|\\n\",\n      \"|             My wife rented this movie and then conveniently never got to see it|NEGATIVE|0.99841607|\\n\",\n      \"|This is one of those star-filled over-the-top comedies that could a) be hyste...|NEGATIVE| 0.9953243|\\n\",\n      \"|This excruciatingly boring and unfunny movie made me think that Chaplin was t...|NEGATIVE| 0.9997607|\\n\",\n      \"|       you will likely be sorely disappointed by this sequel that's not a sequel|NEGATIVE| 0.9997198|\\n\",\n      \"|      If I was British, I would be embarrassed by this portrayal of incompetence|NEGATIVE| 0.9965172|\\n\",\n      \"|One of those movies in which there are no big twists whatsoever and you can p...|NEGATIVE| 0.9986059|\\n\",\n      \"|    This show is like watching someone who is in training to someday host a show|NEGATIVE|0.97015846|\\n\",\n      \"|                                                                            Sigh|NEGATIVE| 0.9923151|\\n\",\n      \"+--------------------------------------------------------------------------------+--------+----------+\\n\",\n      \"only showing top 20 rows\\n\",\n      \"\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"preds.show(truncate=80)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"fac2ae57\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Shut down server on each executor\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 42,\n   \"id\": \"425d3b28-7705-45ba-8a18-ad34fc895219\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2025-02-04 13:57:58,747 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\\n\",\n      \"2025-02-04 13:58:03,931 - INFO - Sucessfully stopped 1 servers.                 \\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"[True]\"\n      ]\n     },\n     \"execution_count\": 42,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"server_manager.stop_servers()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 43,\n   \"id\": \"9f19643c-4ee4-44f2-b762-2078c0c8eba9\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"if not on_databricks: # on databricks, spark.stop() puts the cluster in a bad state\\n\",\n    \"    spark.stop()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"6a538c47-317d-4cac-b9b9-559e88677518\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"spark-dl-tf\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.11.9\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/pipelines_torch.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"60f7ac5d-4a95-4170-a0ac-a7faac9d9ef4\",\n   \"metadata\": {},\n   \"source\": [\n    \"<img src=\\\"http://developer.download.nvidia.com/notebooks/dlsw-notebooks/tensorrt_torchtrt_efficientnet/nvidia_logo.png\\\" width=\\\"90px\\\">\\n\",\n    \"\\n\",\n    \"# PySpark Huggingface Inferencing\\n\",\n    \"### Sentiment Analysis using Pipelines with PyTorch\\n\",\n    \"\\n\",\n    \"In this notebook, we demonstrate distributed inference with Huggingface Pipelines to perform sentiment analysis.  \\n\",\n    \"From: https://huggingface.co/docs/transformers/quicktour#pipeline-usage\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"id\": \"0dd0f77b-ee1b-4477-a038-d25a4f1da0ea\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import torch\\n\",\n    \"from transformers import pipeline\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"id\": \"e1f756c6\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"device = torch.device(\\\"cuda\\\" if torch.cuda.is_available() else \\\"cpu\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"id\": \"553b28d2-a5d1-4d07-8a49-8f82b808e738\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"No model was supplied, defaulted to distilbert/distilbert-base-uncased-finetuned-sst-2-english and revision 714eb0f (https://huggingface.co/distilbert/distilbert-base-uncased-finetuned-sst-2-english).\\n\",\n      \"Using a pipeline without specifying a model name and revision in production is not recommended.\\n\",\n      \"Device set to use cuda\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"classifier = pipeline(\\\"sentiment-analysis\\\", device=device)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"id\": \"3b91fe91-b725-4564-ae93-56e3fb51e47c\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"[{'label': 'POSITIVE', 'score': 0.9997795224189758}]\"\n      ]\n     },\n     \"execution_count\": 4,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"classifier((\\\"We are very happy to show you the 🤗 Transformers library.\\\"))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"id\": \"0be39eb3-462c-42ff-b8f4-09f4e4fe3a3c\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"label: POSITIVE, with score: 0.9998\\n\",\n      \"label: NEGATIVE, with score: 0.5309\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"results = classifier([\\\"We are very happy to show you the 🤗 Transformers library.\\\", \\\"We hope you don't hate it.\\\"])\\n\",\n    \"for result in results:\\n\",\n    \"    print(f\\\"label: {result['label']}, with score: {round(result['score'], 4)}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"f752f929\",\n   \"metadata\": {},\n   \"source\": [\n    \"Let's try a different model and tokenizer in the pipeline.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"id\": \"9861865f\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"model_name = \\\"nlptown/bert-base-multilingual-uncased-sentiment\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"id\": \"506e7834\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from transformers import AutoTokenizer, AutoModelForSequenceClassification\\n\",\n    \"\\n\",\n    \"model = AutoModelForSequenceClassification.from_pretrained(model_name)\\n\",\n    \"tokenizer = AutoTokenizer.from_pretrained(model_name)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"id\": \"312017fc\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Device set to use cuda\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"[{'label': '5 stars', 'score': 0.7272652983665466}]\"\n      ]\n     },\n     \"execution_count\": 8,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"classifier = pipeline(\\\"sentiment-analysis\\\", model=model, tokenizer=tokenizer, device=device)\\n\",\n    \"classifier(\\\"Nous sommes très heureux de vous présenter la bibliothèque 🤗 Transformers.\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"ae92b15e-0da0-46c3-81a3-fabaedbfc42c\",\n   \"metadata\": {},\n   \"source\": [\n    \"## PySpark\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 9,\n   \"id\": \"69dd6a1a-f450-47f0-9dbf-ad250585a011\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from pyspark.sql.functions import col, struct, pandas_udf\\n\",\n    \"from pyspark.ml.functions import predict_batch_udf\\n\",\n    \"from pyspark.sql.types import *\\n\",\n    \"from pyspark.sql import SparkSession\\n\",\n    \"from pyspark import SparkConf\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 10,\n   \"id\": \"42c19ad8\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import os\\n\",\n    \"import json\\n\",\n    \"import pandas as pd\\n\",\n    \"import datasets\\n\",\n    \"from datasets import load_dataset\\n\",\n    \"datasets.disable_progress_bars()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"3f1a0210\",\n   \"metadata\": {},\n   \"source\": [\n    \"Check the cluster environment to handle any platform-specific Spark configurations.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 11,\n   \"id\": \"79aaf5ec\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"on_databricks = os.environ.get(\\\"DATABRICKS_RUNTIME_VERSION\\\", False)\\n\",\n    \"on_dataproc = os.environ.get(\\\"DATAPROC_IMAGE_VERSION\\\", False)\\n\",\n    \"on_standalone = not (on_databricks or on_dataproc)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"b99f9c38\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Create Spark Session\\n\",\n    \"\\n\",\n    \"For local standalone clusters, we'll connect to the cluster and create the Spark Session.  \\n\",\n    \"For CSP environments, Spark will either be preconfigured (Databricks) or we'll need to create the Spark Session (Dataproc).\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 12,\n   \"id\": \"6e0e0dd7\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"25/02/04 13:23:47 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)\\n\",\n      \"25/02/04 13:23:47 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\\n\",\n      \"Setting default log level to \\\"WARN\\\".\\n\",\n      \"To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\\n\",\n      \"25/02/04 13:23:47 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"conf = SparkConf()\\n\",\n    \"\\n\",\n    \"if 'spark' not in globals():\\n\",\n    \"    if on_standalone:\\n\",\n    \"        import socket\\n\",\n    \"        conda_env = os.environ.get(\\\"CONDA_PREFIX\\\")\\n\",\n    \"        hostname = socket.gethostname()\\n\",\n    \"        conf.setMaster(f\\\"spark://{hostname}:7077\\\")\\n\",\n    \"        conf.set(\\\"spark.pyspark.python\\\", f\\\"{conda_env}/bin/python\\\")\\n\",\n    \"        conf.set(\\\"spark.pyspark.driver.python\\\", f\\\"{conda_env}/bin/python\\\")\\n\",\n    \"\\n\",\n    \"    conf.set(\\\"spark.executor.cores\\\", \\\"8\\\")\\n\",\n    \"    conf.set(\\\"spark.task.resource.gpu.amount\\\", \\\"0.125\\\")\\n\",\n    \"    conf.set(\\\"spark.executor.resource.gpu.amount\\\", \\\"1\\\")\\n\",\n    \"    conf.set(\\\"spark.sql.execution.arrow.pyspark.enabled\\\", \\\"true\\\")\\n\",\n    \"    conf.set(\\\"spark.python.worker.reuse\\\", \\\"true\\\")\\n\",\n    \"\\n\",\n    \"conf.set(\\\"spark.sql.execution.arrow.maxRecordsPerBatch\\\", \\\"1000\\\")\\n\",\n    \"spark = SparkSession.builder.appName(\\\"spark-dl-examples\\\").config(conf=conf).getOrCreate()\\n\",\n    \"sc = spark.sparkContext\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 13,\n   \"id\": \"42d70208\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"dataset = load_dataset(\\\"imdb\\\", split=\\\"test\\\")\\n\",\n    \"dataset = dataset.to_pandas().drop(columns=\\\"label\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"de0f421d\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Create PySpark DataFrame\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 14,\n   \"id\": \"ac24f3c2\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"StructType([StructField('text', StringType(), True)])\"\n      ]\n     },\n     \"execution_count\": 14,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"df = spark.createDataFrame(dataset).repartition(8)\\n\",\n    \"df.schema\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 15,\n   \"id\": \"b0d1876b\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"25000\"\n      ]\n     },\n     \"execution_count\": 15,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"df.count()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 16,\n   \"id\": \"06ec6bb6\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"25/02/04 13:23:54 WARN TaskSetManager: Stage 6 contains a task of very large size (4021 KiB). The maximum recommended task size is 1000 KiB.\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"[Row(text=\\\"Anyone remember the first CKY, CKY2K etc..? Back when it was about making crazy cool stuff, rather than watching Bam Margera act like a douchebag, spoiled 5 year old, super/rock-star wannabe.<br /><br />The show used to be awesome, however, Bam's fame and wealth has led him to believe, that we now enjoy him acting childish and idiotic, more than actual cool stuff, that used to be in ex. CKY2K.<br /><br />The acts are so repetitive, there's like nothing new, except annoying stupidity and rehearsed comments... The only things we see is Bam Margera, so busy showing us how much he doesn't care, how much money he got or whatsoever.<br /><br />I really got nothing much left to say except, give us back CKY2K, cause Bam suck..<br /><br />I enjoy watching Steve-o, Knoxville etc. a thousand times more.\\\")]\"\n      ]\n     },\n     \"execution_count\": 16,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"df.take(1)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 17,\n   \"id\": \"eeadf4e2\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"25/02/04 13:23:54 WARN TaskSetManager: Stage 9 contains a task of very large size (4021 KiB). The maximum recommended task size is 1000 KiB.\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"data_path = \\\"spark-dl-datasets/imdb_test\\\"\\n\",\n    \"if on_databricks:\\n\",\n    \"    dbutils.fs.mkdirs(\\\"/FileStore/spark-dl-datasets\\\")\\n\",\n    \"    data_path = \\\"dbfs:/FileStore/\\\" + data_path\\n\",\n    \"\\n\",\n    \"df.write.mode(\\\"overwrite\\\").parquet(data_path)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"09cddc95\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Load and preprocess DataFrame\\n\",\n    \"\\n\",\n    \"Define our preprocess function. We'll take the first sentence from each sample as our input for sentiment analysis.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 18,\n   \"id\": \"9665b7b6-d7e9-4bd4-b29d-7a449ac5b574\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"@pandas_udf(\\\"string\\\")\\n\",\n    \"def preprocess(text: pd.Series) -> pd.Series:\\n\",\n    \"    return pd.Series([s.split(\\\".\\\")[0] for s in text])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 19,\n   \"id\": \"74cfa3ff\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+----------------------------------------------------------------------------------------------------+\\n\",\n      \"|                                                                                               input|\\n\",\n      \"+----------------------------------------------------------------------------------------------------+\\n\",\n      \"|Doesn't anyone bother to check where this kind of sludge comes from before blathering on about it...|\\n\",\n      \"|                          There were two things I hated about WASTED : The directing and the script |\\n\",\n      \"|                                I'm rather surprised that anybody found this film touching or moving|\\n\",\n      \"|Cultural Vandalism Is the new Hallmark production of Gulliver's Travels an act of cultural vandal...|\\n\",\n      \"|I was at Wrestlemania VI in Toronto as a 10 year old, and the event I saw then was pretty differe...|\\n\",\n      \"|                                                                     This movie has been done before|\\n\",\n      \"|[ as a new resolution for this year 2005, i decide to write a comment for each movie I saw in the...|\\n\",\n      \"|This movie is over hyped!! I am sad to say that I manage to watch the first 15 minutes of this mo...|\\n\",\n      \"|This show had a promising start as sort of the opposite of 'Oceans 11' but has developed into a s...|\\n\",\n      \"|MINOR PLOT SPOILERS AHEAD!!!<br /><br />How did such talented actors get involved in such mindles...|\\n\",\n      \"|                              There is not one character on this sitcom with any redeeming qualities|\\n\",\n      \"|                                  Tommy Lee Jones was the best Woodroe and no one can play Woodroe F|\\n\",\n      \"|                                 My wife rented this movie and then conveniently never got to see it|\\n\",\n      \"|This is one of those star-filled over-the-top comedies that could a) be hysterical, or b) wish th...|\\n\",\n      \"|This excruciatingly boring and unfunny movie made me think that Chaplin was the real Hitler, as o...|\\n\",\n      \"|                           you will likely be sorely disappointed by this sequel that's not a sequel|\\n\",\n      \"|                          If I was British, I would be embarrassed by this portrayal of incompetence|\\n\",\n      \"|One of those movies in which there are no big twists whatsoever and you can predict pretty much w...|\\n\",\n      \"|                        This show is like watching someone who is in training to someday host a show|\\n\",\n      \"|                                                                                                Sigh|\\n\",\n      \"+----------------------------------------------------------------------------------------------------+\\n\",\n      \"only showing top 20 rows\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"# Limit to N rows, since this can be slow\\n\",\n    \"df = spark.read.parquet(data_path).limit(256).repartition(8)\\n\",\n    \"df = df.select(preprocess(col(\\\"text\\\")).alias(\\\"input\\\")).cache()\\n\",\n    \"df.show(truncate=100)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"1ad92750\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Inference using Spark DL API\\n\",\n    \"\\n\",\n    \"Distributed inference using the PySpark [predict_batch_udf](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.functions.predict_batch_udf.html#pyspark.ml.functions.predict_batch_udf):\\n\",\n    \"\\n\",\n    \"- predict_batch_fn uses PyTorch APIs to load the model and return a predict function which operates on numpy arrays \\n\",\n    \"- predict_batch_udf will convert the Spark DataFrame columns into numpy input batches for the predict function\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 20,\n   \"id\": \"0da9d25c-5ebe-4503-bb19-154fcc047cbf\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def predict_batch_fn():\\n\",\n    \"    import torch\\n\",\n    \"    from transformers import pipeline\\n\",\n    \"    \\n\",\n    \"    device = torch.device(\\\"cuda\\\" if torch.cuda.is_available() else \\\"cpu\\\")\\n\",\n    \"    pipe = pipeline(\\\"sentiment-analysis\\\", device=device)\\n\",\n    \"    def predict(inputs):\\n\",\n    \"        return pipe(inputs.tolist())\\n\",\n    \"    return predict\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 21,\n   \"id\": \"78afef29-ee30-4267-9fb6-be2dcb86cbba\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"classify = predict_batch_udf(predict_batch_fn,\\n\",\n    \"                             return_type=StructType([\\n\",\n    \"                                 StructField(\\\"label\\\", StringType(), True),\\n\",\n    \"                                 StructField(\\\"score\\\", FloatType(), True)\\n\",\n    \"                             ]),\\n\",\n    \"                             batch_size=32)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 22,\n   \"id\": \"a5bc327e-89cf-4731-82e6-e66cb93deef1\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 18:====================================>                     (5 + 3) / 8]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 8.82 ms, sys: 2.5 ms, total: 11.3 ms\\n\",\n      \"Wall time: 3.59 s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"# first pass caches model/fn\\n\",\n    \"# note: expanding the \\\"struct\\\" return_type to top-level columns\\n\",\n    \"preds = df.withColumn(\\\"preds\\\", classify(struct(\\\"input\\\"))).select(\\\"input\\\", \\\"preds.*\\\")\\n\",\n    \"results = preds.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 23,\n   \"id\": \"ac642895-cfd6-47ee-9b21-02e7835424e4\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 3.19 ms, sys: 1.65 ms, total: 4.84 ms\\n\",\n      \"Wall time: 392 ms\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"preds = df.withColumn(\\\"preds\\\", classify(\\\"input\\\")).select(\\\"input\\\", \\\"preds.*\\\")\\n\",\n    \"results = preds.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 24,\n   \"id\": \"76a44d80-d5db-405f-989c-7246379cfb95\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 3.43 ms, sys: 2.33 ms, total: 5.77 ms\\n\",\n      \"Wall time: 403 ms\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"preds = df.withColumn(\\\"preds\\\", classify(col(\\\"input\\\"))).select(\\\"input\\\", \\\"preds.*\\\")\\n\",\n    \"results = preds.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 25,\n   \"id\": \"c01761b3-c766-46b0-ae0b-fcf968ffb3a1\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+--------------------------------------------------------------------------------+--------+----------+\\n\",\n      \"|                                                                           input|   label|     score|\\n\",\n      \"+--------------------------------------------------------------------------------+--------+----------+\\n\",\n      \"|Doesn't anyone bother to check where this kind of sludge comes from before bl...|NEGATIVE| 0.9984042|\\n\",\n      \"|      There were two things I hated about WASTED : The directing and the script |NEGATIVE| 0.9979019|\\n\",\n      \"|            I'm rather surprised that anybody found this film touching or moving|POSITIVE|  0.839279|\\n\",\n      \"|Cultural Vandalism Is the new Hallmark production of Gulliver's Travels an ac...|NEGATIVE|0.99726933|\\n\",\n      \"|I was at Wrestlemania VI in Toronto as a 10 year old, and the event I saw the...|POSITIVE|0.98212504|\\n\",\n      \"|                                                 This movie has been done before|NEGATIVE| 0.9419482|\\n\",\n      \"|[ as a new resolution for this year 2005, i decide to write a comment for eac...|NEGATIVE|0.99678314|\\n\",\n      \"|This movie is over hyped!! I am sad to say that I manage to watch the first 1...|NEGATIVE| 0.9985846|\\n\",\n      \"|This show had a promising start as sort of the opposite of 'Oceans 11' but ha...|NEGATIVE|0.99926823|\\n\",\n      \"|MINOR PLOT SPOILERS AHEAD!!!<br /><br />How did such talented actors get invo...|NEGATIVE| 0.9995671|\\n\",\n      \"|          There is not one character on this sitcom with any redeeming qualities|NEGATIVE|0.99856514|\\n\",\n      \"|              Tommy Lee Jones was the best Woodroe and no one can play Woodroe F|POSITIVE| 0.9945687|\\n\",\n      \"|             My wife rented this movie and then conveniently never got to see it|NEGATIVE| 0.9984137|\\n\",\n      \"|This is one of those star-filled over-the-top comedies that could a) be hyste...|NEGATIVE| 0.9953224|\\n\",\n      \"|This excruciatingly boring and unfunny movie made me think that Chaplin was t...|NEGATIVE| 0.9997607|\\n\",\n      \"|       you will likely be sorely disappointed by this sequel that's not a sequel|NEGATIVE|0.99971956|\\n\",\n      \"|      If I was British, I would be embarrassed by this portrayal of incompetence|NEGATIVE|0.99651587|\\n\",\n      \"|One of those movies in which there are no big twists whatsoever and you can p...|NEGATIVE|0.99860746|\\n\",\n      \"|    This show is like watching someone who is in training to someday host a show|NEGATIVE|  0.970153|\\n\",\n      \"|                                                                            Sigh|NEGATIVE|0.99231356|\\n\",\n      \"+--------------------------------------------------------------------------------+--------+----------+\\n\",\n      \"only showing top 20 rows\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"preds.show(truncate=80)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"8ba1a6ce\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Using Triton Inference Server\\n\",\n    \"In this section, we demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL.  \\n\",\n    \"We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server.  \\n\",\n    \"\\n\",\n    \"The process looks like this:\\n\",\n    \"- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.\\n\",\n    \"- Define a Triton inference function, which contains a client that binds to the local server on a given node and sends inference requests.\\n\",\n    \"- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.\\n\",\n    \"- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.\\n\",\n    \"\\n\",\n    \"<img src=\\\"../images/spark-server.png\\\" alt=\\\"drawing\\\" width=\\\"700\\\"/>\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 26,\n   \"id\": \"4d4be844-4b8c-47df-bd09-0c280c7ff16b\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from functools import partial\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"ab52381b\",\n   \"metadata\": {},\n   \"source\": [\n    \"Import the helper class from server_utils.py:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 32,\n   \"id\": \"4e6764c4\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"sc.addPyFile(\\\"server_utils.py\\\")\\n\",\n    \"\\n\",\n    \"from server_utils import TritonServerManager\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"bab70481\",\n   \"metadata\": {},\n   \"source\": [\n    \"Define the Triton Server function:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 33,\n   \"id\": \"7e53df9f-43cb-4c38-b8ac-dc2cbad99815\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def triton_server(ports):\\n\",\n    \"    import time\\n\",\n    \"    import signal\\n\",\n    \"    import numpy as np\\n\",\n    \"    import torch\\n\",\n    \"    from transformers import pipeline\\n\",\n    \"    from pytriton.decorators import batch\\n\",\n    \"    from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor\\n\",\n    \"    from pytriton.triton import Triton, TritonConfig\\n\",\n    \"    from pyspark import TaskContext\\n\",\n    \"\\n\",\n    \"    print(f\\\"SERVER: Initializing pipeline on worker {TaskContext.get().partitionId()}.\\\")\\n\",\n    \"    device = torch.device(\\\"cuda\\\" if torch.cuda.is_available() else \\\"cpu\\\")\\n\",\n    \"    pipe = pipeline(\\\"sentiment-analysis\\\", device=device)\\n\",\n    \"    print(f\\\"SERVER: Using {device} device.\\\")\\n\",\n    \"\\n\",\n    \"    @batch\\n\",\n    \"    def _infer_fn(**inputs):\\n\",\n    \"        sentences = np.squeeze(inputs[\\\"text\\\"]).tolist()\\n\",\n    \"        print(f\\\"SERVER: Received batch of size {len(sentences)}\\\")\\n\",\n    \"        decoded_sentences = [s.decode(\\\"utf-8\\\") for s in sentences]\\n\",\n    \"        return {\\n\",\n    \"            \\\"outputs\\\": np.array([[json.dumps(o)] for o in pipe(decoded_sentences)])\\n\",\n    \"        }\\n\",\n    \"\\n\",\n    \"    workspace_path = f\\\"/tmp/triton_{time.strftime('%m_%d_%M_%S')}\\\"\\n\",\n    \"    triton_conf = TritonConfig(http_port=ports[0], grpc_port=ports[1], metrics_port=ports[2])\\n\",\n    \"    with Triton(config=triton_conf, workspace=workspace_path) as triton:\\n\",\n    \"        triton.bind(\\n\",\n    \"            model_name=\\\"SentimentAnalysis\\\",\\n\",\n    \"            infer_func=_infer_fn,\\n\",\n    \"            inputs=[\\n\",\n    \"                Tensor(name=\\\"text\\\", dtype=object, shape=(-1,)),\\n\",\n    \"            ],\\n\",\n    \"            outputs=[\\n\",\n    \"                Tensor(name=\\\"outputs\\\", dtype=object, shape=(-1,)),\\n\",\n    \"            ],\\n\",\n    \"            config=ModelConfig(\\n\",\n    \"                max_batch_size=64,\\n\",\n    \"                batcher=DynamicBatcher(max_queue_delay_microseconds=5000),  # 5ms\\n\",\n    \"            ),\\n\",\n    \"            strict=True,\\n\",\n    \"        )\\n\",\n    \"\\n\",\n    \"        def _stop_triton(signum, frame):\\n\",\n    \"            # The server manager sends SIGTERM to stop the server; this function ensures graceful cleanup.\\n\",\n    \"            print(\\\"SERVER: Received SIGTERM. Stopping Triton server.\\\")\\n\",\n    \"            triton.stop()\\n\",\n    \"\\n\",\n    \"        signal.signal(signal.SIGTERM, _stop_triton)\\n\",\n    \"\\n\",\n    \"        print(\\\"SERVER: Serving inference\\\")\\n\",\n    \"        triton.serve()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"7c5f4f2d\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Start Triton servers\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"b5ef160a\",\n   \"metadata\": {},\n   \"source\": [\n    \"The `TritonServerManager` will handle the lifecycle of Triton server instances across the Spark cluster:\\n\",\n    \"- Find available ports for HTTP/gRPC/metrics\\n\",\n    \"- Deploy a server on each node via stage-level scheduling\\n\",\n    \"- Gracefully shutdown servers across nodes\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 35,\n   \"id\": \"ad13db78\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"model_name = \\\"SentimentAnalysis\\\"\\n\",\n    \"server_manager = TritonServerManager(model_name=model_name)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"e62d9739\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2025-02-07 11:03:44,809 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\\n\",\n      \"2025-02-07 11:03:44,810 - INFO - Starting 1 servers.\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'cb4ae00-lcedt': (2020631, [7000, 7001, 7002])}\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"# Returns {'hostname', (server_pid, [http_port, grpc_port, metrics_port])}\\n\",\n    \"server_manager.start_servers(triton_server)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"f5ae0b8e\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Define client function\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"9e2059f9\",\n   \"metadata\": {},\n   \"source\": [\n    \"Get the hostname -> url mapping from the server manager:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"7ede428b\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"host_to_http_url = server_manager.host_to_http_url  # or server_manager.host_to_grpc_url\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"72f16ff5\",\n   \"metadata\": {},\n   \"source\": [\n    \"Define the Triton inference function, which returns a predict function for batch inference through the server:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 38,\n   \"id\": \"14760940\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def triton_fn(model_name, host_to_url):\\n\",\n    \"    import socket\\n\",\n    \"    import numpy as np\\n\",\n    \"    from pytriton.client import ModelClient\\n\",\n    \"\\n\",\n    \"    url = host_to_url[socket.gethostname()]\\n\",\n    \"    print(f\\\"Connecting to Triton model {model_name} at {url}.\\\")\\n\",\n    \"\\n\",\n    \"    def infer_batch(inputs):\\n\",\n    \"        with ModelClient(url, model_name, inference_timeout_s=240) as client:\\n\",\n    \"            flattened = np.squeeze(inputs).tolist()\\n\",\n    \"            # Encode batch\\n\",\n    \"            encoded_batch = [[text.encode(\\\"utf-8\\\")] for text in flattened]\\n\",\n    \"            encoded_batch_np = np.array(encoded_batch, dtype=np.bytes_)\\n\",\n    \"            # Run inference\\n\",\n    \"            result_data = client.infer_batch(encoded_batch_np)\\n\",\n    \"            result_data = np.squeeze(result_data[\\\"outputs\\\"], -1)\\n\",\n    \"            return [json.loads(o) for o in result_data]\\n\",\n    \"        \\n\",\n    \"    return infer_batch\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 41,\n   \"id\": \"3930cfcd-3284-4c6a-a9b5-36b8053fe899\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"classify = predict_batch_udf(partial(triton_fn, model_name=model_name, host_to_url=host_to_http_url),\\n\",\n    \"                             return_type=StructType([\\n\",\n    \"                                 StructField(\\\"label\\\", StringType(), True),\\n\",\n    \"                                 StructField(\\\"score\\\", FloatType(), True)\\n\",\n    \"                             ]),\\n\",\n    \"                             input_tensor_shapes=[[1]],\\n\",\n    \"                             batch_size=32)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"a741e23a\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Load and preprocess DataFrame\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 39,\n   \"id\": \"ccc884a4\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"@pandas_udf(\\\"string\\\")\\n\",\n    \"def preprocess(text: pd.Series) -> pd.Series:\\n\",\n    \"    return pd.Series([s.split(\\\".\\\")[0] for s in text])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 40,\n   \"id\": \"c426fdbe\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"25/02/04 13:24:35 WARN CacheManager: Asked to cache already cached data.\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"df = spark.read.parquet(data_path).limit(256).repartition(8)\\n\",\n    \"df = df.select(preprocess(col(\\\"text\\\")).alias(\\\"input\\\")).cache()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"7da06df4\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Run Inference\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 42,\n   \"id\": \"8eecbf23-4e9e-4d4c-8645-98209b25db2c\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 10.5 ms, sys: 2.2 ms, total: 12.7 ms\\n\",\n      \"Wall time: 671 ms\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"# first pass caches model/fn\\n\",\n    \"# note: expanding the \\\"struct\\\" return_type to top-level columns\\n\",\n    \"preds = df.withColumn(\\\"preds\\\", classify(struct(\\\"input\\\"))).select(\\\"input\\\", \\\"preds.*\\\")\\n\",\n    \"results = preds.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 43,\n   \"id\": \"566ba28c-0ca4-4479-a24a-c8a362228b89\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 1.68 ms, sys: 1.87 ms, total: 3.55 ms\\n\",\n      \"Wall time: 396 ms\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"preds = df.withColumn(\\\"preds\\\", classify(\\\"input\\\")).select(\\\"input\\\", \\\"preds.*\\\")\\n\",\n    \"results = preds.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 44,\n   \"id\": \"44c7e776-08da-484a-ba07-9d6add1a0f15\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 3.06 ms, sys: 5.02 ms, total: 8.08 ms\\n\",\n      \"Wall time: 408 ms\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"preds = df.withColumn(\\\"preds\\\", classify(col(\\\"input\\\"))).select(\\\"input\\\", \\\"preds.*\\\")\\n\",\n    \"results = preds.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 45,\n   \"id\": \"f61d79f8-661e-4d9e-a3aa-c0754b854603\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+----------------------------------------------------------------------+--------+----------+\\n\",\n      \"|                                                                 input|   label|     score|\\n\",\n      \"+----------------------------------------------------------------------+--------+----------+\\n\",\n      \"|Doesn't anyone bother to check where this kind of sludge comes from...|NEGATIVE| 0.9984042|\\n\",\n      \"|There were two things I hated about WASTED : The directing and the ...|NEGATIVE| 0.9979019|\\n\",\n      \"|  I'm rather surprised that anybody found this film touching or moving|POSITIVE|  0.839279|\\n\",\n      \"|Cultural Vandalism Is the new Hallmark production of Gulliver's Tra...|NEGATIVE|0.99726933|\\n\",\n      \"|I was at Wrestlemania VI in Toronto as a 10 year old, and the event...|POSITIVE|0.98212504|\\n\",\n      \"|                                       This movie has been done before|NEGATIVE| 0.9419482|\\n\",\n      \"|[ as a new resolution for this year 2005, i decide to write a comme...|NEGATIVE|0.99678314|\\n\",\n      \"|This movie is over hyped!! I am sad to say that I manage to watch t...|NEGATIVE| 0.9985846|\\n\",\n      \"|This show had a promising start as sort of the opposite of 'Oceans ...|NEGATIVE|0.99926823|\\n\",\n      \"|MINOR PLOT SPOILERS AHEAD!!!<br /><br />How did such talented actor...|NEGATIVE| 0.9995671|\\n\",\n      \"|There is not one character on this sitcom with any redeeming qualities|NEGATIVE|0.99856514|\\n\",\n      \"|    Tommy Lee Jones was the best Woodroe and no one can play Woodroe F|POSITIVE| 0.9945687|\\n\",\n      \"|   My wife rented this movie and then conveniently never got to see it|NEGATIVE| 0.9984137|\\n\",\n      \"|This is one of those star-filled over-the-top comedies that could a...|NEGATIVE| 0.9953224|\\n\",\n      \"|This excruciatingly boring and unfunny movie made me think that Cha...|NEGATIVE| 0.9997607|\\n\",\n      \"|you will likely be sorely disappointed by this sequel that's not a ...|NEGATIVE|0.99971956|\\n\",\n      \"|If I was British, I would be embarrassed by this portrayal of incom...|NEGATIVE|0.99651587|\\n\",\n      \"|One of those movies in which there are no big twists whatsoever and...|NEGATIVE|0.99860746|\\n\",\n      \"|This show is like watching someone who is in training to someday ho...|NEGATIVE|  0.970153|\\n\",\n      \"|                                                                  Sigh|NEGATIVE|0.99231356|\\n\",\n      \"+----------------------------------------------------------------------+--------+----------+\\n\",\n      \"only showing top 20 rows\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"preds.show(truncate=70)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"2248858c\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Shut down server on each executor\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 46,\n   \"id\": \"e3a4e51f\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2025-02-04 13:24:40,325 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2025-02-04 13:24:45,576 - INFO - Sucessfully stopped 1 servers.                 \\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"[True]\"\n      ]\n     },\n     \"execution_count\": 46,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"server_manager.stop_servers()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 47,\n   \"id\": \"9f19643c-4ee4-44f2-b762-2078c0c8eba9\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"if not on_databricks: # on databricks, spark.stop() puts the cluster in a bad state\\n\",\n    \"    spark.stop()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"a8b03e1e\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"spark-dl-torch\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.11.11\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/qwen-2.5-7b_torch.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"<img src=\\\"http://developer.download.nvidia.com/notebooks/dlsw-notebooks/tensorrt_torchtrt_efficientnet/nvidia_logo.png\\\" width=\\\"90px\\\">\\n\",\n    \"\\n\",\n    \"# PySpark LLM Inference: Qwen-2.5 Text Summarization\\n\",\n    \"\\n\",\n    \"In this notebook, we demonstrate distributed batch inference with [Qwen-2.5](https://huggingface.co/Qwen/Qwen2.5-7B-Instruct), using open weights on Huggingface.\\n\",\n    \"\\n\",\n    \"The Qwen-2.5-7b-instruct is an instruction-fine-tuned version of the Qwen-2.5-7b base model. We'll show how to use the model to perform text summarization.\\n\",\n    \"\\n\",\n    \"**Note:** Running this model on GPU with 16-bit precision requires **~16GB** of GPU RAM. Make sure your instances have sufficient GPU capacity.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"The dataset we'll use requires Zstandard compression.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 16,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Collecting zstandard\\n\",\n      \"  Downloading zstandard-0.23.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.0 kB)\\n\",\n      \"Downloading zstandard-0.23.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.4 MB)\\n\",\n      \"\\u001b[2K   \\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\\u001b[0m \\u001b[32m5.4/5.4 MB\\u001b[0m \\u001b[31m66.3 MB/s\\u001b[0m eta \\u001b[36m0:00:00\\u001b[0m\\n\",\n      \"\\u001b[?25hInstalling collected packages: zstandard\\n\",\n      \"Successfully installed zstandard-0.23.0\\n\",\n      \"Note: you may need to restart the kernel to use updated packages.\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%pip install zstandard\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Manually enable Huggingface tokenizer parallelism to avoid disabling with PySpark parallelism.\\n\",\n    \"# See (https://github.com/huggingface/transformers/issues/5486) for more info. \\n\",\n    \"import os\\n\",\n    \"os.environ[\\\"TOKENIZERS_PARALLELISM\\\"] = \\\"true\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Check the cluster environment to handle any platform-specific configurations.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"on_databricks = os.environ.get(\\\"DATABRICKS_RUNTIME_VERSION\\\", False)\\n\",\n    \"on_dataproc = os.environ.get(\\\"DATAPROC_IMAGE_VERSION\\\", False)\\n\",\n    \"on_standalone = not (on_databricks or on_dataproc)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# For cloud environments, load the model to the distributed file system.\\n\",\n    \"if on_databricks:\\n\",\n    \"    models_dir = \\\"/dbfs/FileStore/spark-dl-models\\\"\\n\",\n    \"    dbutils.fs.mkdirs(\\\"/FileStore/spark-dl-models\\\")\\n\",\n    \"    model_path = f\\\"{models_dir}/qwen-2.5-7b\\\"\\n\",\n    \"elif on_dataproc:\\n\",\n    \"    models_dir = \\\"/mnt/gcs/spark-dl-models\\\"\\n\",\n    \"    os.mkdir(models_dir) if not os.path.exists(models_dir) else None\\n\",\n    \"    model_path = f\\\"{models_dir}/qwen-2.5-7b\\\"\\n\",\n    \"else:\\n\",\n    \"    model_path = os.path.abspath(\\\"qwen-2.5-7b\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Download the model from huggingface hub.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from huggingface_hub import snapshot_download\\n\",\n    \"\\n\",\n    \"model_path = snapshot_download(\\n\",\n    \"    repo_id=\\\"Qwen/Qwen2.5-7B-Instruct\\\",\\n\",\n    \"    local_dir=model_path\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Warmup: Running locally\\n\",\n    \"\\n\",\n    \"**Note**: If the driver node does not have sufficient GPU capacity, proceed to the PySpark section.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 16,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"352b738e1a2442b0a997467aaf6eb0ad\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"import torch\\n\",\n    \"from transformers import AutoModelForCausalLM, AutoTokenizer\\n\",\n    \"\\n\",\n    \"model = AutoModelForCausalLM.from_pretrained(\\n\",\n    \"    model_path,\\n\",\n    \"    torch_dtype=torch.bfloat16,\\n\",\n    \"    device_map=\\\"auto\\\"\\n\",\n    \")\\n\",\n    \"tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side=\\\"left\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 17,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"system_prompt = {\\n\",\n    \"    \\\"role\\\": \\\"system\\\",\\n\",\n    \"    \\\"content\\\": \\\"You are a knowledgeable AI assistant that provides accurate answers to questions.\\\"\\n\",\n    \"}\\n\",\n    \"\\n\",\n    \"queries = [\\n\",\n    \"    \\\"How many vowels are in 'elephant'?\\\",\\n\",\n    \"    \\\"What is the square root of 16?\\\",\\n\",\n    \"    \\\"How many planets are in our solar system?\\\"\\n\",\n    \"]\\n\",\n    \"\\n\",\n    \"prompts = [\\n\",\n    \"    [\\n\",\n    \"        system_prompt,\\n\",\n    \"        {\\\"role\\\": \\\"user\\\", \\\"content\\\": query}\\n\",\n    \"    ] for query in queries\\n\",\n    \"]\\n\",\n    \"\\n\",\n    \"text = tokenizer.apply_chat_template(\\n\",\n    \"    prompts,\\n\",\n    \"    tokenize=False,\\n\",\n    \"    add_generation_prompt=True,\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"model_inputs = tokenizer(text, return_tensors=\\\"pt\\\", padding=True).to(model.device)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 23,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"generated_ids = model.generate(\\n\",\n    \"    **model_inputs,\\n\",\n    \"    max_new_tokens=256,\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"outputs = tokenizer.batch_decode(generated_ids[:, model_inputs.input_ids.shape[1]:], skip_special_tokens = True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 24,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Q: How many vowels are in 'elephant'?\\n\",\n      \"A: The word \\\"elephant\\\" contains 3 vowels. The vowels are 'e', 'e', and 'a'.\\n\",\n      \"\\n\",\n      \"Q: What is the square root of 16?\\n\",\n      \"A: The square root of 16 is 4, because \\\\(4 \\\\times 4 = 16\\\\).\\n\",\n      \"\\n\",\n      \"Q: How many planets are in our solar system?\\n\",\n      \"A: There are eight planets in our solar system. They are, in order from the Sun:\\n\",\n      \"\\n\",\n      \"1. Mercury\\n\",\n      \"2. Venus\\n\",\n      \"3. Earth\\n\",\n      \"4. Mars\\n\",\n      \"5. Jupiter\\n\",\n      \"6. Saturn\\n\",\n      \"7. Uranus\\n\",\n      \"8. Neptune\\n\",\n      \"\\n\",\n      \"Pluto was previously considered the ninth planet but is now classified as a dwarf planet.\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"for query, output in zip(queries, outputs):\\n\",\n    \"    print(f\\\"Q: {query}\\\\nA: {output}\\\\n\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 25,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import torch\\n\",\n    \"del model\\n\",\n    \"torch.cuda.empty_cache()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## PySpark\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import pandas as pd\\n\",\n    \"from pyspark.sql.types import *\\n\",\n    \"from pyspark import SparkConf\\n\",\n    \"from pyspark.sql import SparkSession\\n\",\n    \"from pyspark.sql.functions import pandas_udf, col, struct, length, lit, concat\\n\",\n    \"from pyspark.ml.functions import predict_batch_udf\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import os\\n\",\n    \"import datasets\\n\",\n    \"from datasets import load_dataset\\n\",\n    \"datasets.disable_progress_bars()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Create Spark Session\\n\",\n    \"\\n\",\n    \"For local standalone clusters, we'll connect to the cluster and create the Spark Session.  \\n\",\n    \"For CSP environments, Spark will either be preconfigured (Databricks) or we'll need to create the Spark Session (Dataproc).\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"25/02/16 11:48:57 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)\\n\",\n      \"25/02/16 11:48:57 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\\n\",\n      \"Setting default log level to \\\"WARN\\\".\\n\",\n      \"To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\\n\",\n      \"25/02/16 11:48:57 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"conf = SparkConf()\\n\",\n    \"\\n\",\n    \"if 'spark' not in globals():\\n\",\n    \"    if on_standalone:\\n\",\n    \"        import socket\\n\",\n    \"        conda_env = os.environ.get(\\\"CONDA_PREFIX\\\")\\n\",\n    \"        hostname = socket.gethostname()\\n\",\n    \"        conf.setMaster(f\\\"spark://{hostname}:7077\\\")\\n\",\n    \"        conf.set(\\\"spark.pyspark.python\\\", f\\\"{conda_env}/bin/python\\\")\\n\",\n    \"        conf.set(\\\"spark.pyspark.driver.python\\\", f\\\"{conda_env}/bin/python\\\")\\n\",\n    \"\\n\",\n    \"    conf.set(\\\"spark.executor.cores\\\", \\\"8\\\")\\n\",\n    \"    conf.set(\\\"spark.task.maxFailures\\\", \\\"1\\\")\\n\",\n    \"    conf.set(\\\"spark.task.resource.gpu.amount\\\", \\\"0.125\\\")\\n\",\n    \"    conf.set(\\\"spark.executor.resource.gpu.amount\\\", \\\"1\\\")\\n\",\n    \"    conf.set(\\\"spark.sql.execution.arrow.pyspark.enabled\\\", \\\"true\\\")\\n\",\n    \"    conf.set(\\\"spark.python.worker.reuse\\\", \\\"true\\\")\\n\",\n    \"\\n\",\n    \"spark = SparkSession.builder.appName(\\\"spark-dl-examples\\\").config(conf=conf).getOrCreate()\\n\",\n    \"sc = spark.sparkContext\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Load and Preprocess DataFrame\\n\",\n    \"\\n\",\n    \"Load the first 500 samples of the [PUBMED abstracts dataset](https://huggingface.co/datasets/casinca/PUBMED_title_abstracts_2019_baseline) from Huggingface and store in a Spark Dataframe.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"pubmed_dataset = load_dataset(\\\"casinca/PUBMED_title_abstracts_2019_baseline\\\", split=\\\"train\\\", streaming=True)\\n\",\n    \"pubmed_pds = pd.Series([sample[\\\"text\\\"] for sample in pubmed_dataset.take(500)])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"df = spark.createDataFrame(pubmed_pds, schema=StringType())\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 9,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+----------------------------------------------------------------------------------------------------+\\n\",\n      \"|                                                                                               value|\\n\",\n      \"+----------------------------------------------------------------------------------------------------+\\n\",\n      \"|Epidemiology of hypoxaemia in children with acute lower respiratory infection.\\\\nTo determine the ...|\\n\",\n      \"|Clinical signs of hypoxaemia in children with acute lower respiratory infection: indicators of ox...|\\n\",\n      \"|Hypoxaemia in children with severe pneumonia in Papua New Guinea.\\\\nTo investigate the severity an...|\\n\",\n      \"|Oxygen concentrators and cylinders.\\\\nA comparison is made between oxygen cylinders and oxygen con...|\\n\",\n      \"|Oxygen supply in rural africa: a personal experience.\\\\nOxygen is one of the essential medical sup...|\\n\",\n      \"+----------------------------------------------------------------------------------------------------+\\n\",\n      \"only showing top 5 rows\\n\",\n      \"\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"df.show(5, truncate=100)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Format each sample into the chat template, including a system prompt to guide generation.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 11,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"system_prompt = '''You are a knowledgeable AI assistant. Your job is to create a 2-3 sentence summary \\n\",\n    \"of a research abstract that captures the main objective, methodology, and key findings, using clear \\n\",\n    \"language while preserving technical accuracy and quantitative results.'''\\n\",\n    \"\\n\",\n    \"df = df.select(\\n\",\n    \"    concat(\\n\",\n    \"        lit(\\\"<|im_start|>system\\\\n\\\"),\\n\",\n    \"        lit(system_prompt),\\n\",\n    \"        lit(\\\"<|im_end|>\\\\n<|im_start|>user\\\\n\\\"),\\n\",\n    \"        col(\\\"value\\\"),\\n\",\n    \"        lit(\\\"<|im_end|>\\\\n<|im_start|>assistant\\\\n\\\")\\n\",\n    \"    ).alias(\\\"prompt\\\")\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"<|im_start|>system\\n\",\n      \"You are a knowledgeable AI assistant. Your job is to create a 2-3 sentence summary \\n\",\n      \"of a research abstract that captures the main objective, methodology, and key findings, using clear \\n\",\n      \"language while preserving technical accuracy and quantitative results.<|im_end|>\\n\",\n      \"<|im_start|>user\\n\",\n      \"Epidemiology of hypoxaemia in children with acute lower respiratory infection.\\n\",\n      \"To determine the prevalence of hypoxaemia in children aged under 5 years suffering acute lower respiratory infections (ALRI), the risk factors for hypoxaemia in children under 5 years of age with ALRI, and the association of hypoxaemia with an increased risk of dying in children of the same age. Systematic review of the published literature. Out-patient clinics, emergency departments and hospitalisation wards in 23 health centres from 10 countries. Cohort studies reporting the frequency of hypoxaemia in children under 5 years of age with ALRI, and the association between hypoxaemia and the risk of dying. Prevalence of hypoxaemia measured in children with ARI and relative risks for the association between the severity of illness and the frequency of hypoxaemia, and between hypoxaemia and the risk of dying. Seventeen published studies were found that included 4,021 children under 5 with acute respiratory infections (ARI) and reported the prevalence of hypoxaemia. Out-patient children and those with a clinical diagnosis of upper ARI had a low risk of hypoxaemia (pooled estimate of 6% to 9%). The prevalence increased to 31% and to 43% in patients in emergency departments and in cases with clinical pneumonia, respectively, and it was even higher among hospitalised children (47%) and in those with radiographically confirmed pneumonia (72%). The cumulated data also suggest that hypoxaemia is more frequent in children living at high altitude. Three papers reported an association between hypoxaemia and death, with relative risks varying between 1.4 and 4.6. Papers describing predictors of hypoxaemia have focused on clinical signs for detecting hypoxaemia rather than on identifying risk factors for developing this complication. Hypoxaemia is a common and potentially lethal complication of ALRI in children under 5, particularly among those with severe disease and those living at high altitude. Given the observed high prevalence of hypoxaemia and its likely association with increased mortality, efforts should be made to improve the detection of hypoxaemia and to provide oxygen earlier to more children with severe ALRI.<|im_end|>\\n\",\n      \"<|im_start|>assistant\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"print(df.take(1)[0].prompt)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 13,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"data_path = \\\"spark-dl-datasets/pubmed_abstracts\\\"\\n\",\n    \"if on_databricks:\\n\",\n    \"    dbutils.fs.mkdirs(\\\"/FileStore/spark-dl-datasets\\\")\\n\",\n    \"    data_path = \\\"dbfs:/FileStore/\\\" + data_path\\n\",\n    \"\\n\",\n    \"df.write.mode(\\\"overwrite\\\").parquet(data_path)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Using Triton Inference Server\\n\",\n    \"In this section, we demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL.  \\n\",\n    \"We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server.  \\n\",\n    \"\\n\",\n    \"The process looks like this:\\n\",\n    \"- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.\\n\",\n    \"- Define a Triton inference function, which contains a client that binds to the local server on a given node and sends inference requests.\\n\",\n    \"- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.\\n\",\n    \"- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.\\n\",\n    \"\\n\",\n    \"<img src=\\\"../images/spark-server.png\\\" alt=\\\"drawing\\\" width=\\\"700\\\"/>\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 14,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from functools import partial\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Import the helper class from server_utils.py:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 15,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"sc.addPyFile(\\\"server_utils.py\\\")\\n\",\n    \"\\n\",\n    \"from server_utils import TritonServerManager\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Define the Triton Server function:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 16,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def triton_server(ports, model_path):\\n\",\n    \"    import time\\n\",\n    \"    import signal\\n\",\n    \"    import torch\\n\",\n    \"    import numpy as np\\n\",\n    \"    from pytriton.decorators import batch\\n\",\n    \"    from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor\\n\",\n    \"    from pytriton.triton import Triton, TritonConfig\\n\",\n    \"    from pyspark import TaskContext\\n\",\n    \"    from transformers import AutoModelForCausalLM, AutoTokenizer\\n\",\n    \"\\n\",\n    \"    print(f\\\"SERVER: Initializing model on worker {TaskContext.get().partitionId()}.\\\")\\n\",\n    \"    model = AutoModelForCausalLM.from_pretrained(\\n\",\n    \"        model_path,\\n\",\n    \"        torch_dtype=torch.bfloat16,\\n\",\n    \"        device_map=\\\"auto\\\"\\n\",\n    \"    )\\n\",\n    \"    tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side=\\\"left\\\")\\n\",\n    \"\\n\",\n    \"    @batch\\n\",\n    \"    def _infer_fn(**inputs):\\n\",\n    \"        prompts = np.squeeze(inputs[\\\"prompts\\\"]).tolist()\\n\",\n    \"        print(f\\\"SERVER: Received batch of size {len(prompts)}\\\")\\n\",\n    \"        decoded_prompts = [p.decode(\\\"utf-8\\\") for p in prompts]\\n\",\n    \"        tokenized_inputs = tokenizer(decoded_prompts, padding=True, return_tensors=\\\"pt\\\").to(model.device)\\n\",\n    \"        generated_ids = model.generate(**tokenized_inputs, max_new_tokens=256)\\n\",\n    \"        outputs = tokenizer.batch_decode(generated_ids[:, tokenized_inputs.input_ids.shape[1]:], skip_special_tokens = True)\\n\",\n    \"        return {\\n\",\n    \"            \\\"outputs\\\": np.array(outputs).reshape(-1, 1)\\n\",\n    \"        }\\n\",\n    \"\\n\",\n    \"    workspace_path = f\\\"/tmp/triton_{time.strftime('%m_%d_%M_%S')}\\\"\\n\",\n    \"    triton_conf = TritonConfig(http_port=ports[0], grpc_port=ports[1], metrics_port=ports[2])\\n\",\n    \"    with Triton(config=triton_conf, workspace=workspace_path) as triton:\\n\",\n    \"        triton.bind(\\n\",\n    \"            model_name=\\\"qwen-2.5\\\",\\n\",\n    \"            infer_func=_infer_fn,\\n\",\n    \"            inputs=[\\n\",\n    \"                Tensor(name=\\\"prompts\\\", dtype=object, shape=(-1,)),\\n\",\n    \"            ],\\n\",\n    \"            outputs=[\\n\",\n    \"                Tensor(name=\\\"outputs\\\", dtype=object, shape=(-1,)),\\n\",\n    \"            ],\\n\",\n    \"            config=ModelConfig(\\n\",\n    \"                max_batch_size=64,\\n\",\n    \"                batcher=DynamicBatcher(max_queue_delay_microseconds=5000),  # 5ms\\n\",\n    \"            ),\\n\",\n    \"            strict=True,\\n\",\n    \"        )\\n\",\n    \"\\n\",\n    \"        def _stop_triton(signum, frame):\\n\",\n    \"            print(\\\"SERVER: Received SIGTERM. Stopping Triton server.\\\")\\n\",\n    \"            triton.stop()\\n\",\n    \"\\n\",\n    \"        signal.signal(signal.SIGTERM, _stop_triton)\\n\",\n    \"\\n\",\n    \"        print(\\\"SERVER: Serving inference\\\")\\n\",\n    \"        triton.serve()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Start Triton servers\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"The `TritonServerManager` will handle the lifecycle of Triton server instances across the Spark cluster:\\n\",\n    \"- Find available ports for HTTP/gRPC/metrics\\n\",\n    \"- Deploy a server on each node via stage-level scheduling\\n\",\n    \"- Gracefully shutdown servers across nodes\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 18,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"model_name = \\\"qwen-2.5\\\"\\n\",\n    \"server_manager = TritonServerManager(model_name=model_name, model_path=model_path)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 19,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2025-02-16 11:49:25,237 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\\n\",\n      \"2025-02-16 11:49:25,239 - INFO - Starting 1 servers.\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'cb4ae00-lcedt': (3490378, [7000, 7001, 7002])}\"\n      ]\n     },\n     \"execution_count\": 19,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"# Returns {'hostname', (server_pid, [http_port, grpc_port, metrics_port])}\\n\",\n    \"server_manager.start_servers(triton_server, wait_retries=24)  # allow up to 2 minutes for model loading\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Define client function\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Get the hostname -> url mapping from the server manager:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 20,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"host_to_grpc_url = server_manager.host_to_grpc_url  # or server_manager.host_to_http_url\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 21,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def triton_fn(model_name, host_to_url):\\n\",\n    \"    import socket\\n\",\n    \"    import numpy as np\\n\",\n    \"    from pytriton.client import ModelClient\\n\",\n    \"\\n\",\n    \"    url = host_to_url[socket.gethostname()]\\n\",\n    \"    print(f\\\"Connecting to Triton model {model_name} at {url}.\\\")\\n\",\n    \"\\n\",\n    \"    def infer_batch(inputs):\\n\",\n    \"        with ModelClient(url, model_name, inference_timeout_s=500) as client:\\n\",\n    \"            flattened = np.squeeze(inputs).tolist()\\n\",\n    \"            # Encode batch\\n\",\n    \"            encoded_batch = [[text.encode(\\\"utf-8\\\")] for text in flattened]\\n\",\n    \"            encoded_batch_np = np.array(encoded_batch, dtype=np.bytes_)\\n\",\n    \"            # Run inference\\n\",\n    \"            result_data = client.infer_batch(encoded_batch_np)\\n\",\n    \"            result_data = np.squeeze(result_data[\\\"outputs\\\"], -1)\\n\",\n    \"            return result_data\\n\",\n    \"        \\n\",\n    \"    return infer_batch\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 22,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"generate = predict_batch_udf(partial(triton_fn, model_name=model_name, host_to_url=host_to_grpc_url),\\n\",\n    \"                             return_type=StringType(),\\n\",\n    \"                             input_tensor_shapes=[[1]],\\n\",\n    \"                             batch_size=8)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Load DataFrame\\n\",\n    \"\\n\",\n    \"We'll parallelize over a small set of prompts for demonstration.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 23,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"df = spark.read.parquet(data_path).limit(64).repartition(8)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Run Inference\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 24,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 10:=====================>                                    (3 + 5) / 8]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 10.5 ms, sys: 6.63 ms, total: 17.1 ms\\n\",\n      \"Wall time: 23.7 s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"# first pass caches model/fn\\n\",\n    \"preds = df.withColumn(\\\"outputs\\\", generate(col(\\\"prompt\\\")))\\n\",\n    \"results = preds.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 25,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 16:=====================>                                    (3 + 5) / 8]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 8.1 ms, sys: 4.47 ms, total: 12.6 ms\\n\",\n      \"Wall time: 21.7 s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"preds = df.withColumn(\\\"outputs\\\", generate(col(\\\"prompt\\\")))\\n\",\n    \"results = preds.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 28,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Q: <|im_start|>system\\n\",\n      \"You are a knowledgeable AI assistant. Your job is to create a 2-3 sentence summary \\n\",\n      \"of a research abstract that captures the main objective, methodology, and key findings, using clear \\n\",\n      \"language while preserving technical accuracy and quantitative results.<|im_end|>\\n\",\n      \"<|im_start|>user\\n\",\n      \"Oral health promotion evaluation--time for development.\\n\",\n      \"Increasing emphasis is now being placed upon the evaluation of health service interventions to demonstrate their effects. A series of effectiveness reviews of the oral health education and promotion literature has demonstrated that many of these interventions are poorly and inadequately evaluated. It is therefore difficult to determine the effectiveness of many interventions. Based upon developments from the field of health promotion research this paper explores options for improving the quality of oral health promotion evaluation. It is essential that the methods and measures used in the evaluation of oral health promotion are appropriate to the intervention. For many oral health promotion interventions clinical measures and methods of evaluation may not be appropriate. This paper outlines an evaluation framework which can be used to assess the range of effects of oral health promotion programmes. Improving the quality of oral health promotion evaluation is a shared responsibility between researchers and those involved in the provision of programmes. The provision of adequate resources and training are essential requirements for this to be successfully achieved.<|im_end|>\\n\",\n      \"<|im_start|>assistant\\n\",\n      \" \\n\",\n      \"\\n\",\n      \"A: This research aims to improve the evaluation of oral health promotion programs by developing an appropriate framework. It explores how methods and measures should align with the specific nature of these interventions, emphasizing that both researchers and program providers must collaborate to ensure adequate resources and training are available for high-quality evaluations. \\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"print(f\\\"Q: {results[0].prompt} \\\\n\\\")\\n\",\n    \"print(f\\\"A: {results[0].outputs} \\\\n\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Shut down server on each executor\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 29,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2025-02-16 11:51:42,365 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2025-02-16 11:51:47,609 - INFO - Sucessfully stopped 1 servers.                 \\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"[True]\"\n      ]\n     },\n     \"execution_count\": 29,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"server_manager.stop_servers()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 30,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"if not on_databricks: # on databricks, spark.stop() puts the cluster in a bad state\\n\",\n    \"    spark.stop()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"spark-dl-torch\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.11.11\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/sentence_transformers_torch.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"777fc40d\",\n   \"metadata\": {},\n   \"source\": [\n    \"<img src=\\\"http://developer.download.nvidia.com/notebooks/dlsw-notebooks/tensorrt_torchtrt_efficientnet/nvidia_logo.png\\\" width=\\\"90px\\\">\\n\",\n    \"\\n\",\n    \"# PySpark Huggingface Inferencing\\n\",\n    \"### Sentence Transformers with PyTorch\\n\",\n    \"\\n\",\n    \"In this notebook, we demonstrate distributed inference with the Huggingface SentenceTransformer library for sentence embedding.  \\n\",\n    \"From: https://huggingface.co/sentence-transformers\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"id\": \"c5f0d0a8\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import torch\\n\",\n    \"from sentence_transformers import SentenceTransformer\\n\",\n    \"\\n\",\n    \"# Manually enable Huggingface tokenizer parallelism to avoid disabling with PySpark parallelism.\\n\",\n    \"# See (https://github.com/huggingface/transformers/issues/5486) for more info. \\n\",\n    \"import os\\n\",\n    \"os.environ[\\\"TOKENIZERS_PARALLELISM\\\"] = \\\"true\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"731faab7-a700-46f8-bba5-1c8764e5eacb\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"device = torch.device(\\\"cuda\\\" if torch.cuda.is_available() else \\\"cpu\\\")\\n\",\n    \"model = SentenceTransformer(\\\"paraphrase-MiniLM-L6-v2\\\", device=device)\\n\",\n    \"\\n\",\n    \"sentence = ['This framework generates embeddings for each input sentence']\\n\",\n    \"embedding = model.encode(sentence)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"id\": \"96eea5ca-3cf7-46e3-b40c-598538112d24\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[-0.17621444  0.1206013  -0.29362372 -0.22985819 -0.08229247  0.2377093\\n\",\n      \"  0.33998525 -0.7809643   0.11812777  0.16337365 -0.13771524  0.24028276\\n\",\n      \"  0.4251256   0.17241786  0.10527937  0.5181643   0.062222    0.39928585\\n\",\n      \" -0.18165241 -0.58557856]\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"print(embedding[0][:20])\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"546eabe0\",\n   \"metadata\": {},\n   \"source\": [\n    \"## PySpark\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"id\": \"dbda3e66-005a-4ad0-8017-c1cc7cbf0058\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from pyspark.sql.types import *\\n\",\n    \"from pyspark import SparkConf\\n\",\n    \"from pyspark.sql import SparkSession\\n\",\n    \"from pyspark.sql.functions import pandas_udf, col, struct\\n\",\n    \"from pyspark.ml.functions import predict_batch_udf\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"id\": \"b525c5c4\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import json\\n\",\n    \"import pandas as pd\\n\",\n    \"import datasets\\n\",\n    \"from datasets import load_dataset\\n\",\n    \"datasets.disable_progress_bars()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"58e7c1bc\",\n   \"metadata\": {},\n   \"source\": [\n    \"Check the cluster environment to handle any platform-specific Spark configurations.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"id\": \"5a013217\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"on_databricks = os.environ.get(\\\"DATABRICKS_RUNTIME_VERSION\\\", False)\\n\",\n    \"on_dataproc = os.environ.get(\\\"DATAPROC_IMAGE_VERSION\\\", False)\\n\",\n    \"on_standalone = not (on_databricks or on_dataproc)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"ad3c003d\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Create Spark Session\\n\",\n    \"\\n\",\n    \"For local standalone clusters, we'll connect to the cluster and create the Spark Session.  \\n\",\n    \"For CSP environments, Spark will either be preconfigured (Databricks) or we'll need to create the Spark Session (Dataproc).\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"id\": \"23ec67ba\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"25/02/04 13:40:01 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)\\n\",\n      \"25/02/04 13:40:01 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\\n\",\n      \"Setting default log level to \\\"WARN\\\".\\n\",\n      \"To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\\n\",\n      \"25/02/04 13:40:01 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"conf = SparkConf()\\n\",\n    \"\\n\",\n    \"if 'spark' not in globals():\\n\",\n    \"    if on_standalone:\\n\",\n    \"        import socket\\n\",\n    \"        conda_env = os.environ.get(\\\"CONDA_PREFIX\\\")\\n\",\n    \"        hostname = socket.gethostname()\\n\",\n    \"        conf.setMaster(f\\\"spark://{hostname}:7077\\\")\\n\",\n    \"        conf.set(\\\"spark.pyspark.python\\\", f\\\"{conda_env}/bin/python\\\")\\n\",\n    \"        conf.set(\\\"spark.pyspark.driver.python\\\", f\\\"{conda_env}/bin/python\\\")\\n\",\n    \"\\n\",\n    \"    conf.set(\\\"spark.executor.cores\\\", \\\"8\\\")\\n\",\n    \"    conf.set(\\\"spark.task.resource.gpu.amount\\\", \\\"0.125\\\")\\n\",\n    \"    conf.set(\\\"spark.executor.resource.gpu.amount\\\", \\\"1\\\")\\n\",\n    \"    conf.set(\\\"spark.sql.execution.arrow.pyspark.enabled\\\", \\\"true\\\")\\n\",\n    \"    conf.set(\\\"spark.python.worker.reuse\\\", \\\"true\\\")\\n\",\n    \"\\n\",\n    \"conf.set(\\\"spark.sql.execution.arrow.maxRecordsPerBatch\\\", \\\"1000\\\")\\n\",\n    \"spark = SparkSession.builder.appName(\\\"spark-dl-examples\\\").config(conf=conf).getOrCreate()\\n\",\n    \"sc = spark.sparkContext\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"4cfd1394\",\n   \"metadata\": {},\n   \"source\": [\n    \"Load the IMBD Movie Reviews dataset from Huggingface.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"id\": \"9bc1edb5\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"dataset = load_dataset(\\\"imdb\\\", split=\\\"test\\\")\\n\",\n    \"dataset = dataset.to_pandas().drop(columns=\\\"label\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"59c71bff\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Create PySpark DataFrame\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 9,\n   \"id\": \"836e5f84-12c6-4c95-838e-53de7e46a20b\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"StructType([StructField('text', StringType(), True)])\"\n      ]\n     },\n     \"execution_count\": 9,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"df = spark.createDataFrame(dataset).repartition(8)\\n\",\n    \"df.schema\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 10,\n   \"id\": \"36703d23-37a3-40df-b09a-c68206d285b6\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"25000\"\n      ]\n     },\n     \"execution_count\": 10,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"df.count()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 11,\n   \"id\": \"1f122ae3\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"25/02/04 13:40:08 WARN TaskSetManager: Stage 6 contains a task of very large size (4021 KiB). The maximum recommended task size is 1000 KiB.\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"[Row(text=\\\"Anyone remember the first CKY, CKY2K etc..? Back when it was about making crazy cool stuff, rather than watching Bam Margera act like a douchebag, spoiled 5 year old, super/rock-star wannabe.<br /><br />The show used to be awesome, however, Bam's fame and wealth has led him to believe, that we now enjoy him acting childish and idiotic, more than actual cool stuff, that used to be in ex. CKY2K.<br /><br />The acts are so repetitive, there's like nothing new, except annoying stupidity and rehearsed comments... The only things we see is Bam Margera, so busy showing us how much he doesn't care, how much money he got or whatsoever.<br /><br />I really got nothing much left to say except, give us back CKY2K, cause Bam suck..<br /><br />I enjoy watching Steve-o, Knoxville etc. a thousand times more.\\\")]\"\n      ]\n     },\n     \"execution_count\": 11,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"df.take(1)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 12,\n   \"id\": \"14fd59fb\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"25/02/04 13:40:08 WARN TaskSetManager: Stage 9 contains a task of very large size (4021 KiB). The maximum recommended task size is 1000 KiB.\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"data_path = \\\"spark-dl-datasets/imdb_test\\\"\\n\",\n    \"if on_databricks:\\n\",\n    \"    dbutils.fs.mkdirs(\\\"/FileStore/spark-dl-datasets\\\")\\n\",\n    \"    data_path = \\\"dbfs:/FileStore/\\\" + data_path\\n\",\n    \"\\n\",\n    \"df.write.mode(\\\"overwrite\\\").parquet(data_path)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"6bb083ec\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Load and preprocess DataFrame\\n\",\n    \"\\n\",\n    \"Define our preprocess function. We'll take the first sentence from each sample as our input for translation.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 13,\n   \"id\": \"2510bdd1\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"@pandas_udf(\\\"string\\\")\\n\",\n    \"def preprocess(text: pd.Series) -> pd.Series:\\n\",\n    \"    return pd.Series([s.split(\\\".\\\")[0] for s in text])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 14,\n   \"id\": \"5bb28548\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+----------------------------------------------------------------------------------------------------+\\n\",\n      \"|                                                                                               input|\\n\",\n      \"+----------------------------------------------------------------------------------------------------+\\n\",\n      \"|Doesn't anyone bother to check where this kind of sludge comes from before blathering on about it...|\\n\",\n      \"|                          There were two things I hated about WASTED : The directing and the script |\\n\",\n      \"|                                I'm rather surprised that anybody found this film touching or moving|\\n\",\n      \"|Cultural Vandalism Is the new Hallmark production of Gulliver's Travels an act of cultural vandal...|\\n\",\n      \"|I was at Wrestlemania VI in Toronto as a 10 year old, and the event I saw then was pretty differe...|\\n\",\n      \"|                                                                     This movie has been done before|\\n\",\n      \"|[ as a new resolution for this year 2005, i decide to write a comment for each movie I saw in the...|\\n\",\n      \"|This movie is over hyped!! I am sad to say that I manage to watch the first 15 minutes of this mo...|\\n\",\n      \"|This show had a promising start as sort of the opposite of 'Oceans 11' but has developed into a s...|\\n\",\n      \"|MINOR PLOT SPOILERS AHEAD!!!<br /><br />How did such talented actors get involved in such mindles...|\\n\",\n      \"|                              There is not one character on this sitcom with any redeeming qualities|\\n\",\n      \"|                                  Tommy Lee Jones was the best Woodroe and no one can play Woodroe F|\\n\",\n      \"|                                 My wife rented this movie and then conveniently never got to see it|\\n\",\n      \"|This is one of those star-filled over-the-top comedies that could a) be hysterical, or b) wish th...|\\n\",\n      \"|This excruciatingly boring and unfunny movie made me think that Chaplin was the real Hitler, as o...|\\n\",\n      \"|                           you will likely be sorely disappointed by this sequel that's not a sequel|\\n\",\n      \"|                          If I was British, I would be embarrassed by this portrayal of incompetence|\\n\",\n      \"|One of those movies in which there are no big twists whatsoever and you can predict pretty much w...|\\n\",\n      \"|                        This show is like watching someone who is in training to someday host a show|\\n\",\n      \"|                                                                                                Sigh|\\n\",\n      \"+----------------------------------------------------------------------------------------------------+\\n\",\n      \"only showing top 20 rows\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"# Limit to N rows, since this can be slow\\n\",\n    \"df = spark.read.parquet(data_path).limit(256).repartition(8)\\n\",\n    \"df = df.select(preprocess(col(\\\"text\\\")).alias(\\\"input\\\")).cache()\\n\",\n    \"df.show(truncate=100)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"014eae88\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Inference using Spark DL API\\n\",\n    \"\\n\",\n    \"Distributed inference using the PySpark [predict_batch_udf](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.functions.predict_batch_udf.html#pyspark.ml.functions.predict_batch_udf):\\n\",\n    \"\\n\",\n    \"- predict_batch_fn uses PyTorch APIs to load the model and return a predict function which operates on numpy arrays \\n\",\n    \"- predict_batch_udf will convert the Spark DataFrame columns into numpy input batches for the predict function\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 15,\n   \"id\": \"f780c026-0f3f-4aea-8b61-5b3dbae83fb7\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def predict_batch_fn():\\n\",\n    \"    import torch\\n\",\n    \"    from sentence_transformers import SentenceTransformer\\n\",\n    \"\\n\",\n    \"    device = torch.device(\\\"cuda\\\" if torch.cuda.is_available() else \\\"cpu\\\")\\n\",\n    \"    model = SentenceTransformer(\\\"paraphrase-MiniLM-L6-v2\\\", device=device)\\n\",\n    \"    def predict(inputs):\\n\",\n    \"        return model.encode(inputs.tolist())\\n\",\n    \"    return predict\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 16,\n   \"id\": \"f5c88ddc-ca19-4430-8b0e-b9fae143b237\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"encode = predict_batch_udf(predict_batch_fn,\\n\",\n    \"                           return_type=ArrayType(FloatType()),\\n\",\n    \"                           batch_size=32)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 17,\n   \"id\": \"85344c22-4a4d-4cb0-8771-5836ae2794db\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 18:=====================>                                    (3 + 5) / 8]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 10.6 ms, sys: 4.83 ms, total: 15.4 ms\\n\",\n      \"Wall time: 4.23 s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"# first pass caches model/fn\\n\",\n    \"embeddings = df.withColumn(\\\"embedding\\\", encode(struct(\\\"input\\\")))\\n\",\n    \"results = embeddings.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 18,\n   \"id\": \"c23bb885-6ab0-4471-943d-4c10414100fa\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 6.7 ms, sys: 2.44 ms, total: 9.15 ms\\n\",\n      \"Wall time: 163 ms\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"embeddings = df.withColumn(\\\"embedding\\\", encode(\\\"input\\\"))\\n\",\n    \"results = embeddings.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 19,\n   \"id\": \"93bc6da3-d853-4233-b805-cb4a46f4f9b9\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 5.37 ms, sys: 2.73 ms, total: 8.1 ms\\n\",\n      \"Wall time: 232 ms\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"embeddings = df.withColumn(\\\"embedding\\\", encode(col(\\\"input\\\")))\\n\",\n    \"results = embeddings.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 20,\n   \"id\": \"2073616f-7151-4760-92f2-441dd0bfe9fe\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+--------------------------------------------------+--------------------------------------------------+\\n\",\n      \"|                                             input|                                         embedding|\\n\",\n      \"+--------------------------------------------------+--------------------------------------------------+\\n\",\n      \"|Doesn't anyone bother to check where this kind ...|[0.118947476, -0.053823642, -0.29726124, 0.0720...|\\n\",\n      \"|There were two things I hated about WASTED : Th...|[0.18953452, 0.11079162, 0.07503566, 0.01050696...|\\n\",\n      \"|I'm rather surprised that anybody found this fi...|[-0.0010759671, -0.14203517, -0.06649738, 0.129...|\\n\",\n      \"|Cultural Vandalism Is the new Hallmark producti...|[0.34815887, -0.2966917, -0.10905265, 0.1051652...|\\n\",\n      \"|I was at Wrestlemania VI in Toronto as a 10 yea...|[0.45902696, 0.019472413, 0.28720972, -0.070724...|\\n\",\n      \"|                   This movie has been done before|[-0.062292397, -0.025909504, -0.031942524, 0.01...|\\n\",\n      \"|[ as a new resolution for this year 2005, i dec...|[0.3469342, -0.14378615, 0.30223376, -0.1102267...|\\n\",\n      \"|This movie is over hyped!! I am sad to say that...|[0.13230576, -0.06588756, 0.0472389, 0.08353163...|\\n\",\n      \"|This show had a promising start as sort of the ...|[-0.19361982, -0.14412567, 0.15149693, -0.17715...|\\n\",\n      \"|MINOR PLOT SPOILERS AHEAD!!!<br /><br />How did...|[-0.048036292, 0.050720096, -0.04668727, -0.316...|\\n\",\n      \"|There is not one character on this sitcom with ...|[0.13720773, -0.5963504, 0.30331734, -0.3830607...|\\n\",\n      \"|Tommy Lee Jones was the best Woodroe and no one...|[-0.20960267, -0.15760122, -0.30596405, -0.5181...|\\n\",\n      \"|My wife rented this movie and then conveniently...|[0.46534792, -0.40655977, 0.054217298, -0.03414...|\\n\",\n      \"|This is one of those star-filled over-the-top c...|[0.14433198, -0.016140658, 0.3775344, 0.0659043...|\\n\",\n      \"|This excruciatingly boring and unfunny movie ma...|[0.056464806, 0.01144963, -0.51797307, 0.089813...|\\n\",\n      \"|you will likely be sorely disappointed by this ...|[-0.44146675, -0.17866582, 0.49889183, -0.26819...|\\n\",\n      \"|If I was British, I would be embarrassed by thi...|[0.1191261, -0.15379854, 0.17487673, -0.5123498...|\\n\",\n      \"|One of those movies in which there are no big t...|[-0.016174048, -0.5558219, -0.024818476, 0.1543...|\\n\",\n      \"|This show is like watching someone who is in tr...|[0.033776704, -0.6682203, 0.30547586, -0.581407...|\\n\",\n      \"|                                              Sigh|[-0.119870394, 0.40893683, 0.4174831, -0.010004...|\\n\",\n      \"+--------------------------------------------------+--------------------------------------------------+\\n\",\n      \"only showing top 20 rows\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"embeddings.show(truncate=50)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"0c9c6535\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Using Triton Inference Server\\n\",\n    \"In this section, we demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL.  \\n\",\n    \"We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server.  \\n\",\n    \"\\n\",\n    \"The process looks like this:\\n\",\n    \"- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.\\n\",\n    \"- Define a Triton inference function, which contains a client that binds to the local server on a given node and sends inference requests.\\n\",\n    \"- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.\\n\",\n    \"- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.\\n\",\n    \"\\n\",\n    \"<img src=\\\"../images/spark-server.png\\\" alt=\\\"drawing\\\" width=\\\"700\\\"/>\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 21,\n   \"id\": \"772e337e-1098-4c7b-ba81-8cb221a518e2\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from functools import partial\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"759385ac\",\n   \"metadata\": {},\n   \"source\": [\n    \"Import the helper class from server_utils.py:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 22,\n   \"id\": \"485fb0de\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"sc.addPyFile(\\\"server_utils.py\\\")\\n\",\n    \"\\n\",\n    \"from server_utils import TritonServerManager\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"ece5c38a\",\n   \"metadata\": {},\n   \"source\": [\n    \"Define the Triton Server function:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 23,\n   \"id\": \"69d0c93a-bb0b-46c5-9d28-7b08a2e70964\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def triton_server(ports):\\n\",\n    \"    import time\\n\",\n    \"    import signal\\n\",\n    \"    import numpy as np\\n\",\n    \"    import torch\\n\",\n    \"    from sentence_transformers import SentenceTransformer\\n\",\n    \"    from pytriton.decorators import batch\\n\",\n    \"    from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor\\n\",\n    \"    from pytriton.triton import Triton, TritonConfig\\n\",\n    \"    from pyspark import TaskContext\\n\",\n    \"\\n\",\n    \"    print(f\\\"SERVER: Initializing sentence transformer on worker {TaskContext.get().partitionId()}.\\\")\\n\",\n    \"    device = torch.device(\\\"cuda\\\" if torch.cuda.is_available() else \\\"cpu\\\")\\n\",\n    \"    model = SentenceTransformer(\\\"paraphrase-MiniLM-L6-v2\\\", device=device)\\n\",\n    \"    print(f\\\"SERVER: Using {device} device.\\\")\\n\",\n    \"\\n\",\n    \"    @batch\\n\",\n    \"    def _infer_fn(**inputs):\\n\",\n    \"        sentences = np.squeeze(inputs[\\\"text\\\"])\\n\",\n    \"        print(f\\\"SERVER: Received batch of size {len(sentences)}\\\")\\n\",\n    \"        decoded_sentences = [s.decode(\\\"utf-8\\\") for s in sentences]\\n\",\n    \"        embeddings = model.encode(decoded_sentences)\\n\",\n    \"        return {\\n\",\n    \"            \\\"embeddings\\\": embeddings,\\n\",\n    \"        }\\n\",\n    \"\\n\",\n    \"    workspace_path = f\\\"/tmp/triton_{time.strftime('%m_%d_%M_%S')}\\\"\\n\",\n    \"    triton_conf = TritonConfig(http_port=ports[0], grpc_port=ports[1], metrics_port=ports[2])\\n\",\n    \"    with Triton(config=triton_conf, workspace=workspace_path) as triton:\\n\",\n    \"        triton.bind(\\n\",\n    \"            model_name=\\\"SentenceTransformer\\\",\\n\",\n    \"            infer_func=_infer_fn,\\n\",\n    \"            inputs=[\\n\",\n    \"                Tensor(name=\\\"text\\\", dtype=object, shape=(-1,)),\\n\",\n    \"            ],\\n\",\n    \"            outputs=[\\n\",\n    \"                Tensor(name=\\\"embeddings\\\", dtype=np.float32, shape=(-1,)),\\n\",\n    \"            ],\\n\",\n    \"            config=ModelConfig(\\n\",\n    \"                max_batch_size=64,\\n\",\n    \"                batcher=DynamicBatcher(max_queue_delay_microseconds=5000),  # 5ms\\n\",\n    \"            ),\\n\",\n    \"            strict=True,\\n\",\n    \"        )\\n\",\n    \"\\n\",\n    \"        def _stop_triton(signum, frame):\\n\",\n    \"            # The server manager sends SIGTERM to stop the server; this function ensures graceful cleanup.\\n\",\n    \"            print(\\\"SERVER: Received SIGTERM. Stopping Triton server.\\\")\\n\",\n    \"            triton.stop()\\n\",\n    \"\\n\",\n    \"        signal.signal(signal.SIGTERM, _stop_triton)\\n\",\n    \"\\n\",\n    \"        print(\\\"SERVER: Serving inference\\\")\\n\",\n    \"        triton.serve()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"79532110\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Start Triton servers\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"1b0371c8\",\n   \"metadata\": {},\n   \"source\": [\n    \"The `TritonServerManager` will handle the lifecycle of Triton server instances across the Spark cluster:\\n\",\n    \"- Find available ports for HTTP/gRPC/metrics\\n\",\n    \"- Deploy a server on each node via stage-level scheduling\\n\",\n    \"- Gracefully shutdown servers across nodes\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 25,\n   \"id\": \"e66e8927\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"model_name = \\\"SentenceTransformer\\\"\\n\",\n    \"server_manager = TritonServerManager(model_name=model_name)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"040df0dd\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2025-02-07 11:03:44,809 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\\n\",\n      \"2025-02-07 11:03:44,810 - INFO - Starting 1 servers.\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'cb4ae00-lcedt': (2020631, [7000, 7001, 7002])}\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"# Returns {'hostname', (server_pid, [http_port, grpc_port, metrics_port])}\\n\",\n    \"server_manager.start_servers(triton_server)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"1fd19fae\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Define client function\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"ddeadc74\",\n   \"metadata\": {},\n   \"source\": [\n    \"Get the hostname -> url mapping from the server manager:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"c42d1578\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"host_to_http_url = server_manager.host_to_http_url  # or server_manager.host_to_grpc_url\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 28,\n   \"id\": \"807dbc45\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def triton_fn(model_name, host_to_url):\\n\",\n    \"    import socket\\n\",\n    \"    import numpy as np\\n\",\n    \"    from pytriton.client import ModelClient\\n\",\n    \"\\n\",\n    \"    url = host_to_url[socket.gethostname()]\\n\",\n    \"    print(f\\\"Connecting to Triton model {model_name} at {url}.\\\")\\n\",\n    \"\\n\",\n    \"    def infer_batch(inputs):\\n\",\n    \"        with ModelClient(url, model_name, inference_timeout_s=240) as client:\\n\",\n    \"            flattened = np.squeeze(inputs).tolist()\\n\",\n    \"            # Encode batch\\n\",\n    \"            encoded_batch = [[text.encode(\\\"utf-8\\\")] for text in flattened]\\n\",\n    \"            encoded_batch_np = np.array(encoded_batch, dtype=np.bytes_)\\n\",\n    \"            # Run inference\\n\",\n    \"            result_data = client.infer_batch(encoded_batch_np)\\n\",\n    \"            return result_data[\\\"embeddings\\\"]\\n\",\n    \"        \\n\",\n    \"    return infer_batch\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 31,\n   \"id\": \"9c712b8f-6eb4-4fb8-9f0a-04feef847fea\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"encode = predict_batch_udf(partial(triton_fn, model_name=model_name, host_to_url=host_to_http_url),\\n\",\n    \"                           return_type=ArrayType(FloatType()),\\n\",\n    \"                           input_tensor_shapes=[[1]],\\n\",\n    \"                           batch_size=32)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"af174106\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Load and preprocess DataFrame\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 29,\n   \"id\": \"2969d502-e97b-49d6-bf80-7d177ae867cf\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"@pandas_udf(\\\"string\\\")\\n\",\n    \"def preprocess(text: pd.Series) -> pd.Series:\\n\",\n    \"    return pd.Series([s.split(\\\".\\\")[0] for s in text])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 30,\n   \"id\": \"c8f1e6d6-6519-49e7-8465-4419547633b8\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"25/02/04 13:40:22 WARN CacheManager: Asked to cache already cached data.\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"df = spark.read.parquet(data_path).limit(256).repartition(8)\\n\",\n    \"df = df.select(preprocess(col(\\\"text\\\")).alias(\\\"input\\\")).cache()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"cf0ee731\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Run Inference\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 32,\n   \"id\": \"934c1a1f-b126-45b0-9c15-265236820ad3\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 5.59 ms, sys: 5.1 ms, total: 10.7 ms\\n\",\n      \"Wall time: 605 ms\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"# first pass caches model/fn\\n\",\n    \"embeddings = df.withColumn(\\\"embedding\\\", encode(struct(\\\"input\\\")))\\n\",\n    \"results = embeddings.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 33,\n   \"id\": \"f84cd3f6-b6a8-4142-859a-91f3c183457b\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 2.57 ms, sys: 4.36 ms, total: 6.93 ms\\n\",\n      \"Wall time: 161 ms\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"embeddings = df.withColumn(\\\"embedding\\\", encode(\\\"input\\\"))\\n\",\n    \"results = embeddings.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 34,\n   \"id\": \"921a4c01-e296-4406-be90-86f20c8c582d\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 7.06 ms, sys: 605 μs, total: 7.67 ms\\n\",\n      \"Wall time: 191 ms\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"embeddings = df.withColumn(\\\"embedding\\\", encode(col(\\\"input\\\")))\\n\",\n    \"results = embeddings.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 35,\n   \"id\": \"9f67584e-9c4e-474f-b6ea-7811b14d116e\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+--------------------------------------------------+--------------------------------------------------+\\n\",\n      \"|                                             input|                                         embedding|\\n\",\n      \"+--------------------------------------------------+--------------------------------------------------+\\n\",\n      \"|Doesn't anyone bother to check where this kind ...|[0.118947476, -0.053823642, -0.29726124, 0.0720...|\\n\",\n      \"|There were two things I hated about WASTED : Th...|[0.18953452, 0.11079162, 0.07503566, 0.01050696...|\\n\",\n      \"|I'm rather surprised that anybody found this fi...|[-0.0010759671, -0.14203517, -0.06649738, 0.129...|\\n\",\n      \"|Cultural Vandalism Is the new Hallmark producti...|[0.34815887, -0.2966917, -0.10905265, 0.1051652...|\\n\",\n      \"|I was at Wrestlemania VI in Toronto as a 10 yea...|[0.45902696, 0.019472413, 0.28720972, -0.070724...|\\n\",\n      \"|                   This movie has been done before|[-0.062292397, -0.025909504, -0.031942524, 0.01...|\\n\",\n      \"|[ as a new resolution for this year 2005, i dec...|[0.3469342, -0.14378615, 0.30223376, -0.1102267...|\\n\",\n      \"|This movie is over hyped!! I am sad to say that...|[0.13230576, -0.06588756, 0.0472389, 0.08353163...|\\n\",\n      \"|This show had a promising start as sort of the ...|[-0.19361982, -0.14412567, 0.15149693, -0.17715...|\\n\",\n      \"|MINOR PLOT SPOILERS AHEAD!!!<br /><br />How did...|[-0.048036292, 0.050720096, -0.04668727, -0.316...|\\n\",\n      \"|There is not one character on this sitcom with ...|[0.13720773, -0.5963504, 0.30331734, -0.3830607...|\\n\",\n      \"|Tommy Lee Jones was the best Woodroe and no one...|[-0.20960267, -0.15760122, -0.30596405, -0.5181...|\\n\",\n      \"|My wife rented this movie and then conveniently...|[0.46534792, -0.40655977, 0.054217298, -0.03414...|\\n\",\n      \"|This is one of those star-filled over-the-top c...|[0.14433198, -0.016140658, 0.3775344, 0.0659043...|\\n\",\n      \"|This excruciatingly boring and unfunny movie ma...|[0.056464806, 0.01144963, -0.51797307, 0.089813...|\\n\",\n      \"|you will likely be sorely disappointed by this ...|[-0.44146675, -0.17866582, 0.49889183, -0.26819...|\\n\",\n      \"|If I was British, I would be embarrassed by thi...|[0.1191261, -0.15379854, 0.17487673, -0.5123498...|\\n\",\n      \"|One of those movies in which there are no big t...|[-0.016174048, -0.5558219, -0.024818476, 0.1543...|\\n\",\n      \"|This show is like watching someone who is in tr...|[0.033776704, -0.6682203, 0.30547586, -0.581407...|\\n\",\n      \"|                                              Sigh|[-0.119870394, 0.40893683, 0.4174831, -0.010004...|\\n\",\n      \"+--------------------------------------------------+--------------------------------------------------+\\n\",\n      \"only showing top 20 rows\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"embeddings.show(truncate=50)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"e3b0077c-785f-41af-9fa9-812e7fb63810\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"source\": [\n    \"#### Stop Triton Server on each executor\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 36,\n   \"id\": \"ef780e30\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2025-02-04 13:40:23,196 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\\n\",\n      \"2025-02-04 13:40:28,390 - INFO - Sucessfully stopped 1 servers.                 \\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"[True]\"\n      ]\n     },\n     \"execution_count\": 36,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"server_manager.stop_servers()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 37,\n   \"id\": \"e82b9518-da7b-4ebc-8990-c8ab909bec18\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"if not on_databricks: # on databricks, spark.stop() puts the cluster in a bad state\\n\",\n    \"    spark.stop()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"33a60f2d-295a-4270-a2fd-16559962edda\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"spark-dl-torch\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.11.11\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "examples/ML+DL-Examples/Spark-DL/dl_inference/pytorch/housing_regression_torch.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"792d95f9\",\n   \"metadata\": {},\n   \"source\": [\n    \"<img src=\\\"http://developer.download.nvidia.com/notebooks/dlsw-notebooks/tensorrt_torchtrt_efficientnet/nvidia_logo.png\\\" width=\\\"90px\\\">\\n\",\n    \"\\n\",\n    \"# PySpark PyTorch Inference\\n\",\n    \"\\n\",\n    \"### Regression\\n\",\n    \"\\n\",\n    \"In this notebook, we will train an MLP to perform regression on the California housing dataset, and load it for distributed inference with Spark.  \\n\",\n    \"\\n\",\n    \"Based on: https://github.com/christianversloot/machine-learning-articles/blob/main/how-to-create-a-neural-network-for-regression-with-pytorch.md  \\n\",\n    \"\\n\",\n    \"We also demonstrate accelerated inference via Torch-TensorRT model compilation.   \"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"id\": \"75930360-c5ce-49ef-a69a-da88fa69a2ef\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import torch\\n\",\n    \"import os\\n\",\n    \"import shutil\\n\",\n    \"import numpy as np\\n\",\n    \"from torch import nn\\n\",\n    \"from torch.utils.data import DataLoader\\n\",\n    \"from sklearn.datasets import fetch_california_housing\\n\",\n    \"from sklearn.preprocessing import StandardScaler\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"id\": \"1de685f4\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"os.mkdir('models') if not os.path.exists('models') else None\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"id\": \"6d5bc0c7\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"'2.5.1+cu124'\"\n      ]\n     },\n     \"execution_count\": 3,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"torch.__version__\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"8754b174\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Load Dataset\\n\",\n    \"\\n\",\n    \"Each label corresponds to the average house value in units of 100,000, which we'll try to predict using the following features:  \\n\",\n    \"['MedInc', 'HouseAge', 'AveRooms', 'AveBedrms', 'Population', 'AveOccup', 'Latitude', 'Longitude']\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"id\": \"2bee64cf-a44a-4aff-82db-c64ee3a8b0e8\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"X, y = fetch_california_housing(return_X_y=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"id\": \"8644e508-5e4c-4cdd-9ed1-9235887d9659\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"class HousingDataset(torch.utils.data.Dataset):\\n\",\n    \"    def __init__(self, X, y, scale_data=True):\\n\",\n    \"        if not torch.is_tensor(X) and not torch.is_tensor(y):\\n\",\n    \"            # Apply scaling if necessary\\n\",\n    \"            if scale_data:\\n\",\n    \"                X = StandardScaler().fit_transform(X)\\n\",\n    \"            self.X = torch.from_numpy(X.astype(np.float32))\\n\",\n    \"            self.y = torch.from_numpy(y.astype(np.float32))\\n\",\n    \"\\n\",\n    \"    def __len__(self):\\n\",\n    \"        return len(self.X)\\n\",\n    \"\\n\",\n    \"    def __getitem__(self, i):\\n\",\n    \"        return self.X[i], self.y[i]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"id\": \"cc6b55c3-dc7b-4831-9943-83efd48091bf\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"dataset = HousingDataset(X, y)\\n\",\n    \"trainloader = torch.utils.data.DataLoader(\\n\",\n    \"    dataset, batch_size=10, shuffle=True, num_workers=1)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"d868f39d-4695-4110-91d2-6f7a09d73b93\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"[tensor([[ 6.5799e-01,  4.2594e-01, -1.4755e-01, -2.3638e-01, -4.0221e-01,\\n\",\n       \"          -5.6793e-02,  8.8868e-01, -1.3528e+00],\\n\",\n       \"         [ 6.7288e-01, -1.0043e+00,  5.7486e-01, -1.6537e-01, -3.3422e-01,\\n\",\n       \"          -6.4971e-02, -1.2790e+00,  1.2327e+00],\\n\",\n       \"         [-1.1616e-01,  2.8646e-02, -1.7830e-01, -2.3817e-01, -6.7154e-01,\\n\",\n       \"          -3.6429e-02, -1.3258e+00,  1.2726e+00],\\n\",\n       \"         [-3.2513e-01, -6.8648e-01, -3.4226e-01, -8.2805e-02,  5.1239e+00,\\n\",\n       \"           2.6689e-02, -7.7338e-01,  8.3340e-01],\\n\",\n       \"         [ 1.0892e-01, -1.2427e+00,  2.7819e-01, -8.7150e-02,  3.0158e-01,\\n\",\n       \"          -1.8564e-02, -1.1245e+00,  1.1628e+00],\\n\",\n       \"         [-8.6416e-02,  5.8485e-01, -7.8085e-02,  8.1655e-02, -6.7154e-01,\\n\",\n       \"          -1.6053e-02, -3.4733e-01,  1.2577e+00],\\n\",\n       \"         [-1.2463e-01,  1.0810e-01,  2.6662e-01, -1.0883e-01,  3.4839e-01,\\n\",\n       \"          -2.3125e-02, -7.7338e-01,  1.3325e+00],\\n\",\n       \"         [-9.2662e-01, -1.6400e+00, -2.4824e-01,  6.0041e-01,  6.3361e-01,\\n\",\n       \"          -1.0926e-01, -8.8574e-01,  1.2826e+00],\\n\",\n       \"         [ 2.0038e+00, -6.0702e-01,  8.4770e-01, -2.1254e-01,  1.3745e+00,\\n\",\n       \"          -5.0489e-03, -6.5165e-01,  2.5441e-01],\\n\",\n       \"         [-3.9250e-01,  1.0616e+00, -1.8614e-01, -1.7073e-01, -3.8543e-01,\\n\",\n       \"          -8.1186e-02,  1.0806e+00, -1.3827e+00]]),\\n\",\n       \" tensor([3.1090, 1.8430, 1.6890, 1.8670, 1.9600, 0.6200, 0.9860, 0.9440, 3.9120,\\n\",\n       \"         1.4390])]\"\n      ]\n     },\n     \"execution_count\": 7,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"next(iter(trainloader))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"1e817b9a\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Create and Train Model\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"id\": \"9a441b60-dca4-44d2-bc1c-aa7336d704bb\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"class MLP(nn.Module):\\n\",\n    \"    def __init__(self):\\n\",\n    \"        super().__init__()\\n\",\n    \"        self.layers = nn.Sequential(\\n\",\n    \"            nn.Linear(8, 64),\\n\",\n    \"            nn.ReLU(),\\n\",\n    \"            nn.Linear(64, 32),\\n\",\n    \"            nn.ReLU(),\\n\",\n    \"            nn.Linear(32, 1)\\n\",\n    \"        )\\n\",\n    \"\\n\",\n    \"    def forward(self, x):\\n\",\n    \"        return self.layers(x)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"15cff2b4-9d23-4d2b-808a-a5edb8eda135\",\n   \"metadata\": {\n    \"scrolled\": true,\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Using cuda device\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"# Initialize the MLP\\n\",\n    \"device = \\\"cuda\\\" if torch.cuda.is_available() else \\\"cpu\\\"\\n\",\n    \"print(f\\\"Using {device} device\\\")\\n\",\n    \"mlp = MLP().to(device)\\n\",\n    \"\\n\",\n    \"# Define the loss function and optimizer\\n\",\n    \"loss_function = nn.L1Loss()\\n\",\n    \"optimizer = torch.optim.Adam(mlp.parameters(), lr=1e-4)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 10,\n   \"id\": \"5e2db3f9-5db8-4b42-89ad-e77f23c4c1fe\",\n   \"metadata\": {\n    \"scrolled\": true,\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Starting epoch 1\\n\",\n      \"Loss after mini-batch     1: 0.004\\n\",\n      \"Loss after mini-batch   201: 0.701\\n\",\n      \"Loss after mini-batch   401: 0.463\\n\",\n      \"Loss after mini-batch   601: 0.329\\n\",\n      \"Loss after mini-batch   801: 0.285\\n\",\n      \"Loss after mini-batch  1001: 0.253\\n\",\n      \"Loss after mini-batch  1201: 0.247\\n\",\n      \"Loss after mini-batch  1401: 0.234\\n\",\n      \"Loss after mini-batch  1601: 0.232\\n\",\n      \"Loss after mini-batch  1801: 0.217\\n\",\n      \"Loss after mini-batch  2001: 0.211\\n\",\n      \"Starting epoch 2\\n\",\n      \"Loss after mini-batch     1: 0.001\\n\",\n      \"Loss after mini-batch   201: 0.205\\n\",\n      \"Loss after mini-batch   401: 0.212\\n\",\n      \"Loss after mini-batch   601: 0.206\\n\",\n      \"Loss after mini-batch   801: 0.205\\n\",\n      \"Loss after mini-batch  1001: 0.202\\n\",\n      \"Loss after mini-batch  1201: 0.202\\n\",\n      \"Loss after mini-batch  1401: 0.204\\n\",\n      \"Loss after mini-batch  1601: 0.198\\n\",\n      \"Loss after mini-batch  1801: 0.188\\n\",\n      \"Loss after mini-batch  2001: 0.188\\n\",\n      \"Starting epoch 3\\n\",\n      \"Loss after mini-batch     1: 0.001\\n\",\n      \"Loss after mini-batch   201: 0.197\\n\",\n      \"Loss after mini-batch   401: 0.193\\n\",\n      \"Loss after mini-batch   601: 0.196\\n\",\n      \"Loss after mini-batch   801: 0.189\\n\",\n      \"Loss after mini-batch  1001: 0.183\\n\",\n      \"Loss after mini-batch  1201: 0.191\\n\",\n      \"Loss after mini-batch  1401: 0.193\\n\",\n      \"Loss after mini-batch  1601: 0.181\\n\",\n      \"Loss after mini-batch  1801: 0.185\\n\",\n      \"Loss after mini-batch  2001: 0.181\\n\",\n      \"Starting epoch 4\\n\",\n      \"Loss after mini-batch     1: 0.001\\n\",\n      \"Loss after mini-batch   201: 0.190\\n\",\n      \"Loss after mini-batch   401: 0.181\\n\",\n      \"Loss after mini-batch   601: 0.189\\n\",\n      \"Loss after mini-batch   801: 0.180\\n\",\n      \"Loss after mini-batch  1001: 0.184\\n\",\n      \"Loss after mini-batch  1201: 0.180\\n\",\n      \"Loss after mini-batch  1401: 0.180\\n\",\n      \"Loss after mini-batch  1601: 0.184\\n\",\n      \"Loss after mini-batch  1801: 0.186\\n\",\n      \"Loss after mini-batch  2001: 0.179\\n\",\n      \"Starting epoch 5\\n\",\n      \"Loss after mini-batch     1: 0.000\\n\",\n      \"Loss after mini-batch   201: 0.181\\n\",\n      \"Loss after mini-batch   401: 0.177\\n\",\n      \"Loss after mini-batch   601: 0.185\\n\",\n      \"Loss after mini-batch   801: 0.179\\n\",\n      \"Loss after mini-batch  1001: 0.178\\n\",\n      \"Loss after mini-batch  1201: 0.173\\n\",\n      \"Loss after mini-batch  1401: 0.185\\n\",\n      \"Loss after mini-batch  1601: 0.177\\n\",\n      \"Loss after mini-batch  1801: 0.181\\n\",\n      \"Loss after mini-batch  2001: 0.178\\n\",\n      \"Training process has finished.\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"# Run the training loop\\n\",\n    \"for epoch in range(0, 5):  # 5 epochs at maximum\\n\",\n    \"\\n\",\n    \"    # Print epoch\\n\",\n    \"    print(f'Starting epoch {epoch+1}')\\n\",\n    \"\\n\",\n    \"    # Set current loss value\\n\",\n    \"    current_loss = 0.0\\n\",\n    \"\\n\",\n    \"    # Iterate over the DataLoader for training data\\n\",\n    \"    for i, data in enumerate(trainloader, 0):\\n\",\n    \"\\n\",\n    \"        # Get and prepare inputs\\n\",\n    \"        inputs, targets = data\\n\",\n    \"        inputs, targets = inputs.to(device), targets.to(device)\\n\",\n    \"        targets = targets.reshape((targets.shape[0], 1))\\n\",\n    \"\\n\",\n    \"        # Zero the gradients\\n\",\n    \"        optimizer.zero_grad()\\n\",\n    \"\\n\",\n    \"        # Perform forward pass\\n\",\n    \"        outputs = mlp(inputs)\\n\",\n    \"\\n\",\n    \"        # Compute loss\\n\",\n    \"        loss = loss_function(outputs, targets)\\n\",\n    \"\\n\",\n    \"        # Perform backward pass\\n\",\n    \"        loss.backward()\\n\",\n    \"\\n\",\n    \"        # Perform optimization\\n\",\n    \"        optimizer.step()\\n\",\n    \"\\n\",\n    \"        # Print statistics\\n\",\n    \"        current_loss += loss.item()\\n\",\n    \"        if i % 200 == 0:\\n\",\n    \"            print('Loss after mini-batch %5d: %.3f' %\\n\",\n    \"                  (i + 1, current_loss / 500))\\n\",\n    \"            current_loss = 0.0\\n\",\n    \"\\n\",\n    \"# Process is complete.\\n\",\n    \"print('Training process has finished.')\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"352539f5\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Save Model State Dict\\n\",\n    \"This saves the serialized object to disk using pickle.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 11,\n   \"id\": \"b950a3ed-ffe1-477f-a84f-f71c85dbf9ce\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Saved PyTorch Model State to models/housing_model.pt\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"torch.save(mlp.state_dict(), \\\"models/housing_model.pt\\\")\\n\",\n    \"print(\\\"Saved PyTorch Model State to models/housing_model.pt\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"0060fcca\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Save Model as TorchScript\\n\",\n    \"This saves an [intermediate representation of the compute graph](https://pytorch.org/tutorials/beginner/saving_loading_models.html#export-load-model-in-torchscript-format), which does not require pickle (or even python). \"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"20fedb5d-c59e-4b0b-ba91-3dd15df1f09e\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Saved TorchScript Model to models/ts_housing_model.pt\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"scripted = torch.jit.script(mlp)\\n\",\n    \"scripted.save(\\\"models/ts_housing_model.pt\\\")\\n\",\n    \"print(\\\"Saved TorchScript Model to models/ts_housing_model.pt\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"3101c0fe-65f1-411e-9192-e8a6b585ba0d\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Load and Test from Model State\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 13,\n   \"id\": \"7411b00f-88d2-40f5-b716-a26733c968ff\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"<All keys matched successfully>\"\n      ]\n     },\n     \"execution_count\": 13,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"loaded_mlp = MLP().to(device)\\n\",\n    \"loaded_mlp.load_state_dict(torch.load(\\\"models/housing_model.pt\\\", weights_only=True))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 14,\n   \"id\": \"e226f449-2931-4492-9003-503cdc61f061\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"testX, testY = next(iter(trainloader))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"d46af47e-db7e-42ee-9bd3-6e7d93850be3\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Predictions:\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"tensor([[2.3652],\\n\",\n       \"        [1.8444],\\n\",\n       \"        [2.4587],\\n\",\n       \"        [3.1243],\\n\",\n       \"        [2.2726],\\n\",\n       \"        [2.1818],\\n\",\n       \"        [1.5222],\\n\",\n       \"        [0.5554],\\n\",\n       \"        [2.2508],\\n\",\n       \"        [3.5971]], device='cuda:0', grad_fn=<AddmmBackward0>)\"\n      ]\n     },\n     \"execution_count\": 15,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"print(\\\"Predictions:\\\")\\n\",\n    \"loaded_mlp(testX.to(device))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 16,\n   \"id\": \"13ae2c0f-1da5-45a4-bf32-ed8b562d7907\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Labels:\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"tensor([2.7370, 2.2110, 2.5360, 2.6330, 1.6540, 2.3360, 1.4600, 0.6590, 2.6380,\\n\",\n       \"        3.6220])\"\n      ]\n     },\n     \"execution_count\": 16,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"print(\\\"Labels:\\\")\\n\",\n    \"testY\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"3bcd329d\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Load and Test from TorchScript\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 17,\n   \"id\": \"422e317f-c9bd-4f76-9463-7af2935d401d\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"scripted_mlp = torch.jit.load(\\\"models/ts_housing_model.pt\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"0cda8ec8-644e-4888-bfa0-b79425ece7c3\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Predictions:\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"tensor([2.3652, 1.8444, 2.4587, 3.1243, 2.2726, 2.1818, 1.5222, 0.5554, 2.2508,\\n\",\n       \"        3.5971], device='cuda:0', grad_fn=<ViewBackward0>)\"\n      ]\n     },\n     \"execution_count\": 18,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"print(\\\"Predictions:\\\")\\n\",\n    \"scripted_mlp(testX.to(device)).flatten()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"2a3b64e4\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Compile using the Torch JIT Compiler\\n\",\n    \"This leverages the [Torch-TensorRT inference compiler](https://pytorch.org/TensorRT/) for accelerated inference on GPUs using the `torch.compile` JIT interface under the hood. The compiler stack returns a [boxed-function](http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/) that triggers compilation on the first call.  \\n\",\n    \"\\n\",\n    \"Modules compiled in this fashion are [not serializable with pickle](https://github.com/pytorch/pytorch/issues/101107#issuecomment-1542688089), so we cannot send the compiled model directly to Spark.  \"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"c613f24e\",\n   \"metadata\": {},\n   \"source\": [\n    \"(You may see a warning about modelopt quantization. This is safe to ignore, as [implicit quantization](https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#intro-quantization) is deprecated in the latest TensorRT. See [this link](https://pytorch.org/TensorRT/tutorials/_rendered_examples/dynamo/vgg16_fp8_ptq.html) for a guide to explicit quantization.)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 19,\n   \"id\": \"9ffb27fc\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import torch_tensorrt as trt\\n\",\n    \"import time\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 20,\n   \"id\": \"e0c10f90\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Optional: set the filename for the TensorRT timing cache\\n\",\n    \"timestamp = time.time()\\n\",\n    \"timing_cache = f\\\"/tmp/timing_cache-{timestamp}.bin\\\"\\n\",\n    \"with open(timing_cache, \\\"wb\\\") as f:\\n\",\n    \"    pass\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 21,\n   \"id\": \"b4aa2523\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"inputs_bs1 = torch.randn((10, 8), dtype=torch.float).to(\\\"cuda\\\")\\n\",\n    \"# This indicates dimension 0 of inputs_bs1 is dynamic with a range of values [1, 50]. No recompilation will happen when the batch size changes.\\n\",\n    \"torch._dynamo.mark_dynamic(inputs_bs1, 0, min=1, max=50)\\n\",\n    \"trt_model = trt.compile(\\n\",\n    \"    loaded_mlp,\\n\",\n    \"    ir=\\\"torch_compile\\\",\\n\",\n    \"    inputs=inputs_bs1,\\n\",\n    \"    enabled_precisions={torch.float},\\n\",\n    \"    timing_cache_path=timing_cache,\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 22,\n   \"id\": \"a5da8cab\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"WARNING:torch_tensorrt.dynamo._compiler:Node linear_default of op type call_function does not have metadata. This could sometimes lead to undefined behavior.\\n\",\n      \"WARNING:torch_tensorrt.dynamo._compiler:Some nodes do not have metadata (shape and dtype information). This could lead to problems sometimes if the graph has PyTorch and TensorRT segments.\\n\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Predictions:\\n\",\n      \"tensor([[2.3652],\\n\",\n      \"        [1.8444],\\n\",\n      \"        [2.4587],\\n\",\n      \"        [3.1243],\\n\",\n      \"        [2.2726],\\n\",\n      \"        [2.1818],\\n\",\n      \"        [1.5222],\\n\",\n      \"        [0.5554],\\n\",\n      \"        [2.2508],\\n\",\n      \"        [3.5971]], device='cuda:0')\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"stream = torch.cuda.Stream()\\n\",\n    \"with torch.no_grad(), torch.cuda.stream(stream):\\n\",\n    \"    testX = testX.to(device)\\n\",\n    \"    print(\\\"Predictions:\\\")\\n\",\n    \"    print(trt_model(testX))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"d2c55e07\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Compile using the Torch-TensorRT AOT Compiler\\n\",\n    \"Alternatively, use the Torch-TensorRT Dynamo backend for Ahead-of-Time (AOT) compilation to eagerly optimize the model in an explicit compilation phase. We first export the model to produce a traced graph representing the Tensor computation in an AOT fashion, which produces a `ExportedProgram` object which can be [serialized and reloaded](https://pytorch.org/TensorRT/user_guide/saving_models.html). We can then compile this IR using the Torch-TensorRT AOT compiler for inference.   \\n\",\n    \"\\n\",\n    \"[Read the docs](https://pytorch.org/TensorRT/user_guide/torch_tensorrt_explained.html) for more information on JIT vs AOT compilation.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 23,\n   \"id\": \"bf36a50d\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"example_inputs = (torch.randn((10, 8), dtype=torch.float).to(\\\"cuda\\\"),)\\n\",\n    \"\\n\",\n    \"# Mark dim 1 (batch size) as dynamic\\n\",\n    \"batch = torch.export.Dim(\\\"batch\\\", min=1, max=64)\\n\",\n    \"# Produce traced graph in ExportedProgram format\\n\",\n    \"exp_program = torch.export.export(loaded_mlp, args=example_inputs, dynamic_shapes={\\\"x\\\": {0: batch}})\\n\",\n    \"# Compile the traced graph to produce an optimized module\\n\",\n    \"trt_gm = trt.dynamo.compile(exp_program,\\n\",\n    \"                            tuple(example_inputs),\\n\",\n    \"                            enabled_precisions={torch.float},\\n\",\n    \"                            timing_cache_path=timing_cache,\\n\",\n    \"                            workspace_size=1<<30)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 24,\n   \"id\": \"4fc4efd5\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"<class 'torch.export.exported_program.ExportedProgram'>\\n\",\n      \"<class 'torch.fx.graph_module.GraphModule.__new__.<locals>.GraphModuleImpl'>\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"print(type(exp_program))\\n\",\n    \"print(type(trt_gm))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 25,\n   \"id\": \"1bcf9c47\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Predictions:\\n\",\n      \"tensor([[2.3653],\\n\",\n      \"        [1.8443],\\n\",\n      \"        [2.4586],\\n\",\n      \"        [3.1242],\\n\",\n      \"        [2.2725],\\n\",\n      \"        [2.1815],\\n\",\n      \"        [1.5221],\\n\",\n      \"        [0.5556],\\n\",\n      \"        [2.2508],\\n\",\n      \"        [3.5971]], device='cuda:0')\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"stream = torch.cuda.Stream()\\n\",\n    \"with torch.no_grad(), torch.cuda.stream(stream):\\n\",\n    \"    print(\\\"Predictions:\\\")\\n\",\n    \"    testX = testX.to(device)\\n\",\n    \"    print(trt_gm(testX))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"0eeb957a\",\n   \"metadata\": {},\n   \"source\": [\n    \"We can run the optimized module with a few different batch sizes (without recompilation!):\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"49f72c14\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Output shapes:\\n\",\n      \"torch.Size([10, 1])\\n\",\n      \"torch.Size([1, 1])\\n\",\n      \"torch.Size([50, 1])\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"inputs = (torch.randn((10, 8), dtype=torch.float).cuda(),)\\n\",\n    \"inputs_bs1 = (torch.randn((1, 8), dtype=torch.float).cuda(),)\\n\",\n    \"inputs_bs50 = (torch.randn((50, 8), dtype=torch.float).cuda(),)\\n\",\n    \"\\n\",\n    \"stream = torch.cuda.Stream()\\n\",\n    \"with torch.no_grad(), torch.cuda.stream(stream):\\n\",\n    \"    print(\\\"Output shapes:\\\")\\n\",\n    \"    print(trt_gm(*inputs).shape)\\n\",\n    \"    print(trt_gm(*inputs_bs1).shape)\\n\",\n    \"    print(trt_gm(*inputs_bs50).shape)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"b4fef57d\",\n   \"metadata\": {},\n   \"source\": [\n    \"We can serialize the ExportedProgram (a traced graph representing the model's forward function) using `torch.export.save` to be recompiled at a later date.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"876fea4a\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Saved ExportedProgram to models/trt_housing_model.ep\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"torch.export.save(exp_program, \\\"models/trt_housing_model.ep\\\")\\n\",\n    \"print(\\\"Saved ExportedProgram to models/trt_housing_model.ep\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"13631d1f-2c71-4bee-afcb-bd3b55ec87c5\",\n   \"metadata\": {},\n   \"source\": [\n    \"## PySpark\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"bb71dd36\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from pyspark.sql.functions import col, struct, pandas_udf, array\\n\",\n    \"from pyspark.ml.functions import predict_batch_udf\\n\",\n    \"from pyspark.sql.types import *\\n\",\n    \"from pyspark.sql import SparkSession\\n\",\n    \"from pyspark import SparkConf\\n\",\n    \"import json\\n\",\n    \"import pandas as pd\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"6769c060\",\n   \"metadata\": {},\n   \"source\": [\n    \"Check the cluster environment to handle any platform-specific Spark configurations.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 29,\n   \"id\": \"f7727b58\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"on_databricks = os.environ.get(\\\"DATABRICKS_RUNTIME_VERSION\\\", False)\\n\",\n    \"on_dataproc = os.environ.get(\\\"DATAPROC_IMAGE_VERSION\\\", False)\\n\",\n    \"on_standalone = not (on_databricks or on_dataproc)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"a3b7d360\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Create Spark Session\\n\",\n    \"\\n\",\n    \"For local standalone clusters, we'll connect to the cluster and create the Spark Session.  \\n\",\n    \"For CSP environments, Spark will either be preconfigured (Databricks) or we'll need to create the Spark Session (Dataproc).\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 30,\n   \"id\": \"52e9dbdb\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"25/02/04 13:46:28 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)\\n\",\n      \"25/02/04 13:46:28 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\\n\",\n      \"Setting default log level to \\\"WARN\\\".\\n\",\n      \"To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\\n\",\n      \"25/02/04 13:46:28 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"conf = SparkConf()\\n\",\n    \"\\n\",\n    \"if 'spark' not in globals():\\n\",\n    \"    if on_standalone:\\n\",\n    \"        import socket\\n\",\n    \"        conda_env = os.environ.get(\\\"CONDA_PREFIX\\\")\\n\",\n    \"        hostname = socket.gethostname()\\n\",\n    \"        conf.setMaster(f\\\"spark://{hostname}:7077\\\")\\n\",\n    \"        conf.set(\\\"spark.pyspark.python\\\", f\\\"{conda_env}/bin/python\\\")\\n\",\n    \"        conf.set(\\\"spark.pyspark.driver.python\\\", f\\\"{conda_env}/bin/python\\\")\\n\",\n    \"\\n\",\n    \"    conf.set(\\\"spark.executor.cores\\\", \\\"8\\\")\\n\",\n    \"    conf.set(\\\"spark.task.resource.gpu.amount\\\", \\\"0.125\\\")\\n\",\n    \"    conf.set(\\\"spark.executor.resource.gpu.amount\\\", \\\"1\\\")\\n\",\n    \"    conf.set(\\\"spark.sql.execution.arrow.pyspark.enabled\\\", \\\"true\\\")\\n\",\n    \"    conf.set(\\\"spark.python.worker.reuse\\\", \\\"true\\\")\\n\",\n    \"    \\n\",\n    \"spark = SparkSession.builder.appName(\\\"spark-dl-examples\\\").config(conf=conf).getOrCreate()\\n\",\n    \"sc = spark.sparkContext\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"e3b9937e-2c70-4d67-b95f-4d9d5ab17c12\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Create Spark DataFrame from Pandas DataFrame\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 31,\n   \"id\": \"cf35da14-61a3-4e7b-9d4f-086bf5e931b3\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"housing = fetch_california_housing()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 32,\n   \"id\": \"95148019-ea95-40e5-a529-fcdb9a06f928\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"X = StandardScaler().fit_transform(housing.data.astype(np.float32))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 33,\n   \"id\": \"f82d957c-6747-4408-aac8-45305afbfe5e\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"pdf = pd.DataFrame(X, columns=housing.feature_names)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 34,\n   \"id\": \"881afee9\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+------------+------------+-----------+------------+-----------+------------+----------+------------+\\n\",\n      \"|      MedInc|    HouseAge|   AveRooms|   AveBedrms| Population|    AveOccup|  Latitude|   Longitude|\\n\",\n      \"+------------+------------+-----------+------------+-----------+------------+----------+------------+\\n\",\n      \"|  0.20909257|  -1.1632254| 0.38946992|  0.04609274| -0.9806099| -0.07099328|0.61245227|-0.020113053|\\n\",\n      \"|-0.098627955|  0.34647804| 0.27216315|  -0.0129226| -0.6953838| -0.05380849| 1.0665938|  -1.2479742|\\n\",\n      \"| -0.66006273|   1.0616008|-0.55292207| -0.48945764|-0.13641118| 0.028952759| 1.1040496|  -1.3827378|\\n\",\n      \"|  0.08218294|   0.5848523|-0.13912922| -0.14707813|-0.19116047| -0.07136432|0.96827507|  -1.3028787|\\n\",\n      \"|   0.0784456|  -1.4810578| 0.57265776|  0.32067496|  1.0345173|-0.024157424| 1.4411427| -0.52423614|\\n\",\n      \"| -0.82318723| -0.36864465| 0.07829511|  -0.1808107|-0.67242444|-0.061470542| 1.9374212|  -1.0083897|\\n\",\n      \"|  0.59671736|   0.5848523| 0.19346413|  -0.1371872|-0.19645879| 0.009964322|0.96827507|  -1.2928978|\\n\",\n      \"|  -0.9612035|  -1.5605159|-0.56329846| 0.027148023|-0.71127874| -0.08471591| 0.5328614| -0.13990337|\\n\",\n      \"| -0.74344087|  -1.2426835| 0.27282518|   0.4037246| -0.9841421| -0.05610115| 1.2257773| -0.42940006|\\n\",\n      \"|   0.9784464|  -0.2891866| 0.24374022| -0.24670053| 0.28922042| -0.01102468| 1.1087307|  -1.2280084|\\n\",\n      \"|  -0.5070446|  -1.0043093|-0.78254056|0.0122275995|  2.8465424|-0.060435444| 0.8980464|  -1.2080427|\\n\",\n      \"| -0.18690155|   1.2205169|0.015323491|  0.12183313|-0.41015765|  0.04452552|  1.010412|  -1.3228445|\\n\",\n      \"|  -1.2551856|   1.6178073| -0.3341509|-0.060125165| -0.7554314| -0.08777025| 1.0291398|  -1.3477987|\\n\",\n      \"|   4.9607058|  -1.9578062|  1.4854684| -0.03948475|  2.1833694|0.0029250523|  1.024457|  -1.1581304|\\n\",\n      \"|  0.73652315|  -1.6399739|  0.7913185| -0.05238397|    1.67738|  0.01944797| 1.0993668|  -1.1331724|\\n\",\n      \"|   -0.505834|  0.18756187|-0.47093546| -0.24297306|-0.60619545| -0.10791535|  0.977639|  -1.2879055|\\n\",\n      \"| -0.88477343|-0.050812364| -0.6318951| -0.15244243| -0.5258376| -0.15618815| 0.9823201|  -1.2879055|\\n\",\n      \"| -0.42840376|   0.9821427| -0.2266495| -0.36083496| -0.6883194| -0.08552282| 0.5328614| -0.12493005|\\n\",\n      \"|   0.9369153|  -1.4810578|  0.6722208|-0.121177554|  0.3996021|  0.01291408| 1.1040496|  -1.1082181|\\n\",\n      \"| -0.80702734| -0.92485124|-0.26602685|  -0.1560743|  1.4398388| -0.09314839|0.55627036| -0.09498342|\\n\",\n      \"+------------+------------+-----------+------------+-----------+------------+----------+------------+\\n\",\n      \"only showing top 20 rows\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"schema = StructType([\\n\",\n    \"    StructField(\\\"MedInc\\\",FloatType(),True),\\n\",\n    \"    StructField(\\\"HouseAge\\\",FloatType(),True),\\n\",\n    \"    StructField(\\\"AveRooms\\\",FloatType(),True),\\n\",\n    \"    StructField(\\\"AveBedrms\\\",FloatType(),True),\\n\",\n    \"    StructField(\\\"Population\\\",FloatType(),True),\\n\",\n    \"    StructField(\\\"AveOccup\\\",FloatType(),True),\\n\",\n    \"    StructField(\\\"Latitude\\\",FloatType(),True),\\n\",\n    \"    StructField(\\\"Longitude\\\",FloatType(),True)\\n\",\n    \"])\\n\",\n    \"\\n\",\n    \"df = spark.createDataFrame(pdf, schema=schema).repartition(8)\\n\",\n    \"df.show(truncate=12)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 35,\n   \"id\": \"7b33d367-fbf9-4918-b755-5447125547c4\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"StructType([StructField('MedInc', FloatType(), True), StructField('HouseAge', FloatType(), True), StructField('AveRooms', FloatType(), True), StructField('AveBedrms', FloatType(), True), StructField('Population', FloatType(), True), StructField('AveOccup', FloatType(), True), StructField('Latitude', FloatType(), True), StructField('Longitude', FloatType(), True)])\"\n      ]\n     },\n     \"execution_count\": 35,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"df.schema\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 36,\n   \"id\": \"751bff7a-b687-4184-b3fa-b5f5b46ef5d1\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"data_path = \\\"spark-dl-datasets/california_housing\\\"\\n\",\n    \"if on_databricks:\\n\",\n    \"    dbutils.fs.mkdirs(\\\"/FileStore/spark-dl-datasets\\\")\\n\",\n    \"    data_path = \\\"dbfs:/FileStore/\\\" + data_path\\n\",\n    \"\\n\",\n    \"df.write.mode(\\\"overwrite\\\").parquet(data_path)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"88c3cd75\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Inference using Spark DL API\\n\",\n    \"\\n\",\n    \"Distributed inference using the PySpark [predict_batch_udf](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.functions.predict_batch_udf.html#pyspark.ml.functions.predict_batch_udf):\\n\",\n    \"\\n\",\n    \"- predict_batch_fn uses PyTorch APIs to load the model and return a predict function which operates on numpy arrays \\n\",\n    \"- predict_batch_udf will convert the Spark DataFrame columns into numpy input batches for the predict function\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 37,\n   \"id\": \"1e40c266-24de-454d-a776-f3716ba50e90\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"df = spark.read.parquet(data_path)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 38,\n   \"id\": \"5b144c17\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"columns = df.columns\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 39,\n   \"id\": \"3d608e2f-66a8-44b5-9cde-5f7837bf4247\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# get absolute path to model\\n\",\n    \"model_path = \\\"{}/models/trt_housing_model.ep\\\".format(os.getcwd())\\n\",\n    \"\\n\",\n    \"# For cloud environments, copy the model to the distributed file system.\\n\",\n    \"if on_databricks:\\n\",\n    \"    dbutils.fs.mkdirs(\\\"/FileStore/spark-dl-models\\\")\\n\",\n    \"    dbfs_model_path = \\\"/dbfs/FileStore/spark-dl-models/trt_housing_model.ep\\\"\\n\",\n    \"    shutil.copy(model_path, dbfs_model_path)\\n\",\n    \"    model_path = dbfs_model_path\\n\",\n    \"elif on_dataproc:\\n\",\n    \"    # GCS is mounted at /mnt/gcs by the init script\\n\",\n    \"    models_dir = \\\"/mnt/gcs/spark-dl/models\\\"\\n\",\n    \"    os.mkdir(models_dir) if not os.path.exists(models_dir) else None\\n\",\n    \"    gcs_model_path = models_dir + \\\"/trt_housing_model.ep\\\"\\n\",\n    \"    shutil.copy(model_path, gcs_model_path)\\n\",\n    \"    model_path = gcs_model_path\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"2fd143e7\",\n   \"metadata\": {},\n   \"source\": [\n    \"For inference on Spark, we'll load the ExportedProgram and compile the model with the Torch-TensorRT AOT compiler and cache on the executor. \"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 40,\n   \"id\": \"fc400771\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# A resource warning may occur due to unclosed file descriptors used by TensorRT across multiple PySpark daemon processes.\\n\",\n    \"# These can be safely ignored as the resources will be cleaned up when the worker processes terminate.\\n\",\n    \"\\n\",\n    \"import warnings\\n\",\n    \"warnings.simplefilter(\\\"ignore\\\", ResourceWarning)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 41,\n   \"id\": \"a2f45f5d-c941-4197-a274-1eec2af3fca4\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def predict_batch_fn():\\n\",\n    \"    import torch\\n\",\n    \"    import torch_tensorrt as trt\\n\",\n    \"\\n\",\n    \"    device = \\\"cuda\\\" if torch.cuda.is_available() else \\\"cpu\\\"\\n\",\n    \"    if device != \\\"cuda\\\":\\n\",\n    \"        raise ValueError(\\\"This function uses the TensorRT model which requires a GPU device\\\")\\n\",\n    \"\\n\",\n    \"    example_inputs = (torch.randn((50, 8), dtype=torch.float).to(\\\"cuda\\\"),)\\n\",\n    \"    exp_program = torch.export.load(model_path)\\n\",\n    \"    trt_gm = trt.dynamo.compile(exp_program,\\n\",\n    \"                            tuple(example_inputs),\\n\",\n    \"                            enabled_precisions={torch.float},\\n\",\n    \"                            timing_cache_path=timing_cache,\\n\",\n    \"                            workspace_size=1<<30)\\n\",\n    \"\\n\",\n    \"    print(\\\"Model compiled.\\\")\\n\",\n    \"    \\n\",\n    \"    def predict(inputs):\\n\",\n    \"        stream = torch.cuda.Stream()\\n\",\n    \"        with torch.no_grad(), torch.cuda.stream(stream), trt.logging.errors():\\n\",\n    \"            print(f\\\"Predict {inputs.shape}\\\")\\n\",\n    \"            torch_inputs = torch.from_numpy(inputs).to(device)\\n\",\n    \"            outputs = trt_gm(torch_inputs) # .flatten()\\n\",\n    \"            return outputs.detach().cpu().numpy()\\n\",\n    \"\\n\",\n    \"    return predict\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 42,\n   \"id\": \"220a00a4-e842-4f5d-a4b3-7693d09e2d31\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"regress = predict_batch_udf(predict_batch_fn,\\n\",\n    \"                             return_type=FloatType(),\\n\",\n    \"                             input_tensor_shapes=[[8]],\\n\",\n    \"                             batch_size=50)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 43,\n   \"id\": \"0f3bf287-8ffc-4456-8772-e97c418d6aee\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 7:==============>                                            (2 + 6) / 8]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 30.4 ms, sys: 13.1 ms, total: 43.5 ms\\n\",\n      \"Wall time: 10.1 s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"preds = df.withColumn(\\\"preds\\\", regress(struct(*columns)))\\n\",\n    \"results = preds.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 44,\n   \"id\": \"6cd23b71-296d-4ce7-b56c-567cc2eec79c\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 31.6 ms, sys: 7.39 ms, total: 39 ms\\n\",\n      \"Wall time: 263 ms\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"preds = df.withColumn(\\\"preds\\\", regress(array(*columns)))\\n\",\n    \"results = preds.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 45,\n   \"id\": \"75d16bd5\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 28.7 ms, sys: 6.67 ms, total: 35.4 ms\\n\",\n      \"Wall time: 296 ms\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"preds = df.withColumn(\\\"preds\\\", regress(array(*columns)))\\n\",\n    \"results = preds.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 46,\n   \"id\": \"764a40d8-25f7-425c-ba03-fe8c45f4b063\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+------------+------------+-----------+------------+-----------+------------+----------+------------+---------+\\n\",\n      \"|      MedInc|    HouseAge|   AveRooms|   AveBedrms| Population|    AveOccup|  Latitude|   Longitude|    preds|\\n\",\n      \"+------------+------------+-----------+------------+-----------+------------+----------+------------+---------+\\n\",\n      \"|  0.20909257|  -1.1632254| 0.38946992|  0.04609274| -0.9806099| -0.07099328|0.61245227|-0.020113053|1.3746364|\\n\",\n      \"|-0.098627955|  0.34647804| 0.27216315|  -0.0129226| -0.6953838| -0.05380849| 1.0665938|  -1.2479742|1.8087528|\\n\",\n      \"| -0.66006273|   1.0616008|-0.55292207| -0.48945764|-0.13641118| 0.028952759| 1.1040496|  -1.3827378|1.4245079|\\n\",\n      \"|  0.08218294|   0.5848523|-0.13912922| -0.14707813|-0.19116047| -0.07136432|0.96827507|  -1.3028787|2.3895802|\\n\",\n      \"|   0.0784456|  -1.4810578| 0.57265776|  0.32067496|  1.0345173|-0.024157424| 1.4411427| -0.52423614|1.3616933|\\n\",\n      \"| -0.82318723| -0.36864465| 0.07829511|  -0.1808107|-0.67242444|-0.061470542| 1.9374212|  -1.0083897|0.7539238|\\n\",\n      \"|  0.59671736|   0.5848523| 0.19346413|  -0.1371872|-0.19645879| 0.009964322|0.96827507|  -1.2928978|2.6816423|\\n\",\n      \"|  -0.9612035|  -1.5605159|-0.56329846| 0.027148023|-0.71127874| -0.08471591| 0.5328614| -0.13990337|1.1731354|\\n\",\n      \"| -0.74344087|  -1.2426835| 0.27282518|   0.4037246| -0.9841421| -0.05610115| 1.2257773| -0.42940006|1.0198532|\\n\",\n      \"|   0.9784464|  -0.2891866| 0.24374022| -0.24670053| 0.28922042| -0.01102468| 1.1087307|  -1.2280084| 2.708211|\\n\",\n      \"|  -0.5070446|  -1.0043093|-0.78254056|0.0122275995|  2.8465424|-0.060435444| 0.8980464|  -1.2080427|2.0327075|\\n\",\n      \"| -0.18690155|   1.2205169|0.015323491|  0.12183313|-0.41015765|  0.04452552|  1.010412|  -1.3228445|1.9909104|\\n\",\n      \"|  -1.2551856|   1.6178073| -0.3341509|-0.060125165| -0.7554314| -0.08777025| 1.0291398|  -1.3477987|1.2702764|\\n\",\n      \"|   4.9607058|  -1.9578062|  1.4854684| -0.03948475|  2.1833694|0.0029250523|  1.024457|  -1.1581304| 5.975229|\\n\",\n      \"|  0.73652315|  -1.6399739|  0.7913185| -0.05238397|    1.67738|  0.01944797| 1.0993668|  -1.1331724|1.9309721|\\n\",\n      \"|   -0.505834|  0.18756187|-0.47093546| -0.24297306|-0.60619545| -0.10791535|  0.977639|  -1.2879055|1.7610806|\\n\",\n      \"| -0.88477343|-0.050812364| -0.6318951| -0.15244243| -0.5258376| -0.15618815| 0.9823201|  -1.2879055| 1.655031|\\n\",\n      \"| -0.42840376|   0.9821427| -0.2266495| -0.36083496| -0.6883194| -0.08552282| 0.5328614| -0.12493005|1.1175063|\\n\",\n      \"|   0.9369153|  -1.4810578|  0.6722208|-0.121177554|  0.3996021|  0.01291408| 1.1040496|  -1.1082181|2.1779811|\\n\",\n      \"| -0.80702734| -0.92485124|-0.26602685|  -0.1560743|  1.4398388| -0.09314839|0.55627036| -0.09498342|0.9102398|\\n\",\n      \"+------------+------------+-----------+------------+-----------+------------+----------+------------+---------+\\n\",\n      \"only showing top 20 rows\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"preds.show()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 47,\n   \"id\": \"0aa85f81\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# This will clear the engine cache (containing previously compiled TensorRT engines) and reset the CUDA Context.\\n\",\n    \"torch._dynamo.reset()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"53536808\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Using Triton Inference Server\\n\",\n    \"In this section, we demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL.  \\n\",\n    \"We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server.  \\n\",\n    \"\\n\",\n    \"The process looks like this:\\n\",\n    \"- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.\\n\",\n    \"- Define a Triton inference function, which contains a client that binds to the local server on a given node and sends inference requests.\\n\",\n    \"- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.\\n\",\n    \"- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.\\n\",\n    \"\\n\",\n    \"<img src=\\\"../images/spark-server.png\\\" alt=\\\"drawing\\\" width=\\\"700\\\"/>\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 48,\n   \"id\": \"a9ab4cdf-8103-447e-9ac8-944e2e527239\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from functools import partial\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"1b77dc96\",\n   \"metadata\": {},\n   \"source\": [\n    \"Import the helper class from server_utils.py:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 49,\n   \"id\": \"1ac83062\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"sc.addPyFile(\\\"server_utils.py\\\")\\n\",\n    \"\\n\",\n    \"from server_utils import TritonServerManager\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"a4cc5d81\",\n   \"metadata\": {},\n   \"source\": [\n    \"Define the Triton Server function:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 50,\n   \"id\": \"6632636e-67a3-406c-832c-758aac4245fd\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def triton_server(ports, model_path):\\n\",\n    \"    import time\\n\",\n    \"    import signal\\n\",\n    \"    import numpy as np\\n\",\n    \"    import torch\\n\",\n    \"    import torch_tensorrt as trt\\n\",\n    \"    from pytriton.decorators import batch\\n\",\n    \"    from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor\\n\",\n    \"    from pytriton.triton import Triton, TritonConfig\\n\",\n    \"    from pyspark import TaskContext\\n\",\n    \"\\n\",\n    \"    print(f\\\"SERVER: Initializing model on worker {TaskContext.get().partitionId()}.\\\")\\n\",\n    \"    device = torch.device(\\\"cuda\\\" if torch.cuda.is_available() else \\\"cpu\\\")\\n\",\n    \"    \\n\",\n    \"    exp_program = torch.export.load(model_path)\\n\",\n    \"    example_inputs = (torch.randn((50, 8), dtype=torch.float).to(\\\"cuda\\\"),)\\n\",\n    \"    trt_gm = trt.dynamo.compile(exp_program,\\n\",\n    \"                            tuple(example_inputs),\\n\",\n    \"                            enabled_precisions={torch.float},\\n\",\n    \"                            workspace_size=1<<30)\\n\",\n    \"\\n\",\n    \"    print(\\\"SERVER: Compiled model.\\\")\\n\",\n    \"\\n\",\n    \"    @batch\\n\",\n    \"    def _infer_fn(**inputs):\\n\",\n    \"        features = inputs[\\\"features\\\"]\\n\",\n    \"        if len(inputs[\\\"features\\\"]) != 1:\\n\",\n    \"            features = np.squeeze(features)\\n\",\n    \"        stream = torch.cuda.Stream()\\n\",\n    \"        with torch.no_grad(), torch.cuda.stream(stream):\\n\",\n    \"            torch_inputs = torch.from_numpy(features).to(device)\\n\",\n    \"            outputs = trt_gm(torch_inputs)\\n\",\n    \"            return {\\n\",\n    \"                \\\"preds\\\": outputs.cpu().numpy(),\\n\",\n    \"            }\\n\",\n    \"\\n\",\n    \"    workspace_path = f\\\"/tmp/triton_{time.strftime('%m_%d_%M_%S')}\\\"\\n\",\n    \"    triton_conf = TritonConfig(http_port=ports[0], grpc_port=ports[1], metrics_port=ports[2])\\n\",\n    \"    with Triton(config=triton_conf, workspace=workspace_path) as triton:\\n\",\n    \"        triton.bind(\\n\",\n    \"            model_name=\\\"HousingModel\\\",\\n\",\n    \"            infer_func=_infer_fn,\\n\",\n    \"            inputs=[\\n\",\n    \"                Tensor(name=\\\"features\\\", dtype=np.float32, shape=(-1,)),\\n\",\n    \"            ],\\n\",\n    \"            outputs=[\\n\",\n    \"                Tensor(name=\\\"preds\\\", dtype=np.float32, shape=(-1,)),\\n\",\n    \"            ],\\n\",\n    \"            config=ModelConfig(\\n\",\n    \"                max_batch_size=50,\\n\",\n    \"                batcher=DynamicBatcher(max_queue_delay_microseconds=5000),  # 5ms\\n\",\n    \"            ),\\n\",\n    \"            strict=True,\\n\",\n    \"        )\\n\",\n    \"\\n\",\n    \"        def _stop_triton(signum, frame):\\n\",\n    \"            # The server manager sends SIGTERM to stop the server; this function ensures graceful cleanup.\\n\",\n    \"            print(\\\"SERVER: Received SIGTERM. Stopping Triton server.\\\")\\n\",\n    \"            triton.stop()\\n\",\n    \"\\n\",\n    \"        signal.signal(signal.SIGTERM, _stop_triton)\\n\",\n    \"\\n\",\n    \"        print(\\\"SERVER: Serving inference\\\")\\n\",\n    \"        triton.serve()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"74121cd7\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Start Triton servers\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"6d6b7143\",\n   \"metadata\": {},\n   \"source\": [\n    \"The `TritonServerManager` will handle the lifecycle of Triton server instances across the Spark cluster:\\n\",\n    \"- Find available ports for HTTP/gRPC/metrics\\n\",\n    \"- Deploy a server on each node via stage-level scheduling\\n\",\n    \"- Gracefully shutdown servers across nodes\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 54,\n   \"id\": \"2fb22db8\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"model_name = \\\"HousingModel\\\"\\n\",\n    \"server_manager = TritonServerManager(model_name=model_name, model_path=model_path)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"e067aa14\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2025-02-07 11:03:44,809 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\\n\",\n      \"2025-02-07 11:03:44,810 - INFO - Starting 1 servers.\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'cb4ae00-lcedt': (2020631, [7000, 7001, 7002])}\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"# Returns {'hostname', (server_pid, [http_port, grpc_port, metrics_port])}\\n\",\n    \"server_manager.start_servers(triton_server)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"9a1ac038\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Define client function\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"d4ac45ef\",\n   \"metadata\": {},\n   \"source\": [\n    \"Get the hostname -> url mapping from the server manager:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"92760dac\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"host_to_http_url = server_manager.host_to_http_url  # or server_manager.host_to_grpc_url\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"122ebe7c\",\n   \"metadata\": {},\n   \"source\": [\n    \"Define the Triton inference function, which returns a predict function for batch inference through the server:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 57,\n   \"id\": \"1ae91c54\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def triton_fn(model_name, host_to_url):\\n\",\n    \"    import socket\\n\",\n    \"    from pytriton.client import ModelClient\\n\",\n    \"\\n\",\n    \"    url = host_to_url[socket.gethostname()]\\n\",\n    \"    print(f\\\"Connecting to Triton model {model_name} at {url}.\\\")\\n\",\n    \"\\n\",\n    \"    def infer_batch(inputs):\\n\",\n    \"        with ModelClient(url, model_name, inference_timeout_s=240) as client:\\n\",\n    \"            result_data = client.infer_batch(inputs)\\n\",\n    \"            return result_data[\\\"preds\\\"]\\n\",\n    \"        \\n\",\n    \"    return infer_batch\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 60,\n   \"id\": \"d3e64fda-117b-4810-a9a2-dd498239496f\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"regress = predict_batch_udf(partial(triton_fn, model_name=model_name, host_to_url=host_to_http_url),\\n\",\n    \"                               input_tensor_shapes=[[8]],\\n\",\n    \"                               return_type=FloatType(),\\n\",\n    \"                               batch_size=50)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"20b8514e-01de-481f-86aa-75afd99bcc7c\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Run Inference\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 58,\n   \"id\": \"5eae04bc-75ca-421a-87c8-ac507ce1f2f5\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"df = spark.read.parquet(data_path)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 59,\n   \"id\": \"b350bd8e-9b8f-4511-9ddf-76d917b21b5f\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"columns = df.columns\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 61,\n   \"id\": \"a24149a5-3adc-4089-8769-13cf1e44547a\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 16:>                                                         (0 + 8) / 8]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 25.8 ms, sys: 6.21 ms, total: 32.1 ms\\n\",\n      \"Wall time: 2.37 s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"# first pass caches model/fn\\n\",\n    \"predictions = df.withColumn(\\\"preds\\\", regress(struct(*columns)))\\n\",\n    \"preds = predictions.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 62,\n   \"id\": \"df2ce39f-30af-491a-8472-800fb1ce8458\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 17:>                                                         (0 + 8) / 8]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 171 ms, sys: 3.76 ms, total: 174 ms\\n\",\n      \"Wall time: 2.5 s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"predictions = df.withColumn(\\\"preds\\\", regress(array(*columns)))\\n\",\n    \"preds = predictions.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 63,\n   \"id\": \"ca6f3eaa-9569-45d0-88bf-9aa0757e1ecb\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 18:>                                                         (0 + 8) / 8]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 24.4 ms, sys: 4.83 ms, total: 29.2 ms\\n\",\n      \"Wall time: 1.97 s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"predictions = df.withColumn(\\\"preds\\\", regress(array(*columns)))\\n\",\n    \"preds = predictions.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 64,\n   \"id\": \"b79c62c8-e1e8-4467-8aef-8939c31833b8\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+------------+------------+-----------+------------+-----------+------------+----------+------------+---------+\\n\",\n      \"|      MedInc|    HouseAge|   AveRooms|   AveBedrms| Population|    AveOccup|  Latitude|   Longitude|    preds|\\n\",\n      \"+------------+------------+-----------+------------+-----------+------------+----------+------------+---------+\\n\",\n      \"|  0.20909257|  -1.1632254| 0.38946992|  0.04609274| -0.9806099| -0.07099328|0.61245227|-0.020113053|1.3746364|\\n\",\n      \"|-0.098627955|  0.34647804| 0.27216315|  -0.0129226| -0.6953838| -0.05380849| 1.0665938|  -1.2479742|1.8087528|\\n\",\n      \"| -0.66006273|   1.0616008|-0.55292207| -0.48945764|-0.13641118| 0.028952759| 1.1040496|  -1.3827378|1.4245079|\\n\",\n      \"|  0.08218294|   0.5848523|-0.13912922| -0.14707813|-0.19116047| -0.07136432|0.96827507|  -1.3028787|2.3895802|\\n\",\n      \"|   0.0784456|  -1.4810578| 0.57265776|  0.32067496|  1.0345173|-0.024157424| 1.4411427| -0.52423614|1.3616933|\\n\",\n      \"| -0.82318723| -0.36864465| 0.07829511|  -0.1808107|-0.67242444|-0.061470542| 1.9374212|  -1.0083897|0.7539238|\\n\",\n      \"|  0.59671736|   0.5848523| 0.19346413|  -0.1371872|-0.19645879| 0.009964322|0.96827507|  -1.2928978|2.6816423|\\n\",\n      \"|  -0.9612035|  -1.5605159|-0.56329846| 0.027148023|-0.71127874| -0.08471591| 0.5328614| -0.13990337|1.1731354|\\n\",\n      \"| -0.74344087|  -1.2426835| 0.27282518|   0.4037246| -0.9841421| -0.05610115| 1.2257773| -0.42940006|1.0198532|\\n\",\n      \"|   0.9784464|  -0.2891866| 0.24374022| -0.24670053| 0.28922042| -0.01102468| 1.1087307|  -1.2280084| 2.708211|\\n\",\n      \"|  -0.5070446|  -1.0043093|-0.78254056|0.0122275995|  2.8465424|-0.060435444| 0.8980464|  -1.2080427|2.0327075|\\n\",\n      \"| -0.18690155|   1.2205169|0.015323491|  0.12183313|-0.41015765|  0.04452552|  1.010412|  -1.3228445|1.9909104|\\n\",\n      \"|  -1.2551856|   1.6178073| -0.3341509|-0.060125165| -0.7554314| -0.08777025| 1.0291398|  -1.3477987|1.2702764|\\n\",\n      \"|   4.9607058|  -1.9578062|  1.4854684| -0.03948475|  2.1833694|0.0029250523|  1.024457|  -1.1581304| 5.975229|\\n\",\n      \"|  0.73652315|  -1.6399739|  0.7913185| -0.05238397|    1.67738|  0.01944797| 1.0993668|  -1.1331724|1.9309721|\\n\",\n      \"|   -0.505834|  0.18756187|-0.47093546| -0.24297306|-0.60619545| -0.10791535|  0.977639|  -1.2879055|1.7610806|\\n\",\n      \"| -0.88477343|-0.050812364| -0.6318951| -0.15244243| -0.5258376| -0.15618815| 0.9823201|  -1.2879055| 1.655031|\\n\",\n      \"| -0.42840376|   0.9821427| -0.2266495| -0.36083496| -0.6883194| -0.08552282| 0.5328614| -0.12493005|1.1175063|\\n\",\n      \"|   0.9369153|  -1.4810578|  0.6722208|-0.121177554|  0.3996021|  0.01291408| 1.1040496|  -1.1082181|2.1779811|\\n\",\n      \"| -0.80702734| -0.92485124|-0.26602685|  -0.1560743|  1.4398388| -0.09314839|0.55627036| -0.09498342|0.9102398|\\n\",\n      \"+------------+------------+-----------+------------+-----------+------------+----------+------------+---------+\\n\",\n      \"only showing top 20 rows\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"predictions.show()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"3fec23b0-eaf2-4b6a-aa38-7a09873ed6eb\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"source\": [\n    \"#### Stop Triton Server on each executor\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 65,\n   \"id\": \"8084bdef\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"[True]\"\n      ]\n     },\n     \"execution_count\": 65,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"server_manager.stop_servers()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 66,\n   \"id\": \"0138a029-87c5-497f-ac5c-3eed0e11b0f6\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"if not on_databricks: # on databricks, spark.stop() puts the cluster in a bad state\\n\",\n    \"    spark.stop()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"d24147e7-5695-44a0-9961-b94bfba1cfff\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"spark-dl-torch\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.11.11\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "examples/ML+DL-Examples/Spark-DL/dl_inference/pytorch/image_classification_torch.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"9e87c927\",\n   \"metadata\": {},\n   \"source\": [\n    \"<img src=\\\"http://developer.download.nvidia.com/notebooks/dlsw-notebooks/tensorrt_torchtrt_efficientnet/nvidia_logo.png\\\" width=\\\"90px\\\">\\n\",\n    \"\\n\",\n    \"# PySpark PyTorch Inference\\n\",\n    \"\\n\",\n    \"### Image Classification\\n\",\n    \"\\n\",\n    \"In this notebook, we will train an MLP to perform image classification on FashionMNIST, and load it for distributed inference with Spark.\\n\",\n    \"\\n\",\n    \"Based on: https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html  \\n\",\n    \"\\n\",\n    \"We also demonstrate accelerated inference via Torch-TensorRT model compilation.   \"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"id\": \"91d7ec98\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import torch\\n\",\n    \"import os\\n\",\n    \"import shutil\\n\",\n    \"from torch import nn\\n\",\n    \"from torch.utils.data import DataLoader\\n\",\n    \"from torchvision import datasets\\n\",\n    \"from torchvision.transforms import ToTensor\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"id\": \"f71f801d\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"os.mkdir('models') if not os.path.exists('models') else None\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"id\": \"d714f40d\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"'2.5.1+cu124'\"\n      ]\n     },\n     \"execution_count\": 3,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"torch.__version__\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"d0f6fb37\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Load Dataset\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"id\": \"1c942a46\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Download training data from open datasets.\\n\",\n    \"training_data = datasets.FashionMNIST(\\n\",\n    \"    root=\\\"datasets/data\\\",\\n\",\n    \"    train=True,\\n\",\n    \"    download=True,\\n\",\n    \"    transform=ToTensor(),\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"# Download test data from open datasets.\\n\",\n    \"test_data = datasets.FashionMNIST(\\n\",\n    \"    root=\\\"datasets/data\\\",\\n\",\n    \"    train=False,\\n\",\n    \"    download=True,\\n\",\n    \"    transform=ToTensor(),\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"id\": \"4a89aa8e-ef62-4aac-8260-4b004f2c1b55\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"classes = [\\n\",\n    \"    \\\"T-shirt/top\\\",\\n\",\n    \"    \\\"Trouser\\\",\\n\",\n    \"    \\\"Pullover\\\",\\n\",\n    \"    \\\"Dress\\\",\\n\",\n    \"    \\\"Coat\\\",\\n\",\n    \"    \\\"Sandal\\\",\\n\",\n    \"    \\\"Shirt\\\",\\n\",\n    \"    \\\"Sneaker\\\",\\n\",\n    \"    \\\"Bag\\\",\\n\",\n    \"    \\\"Ankle boot\\\",\\n\",\n    \"]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"id\": \"10a97111\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28]) torch.float32\\n\",\n      \"Shape of y: torch.Size([64]) torch.int64\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"batch_size = 64\\n\",\n    \"\\n\",\n    \"# Create data loaders.\\n\",\n    \"train_dataloader = DataLoader(training_data, batch_size=batch_size)\\n\",\n    \"test_dataloader = DataLoader(test_data, batch_size=batch_size)\\n\",\n    \"\\n\",\n    \"for X, y in test_dataloader:\\n\",\n    \"    print(f\\\"Shape of X [N, C, H, W]: {X.shape} {X.dtype}\\\")\\n\",\n    \"    print(f\\\"Shape of y: {y.shape} {y.dtype}\\\")\\n\",\n    \"    break\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"ca7af350\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Create model\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"id\": \"512d0bc7\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Using cuda device\\n\",\n      \"NeuralNetwork(\\n\",\n      \"  (linear_relu_stack): Sequential(\\n\",\n      \"    (0): Linear(in_features=784, out_features=512, bias=True)\\n\",\n      \"    (1): ReLU()\\n\",\n      \"    (2): Linear(in_features=512, out_features=512, bias=True)\\n\",\n      \"    (3): ReLU()\\n\",\n      \"    (4): Linear(in_features=512, out_features=10, bias=True)\\n\",\n      \"  )\\n\",\n      \")\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"# Get cpu or gpu device for training.\\n\",\n    \"device = \\\"cuda\\\" if torch.cuda.is_available() else \\\"cpu\\\"\\n\",\n    \"print(f\\\"Using {device} device\\\")\\n\",\n    \"\\n\",\n    \"# Define model\\n\",\n    \"class NeuralNetwork(nn.Module):\\n\",\n    \"    def __init__(self):\\n\",\n    \"        super(NeuralNetwork, self).__init__()\\n\",\n    \"        self.linear_relu_stack = nn.Sequential(\\n\",\n    \"            nn.Linear(28*28, 512),\\n\",\n    \"            nn.ReLU(),\\n\",\n    \"            nn.Linear(512, 512),\\n\",\n    \"            nn.ReLU(),\\n\",\n    \"            nn.Linear(512, 10)\\n\",\n    \"        )\\n\",\n    \"\\n\",\n    \"    def forward(self, x):\\n\",\n    \"        logits = self.linear_relu_stack(x)\\n\",\n    \"        return logits\\n\",\n    \"\\n\",\n    \"model = NeuralNetwork().to(device)\\n\",\n    \"print(model)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"4573c1b7\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Train Model\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"id\": \"4d4f5538\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"loss_fn = nn.CrossEntropyLoss()\\n\",\n    \"optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 9,\n   \"id\": \"92d9076a\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def train(dataloader, model, loss_fn, optimizer):\\n\",\n    \"    size = len(dataloader.dataset)\\n\",\n    \"    model.train()\\n\",\n    \"    for batch, (X, y) in enumerate(dataloader):\\n\",\n    \"        X, y = X.to(device), y.to(device)\\n\",\n    \"        X = torch.flatten(X, start_dim=1, end_dim=-1)\\n\",\n    \"\\n\",\n    \"        # Zero gradients\\n\",\n    \"        optimizer.zero_grad()\\n\",\n    \"\\n\",\n    \"        # Compute prediction error\\n\",\n    \"        pred = model(X)\\n\",\n    \"        loss = loss_fn(pred, y)\\n\",\n    \"\\n\",\n    \"        # Backpropagation\\n\",\n    \"        loss.backward()\\n\",\n    \"        optimizer.step()\\n\",\n    \"\\n\",\n    \"        if batch % 100 == 0:\\n\",\n    \"            loss, current = loss.item(), (batch + 1) * len(X)\\n\",\n    \"            print(f\\\"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 10,\n   \"id\": \"11c5650d\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def test(dataloader, model, loss_fn):\\n\",\n    \"    size = len(dataloader.dataset)\\n\",\n    \"    num_batches = len(dataloader)\\n\",\n    \"    model.eval()\\n\",\n    \"    test_loss, correct = 0, 0\\n\",\n    \"    with torch.no_grad():\\n\",\n    \"        for X, y in dataloader:\\n\",\n    \"            X, y = X.to(device), y.to(device)\\n\",\n    \"            X = torch.flatten(X, start_dim=1, end_dim=-1)\\n\",\n    \"            pred = model(X)\\n\",\n    \"            test_loss += loss_fn(pred, y).item()\\n\",\n    \"            correct += (pred.argmax(1) == y).type(torch.float).sum().item()\\n\",\n    \"    test_loss /= num_batches\\n\",\n    \"    correct /= size\\n\",\n    \"    print(f\\\"Test Error: \\\\n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \\\\n\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 11,\n   \"id\": \"854608e6\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Epoch 1\\n\",\n      \"-------------------------------\\n\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"loss: 2.298206  [   64/60000]\\n\",\n      \"loss: 2.283203  [ 6464/60000]\\n\",\n      \"loss: 2.262282  [12864/60000]\\n\",\n      \"loss: 2.259791  [19264/60000]\\n\",\n      \"loss: 2.240928  [25664/60000]\\n\",\n      \"loss: 2.218922  [32064/60000]\\n\",\n      \"loss: 2.225280  [38464/60000]\\n\",\n      \"loss: 2.193091  [44864/60000]\\n\",\n      \"loss: 2.194699  [51264/60000]\\n\",\n      \"loss: 2.157922  [57664/60000]\\n\",\n      \"Test Error: \\n\",\n      \" Accuracy: 38.6%, Avg loss: 2.149652 \\n\",\n      \"\\n\",\n      \"Epoch 2\\n\",\n      \"-------------------------------\\n\",\n      \"loss: 2.164765  [   64/60000]\\n\",\n      \"loss: 2.153999  [ 6464/60000]\\n\",\n      \"loss: 2.094229  [12864/60000]\\n\",\n      \"loss: 2.107332  [19264/60000]\\n\",\n      \"loss: 2.060189  [25664/60000]\\n\",\n      \"loss: 2.009164  [32064/60000]\\n\",\n      \"loss: 2.033063  [38464/60000]\\n\",\n      \"loss: 1.954014  [44864/60000]\\n\",\n      \"loss: 1.968186  [51264/60000]\\n\",\n      \"loss: 1.892358  [57664/60000]\\n\",\n      \"Test Error: \\n\",\n      \" Accuracy: 54.1%, Avg loss: 1.883826 \\n\",\n      \"\\n\",\n      \"Epoch 3\\n\",\n      \"-------------------------------\\n\",\n      \"loss: 1.922989  [   64/60000]\\n\",\n      \"loss: 1.895849  [ 6464/60000]\\n\",\n      \"loss: 1.767882  [12864/60000]\\n\",\n      \"loss: 1.804950  [19264/60000]\\n\",\n      \"loss: 1.702711  [25664/60000]\\n\",\n      \"loss: 1.664090  [32064/60000]\\n\",\n      \"loss: 1.682484  [38464/60000]\\n\",\n      \"loss: 1.577310  [44864/60000]\\n\",\n      \"loss: 1.613093  [51264/60000]\\n\",\n      \"loss: 1.510797  [57664/60000]\\n\",\n      \"Test Error: \\n\",\n      \" Accuracy: 59.5%, Avg loss: 1.517127 \\n\",\n      \"\\n\",\n      \"Epoch 4\\n\",\n      \"-------------------------------\\n\",\n      \"loss: 1.588409  [   64/60000]\\n\",\n      \"loss: 1.558777  [ 6464/60000]\\n\",\n      \"loss: 1.393466  [12864/60000]\\n\",\n      \"loss: 1.465835  [19264/60000]\\n\",\n      \"loss: 1.350062  [25664/60000]\\n\",\n      \"loss: 1.359687  [32064/60000]\\n\",\n      \"loss: 1.370576  [38464/60000]\\n\",\n      \"loss: 1.287119  [44864/60000]\\n\",\n      \"loss: 1.330430  [51264/60000]\\n\",\n      \"loss: 1.238912  [57664/60000]\\n\",\n      \"Test Error: \\n\",\n      \" Accuracy: 62.4%, Avg loss: 1.254357 \\n\",\n      \"\\n\",\n      \"Epoch 5\\n\",\n      \"-------------------------------\\n\",\n      \"loss: 1.333722  [   64/60000]\\n\",\n      \"loss: 1.322049  [ 6464/60000]\\n\",\n      \"loss: 1.143545  [12864/60000]\\n\",\n      \"loss: 1.250494  [19264/60000]\\n\",\n      \"loss: 1.123120  [25664/60000]\\n\",\n      \"loss: 1.166146  [32064/60000]\\n\",\n      \"loss: 1.181268  [38464/60000]\\n\",\n      \"loss: 1.112326  [44864/60000]\\n\",\n      \"loss: 1.155791  [51264/60000]\\n\",\n      \"loss: 1.079376  [57664/60000]\\n\",\n      \"Test Error: \\n\",\n      \" Accuracy: 64.0%, Avg loss: 1.092456 \\n\",\n      \"\\n\",\n      \"Done!\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"epochs = 5\\n\",\n    \"for t in range(epochs):\\n\",\n    \"    print(f\\\"Epoch {t+1}\\\\n-------------------------------\\\")\\n\",\n    \"    train(train_dataloader, model, loss_fn, optimizer)\\n\",\n    \"    test(test_dataloader, model, loss_fn)\\n\",\n    \"print(\\\"Done!\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"85d97839\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Save Model State Dict\\n\",\n    \"This saves the serialized object to disk using pickle.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 12,\n   \"id\": \"5d5d24de\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Saved PyTorch Model State to models/model.pt\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"torch.save(model.state_dict(), \\\"models/model.pt\\\")\\n\",\n    \"print(\\\"Saved PyTorch Model State to models/model.pt\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"ac221ca7-e227-4c8c-8577-1eeda4a61fc7\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Save Model as TorchScript\\n\",\n    \"This saves an [intermediate representation of the compute graph](https://pytorch.org/tutorials/beginner/saving_loading_models.html#export-load-model-in-torchscript-format), which does not require pickle (or even python). \"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 13,\n   \"id\": \"6d9b3a45-7618-43e4-8bd3-8bb317a484d3\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Saved TorchScript Model to models/ts_model.pt\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"scripted = torch.jit.script(model)\\n\",\n    \"scripted.save(\\\"models/ts_model.pt\\\")\\n\",\n    \"print(\\\"Saved TorchScript Model to models/ts_model.pt\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"12ee8916-f437-4a2a-9bf4-14ff5376d305\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Load Model State\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 14,\n   \"id\": \"8fe3b5d1\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"<All keys matched successfully>\"\n      ]\n     },\n     \"execution_count\": 14,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"model_from_state = NeuralNetwork().to(device)\\n\",\n    \"model_from_state.load_state_dict(torch.load(\\\"models/model.pt\\\", weights_only=True))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 15,\n   \"id\": \"0c405bd0\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Predicted: \\\"Ankle boot\\\", Actual: \\\"Ankle boot\\\"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"model_from_state.eval()\\n\",\n    \"x, y = test_data[0][0], test_data[0][1]\\n\",\n    \"with torch.no_grad():\\n\",\n    \"    x = torch.flatten(x.to(device), start_dim=1, end_dim=-1)\\n\",\n    \"    pred = model_from_state(x)\\n\",\n    \"    predicted, actual = classes[pred[0].argmax(0)], classes[y]\\n\",\n    \"    print(f'Predicted: \\\"{predicted}\\\", Actual: \\\"{actual}\\\"')\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"290c482a-1c5d-4bf2-bc3f-8a4e53d442b5\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Load Torchscript Model\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 16,\n   \"id\": \"ef3c419e-d384-446c-b07b-1af93e07d6c0\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"ts_model = torch.jit.load(\\\"models/ts_model.pt\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 17,\n   \"id\": \"c92d6cdb\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"x, y = test_data[0][0], test_data[0][1]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"038af830-a360-45eb-ab4e-b1adab0af164\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Predicted: \\\"Ankle boot\\\", Actual: \\\"Ankle boot\\\"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"with torch.no_grad():\\n\",\n    \"    pred = ts_model(torch.flatten(x.to(device), start_dim=1, end_dim=-1))\\n\",\n    \"    predicted, actual = classes[pred[0].argmax(0)], classes[y]\\n\",\n    \"    print(f'Predicted: \\\"{predicted}\\\", Actual: \\\"{actual}\\\"')\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"76980495\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Compile using the Torch JIT Compiler\\n\",\n    \"This leverages the [Torch-TensorRT inference compiler](https://pytorch.org/TensorRT/) for accelerated inference on GPUs using the `torch.compile` JIT interface under the hood. The compiler stack returns a [boxed-function](http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/) that triggers compilation on the first call.  \\n\",\n    \"\\n\",\n    \"Modules compiled in this fashion are [not serializable with pickle](https://github.com/pytorch/pytorch/issues/101107#issuecomment-1542688089), so we cannot send the compiled model directly to Spark. \"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"414bc856\",\n   \"metadata\": {},\n   \"source\": [\n    \"(You may see a warning about modelopt quantization. This is safe to ignore, as [implicit quantization](https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#intro-quantization) is deprecated in the latest TensorRT. See [this link](https://pytorch.org/TensorRT/tutorials/_rendered_examples/dynamo/vgg16_fp8_ptq.html) for a guide to explicit quantization.)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 19,\n   \"id\": \"362b266b\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import torch_tensorrt as trt\\n\",\n    \"import time\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 20,\n   \"id\": \"f0ac1362\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Optional: set the filename for the TensorRT timing cache\\n\",\n    \"timestamp = time.time()\\n\",\n    \"timing_cache = f\\\"/tmp/timing_cache-{timestamp}.bin\\\"\\n\",\n    \"with open(timing_cache, \\\"wb\\\") as f:\\n\",\n    \"    pass\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 21,\n   \"id\": \"f3e3bdc4\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"inputs_bs1 = torch.randn((10, 784), dtype=torch.float).to(\\\"cuda\\\")\\n\",\n    \"# This indicates dimension 0 of inputs_bs1 is dynamic whose range of values is [1, 50]. \\n\",\n    \"torch._dynamo.mark_dynamic(inputs_bs1, 0, min=1, max=64)\\n\",\n    \"trt_model = trt.compile(\\n\",\n    \"    model,\\n\",\n    \"    ir=\\\"torch_compile\\\",\\n\",\n    \"    inputs=inputs_bs1,\\n\",\n    \"    enabled_precisions={torch.float},\\n\",\n    \"    timing_cache_path=timing_cache,\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 22,\n   \"id\": \"66f61302\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"WARNING:torch_tensorrt.dynamo._compiler:Node linear_default of op type call_function does not have metadata. This could sometimes lead to undefined behavior.\\n\",\n      \"WARNING:torch_tensorrt.dynamo._compiler:Some nodes do not have metadata (shape and dtype information). This could lead to problems sometimes if the graph has PyTorch and TensorRT segments.\\n\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Predicted: \\\"Ankle boot\\\", Actual: \\\"Ankle boot\\\"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"stream = torch.cuda.Stream()\\n\",\n    \"with torch.no_grad(), torch.cuda.stream(stream):\\n\",\n    \"    pred = trt_model(torch.flatten(x.to(device), start_dim=1, end_dim=-1))\\n\",\n    \"    predicted, actual = classes[pred[0].argmax(0)], classes[y]\\n\",\n    \"    print(f'Predicted: \\\"{predicted}\\\", Actual: \\\"{actual}\\\"')\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"9ec04be8\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Compile using the Torch-TensorRT AOT Compiler\\n\",\n    \"Alternatively, use the Torch-TensorRT Dynamo backend for Ahead-of-Time (AOT) compilation to eagerly optimize the model in an explicit compilation phase. We first export the model to produce a traced graph representing the Tensor computation in an AOT fashion, which produces a `ExportedProgram` object which can be [serialized and reloaded](https://pytorch.org/TensorRT/user_guide/saving_models.html). We can then compile this IR using the Torch-TensorRT AOT compiler for inference.   \\n\",\n    \"\\n\",\n    \"[Read the docs](https://pytorch.org/TensorRT/user_guide/torch_tensorrt_explained.html) for more information on JIT vs AOT compilation.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 23,\n   \"id\": \"3e7e7689\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"example_inputs = (torch.randn((10, 784), dtype=torch.float).to(\\\"cuda\\\"),)\\n\",\n    \"\\n\",\n    \"# Mark dim 1 (batch size) as dynamic\\n\",\n    \"batch = torch.export.Dim(\\\"batch\\\", min=1, max=64)\\n\",\n    \"# Produce traced graph in ExportedProgram format\\n\",\n    \"exp_program = torch.export.export(model_from_state, args=example_inputs, dynamic_shapes={\\\"x\\\": {0: batch}})\\n\",\n    \"# Compile the traced graph to produce an optimized module\\n\",\n    \"trt_gm = trt.dynamo.compile(exp_program, tuple(example_inputs), enabled_precisions={torch.float}, timing_cache_path=timing_cache)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 24,\n   \"id\": \"6fda0c0e\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"<class 'torch.export.exported_program.ExportedProgram'>\\n\",\n      \"<class 'torch.fx.graph_module.GraphModule.__new__.<locals>.GraphModuleImpl'>\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"print(type(exp_program))\\n\",\n    \"print(type(trt_gm))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 25,\n   \"id\": \"5ed9e4c5\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Predicted: \\\"Ankle boot\\\", Actual: \\\"Ankle boot\\\"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"stream = torch.cuda.Stream()\\n\",\n    \"with torch.no_grad(), torch.cuda.stream(stream):\\n\",\n    \"    trt_gm(torch.flatten(x.to(device), start_dim=1, end_dim=-1))\\n\",\n    \"    predicted, actual = classes[pred[0].argmax(0)], classes[y]\\n\",\n    \"    print(f'Predicted: \\\"{predicted}\\\", Actual: \\\"{actual}\\\"')\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"38697a06\",\n   \"metadata\": {},\n   \"source\": [\n    \"We can run the optimized module with a few different batch sizes (without recompilation!):\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"27871156\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Output shapes:\\n\",\n      \"torch.Size([10, 10])\\n\",\n      \"torch.Size([1, 10])\\n\",\n      \"torch.Size([50, 10])\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"inputs = (torch.randn((10, 784), dtype=torch.float).cuda(),)\\n\",\n    \"inputs_bs1 = (torch.randn((1, 784), dtype=torch.float).cuda(),)\\n\",\n    \"inputs_bs50 = (torch.randn((50, 784), dtype=torch.float).cuda(),)\\n\",\n    \"\\n\",\n    \"stream = torch.cuda.Stream()\\n\",\n    \"with torch.no_grad(), torch.cuda.stream(stream):\\n\",\n    \"    print(\\\"Output shapes:\\\")\\n\",\n    \"    print(trt_gm(*inputs).shape)\\n\",\n    \"    print(trt_gm(*inputs_bs1).shape)\\n\",\n    \"    print(trt_gm(*inputs_bs50).shape)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"ab974244\",\n   \"metadata\": {},\n   \"source\": [\n    \"We can serialize the ExportedProgram (a traced graph representing the model's forward function) using `torch.export.save` to be recompiled at a later date.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"d87e4b20\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Saved ExportedProgram to models/trt_model.ep\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"torch.export.save(exp_program, \\\"models/trt_model.ep\\\")\\n\",\n    \"print(\\\"Saved ExportedProgram to models/trt_model.ep\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"ad918393\",\n   \"metadata\": {},\n   \"source\": [\n    \"## PySpark\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 28,\n   \"id\": \"42c5feba\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from pyspark.sql.functions import col, struct, pandas_udf, array\\n\",\n    \"from pyspark.ml.functions import predict_batch_udf\\n\",\n    \"from pyspark.sql.types import *\\n\",\n    \"from pyspark.sql import SparkSession\\n\",\n    \"from pyspark import SparkConf\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 29,\n   \"id\": \"ef97321d\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import pandas as pd\\n\",\n    \"import numpy as np\\n\",\n    \"import json\\n\",\n    \"import os\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"ece094d6\",\n   \"metadata\": {},\n   \"source\": [\n    \"Check the cluster environment to handle any platform-specific Spark configurations.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 30,\n   \"id\": \"10eb841f\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"on_databricks = os.environ.get(\\\"DATABRICKS_RUNTIME_VERSION\\\", False)\\n\",\n    \"on_dataproc = os.environ.get(\\\"DATAPROC_IMAGE_VERSION\\\", False)\\n\",\n    \"on_standalone = not (on_databricks or on_dataproc)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"425e94ac\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Create Spark Session\\n\",\n    \"\\n\",\n    \"For local standalone clusters, we'll connect to the cluster and create the Spark Session.  \\n\",\n    \"For CSP environments, Spark will either be preconfigured (Databricks) or we'll need to create the Spark Session (Dataproc).\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 31,\n   \"id\": \"60ba6e74\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"25/02/04 13:50:47 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)\\n\",\n      \"25/02/04 13:50:47 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\\n\",\n      \"Setting default log level to \\\"WARN\\\".\\n\",\n      \"To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\\n\",\n      \"25/02/04 13:50:48 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"conf = SparkConf()\\n\",\n    \"\\n\",\n    \"if 'spark' not in globals():\\n\",\n    \"    if on_standalone:\\n\",\n    \"        import socket\\n\",\n    \"        conda_env = os.environ.get(\\\"CONDA_PREFIX\\\")\\n\",\n    \"        hostname = socket.gethostname()\\n\",\n    \"        conf.setMaster(f\\\"spark://{hostname}:7077\\\")\\n\",\n    \"        conf.set(\\\"spark.pyspark.python\\\", f\\\"{conda_env}/bin/python\\\")\\n\",\n    \"        conf.set(\\\"spark.pyspark.driver.python\\\", f\\\"{conda_env}/bin/python\\\")\\n\",\n    \"\\n\",\n    \"    conf.set(\\\"spark.executor.cores\\\", \\\"8\\\")\\n\",\n    \"    conf.set(\\\"spark.task.resource.gpu.amount\\\", \\\"0.125\\\")\\n\",\n    \"    conf.set(\\\"spark.executor.resource.gpu.amount\\\", \\\"1\\\")\\n\",\n    \"    conf.set(\\\"spark.sql.execution.arrow.pyspark.enabled\\\", \\\"true\\\")\\n\",\n    \"    conf.set(\\\"spark.python.worker.reuse\\\", \\\"true\\\")\\n\",\n    \"    \\n\",\n    \"spark = SparkSession.builder.appName(\\\"spark-dl-examples\\\").config(conf=conf).getOrCreate()\\n\",\n    \"sc = spark.sparkContext\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"2cd11476\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Create Spark DataFrame from Pandas DataFrame\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 32,\n   \"id\": \"f063cbe7\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"((10000, 28, 28), dtype('uint8'))\"\n      ]\n     },\n     \"execution_count\": 32,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"data = test_data.data.numpy()\\n\",\n    \"data.shape, data.dtype\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"8c828393\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"((10000, 784), dtype('float64'))\"\n      ]\n     },\n     \"execution_count\": 33,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"data = data.reshape(10000, 784) / 255.0\\n\",\n    \"data.shape, data.dtype\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 34,\n   \"id\": \"7760bdbe\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/html\": [\n       \"<div>\\n\",\n       \"<style scoped>\\n\",\n       \"    .dataframe tbody tr th:only-of-type {\\n\",\n       \"        vertical-align: middle;\\n\",\n       \"    }\\n\",\n       \"\\n\",\n       \"    .dataframe tbody tr th {\\n\",\n       \"        vertical-align: top;\\n\",\n       \"    }\\n\",\n       \"\\n\",\n       \"    .dataframe thead th {\\n\",\n       \"        text-align: right;\\n\",\n       \"    }\\n\",\n       \"</style>\\n\",\n       \"<table border=\\\"1\\\" class=\\\"dataframe\\\">\\n\",\n       \"  <thead>\\n\",\n       \"    <tr style=\\\"text-align: right;\\\">\\n\",\n       \"      <th></th>\\n\",\n       \"      <th>0</th>\\n\",\n       \"      <th>1</th>\\n\",\n       \"      <th>2</th>\\n\",\n       \"      <th>3</th>\\n\",\n       \"      <th>4</th>\\n\",\n       \"      <th>5</th>\\n\",\n       \"      <th>6</th>\\n\",\n       \"      <th>7</th>\\n\",\n       \"      <th>8</th>\\n\",\n       \"      <th>9</th>\\n\",\n       \"      <th>...</th>\\n\",\n       \"      <th>774</th>\\n\",\n       \"      <th>775</th>\\n\",\n       \"      <th>776</th>\\n\",\n       \"      <th>777</th>\\n\",\n       \"      <th>778</th>\\n\",\n       \"      <th>779</th>\\n\",\n       \"      <th>780</th>\\n\",\n       \"      <th>781</th>\\n\",\n       \"      <th>782</th>\\n\",\n       \"      <th>783</th>\\n\",\n       \"    </tr>\\n\",\n       \"  </thead>\\n\",\n       \"  <tbody>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>0</th>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>...</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>1</th>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>...</td>\\n\",\n       \"      <td>0.007843</td>\\n\",\n       \"      <td>0.011765</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.011765</td>\\n\",\n       \"      <td>0.682353</td>\\n\",\n       \"      <td>0.741176</td>\\n\",\n       \"      <td>0.262745</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>2</th>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.003922</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>...</td>\\n\",\n       \"      <td>0.643137</td>\\n\",\n       \"      <td>0.227451</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>3</th>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.082353</td>\\n\",\n       \"      <td>...</td>\\n\",\n       \"      <td>0.003922</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>4</th>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.007843</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.003922</td>\\n\",\n       \"      <td>0.003922</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>...</td>\\n\",\n       \"      <td>0.278431</td>\\n\",\n       \"      <td>0.047059</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>...</th>\\n\",\n       \"      <td>...</td>\\n\",\n       \"      <td>...</td>\\n\",\n       \"      <td>...</td>\\n\",\n       \"      <td>...</td>\\n\",\n       \"      <td>...</td>\\n\",\n       \"      <td>...</td>\\n\",\n       \"      <td>...</td>\\n\",\n       \"      <td>...</td>\\n\",\n       \"      <td>...</td>\\n\",\n       \"      <td>...</td>\\n\",\n       \"      <td>...</td>\\n\",\n       \"      <td>...</td>\\n\",\n       \"      <td>...</td>\\n\",\n       \"      <td>...</td>\\n\",\n       \"      <td>...</td>\\n\",\n       \"      <td>...</td>\\n\",\n       \"      <td>...</td>\\n\",\n       \"      <td>...</td>\\n\",\n       \"      <td>...</td>\\n\",\n       \"      <td>...</td>\\n\",\n       \"      <td>...</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>9995</th>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>...</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>9996</th>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.121569</td>\\n\",\n       \"      <td>...</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>9997</th>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>...</td>\\n\",\n       \"      <td>0.105882</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>9998</th>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>...</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>9999</th>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>...</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.000000</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"    </tr>\\n\",\n       \"  </tbody>\\n\",\n       \"</table>\\n\",\n       \"<p>10000 rows × 784 columns</p>\\n\",\n       \"</div>\"\n      ],\n      \"text/plain\": [\n       \"      0    1    2         3    4         5         6    7         8    \\\\\\n\",\n       \"0     0.0  0.0  0.0  0.000000  0.0  0.000000  0.000000  0.0  0.000000   \\n\",\n       \"1     0.0  0.0  0.0  0.000000  0.0  0.000000  0.000000  0.0  0.000000   \\n\",\n       \"2     0.0  0.0  0.0  0.000000  0.0  0.000000  0.000000  0.0  0.003922   \\n\",\n       \"3     0.0  0.0  0.0  0.000000  0.0  0.000000  0.000000  0.0  0.000000   \\n\",\n       \"4     0.0  0.0  0.0  0.007843  0.0  0.003922  0.003922  0.0  0.000000   \\n\",\n       \"...   ...  ...  ...       ...  ...       ...       ...  ...       ...   \\n\",\n       \"9995  0.0  0.0  0.0  0.000000  0.0  0.000000  0.000000  0.0  0.000000   \\n\",\n       \"9996  0.0  0.0  0.0  0.000000  0.0  0.000000  0.000000  0.0  0.000000   \\n\",\n       \"9997  0.0  0.0  0.0  0.000000  0.0  0.000000  0.000000  0.0  0.000000   \\n\",\n       \"9998  0.0  0.0  0.0  0.000000  0.0  0.000000  0.000000  0.0  0.000000   \\n\",\n       \"9999  0.0  0.0  0.0  0.000000  0.0  0.000000  0.000000  0.0  0.000000   \\n\",\n       \"\\n\",\n       \"           9    ...       774       775  776       777       778       779  \\\\\\n\",\n       \"0     0.000000  ...  0.000000  0.000000  0.0  0.000000  0.000000  0.000000   \\n\",\n       \"1     0.000000  ...  0.007843  0.011765  0.0  0.011765  0.682353  0.741176   \\n\",\n       \"2     0.000000  ...  0.643137  0.227451  0.0  0.000000  0.000000  0.000000   \\n\",\n       \"3     0.082353  ...  0.003922  0.000000  0.0  0.000000  0.000000  0.000000   \\n\",\n       \"4     0.000000  ...  0.278431  0.047059  0.0  0.000000  0.000000  0.000000   \\n\",\n       \"...        ...  ...       ...       ...  ...       ...       ...       ...   \\n\",\n       \"9995  0.000000  ...  0.000000  0.000000  0.0  0.000000  0.000000  0.000000   \\n\",\n       \"9996  0.121569  ...  0.000000  0.000000  0.0  0.000000  0.000000  0.000000   \\n\",\n       \"9997  0.000000  ...  0.105882  0.000000  0.0  0.000000  0.000000  0.000000   \\n\",\n       \"9998  0.000000  ...  0.000000  0.000000  0.0  0.000000  0.000000  0.000000   \\n\",\n       \"9999  0.000000  ...  0.000000  0.000000  0.0  0.000000  0.000000  0.000000   \\n\",\n       \"\\n\",\n       \"           780  781  782  783  \\n\",\n       \"0     0.000000  0.0  0.0  0.0  \\n\",\n       \"1     0.262745  0.0  0.0  0.0  \\n\",\n       \"2     0.000000  0.0  0.0  0.0  \\n\",\n       \"3     0.000000  0.0  0.0  0.0  \\n\",\n       \"4     0.000000  0.0  0.0  0.0  \\n\",\n       \"...        ...  ...  ...  ...  \\n\",\n       \"9995  0.000000  0.0  0.0  0.0  \\n\",\n       \"9996  0.000000  0.0  0.0  0.0  \\n\",\n       \"9997  0.000000  0.0  0.0  0.0  \\n\",\n       \"9998  0.000000  0.0  0.0  0.0  \\n\",\n       \"9999  0.000000  0.0  0.0  0.0  \\n\",\n       \"\\n\",\n       \"[10000 rows x 784 columns]\"\n      ]\n     },\n     \"execution_count\": 34,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"pdf784 = pd.DataFrame(data)\\n\",\n    \"pdf784\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 35,\n   \"id\": \"f7d2bc0d\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/html\": [\n       \"<div>\\n\",\n       \"<style scoped>\\n\",\n       \"    .dataframe tbody tr th:only-of-type {\\n\",\n       \"        vertical-align: middle;\\n\",\n       \"    }\\n\",\n       \"\\n\",\n       \"    .dataframe tbody tr th {\\n\",\n       \"        vertical-align: top;\\n\",\n       \"    }\\n\",\n       \"\\n\",\n       \"    .dataframe thead th {\\n\",\n       \"        text-align: right;\\n\",\n       \"    }\\n\",\n       \"</style>\\n\",\n       \"<table border=\\\"1\\\" class=\\\"dataframe\\\">\\n\",\n       \"  <thead>\\n\",\n       \"    <tr style=\\\"text-align: right;\\\">\\n\",\n       \"      <th></th>\\n\",\n       \"      <th>data</th>\\n\",\n       \"    </tr>\\n\",\n       \"  </thead>\\n\",\n       \"  <tbody>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>0</th>\\n\",\n       \"      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>1</th>\\n\",\n       \"      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>2</th>\\n\",\n       \"      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.003...</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>3</th>\\n\",\n       \"      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>4</th>\\n\",\n       \"      <td>[0.0, 0.0, 0.0, 0.00784313725490196, 0.0, 0.00...</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>...</th>\\n\",\n       \"      <td>...</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>9995</th>\\n\",\n       \"      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>9996</th>\\n\",\n       \"      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>9997</th>\\n\",\n       \"      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>9998</th>\\n\",\n       \"      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>9999</th>\\n\",\n       \"      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\\n\",\n       \"    </tr>\\n\",\n       \"  </tbody>\\n\",\n       \"</table>\\n\",\n       \"<p>10000 rows × 1 columns</p>\\n\",\n       \"</div>\"\n      ],\n      \"text/plain\": [\n       \"                                                   data\\n\",\n       \"0     [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...\\n\",\n       \"1     [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...\\n\",\n       \"2     [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.003...\\n\",\n       \"3     [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...\\n\",\n       \"4     [0.0, 0.0, 0.0, 0.00784313725490196, 0.0, 0.00...\\n\",\n       \"...                                                 ...\\n\",\n       \"9995  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...\\n\",\n       \"9996  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...\\n\",\n       \"9997  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...\\n\",\n       \"9998  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...\\n\",\n       \"9999  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...\\n\",\n       \"\\n\",\n       \"[10000 rows x 1 columns]\"\n      ]\n     },\n     \"execution_count\": 35,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"# 1 column of array<float>\\n\",\n    \"pdf1 = pd.DataFrame()\\n\",\n    \"pdf1['data'] = pdf784.values.tolist()\\n\",\n    \"pdf1\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"07b2a70b\",\n   \"metadata\": {},\n   \"source\": [\n    \"Create dataframes with a single column of 784 floats and 784 separate columns.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 36,\n   \"id\": \"4863d5ff\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 185 ms, sys: 28.9 ms, total: 214 ms\\n\",\n      \"Wall time: 1.5 s\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"StructType([StructField('data', ArrayType(FloatType(), True), True)])\"\n      ]\n     },\n     \"execution_count\": 36,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"# force FloatType since Spark defaults to DoubleType\\n\",\n    \"schema = StructType([StructField(\\\"data\\\",ArrayType(FloatType()), True)])\\n\",\n    \"df = spark.createDataFrame(pdf1, schema).repartition(8)\\n\",\n    \"df.schema\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 37,\n   \"id\": \"831f4a01-3a49-4114-b9a0-2ae54526d72d\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 66.9 ms, sys: 11.2 ms, total: 78.1 ms\\n\",\n      \"Wall time: 875 ms\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"StructType([StructField('data', ArrayType(FloatType(), True), True)])\"\n      ]\n     },\n     \"execution_count\": 37,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"# force FloatType since Spark defaults to DoubleType\\n\",\n    \"schema = StructType([StructField(str(x), FloatType(), True) for x in range(784)])\\n\",\n    \"df784 = spark.createDataFrame(pdf784, schema).repartition(8)\\n\",\n    \"df.schema\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 38,\n   \"id\": \"e8ebae46\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"25/02/04 13:50:51 WARN TaskSetManager: Stage 0 contains a task of very large size (4030 KiB). The maximum recommended task size is 1000 KiB.\\n\",\n      \"[Stage 0:=======>                                                   (1 + 7) / 8]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 2.09 ms, sys: 1.6 ms, total: 3.69 ms\\n\",\n      \"Wall time: 1.71 s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"data_path_1 = \\\"spark-dl-datasets/fashion_mnist_1\\\"\\n\",\n    \"if on_databricks:\\n\",\n    \"    dbutils.fs.mkdirs(\\\"/FileStore/spark-dl-datasets\\\")\\n\",\n    \"    data_path_1 = \\\"dbfs:/FileStore/\\\" + data_path_1\\n\",\n    \"\\n\",\n    \"df.write.mode(\\\"overwrite\\\").parquet(data_path_1)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 39,\n   \"id\": \"922314ce-2996-4666-9fc9-bcd98d16bb56\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"25/02/04 13:50:53 WARN package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"25/02/04 13:50:53 WARN TaskSetManager: Stage 3 contains a task of very large size (7847 KiB). The maximum recommended task size is 1000 KiB.\\n\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 2.94 ms, sys: 61 μs, total: 3 ms\\n\",\n      \"Wall time: 943 ms\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"data_path_784 = \\\"spark-dl-datasets/fashion_mnist_784\\\"\\n\",\n    \"if on_databricks:\\n\",\n    \"    dbutils.fs.mkdirs(\\\"/FileStore/spark-dl-datasets\\\")\\n\",\n    \"    data_path_784 = \\\"dbfs:/FileStore/\\\" + data_path_784\\n\",\n    \"\\n\",\n    \"df784.write.mode(\\\"overwrite\\\").parquet(data_path_784)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"fce89cb0\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Inference using Spark DL API\\n\",\n    \"\\n\",\n    \"Distributed inference using the PySpark [predict_batch_udf](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.functions.predict_batch_udf.html#pyspark.ml.functions.predict_batch_udf):\\n\",\n    \"\\n\",\n    \"- predict_batch_fn uses PyTorch APIs to load the model and return a predict function which operates on numpy arrays \\n\",\n    \"- predict_batch_udf will convert the Spark DataFrame columns into numpy input batches for the predict function\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"59395856-a588-43c6-93c8-c83100716ac1\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"source\": [\n    \"### 1 column of 784 float\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 40,\n   \"id\": \"79b151d9-d112-43b6-a479-887e2fd0e2b1\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"1\"\n      ]\n     },\n     \"execution_count\": 40,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"df = spark.read.parquet(data_path_1)\\n\",\n    \"len(df.columns)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 41,\n   \"id\": \"3e6a4dbb\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# A resource warning may occur due to unclosed file descriptors used by TensorRT across multiple PySpark daemon processes.\\n\",\n    \"# These can be safely ignored as the resources will be cleaned up when the worker processes terminate.\\n\",\n    \"\\n\",\n    \"import warnings\\n\",\n    \"warnings.simplefilter(\\\"ignore\\\", ResourceWarning)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 42,\n   \"id\": \"16e523c2\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# get absolute path to model\\n\",\n    \"model_path = \\\"{}/models/trt_model.ep\\\".format(os.getcwd())\\n\",\n    \"\\n\",\n    \"# For cloud environments, copy the model to the distributed file system.\\n\",\n    \"if on_databricks:\\n\",\n    \"    dbutils.fs.mkdirs(\\\"/FileStore/spark-dl-models\\\")\\n\",\n    \"    dbfs_model_path = \\\"/dbfs/FileStore/spark-dl-models/model.pt\\\"\\n\",\n    \"    shutil.copy(model_path, dbfs_model_path)\\n\",\n    \"    model_path = dbfs_model_path\\n\",\n    \"elif on_dataproc:\\n\",\n    \"    # GCS is mounted at /mnt/gcs by the init script\\n\",\n    \"    models_dir = \\\"/mnt/gcs/spark-dl/models\\\"\\n\",\n    \"    os.mkdir(models_dir) if not os.path.exists(models_dir) else None\\n\",\n    \"    gcs_model_path = models_dir + \\\"/trt_model.ep\\\"\\n\",\n    \"    shutil.copy(model_path, gcs_model_path)\\n\",\n    \"    model_path = gcs_model_path\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 43,\n   \"id\": \"73dc73cb-25e3-4798-a019-e1abd684eaa1\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def predict_batch_fn():\\n\",\n    \"    import torch\\n\",\n    \"    import torch_tensorrt as trt\\n\",\n    \"    \\n\",\n    \"    device = \\\"cuda\\\" if torch.cuda.is_available() else \\\"cpu\\\"\\n\",\n    \"    if device != \\\"cuda\\\":\\n\",\n    \"        raise ValueError(\\\"This function uses the TensorRT model which requires a GPU device\\\")\\n\",\n    \"\\n\",\n    \"    example_inputs = (torch.randn((50, 784), dtype=torch.float).to(\\\"cuda\\\"),)\\n\",\n    \"    exp_program = torch.export.load(model_path)\\n\",\n    \"    trt_gm = trt.dynamo.compile(exp_program,\\n\",\n    \"                                tuple(example_inputs),\\n\",\n    \"                                enabled_precisions={torch.float},\\n\",\n    \"                                workspace_size=1<<30)\\n\",\n    \"\\n\",\n    \"    def predict(inputs: np.ndarray):\\n\",\n    \"        stream = torch.cuda.Stream()\\n\",\n    \"        with torch.no_grad(), torch.cuda.stream(stream):\\n\",\n    \"            # use array to combine columns into tensors\\n\",\n    \"            torch_inputs = torch.from_numpy(inputs).to(device)\\n\",\n    \"            outputs = trt_gm(torch_inputs)\\n\",\n    \"            return outputs.detach().cpu().numpy()\\n\",\n    \"\\n\",\n    \"    return predict\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 44,\n   \"id\": \"df68cca1-2d47-4e88-8aad-9899402aee97\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"mnist = predict_batch_udf(predict_batch_fn,\\n\",\n    \"                          input_tensor_shapes=[[784]],\\n\",\n    \"                          return_type=ArrayType(FloatType()),\\n\",\n    \"                          batch_size=50)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 45,\n   \"id\": \"63555b3b-3673-4712-97aa-fd728c6c4979\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 167 ms, sys: 76.2 ms, total: 243 ms\\n\",\n      \"Wall time: 10.9 s\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"# first pass compiles and caches model/fn\\n\",\n    \"preds = df.withColumn(\\\"preds\\\", mnist(struct(df.columns))).collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 46,\n   \"id\": \"5dbf058a-70d6-4199-af9d-13843d078950\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 234 ms, sys: 64.1 ms, total: 298 ms\\n\",\n      \"Wall time: 685 ms\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"preds = df.withColumn(\\\"preds\\\", mnist(*df.columns)).collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 47,\n   \"id\": \"3f5ed801-6ca5-43a0-bf9c-2535a0dfe2e8\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 403 ms, sys: 60.1 ms, total: 463 ms\\n\",\n      \"Wall time: 809 ms\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"preds = df.withColumn(\\\"preds\\\", mnist(*[col(c) for c in df.columns])).collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"c6dbec03-9b64-46c4-a748-f889be571384\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"source\": [\n    \"### Check predictions\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 48,\n   \"id\": \"f1f1e5fd-5866-4b78-b9d3-709e6b383a0c\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"predictions = preds[0].preds\\n\",\n    \"img = preds[0].data\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 49,\n   \"id\": \"76b76502-adb7-45ec-a365-2e61cdd576fc\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import matplotlib.pyplot as plt\\n\",\n    \"import numpy as np\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 50,\n   \"id\": \"c163953a-1504-444f-b39f-86b61d34e440\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"img = np.array(img).reshape(28,28)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 51,\n   \"id\": \"bc0fad05-50ab-4ae5-b9fd-e50133c4c92a\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"image/png\": \"iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAJJ5JREFUeJzt3Xt0lfWd7/HPTkg2t2THEHKTQAMotHLplErKqBRLDpDOuEA5HW9zDrg6MNLgqlKrJz1Watuz0uIa66lDca2zWqir4oVzREbHYgUljAp0QBjGXlLAKGEgoWCTDQlJdrJ/5w/GzERB+P5M8kvC+7XWXovs/Xx4fnnyJJ882TvfRJxzTgAA9LKU0AsAAFyaKCAAQBAUEAAgCAoIABAEBQQACIICAgAEQQEBAIKggAAAQQwKvYAPSyaTOnr0qDIyMhSJREIvBwBg5JzTqVOnVFhYqJSU81/n9LkCOnr0qIqKikIvAwDwCdXW1mrUqFHnfbzPFVBGRoYk6Vp9WYOUFng16Ha9dVXLhCkgmHYl9Lpe6vx6fj49VkCrV6/Www8/rLq6Ok2dOlWPPfaYpk+ffsHcBz92G6Q0DYpQQANOr/1YlQICgvn3T78LPY3SIy9CeOaZZ7RixQqtXLlSb731lqZOnaq5c+fq+PHjPbE7AEA/1CMF9Mgjj2jJkiW644479JnPfEaPP/64hg4dqp/97Gc9sTsAQD/U7QXU1tamPXv2qLS09D92kpKi0tJS7dix4yPbt7a2Kh6Pd7kBAAa+bi+gEydOqKOjQ3l5eV3uz8vLU11d3Ue2r6ysVCwW67zxCjgAuDQE/0XUiooKNTY2dt5qa2tDLwkA0Au6/VVwOTk5Sk1NVX19fZf76+vrlZ+f/5Hto9GootFody8DANDHdfsVUHp6uqZNm6atW7d23pdMJrV161bNmDGju3cHAOineuT3gFasWKFFixbp85//vKZPn65HH31UTU1NuuOOO3pidwCAfqhHCujmm2/WH//4Rz344IOqq6vTZz/7WW3evPkjL0wAAFy6Is71rZkl8XhcsVhMszSfSQgYsFLHF5szx/7O/lxp7vzfmzPAJ9XuEtqmTWpsbFRmZuZ5twv+KjgAwKWJAgIABEEBAQCCoIAAAEFQQACAICggAEAQFBAAIAgKCAAQBAUEAAiCAgIABEEBAQCCoIAAAEH0yDRsoL+q+YH9b1bdP3+jOVPVcP4BjeczPu2MOZO/356RpF9smG3OFH3vTa994dLFFRAAIAgKCAAQBAUEAAiCAgIABEEBAQCCoIAAAEFQQACAICggAEAQFBAAIAgKCAAQBAUEAAiCAgIABEEBAQCCYBp2HxYZZP/wuI4O+46cs2c8RdLSzRmXaDNnBhWPMWckacttD5szX/zV3ebMlX+z25ypNyek3Tdf75GSHvzuU+bM2u/5HXOzSMSe6cVzHBePKyAAQBAUEAAgCAoIABAEBQQACIICAgAEQQEBAIKggAAAQVBAAIAgKCAAQBAUEAAgCAoIABAEBQQACCLiXN+a0hePxxWLxTRL8zUokhZ6Od3HY4BiJDXVnPEaRuqrb506Xfxh7TSv3NiiP5ozg0oPe+2rLzvzcrE5c8Pl+82ZLZMyzBn0fe0uoW3apMbGRmVmZp53O66AAABBUEAAgCAoIABAEBQQACAICggAEAQFBAAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACCIQaEXcMnwGNzp2tvt+/EYetqXh4pKUsrUT5szW770v732VfrLFebMlfIYRppiHzQbSbF/bL3OIUlpP8w2Z25e9y/mzIbF3zRnLlu3w5xB38QVEAAgCAoIABBEtxfQd77zHUUikS63iRMndvduAAD9XI88B3TVVVdpy5Yt/7GTQTzVBADoqkeaYdCgQcrPz++J/xoAMED0yHNABw4cUGFhocaOHavbb79dhw+f/1VCra2tisfjXW4AgIGv2wuopKRE69at0+bNm7VmzRrV1NTouuuu06lTp865fWVlpWKxWOetqKiou5cEAOiDur2AysrK9JWvfEVTpkzR3Llz9dJLL6mhoUHPPvvsObevqKhQY2Nj5622tra7lwQA6IN6/NUBWVlZuvLKK3Xw4MFzPh6NRhWNRnt6GQCAPqbHfw/o9OnTOnTokAoKCnp6VwCAfqTbC+jee+9VVVWV3n33Xb355pu68cYblZqaqltvvbW7dwUA6Me6/UdwR44c0a233qqTJ09q5MiRuvbaa7Vz506NHDmyu3cFAOjHur2Ann766e7+L2HhM1jUYzCmJCnZYY7Eb/2COTNm+R/MmcdPXmfOSFLhq700ncolzZFIdKh9N57DSOu+YH9e9p1EpjnzzEMPmzMr//bL5kz9DH69oy9iFhwAIAgKCAAQBAUEAAiCAgIABEEBAQCCoIAAAEFQQACAICggAEAQFBAAIAgKCAAQBAUEAAiCAgIABNHjf5Cu10Qi9ozP4M7e5DMk1GPIpc9QUV9Xf2OPOVPdmGfONLenmzOSNPzZnV45q0iq5wDYXpJ2yp55s+kKc+bRP33KnLlr1BZz5oHblpgzkpS53uN86K2vRT778d1XD+EKCAAQBAUEAAiCAgIABEEBAQCCoIAAAEFQQACAICggAEAQFBAAIAgKCAAQBAUEAAiCAgIABEEBAQCCoIAAAEEMnGnYfZ3H5Fqficku0XuTrYdtH2nOtDv7mOXUFPuE73c3jTVnJKlAdV45K9fh8XFqS3T/Qs4j77E3zZlvVVSbM9cevcqcWXlgvjlz8//cbM5I0ssvFJkzyVMeo8R9+E617kN/OYArIABAEBQQACAICggAEAQFBAAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACAICggAEAQFBAAIYuAMI/UYlhdJS/fbVaLNI2Rfn9d+PBy978+9ct/MfdacefLfvmDOJGUfnljwiH2YZq/qw+eDr181p5kz/33MTnNm1d455kxqkd8wzYxf2r9GNF7rtaveE/G47nA9M+SYKyAAQBAUEAAgCAoIABAEBQQACIICAgAEQQEBAIKggAAAQVBAAIAgKCAAQBAUEAAgCAoIABAEBQQACCLinMdUxB4Uj8cVi8U0S/M1KGIfbjiQHF9uHxLammXfz6+WrrKHJK1r+Lw5Myr9fXPme7+8yZyJ/cE+wFSSrv+bXebMxt981pwZ9G9RcyalzeN9ivh9evvsq3VE0pwZOeGEOROLtpgzZ9r9vpasHP8P5syyDUvNmeL/scOc6cvaXULbtEmNjY3KzMw873ZcAQEAgqCAAABBmAto+/btuuGGG1RYWKhIJKLnn3++y+POOT344IMqKCjQkCFDVFpaqgMHDnTXegEAA4S5gJqamjR16lStXr36nI+vWrVKP/7xj/X4449r165dGjZsmObOnauWFvvPbQEAA5f5L6KWlZWprKzsnI855/Too4/qgQce0Pz58yVJTzzxhPLy8vT888/rlltu+WSrBQAMGN36HFBNTY3q6upUWlraeV8sFlNJSYl27Dj3qzxaW1sVj8e73AAAA1+3FlBdXZ0kKS8vr8v9eXl5nY99WGVlpWKxWOetqKioO5cEAOijgr8KrqKiQo2NjZ232tra0EsCAPSCbi2g/Px8SVJ9fX2X++vr6zsf+7BoNKrMzMwuNwDAwNetBVRcXKz8/Hxt3bq18754PK5du3ZpxowZ3bkrAEA/Z34V3OnTp3Xw4MHOt2tqarRv3z5lZ2dr9OjRuvvuu/X9739fV1xxhYqLi/Xtb39bhYWFWrBgQXeuGwDQz5kLaPfu3br++us7316xYoUkadGiRVq3bp3uu+8+NTU1aenSpWpoaNC1116rzZs3a/Dgwd23agBAv3dJDyN9Z5XfjwX/7safmzMr/vmvzJn09HZz5gdTnzNnNjdMMWckKUX2U+d3jXkX3uhD3tt7uTnTcVnCnJGkwbXp5kzmO/bjkNJuz3Sk2weEJs3fYp7lUu2ZZJp9fZGk/Ticntlszny26Ig5I0lHT8fMmWvy3jFn3nrf/urfI+9nmTOSNPor/+qVs2AYKQCgT6OAAABBUEAAgCAoIABAEBQQACAICggAEAQFBAAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACAIz1m5A8Nvb/97r9zCg39hzuT8o/3PUZxeeMqc+dnR68yZxja/P5VxR9Eb5syJtmHmTM3gpDmjDvtkZklqu8y+r8RX/mTOjIo1mjMjo6fNmWiqfaK6JGUNsk+cTniM0G7qiJozMzOr7ftJ2vcjSW8OGm/OpMp+Dg0b1GbOvDR9jTkjSbfefq85E3typ9e+LoQrIABAEBQQACAICggAEAQFBAAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACAICggAEAQFBAAIYsAMI225Ybo5kxbZ57evb+WbMzn/6z1z5uXxz5kzD5+wH4ehKfZBiJK08vUF5kxK3H7KuSyPgZoe80vP7ithzsQPXGbOHGwYYc7U2meeKrXV2UOeOqL2AbDOY2bs6+mfM2duX/yKfUeSrsv6gznzucGHzZmX064yZ/7in+80ZyTpbyp+Zc68/GSm174uhCsgAEAQFBAAIAgKCAAQBAUEAAiCAgIABEEBAQCCoIAAAEFQQACAICggAEAQFBAAIAgKCAAQBAUEAAhiwAwjrf98770rWT+sNWf+MudfzJmfNtgHFP5V1j+bM9+t/UtzRpIu/2WqOZMY6jF9UmnmRCTpN4TTpdjX15Fu308yzb4+n7W1Zvkcb0kesYjHzFif/Qz/N/uk2cf/6Xr7jiT9Yf4ac+Y3bfZ3alFsvznzj5mTzRlJ+uvYv5ozv/qzpabtIx2t0r9suuB2XAEBAIKggAAAQVBAAIAgKCAAQBAUEAAgCAoIABAEBQQACIICAgAEQQEBAIKggAAAQVBAAIAgKCAAQBADZhhp2une29e9l282Zzb8abo5k5seN2e+eei/mjPNf3+5OSNJZ3Lt37/4DMdMbTNHlEz1G8KZ0ksDNX0Gdzr77FevtUlS+xC/nJXPcTg12n7eFVTZ9yNJX5s205wpHNxgzlSfzjNnFl6+15yRpBEp9g/umcuHmbZvT6RKFzF/mSsgAEAQFBAAIAhzAW3fvl033HCDCgsLFYlE9Pzzz3d5fPHixYpEIl1u8+bN6671AgAGCHMBNTU1aerUqVq9evV5t5k3b56OHTvWeXvqqac+0SIBAAOP+UUIZWVlKisr+9htotGo8vPzvRcFABj4euQ5oG3btik3N1cTJkzQsmXLdPLkyfNu29raqng83uUGABj4ur2A5s2bpyeeeEJbt27VD3/4Q1VVVamsrEwdHR3n3L6yslKxWKzzVlRU1N1LAgD0Qd3+e0C33HJL578nT56sKVOmaNy4cdq2bZtmz579ke0rKiq0YsWKzrfj8TglBACXgB5/GfbYsWOVk5OjgwcPnvPxaDSqzMzMLjcAwMDX4wV05MgRnTx5UgUFBT29KwBAP2L+Edzp06e7XM3U1NRo3759ys7OVnZ2th566CEtXLhQ+fn5OnTokO677z6NHz9ec+fO7daFAwD6N3MB7d69W9dff33n2x88f7No0SKtWbNG+/fv189//nM1NDSosLBQc+bM0fe+9z1Fo9HuWzUAoN8zF9CsWbPknDvv4y+//PInWpCvQc29t68DbfbfcXo43z448P+dtj8f1viz/2LOJGN+EysTw+y5yLlfDPmxOtLtGd8hnB9zap9/V700WNRrQKjncRh0xp5J8Rga63M++By7jqjfgfj1+qnmzIYVD5sz/5Q+zpy5esi75owkxZNJc2ZY9QnT9u0drRe1HbPgAABBUEAAgCAoIABAEBQQACAICggAEAQFBAAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACAICggAEES3/0nuUIacsE949TUu7bg5c6TdPl74vpfuNGcyRti/p0gMM0ckSelxe8Z5fMvTMdiekcdUa0lqH+qxK4/pzD7r85m67Ssx3L7ApMdXk9Q2+5TqlIsbtNxF/FN+07AzDtuPw7cOzzdn/u+4LebMtjMeJ6ukwtRT5kzHgXds27vERW3HFRAAIAgKCAAQBAUEAAiCAgIABEEBAQCCoIAAAEFQQACAICggAEAQFBAAIAgKCAAQBAUEAAiCAgIABDFghpFmHjptziRch9e+xqa1mDMzdywzZ/LfsA9C/NNEc8R7yGVLjj3TEbW/T9H3PQZJen5r5TMstX2o/X1qz7CfeynDL27A43+WbPWZlCpFj6aZM2mn7R8nr2PnMZw2tdVvGOmZXHvunfVXmDO/u/8fzBlpuEdGuixliDmTOmG8aXvX0SoduPB2XAEBAIKggAAAQVBAAIAgKCAAQBAUEAAgCAoIABAEBQQACIICAgAEQQEBAIKggAAAQVBAAIAgKCAAQBADZhhpSnObOZMW8RvU+K9tmebM8F/ZBwf+aaLHAMWkPdIxxD4QUpLasuw7i56wH/P0U/b1pS84bs5I0snGYeZMR5v90yj9cNScyd5u/34x4veh1Zls+7nXXOgxWNRjGKl85or6zSJVIsO+viEp9o/T/J13mjO/nPETc0aS2mU/9yIJ28TiSPLitucKCAAQBAUEAAiCAgIABEEBAQCCoIAAAEFQQACAICggAEAQFBAAIAgKCAAQBAUEAAiCAgIABEEBAQCCGDDDSNXeYY40Js947epvd37NnEkbZZ+G2Jpjf59SWu37iST8JjWm5zebM2N/FDdn2t87Ys7UR0vMGUkav/WkPXS01p7JHWGOvHdTrjnTPMY2RPIDkaH24b7ujMdw36TH+dpuz3Sk+k1ldYPsueZR9kzGTvuw4rrpQ80ZSRqXZr/uaH/nXdv2LnFR23EFBAAIggICAARhKqDKykpdffXVysjIUG5urhYsWKDq6uou27S0tKi8vFwjRozQ8OHDtXDhQtXX13frogEA/Z+pgKqqqlReXq6dO3fqlVdeUSKR0Jw5c9TU1NS5zT333KMXXnhBGzZsUFVVlY4ePaqbbrqp2xcOAOjfTC9C2Lx5c5e3161bp9zcXO3Zs0czZ85UY2OjfvrTn2r9+vX60pe+JElau3atPv3pT2vnzp36whe+0H0rBwD0a5/oOaDGxkZJUnZ2tiRpz549SiQSKi0t7dxm4sSJGj16tHbs2HHO/6O1tVXxeLzLDQAw8HkXUDKZ1N13361rrrlGkyZNkiTV1dUpPT1dWVlZXbbNy8tTXV3dOf+fyspKxWKxzltRUZHvkgAA/Yh3AZWXl+vtt9/W008//YkWUFFRocbGxs5bba3H71QAAPodr19EXb58uV588UVt375do0aN6rw/Pz9fbW1tamho6HIVVF9fr/z8/HP+X9FoVNFo1GcZAIB+zHQF5JzT8uXLtXHjRr366qsqLi7u8vi0adOUlpamrVu3dt5XXV2tw4cPa8aMGd2zYgDAgGC6AiovL9f69eu1adMmZWRkdD6vE4vFNGTIEMViMX31q1/VihUrlJ2drczMTN11112aMWMGr4ADAHRhKqA1a9ZIkmbNmtXl/rVr12rx4sWSpB/96EdKSUnRwoUL1draqrlz5+onP/lJtywWADBwmArIuQsP2Rs8eLBWr16t1atXey/KS6r99RQvNo268Ebn4JIeGY95n6nNHkMDh3sMMI34vRYl2Wp/CrH5ypHmTHrNe+bM5c8eMmckqf4vxpozJ6/JMGeyRpw2Z1pP24fnRt5PN2ckSQ1p9n35zLT1OfV8Mn6zSBVJ2HfmhtsHwDYX2vfz11v+1pyRpJq//D/mTOqIbNP2LtkmvX/h7ZgFBwAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACAICggAEAQFBAAIggICAARBAQEAgqCAAABBUEAAgCC8/iJqX9TxuwPmzLCUVq99PXHNT82Zv04sNWcizanmTGosYc4kO/wmJrsW++kTX95ozqTcdaU509xqn+YsSc6dsofeH2KONL6bZc5E7IPOpTTPMdD2U89r4rRL9Qh5TKOP+Iyjl+QG2w96aoP986JjmP2dSq/vvS/fLX9WfOGN/pP29hbptQtvxxUQACAICggAEAQFBAAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACAICggAEAQFBAAIggICAAQxYIaR+shKafbKpXpMXfzRdU+bM4WD/mTO/OLkn5szL74xzZyRJHXYBzy+fyTLnElttn+f5HrxW6uIx0BNl2YfPun8ZsZ6ibR7DO/0mffpMyvVYz8uxXMoq8c5nox67ssoJeE3YLU52WbOtGXZqqI9cXHbcwUEAAiCAgIABEEBAQCCoIAAAEFQQACAICggAEAQFBAAIAgKCAAQBAUEAAiCAgIABEEBAQCCoIAAAEFc0sNIt5y6yis3Zehhc+Zyj8GiE9LazZlPDT5pzvy3Wf9kzkjS+t9cbQ8dHmKOpNgPgxIx+7BPSYp4DJ+MJO2ZlDN+gyStXO/sRpIU8ZjB6VLsC/QZNOuztrP78nmnfM4h+24GtdgzklTT3mHODD6RMG3f3n5x23MFBAAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACAICggAEAQFBAAIggICAARBAQEAgqCAAABBXNrDSI9O8MrFRjebM+PS/mjO/CJ+pTmTFrEPGizL2G/OSNKUz9eaMyNKTpszean2zMb4n5kzkvTSUfuA2vak/fu45rY0cybpsZ9omm2I5AeGeAzC9TkOg1LsUzh95oqebol6pKTYEPvEz8LhjeZMW0eqOZP0mcoq6Y0z48yZYzMGm7bvaJV0ETOOuQICAARBAQEAgjAVUGVlpa6++mplZGQoNzdXCxYsUHV1dZdtZs2apUgk0uV25513duuiAQD9n6mAqqqqVF5erp07d+qVV15RIpHQnDlz1NTU1GW7JUuW6NixY523VatWdeuiAQD9n+lFCJs3b+7y9rp165Sbm6s9e/Zo5syZnfcPHTpU+fn53bNCAMCA9ImeA2psPPtqj+zs7C73P/nkk8rJydGkSZNUUVGh5ubzv2qstbVV8Xi8yw0AMPB5vww7mUzq7rvv1jXXXKNJkyZ13n/bbbdpzJgxKiws1P79+3X//ferurpazz333Dn/n8rKSj300EO+ywAA9FPeBVReXq63335br7/+epf7ly5d2vnvyZMnq6CgQLNnz9ahQ4c0btxHX39eUVGhFStWdL4dj8dVVFTkuywAQD/hVUDLly/Xiy++qO3bt2vUqFEfu21JSYkk6eDBg+csoGg0qmjU75fEAAD9l6mAnHO66667tHHjRm3btk3FxcUXzOzbt0+SVFBQ4LVAAMDAZCqg8vJyrV+/Xps2bVJGRobq6uokSbFYTEOGDNGhQ4e0fv16ffnLX9aIESO0f/9+3XPPPZo5c6amTJnSI+8AAKB/MhXQmjVrJJ39ZdP/bO3atVq8eLHS09O1ZcsWPfroo2pqalJRUZEWLlyoBx54oNsWDAAYGMw/gvs4RUVFqqqq+kQLAgBcGi7padhDPCcFL8v6jTlzMBExZ8qz7NOm/dgn8Z5l/52t5mSbOTM0Zag58+mc6gtvdA7LLttrzlyWal+fjzda7JOjG5J+a8tPtX9sB3tMYu+Q/fOixdnP15EpreaMJBWnDTdn9rTaz/HxafZjt6Mly5yRpPV/LDFnRlW+adq+3SV04CK2YxgpACAICggAEAQFBAAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACAICggAEAQFBAAIggICAARxaQ8j/brfX2Itm/B1cybaYB98mkzrne8POqJ++zlyvT0XybMPhcx4c4g5U/jLo+aMJLlU+/vUcdkwcyb1VIs5o2PHzRGXaLfvR1JkqH2IaWS4x+DTC0zYP6d2++BOX81X2f+Qps/n7fA9h82Z9mN15sxZ9kGzPYUrIABAEBQQACAICggAEAQFBAAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACAICggAEESfmwXn/n02VLsSkseYKNO+OuxzySSpPWGf45Xa7jELLtJLs+BS/PaTbPGYBddsP+YdbRFzpj3p97F1Ht+TdbSn2vfjc+65NnvEec6CS9q/NESS9uPgNQsu2Xuz4Nrb7Z/rSY9zqD1p/9i2O/vXlN7SrrNrcxf4+EbchbboZUeOHFFRUVHoZQAAPqHa2lqNGjXqvI/3uQJKJpM6evSoMjIyFIl0/c43Ho+rqKhItbW1yszMDLTC8DgOZ3EczuI4nMVxOKsvHAfnnE6dOqXCwkKlfMxPWPrcj+BSUlI+tjElKTMz85I+wT7AcTiL43AWx+EsjsNZoY9DLBa74Da8CAEAEAQFBAAIol8VUDQa1cqVKxWN+v0l04GC43AWx+EsjsNZHIez+tNx6HMvQgAAXBr61RUQAGDgoIAAAEFQQACAICggAEAQ/aaAVq9erU996lMaPHiwSkpK9Otf/zr0knrdd77zHUUikS63iRMnhl5Wj9u+fbtuuOEGFRYWKhKJ6Pnnn+/yuHNODz74oAoKCjRkyBCVlpbqwIEDYRbbgy50HBYvXvyR82PevHlhFttDKisrdfXVVysjI0O5ublasGCBqquru2zT0tKi8vJyjRgxQsOHD9fChQtVX18faMU942KOw6xZsz5yPtx5552BVnxu/aKAnnnmGa1YsUIrV67UW2+9palTp2ru3Lk6fvx46KX1uquuukrHjh3rvL3++uuhl9TjmpqaNHXqVK1evfqcj69atUo//vGP9fjjj2vXrl0aNmyY5s6dq5YW+yDJvuxCx0GS5s2b1+X8eOqpp3pxhT2vqqpK5eXl2rlzp1555RUlEgnNmTNHTU1Nndvcc889euGFF7RhwwZVVVXp6NGjuummmwKuuvtdzHGQpCVLlnQ5H1atWhVoxefh+oHp06e78vLyzrc7OjpcYWGhq6ysDLiq3rdy5Uo3derU0MsISpLbuHFj59vJZNLl5+e7hx9+uPO+hoYGF41G3VNPPRVghb3jw8fBOecWLVrk5s+fH2Q9oRw/ftxJclVVVc65sx/7tLQ0t2HDhs5tfve73zlJbseOHaGW2eM+fBycc+6LX/yi+/rXvx5uURehz18BtbW1ac+ePSotLe28LyUlRaWlpdqxY0fAlYVx4MABFRYWauzYsbr99tt1+PDh0EsKqqamRnV1dV3Oj1gsppKSkkvy/Ni2bZtyc3M1YcIELVu2TCdPngy9pB7V2NgoScrOzpYk7dmzR4lEosv5MHHiRI0ePXpAnw8fPg4fePLJJ5WTk6NJkyapoqJCzc3NIZZ3Xn1uGOmHnThxQh0dHcrLy+tyf15enn7/+98HWlUYJSUlWrdunSZMmKBjx47poYce0nXXXae3335bGRkZoZcXRF1dnSSd8/z44LFLxbx583TTTTepuLhYhw4d0re+9S2VlZVpx44dSk31+Fs9fVwymdTdd9+ta665RpMmTZJ09nxIT09XVlZWl20H8vlwruMgSbfddpvGjBmjwsJC7d+/X/fff7+qq6v13HPPBVxtV32+gPAfysrKOv89ZcoUlZSUaMyYMXr22Wf11a9+NeDK0Bfccsstnf+ePHmypkyZonHjxmnbtm2aPXt2wJX1jPLycr399tuXxPOgH+d8x2Hp0qWd/548ebIKCgo0e/ZsHTp0SOPGjevtZZ5Tn/8RXE5OjlJTUz/yKpb6+nrl5+cHWlXfkJWVpSuvvFIHDx4MvZRgPjgHOD8+auzYscrJyRmQ58fy5cv14osv6rXXXuvy51vy8/PV1tamhoaGLtsP1PPhfMfhXEpKSiSpT50Pfb6A0tPTNW3aNG3durXzvmQyqa1bt2rGjBkBVxbe6dOndejQIRUUFIReSjDFxcXKz8/vcn7E43Ht2rXrkj8/jhw5opMnTw6o88M5p+XLl2vjxo169dVXVVxc3OXxadOmKS0trcv5UF1drcOHDw+o8+FCx+Fc9u3bJ0l963wI/SqIi/H000+7aDTq1q1b537729+6pUuXuqysLFdXVxd6ab3qG9/4htu2bZurqalxb7zxhistLXU5OTnu+PHjoZfWo06dOuX27t3r9u7d6yS5Rx55xO3du9e99957zjnnfvCDH7isrCy3adMmt3//fjd//nxXXFzszpw5E3jl3evjjsOpU6fcvffe63bs2OFqamrcli1b3Oc+9zl3xRVXuJaWltBL7zbLli1zsVjMbdu2zR07dqzz1tzc3LnNnXfe6UaPHu1effVVt3v3bjdjxgw3Y8aMgKvufhc6DgcPHnTf/e533e7du11NTY3btGmTGzt2rJs5c2bglXfVLwrIOecee+wxN3r0aJeenu6mT5/udu7cGXpJve7mm292BQUFLj093V1++eXu5ptvdgcPHgy9rB732muvOUkfuS1atMg5d/al2N/+9rddXl6ei0ajbvbs2a66ujrsonvAxx2H5uZmN2fOHDdy5EiXlpbmxowZ45YsWTLgvkk71/svya1du7ZzmzNnzrivfe1r7rLLLnNDhw51N954ozt27Fi4RfeACx2Hw4cPu5kzZ7rs7GwXjUbd+PHj3Te/+U3X2NgYduEfwp9jAAAE0eefAwIADEwUEAAgCAoIABAEBQQACIICAgAEQQEBAIKggAAAQVBAAIAgKCAAQBAUEAAgCAoIABAEBQQACOL/AyBNQnoqGwl/AAAAAElFTkSuQmCC\",\n      \"text/plain\": [\n       \"<Figure size 640x480 with 1 Axes>\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"plt.figure()\\n\",\n    \"plt.imshow(img)\\n\",\n    \"plt.show()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 52,\n   \"id\": \"56f36efb-e3a2-49f9-b9fb-1657bc25e5c5\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[-1.0776339769363403, -3.4281859397888184, 1.0321333408355713, -2.1151161193847656, 0.7665405869483948, 0.7089913487434387, 0.6775667071342468, 0.3138602077960968, 2.9969606399536133, 0.7927607297897339]\\n\",\n      \"predicted label: Bag\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"print(predictions)\\n\",\n    \"print(\\\"predicted label:\\\", classes[np.argmax(predictions)])\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"56ca1195-ea0f-405f-87fe-857e5c0c76a5\",\n   \"metadata\": {},\n   \"source\": [\n    \"### 784 columns of float\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 53,\n   \"id\": \"e0ab0af6-b5c9-4b74-9dd6-baa7737cc986\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"784\"\n      ]\n     },\n     \"execution_count\": 53,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"df = spark.read.parquet(data_path_784)\\n\",\n    \"len(df.columns)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 54,\n   \"id\": \"13ae45dc-85a0-4864-8a58-9dc29ae4efd7\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 225 ms, sys: 91.1 ms, total: 316 ms\\n\",\n      \"Wall time: 3.16 s\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"preds = df.withColumn(\\\"preds\\\", mnist(struct(df.columns))).collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 55,\n   \"id\": \"0b3fb48b-f871-41f2-ac57-346899a6fe48\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 283 ms, sys: 67.8 ms, total: 351 ms\\n\",\n      \"Wall time: 1.47 s\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"preds = df.withColumn(\\\"preds\\\", mnist(array(*df.columns))).collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 56,\n   \"id\": \"b59114ad\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 543 ms, sys: 65.1 ms, total: 608 ms\\n\",\n      \"Wall time: 1.36 s\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"preds = df.withColumn(\\\"preds\\\", mnist(array(*df.columns))).collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"dc48ec42-0df6-4e6a-b019-1270ab71d2cf\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"source\": [\n    \"### Check predictions\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 57,\n   \"id\": \"d815c701-9f5b-422c-b3f9-fbc30456953c\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"preds = df.withColumn(\\\"preds\\\", mnist(array(*df.columns))).limit(10).toPandas()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 58,\n   \"id\": \"b571b742-5079-42b2-8524-9181a0dec2c7\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"sample = preds.iloc[0]\\n\",\n    \"predictions = sample.preds\\n\",\n    \"img = sample.drop('preds').to_numpy(dtype=float)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 59,\n   \"id\": \"d33d6a4e-e6b9-489d-ac21-c4eddc801784\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import matplotlib.pyplot as plt\\n\",\n    \"import numpy as np\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 60,\n   \"id\": \"6d10061e-aca6-4f81-bdfe-72e327ed7349\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"img = np.array(img).reshape(28,28)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 61,\n   \"id\": \"01f70e08-2c1d-419f-8676-3f6f4aba760f\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"image/png\": \"iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAHpRJREFUeJzt3X1wlPW99/HPZpNsAoSNIeSpBBpQoZWHnlJJuVWKJQOkczyiTMenP8BxYLTBKVKrk46K2s6kxRnr6FA8f7RQ71t8mhEYPb3pKJowtoEOKDeH0zZCTix4QoLS5oGEPJD9nT84bu+FAP1dbPLdhPdr5prJ7l7fXF+uvchnr+y134Scc04AAAyzNOsGAABXJgIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJtKtGzhXLBZTc3OzcnJyFAqFrNsBAHhyzqmzs1MlJSVKS7vweU7KBVBzc7NKS0ut2wAAXKZjx45p0qRJF3w85QIoJydHknSjvqN0ZRh3g4tJn/wl75qy/33Cu2b3f03zrsnKOONdI0lnBobnt9IDbnjO7mOxYP+ejPCAd01He7Z3zdypR71r2v+527vG9fZ61yC4M+rXB/pN/Of5hQxZAG3cuFHPPPOMWlpaNGfOHL3wwguaN2/eJeu++LVbujKUHiKAUll6WsS7JnOc/3MaHuO/nXBG2LtGktwwBZCGKYBCAQMoHCCA0vqyvGsyxmZ616SH+r1rXCjmXYPL8D8TRi/1NsqQ/G977bXXtG7dOq1fv14ffvih5syZoyVLlujECf9XvwCA0WlIAujZZ5/VqlWrdO+99+qrX/2qXnzxRY0ZM0a/+tWvhmJzAIARKOkB1NfXp/3796uiouLvG0lLU0VFherr689bv7e3Vx0dHQkLAGD0S3oAff755xoYGFBhYWHC/YWFhWppaTlv/ZqaGkWj0fjCFXAAcGUw/yBqdXW12tvb48uxY8esWwIADIOkXwWXn5+vcDis1tbWhPtbW1tVVFR03vqRSESRiP9VTgCAkS3pZ0CZmZmaO3eudu3aFb8vFotp165dmj9/frI3BwAYoYbkc0Dr1q3TihUr9I1vfEPz5s3Tc889p66uLt17771DsTkAwAg0JAF0xx136LPPPtMTTzyhlpYWfe1rX9POnTvPuzABAHDlGrJJCGvWrNGaNWuG6tsjBfz1Bv9RPG8Wv+ld80xmp3dNWeQz7xpJCsv/E/NpAT5lPzbNfzTMgPP/jXk44ASA/9c9xbtmz9/KvGv+dcq/edf8y+Lve9dkvfUH7xoMPfOr4AAAVyYCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmhmwYKUa/9qn+r1/29Ya9a5p7c71r0uS8ayQpppB3TU8sw7smmt7tXZMVOuNdE2RQqiQ1dk/0rjnWlutd89vu8/9I5aW0Xe3/Y8t/KxgOnAEBAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwwDRuB9RQNeNdMzzjtXVMSafOuyc/o9K6RpBP9471rMkL++6E/5v9fr9tFvGtywj3eNZJUnNXuXfONIv8J5LMyj3vX9OV4lyBFcQYEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABMNIEVhmQbd3TY/zH1g5LsBAzX4X9q6RpDFpfd41nQNZ3jXtA9neNZ/3jvOuuSn3Y+8aKdg+D4diAWr8j4dYxL8GqYkzIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYYRorAxmb3etc0D0S8a2LO/3VSv3fFWRmhgWGp6Y35/9crG/O5d03bwBjvmqBOnfF/bpvP5HjX9BUHfXaRajgDAgCYIIAAACaSHkBPPvmkQqFQwjJjxoxkbwYAMMINyXtA1113nd59992/bySdt5oAAImGJBnS09NVVFQ0FN8aADBKDMl7QIcPH1ZJSYmmTp2qe+65R0ePHr3gur29vero6EhYAACjX9IDqLy8XFu2bNHOnTu1adMmNTU16aabblJnZ+eg69fU1CgajcaX0tLSZLcEAEhBSQ+gyspKffe739Xs2bO1ZMkS/eY3v1FbW5tef/31Qdevrq5We3t7fDl27FiyWwIApKAhvzogNzdX1157rY4cOTLo45FIRJGI/wfYAAAj25B/DujUqVNqbGxUcXHxUG8KADCCJD2AHn74YdXV1emTTz7R73//e912220Kh8O66667kr0pAMAIlvRfwX366ae66667dPLkSU2cOFE33nij9uzZo4kTJyZ7UwCAESzpAfTqq68m+1siReVE+rxrMhXzrkkL+ddkhYINrOx3/v8lpkZOeNdcndXiXdPcf5V3TXeA4a+SlJXmv/96YxneNR2xLO+azLH+xx1SE7PgAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmBjyP0iH0SsUct41Xc5/YGX7mTHeNUrv9q9RsCGmOeHT3jU//vifvWu+N7XWu+bj7iLvGknKDbD/ugYyA23LVyQSbNAsUg9nQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAE0zDRmA9Z/wPn4EAr3mae6PeNUFNzGr1rsnQgHdN9DtHvGtmNB33rtnTebV3jSS1BZhA3tkf8a4JMn08FuN182jBMwkAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEw0gR2Kke/+GTY0N93jUx5/866ZPTE7xrJOmfxnziXRMbptdxnw3keNeUZX8WaFt7/1bmXXOi27+/zJD/INeengzvGqQmzoAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYYBgpAuvu8h9GmhGKDUEn5xtwoUB1BeFO75r/c/J/BdhSr3fF7s4Z3jX/Ev3Qu0aS/q15lndNzxn/Hyc5aT3eNQOn+bE1WnAGBAAwQQABAEx4B9Du3bt1yy23qKSkRKFQSNu3b0943DmnJ554QsXFxcrOzlZFRYUOHz6crH4BAKOEdwB1dXVpzpw52rhx46CPb9iwQc8//7xefPFF7d27V2PHjtWSJUvU0+P/u14AwOjl/W5eZWWlKisrB33MOafnnntOjz32mG699VZJ0ksvvaTCwkJt375dd9555+V1CwAYNZL6HlBTU5NaWlpUUVERvy8ajaq8vFz19fWD1vT29qqjoyNhAQCMfkkNoJaWFklSYWFhwv2FhYXxx85VU1OjaDQaX0pLS5PZEgAgRZlfBVddXa329vb4cuzYMeuWAADDIKkBVFRUJElqbW1NuL+1tTX+2LkikYjGjx+fsAAARr+kBlBZWZmKioq0a9eu+H0dHR3au3ev5s+fn8xNAQBGOO+r4E6dOqUjR47Ebzc1NenAgQPKy8vT5MmTtXbtWv3kJz/RNddco7KyMj3++OMqKSnRsmXLktk3AGCE8w6gffv26eabb47fXrdunSRpxYoV2rJlix555BF1dXVp9erVamtr04033qidO3cqKysreV0DAEY87wBauHChnHMXfDwUCunpp5/W008/fVmNIfXFuvyHQmZoeIaRBjUzs9+7ZufHX/WumaaPvGvqT5R519yf94F3jST1x/x/O5+Vfsa/JjTgXRPqDnvXIDWZXwUHALgyEUAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBM+I8zBv5HWpf/VOKxacMzDftMLNjE5HFp/n82JKMhO9C2fB1rzvOuKbwuM9C2unv96/LGdnvXjAkwDTt8mtfNowXPJADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMMI0Vg4Z6Qd01WyL/m9ECGd01+5JR3TVDR/xyeAavZjRHvmvBi//0dVFrIeddkBGgvdMa/BqmJMyAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmGEaKwMK9/pMk+5z/wMpImv/0ySCDMYMa39QzLNsZ2zx8/6YxkT7vmrHp/jVh74pgQ3CRmjgDAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIJhpAgs/fTwbCcm/+GT0YDN9bsB75r0j//Lu8Z/K9L4T/yHfQY1Jedv3jV9Mf/Rohkh/+c2PHy7AUOMMyAAgAkCCABgwjuAdu/erVtuuUUlJSUKhULavn17wuMrV65UKBRKWJYuXZqsfgEAo4R3AHV1dWnOnDnauHHjBddZunSpjh8/Hl9eeeWVy2oSADD6eF+EUFlZqcrKyouuE4lEVFRUFLgpAMDoNyTvAdXW1qqgoEDTp0/XAw88oJMnT15w3d7eXnV0dCQsAIDRL+kBtHTpUr300kvatWuXfvazn6murk6VlZUaGBj8wtOamhpFo9H4UlpamuyWAAApKOmfA7rzzjvjX8+aNUuzZ8/WtGnTVFtbq0WLFp23fnV1tdatWxe/3dHRQQgBwBVgyC/Dnjp1qvLz83XkyJFBH49EIho/fnzCAgAY/YY8gD799FOdPHlSxcXFQ70pAMAI4v0ruFOnTiWczTQ1NenAgQPKy8tTXl6ennrqKS1fvlxFRUVqbGzUI488oquvvlpLlixJauMAgJHNO4D27dunm2++OX77i/dvVqxYoU2bNungwYP69a9/rba2NpWUlGjx4sX68Y9/rEgkkryuAQAjnncALVy4UM65Cz7+29/+9rIawsgRDjDvMyfNf2Bl74D/tTLjwj3eNUENfPbZsGwn65MLf5wh2fIyu71r/to3Zgg6OV8oNiybwTBgFhwAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwETS/yQ3rhzppy88Ff1CMuQ/DTuIsPx7S3UDx5qHbVsTMk951wSZhp0R8n8NHDrjXYIUxRkQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwwjRWCZnf4DP8OhkHfNV3OOe9ekhWLeNZKUERqeYalBuP6+YdvWVeld3jXTx7V612SF/H8EZXSNvkGzVyrOgAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJhgGCkCS+/1H/iZFuA1TzR82rsmktbvXSNJ/W4gUF2q6o4F2w85aT3eNeH04RkSyjDS0YMzIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYYRoqUlxE6410zNq030LZi8h+wmsr6FWxwZ5D91+/CgbblK6ObYaSjBWdAAAATBBAAwIRXANXU1Oj6669XTk6OCgoKtGzZMjU0NCSs09PTo6qqKk2YMEHjxo3T8uXL1dramtSmAQAjn1cA1dXVqaqqSnv27NE777yj/v5+LV68WF1dXfF1HnroIb311lt64403VFdXp+bmZt1+++1JbxwAMLJ5XYSwc+fOhNtbtmxRQUGB9u/frwULFqi9vV2//OUvtXXrVn3729+WJG3evFlf+cpXtGfPHn3zm99MXucAgBHtst4Dam9vlyTl5eVJkvbv36/+/n5VVFTE15kxY4YmT56s+vr6Qb9Hb2+vOjo6EhYAwOgXOIBisZjWrl2rG264QTNnzpQktbS0KDMzU7m5uQnrFhYWqqWlZdDvU1NTo2g0Gl9KS0uDtgQAGEECB1BVVZUOHTqkV1999bIaqK6uVnt7e3w5duzYZX0/AMDIEOiDqGvWrNHbb7+t3bt3a9KkSfH7i4qK1NfXp7a2toSzoNbWVhUVFQ36vSKRiCKRSJA2AAAjmNcZkHNOa9as0bZt2/Tee++prKws4fG5c+cqIyNDu3btit/X0NCgo0ePav78+cnpGAAwKnidAVVVVWnr1q3asWOHcnJy4u/rRKNRZWdnKxqN6r777tO6deuUl5en8ePH68EHH9T8+fO5Ag4AkMArgDZt2iRJWrhwYcL9mzdv1sqVKyVJP//5z5WWlqbly5ert7dXS5Ys0S9+8YukNAsAGD28Asi5Sw8BzMrK0saNG7Vx48bATWFkSO8ensGdAwGulQkywFSSelywulTVGQs2uDMt5P/cZoQG/LcT4LnNbB9dz9GVjFlwAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATgf4iKiBJ4e7UnUocVrAp0H8d8J/onMo+G8gOVBd0/w2HzJOnvWuGZ247fHEGBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwATDSJHyYs7/dVJGKNig1H/vKwhUl6paBqKB6rLS+rxrwrGsQNvylfZZm3cNw0hTE2dAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATDCMFIGFBpx3TX1veAg6SZ7/HGXDSBt6igPV/VP2J941aQFGfv7f7hzvGtfV5V2D1MQZEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMMI0VgsSz/waIT0k571xRmtHnXTE7/m3eNJLWcyQ1Ul6r+41SwYaSVOf/uXdPtIt41ueFu7xql82NrtOAMCABgggACAJjwCqCamhpdf/31ysnJUUFBgZYtW6aGhoaEdRYuXKhQKJSw3H///UltGgAw8nkFUF1dnaqqqrRnzx6988476u/v1+LFi9V1zh+IWrVqlY4fPx5fNmzYkNSmAQAjn9e7eTt37ky4vWXLFhUUFGj//v1asGBB/P4xY8aoqKgoOR0CAEaly3oPqL29XZKUl5eXcP/LL7+s/Px8zZw5U9XV1eruvvCVLr29vero6EhYAACjX+DrGWOxmNauXasbbrhBM2fOjN9/9913a8qUKSopKdHBgwf16KOPqqGhQW+++eag36empkZPPfVU0DYAACNU4ACqqqrSoUOH9MEHHyTcv3r16vjXs2bNUnFxsRYtWqTGxkZNmzbtvO9TXV2tdevWxW93dHSotLQ0aFsAgBEiUACtWbNGb7/9tnbv3q1JkyZddN3y8nJJ0pEjRwYNoEgkokjE/wNsAICRzSuAnHN68MEHtW3bNtXW1qqsrOySNQcOHJAkFRcH+0Q2AGB08gqgqqoqbd26VTt27FBOTo5aWlokSdFoVNnZ2WpsbNTWrVv1ne98RxMmTNDBgwf10EMPacGCBZo9e/aQ/AMAACOTVwBt2rRJ0tkPm/7/Nm/erJUrVyozM1PvvvuunnvuOXV1dam0tFTLly/XY489lrSGAQCjg/ev4C6mtLRUdXV1l9UQAODKwFhZBBaKXfwFyWCuy8z2rvnVyfMvXrmU5qyrvGsk6eWmed41efo40LaGQ3a4P1Ddr07e6F2TERrwrvnhxA8uvdK5LvFCGCMHw0gBACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYCLlLjbgeZh0dHYpGo1qoW5UeyrBuBwDg6YzrV612qL29XePHj7/gepwBAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMBEunUD5/piNN0Z9UspNaUOAPCPOKN+SX//eX4hKRdAnZ2dkqQP9BvjTgAAl6Ozs1PRaPSCj6fcNOxYLKbm5mbl5OQoFAolPNbR0aHS0lIdO3bsohNWRzv2w1nsh7PYD2exH85Khf3gnFNnZ6dKSkqUlnbhd3pS7gwoLS1NkyZNuug648ePv6IPsC+wH85iP5zFfjiL/XCW9X642JnPF7gIAQBgggACAJgYUQEUiUS0fv16RSIR61ZMsR/OYj+cxX44i/1w1kjaDyl3EQIA4Mowos6AAACjBwEEADBBAAEATBBAAAATIyaANm7cqC9/+cvKyspSeXm5/vCHP1i3NOyefPJJhUKhhGXGjBnWbQ253bt365ZbblFJSYlCoZC2b9+e8LhzTk888YSKi4uVnZ2tiooKHT582KbZIXSp/bBy5crzjo+lS5faNDtEampqdP311ysnJ0cFBQVatmyZGhoaEtbp6elRVVWVJkyYoHHjxmn58uVqbW016nho/CP7YeHChecdD/fff79Rx4MbEQH02muvad26dVq/fr0+/PBDzZkzR0uWLNGJEyesWxt21113nY4fPx5fPvjgA+uWhlxXV5fmzJmjjRs3Dvr4hg0b9Pzzz+vFF1/U3r17NXbsWC1ZskQ9PT3D3OnQutR+kKSlS5cmHB+vvPLKMHY49Orq6lRVVaU9e/bonXfeUX9/vxYvXqyurq74Og899JDeeustvfHGG6qrq1Nzc7Nuv/12w66T7x/ZD5K0atWqhONhw4YNRh1fgBsB5s2b56qqquK3BwYGXElJiaupqTHsavitX7/ezZkzx7oNU5Lctm3b4rdjsZgrKipyzzzzTPy+trY2F4lE3CuvvGLQ4fA4dz8459yKFSvcrbfeatKPlRMnTjhJrq6uzjl39rnPyMhwb7zxRnydP/3pT06Sq6+vt2pzyJ27H5xz7lvf+pb7/ve/b9fUPyDlz4D6+vq0f/9+VVRUxO9LS0tTRUWF6uvrDTuzcfjwYZWUlGjq1Km65557dPToUeuWTDU1NamlpSXh+IhGoyovL78ij4/a2loVFBRo+vTpeuCBB3Ty5EnrloZUe3u7JCkvL0+StH//fvX39yccDzNmzNDkyZNH9fFw7n74wssvv6z8/HzNnDlT1dXV6u7utmjvglJuGOm5Pv/8cw0MDKiwsDDh/sLCQv35z3826spGeXm5tmzZounTp+v48eN66qmndNNNN+nQoUPKycmxbs9ES0uLJA16fHzx2JVi6dKluv3221VWVqbGxkb96Ec/UmVlperr6xUOh63bS7pYLKa1a9fqhhtu0MyZMyWdPR4yMzOVm5ubsO5oPh4G2w+SdPfdd2vKlCkqKSnRwYMH9eijj6qhoUFvvvmmYbeJUj6A8HeVlZXxr2fPnq3y8nJNmTJFr7/+uu677z7DzpAK7rzzzvjXs2bN0uzZszVt2jTV1tZq0aJFhp0NjaqqKh06dOiKeB/0Yi60H1avXh3/etasWSouLtaiRYvU2NioadOmDXebg0r5X8Hl5+crHA6fdxVLa2urioqKjLpKDbm5ubr22mt15MgR61bMfHEMcHycb+rUqcrPzx+Vx8eaNWv09ttv6/3330/48y1FRUXq6+tTW1tbwvqj9Xi40H4YTHl5uSSl1PGQ8gGUmZmpuXPnateuXfH7YrGYdu3apfnz5xt2Zu/UqVNqbGxUcXGxdStmysrKVFRUlHB8dHR0aO/evVf88fHpp5/q5MmTo+r4cM5pzZo12rZtm9577z2VlZUlPD537lxlZGQkHA8NDQ06evToqDoeLrUfBnPgwAFJSq3jwfoqiH/Eq6++6iKRiNuyZYv74x//6FavXu1yc3NdS0uLdWvD6gc/+IGrra11TU1N7ne/+52rqKhw+fn57sSJE9atDanOzk730UcfuY8++shJcs8++6z76KOP3F/+8hfnnHM//elPXW5urtuxY4c7ePCgu/XWW11ZWZk7ffq0cefJdbH90NnZ6R5++GFXX1/vmpqa3Lvvvuu+/vWvu2uuucb19PRYt540DzzwgItGo662ttYdP348vnR3d8fXuf/++93kyZPde++95/bt2+fmz5/v5s+fb9h18l1qPxw5csQ9/fTTbt++fa6pqcnt2LHDTZ061S1YsMC480QjIoCcc+6FF15wkydPdpmZmW7evHluz5491i0NuzvuuMMVFxe7zMxM96Uvfcndcccd7siRI9ZtDbn333/fSTpvWbFihXPu7KXYjz/+uCssLHSRSMQtWrTINTQ02DY9BC62H7q7u93ixYvdxIkTXUZGhpsyZYpbtWrVqHuRNti/X5LbvHlzfJ3Tp0+7733ve+6qq65yY8aMcbfddps7fvy4XdND4FL74ejRo27BggUuLy/PRSIRd/XVV7sf/vCHrr293bbxc/DnGAAAJlL+PSAAwOhEAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADAxH8DR+VYJGWV2nkAAAAASUVORK5CYII=\",\n      \"text/plain\": [\n       \"<Figure size 640x480 with 1 Axes>\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"plt.figure()\\n\",\n    \"plt.imshow(img)\\n\",\n    \"plt.show()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 62,\n   \"id\": \"8e1c07cc-b2bc-4902-a9a6-4ac7f02c5fe4\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[ 2.44647     3.8623989   0.14587203  3.2146688   1.0799949  -2.5363288\\n\",\n      \"  0.86715794 -3.8287208  -2.02238    -2.9016623 ]\\n\",\n      \"predicted label: Trouser\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"print(predictions)\\n\",\n    \"print(\\\"predicted label:\\\", classes[np.argmax(predictions)])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 63,\n   \"id\": \"3d47a8ec\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# This will clear the engine cache (containing previously compiled TensorRT engines) and resets the CUDA Context.\\n\",\n    \"torch._dynamo.reset()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"281c7889\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Using Triton Inference Server\\n\",\n    \"In this section, we demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL.  \\n\",\n    \"We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server.  \\n\",\n    \"\\n\",\n    \"The process looks like this:\\n\",\n    \"- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.\\n\",\n    \"- Define a Triton inference function, which contains a client that binds to the local server on a given node and sends inference requests.\\n\",\n    \"- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.\\n\",\n    \"- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.\\n\",\n    \"\\n\",\n    \"<img src=\\\"../images/spark-server.png\\\" alt=\\\"drawing\\\" width=\\\"700\\\"/>\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 64,\n   \"id\": \"53ca290a-ccc3-4923-a292-944921bab36d\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from functools import partial\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"d8abea75\",\n   \"metadata\": {},\n   \"source\": [\n    \"Import the helper class from server_utils.py:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 65,\n   \"id\": \"e616b207\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"sc.addPyFile(\\\"server_utils.py\\\")\\n\",\n    \"\\n\",\n    \"from server_utils import TritonServerManager\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"606934ac\",\n   \"metadata\": {},\n   \"source\": [\n    \"Define the Triton Server function:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 66,\n   \"id\": \"8fa92fe4-2e04-4d82-a357-bfdfca38bd8c\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def triton_server(ports, model_path):\\n\",\n    \"    import time\\n\",\n    \"    import signal\\n\",\n    \"    import numpy as np\\n\",\n    \"    import torch\\n\",\n    \"    from torch import nn\\n\",\n    \"    import torch_tensorrt as trt\\n\",\n    \"    from pytriton.decorators import batch\\n\",\n    \"    from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor\\n\",\n    \"    from pytriton.triton import Triton, TritonConfig\\n\",\n    \"    from pyspark import TaskContext\\n\",\n    \"\\n\",\n    \"    print(f\\\"SERVER: Initializing model on worker {TaskContext.get().partitionId()}.\\\")\\n\",\n    \"    device = torch.device(\\\"cuda\\\" if torch.cuda.is_available() else \\\"cpu\\\")\\n\",\n    \"    \\n\",\n    \"    exp_program = torch.export.load(model_path)\\n\",\n    \"    example_inputs = (torch.randn((50, 784), dtype=torch.float).to(\\\"cuda\\\"),)\\n\",\n    \"    trt_gm = trt.dynamo.compile(exp_program,\\n\",\n    \"                                tuple(example_inputs),\\n\",\n    \"                                enabled_precisions={torch.float},\\n\",\n    \"                                workspace_size=1<<30)\\n\",\n    \"\\n\",\n    \"    print(\\\"SERVER: Compiled model.\\\")\\n\",\n    \"\\n\",\n    \"    @batch\\n\",\n    \"    def _infer_fn(**inputs):\\n\",\n    \"        images = inputs[\\\"images\\\"]\\n\",\n    \"        if len(images) != 1:\\n\",\n    \"            images = np.squeeze(images)\\n\",\n    \"        stream = torch.cuda.Stream()\\n\",\n    \"        with torch.no_grad(), torch.cuda.stream(stream):\\n\",\n    \"            torch_inputs = torch.from_numpy(images).to(device)\\n\",\n    \"            outputs = trt_gm(torch_inputs)\\n\",\n    \"            return {\\n\",\n    \"                \\\"labels\\\": outputs.cpu().numpy(),\\n\",\n    \"            }\\n\",\n    \"        \\n\",\n    \"    workspace_path = f\\\"/tmp/triton_{time.strftime('%m_%d_%M_%S')}\\\"\\n\",\n    \"    triton_conf = TritonConfig(http_port=ports[0], grpc_port=ports[1], metrics_port=ports[2])\\n\",\n    \"    with Triton(config=triton_conf, workspace=workspace_path) as triton:\\n\",\n    \"        triton.bind(\\n\",\n    \"            model_name=\\\"ImageClassifier\\\",\\n\",\n    \"            infer_func=_infer_fn,\\n\",\n    \"            inputs=[\\n\",\n    \"                Tensor(name=\\\"images\\\", dtype=np.float32, shape=(-1,)),\\n\",\n    \"            ],\\n\",\n    \"            outputs=[\\n\",\n    \"                Tensor(name=\\\"labels\\\", dtype=np.float32, shape=(-1,)),\\n\",\n    \"            ],\\n\",\n    \"            config=ModelConfig(\\n\",\n    \"                max_batch_size=64,\\n\",\n    \"                batcher=DynamicBatcher(max_queue_delay_microseconds=5000),  # 5ms\\n\",\n    \"            ),\\n\",\n    \"            strict=True,\\n\",\n    \"        )\\n\",\n    \"\\n\",\n    \"        def _stop_triton(signum, frame):\\n\",\n    \"            # The server manager sends SIGTERM to stop the server; this function ensures graceful cleanup.\\n\",\n    \"            print(\\\"SERVER: Received SIGTERM. Stopping Triton server.\\\")\\n\",\n    \"            triton.stop()\\n\",\n    \"\\n\",\n    \"        signal.signal(signal.SIGTERM, _stop_triton)\\n\",\n    \"\\n\",\n    \"        print(\\\"SERVER: Serving inference\\\")\\n\",\n    \"        triton.serve()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"8fea6e5e\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Start Triton servers  \"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"f837300c\",\n   \"metadata\": {},\n   \"source\": [\n    \"The `TritonServerManager` will handle the lifecycle of Triton server instances across the Spark cluster:\\n\",\n    \"- Find available ports for HTTP/gRPC/metrics\\n\",\n    \"- Deploy a server on each node via stage-level scheduling\\n\",\n    \"- Gracefully shutdown servers across nodes\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 68,\n   \"id\": \"f72c53d6\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"model_name = \\\"ImageClassifier\\\"\\n\",\n    \"server_manager = TritonServerManager(model_name=model_name, model_path=model_path)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"65d3f7be\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2025-02-07 11:03:44,809 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\\n\",\n      \"2025-02-07 11:03:44,810 - INFO - Starting 1 servers.\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'cb4ae00-lcedt': (2020631, [7000, 7001, 7002])}\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"# Returns {'hostname', (server_pid, [http_port, grpc_port, metrics_port])}\\n\",\n    \"server_manager.start_servers(triton_server)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"90ed191b\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Define client function\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"86c1545a\",\n   \"metadata\": {},\n   \"source\": [\n    \"Get the hostname -> url mapping from the server manager:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"c4c2833f\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"host_to_http_url = server_manager.host_to_http_url  # or server_manager.host_to_grpc_url\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"c6771c93\",\n   \"metadata\": {},\n   \"source\": [\n    \"Define the Triton inference function, which returns a predict function for batch inference through the server:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 71,\n   \"id\": \"cec9a48c\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def triton_fn(model_name, host_to_url):\\n\",\n    \"    import socket\\n\",\n    \"    from pytriton.client import ModelClient\\n\",\n    \"\\n\",\n    \"    url = host_to_url[socket.gethostname()]\\n\",\n    \"    print(f\\\"Connecting to Triton model {model_name} at {url}.\\\")\\n\",\n    \"\\n\",\n    \"    def infer_batch(inputs):\\n\",\n    \"        with ModelClient(url, model_name, inference_timeout_s=240) as client:\\n\",\n    \"            result_data = client.infer_batch(inputs)\\n\",\n    \"            return result_data[\\\"labels\\\"]\\n\",\n    \"        \\n\",\n    \"    return infer_batch\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 73,\n   \"id\": \"0262fd4a-9845-44b9-8c75-1c105e7deeca\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"mnist = predict_batch_udf(partial(triton_fn, model_name=model_name, host_to_url=host_to_http_url),\\n\",\n    \"                          input_tensor_shapes=[[784]],\\n\",\n    \"                          return_type=ArrayType(FloatType()),\\n\",\n    \"                          batch_size=50)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"30a4362d-7514-4b84-b238-f704a97e1e72\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Run inference\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 72,\n   \"id\": \"ab94d4d1-dac6-4474-9eb0-59478aa98f7d\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"StructType([StructField('data', ArrayType(FloatType(), True), True)])\"\n      ]\n     },\n     \"execution_count\": 72,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"df = spark.read.parquet(data_path_1)\\n\",\n    \"df.schema\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 74,\n   \"id\": \"fc5f6baa-052e-4b89-94b6-4821cf01952a\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 157 ms, sys: 47.6 ms, total: 205 ms\\n\",\n      \"Wall time: 2.49 s\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"preds = df.withColumn(\\\"preds\\\", mnist(struct(df.columns))).collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 75,\n   \"id\": \"a85dea35-e41d-482d-8a8f-52d3c108f038\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 183 ms, sys: 60.3 ms, total: 243 ms\\n\",\n      \"Wall time: 1.49 s\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"preds = df.withColumn(\\\"preds\\\", mnist(*df.columns)).collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 76,\n   \"id\": \"bc3f0dbe-c52b-41d6-8097-8cebaa5ee5a8\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 383 ms, sys: 43.9 ms, total: 427 ms\\n\",\n      \"Wall time: 1.6 s\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"preds = df.withColumn(\\\"preds\\\", mnist(*[col(c) for c in df.columns])).collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 77,\n   \"id\": \"99fb5e8d\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Predicted label: Bag\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"image/png\": \"iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAJJ5JREFUeJzt3Xt0lfWd7/HPTkg2t2THEHKTQAMotHLplErKqBRLDpDOuEA5HW9zDrg6MNLgqlKrJz1Watuz0uIa66lDca2zWqir4oVzREbHYgUljAp0QBjGXlLAKGEgoWCTDQlJdrJ/5w/GzERB+P5M8kvC+7XWXovs/Xx4fnnyJJ882TvfRJxzTgAA9LKU0AsAAFyaKCAAQBAUEAAgCAoIABAEBQQACIICAgAEQQEBAIKggAAAQQwKvYAPSyaTOnr0qDIyMhSJREIvBwBg5JzTqVOnVFhYqJSU81/n9LkCOnr0qIqKikIvAwDwCdXW1mrUqFHnfbzPFVBGRoYk6Vp9WYOUFng16Ha9dVXLhCkgmHYl9Lpe6vx6fj49VkCrV6/Www8/rLq6Ok2dOlWPPfaYpk+ffsHcBz92G6Q0DYpQQANOr/1YlQICgvn3T78LPY3SIy9CeOaZZ7RixQqtXLlSb731lqZOnaq5c+fq+PHjPbE7AEA/1CMF9Mgjj2jJkiW644479JnPfEaPP/64hg4dqp/97Gc9sTsAQD/U7QXU1tamPXv2qLS09D92kpKi0tJS7dix4yPbt7a2Kh6Pd7kBAAa+bi+gEydOqKOjQ3l5eV3uz8vLU11d3Ue2r6ysVCwW67zxCjgAuDQE/0XUiooKNTY2dt5qa2tDLwkA0Au6/VVwOTk5Sk1NVX19fZf76+vrlZ+f/5Hto9GootFody8DANDHdfsVUHp6uqZNm6atW7d23pdMJrV161bNmDGju3cHAOineuT3gFasWKFFixbp85//vKZPn65HH31UTU1NuuOOO3pidwCAfqhHCujmm2/WH//4Rz344IOqq6vTZz/7WW3evPkjL0wAAFy6Is71rZkl8XhcsVhMszSfSQgYsFLHF5szx/7O/lxp7vzfmzPAJ9XuEtqmTWpsbFRmZuZ5twv+KjgAwKWJAgIABEEBAQCCoIAAAEFQQACAICggAEAQFBAAIAgKCAAQBAUEAAiCAgIABEEBAQCCoIAAAEH0yDRsoL+q+YH9b1bdP3+jOVPVcP4BjeczPu2MOZO/356RpF9smG3OFH3vTa994dLFFRAAIAgKCAAQBAUEAAiCAgIABEEBAQCCoIAAAEFQQACAICggAEAQFBAAIAgKCAAQBAUEAAiCAgIABEEBAQCCYBp2HxYZZP/wuI4O+46cs2c8RdLSzRmXaDNnBhWPMWckacttD5szX/zV3ebMlX+z25ypNyek3Tdf75GSHvzuU+bM2u/5HXOzSMSe6cVzHBePKyAAQBAUEAAgCAoIABAEBQQACIICAgAEQQEBAIKggAAAQVBAAIAgKCAAQBAUEAAgCAoIABAEBQQACCLiXN+a0hePxxWLxTRL8zUokhZ6Od3HY4BiJDXVnPEaRuqrb506Xfxh7TSv3NiiP5ozg0oPe+2rLzvzcrE5c8Pl+82ZLZMyzBn0fe0uoW3apMbGRmVmZp53O66AAABBUEAAgCAoIABAEBQQACAICggAEAQFBAAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACCIQaEXcMnwGNzp2tvt+/EYetqXh4pKUsrUT5szW770v732VfrLFebMlfIYRppiHzQbSbF/bL3OIUlpP8w2Z25e9y/mzIbF3zRnLlu3w5xB38QVEAAgCAoIABBEtxfQd77zHUUikS63iRMndvduAAD9XI88B3TVVVdpy5Yt/7GTQTzVBADoqkeaYdCgQcrPz++J/xoAMED0yHNABw4cUGFhocaOHavbb79dhw+f/1VCra2tisfjXW4AgIGv2wuopKRE69at0+bNm7VmzRrV1NTouuuu06lTp865fWVlpWKxWOetqKiou5cEAOiDur2AysrK9JWvfEVTpkzR3Llz9dJLL6mhoUHPPvvsObevqKhQY2Nj5622tra7lwQA6IN6/NUBWVlZuvLKK3Xw4MFzPh6NRhWNRnt6GQCAPqbHfw/o9OnTOnTokAoKCnp6VwCAfqTbC+jee+9VVVWV3n33Xb355pu68cYblZqaqltvvbW7dwUA6Me6/UdwR44c0a233qqTJ09q5MiRuvbaa7Vz506NHDmyu3cFAOjHur2Ann766e7+L2HhM1jUYzCmJCnZYY7Eb/2COTNm+R/MmcdPXmfOSFLhq700ncolzZFIdKh9N57DSOu+YH9e9p1EpjnzzEMPmzMr//bL5kz9DH69oy9iFhwAIAgKCAAQBAUEAAiCAgIABEEBAQCCoIAAAEFQQACAICggAEAQFBAAIAgKCAAQBAUEAAiCAgIABNHjf5Cu10Qi9ozP4M7e5DMk1GPIpc9QUV9Xf2OPOVPdmGfONLenmzOSNPzZnV45q0iq5wDYXpJ2yp55s+kKc+bRP33KnLlr1BZz5oHblpgzkpS53uN86K2vRT778d1XD+EKCAAQBAUEAAiCAgIABEEBAQCCoIAAAEFQQACAICggAEAQFBAAIAgKCAAQBAUEAAiCAgIABEEBAQCCoIAAAEEMnGnYfZ3H5Fqficku0XuTrYdtH2nOtDv7mOXUFPuE73c3jTVnJKlAdV45K9fh8XFqS3T/Qs4j77E3zZlvVVSbM9cevcqcWXlgvjlz8//cbM5I0ssvFJkzyVMeo8R9+E617kN/OYArIABAEBQQACAICggAEAQFBAAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACAICggAEAQFBAAIYuAMI/UYlhdJS/fbVaLNI2Rfn9d+PBy978+9ct/MfdacefLfvmDOJGUfnljwiH2YZq/qw+eDr181p5kz/33MTnNm1d455kxqkd8wzYxf2r9GNF7rtaveE/G47nA9M+SYKyAAQBAUEAAgCAoIABAEBQQACIICAgAEQQEBAIKggAAAQVBAAIAgKCAAQBAUEAAgCAoIABAEBQQACCLinMdUxB4Uj8cVi8U0S/M1KGIfbjiQHF9uHxLammXfz6+WrrKHJK1r+Lw5Myr9fXPme7+8yZyJ/cE+wFSSrv+bXebMxt981pwZ9G9RcyalzeN9ivh9evvsq3VE0pwZOeGEOROLtpgzZ9r9vpasHP8P5syyDUvNmeL/scOc6cvaXULbtEmNjY3KzMw873ZcAQEAgqCAAABBmAto+/btuuGGG1RYWKhIJKLnn3++y+POOT344IMqKCjQkCFDVFpaqgMHDnTXegEAA4S5gJqamjR16lStXr36nI+vWrVKP/7xj/X4449r165dGjZsmObOnauWFvvPbQEAA5f5L6KWlZWprKzsnI855/Too4/qgQce0Pz58yVJTzzxhPLy8vT888/rlltu+WSrBQAMGN36HFBNTY3q6upUWlraeV8sFlNJSYl27Dj3qzxaW1sVj8e73AAAA1+3FlBdXZ0kKS8vr8v9eXl5nY99WGVlpWKxWOetqKioO5cEAOijgr8KrqKiQo2NjZ232tra0EsCAPSCbi2g/Px8SVJ9fX2X++vr6zsf+7BoNKrMzMwuNwDAwNetBVRcXKz8/Hxt3bq18754PK5du3ZpxowZ3bkrAEA/Z34V3OnTp3Xw4MHOt2tqarRv3z5lZ2dr9OjRuvvuu/X9739fV1xxhYqLi/Xtb39bhYWFWrBgQXeuGwDQz5kLaPfu3br++us7316xYoUkadGiRVq3bp3uu+8+NTU1aenSpWpoaNC1116rzZs3a/Dgwd23agBAv3dJDyN9Z5XfjwX/7safmzMr/vmvzJn09HZz5gdTnzNnNjdMMWckKUX2U+d3jXkX3uhD3tt7uTnTcVnCnJGkwbXp5kzmO/bjkNJuz3Sk2weEJs3fYp7lUu2ZZJp9fZGk/Ticntlszny26Ig5I0lHT8fMmWvy3jFn3nrf/urfI+9nmTOSNPor/+qVs2AYKQCgT6OAAABBUEAAgCAoIABAEBQQACAICggAEAQFBAAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACAIz1m5A8Nvb/97r9zCg39hzuT8o/3PUZxeeMqc+dnR68yZxja/P5VxR9Eb5syJtmHmTM3gpDmjDvtkZklqu8y+r8RX/mTOjIo1mjMjo6fNmWiqfaK6JGUNsk+cTniM0G7qiJozMzOr7ftJ2vcjSW8OGm/OpMp+Dg0b1GbOvDR9jTkjSbfefq85E3typ9e+LoQrIABAEBQQACAICggAEAQFBAAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACAICggAEAQFBAAIYsAMI225Ybo5kxbZ57evb+WbMzn/6z1z5uXxz5kzD5+wH4ehKfZBiJK08vUF5kxK3H7KuSyPgZoe80vP7ithzsQPXGbOHGwYYc7U2meeKrXV2UOeOqL2AbDOY2bs6+mfM2duX/yKfUeSrsv6gznzucGHzZmX064yZ/7in+80ZyTpbyp+Zc68/GSm174uhCsgAEAQFBAAIAgKCAAQBAUEAAiCAgIABEEBAQCCoIAAAEFQQACAICggAEAQFBAAIAgKCAAQBAUEAAhiwAwjrf98770rWT+sNWf+MudfzJmfNtgHFP5V1j+bM9+t/UtzRpIu/2WqOZMY6jF9UmnmRCTpN4TTpdjX15Fu308yzb4+n7W1Zvkcb0kesYjHzFif/Qz/N/uk2cf/6Xr7jiT9Yf4ac+Y3bfZ3alFsvznzj5mTzRlJ+uvYv5ozv/qzpabtIx2t0r9suuB2XAEBAIKggAAAQVBAAIAgKCAAQBAUEAAgCAoIABAEBQQACIICAgAEQQEBAIKggAAAQVBAAIAgKCAAQBADZhhp2une29e9l282Zzb8abo5k5seN2e+eei/mjPNf3+5OSNJZ3Lt37/4DMdMbTNHlEz1G8KZ0ksDNX0Gdzr77FevtUlS+xC/nJXPcTg12n7eFVTZ9yNJX5s205wpHNxgzlSfzjNnFl6+15yRpBEp9g/umcuHmbZvT6RKFzF/mSsgAEAQFBAAIAhzAW3fvl033HCDCgsLFYlE9Pzzz3d5fPHixYpEIl1u8+bN6671AgAGCHMBNTU1aerUqVq9evV5t5k3b56OHTvWeXvqqac+0SIBAAOP+UUIZWVlKisr+9htotGo8vPzvRcFABj4euQ5oG3btik3N1cTJkzQsmXLdPLkyfNu29raqng83uUGABj4ur2A5s2bpyeeeEJbt27VD3/4Q1VVVamsrEwdHR3n3L6yslKxWKzzVlRU1N1LAgD0Qd3+e0C33HJL578nT56sKVOmaNy4cdq2bZtmz579ke0rKiq0YsWKzrfj8TglBACXgB5/GfbYsWOVk5OjgwcPnvPxaDSqzMzMLjcAwMDX4wV05MgRnTx5UgUFBT29KwBAP2L+Edzp06e7XM3U1NRo3759ys7OVnZ2th566CEtXLhQ+fn5OnTokO677z6NHz9ec+fO7daFAwD6N3MB7d69W9dff33n2x88f7No0SKtWbNG+/fv189//nM1NDSosLBQc+bM0fe+9z1Fo9HuWzUAoN8zF9CsWbPknDvv4y+//PInWpCvQc29t68DbfbfcXo43z448P+dtj8f1viz/2LOJGN+EysTw+y5yLlfDPmxOtLtGd8hnB9zap9/V700WNRrQKjncRh0xp5J8Rga63M++By7jqjfgfj1+qnmzIYVD5sz/5Q+zpy5esi75owkxZNJc2ZY9QnT9u0drRe1HbPgAABBUEAAgCAoIABAEBQQACAICggAEAQFBAAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACAICggAEES3/0nuUIacsE949TUu7bg5c6TdPl74vpfuNGcyRti/p0gMM0ckSelxe8Z5fMvTMdiekcdUa0lqH+qxK4/pzD7r85m67Ssx3L7ApMdXk9Q2+5TqlIsbtNxF/FN+07AzDtuPw7cOzzdn/u+4LebMtjMeJ6ukwtRT5kzHgXds27vERW3HFRAAIAgKCAAQBAUEAAiCAgIABEEBAQCCoIAAAEFQQACAICggAEAQFBAAIAgKCAAQBAUEAAiCAgIABDFghpFmHjptziRch9e+xqa1mDMzdywzZ/LfsA9C/NNEc8R7yGVLjj3TEbW/T9H3PQZJen5r5TMstX2o/X1qz7CfeynDL27A43+WbPWZlCpFj6aZM2mn7R8nr2PnMZw2tdVvGOmZXHvunfVXmDO/u/8fzBlpuEdGuixliDmTOmG8aXvX0SoduPB2XAEBAIKggAAAQVBAAIAgKCAAQBAUEAAgCAoIABAEBQQACIICAgAEQQEBAIKggAAAQVBAAIAgKCAAQBADZhhpSnObOZMW8RvU+K9tmebM8F/ZBwf+aaLHAMWkPdIxxD4QUpLasuw7i56wH/P0U/b1pS84bs5I0snGYeZMR5v90yj9cNScyd5u/34x4veh1Zls+7nXXOgxWNRjGKl85or6zSJVIsO+viEp9o/T/J13mjO/nPETc0aS2mU/9yIJ28TiSPLitucKCAAQBAUEAAiCAgIABEEBAQCCoIAAAEFQQACAICggAEAQFBAAIAgKCAAQBAUEAAiCAgIABEEBAQCCGDDDSNXeYY40Js947epvd37NnEkbZZ+G2Jpjf59SWu37iST8JjWm5zebM2N/FDdn2t87Ys7UR0vMGUkav/WkPXS01p7JHWGOvHdTrjnTPMY2RPIDkaH24b7ujMdw36TH+dpuz3Sk+k1ldYPsueZR9kzGTvuw4rrpQ80ZSRqXZr/uaH/nXdv2LnFR23EFBAAIggICAARhKqDKykpdffXVysjIUG5urhYsWKDq6uou27S0tKi8vFwjRozQ8OHDtXDhQtXX13frogEA/Z+pgKqqqlReXq6dO3fqlVdeUSKR0Jw5c9TU1NS5zT333KMXXnhBGzZsUFVVlY4ePaqbbrqp2xcOAOjfTC9C2Lx5c5e3161bp9zcXO3Zs0czZ85UY2OjfvrTn2r9+vX60pe+JElau3atPv3pT2vnzp36whe+0H0rBwD0a5/oOaDGxkZJUnZ2tiRpz549SiQSKi0t7dxm4sSJGj16tHbs2HHO/6O1tVXxeLzLDQAw8HkXUDKZ1N13361rrrlGkyZNkiTV1dUpPT1dWVlZXbbNy8tTXV3dOf+fyspKxWKxzltRUZHvkgAA/Yh3AZWXl+vtt9/W008//YkWUFFRocbGxs5bba3H71QAAPodr19EXb58uV588UVt375do0aN6rw/Pz9fbW1tamho6HIVVF9fr/z8/HP+X9FoVNFo1GcZAIB+zHQF5JzT8uXLtXHjRr366qsqLi7u8vi0adOUlpamrVu3dt5XXV2tw4cPa8aMGd2zYgDAgGC6AiovL9f69eu1adMmZWRkdD6vE4vFNGTIEMViMX31q1/VihUrlJ2drczMTN11112aMWMGr4ADAHRhKqA1a9ZIkmbNmtXl/rVr12rx4sWSpB/96EdKSUnRwoUL1draqrlz5+onP/lJtywWADBwmArIuQsP2Rs8eLBWr16t1atXey/KS6r99RQvNo268Ebn4JIeGY95n6nNHkMDh3sMMI34vRYl2Wp/CrH5ypHmTHrNe+bM5c8eMmckqf4vxpozJ6/JMGeyRpw2Z1pP24fnRt5PN2ckSQ1p9n35zLT1OfV8Mn6zSBVJ2HfmhtsHwDYX2vfz11v+1pyRpJq//D/mTOqIbNP2LtkmvX/h7ZgFBwAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACAICggAEAQFBAAIggICAARBAQEAgqCAAABBUEAAgCC8/iJqX9TxuwPmzLCUVq99PXHNT82Zv04sNWcizanmTGosYc4kO/wmJrsW++kTX95ozqTcdaU509xqn+YsSc6dsofeH2KONL6bZc5E7IPOpTTPMdD2U89r4rRL9Qh5TKOP+Iyjl+QG2w96aoP986JjmP2dSq/vvS/fLX9WfOGN/pP29hbptQtvxxUQACAICggAEAQFBAAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACAICggAEAQFBAAIggICAAQxYIaR+shKafbKpXpMXfzRdU+bM4WD/mTO/OLkn5szL74xzZyRJHXYBzy+fyTLnElttn+f5HrxW6uIx0BNl2YfPun8ZsZ6ibR7DO/0mffpMyvVYz8uxXMoq8c5nox67ssoJeE3YLU52WbOtGXZqqI9cXHbcwUEAAiCAgIABEEBAQCCoIAAAEFQQACAICggAEAQFBAAIAgKCAAQBAUEAAiCAgIABEEBAQCCoIAAAEFc0sNIt5y6yis3Zehhc+Zyj8GiE9LazZlPDT5pzvy3Wf9kzkjS+t9cbQ8dHmKOpNgPgxIx+7BPSYp4DJ+MJO2ZlDN+gyStXO/sRpIU8ZjB6VLsC/QZNOuztrP78nmnfM4h+24GtdgzklTT3mHODD6RMG3f3n5x23MFBAAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACAICggAEAQFBAAIggICAARBAQEAgqCAAABBXNrDSI9O8MrFRjebM+PS/mjO/CJ+pTmTFrEPGizL2G/OSNKUz9eaMyNKTpszean2zMb4n5kzkvTSUfuA2vak/fu45rY0cybpsZ9omm2I5AeGeAzC9TkOg1LsUzh95oqebol6pKTYEPvEz8LhjeZMW0eqOZP0mcoq6Y0z48yZYzMGm7bvaJV0ETOOuQICAARBAQEAgjAVUGVlpa6++mplZGQoNzdXCxYsUHV1dZdtZs2apUgk0uV25513duuiAQD9n6mAqqqqVF5erp07d+qVV15RIpHQnDlz1NTU1GW7JUuW6NixY523VatWdeuiAQD9n+lFCJs3b+7y9rp165Sbm6s9e/Zo5syZnfcPHTpU+fn53bNCAMCA9ImeA2psPPtqj+zs7C73P/nkk8rJydGkSZNUUVGh5ubzv2qstbVV8Xi8yw0AMPB5vww7mUzq7rvv1jXXXKNJkyZ13n/bbbdpzJgxKiws1P79+3X//ferurpazz333Dn/n8rKSj300EO+ywAA9FPeBVReXq63335br7/+epf7ly5d2vnvyZMnq6CgQLNnz9ahQ4c0btxHX39eUVGhFStWdL4dj8dVVFTkuywAQD/hVUDLly/Xiy++qO3bt2vUqFEfu21JSYkk6eDBg+csoGg0qmjU75fEAAD9l6mAnHO66667tHHjRm3btk3FxcUXzOzbt0+SVFBQ4LVAAMDAZCqg8vJyrV+/Xps2bVJGRobq6uokSbFYTEOGDNGhQ4e0fv16ffnLX9aIESO0f/9+3XPPPZo5c6amTJnSI+8AAKB/MhXQmjVrJJ39ZdP/bO3atVq8eLHS09O1ZcsWPfroo2pqalJRUZEWLlyoBx54oNsWDAAYGMw/gvs4RUVFqqqq+kQLAgBcGi7padhDPCcFL8v6jTlzMBExZ8qz7NOm/dgn8Z5l/52t5mSbOTM0Zag58+mc6gtvdA7LLttrzlyWal+fjzda7JOjG5J+a8tPtX9sB3tMYu+Q/fOixdnP15EpreaMJBWnDTdn9rTaz/HxafZjt6Mly5yRpPV/LDFnRlW+adq+3SV04CK2YxgpACAICggAEAQFBAAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACAICggAEAQFBAAIggICAARxaQ8j/brfX2Itm/B1cybaYB98mkzrne8POqJ++zlyvT0XybMPhcx4c4g5U/jLo+aMJLlU+/vUcdkwcyb1VIs5o2PHzRGXaLfvR1JkqH2IaWS4x+DTC0zYP6d2++BOX81X2f+Qps/n7fA9h82Z9mN15sxZ9kGzPYUrIABAEBQQACAICggAEAQFBAAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACAICggAEESfmwXn/n02VLsSkseYKNO+OuxzySSpPWGf45Xa7jELLtJLs+BS/PaTbPGYBddsP+YdbRFzpj3p97F1Ht+TdbSn2vfjc+65NnvEec6CS9q/NESS9uPgNQsu2Xuz4Nrb7Z/rSY9zqD1p/9i2O/vXlN7SrrNrcxf4+EbchbboZUeOHFFRUVHoZQAAPqHa2lqNGjXqvI/3uQJKJpM6evSoMjIyFIl0/c43Ho+rqKhItbW1yszMDLTC8DgOZ3EczuI4nMVxOKsvHAfnnE6dOqXCwkKlfMxPWPrcj+BSUlI+tjElKTMz85I+wT7AcTiL43AWx+EsjsNZoY9DLBa74Da8CAEAEAQFBAAIol8VUDQa1cqVKxWN+v0l04GC43AWx+EsjsNZHIez+tNx6HMvQgAAXBr61RUQAGDgoIAAAEFQQACAICggAEAQ/aaAVq9erU996lMaPHiwSkpK9Otf/zr0knrdd77zHUUikS63iRMnhl5Wj9u+fbtuuOEGFRYWKhKJ6Pnnn+/yuHNODz74oAoKCjRkyBCVlpbqwIEDYRbbgy50HBYvXvyR82PevHlhFttDKisrdfXVVysjI0O5ublasGCBqquru2zT0tKi8vJyjRgxQsOHD9fChQtVX18faMU942KOw6xZsz5yPtx5552BVnxu/aKAnnnmGa1YsUIrV67UW2+9palTp2ru3Lk6fvx46KX1uquuukrHjh3rvL3++uuhl9TjmpqaNHXqVK1evfqcj69atUo//vGP9fjjj2vXrl0aNmyY5s6dq5YW+yDJvuxCx0GS5s2b1+X8eOqpp3pxhT2vqqpK5eXl2rlzp1555RUlEgnNmTNHTU1Nndvcc889euGFF7RhwwZVVVXp6NGjuummmwKuuvtdzHGQpCVLlnQ5H1atWhVoxefh+oHp06e78vLyzrc7OjpcYWGhq6ysDLiq3rdy5Uo3derU0MsISpLbuHFj59vJZNLl5+e7hx9+uPO+hoYGF41G3VNPPRVghb3jw8fBOecWLVrk5s+fH2Q9oRw/ftxJclVVVc65sx/7tLQ0t2HDhs5tfve73zlJbseOHaGW2eM+fBycc+6LX/yi+/rXvx5uURehz18BtbW1ac+ePSotLe28LyUlRaWlpdqxY0fAlYVx4MABFRYWauzYsbr99tt1+PDh0EsKqqamRnV1dV3Oj1gsppKSkkvy/Ni2bZtyc3M1YcIELVu2TCdPngy9pB7V2NgoScrOzpYk7dmzR4lEosv5MHHiRI0ePXpAnw8fPg4fePLJJ5WTk6NJkyapoqJCzc3NIZZ3Xn1uGOmHnThxQh0dHcrLy+tyf15enn7/+98HWlUYJSUlWrdunSZMmKBjx47poYce0nXXXae3335bGRkZoZcXRF1dnSSd8/z44LFLxbx583TTTTepuLhYhw4d0re+9S2VlZVpx44dSk31+Fs9fVwymdTdd9+ta665RpMmTZJ09nxIT09XVlZWl20H8vlwruMgSbfddpvGjBmjwsJC7d+/X/fff7+qq6v13HPPBVxtV32+gPAfysrKOv89ZcoUlZSUaMyYMXr22Wf11a9+NeDK0Bfccsstnf+ePHmypkyZonHjxmnbtm2aPXt2wJX1jPLycr399tuXxPOgH+d8x2Hp0qWd/548ebIKCgo0e/ZsHTp0SOPGjevtZZ5Tn/8RXE5OjlJTUz/yKpb6+nrl5+cHWlXfkJWVpSuvvFIHDx4MvZRgPjgHOD8+auzYscrJyRmQ58fy5cv14osv6rXXXuvy51vy8/PV1tamhoaGLtsP1PPhfMfhXEpKSiSpT50Pfb6A0tPTNW3aNG3durXzvmQyqa1bt2rGjBkBVxbe6dOndejQIRUUFIReSjDFxcXKz8/vcn7E43Ht2rXrkj8/jhw5opMnTw6o88M5p+XLl2vjxo169dVXVVxc3OXxadOmKS0trcv5UF1drcOHDw+o8+FCx+Fc9u3bJ0l963wI/SqIi/H000+7aDTq1q1b537729+6pUuXuqysLFdXVxd6ab3qG9/4htu2bZurqalxb7zxhistLXU5OTnu+PHjoZfWo06dOuX27t3r9u7d6yS5Rx55xO3du9e99957zjnnfvCDH7isrCy3adMmt3//fjd//nxXXFzszpw5E3jl3evjjsOpU6fcvffe63bs2OFqamrcli1b3Oc+9zl3xRVXuJaWltBL7zbLli1zsVjMbdu2zR07dqzz1tzc3LnNnXfe6UaPHu1effVVt3v3bjdjxgw3Y8aMgKvufhc6DgcPHnTf/e533e7du11NTY3btGmTGzt2rJs5c2bglXfVLwrIOecee+wxN3r0aJeenu6mT5/udu7cGXpJve7mm292BQUFLj093V1++eXu5ptvdgcPHgy9rB732muvOUkfuS1atMg5d/al2N/+9rddXl6ei0ajbvbs2a66ujrsonvAxx2H5uZmN2fOHDdy5EiXlpbmxowZ45YsWTLgvkk71/svya1du7ZzmzNnzrivfe1r7rLLLnNDhw51N954ozt27Fi4RfeACx2Hw4cPu5kzZ7rs7GwXjUbd+PHj3Te/+U3X2NgYduEfwp9jAAAE0eefAwIADEwUEAAgCAoIABAEBQQACIICAgAEQQEBAIKggAAAQVBAAIAgKCAAQBAUEAAgCAoIABAEBQQACOL/AyBNQnoqGwl/AAAAAElFTkSuQmCC\",\n      \"text/plain\": [\n       \"<Figure size 640x480 with 1 Axes>\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"# Sample prediction\\n\",\n    \"sample = preds[0]\\n\",\n    \"predictions = sample.preds\\n\",\n    \"img = sample.data\\n\",\n    \"\\n\",\n    \"img = np.array(img).reshape(28,28)\\n\",\n    \"plt.figure()\\n\",\n    \"plt.imshow(img)\\n\",\n    \"\\n\",\n    \"print(\\\"Predicted label:\\\", classes[np.argmax(predictions)])\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"7a26690a-9dc4-4c36-9904-568d73e2be3c\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"source\": [\n    \"#### Stop Triton Server on each executor\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"e02838ba\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2025-02-04 14:00:18,330 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\\n\",\n      \"2025-02-04 14:00:28,520 - INFO - Sucessfully stopped 1 servers.                 \\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"[True]\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"server_manager.stop_servers()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 79,\n   \"id\": \"a0608fff-7cfb-489e-96c9-8e1d92e57562\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"if not on_databricks: # on databricks, spark.stop() puts the cluster in a bad state\\n\",\n    \"    spark.stop()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"08de2664-3d60-487b-90da-6d0f3b8b9203\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"spark-dl-torch\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.11.11\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "examples/ML+DL-Examples/Spark-DL/dl_inference/requirements.txt",
    "content": "# Copyright (c) 2025, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nnumpy\npandas\nmatplotlib\nportalocker\npyarrow\nh5py\npydot\nscikit-learn\njupyterlab\npyspark>=3.4.0\nhuggingface\ndatasets\ntransformers\nipywidgets\nnvidia-pytriton\n"
  },
  {
    "path": "examples/ML+DL-Examples/Spark-DL/dl_inference/server_utils.py",
    "content": "#\n# Copyright (c) 2025, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\nimport inspect\nimport logging\nimport os\nimport socket\nimport subprocess\nimport sys\nimport time\nfrom multiprocessing import Process\nfrom typing import Any, Callable, Dict, List, Optional, Set, Tuple\n\nimport psutil\nimport requests\nfrom pyspark import RDD\nfrom pyspark.sql import SparkSession\n\nlogging.basicConfig(\n    level=logging.INFO, format=\"%(asctime)s - %(levelname)s - %(message)s\"\n)\nlogger = logging.getLogger(\"ServerManager\")\n\n# -----------------------------------------------------------------------------\n# Helper Functions\n# -----------------------------------------------------------------------------\n\n\ndef _find_ports(num_ports: int, start_port: int = 7000) -> List[int]:\n    \"\"\"Find available ports on executor for server services.\"\"\"\n    ports = []\n    conns = {conn.laddr.port for conn in psutil.net_connections(kind=\"inet\")}\n    i = start_port\n\n    while len(ports) < num_ports:\n        if i not in conns:\n            ports.append(i)\n        i += 1\n\n    return ports\n\n\ndef _get_valid_vllm_parameters_task() -> Set[str]:\n    \"\"\"Task to get valid vLLM parameters on executor.\"\"\"\n    from vllm.entrypoints.openai.cli_args import create_parser_for_docs\n\n    parser = create_parser_for_docs()\n    valid_args = set()\n    for action in parser._actions:\n        if action.dest not in [\n            \"help\",\n            \"host\",\n            \"port\",\n            \"served-model-name\",\n            \"model\",\n        ]:\n            valid_args.add(action.dest)\n\n    return valid_args\n\n\ndef _start_triton_server_task(\n    triton_server_fn: Callable,\n    model_name: str,\n    wait_retries: int,\n    wait_timeout: int,\n    model_path: Optional[str] = None,\n) -> List[tuple]:\n    \"\"\"Task to start Triton server process on a Spark executor.\"\"\"\n\n    from pyspark import BarrierTaskContext\n\n    from pytriton.client import ModelClient\n\n    def _prepare_pytriton_env():\n        \"\"\"Expose PyTriton to correct libpython3.11.so and Triton bundled libraries.\"\"\"\n        ld_library_paths = []\n\n        # Add nvidia_pytriton.libs to LD_LIBRARY_PATH\n        for path in sys.path:\n            if os.path.isdir(path) and \"site-packages\" in path:\n                libs_path = os.path.join(path, \"nvidia_pytriton.libs\")\n                if os.path.isdir(libs_path):\n                    ld_library_paths.append(libs_path)\n                    break\n\n        # Add ${CONDA_PREFIX}/lib to LD_LIBRARY_PATH for conda environments\n        if os.path.exists(os.path.join(sys.prefix, \"conda-meta\")):\n            conda_lib = os.path.join(sys.prefix, \"lib\")\n            if os.path.isdir(conda_lib):\n                ld_library_paths.append(conda_lib)\n\n        if \"LD_LIBRARY_PATH\" in os.environ:\n            ld_library_paths.append(os.environ[\"LD_LIBRARY_PATH\"])\n\n        os.environ[\"LD_LIBRARY_PATH\"] = \":\".join(ld_library_paths)\n\n        return None\n\n    # Setup server function arguments\n    tc = BarrierTaskContext.get()\n    ports = _find_ports(num_ports=3)\n    sig = inspect.signature(triton_server_fn)\n    params = sig.parameters\n\n    if model_path is not None:\n        assert (\n            len(params) == 2\n        ), \"Server function must accept (ports, model_path) when model_path is provided\"\n        args = (ports, model_path)\n    else:\n        assert len(params) == 1, \"Server function must accept (ports) argument\"\n        args = (ports,)\n\n    # Prepare and start server process\n    _prepare_pytriton_env()\n    hostname = socket.gethostname()\n    process = Process(target=triton_server_fn, args=args)\n    process.start()\n\n    client = ModelClient(f\"http://localhost:{ports[0]}\", model_name)\n\n    # Wait for server to start\n    for _ in range(wait_retries):\n        try:\n            client.wait_for_model(wait_timeout)\n            tc.barrier()\n            client.close()\n            return [(hostname, (process.pid, ports))]\n        except Exception:\n            if not process.is_alive():\n                # If process terminated due to an error, stop waiting\n                break\n            pass\n\n    client.close()\n    if process.is_alive():\n        # Terminate if timeout is exceeded to avoid dangling server processes\n        process.terminate()\n\n    raise TimeoutError(\n        \"Failure: Triton server startup failed or timed out. Check the executor logs for more info.\"\n    )\n\n\ndef _start_vllm_server_task(\n    model_name: str,\n    model_path: str,\n    wait_retries: int,\n    wait_timeout: int,\n    **kwargs,\n) -> List[tuple]:\n    \"\"\"Task to start vLLM server process on a Spark executor.\"\"\"\n    from pyspark import BarrierTaskContext\n\n    tc = BarrierTaskContext.get()\n    port = _find_ports(num_ports=1)[0]\n    hostname = socket.gethostname()\n\n    # Build command for vLLM server\n    cmd = [\n        sys.executable,\n        \"-m\",\n        \"vllm.entrypoints.openai.api_server\",\n        \"--model\",\n        model_path,\n        \"--served-model-name\",\n        model_name,\n        \"--port\",\n        str(port),\n    ]\n\n    # Add additional args from kwargs\n    for key, value in kwargs.items():\n        if isinstance(value, bool) and value:\n            cmd.append(f\"--{key}\")\n        elif not isinstance(value, bool):\n            cmd.append(f\"--{key}\")\n            cmd.append(str(value))\n\n    logger.info(f\"Starting vLLM server with command: {' '.join(cmd)}\")\n\n    # vLLM does CUDA init at import time. Forking will try to re-initialize CUDA if vLLM was imported before and throw an error.\n    os.environ[\"VLLM_WORKER_MULTIPROC_METHOD\"] = \"spawn\"\n\n    # Start server process\n    process = subprocess.Popen(cmd, stdout=sys.stdout, stderr=sys.stderr, text=True)\n\n    # Wait for server to start\n    health_url = f\"http://localhost:{port}/health\"\n    for _ in range(wait_retries):\n        try:\n            time.sleep(wait_timeout)\n            response = requests.get(health_url)\n            if response.status_code == 200:\n                tc.barrier()\n                return [(hostname, (process.pid, [port]))]\n        except Exception:\n            if process.poll() is not None:\n                # If process terminated due to an error, stop waiting\n                break\n            pass\n\n    if process.poll() is None:\n        # Terminate if timeout is exceeded to avoid dangling server processes\n        process.terminate()\n\n    raise TimeoutError(\n        \"Failure: vLLM server startup failed or timed out. Check the executor logs for more info.\"\n    )\n\n\ndef _stop_server_task(\n    server_pids_ports: Dict[str, Tuple[int, List[int]]],\n    wait_retries: int,\n    wait_timeout: int,\n) -> List[bool]:\n    \"\"\"Task to stop a server process on a Spark executor.\"\"\"\n    hostname = socket.gethostname()\n    pid, _ = server_pids_ports.get(hostname, (None, None))\n    assert pid is not None, f\"No server PID found for host {hostname}\"\n\n    try:\n        process = psutil.Process(pid)\n        process.terminate()\n        process.wait(timeout=wait_timeout * wait_retries)\n        return [True]\n    except psutil.NoSuchProcess:\n        return [True]\n    except psutil.TimeoutExpired:\n        try:\n            process.kill()\n            return [True]\n        except:\n            return [False]\n\n\n# -----------------------------------------------------------------------------\n# ServerManager Classes\n# -----------------------------------------------------------------------------\n\n\nclass ServerManager:\n    \"\"\"\n    Base class for server management across a Spark cluster.\n\n    Attributes:\n        spark: Active SparkSession\n        num_executors: Number of servers to manage (= # of executors)\n        model_name: Name of the served model\n        model_path: Optional path to model files\n        server_pids_ports: Dictionary of hostname to (server process ID, ports)\n    \"\"\"\n\n    DEFAULT_WAIT_RETRIES = 24\n    DEFAULT_WAIT_TIMEOUT = 5\n\n    def __init__(self, model_name: str, model_path: Optional[str] = None):\n        \"\"\"\n        Initialize the server manager.\n\n        Args:\n            model_name: Name of the model to serve\n            model_path: Optional path to model file for server function to load from disk\n        \"\"\"\n        self.spark = SparkSession.getActiveSession()\n        self.num_executors = self._get_num_executors()\n        self.model_name = model_name\n        self.model_path = model_path\n        self._server_pids_ports: Dict[str, Tuple[int, List[int]]] = {}\n\n    def _get_num_executors(self) -> int:\n        \"\"\"Get the number of executors in the cluster.\"\"\"\n        return (\n            len(\n                [\n                    executor.host()\n                    for executor in self.spark._jsc.sc()\n                    .statusTracker()\n                    .getExecutorInfos()\n                ]\n            )\n            - 1\n        )\n\n    @property\n    def host_to_http_url(self) -> Dict[str, str]:\n        \"\"\"Map hostname to client HTTP URL for server on that host.\"\"\"\n        if not self._server_pids_ports:\n            logger.warning(\"No urls available. Start servers first.\")\n            return None\n\n        return {\n            host: f\"http://localhost:{ports[0]}\"\n            for host, (_, ports) in self._server_pids_ports.items()\n        }\n\n    def _get_node_rdd(self) -> RDD:\n        \"\"\"Create and configure RDD with stage-level scheduling for 1 task per executor.\"\"\"\n        sc = self.spark.sparkContext\n        node_rdd = sc.parallelize(list(range(self.num_executors)), self.num_executors)\n        return self._use_stage_level_scheduling(node_rdd)\n\n    def _use_stage_level_scheduling(self, rdd: RDD) -> RDD:\n        \"\"\"\n        Use stage-level scheduling to ensure each server instance maps to 1 executor.\n        Adapted from https://github.com/NVIDIA/spark-rapids-ml/blob/main/python/src/spark_rapids_ml/core.py\n        \"\"\"\n        from pyspark.resource.profile import ResourceProfileBuilder\n        from pyspark.resource.requests import TaskResourceRequests\n\n        executor_cores = self.spark.conf.get(\"spark.executor.cores\")\n        assert executor_cores is not None, \"spark.executor.cores is not set\"\n        executor_gpus = self.spark.conf.get(\"spark.executor.resource.gpu.amount\")\n        assert (\n            executor_gpus is not None\n        ), \"spark.executor.resource.gpu.amount is not set\"\n\n        spark_plugins = self.spark.conf.get(\"spark.plugins\", \" \")\n        assert spark_plugins is not None\n        spark_rapids_sql_enabled = self.spark.conf.get(\n            \"spark.rapids.sql.enabled\", \"true\"\n        )\n        assert spark_rapids_sql_enabled is not None\n\n        task_cores = (\n            int(executor_cores)\n            if \"com.nvidia.spark.SQLPlugin\" in spark_plugins\n            and \"true\" == spark_rapids_sql_enabled.lower()\n            else (int(executor_cores) // 2) + 1\n        )\n        task_gpus = float(executor_gpus)\n\n        treqs = TaskResourceRequests().cpus(task_cores).resource(\"gpu\", task_gpus)\n        rp = ResourceProfileBuilder().require(treqs).build\n        logger.info(\n            f\"Requesting stage-level resources: (cores={task_cores}, gpu={task_gpus})\"\n        )\n\n        return rdd.withResources(rp)\n\n    def start_servers(\n        self,\n        start_server_fn: Callable,\n        wait_retries: int = DEFAULT_WAIT_RETRIES,\n        wait_timeout: int = DEFAULT_WAIT_TIMEOUT,\n        **kwargs,\n    ) -> Dict[str, Tuple[int, List[int]]]:\n        \"\"\"\n        Start servers across the cluster.\n\n        Args:\n            start_server_fn: Function used to start the server process\n            wait_retries: Number of retries for waiting for server startup\n            wait_timeout: Timeout in seconds for each retry\n            **kwargs: Additional server-specific arguments\n\n        Returns:\n            Dictionary of hostname -> (server PID, [ports])\n        \"\"\"\n        node_rdd = self._get_node_rdd()\n        model_name = self.model_name\n        model_path = self.model_path\n        server_type = self.__class__.__name__.replace(\"ServerManager\", \"\")\n\n        logger.info(f\"Starting {self.num_executors} {server_type} servers.\")\n\n        start_args = {\n            \"model_name\": model_name,\n            \"wait_retries\": wait_retries,\n            \"wait_timeout\": wait_timeout,\n        }\n\n        if model_path is not None:\n            start_args[\"model_path\"] = model_path\n\n        start_args.update(kwargs)\n\n        self._server_pids_ports = (\n            node_rdd.barrier()\n            .mapPartitions(lambda _: start_server_fn(**start_args))\n            .collectAsMap()\n        )\n\n        return self._server_pids_ports\n\n    def stop_servers(\n        self,\n        wait_retries: int = DEFAULT_WAIT_RETRIES,\n        wait_timeout: int = DEFAULT_WAIT_TIMEOUT,\n    ) -> List[bool]:\n        \"\"\"\n        Stop all servers across the cluster.\n\n        Returns:\n            List of booleans indicating success/failure of stopping each server\n        \"\"\"\n        if not self._server_pids_ports:\n            logger.warning(\"No servers to stop.\")\n            return []\n\n        node_rdd = self._get_node_rdd()\n        server_pids_ports = self._server_pids_ports\n        server_type = self.__class__.__name__.replace(\"ServerManager\", \"\")\n\n        stop_success = (\n            node_rdd.barrier()\n            .mapPartitions(\n                lambda _: _stop_server_task(\n                    server_pids_ports=server_pids_ports,\n                    wait_retries=wait_retries,\n                    wait_timeout=wait_timeout,\n                )\n            )\n            .collect()\n        )\n\n        if all(stop_success):\n            self._server_pids_ports.clear()\n            logger.info(\n                f\"Successfully stopped {self.num_executors} {server_type} servers.\"\n            )\n        else:\n            logger.warning(\n                f\"{server_type} server termination failed or timed out. Check executor logs.\"\n            )\n\n        return stop_success\n\n\nclass TritonServerManager(ServerManager):\n    \"\"\"\n    Handle lifecycle of Triton server instances across Spark cluster.\n\n    Example usage:\n    >>> server_manager = TritonServerManager(model_name=\"my_model\", model_path=\"/path/to/my_model\")\n    >>> # Define triton_server(ports, model_path) that contains PyTriton server logic\n    >>> server_pids_ports = server_manager.start_servers(triton_server)\n    >>> print(f\"Servers started with PIDs/Ports: {server_pids_ports}\")\n    >>> host_to_http_url = server_manager.host_to_http_url\n    >>> host_to_grpc_url = server_manager.host_to_grpc_url\n    >>> # Define triton_fn() and predict_batch_udf(triton_fn) and run inference...\n    >>> success = server_manager.stop_servers()\n    >>> print(f\"Server shutdown success: {success}\")\n    \"\"\"\n\n    def __init__(self, model_name: str, model_path: Optional[str] = None):\n        super().__init__(model_name, model_path)\n\n    @property\n    def host_to_grpc_url(self) -> Dict[str, str]:\n        \"\"\"Map hostname to client gRPC URL for Triton server on that host.\"\"\"\n        if not self._server_pids_ports:\n            logger.warning(\"No urls available. Start servers first.\")\n            return None\n\n        return {\n            host: f\"grpc://localhost:{ports[1]}\"\n            for host, (_, ports) in self._server_pids_ports.items()\n        }\n\n    def start_servers(\n        self,\n        triton_server_fn: Callable,\n        wait_retries: int = ServerManager.DEFAULT_WAIT_RETRIES,\n        wait_timeout: int = ServerManager.DEFAULT_WAIT_TIMEOUT,\n    ) -> Dict[str, Tuple[int, List[int]]]:\n        \"\"\"\n        Start Triton servers across the cluster.\n\n        Args:\n            triton_server_fn: PyTriton server function defining the model and inference logic\n            wait_retries: Number of retries for waiting for server startup\n            wait_timeout: Timeout in seconds for each retry\n\n        Returns:\n            Dictionary of hostname -> (server PID, [ports])\n        \"\"\"\n        return super().start_servers(\n            start_server_fn=_start_triton_server_task,\n            wait_retries=wait_retries,\n            wait_timeout=wait_timeout,\n            triton_server_fn=triton_server_fn,\n        )\n\n\nclass VLLMServerManager(ServerManager):\n    \"\"\"\n    Handle lifecycle of vLLM server instances across Spark cluster.\n\n    Example usage:\n    >>> server_manager = VLLMServerManager(model_name=\"my_llm\", model_path=\"/path/to/my_llm\")\n    >>> server_manager.start_servers(\n    >>>     tensor_parallel_size=1,\n    >>>     max_num_seqs=1024,\n    >>>     gpu_memory_utilization=0.85,\n    >>> )\n    >>> print(f\"Servers started with PIDs/Ports: {server_pids_ports}\")\n    >>> host_to_http_url = server_manager.host_to_http_url\n    >>> # Define vllm_fn() and predict_batch_udf(vllm_fn) and run inference...\n    >>> success = server_manager.stop_servers()\n    >>> print(f\"Server shutdown success: {success}\")\n    \"\"\"\n\n    def __init__(self, model_name: str, model_path: str = None):\n        super().__init__(model_name, model_path)\n        self.vllm_valid_parameters = self._get_valid_vllm_parameters()\n\n    def _get_valid_vllm_parameters(self) -> List[str]:\n        \"\"\"Get valid vLLM parameters on executor.\"\"\"\n        rdd = self.spark.sparkContext.parallelize(list(range(1)), 1)\n        return rdd.mapPartitions(lambda _: _get_valid_vllm_parameters_task()).collect()\n\n    def _validate_vllm_kwargs(self, kwargs: Dict[str, Any]):\n        \"\"\"Validate vLLM parameters.\"\"\"\n        for key in kwargs:\n            if key not in self.vllm_valid_parameters:\n                if key == \"host\" or key == \"port\":\n                    raise ValueError(\n                        f\"Invalid vLLM parameter: {key}. Host and port are set by server manager.\"\n                    )\n                elif key == \"served-model-name\":\n                    raise ValueError(\n                        f\"Invalid vLLM parameter: {key}. Served model name is set via model_name.\"\n                    )\n                elif key == \"model\":\n                    raise ValueError(\n                        f\"Invalid vLLM parameter: {key}. Model path is set via model_path.\"\n                    )\n                else:\n                    raise ValueError(f\"Invalid vLLM parameter: {key}\")\n\n    def start_servers(\n        self,\n        wait_retries: int = ServerManager.DEFAULT_WAIT_RETRIES,\n        wait_timeout: int = ServerManager.DEFAULT_WAIT_TIMEOUT,\n        **kwargs,\n    ) -> Dict[str, Tuple[int, List[int]]]:\n        \"\"\"\n        Start vLLM OpenAI-compatible servers across the cluster.\n\n        Args:\n            wait_retries: Number of retries for waiting for server startup\n            wait_timeout: Timeout in seconds for each retry\n            **kwargs: Additional arguments to pass to vLLM server command line\n                e.g. tensor_parallel_size, max_num_seqs, gpu_memory_utilization, etc.\n                See https://docs.vllm.ai/en/stable/serving/openai_compatible_server.html#vllm-serve\n\n        Returns:\n            Dictionary of hostname -> (server PID, [port])\n        \"\"\"\n        self._validate_vllm_kwargs(kwargs)\n\n        return super().start_servers(\n            start_server_fn=_start_vllm_server_task,\n            wait_retries=wait_retries,\n            wait_timeout=wait_timeout,\n            **kwargs,\n        )\n"
  },
  {
    "path": "examples/ML+DL-Examples/Spark-DL/dl_inference/tensorflow/image_classification_tf.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"52d55e3f\",\n   \"metadata\": {},\n   \"source\": [\n    \"<img src=\\\"http://developer.download.nvidia.com/notebooks/dlsw-notebooks/tensorrt_torchtrt_efficientnet/nvidia_logo.png\\\" width=\\\"90px\\\">\\n\",\n    \"\\n\",\n    \"# Pyspark TensorFlow Inference\\n\",\n    \"\\n\",\n    \"## Image classification\\n\",\n    \"This notebook demonstrates training and distributed inference for image classification on MNIST.  \\n\",\n    \"Based on: https://www.tensorflow.org/tutorials/keras/save_and_load\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"5233632d\",\n   \"metadata\": {},\n   \"source\": [\n    \"Note that cuFFT/cuDNN/cuBLAS registration errors are expected (as of `tf=2.17.0`) and will not affect behavior, as noted in [this issue.](https://github.com/tensorflow/tensorflow/issues/62075)  \"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"id\": \"c8b28f02\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2025-02-04 13:58:23.275397: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\\n\",\n      \"2025-02-04 13:58:23.282713: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\\n\",\n      \"2025-02-04 13:58:23.290717: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\\n\",\n      \"2025-02-04 13:58:23.293187: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\\n\",\n      \"2025-02-04 13:58:23.299616: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\\n\",\n      \"To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\\n\",\n      \"2025-02-04 13:58:23.677341: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"import matplotlib.pyplot as plt\\n\",\n    \"import numpy as np\\n\",\n    \"import subprocess\\n\",\n    \"import shutil\\n\",\n    \"import os\\n\",\n    \"\\n\",\n    \"import tensorflow as tf\\n\",\n    \"from tensorflow import keras\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"id\": \"e2e67086\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2.17.0\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\\n\",\n      \"I0000 00:00:1738706304.084788 3671509 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\\n\",\n      \"I0000 00:00:1738706304.107153 3671509 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\\n\",\n      \"I0000 00:00:1738706304.109954 3671509 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"print(tf.version.VERSION)\\n\",\n    \"\\n\",\n    \"# Enable GPU memory growth\\n\",\n    \"gpus = tf.config.experimental.list_physical_devices('GPU')\\n\",\n    \"if gpus:\\n\",\n    \"    try:\\n\",\n    \"        for gpu in gpus:\\n\",\n    \"            tf.config.experimental.set_memory_growth(gpu, True)\\n\",\n    \"    except RuntimeError as e:\\n\",\n    \"        print(e)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"7e0c7ad6\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Load and preprocess dataset\\n\",\n    \"\\n\",\n    \"Load MNIST and create a train/test split.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"id\": \"5b007f7c\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"((60000, 28, 28), (10000, 28, 28))\"\n      ]\n     },\n     \"execution_count\": 3,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()\\n\",\n    \"train_images.shape, test_images.shape\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"id\": \"7b7cedd1\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"((1000, 784), (1000, 784))\"\n      ]\n     },\n     \"execution_count\": 4,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"train_labels = train_labels[:1000]\\n\",\n    \"test_labels = test_labels[:1000]\\n\",\n    \"\\n\",\n    \"train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0\\n\",\n    \"test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0\\n\",\n    \"\\n\",\n    \"train_images.shape, test_images.shape\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"867a4403\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Define a model\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"id\": \"746d94db\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"/home/rishic/anaconda3/envs/spark-dl-tf/lib/python3.11/site-packages/keras/src/layers/core/dense.py:87: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead.\\n\",\n      \"  super().__init__(activity_regularizer=activity_regularizer, **kwargs)\\n\",\n      \"I0000 00:00:1738706304.278396 3671509 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"I0000 00:00:1738706304.281131 3671509 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\\n\",\n      \"I0000 00:00:1738706304.283741 3671509 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\\n\",\n      \"I0000 00:00:1738706304.403175 3671509 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\\n\",\n      \"I0000 00:00:1738706304.404296 3671509 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\\n\",\n      \"I0000 00:00:1738706304.405232 3671509 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\\n\",\n      \"2025-02-04 13:58:24.406153: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 40769 MB memory:  -> device: 0, name: NVIDIA RTX A6000, pci bus id: 0000:01:00.0, compute capability: 8.6\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/html\": [\n       \"<pre style=\\\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\\\"><span style=\\\"font-weight: bold\\\">Model: \\\"sequential\\\"</span>\\n\",\n       \"</pre>\\n\"\n      ],\n      \"text/plain\": [\n       \"\\u001b[1mModel: \\\"sequential\\\"\\u001b[0m\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/html\": [\n       \"<pre style=\\\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\\\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\\n\",\n       \"┃<span style=\\\"font-weight: bold\\\"> Layer (type)                    </span>┃<span style=\\\"font-weight: bold\\\"> Output Shape           </span>┃<span style=\\\"font-weight: bold\\\">       Param # </span>┃\\n\",\n       \"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\\n\",\n       \"│ dense (<span style=\\\"color: #0087ff; text-decoration-color: #0087ff\\\">Dense</span>)                   │ (<span style=\\\"color: #00d7ff; text-decoration-color: #00d7ff\\\">None</span>, <span style=\\\"color: #00af00; text-decoration-color: #00af00\\\">512</span>)            │       <span style=\\\"color: #00af00; text-decoration-color: #00af00\\\">401,920</span> │\\n\",\n       \"├─────────────────────────────────┼────────────────────────┼───────────────┤\\n\",\n       \"│ dropout (<span style=\\\"color: #0087ff; text-decoration-color: #0087ff\\\">Dropout</span>)               │ (<span style=\\\"color: #00d7ff; text-decoration-color: #00d7ff\\\">None</span>, <span style=\\\"color: #00af00; text-decoration-color: #00af00\\\">512</span>)            │             <span style=\\\"color: #00af00; text-decoration-color: #00af00\\\">0</span> │\\n\",\n       \"├─────────────────────────────────┼────────────────────────┼───────────────┤\\n\",\n       \"│ dense_1 (<span style=\\\"color: #0087ff; text-decoration-color: #0087ff\\\">Dense</span>)                 │ (<span style=\\\"color: #00d7ff; text-decoration-color: #00d7ff\\\">None</span>, <span style=\\\"color: #00af00; text-decoration-color: #00af00\\\">10</span>)             │         <span style=\\\"color: #00af00; text-decoration-color: #00af00\\\">5,130</span> │\\n\",\n       \"└─────────────────────────────────┴────────────────────────┴───────────────┘\\n\",\n       \"</pre>\\n\"\n      ],\n      \"text/plain\": [\n       \"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\\n\",\n       \"┃\\u001b[1m \\u001b[0m\\u001b[1mLayer (type)                   \\u001b[0m\\u001b[1m \\u001b[0m┃\\u001b[1m \\u001b[0m\\u001b[1mOutput Shape          \\u001b[0m\\u001b[1m \\u001b[0m┃\\u001b[1m \\u001b[0m\\u001b[1m      Param #\\u001b[0m\\u001b[1m \\u001b[0m┃\\n\",\n       \"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\\n\",\n       \"│ dense (\\u001b[38;5;33mDense\\u001b[0m)                   │ (\\u001b[38;5;45mNone\\u001b[0m, \\u001b[38;5;34m512\\u001b[0m)            │       \\u001b[38;5;34m401,920\\u001b[0m │\\n\",\n       \"├─────────────────────────────────┼────────────────────────┼───────────────┤\\n\",\n       \"│ dropout (\\u001b[38;5;33mDropout\\u001b[0m)               │ (\\u001b[38;5;45mNone\\u001b[0m, \\u001b[38;5;34m512\\u001b[0m)            │             \\u001b[38;5;34m0\\u001b[0m │\\n\",\n       \"├─────────────────────────────────┼────────────────────────┼───────────────┤\\n\",\n       \"│ dense_1 (\\u001b[38;5;33mDense\\u001b[0m)                 │ (\\u001b[38;5;45mNone\\u001b[0m, \\u001b[38;5;34m10\\u001b[0m)             │         \\u001b[38;5;34m5,130\\u001b[0m │\\n\",\n       \"└─────────────────────────────────┴────────────────────────┴───────────────┘\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/html\": [\n       \"<pre style=\\\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\\\"><span style=\\\"font-weight: bold\\\"> Total params: </span><span style=\\\"color: #00af00; text-decoration-color: #00af00\\\">407,050</span> (1.55 MB)\\n\",\n       \"</pre>\\n\"\n      ],\n      \"text/plain\": [\n       \"\\u001b[1m Total params: \\u001b[0m\\u001b[38;5;34m407,050\\u001b[0m (1.55 MB)\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/html\": [\n       \"<pre style=\\\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\\\"><span style=\\\"font-weight: bold\\\"> Trainable params: </span><span style=\\\"color: #00af00; text-decoration-color: #00af00\\\">407,050</span> (1.55 MB)\\n\",\n       \"</pre>\\n\"\n      ],\n      \"text/plain\": [\n       \"\\u001b[1m Trainable params: \\u001b[0m\\u001b[38;5;34m407,050\\u001b[0m (1.55 MB)\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/html\": [\n       \"<pre style=\\\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\\\"><span style=\\\"font-weight: bold\\\"> Non-trainable params: </span><span style=\\\"color: #00af00; text-decoration-color: #00af00\\\">0</span> (0.00 B)\\n\",\n       \"</pre>\\n\"\n      ],\n      \"text/plain\": [\n       \"\\u001b[1m Non-trainable params: \\u001b[0m\\u001b[38;5;34m0\\u001b[0m (0.00 B)\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"# Define a simple sequential model\\n\",\n    \"def create_model():\\n\",\n    \"    model = tf.keras.Sequential([\\n\",\n    \"    keras.layers.Dense(512, activation='relu', input_shape=(784,)),\\n\",\n    \"    keras.layers.Dropout(0.2),\\n\",\n    \"    keras.layers.Dense(10)\\n\",\n    \"    ])\\n\",\n    \"\\n\",\n    \"    model.compile(optimizer='adam',\\n\",\n    \"                loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\\n\",\n    \"                metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])\\n\",\n    \"\\n\",\n    \"    return model\\n\",\n    \"\\n\",\n    \"# Create a basic model instance\\n\",\n    \"model = create_model()\\n\",\n    \"\\n\",\n    \"# Display the model's architecture\\n\",\n    \"model.summary()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"605d082a\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Save checkpoints during training\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"id\": \"dde1a855\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"os.mkdir(\\\"models\\\") if not os.path.exists(\\\"models\\\") else None\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"id\": \"244746be\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Epoch 1/10\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\\n\",\n      \"I0000 00:00:1738706304.982690 3671754 service.cc:146] XLA service 0x7f1464019260 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:\\n\",\n      \"I0000 00:00:1738706304.982718 3671754 service.cc:154]   StreamExecutor device (0): NVIDIA RTX A6000, Compute Capability 8.6\\n\",\n      \"2025-02-04 13:58:24.999594: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.\\n\",\n      \"2025-02-04 13:58:25.043847: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:531] Loaded cuDNN version 8907\\n\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\u001b[1m 1/32\\u001b[0m \\u001b[37m━━━━━━━━━━━━━━━━━━━━\\u001b[0m \\u001b[1m26s\\u001b[0m 868ms/step - loss: 2.4638 - sparse_categorical_accuracy: 0.0625\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"I0000 00:00:1738706305.619913 3671754 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.\\n\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\u001b[1m32/32\\u001b[0m \\u001b[32m━━━━━━━━━━━━━━━━━━━━\\u001b[0m\\u001b[37m\\u001b[0m \\u001b[1m0s\\u001b[0m 17ms/step - loss: 1.6323 - sparse_categorical_accuracy: 0.4913  \"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2025-02-04 13:58:26.791107: I external/local_xla/xla/stream_executor/cuda/cuda_asm_compiler.cc:393] ptxas warning : Registers are spilled to local memory in function 'gemm_fusion_dot_33', 4 bytes spill stores, 4 bytes spill loads\\n\",\n      \"\\n\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\n\",\n      \"Epoch 1: val_sparse_categorical_accuracy improved from -inf to 0.76100, saving model to models/training_1/checkpoint.model.keras\\n\",\n      \"\\u001b[1m32/32\\u001b[0m \\u001b[32m━━━━━━━━━━━━━━━━━━━━\\u001b[0m\\u001b[37m\\u001b[0m \\u001b[1m2s\\u001b[0m 48ms/step - loss: 1.6179 - sparse_categorical_accuracy: 0.4965 - val_loss: 0.7533 - val_sparse_categorical_accuracy: 0.7610\\n\",\n      \"Epoch 2/10\\n\",\n      \"\\u001b[1m 1/32\\u001b[0m \\u001b[37m━━━━━━━━━━━━━━━━━━━━\\u001b[0m \\u001b[1m0s\\u001b[0m 13ms/step - loss: 0.3965 - sparse_categorical_accuracy: 0.9062\\n\",\n      \"Epoch 2: val_sparse_categorical_accuracy improved from 0.76100 to 0.80400, saving model to models/training_1/checkpoint.model.keras\\n\",\n      \"\\u001b[1m32/32\\u001b[0m \\u001b[32m━━━━━━━━━━━━━━━━━━━━\\u001b[0m\\u001b[37m\\u001b[0m \\u001b[1m0s\\u001b[0m 1ms/step - loss: 0.4549 - sparse_categorical_accuracy: 0.8773 - val_loss: 0.6002 - val_sparse_categorical_accuracy: 0.8040\\n\",\n      \"Epoch 3/10\\n\",\n      \"\\u001b[1m 1/32\\u001b[0m \\u001b[37m━━━━━━━━━━━━━━━━━━━━\\u001b[0m \\u001b[1m0s\\u001b[0m 13ms/step - loss: 0.4427 - sparse_categorical_accuracy: 0.8438\\n\",\n      \"Epoch 3: val_sparse_categorical_accuracy improved from 0.80400 to 0.85100, saving model to models/training_1/checkpoint.model.keras\\n\",\n      \"\\u001b[1m32/32\\u001b[0m \\u001b[32m━━━━━━━━━━━━━━━━━━━━\\u001b[0m\\u001b[37m\\u001b[0m \\u001b[1m0s\\u001b[0m 1ms/step - loss: 0.2924 - sparse_categorical_accuracy: 0.9289 - val_loss: 0.4876 - val_sparse_categorical_accuracy: 0.8510\\n\",\n      \"Epoch 4/10\\n\",\n      \"\\u001b[1m 1/32\\u001b[0m \\u001b[37m━━━━━━━━━━━━━━━━━━━━\\u001b[0m \\u001b[1m0s\\u001b[0m 13ms/step - loss: 0.3644 - sparse_categorical_accuracy: 0.9375\\n\",\n      \"Epoch 4: val_sparse_categorical_accuracy did not improve from 0.85100\\n\",\n      \"\\u001b[1m32/32\\u001b[0m \\u001b[32m━━━━━━━━━━━━━━━━━━━━\\u001b[0m\\u001b[37m\\u001b[0m \\u001b[1m0s\\u001b[0m 1ms/step - loss: 0.2790 - sparse_categorical_accuracy: 0.9275 - val_loss: 0.4981 - val_sparse_categorical_accuracy: 0.8430\\n\",\n      \"Epoch 5/10\\n\",\n      \"\\u001b[1m 1/32\\u001b[0m \\u001b[37m━━━━━━━━━━━━━━━━━━━━\\u001b[0m \\u001b[1m0s\\u001b[0m 12ms/step - loss: 0.2368 - sparse_categorical_accuracy: 0.9375\\n\",\n      \"Epoch 5: val_sparse_categorical_accuracy did not improve from 0.85100\\n\",\n      \"\\u001b[1m32/32\\u001b[0m \\u001b[32m━━━━━━━━━━━━━━━━━━━━\\u001b[0m\\u001b[37m\\u001b[0m \\u001b[1m0s\\u001b[0m 1ms/step - loss: 0.1794 - sparse_categorical_accuracy: 0.9645 - val_loss: 0.4893 - val_sparse_categorical_accuracy: 0.8450\\n\",\n      \"Epoch 6/10\\n\",\n      \"\\u001b[1m 1/32\\u001b[0m \\u001b[37m━━━━━━━━━━━━━━━━━━━━\\u001b[0m \\u001b[1m0s\\u001b[0m 12ms/step - loss: 0.0830 - sparse_categorical_accuracy: 1.0000\\n\",\n      \"Epoch 6: val_sparse_categorical_accuracy improved from 0.85100 to 0.85400, saving model to models/training_1/checkpoint.model.keras\\n\",\n      \"\\u001b[1m32/32\\u001b[0m \\u001b[32m━━━━━━━━━━━━━━━━━━━━\\u001b[0m\\u001b[37m\\u001b[0m \\u001b[1m0s\\u001b[0m 1ms/step - loss: 0.1430 - sparse_categorical_accuracy: 0.9739 - val_loss: 0.4338 - val_sparse_categorical_accuracy: 0.8540\\n\",\n      \"Epoch 7/10\\n\",\n      \"\\u001b[1m 1/32\\u001b[0m \\u001b[37m━━━━━━━━━━━━━━━━━━━━\\u001b[0m \\u001b[1m0s\\u001b[0m 13ms/step - loss: 0.1518 - sparse_categorical_accuracy: 1.0000\\n\",\n      \"Epoch 7: val_sparse_categorical_accuracy improved from 0.85400 to 0.86200, saving model to models/training_1/checkpoint.model.keras\\n\",\n      \"\\u001b[1m32/32\\u001b[0m \\u001b[32m━━━━━━━━━━━━━━━━━━━━\\u001b[0m\\u001b[37m\\u001b[0m \\u001b[1m0s\\u001b[0m 1ms/step - loss: 0.0876 - sparse_categorical_accuracy: 0.9909 - val_loss: 0.4194 - val_sparse_categorical_accuracy: 0.8620\\n\",\n      \"Epoch 8/10\\n\",\n      \"\\u001b[1m 1/32\\u001b[0m \\u001b[37m━━━━━━━━━━━━━━━━━━━━\\u001b[0m \\u001b[1m0s\\u001b[0m 12ms/step - loss: 0.0209 - sparse_categorical_accuracy: 1.0000\\n\",\n      \"Epoch 8: val_sparse_categorical_accuracy improved from 0.86200 to 0.86800, saving model to models/training_1/checkpoint.model.keras\\n\",\n      \"\\u001b[1m32/32\\u001b[0m \\u001b[32m━━━━━━━━━━━━━━━━━━━━\\u001b[0m\\u001b[37m\\u001b[0m \\u001b[1m0s\\u001b[0m 1ms/step - loss: 0.0669 - sparse_categorical_accuracy: 0.9938 - val_loss: 0.4038 - val_sparse_categorical_accuracy: 0.8680\\n\",\n      \"Epoch 9/10\\n\",\n      \"\\u001b[1m 1/32\\u001b[0m \\u001b[37m━━━━━━━━━━━━━━━━━━━━\\u001b[0m \\u001b[1m0s\\u001b[0m 13ms/step - loss: 0.0211 - sparse_categorical_accuracy: 1.0000\\n\",\n      \"Epoch 9: val_sparse_categorical_accuracy improved from 0.86800 to 0.86900, saving model to models/training_1/checkpoint.model.keras\\n\",\n      \"\\u001b[1m32/32\\u001b[0m \\u001b[32m━━━━━━━━━━━━━━━━━━━━\\u001b[0m\\u001b[37m\\u001b[0m \\u001b[1m0s\\u001b[0m 1ms/step - loss: 0.0429 - sparse_categorical_accuracy: 0.9998 - val_loss: 0.4062 - val_sparse_categorical_accuracy: 0.8690\\n\",\n      \"Epoch 10/10\\n\",\n      \"\\u001b[1m 1/32\\u001b[0m \\u001b[37m━━━━━━━━━━━━━━━━━━━━\\u001b[0m \\u001b[1m0s\\u001b[0m 13ms/step - loss: 0.0283 - sparse_categorical_accuracy: 1.0000\\n\",\n      \"Epoch 10: val_sparse_categorical_accuracy did not improve from 0.86900\\n\",\n      \"\\u001b[1m32/32\\u001b[0m \\u001b[32m━━━━━━━━━━━━━━━━━━━━\\u001b[0m\\u001b[37m\\u001b[0m \\u001b[1m0s\\u001b[0m 1ms/step - loss: 0.0387 - sparse_categorical_accuracy: 0.9992 - val_loss: 0.4069 - val_sparse_categorical_accuracy: 0.8680\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"<keras.src.callbacks.history.History at 0x7f1673724c50>\"\n      ]\n     },\n     \"execution_count\": 7,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"checkpoint_path = \\\"models/training_1/checkpoint.model.keras\\\"\\n\",\n    \"checkpoint_dir = os.path.dirname(checkpoint_path)\\n\",\n    \"\\n\",\n    \"# Create a callback that saves the model's weights\\n\",\n    \"cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,\\n\",\n    \"                                                 monitor='val_sparse_categorical_accuracy',\\n\",\n    \"                                                 mode='max',\\n\",\n    \"                                                 save_best_only=True,\\n\",\n    \"                                                 verbose=1)\\n\",\n    \"\\n\",\n    \"# Train the model with the new callback\\n\",\n    \"model.fit(train_images, \\n\",\n    \"          train_labels,  \\n\",\n    \"          epochs=10,\\n\",\n    \"          validation_data=(test_images, test_labels),\\n\",\n    \"          callbacks=[cp_callback])  # Pass callback to training\\n\",\n    \"\\n\",\n    \"# This may generate warnings related to saving the state of the optimizer.\\n\",\n    \"# These warnings (and similar warnings throughout this notebook)\\n\",\n    \"# are in place to discourage outdated usage, and can be ignored.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"id\": \"310eae08\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"['checkpoint.model.keras']\"\n      ]\n     },\n     \"execution_count\": 8,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"os.listdir(checkpoint_dir)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"50eeb6e5\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"INFO:tensorflow:Assets written to: models/mnist_model/assets\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"INFO:tensorflow:Assets written to: models/mnist_model/assets\\n\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Saved artifact at 'models/mnist_model'. The following endpoints are available:\\n\",\n      \"\\n\",\n      \"* Endpoint 'serve'\\n\",\n      \"  args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 784), dtype=tf.float32, name='keras_tensor')\\n\",\n      \"Output Type:\\n\",\n      \"  TensorSpec(shape=(None, 10), dtype=tf.float32, name=None)\\n\",\n      \"Captures:\\n\",\n      \"  139734758151120: TensorSpec(shape=(), dtype=tf.resource, name=None)\\n\",\n      \"  139734413261904: TensorSpec(shape=(), dtype=tf.resource, name=None)\\n\",\n      \"  139739081696528: TensorSpec(shape=(), dtype=tf.resource, name=None)\\n\",\n      \"  139734413262096: TensorSpec(shape=(), dtype=tf.resource, name=None)\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"# Export model in saved_model format\\n\",\n    \"model.export(\\\"models/mnist_model\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 10,\n   \"id\": \"6d3bba9e\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"/home/rishic/anaconda3/envs/spark-dl-tf/lib/python3.11/site-packages/keras/src/layers/core/dense.py:87: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead.\\n\",\n      \"  super().__init__(activity_regularizer=activity_regularizer, **kwargs)\\n\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"32/32 - 0s - 10ms/step - loss: 2.3876 - sparse_categorical_accuracy: 0.0840\\n\",\n      \"Untrained model, accuracy:  8.40%\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"# Create a basic model instance\\n\",\n    \"model = create_model()\\n\",\n    \"\\n\",\n    \"# Evaluate the model\\n\",\n    \"loss, acc = model.evaluate(test_images, test_labels, verbose=2)\\n\",\n    \"print(\\\"Untrained model, accuracy: {:5.2f}%\\\".format(100 * acc))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"22ad1708\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"32/32 - 0s - 704us/step - loss: 0.4062 - sparse_categorical_accuracy: 0.8690\\n\",\n      \"Restored model, accuracy: 86.90%\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"/home/rishic/anaconda3/envs/spark-dl-tf/lib/python3.11/site-packages/keras/src/saving/saving_lib.py:713: UserWarning: Skipping variable loading for optimizer 'adam', because it has 2 variables whereas the saved optimizer has 10 variables. \\n\",\n      \"  saveable.load_own_variables(weights_store.get(inner_path))\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"# Load the weights from the checkpoint\\n\",\n    \"model.load_weights(checkpoint_path)\\n\",\n    \"\\n\",\n    \"# Re-evaluate the model\\n\",\n    \"loss, acc = model.evaluate(test_images, test_labels, verbose=2)\\n\",\n    \"print(\\\"Restored model, accuracy: {:5.2f}%\\\".format(100 * acc))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"1c097d63\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Checkpoint callback options\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 12,\n   \"id\": \"cb336e89\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"os.mkdir(\\\"models/training_2\\\") if not os.path.exists(\\\"models/training_2\\\") else None\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 13,\n   \"id\": \"750b6deb\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\n\",\n      \"Epoch 5: saving model to models/training_2/cp-0005.weights.h5\\n\",\n      \"\\n\",\n      \"Epoch 10: saving model to models/training_2/cp-0010.weights.h5\\n\",\n      \"\\n\",\n      \"Epoch 15: saving model to models/training_2/cp-0015.weights.h5\\n\",\n      \"\\n\",\n      \"Epoch 20: saving model to models/training_2/cp-0020.weights.h5\\n\",\n      \"\\n\",\n      \"Epoch 25: saving model to models/training_2/cp-0025.weights.h5\\n\",\n      \"\\n\",\n      \"Epoch 30: saving model to models/training_2/cp-0030.weights.h5\\n\",\n      \"\\n\",\n      \"Epoch 35: saving model to models/training_2/cp-0035.weights.h5\\n\",\n      \"\\n\",\n      \"Epoch 40: saving model to models/training_2/cp-0040.weights.h5\\n\",\n      \"\\n\",\n      \"Epoch 45: saving model to models/training_2/cp-0045.weights.h5\\n\",\n      \"\\n\",\n      \"Epoch 50: saving model to models/training_2/cp-0050.weights.h5\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"<keras.src.callbacks.history.History at 0x7f1672f47510>\"\n      ]\n     },\n     \"execution_count\": 13,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"# Include the epoch in the file name (uses `str.format`)\\n\",\n    \"checkpoint_path = \\\"models/training_2/cp-{epoch:04d}.weights.h5\\\"\\n\",\n    \"checkpoint_dir = os.path.dirname(checkpoint_path)\\n\",\n    \"\\n\",\n    \"batch_size = 32\\n\",\n    \"\\n\",\n    \"# Calculate the number of batches per epoch\\n\",\n    \"import math\\n\",\n    \"n_batches = len(train_images) / batch_size\\n\",\n    \"n_batches = math.ceil(n_batches)    # round up the number of batches to the nearest whole integer\\n\",\n    \"\\n\",\n    \"# Create a callback that saves the model's weights every 5 epochs\\n\",\n    \"cp_callback = tf.keras.callbacks.ModelCheckpoint(\\n\",\n    \"    filepath=checkpoint_path, \\n\",\n    \"    verbose=1, \\n\",\n    \"    save_weights_only=True,\\n\",\n    \"    save_freq=5*n_batches)\\n\",\n    \"\\n\",\n    \"# Create a new model instance\\n\",\n    \"model = create_model()\\n\",\n    \"\\n\",\n    \"# Save the weights using the `checkpoint_path` format\\n\",\n    \"model.save_weights(checkpoint_path.format(epoch=0))\\n\",\n    \"\\n\",\n    \"# Train the model with the new callback\\n\",\n    \"model.fit(train_images, \\n\",\n    \"          train_labels,\\n\",\n    \"          epochs=50, \\n\",\n    \"          batch_size=batch_size, \\n\",\n    \"          callbacks=[cp_callback],\\n\",\n    \"          validation_data=(test_images, test_labels),\\n\",\n    \"          verbose=0)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 14,\n   \"id\": \"1c43fd3d\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"['cp-0000.weights.h5',\\n\",\n       \" 'cp-0015.weights.h5',\\n\",\n       \" 'cp-0010.weights.h5',\\n\",\n       \" 'cp-0035.weights.h5',\\n\",\n       \" 'cp-0020.weights.h5',\\n\",\n       \" 'cp-0040.weights.h5',\\n\",\n       \" 'cp-0050.weights.h5',\\n\",\n       \" 'cp-0005.weights.h5',\\n\",\n       \" 'cp-0045.weights.h5',\\n\",\n       \" 'cp-0025.weights.h5',\\n\",\n       \" 'cp-0030.weights.h5']\"\n      ]\n     },\n     \"execution_count\": 14,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"os.listdir(checkpoint_dir)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 15,\n   \"id\": \"0d7ae715\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"latest = \\\"models/training_2/cp-0030.weights.h5\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 16,\n   \"id\": \"d345c6f7\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"32/32 - 0s - 11ms/step - loss: 0.4827 - sparse_categorical_accuracy: 0.8740\\n\",\n      \"Restored model, accuracy: 87.40%\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"# Create a new model instance\\n\",\n    \"model = create_model()\\n\",\n    \"\\n\",\n    \"# Load the previously saved weights\\n\",\n    \"model.load_weights(latest)\\n\",\n    \"\\n\",\n    \"# Re-evaluate the model from the latest checkpoint\\n\",\n    \"loss, acc = model.evaluate(test_images, test_labels, verbose=2)\\n\",\n    \"print(\\\"Restored model, accuracy: {:5.2f}%\\\".format(100 * acc))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"a86f4700\",\n   \"metadata\": {},\n   \"source\": [\n    \"## PySpark\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 17,\n   \"id\": \"7fcf07bb\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from pyspark.ml.functions import predict_batch_udf\\n\",\n    \"from pyspark.sql.functions import struct, col, array, pandas_udf\\n\",\n    \"from pyspark.sql.types import *\\n\",\n    \"from pyspark.sql import SparkSession\\n\",\n    \"from pyspark import SparkConf\\n\",\n    \"import pandas as pd\\n\",\n    \"import json\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"50f02919\",\n   \"metadata\": {},\n   \"source\": [\n    \"Check the cluster environment to handle any platform-specific Spark configurations.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 18,\n   \"id\": \"4c81d510\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"on_databricks = os.environ.get(\\\"DATABRICKS_RUNTIME_VERSION\\\", False)\\n\",\n    \"on_dataproc = os.environ.get(\\\"DATAPROC_IMAGE_VERSION\\\", False)\\n\",\n    \"on_standalone = not (on_databricks or on_dataproc)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"c58f4df7\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Create Spark Session\\n\",\n    \"\\n\",\n    \"For local standalone clusters, we'll connect to the cluster and create the Spark Session.  \\n\",\n    \"For CSP environments, Spark will either be preconfigured (Databricks) or we'll need to create the Spark Session (Dataproc).\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 19,\n   \"id\": \"2c022c24\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"25/02/04 13:58:33 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)\\n\",\n      \"25/02/04 13:58:33 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\\n\",\n      \"Setting default log level to \\\"WARN\\\".\\n\",\n      \"To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\\n\",\n      \"25/02/04 13:58:33 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"conf = SparkConf()\\n\",\n    \"\\n\",\n    \"if 'spark' not in globals():\\n\",\n    \"    if on_standalone:\\n\",\n    \"        import socket\\n\",\n    \"        \\n\",\n    \"        conda_env = os.environ.get(\\\"CONDA_PREFIX\\\")\\n\",\n    \"        hostname = socket.gethostname()\\n\",\n    \"        conf.setMaster(f\\\"spark://{hostname}:7077\\\")\\n\",\n    \"        conf.set(\\\"spark.pyspark.python\\\", f\\\"{conda_env}/bin/python\\\")\\n\",\n    \"        conf.set(\\\"spark.pyspark.driver.python\\\", f\\\"{conda_env}/bin/python\\\")\\n\",\n    \"    elif on_dataproc:\\n\",\n    \"        conf.set(\\\"spark.executorEnv.TF_GPU_ALLOCATOR\\\", \\\"cuda_malloc_async\\\")\\n\",\n    \"\\n\",\n    \"    conf.set(\\\"spark.executor.cores\\\", \\\"8\\\")\\n\",\n    \"    conf.set(\\\"spark.task.resource.gpu.amount\\\", \\\"0.125\\\")\\n\",\n    \"    conf.set(\\\"spark.executor.resource.gpu.amount\\\", \\\"1\\\")\\n\",\n    \"    conf.set(\\\"spark.sql.execution.arrow.pyspark.enabled\\\", \\\"true\\\")\\n\",\n    \"    conf.set(\\\"spark.python.worker.reuse\\\", \\\"true\\\")\\n\",\n    \"\\n\",\n    \"conf.set(\\\"spark.sql.execution.arrow.maxRecordsPerBatch\\\", \\\"1000\\\")\\n\",\n    \"spark = SparkSession.builder.appName(\\\"spark-dl-examples\\\").config(conf=conf).getOrCreate()\\n\",\n    \"sc = spark.sparkContext\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"c81d0b1b\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Create Spark Dataframe\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 20,\n   \"id\": \"49ff5203\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"(1000, 784)\"\n      ]\n     },\n     \"execution_count\": 20,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"# numpy array to pandas DataFrame\\n\",\n    \"test_pdf = pd.DataFrame(test_images)\\n\",\n    \"test_pdf.shape\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 21,\n   \"id\": \"182ee0c7\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"df = spark.createDataFrame(test_pdf).repartition(8)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"d4e1c7ec-64fa-43c4-9bcf-0868a401d1f2\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Save as Parquet (784 columns of float)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 22,\n   \"id\": \"0061c39a-0871-429e-a4ff-751d26bf4b04\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"25/02/04 13:58:35 WARN package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.\\n\",\n      \"[Stage 0:>                                                          (0 + 8) / 8]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 3.05 ms, sys: 1.22 ms, total: 4.26 ms\\n\",\n      \"Wall time: 1.93 s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"data_path_784 = \\\"spark-dl-datasets/mnist_784\\\"\\n\",\n    \"if on_databricks:\\n\",\n    \"    dbutils.fs.mkdirs(\\\"/FileStore/spark-dl-datasets\\\")\\n\",\n    \"    data_path_784 = \\\"dbfs:/FileStore/\\\" + data_path_784\\n\",\n    \"\\n\",\n    \"df.write.mode(\\\"overwrite\\\").parquet(data_path_784)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"18315afb-3fa2-4953-9297-52c04dd70c32\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Save as Parquet (1 column of 784 float)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 23,\n   \"id\": \"302c73ec\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"(1000, 1)\"\n      ]\n     },\n     \"execution_count\": 23,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"test_pdf['data'] = test_pdf.values.tolist()\\n\",\n    \"pdf = test_pdf[['data']]\\n\",\n    \"pdf.shape\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 24,\n   \"id\": \"5495901b\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"df = spark.createDataFrame(pdf)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 25,\n   \"id\": \"5fa7faa8-c6bd-41b0-b5f7-fb121f0332e6\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 875 μs, sys: 187 μs, total: 1.06 ms\\n\",\n      \"Wall time: 196 ms\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"data_path_1 = \\\"spark-dl-datasets/mnist_1\\\"\\n\",\n    \"if on_databricks:\\n\",\n    \"    dbutils.fs.mkdirs(\\\"/FileStore/spark-dl-datasets\\\")\\n\",\n    \"    data_path_1 = \\\"dbfs:/FileStore/\\\" + data_path_1\\n\",\n    \"\\n\",\n    \"df.write.mode(\\\"overwrite\\\").parquet(data_path_1)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"b366aaeb\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Inference using Spark DL API\\n\",\n    \"\\n\",\n    \"Distributed inference using the PySpark [predict_batch_udf](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.functions.predict_batch_udf.html#pyspark.ml.functions.predict_batch_udf):\\n\",\n    \"\\n\",\n    \"- predict_batch_fn uses Tensorflow APIs to load the model and return a predict function which operates on numpy arrays \\n\",\n    \"- predict_batch_udf will convert the Spark DataFrame columns into numpy input batches for the predict function\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"4238fb28-d002-4b4d-9aa1-8af1fbd5d569\",\n   \"metadata\": {},\n   \"source\": [\n    \"### 1 column of 784 float\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 26,\n   \"id\": \"b9cf62f8-96b2-4716-80bd-bb93d5f939bd\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"model_path = \\\"{}/models/training_1/checkpoint.model.keras\\\".format(os.getcwd())\\n\",\n    \"\\n\",\n    \"# For cloud environments, copy the model to the distributed file system.\\n\",\n    \"if on_databricks:\\n\",\n    \"    dbutils.fs.mkdirs(\\\"/FileStore/spark-dl-models\\\")\\n\",\n    \"    dbfs_model_path = \\\"/dbfs/FileStore/spark-dl-models/checkpoint.model.keras\\\"\\n\",\n    \"    shutil.copy(model_path, dbfs_model_path)\\n\",\n    \"    model_path = dbfs_model_path\\n\",\n    \"elif on_dataproc:\\n\",\n    \"    # GCS is mounted at /mnt/gcs by the init script\\n\",\n    \"    models_dir = \\\"/mnt/gcs/spark-dl/models\\\"\\n\",\n    \"    os.mkdir(models_dir) if not os.path.exists(models_dir) else None\\n\",\n    \"    gcs_model_path = models_dir + \\\"/checkpoint.model.keras\\\"\\n\",\n    \"    shutil.copy(model_path, gcs_model_path)\\n\",\n    \"    model_path = gcs_model_path\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 27,\n   \"id\": \"b81fa297-d9d0-4600-880d-dbdcdf8bccc6\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def predict_batch_fn():\\n\",\n    \"    import tensorflow as tf\\n\",\n    \"\\n\",\n    \"    # Enable GPU memory growth to avoid CUDA OOM\\n\",\n    \"    gpus = tf.config.experimental.list_physical_devices('GPU')\\n\",\n    \"    if gpus:\\n\",\n    \"        try:\\n\",\n    \"            for gpu in gpus:\\n\",\n    \"                tf.config.experimental.set_memory_growth(gpu, True)\\n\",\n    \"        except RuntimeError as e:\\n\",\n    \"            print(e)\\n\",\n    \"\\n\",\n    \"    model = tf.keras.models.load_model(model_path)\\n\",\n    \"    def predict(inputs: np.ndarray) -> np.ndarray:\\n\",\n    \"        return model.predict(inputs)\\n\",\n    \"        \\n\",\n    \"    return predict\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 28,\n   \"id\": \"72a689bd-dd82-492e-8740-1738a215325f\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"mnist = predict_batch_udf(predict_batch_fn,\\n\",\n    \"                          return_type=ArrayType(FloatType()),\\n\",\n    \"                          batch_size=128,\\n\",\n    \"                          input_tensor_shapes=[[784]])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 29,\n   \"id\": \"60a70150-26b1-4145-9e7d-6e17389216b7\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"1\"\n      ]\n     },\n     \"execution_count\": 29,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"df = spark.read.parquet(data_path_1)\\n\",\n    \"len(df.columns)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 30,\n   \"id\": \"e027f0d2-0f65-47b7-a562-2f0965faceec\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+--------------------+\\n\",\n      \"|                data|\\n\",\n      \"+--------------------+\\n\",\n      \"|[0.0, 0.0, 0.0, 0...|\\n\",\n      \"|[0.0, 0.0, 0.0, 0...|\\n\",\n      \"|[0.0, 0.0, 0.0, 0...|\\n\",\n      \"|[0.0, 0.0, 0.0, 0...|\\n\",\n      \"|[0.0, 0.0, 0.0, 0...|\\n\",\n      \"+--------------------+\\n\",\n      \"only showing top 5 rows\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"df.show(5)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 31,\n   \"id\": \"f0c3fb2e-469e-47bc-b948-8f6b0d7f6513\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 6:===================================================>       (7 + 1) / 8]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 24.1 ms, sys: 11 ms, total: 35.2 ms\\n\",\n      \"Wall time: 5.52 s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"# first pass caches model/fn\\n\",\n    \"preds = df.withColumn(\\\"preds\\\", mnist(struct(df.columns))).collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 32,\n   \"id\": \"cdfa229a-f4a9-4c11-a410-de4a21c02c82\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 21.1 ms, sys: 14.7 ms, total: 35.8 ms\\n\",\n      \"Wall time: 277 ms\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"preds = df.withColumn(\\\"preds\\\", mnist(*df.columns)).collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 33,\n   \"id\": \"5586ce49-6f93-4343-9b66-0dbb64972179\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 37.1 ms, sys: 8.46 ms, total: 45.6 ms\\n\",\n      \"Wall time: 216 ms\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"preds = df.withColumn(\\\"preds\\\", mnist(*[col(c) for c in df.columns])).collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"004f1599-3c62-499e-9fd8-ed5cb0c90de4\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"source\": [\n    \"#### Check predictions\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 34,\n   \"id\": \"4f947dc0-6b18-4605-810b-e83250a161db\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/html\": [\n       \"<div>\\n\",\n       \"<style scoped>\\n\",\n       \"    .dataframe tbody tr th:only-of-type {\\n\",\n       \"        vertical-align: middle;\\n\",\n       \"    }\\n\",\n       \"\\n\",\n       \"    .dataframe tbody tr th {\\n\",\n       \"        vertical-align: top;\\n\",\n       \"    }\\n\",\n       \"\\n\",\n       \"    .dataframe thead th {\\n\",\n       \"        text-align: right;\\n\",\n       \"    }\\n\",\n       \"</style>\\n\",\n       \"<table border=\\\"1\\\" class=\\\"dataframe\\\">\\n\",\n       \"  <thead>\\n\",\n       \"    <tr style=\\\"text-align: right;\\\">\\n\",\n       \"      <th></th>\\n\",\n       \"      <th>data</th>\\n\",\n       \"      <th>preds</th>\\n\",\n       \"    </tr>\\n\",\n       \"  </thead>\\n\",\n       \"  <tbody>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>0</th>\\n\",\n       \"      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\\n\",\n       \"      <td>[-4.6654954, -2.4895542, -0.5886033, 13.380537...</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>1</th>\\n\",\n       \"      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\\n\",\n       \"      <td>[-2.273215, -7.5127845, 1.1983701, -3.540661, ...</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>2</th>\\n\",\n       \"      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\\n\",\n       \"      <td>[-2.28909, 0.8308607, 0.31311005, 1.1683632, -...</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>3</th>\\n\",\n       \"      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\\n\",\n       \"      <td>[-1.0551968, -6.5028114, 12.420729, 0.45280308...</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>4</th>\\n\",\n       \"      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\\n\",\n       \"      <td>[-3.7887802, 3.9983602, -1.5343361, -0.3698440...</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>5</th>\\n\",\n       \"      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\\n\",\n       \"      <td>[-4.499274, -1.7618222, 1.1183227, 3.946932, -...</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>6</th>\\n\",\n       \"      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\\n\",\n       \"      <td>[-2.7540536, 4.8684144, 0.25152916, -0.4730078...</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>7</th>\\n\",\n       \"      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\\n\",\n       \"      <td>[-1.8887109, 0.02717152, -6.0508857, 0.0875094...</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>8</th>\\n\",\n       \"      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\\n\",\n       \"      <td>[0.9541265, -2.113048, -1.7508972, -5.4303794,...</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>9</th>\\n\",\n       \"      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\\n\",\n       \"      <td>[-1.612412, -0.7655784, -4.473859, 2.0609212, ...</td>\\n\",\n       \"    </tr>\\n\",\n       \"  </tbody>\\n\",\n       \"</table>\\n\",\n       \"</div>\"\n      ],\n      \"text/plain\": [\n       \"                                                data  \\\\\\n\",\n       \"0  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   \\n\",\n       \"1  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   \\n\",\n       \"2  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   \\n\",\n       \"3  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   \\n\",\n       \"4  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   \\n\",\n       \"5  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   \\n\",\n       \"6  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   \\n\",\n       \"7  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   \\n\",\n       \"8  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   \\n\",\n       \"9  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   \\n\",\n       \"\\n\",\n       \"                                               preds  \\n\",\n       \"0  [-4.6654954, -2.4895542, -0.5886033, 13.380537...  \\n\",\n       \"1  [-2.273215, -7.5127845, 1.1983701, -3.540661, ...  \\n\",\n       \"2  [-2.28909, 0.8308607, 0.31311005, 1.1683632, -...  \\n\",\n       \"3  [-1.0551968, -6.5028114, 12.420729, 0.45280308...  \\n\",\n       \"4  [-3.7887802, 3.9983602, -1.5343361, -0.3698440...  \\n\",\n       \"5  [-4.499274, -1.7618222, 1.1183227, 3.946932, -...  \\n\",\n       \"6  [-2.7540536, 4.8684144, 0.25152916, -0.4730078...  \\n\",\n       \"7  [-1.8887109, 0.02717152, -6.0508857, 0.0875094...  \\n\",\n       \"8  [0.9541265, -2.113048, -1.7508972, -5.4303794,...  \\n\",\n       \"9  [-1.612412, -0.7655784, -4.473859, 2.0609212, ...  \"\n      ]\n     },\n     \"execution_count\": 34,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"preds = df.withColumn(\\\"preds\\\", mnist(*df.columns)).limit(10).toPandas()\\n\",\n    \"preds\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 35,\n   \"id\": \"de4964e0-d1f8-4753-afa1-a8f95ca3f151\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"array([-4.6654954, -2.4895542, -0.5886033, 13.380537 , -6.652599 ,\\n\",\n       \"        2.8400383, -7.9901567, -0.7500452, -2.4487166, -4.349809 ],\\n\",\n       \"      dtype=float32)\"\n      ]\n     },\n     \"execution_count\": 35,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"sample = preds.iloc[0]\\n\",\n    \"sample.preds\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 36,\n   \"id\": \"44e9a874-e301-4b72-8df7-bf1c5133c287\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import matplotlib.pyplot as plt\\n\",\n    \"import numpy as np\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 37,\n   \"id\": \"c60e5af4-fc1e-4575-a717-f304664235be\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"prediction = np.argmax(sample.preds)\\n\",\n    \"img = np.array(sample.data).reshape(28,28)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"eb45ecc9-d376-40c4-ad7b-2bd08ca5aaf6\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"image/png\": \"iVBORw0KGgoAAAANSUhEUgAAAaAAAAGzCAYAAABpdMNsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAkfElEQVR4nO3dfXQUdZ7v8U/nqSEkaR7yLAFCFHRAcAYly/AgSiQEZUCYGUG9F7gziJiggI6KR0Udzsksrg7qIHjcHVhHEGWOyMoiDg9JGBRwwTCIM2QhJ0g4kIBcSYcAIaR/9w+uvbQkQDUdfkl4v86pc+iq37fqm6Lgk+qqrnYZY4wAALjKwmw3AAC4NhFAAAArCCAAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKwggAAAVhBAgAPdunXTpEmT/K8LCwvlcrlUWFgYsm24XC698MILIVsf0FwRQGgxlixZIpfL5Z/atGmjHj16KC8vT5WVlbbbc2TNmjUtJmTefvtt3X777UpKSpLb7VZ6eromT56s/fv3224NLVyE7QYAp1566SWlp6fr9OnT2rx5sxYuXKg1a9Zo9+7dio6Ovqq9DBkyRKdOnVJUVJSjujVr1mjBggUNhtCpU6cUEdF8/mkWFxcrPT1dP/vZz9ShQweVlZXp7bff1urVq/W3v/1NqamptltEC9V8jnLgMuXk5OjWW2+VJP36179Wp06d9Oqrr2rVqlWaMGFCgzU1NTVq165dyHsJCwtTmzZtQrrOUK/vSr355psXzBszZoxuvfVWvfPOO3r66actdIXWgLfg0OLdeeedkqSysjJJ0qRJkxQTE6PS0lKNHDlSsbGxeuCBByRJPp9P8+fPV69evdSmTRslJSVp6tSp+u677wLWaYzR3Llz1blzZ0VHR+uOO+7Q119/fcG2G7sGtG3bNo0cOVIdOnRQu3bt1KdPH7322mv+/hYsWCBJAW8pfq+ha0DFxcXKyclRXFycYmJiNGzYMG3dujVgzPdvUX722WeaNWuWEhIS1K5dO9177706evRowNiqqirt2bNHVVVVl7OLL9CtWzdJ0vHjx4OqByTOgNAKlJaWSpI6derkn3f27FllZ2dr0KBB+pd/+Rf/W3NTp07VkiVLNHnyZD366KMqKyvTH/7wBxUXF+uzzz5TZGSkJOn555/X3LlzNXLkSI0cOVJffvmlhg8frjNnzlyyn3Xr1umee+5RSkqKHnvsMSUnJ+sf//iHVq9erccee0xTp07VoUOHtG7dOv3pT3+65Pq+/vprDR48WHFxcXryyScVGRmpt956S0OHDlVRUZEyMzMDxk+fPl0dOnTQnDlztH//fs2fP195eXl6//33/WNWrlypyZMna/HixQE3VVzMsWPHVF9frwMHDuill16SJA0bNuyyaoEGGaCFWLx4sZFk1q9fb44ePWrKy8vN8uXLTadOnUzbtm3NwYMHjTHGTJw40UgyTz/9dED9X//6VyPJLF26NGD+2rVrA+YfOXLEREVFmbvvvtv4fD7/uGeeecZIMhMnTvTPKygoMJJMQUGBMcaYs2fPmvT0dNO1a1fz3XffBWzn/HXl5uaaxv75STJz5szxvx4zZoyJiooypaWl/nmHDh0ysbGxZsiQIRfsn6ysrIBtzZw504SHh5vjx49fMHbx4sUN9tAQt9ttJBlJplOnTub111+/7FqgIbwFhxYnKytLCQkJSktL0/jx4xUTE6OVK1fquuuuCxg3bdq0gNcrVqyQx+PRXXfdpW+//dY/9evXTzExMSooKJAkrV+/XmfOnNH06dMD3hqbMWPGJXsrLi5WWVmZZsyYofbt2wcsO39dl6u+vl5/+ctfNGbMGHXv3t0/PyUlRffff782b94sr9cbUPPQQw8FbGvw4MGqr6/XN9984583adIkGWMu++xHkj755BOtWbNGr7zyirp06aKamhrHPw9wPt6CQ4uzYMEC9ejRQxEREUpKSlLPnj0VFhb4u1RERIQ6d+4cMG/v3r2qqqpSYmJig+s9cuSIJPn/o77hhhsClickJKhDhw4X7e37twN79+59+T/QRRw9elQnT55Uz549L1h20003yefzqby8XL169fLP79KlS8C473v+4XUup+644w5J524CGT16tHr37q2YmBjl5eVd0Xpx7SKA0OL079/ffxdcY9xu9wWh5PP5lJiYqKVLlzZYk5CQELIebQoPD29wvjEmZNvIyMjQj3/8Yy1dupQAQtAIIFwzMjIytH79eg0cOFBt27ZtdFzXrl0lnTtjOv9tr6NHj17yLCIjI0OStHv3bmVlZTU67nLfjktISFB0dLRKSkouWLZnzx6FhYUpLS3tstYVaqdOnVJtba2VbaN14BoQrhm//OUvVV9fr9/+9rcXLDt79qz/luKsrCxFRkbqjTfeCDhrmD9//iW38ZOf/ETp6emaP3/+Bbcon7+u7z+TdKnbmMPDwzV8+HCtWrUq4MkDlZWVWrZsmQYNGqS4uLhL9vVDl3sb9tmzZxsM3S+++EJfffXVJc9EgYvhDAjXjNtvv11Tp05Vfn6+du7cqeHDhysyMlJ79+7VihUr9Nprr+nnP/+5EhIS9MQTTyg/P1/33HOPRo4cqeLiYn3yySeKj4+/6DbCwsK0cOFCjRo1SrfccosmT56slJQU7dmzR19//bU+/fRTSVK/fv0kSY8++qiys7MVHh6u8ePHN7jOuXPnat26dRo0aJAeeeQRRURE6K233lJtba3mzZsX1L643NuwT5w4obS0NN13333q1auX2rVrp6+++kqLFy+Wx+PRc889F9T2AYkAwjVm0aJF6tevn9566y0988wzioiIULdu3fTggw9q4MCB/nFz585VmzZttGjRIhUUFCgzM1N/+ctfdPfdd19yG9nZ2SooKNCLL76oV155RT6fTxkZGZoyZYp/zNixYzV9+nQtX75c7777rowxjQZQr1699Ne//lWzZ89Wfn6+fD6fMjMz9e67717wGaBQi46O1q9//WsVFBToz3/+s06dOqXU1FRNmDBBzz77rP8DqUAwXCaUVyYBALhMXAMCAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMCKZvc5IJ/Pp0OHDik2NjaopwcDAOwyxqi6ulqpqakXPJPxfM0ugA4dOmTt2VYAgNApLy+/4Kn052t2ARQbGytJGqSRilCk5W4AAE6dVZ02a43///PGNFkALViwQC+//LIqKirUt29fvfHGG+rfv/8l675/2y1CkYpwEUAA0OL8/+frXOoySpPchPD+++9r1qxZmjNnjr788kv17dtX2dnZ/i/8AgCgSQLo1Vdf1ZQpUzR58mT96Ec/0qJFixQdHa0//vGPTbE5AEALFPIAOnPmjHbs2BHwZVxhYWHKysrSli1bLhhfW1srr9cbMAEAWr+QB9C3336r+vp6JSUlBcxPSkpSRUXFBePz8/Pl8Xj8E3fAAcC1wfoHUWfPnq2qqir/VF5ebrslAMBVEPK74OLj4xUeHq7KysqA+ZWVlUpOTr5gvNvtltvtDnUbAIBmLuRnQFFRUerXr582bNjgn+fz+bRhwwYNGDAg1JsDALRQTfI5oFmzZmnixIm69dZb1b9/f82fP181NTWaPHlyU2wOANACNUkA3XfffTp69Kief/55VVRU6JZbbtHatWsvuDEBAHDtchljjO0mzuf1euXxeDRUo3kSAgC0QGdNnQq1SlVVVYqLi2t0nPW74AAA1yYCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWBFhuwHYZ37aN7i6cOe/v0RWeh3XlP7vRMc1vu6nHNdI0p7b/+i4JtzlfD9MPTjAcU3hp7c4run2nzWOayRJW3cFVwc4wBkQAMAKAggAYEXIA+iFF16Qy+UKmG688cZQbwYA0MI1yTWgXr16af369f+zkQguNQEAAjVJMkRERCg5ObkpVg0AaCWa5BrQ3r17lZqaqu7du+uBBx7QgQMHGh1bW1srr9cbMAEAWr+QB1BmZqaWLFmitWvXauHChSorK9PgwYNVXV3d4Pj8/Hx5PB7/lJaWFuqWAADNUMgDKCcnR7/4xS/Up08fZWdna82aNTp+/Lg++OCDBsfPnj1bVVVV/qm8vDzULQEAmqEmvzugffv26tGjh/bt29fgcrfbLbfb3dRtAACamSb/HNCJEydUWlqqlJSUpt4UAKAFCXkAPfHEEyoqKtL+/fv1+eef695771V4eLgmTJgQ6k0BAFqwkL8Fd/DgQU2YMEHHjh1TQkKCBg0apK1btyohISHUmwIAtGAuY4yx3cT5vF6vPB6Phmq0IlyRttuxqubnmY5rKm91flK7dsLLjmskqUtEW8c1/2v/XY5r/tRtneManFN8xhdU3eOP5zmuif5wW1DbQutz1tSpUKtUVVWluLi4RsfxLDgAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKwggAAAVhBAAAArCCAAgBUEEADACgIIAGAFAQQAsIKHkV4lR/J+6rim8OlXHNdEu6Ic1zR339afclzTxhXc71Z1cv7PYcaBexzX/DLxvxzX3B1d5bgmWPvqah3XPDH4l45rzpYfdFyD5o+HkQIAmjUCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsiLDdwLXCF+68pjU+2frlYz9yXLNhxiDHNfVtg/vd6rvrnT+B/br/POy45s2EcY5r7v7zHx3XBGvsf011XNPt+P7QN4JWjTMgAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCh5FeJan/+jfHNR88kui4Jjv6gOOanDlPOK6RpLoYl+Oa6/7joOOaiP07nNc4rjgnOYia+iBqKu/5aRBVV8+uny5xXDMmiAes+qqrHdeg9eAMCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCs4GGkV4mvpsZxzTs90xzXvJ0z1nFNfEGx4xpJ8p0+7bjmbFBbunrCExIc13x3V4bjmscf/sBxDdDacAYEALCCAAIAWOE4gDZt2qRRo0YpNTVVLpdLH330UcByY4yef/55paSkqG3btsrKytLevXtD1S8AoJVwHEA1NTXq27evFixY0ODyefPm6fXXX9eiRYu0bds2tWvXTtnZ2TodxPUCAEDr5fgmhJycHOXk5DS4zBij+fPn69lnn9Xo0aMlSe+8846SkpL00Ucfafz48VfWLQCg1QjpNaCysjJVVFQoKyvLP8/j8SgzM1NbtmxpsKa2tlZerzdgAgC0fiENoIqKCklSUlJSwPykpCT/sh/Kz8+Xx+PxT2lpzm89BgC0PNbvgps9e7aqqqr8U3l5ue2WAABXQUgDKDk5WZJUWVkZML+ystK/7Ifcbrfi4uICJgBA6xfSAEpPT1dycrI2bNjgn+f1erVt2zYNGDAglJsCALRwju+CO3HihPbt2+d/XVZWpp07d6pjx47q0qWLZsyYoblz5+qGG25Qenq6nnvuOaWmpmrMmDGh7BsA0MI5DqDt27frjjvu8L+eNWuWJGnixIlasmSJnnzySdXU1Oihhx7S8ePHNWjQIK1du1Zt2rQJXdcAgBbPZYwxtps4n9frlcfj0VCNVoQr0nY7aKHC23uCqnt8x2bHNUPanAlqW1eDT76g6n73bV/HNduGpTiuqf/2mOMaNH9nTZ0KtUpVVVUXva5v/S44AMC1iQACAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwggACAFhBAAEArCCAAABWEEAAACscfx0D0BKUPdorqLohbTaGuBO7VtXEB1X3ed+oIKp4sjWc4QwIAGAFAQQAsIIAAgBYQQABAKwggAAAVhBAAAArCCAAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKzgYaRAK3ZX28NB1c2d9YDjmrqYoDblWGLxWcc1bT7+ogk6wZXiDAgAYAUBBACwggACAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwggACAFhBAAEArOBhpGiV0jacDKpuxyTnNf3cQW3qqogJC665HY+/EeJOQmfOkR87rtnxMb9rN0f8rQAArCCAAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwggACAFhBAAEArCCAAABWEEAAACsIIACAFTyMFK2S67OdQdW9NOp+xzW1STGOa6of9zqu+eyW5Y5rWqNnE7Y7rrnzwUeD2pbn3a1B1eHycAYEALCCAAIAWOE4gDZt2qRRo0YpNTVVLpdLH330UcDySZMmyeVyBUwjRowIVb8AgFbCcQDV1NSob9++WrBgQaNjRowYocOHD/un995774qaBAC0Po5vQsjJyVFOTs5Fx7jdbiUnJwfdFACg9WuSa0CFhYVKTExUz549NW3aNB07dqzRsbW1tfJ6vQETAKD1C3kAjRgxQu+88442bNigf/7nf1ZRUZFycnJUX1/f4Pj8/Hx5PB7/lJaWFuqWAADNUMg/BzR+/Hj/n2+++Wb16dNHGRkZKiws1LBhwy4YP3v2bM2aNcv/2uv1EkIAcA1o8tuwu3fvrvj4eO3bt6/B5W63W3FxcQETAKD1a/IAOnjwoI4dO6aUlJSm3hQAoAVx/BbciRMnAs5mysrKtHPnTnXs2FEdO3bUiy++qHHjxik5OVmlpaV68skndf311ys7OzukjQMAWjbHAbR9+3bdcccd/tffX7+ZOHGiFi5cqF27dunf//3fdfz4caWmpmr48OH67W9/K7fbHbquAQAtnuMAGjp0qIwxjS7/9NNPr6ghwKb6r0sc10R87Xw7HQpcjmtGRf3Ucc3+P/VwXCNJn2QudFzTOaJtUNtyKtIV7rjmdMfgrjZ4gqrC5eJZcAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALAi5F/JDeAyXOSJ8o2W1NY6run6y68c10jSnW/NdFzz3/csCmpbuHZxBgQAsIIAAgBYQQABAKwggAAAVhBAAAArCCAAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKwggAAAVvAwUqAVc0VGBVfXtj7EnYTOrjPOe0vcXtMEneBKcQYEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFbwMFKgFSt545ag6v572MLQNhJCMx6f7rgm+vNtTdAJrhRnQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQ8jRasUFhsbXF17T4g7adiRu9Ic19w1/TPHNf+RuMBxzTlX53fTD04kOq6J2/KN45qzjitwNXAGBACwggACAFjhKIDy8/N12223KTY2VomJiRozZoxKSkoCxpw+fVq5ubnq1KmTYmJiNG7cOFVWVoa0aQBAy+cogIqKipSbm6utW7dq3bp1qqur0/Dhw1VTU+MfM3PmTH388cdasWKFioqKdOjQIY0dOzbkjQMAWjZHNyGsXbs24PWSJUuUmJioHTt2aMiQIaqqqtK//du/admyZbrzzjslSYsXL9ZNN92krVu36p/+6Z9C1zkAoEW7omtAVVVVkqSOHTtKknbs2KG6ujplZWX5x9x4443q0qWLtmzZ0uA6amtr5fV6AyYAQOsXdAD5fD7NmDFDAwcOVO/evSVJFRUVioqKUvv27QPGJiUlqaKiosH15Ofny+Px+Ke0NOe3pwIAWp6gAyg3N1e7d+/W8uXLr6iB2bNnq6qqyj+Vl5df0foAAC1DUB9EzcvL0+rVq7Vp0yZ17tzZPz85OVlnzpzR8ePHA86CKisrlZyc3OC63G633G53MG0AAFowR2dAxhjl5eVp5cqV2rhxo9LT0wOW9+vXT5GRkdqwYYN/XklJiQ4cOKABAwaEpmMAQKvg6AwoNzdXy5Yt06pVqxQbG+u/ruPxeNS2bVt5PB796le/0qxZs9SxY0fFxcVp+vTpGjBgAHfAAQACOAqghQsXSpKGDh0aMH/x4sWaNGmSJOn3v/+9wsLCNG7cONXW1io7O1tvvvlmSJoFALQeLmOMsd3E+bxerzwej4ZqtCJckbbbuSaE9b0pqLo9uTGOa5LT/q/jmiMlCY5rJt9Z6LhGkp7q9HVQdQhOn88nOa7p8ouvQt8IQuqsqVOhVqmqqkpxcXGNjuNZcAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALAiqG9ERfPl6tfLcU3b3x8Jalv/nfFuUHWO9bk6m2nuak2d45pIV3hQ26qsr3VcM+dQjuOazq8F1x9aB86AAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwggACAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMAKHkbaytR0jXFc8373fw1ya1FB1jU9n3xB1c08NNhxzW8S1zuuyf4813FNbGG045rqbo5LJEnps7cEUVXtuCJMO4PYDloLzoAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoeRtrKRH+4zXHNLw4/HNS2jv64neMaXxDPL61z/nxVvf1//uC8SFLpbacd10z78VTHNek7dzmukTGOS+KdbwW4ajgDAgBYQQABAKwggAAAVhBAAAArCCAAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKwggAAAVhBAAAArXMYE8YTDJuT1euXxeDRUoxXhirTdDgDAobOmToVapaqqKsXFxTU6jjMgAIAVBBAAwApHAZSfn6/bbrtNsbGxSkxM1JgxY1RSUhIwZujQoXK5XAHTww8H930zAIDWy1EAFRUVKTc3V1u3btW6detUV1en4cOHq6amJmDclClTdPjwYf80b968kDYNAGj5HH0j6tq1awNeL1myRImJidqxY4eGDBninx8dHa3k5OTQdAgAaJWu6BpQVVWVJKljx44B85cuXar4+Hj17t1bs2fP1smTJxtdR21trbxeb8AEAGj9HJ0Bnc/n82nGjBkaOHCgevfu7Z9///33q2vXrkpNTdWuXbv01FNPqaSkRB9++GGD68nPz9eLL74YbBsAgBYq6M8BTZs2TZ988ok2b96szp07Nzpu48aNGjZsmPbt26eMjIwLltfW1qq2ttb/2uv1Ki0tjc8BAUALdbmfAwrqDCgvL0+rV6/Wpk2bLho+kpSZmSlJjQaQ2+2W2+0Opg0AQAvmKICMMZo+fbpWrlypwsJCpaenX7Jm586dkqSUlJSgGgQAtE6OAig3N1fLli3TqlWrFBsbq4qKCkmSx+NR27ZtVVpaqmXLlmnkyJHq1KmTdu3apZkzZ2rIkCHq06dPk/wAAICWydE1IJfL1eD8xYsXa9KkSSovL9eDDz6o3bt3q6amRmlpabr33nv17LPPXvR9wPPxLDgAaNma5BrQpbIqLS1NRUVFTlYJALhG8Sw4AIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVEbYb+CFjjCTprOokY7kZAIBjZ1Un6X/+P29Mswug6upqSdJmrbHcCQDgSlRXV8vj8TS63GUuFVFXmc/n06FDhxQbGyuXyxWwzOv1Ki0tTeXl5YqLi7PUoX3sh3PYD+ewH85hP5zTHPaDMUbV1dVKTU1VWFjjV3qa3RlQWFiYOnfufNExcXFx1/QB9j32wznsh3PYD+ewH86xvR8udubzPW5CAABYQQABAKxoUQHkdrs1Z84cud1u261YxX44h/1wDvvhHPbDOS1pPzS7mxAAANeGFnUGBABoPQggAIAVBBAAwAoCCABgBQEEALCixQTQggUL1K1bN7Vp00aZmZn64osvbLd01b3wwgtyuVwB04033mi7rSa3adMmjRo1SqmpqXK5XProo48Clhtj9PzzzyslJUVt27ZVVlaW9u7da6fZJnSp/TBp0qQLjo8RI0bYabaJ5Ofn67bbblNsbKwSExM1ZswYlZSUBIw5ffq0cnNz1alTJ8XExGjcuHGqrKy01HHTuJz9MHTo0AuOh4cffthSxw1rEQH0/vvva9asWZozZ46+/PJL9e3bV9nZ2Tpy5Ijt1q66Xr166fDhw/5p8+bNtltqcjU1Nerbt68WLFjQ4PJ58+bp9ddf16JFi7Rt2za1a9dO2dnZOn369FXutGldaj9I0ogRIwKOj/fee+8qdtj0ioqKlJubq61bt2rdunWqq6vT8OHDVVNT4x8zc+ZMffzxx1qxYoWKiop06NAhjR071mLXoXc5+0GSpkyZEnA8zJs3z1LHjTAtQP/+/U1ubq7/dX19vUlNTTX5+fkWu7r65syZY/r27Wu7DaskmZUrV/pf+3w+k5ycbF5++WX/vOPHjxu3223ee+89Cx1eHT/cD8YYM3HiRDN69Ggr/dhy5MgRI8kUFRUZY8793UdGRpoVK1b4x/zjH/8wksyWLVtstdnkfrgfjDHm9ttvN4899pi9pi5Dsz8DOnPmjHbs2KGsrCz/vLCwMGVlZWnLli0WO7Nj7969Sk1NVffu3fXAAw/owIEDtluyqqysTBUVFQHHh8fjUWZm5jV5fBQWFioxMVE9e/bUtGnTdOzYMdstNamqqipJUseOHSVJO3bsUF1dXcDxcOONN6pLly6t+nj44X743tKlSxUfH6/evXtr9uzZOnnypI32GtXsnob9Q99++63q6+uVlJQUMD8pKUl79uyx1JUdmZmZWrJkiXr27KnDhw/rxRdf1ODBg7V7927Fxsbabs+KiooKSWrw+Ph+2bVixIgRGjt2rNLT01VaWqpnnnlGOTk52rJli8LDw223F3I+n08zZszQwIED1bt3b0nnjoeoqCi1b98+YGxrPh4a2g+SdP/996tr165KTU3Vrl279NRTT6mkpEQffvihxW4DNfsAwv/Iycnx/7lPnz7KzMxU165d9cEHH+hXv/qVxc7QHIwfP97/55tvvll9+vRRRkaGCgsLNWzYMIudNY3c3Fzt3r37mrgOejGN7YeHHnrI/+ebb75ZKSkpGjZsmEpLS5WRkXG122xQs38LLj4+XuHh4RfcxVJZWank5GRLXTUP7du3V48ePbRv3z7brVjz/THA8XGh7t27Kz4+vlUeH3l5eVq9erUKCgoCvj8sOTlZZ86c0fHjxwPGt9bjobH90JDMzExJalbHQ7MPoKioKPXr108bNmzwz/P5fNqwYYMGDBhgsTP7Tpw4odLSUqWkpNhuxZr09HQlJycHHB9er1fbtm275o+PgwcP6tixY63q+DDGKC8vTytXrtTGjRuVnp4esLxfv36KjIwMOB5KSkp04MCBVnU8XGo/NGTnzp2S1LyOB9t3QVyO5cuXG7fbbZYsWWL+/ve/m4ceesi0b9/eVFRU2G7tqnr88cdNYWGhKSsrM5999pnJysoy8fHx5siRI7Zba1LV1dWmuLjYFBcXG0nm1VdfNcXFxeabb74xxhjzu9/9zrRv396sWrXK7Nq1y4wePdqkp6ebU6dOWe48tC62H6qrq80TTzxhtmzZYsrKysz69evNT37yE3PDDTeY06dP2249ZKZNm2Y8Ho8pLCw0hw8f9k8nT570j3n44YdNly5dzMaNG8327dvNgAEDzIABAyx2HXqX2g/79u0zL730ktm+fbspKyszq1atMt27dzdDhgyx3HmgFhFAxhjzxhtvmC5dupioqCjTv39/s3XrVtstXXX33XefSUlJMVFRUea6664z9913n9m3b5/ttppcQUGBkXTBNHHiRGPMuVuxn3vuOZOUlGTcbrcZNmyYKSkpsdt0E7jYfjh58qQZPny4SUhIMJGRkaZr165mypQpre6XtIZ+fklm8eLF/jGnTp0yjzzyiOnQoYOJjo429957rzl8+LC9ppvApfbDgQMHzJAhQ0zHjh2N2+02119/vfnNb35jqqqq7Db+A3wfEADAimZ/DQgA0DoRQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAV/w/hgVLrpVGHsAAAAABJRU5ErkJggg==\",\n      \"text/plain\": [\n       \"<Figure size 640x480 with 1 Axes>\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"plt.figure()\\n\",\n    \"plt.title(\\\"Prediction: {}\\\".format(prediction))\\n\",\n    \"plt.imshow(img)\\n\",\n    \"plt.show()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"39167347-0b99-4972-998c-e1230bf1d4d5\",\n   \"metadata\": {},\n   \"source\": [\n    \"### 784 columns of float\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 39,\n   \"id\": \"6bea332e-f6de-494f-a0db-795d9fe3e134\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def predict_batch_fn():\\n\",\n    \"    import tensorflow as tf\\n\",\n    \"    # Enable GPU memory growth\\n\",\n    \"    gpus = tf.config.experimental.list_physical_devices('GPU')\\n\",\n    \"    if gpus:\\n\",\n    \"        try:\\n\",\n    \"            for gpu in gpus:\\n\",\n    \"                tf.config.experimental.set_memory_growth(gpu, True)\\n\",\n    \"        except RuntimeError as e:\\n\",\n    \"            print(e)\\n\",\n    \"            \\n\",\n    \"    model = tf.keras.models.load_model(model_path)\\n\",\n    \"    def predict(inputs: np.ndarray) -> np.ndarray:\\n\",\n    \"        return model.predict(inputs)\\n\",\n    \"        \\n\",\n    \"    return predict\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 40,\n   \"id\": \"731d234c-549f-4df3-8a2b-312e63195396\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"mnist = predict_batch_udf(predict_batch_fn,\\n\",\n    \"                          return_type=ArrayType(FloatType()),\\n\",\n    \"                          batch_size=128,\\n\",\n    \"                          input_tensor_shapes=[[784]])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"a40fe207-6246-4b0e-abde-823979878d97\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"784\"\n      ]\n     },\n     \"execution_count\": 41,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"df = spark.read.parquet(data_path_784)\\n\",\n    \"len(df.columns)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 42,\n   \"id\": \"10904f12-03e7-4518-8f12-2aa11989ddf5\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 12:==============>                                           (2 + 6) / 8]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 52.5 ms, sys: 22 ms, total: 74.5 ms\\n\",\n      \"Wall time: 5.72 s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"preds = df.withColumn(\\\"preds\\\", mnist(struct(*df.columns))).collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 43,\n   \"id\": \"671128df-f0f4-4f54-b35c-d63a78c7f89a\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 13:===========================================>              (6 + 2) / 8]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 49.4 ms, sys: 31.9 ms, total: 81.2 ms\\n\",\n      \"Wall time: 1.34 s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"preds = df.withColumn(\\\"preds\\\", mnist(array(*df.columns))).collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 44,\n   \"id\": \"ce35deaf-7d49-4f34-9bf9-b4e6fc5761f4\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# should raise ValueError\\n\",\n    \"# preds = df.withColumn(\\\"preds\\\", mnist(*df.columns)).collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"01709833-484b-451f-9aa8-37be5b7baf14\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Check prediction\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 45,\n   \"id\": \"f9119632-b284-45d7-a262-c262e034c15c\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/html\": [\n       \"<div>\\n\",\n       \"<style scoped>\\n\",\n       \"    .dataframe tbody tr th:only-of-type {\\n\",\n       \"        vertical-align: middle;\\n\",\n       \"    }\\n\",\n       \"\\n\",\n       \"    .dataframe tbody tr th {\\n\",\n       \"        vertical-align: top;\\n\",\n       \"    }\\n\",\n       \"\\n\",\n       \"    .dataframe thead th {\\n\",\n       \"        text-align: right;\\n\",\n       \"    }\\n\",\n       \"</style>\\n\",\n       \"<table border=\\\"1\\\" class=\\\"dataframe\\\">\\n\",\n       \"  <thead>\\n\",\n       \"    <tr style=\\\"text-align: right;\\\">\\n\",\n       \"      <th></th>\\n\",\n       \"      <th>0</th>\\n\",\n       \"      <th>1</th>\\n\",\n       \"      <th>2</th>\\n\",\n       \"      <th>3</th>\\n\",\n       \"      <th>4</th>\\n\",\n       \"      <th>5</th>\\n\",\n       \"      <th>6</th>\\n\",\n       \"      <th>7</th>\\n\",\n       \"      <th>8</th>\\n\",\n       \"      <th>9</th>\\n\",\n       \"      <th>...</th>\\n\",\n       \"      <th>775</th>\\n\",\n       \"      <th>776</th>\\n\",\n       \"      <th>777</th>\\n\",\n       \"      <th>778</th>\\n\",\n       \"      <th>779</th>\\n\",\n       \"      <th>780</th>\\n\",\n       \"      <th>781</th>\\n\",\n       \"      <th>782</th>\\n\",\n       \"      <th>783</th>\\n\",\n       \"      <th>preds</th>\\n\",\n       \"    </tr>\\n\",\n       \"  </thead>\\n\",\n       \"  <tbody>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>0</th>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>...</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>[-6.9618006, 1.2047814, -0.09570807, 0.0462105...</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>1</th>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>...</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>[-5.2882323, 5.902014, -2.0389183, -1.2460864,...</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>2</th>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>...</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>[-5.822013, -2.3333628, -2.4322102, -8.040086,...</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>3</th>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>...</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>[-0.57203317, -1.2920653, -2.7234774, 0.914070...</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>4</th>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>...</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>[-3.689301, 5.0702505, -0.23930073, -0.7988689...</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>5</th>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>...</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>[8.268821, -2.070008, 1.722378, -1.8471404, -8...</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>6</th>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>...</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>[5.59269, -3.1613479, 0.4734843, -0.7772096, -...</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>7</th>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>...</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>[1.9852623, -5.166985, 0.86473066, -6.491789, ...</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>8</th>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>...</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>[-2.800528, -4.2984514, 10.887824, -3.1346364,...</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>9</th>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>...</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>0.0</td>\\n\",\n       \"      <td>[-3.7827752, -4.51145, -5.354035, 9.399383, -6...</td>\\n\",\n       \"    </tr>\\n\",\n       \"  </tbody>\\n\",\n       \"</table>\\n\",\n       \"<p>10 rows × 785 columns</p>\\n\",\n       \"</div>\"\n      ],\n      \"text/plain\": [\n       \"     0    1    2    3    4    5    6    7    8    9  ...  775  776  777  778  \\\\\\n\",\n       \"0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0  0.0  0.0  0.0   \\n\",\n       \"1  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0  0.0  0.0  0.0   \\n\",\n       \"2  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0  0.0  0.0  0.0   \\n\",\n       \"3  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0  0.0  0.0  0.0   \\n\",\n       \"4  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0  0.0  0.0  0.0   \\n\",\n       \"5  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0  0.0  0.0  0.0   \\n\",\n       \"6  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0  0.0  0.0  0.0   \\n\",\n       \"7  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0  0.0  0.0  0.0   \\n\",\n       \"8  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0  0.0  0.0  0.0   \\n\",\n       \"9  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0  0.0  0.0  0.0   \\n\",\n       \"\\n\",\n       \"   779  780  781  782  783                                              preds  \\n\",\n       \"0  0.0  0.0  0.0  0.0  0.0  [-6.9618006, 1.2047814, -0.09570807, 0.0462105...  \\n\",\n       \"1  0.0  0.0  0.0  0.0  0.0  [-5.2882323, 5.902014, -2.0389183, -1.2460864,...  \\n\",\n       \"2  0.0  0.0  0.0  0.0  0.0  [-5.822013, -2.3333628, -2.4322102, -8.040086,...  \\n\",\n       \"3  0.0  0.0  0.0  0.0  0.0  [-0.57203317, -1.2920653, -2.7234774, 0.914070...  \\n\",\n       \"4  0.0  0.0  0.0  0.0  0.0  [-3.689301, 5.0702505, -0.23930073, -0.7988689...  \\n\",\n       \"5  0.0  0.0  0.0  0.0  0.0  [8.268821, -2.070008, 1.722378, -1.8471404, -8...  \\n\",\n       \"6  0.0  0.0  0.0  0.0  0.0  [5.59269, -3.1613479, 0.4734843, -0.7772096, -...  \\n\",\n       \"7  0.0  0.0  0.0  0.0  0.0  [1.9852623, -5.166985, 0.86473066, -6.491789, ...  \\n\",\n       \"8  0.0  0.0  0.0  0.0  0.0  [-2.800528, -4.2984514, 10.887824, -3.1346364,...  \\n\",\n       \"9  0.0  0.0  0.0  0.0  0.0  [-3.7827752, -4.51145, -5.354035, 9.399383, -6...  \\n\",\n       \"\\n\",\n       \"[10 rows x 785 columns]\"\n      ]\n     },\n     \"execution_count\": 45,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"preds = df.withColumn(\\\"preds\\\", mnist(struct(df.columns))).limit(10).toPandas()\\n\",\n    \"preds\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 46,\n   \"id\": \"7c067c62-03a6-461e-a1ff-4653276fbea1\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import matplotlib.pyplot as plt\\n\",\n    \"import numpy as np\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 47,\n   \"id\": \"a7084ad0-c021-4296-bad0-7a238971f53b\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"array([-6.9618006 ,  1.2047814 , -0.09570807,  0.04621054, -5.8169513 ,\\n\",\n       \"       -4.148872  , -5.17938   ,  6.382909  , -0.11228667,  0.6022302 ],\\n\",\n       \"      dtype=float32)\"\n      ]\n     },\n     \"execution_count\": 47,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"sample = preds.iloc[0]\\n\",\n    \"sample.preds\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 48,\n   \"id\": \"8167c832-93ef-4f50-873b-07b67c19ef53\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"prediction = np.argmax(sample.preds)\\n\",\n    \"img = sample.drop('preds').to_numpy(dtype=float)\\n\",\n    \"img = np.array(img).reshape(28,28)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"297811e1-aecb-4afd-9a6a-30c49e8881cc\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"image/png\": \"iVBORw0KGgoAAAANSUhEUgAAAaAAAAGzCAYAAABpdMNsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAiTklEQVR4nO3dfXBV9b3v8c/O0+YpCYQ8S8CAAhYET1FyuCCipAlBHVF6KmrvBY4FpQHFHGsPTgVRZtJDTzmoTcE59xTaUxAP0wK3lKKAJBQKdEAYBqu5kMYCAwnImAQChIf9u39w2cdNArg2O3zz8H7NrJnstX7ftb5ZLPiw9lp7bZ9zzgkAgFssyroBAED7RAABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEe3H777Zo0aVLwdWlpqXw+n0pLSyO2DZ/Pp9dffz1i6wNaKgIIrcbSpUvl8/mCU4cOHdS3b19Nnz5d1dXV1u15sm7dulYTMl/d51dP3/rWt6zbQysWY90A4NUbb7yh7OxsnTt3Tlu3btWiRYu0bt067d+/X506dbqlvYwcOVJnz55VXFycp7p169appKSkyRA6e/asYmJazl/N//zP/2w0b9euXXrrrbeUl5dn0BHaipZzlANfU0FBge69915J0ve+9z11795dCxYs0Jo1a/TUU081WVNfX6/OnTtHvJeoqCh16NAhouuM9Ppu1ne/+91G86689Xit/Q18HbwFh1bvoYcekiRVVlZKkiZNmqQuXbqooqJCY8eOVXx8vJ555hlJUiAQ0MKFCzVgwAB16NBBaWlpeu655/Tll1+GrNM5p3nz5qlHjx7q1KmTHnzwQX3yySeNtn2ta0A7d+7U2LFj1a1bN3Xu3FmDBg3SW2+9FeyvpKREUujbW1c0dQ1oz549KigoUEJCgrp06aLRo0drx44dIWOuvEW5bds2FRUVKSUlRZ07d9bjjz+uEydOhIytra3VZ599ptra2q+zi0M0NDToN7/5jR544AH16NHDcz1wBWdAaPUqKiokSd27dw/Ou3jxovLz8zVixAj967/+a/Ctueeee05Lly7V5MmT9cILL6iyslI/+9nPtGfPHm3btk2xsbGSpNmzZ2vevHkaO3asxo4dq48//lh5eXk6f/78DfvZsGGDHnnkEWVkZOjFF19Uenq6Pv30U61du1YvvviinnvuOR09elQbNmxo8u2tq33yySe6//77lZCQoFdeeUWxsbF69913NWrUKJWVlSknJydk/IwZM9StWzfNmTNHn3/+uRYuXKjp06fr/fffD45ZtWqVJk+erCVLloTcVPF1rFu3TjU1NcFQB8LmgFZiyZIlTpLbuHGjO3HihDt8+LBbsWKF6969u+vYsaM7cuSIc865iRMnOknun//5n0Pq//jHPzpJbtmyZSHz169fHzL/+PHjLi4uzj388MMuEAgEx7366qtOkps4cWJw3ubNm50kt3nzZueccxcvXnTZ2dmuV69e7ssvvwzZzlfXVVhY6K7110+SmzNnTvD1uHHjXFxcnKuoqAjOO3r0qIuPj3cjR45stH9yc3NDtvXSSy+56OhoV1NT02jskiVLmuzhesaPH+/8fn+j3w/wirfg0Ork5uYqJSVFWVlZmjBhgrp06aJVq1bptttuCxk3bdq0kNcrV65UYmKivvWtb+mLL74ITkOGDFGXLl20efNmSdLGjRt1/vx5zZgxI+StsZkzZ96wtz179qiyslIzZ85U165dQ5Z9dV1f16VLl/Thhx9q3Lhx6t27d3B+RkaGnn76aW3dulV1dXUhNVOnTg3Z1v33369Lly7pb3/7W3DepEmT5JzzfPZTV1en3//+9xo7dmyj3w/wirfg0OqUlJSob9++iomJUVpamvr166eoqND/S8XExDS6PnHgwAHV1tYqNTW1yfUeP35ckoL/UN95550hy1NSUtStW7fr9nbl7cCBAwd+/V/oOk6cOKEzZ86oX79+jZbdddddCgQCOnz4sAYMGBCc37Nnz5BxV3q++jpXOH7zm9/o3LlzvP2GiCCA0OoMHTo0eBfctfj9/kahFAgElJqaqmXLljVZk5KSErEeLUVHRzc53zl30+tetmyZEhMT9cgjj9z0ugACCO1Gnz59tHHjRg0fPlwdO3a85rhevXpJunzG9NW3vU6cOHHDs4g+ffpIkvbv36/c3Nxrjvu6b8elpKSoU6dOKi8vb7Tss88+U1RUlLKysr7Wum7WsWPHtHnzZk2aNEl+v/+WbBNtG9eA0G585zvf0aVLl/Tmm282Wnbx4kXV1NRIunyNKTY2Vu+8807IWcPChQtvuI1vfvObys7O1sKFC4Pru+Kr67rymaSrx1wtOjpaeXl5WrNmjT7//PPg/Orqai1fvlwjRoxQQkLCDfu6Wji3Ya9YsUKBQIC33xAxnAGh3XjggQf03HPPqbi4WHv37lVeXp5iY2N14MABrVy5Um+99Za+/e1vKyUlRS+//LKKi4v1yCOPaOzYsdqzZ4/+8Ic/KDk5+brbiIqK0qJFi/Too4/qnnvu0eTJk5WRkaHPPvtMn3zyiT744ANJ0pAhQyRJL7zwgvLz8xUdHa0JEyY0uc558+Zpw4YNGjFihL7//e8rJiZG7777rhoaGjR//vyw9kU4t2EvW7ZMmZmZGjVqVFjbBK5GAKFdWbx4sYYMGaJ3331Xr776qmJiYnT77bfru9/9roYPHx4cN2/ePHXo0EGLFy/W5s2blZOTow8//FAPP/zwDbeRn5+vzZs3a+7cufrpT3+qQCCgPn36aMqUKcExTzzxhGbMmKEVK1bo17/+tZxz1wygAQMG6I9//KNmzZql4uJiBQIB5eTk6Ne//nWjzwA1l/Lycu3evVtFRUWNrq0B4fK5SFyZBADAI/4rAwAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMtLjPAQUCAR09elTx8fFhPT0YAGDLOadTp04pMzPzup8ba3EBdPTo0Vv2bCsAQPM5fPjwdb81t8UFUHx8vCRphMYqRrHG3QAAvLqoC9qqdcF/z6+l2QKopKREP/nJT1RVVaXBgwfrnXfe0dChQ29Yd+VttxjFKsZHAAFAq/P/n69zo8sozXITwvvvv6+ioiLNmTNHH3/8sQYPHqz8/PzgF34BANAsAbRgwQJNmTJFkydP1je+8Q0tXrxYnTp10i9+8Yvm2BwAoBWKeACdP39eu3fvDvkyrqioKOXm5mr79u2Nxjc0NKiuri5kAgC0fREPoC+++EKXLl1SWlpayPy0tDRVVVU1Gl9cXKzExMTgxB1wANA+mH8QddasWaqtrQ1Ohw8ftm4JAHALRPwuuOTkZEVHR6u6ujpkfnV1tdLT0xuN9/v9fL88ALRDET8DiouL05AhQ7Rp06bgvEAgoE2bNmnYsGGR3hwAoJVqls8BFRUVaeLEibr33ns1dOhQLVy4UPX19Zo8eXJzbA4A0Ao1SwA9+eSTOnHihGbPnq2qqirdc889Wr9+faMbEwAA7ZfPOeesm/iquro6JSYmapQe40kIANAKXXQXVKo1qq2tVUJCwjXHmd8FBwBonwggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGAi4gH0+uuvy+fzhUz9+/eP9GYAAK1cTHOsdMCAAdq4ceN/bySmWTYDAGjFmiUZYmJilJ6e3hyrBgC0Ec1yDejAgQPKzMxU79699cwzz+jQoUPXHNvQ0KC6urqQCQDQ9kU8gHJycrR06VKtX79eixYtUmVlpe6//36dOnWqyfHFxcVKTEwMTllZWZFuCQDQAvmcc645N1BTU6NevXppwYIFevbZZxstb2hoUENDQ/B1XV2dsrKyNEqPKcYX25ytAQCawUV3QaVao9raWiUkJFxzXLPfHdC1a1f17dtXBw8ebHK53++X3+9v7jYAAC1Ms38O6PTp06qoqFBGRkZzbwoA0IpEPIBefvlllZWV6fPPP9ef/vQnPf7444qOjtZTTz0V6U0BAFqxiL8Fd+TIET311FM6efKkUlJSNGLECO3YsUMpKSmR3hQAoBWLeACtWLEi0qtEOxc9oJ/nmpqB3cLa1qkJ3j8G8D9uq/Rcs+1Ib881w3v81XPN1lV/57lGknq+tddzTeDMmbC2hfaLZ8EBAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAw0exfSAd8VfQd2Z5rpq7+veeahzvVeq6RpCj5PNcEFMaXCt+21XtNGKKmbwurrl9SoeeaPj/YHta20H5xBgQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMMHTsHFLueovPNcU/eEZzzUPj/+55xpJ+jJw1nPNfRtf8FwTdyTOc83+f/yZ55pw/fzx/+255q238z3XXDx8xHMN2g7OgAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJjgYaS4pQKnTnmu6f/mXz3X3HPb//JcI0kd1yd4run779s918Rk9/Jco3/0XhKu1OjTnmtcpw7N0AnaMs6AAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmOBhpGjxLp044bmmx3jvNbdSQ6/unmui5GuGTq6xLZ+7ZdtC+8UZEADABAEEADDhOYC2bNmiRx99VJmZmfL5fFq9enXIcuecZs+erYyMDHXs2FG5ubk6cOBApPoFALQRngOovr5egwcPVklJSZPL58+fr7fffluLFy/Wzp071blzZ+Xn5+vcuXM33SwAoO3wfBNCQUGBCgoKmlzmnNPChQv1ox/9SI899pgk6Ve/+pXS0tK0evVqTZgw4ea6BQC0GRG9BlRZWamqqirl5uYG5yUmJionJ0fbtzf9tcUNDQ2qq6sLmQAAbV9EA6iqqkqSlJaWFjI/LS0tuOxqxcXFSkxMDE5ZWVmRbAkA0EKZ3wU3a9Ys1dbWBqfDhw9btwQAuAUiGkDp6emSpOrq6pD51dXVwWVX8/v9SkhICJkAAG1fRAMoOztb6enp2rRpU3BeXV2ddu7cqWHDhkVyUwCAVs7zXXCnT5/WwYMHg68rKyu1d+9eJSUlqWfPnpo5c6bmzZunO++8U9nZ2XrttdeUmZmpcePGRbJvAEAr5zmAdu3apQcffDD4uqioSJI0ceJELV26VK+88orq6+s1depU1dTUaMSIEVq/fr06dOgQua4BAK2e5wAaNWqUnLv2gwp9Pp/eeOMNvfHGGzfVGNCWHc71e64JyPsDQsN9gGlS1EXPNYEu3n8ntG/md8EBANonAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJz0/DBnDzfH1PW7dwXfOPP3jjQVdxuz9phk7QlnEGBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQPIwVu0ul/yPFc839yFoSxpQ5h1ITngz/c67nmdm1vhk7QlnEGBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQPIwVu0tHcgOeaPjEdm6GTyMncdtG6BbQDnAEBAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwwcNIga+I7p7kueahwZ96rgnIea4JR9/fPx9e3YcfR7gToDHOgAAAJgggAIAJzwG0ZcsWPfroo8rMzJTP59Pq1atDlk+aNEk+ny9kGjNmTKT6BQC0EZ4DqL6+XoMHD1ZJSck1x4wZM0bHjh0LTu+9995NNQkAaHs834RQUFCggoKC647x+/1KT08PuykAQNvXLNeASktLlZqaqn79+mnatGk6efLkNcc2NDSorq4uZAIAtH0RD6AxY8boV7/6lTZt2qR/+Zd/UVlZmQoKCnTp0qUmxxcXFysxMTE4ZWVlRbolAEALFPHPAU2YMCH48913361BgwapT58+Ki0t1ejRoxuNnzVrloqKioKv6+rqCCEAaAea/Tbs3r17Kzk5WQcPHmxyud/vV0JCQsgEAGj7mj2Ajhw5opMnTyojI6O5NwUAaEU8vwV3+vTpkLOZyspK7d27V0lJSUpKStLcuXM1fvx4paenq6KiQq+88oruuOMO5efnR7RxAEDr5jmAdu3apQcffDD4+sr1m4kTJ2rRokXat2+ffvnLX6qmpkaZmZnKy8vTm2++Kb/fH7muAQCtnucAGjVqlJy79oMUP/jgg5tqCLBUOaO/55o1We80QyeR8Y3Zh8Kquxho+q5VIJJ4FhwAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwETEv5IbaM1Gjt1j3cI13VX6Pc81fapa7u8DcAYEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABA8jBb7i57dtC6PK57ni/14457mm32tfeq656LkCuHU4AwIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCh5GiTTr9DzlhVn7suSIg57nmO3u+57km869/8VwDtGScAQEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADDBw0jR4kV3TfRc8z/nrm2GTiIn/adx1i0A5jgDAgCYIIAAACY8BVBxcbHuu+8+xcfHKzU1VePGjVN5eXnImHPnzqmwsFDdu3dXly5dNH78eFVXV0e0aQBA6+cpgMrKylRYWKgdO3Zow4YNunDhgvLy8lRfXx8c89JLL+l3v/udVq5cqbKyMh09elRPPPFExBsHALRunm5CWL9+fcjrpUuXKjU1Vbt379bIkSNVW1ur//iP/9Dy5cv10EMPSZKWLFmiu+66Szt27NDf//3fR65zAECrdlPXgGprayVJSUlJkqTdu3frwoULys3NDY7p37+/evbsqe3btze5joaGBtXV1YVMAIC2L+wACgQCmjlzpoYPH66BAwdKkqqqqhQXF6euXbuGjE1LS1NVVVWT6ykuLlZiYmJwysrKCrclAEArEnYAFRYWav/+/VqxYsVNNTBr1izV1tYGp8OHD9/U+gAArUNYH0SdPn261q5dqy1btqhHjx7B+enp6Tp//rxqampCzoKqq6uVnp7e5Lr8fr/8fn84bQAAWjFPZ0DOOU2fPl2rVq3SRx99pOzs7JDlQ4YMUWxsrDZt2hScV15erkOHDmnYsGGR6RgA0CZ4OgMqLCzU8uXLtWbNGsXHxwev6yQmJqpjx45KTEzUs88+q6KiIiUlJSkhIUEzZszQsGHDuAMOABDCUwAtWrRIkjRq1KiQ+UuWLNGkSZMkSf/2b/+mqKgojR8/Xg0NDcrPz9fPf/7ziDQLAGg7PAWQc+6GYzp06KCSkhKVlJSE3RTwVb5uXT3XPJt4KNythVkHwCueBQcAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMBHWN6ICLV1UmE+1jvaF8X8yFwhrW0B7xxkQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEzyMFC1e5TO3ea4JyIW3sTAeLJr36TjPNbE7/+K5JszfCGixOAMCAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABggoeRosVL3n/Rc83imt5hbevb8Z94rhmZctBzzZ8uxHmuAdoazoAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCY4GGkaPE6rv6z55r1+/4urG0t+EG+55r4g97/GmXoT55rgLaGMyAAgAkCCABgwlMAFRcX67777lN8fLxSU1M1btw4lZeXh4wZNWqUfD5fyPT8889HtGkAQOvnKYDKyspUWFioHTt2aMOGDbpw4YLy8vJUX18fMm7KlCk6duxYcJo/f35EmwYAtH6erp6uX78+5PXSpUuVmpqq3bt3a+TIkcH5nTp1Unp6emQ6BAC0STd1Dai2tlaSlJSUFDJ/2bJlSk5O1sCBAzVr1iydOXPmmutoaGhQXV1dyAQAaPvCvg07EAho5syZGj58uAYOHBic//TTT6tXr17KzMzUvn379MMf/lDl5eX67W9/2+R6iouLNXfu3HDbAAC0UmEHUGFhofbv36+tW7eGzJ86dWrw57vvvlsZGRkaPXq0Kioq1KdPn0brmTVrloqKioKv6+rqlJWVFW5bAIBWIqwAmj59utauXastW7aoR48e1x2bk5MjSTp48GCTAeT3++X3+8NpAwDQinkKIOecZsyYoVWrVqm0tFTZ2dk3rNm7d68kKSMjI6wGAQBtk6cAKiws1PLly7VmzRrFx8erqqpKkpSYmKiOHTuqoqJCy5cv19ixY9W9e3ft27dPL730kkaOHKlBgwY1yy8AAGidPAXQokWLJF3+sOlXLVmyRJMmTVJcXJw2btyohQsXqr6+XllZWRo/frx+9KMfRaxhAEDb4PktuOvJyspSWVnZTTUEAGgfeBo22qSLf/08rLq+08KrA+AdDyMFAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgIsa6gas55yRJF3VBcsbNAAA8u6gLkv773/NraXEBdOrUKUnSVq0z7gQAcDNOnTqlxMTEay73uRtF1C0WCAR09OhRxcfHy+fzhSyrq6tTVlaWDh8+rISEBKMO7bEfLmM/XMZ+uIz9cFlL2A/OOZ06dUqZmZmKirr2lZ4WdwYUFRWlHj16XHdMQkJCuz7ArmA/XMZ+uIz9cBn74TLr/XC9M58ruAkBAGCCAAIAmGhVAeT3+zVnzhz5/X7rVkyxHy5jP1zGfriM/XBZa9oPLe4mBABA+9CqzoAAAG0HAQQAMEEAAQBMEEAAABMEEADARKsJoJKSEt1+++3q0KGDcnJy9Oc//9m6pVvu9ddfl8/nC5n69+9v3Vaz27Jlix599FFlZmbK5/Np9erVIcudc5o9e7YyMjLUsWNH5ebm6sCBAzbNNqMb7YdJkyY1Oj7GjBlj02wzKS4u1n333af4+HilpqZq3LhxKi8vDxlz7tw5FRYWqnv37urSpYvGjx+v6upqo46bx9fZD6NGjWp0PDz//PNGHTetVQTQ+++/r6KiIs2ZM0cff/yxBg8erPz8fB0/fty6tVtuwIABOnbsWHDaunWrdUvNrr6+XoMHD1ZJSUmTy+fPn6+3335bixcv1s6dO9W5c2fl5+fr3Llzt7jT5nWj/SBJY8aMCTk+3nvvvVvYYfMrKytTYWGhduzYoQ0bNujChQvKy8tTfX19cMxLL72k3/3ud1q5cqXKysp09OhRPfHEE4ZdR97X2Q+SNGXKlJDjYf78+UYdX4NrBYYOHeoKCwuDry9duuQyMzNdcXGxYVe33pw5c9zgwYOt2zAlya1atSr4OhAIuPT0dPeTn/wkOK+mpsb5/X733nvvGXR4a1y9H5xzbuLEie6xxx4z6cfK8ePHnSRXVlbmnLv8Zx8bG+tWrlwZHPPpp586SW779u1WbTa7q/eDc8498MAD7sUXX7Rr6mto8WdA58+f1+7du5WbmxucFxUVpdzcXG3fvt2wMxsHDhxQZmamevfurWeeeUaHDh2ybslUZWWlqqqqQo6PxMRE5eTktMvjo7S0VKmpqerXr5+mTZumkydPWrfUrGprayVJSUlJkqTdu3frwoULIcdD//791bNnzzZ9PFy9H65YtmyZkpOTNXDgQM2aNUtnzpyxaO+aWtzTsK/2xRdf6NKlS0pLSwuZn5aWps8++8yoKxs5OTlaunSp+vXrp2PHjmnu3Lm6//77tX//fsXHx1u3Z6KqqkqSmjw+rixrL8aMGaMnnnhC2dnZqqio0KuvvqqCggJt375d0dHR1u1FXCAQ0MyZMzV8+HANHDhQ0uXjIS4uTl27dg0Z25aPh6b2gyQ9/fTT6tWrlzIzM7Vv3z798Ic/VHl5uX77298adhuqxQcQ/ltBQUHw50GDBiknJ0e9evXSf/3Xf+nZZ5817AwtwYQJE4I/33333Ro0aJD69Omj0tJSjR492rCz5lFYWKj9+/e3i+ug13Ot/TB16tTgz3fffbcyMjI0evRoVVRUqE+fPre6zSa1+LfgkpOTFR0d3egulurqaqWnpxt11TJ07dpVffv21cGDB61bMXPlGOD4aKx3795KTk5uk8fH9OnTtXbtWm3evDnk+8PS09N1/vx51dTUhIxvq8fDtfZDU3JyciSpRR0PLT6A4uLiNGTIEG3atCk4LxAIaNOmTRo2bJhhZ/ZOnz6tiooKZWRkWLdiJjs7W+np6SHHR11dnXbu3Nnuj48jR47o5MmTber4cM5p+vTpWrVqlT766CNlZ2eHLB8yZIhiY2NDjofy8nIdOnSoTR0PN9oPTdm7d68ktazjwfouiK9jxYoVzu/3u6VLl7q//OUvburUqa5r166uqqrKurVb6p/+6Z9caWmpq6ysdNu2bXO5ubkuOTnZHT9+3Lq1ZnXq1Cm3Z88et2fPHifJLViwwO3Zs8f97W9/c8459+Mf/9h17drVrVmzxu3bt8899thjLjs72509e9a488i63n44deqUe/nll9327dtdZWWl27hxo/vmN7/p7rzzTnfu3Dnr1iNm2rRpLjEx0ZWWlrpjx44FpzNnzgTHPP/8865nz57uo48+crt27XLDhg1zw4YNM+w68m60Hw4ePOjeeOMNt2vXLldZWenWrFnjevfu7UaOHGnceahWEUDOOffOO++4nj17uri4ODd06FC3Y8cO65ZuuSeffNJlZGS4uLg4d9ttt7knn3zSHTx40LqtZrd582YnqdE0ceJE59zlW7Ffe+01l5aW5vx+vxs9erQrLy+3bboZXG8/nDlzxuXl5bmUlBQXGxvrevXq5aZMmdLm/pPW1O8vyS1ZsiQ45uzZs+773/++69atm+vUqZN7/PHH3bFjx+yabgY32g+HDh1yI0eOdElJSc7v97s77rjD/eAHP3C1tbW2jV+F7wMCAJho8deAAABtEwEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBM/D/AaY3Zb7z6aAAAAABJRU5ErkJggg==\",\n      \"text/plain\": [\n       \"<Figure size 640x480 with 1 Axes>\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"plt.figure()\\n\",\n    \"plt.title(\\\"Prediction: {}\\\".format(prediction))\\n\",\n    \"plt.imshow(img)\\n\",\n    \"plt.show()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"d3dc87a7\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Using Triton Inference Server\\n\",\n    \"In this section, we demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL.  \\n\",\n    \"We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server.  \\n\",\n    \"\\n\",\n    \"The process looks like this:\\n\",\n    \"- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.\\n\",\n    \"- Define a Triton inference function, which contains a client that binds to the local server on a given node and sends inference requests.\\n\",\n    \"- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.\\n\",\n    \"- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.\\n\",\n    \"\\n\",\n    \"<img src=\\\"../images/spark-server.png\\\" alt=\\\"drawing\\\" width=\\\"700\\\"/>\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 50,\n   \"id\": \"cfc841c3\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from functools import partial\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"d1e63867\",\n   \"metadata\": {},\n   \"source\": [\n    \"Import the helper class from server_utils.py:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 51,\n   \"id\": \"d7af3599\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"sc.addPyFile(\\\"server_utils.py\\\")\\n\",\n    \"\\n\",\n    \"from server_utils import TritonServerManager\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"32cbe1cb\",\n   \"metadata\": {},\n   \"source\": [\n    \"Define the Triton Server function:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 52,\n   \"id\": \"c3539d1b\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def triton_server(ports, model_path):\\n\",\n    \"    import time\\n\",\n    \"    import signal\\n\",\n    \"    import numpy as np\\n\",\n    \"    from pytriton.decorators import batch\\n\",\n    \"    from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor\\n\",\n    \"    from pytriton.triton import Triton, TritonConfig\\n\",\n    \"    from pyspark import TaskContext\\n\",\n    \"    import tensorflow as tf\\n\",\n    \"    from tensorflow import keras\\n\",\n    \"\\n\",\n    \"    print(f\\\"SERVER: Initializing model on worker {TaskContext.get().partitionId()}.\\\")\\n\",\n    \"\\n\",\n    \"    # Enable GPU memory growth\\n\",\n    \"    gpus = tf.config.experimental.list_physical_devices('GPU')\\n\",\n    \"    if gpus:\\n\",\n    \"        try:\\n\",\n    \"            for gpu in gpus:\\n\",\n    \"                tf.config.experimental.set_memory_growth(gpu, True)\\n\",\n    \"        except RuntimeError as e:\\n\",\n    \"            print(e)\\n\",\n    \"\\n\",\n    \"    model = keras.models.load_model(model_path)\\n\",\n    \"\\n\",\n    \"    @batch\\n\",\n    \"    def _infer_fn(**inputs):\\n\",\n    \"        images = np.squeeze(inputs[\\\"images\\\"])\\n\",\n    \"        print(f\\\"SERVER: Received batch of size {len(images)}.\\\")\\n\",\n    \"        return {\\n\",\n    \"            \\\"labels\\\": model.predict(images)\\n\",\n    \"        }\\n\",\n    \"\\n\",\n    \"    workspace_path = f\\\"/tmp/triton_{time.strftime('%m_%d_%M_%S')}\\\"\\n\",\n    \"    triton_conf = TritonConfig(http_port=ports[0], grpc_port=ports[1], metrics_port=ports[2])\\n\",\n    \"    with Triton(config=triton_conf, workspace=workspace_path) as triton:\\n\",\n    \"        triton.bind(\\n\",\n    \"            model_name=\\\"ImageClassifier\\\",\\n\",\n    \"            infer_func=_infer_fn,\\n\",\n    \"            inputs=[\\n\",\n    \"                Tensor(name=\\\"images\\\", dtype=np.float64, shape=(-1,)),\\n\",\n    \"            ],\\n\",\n    \"            outputs=[\\n\",\n    \"                Tensor(name=\\\"labels\\\", dtype=np.float32, shape=(-1,)),\\n\",\n    \"            ],\\n\",\n    \"            config=ModelConfig(\\n\",\n    \"                max_batch_size=128,\\n\",\n    \"                batcher=DynamicBatcher(max_queue_delay_microseconds=5000),  # 5ms\\n\",\n    \"            ),\\n\",\n    \"            strict=True,\\n\",\n    \"        )\\n\",\n    \"\\n\",\n    \"        def _stop_triton(signum, frame):\\n\",\n    \"            # The server manager sends SIGTERM to stop the server; this function ensures graceful cleanup.\\n\",\n    \"            print(\\\"SERVER: Received SIGTERM. Stopping Triton server.\\\")\\n\",\n    \"            triton.stop()\\n\",\n    \"\\n\",\n    \"        signal.signal(signal.SIGTERM, _stop_triton)\\n\",\n    \"\\n\",\n    \"        print(\\\"SERVER: Serving inference\\\")\\n\",\n    \"        triton.serve()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"ce4c7701\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Start Triton servers\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"2695d9ab\",\n   \"metadata\": {},\n   \"source\": [\n    \"The `TritonServerManager` will handle the lifecycle of Triton server instances across the Spark cluster:\\n\",\n    \"- Find available ports for HTTP/gRPC/metrics\\n\",\n    \"- Deploy a server on each node via stage-level scheduling\\n\",\n    \"- Gracefully shutdown servers across nodes\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"4deae3b1\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"model_name = \\\"ImageClassifier\\\"\\n\",\n    \"server_manager = TritonServerManager(model_name=model_name, model_path=model_path)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"e56c84f4\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2025-02-07 11:03:44,809 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\\n\",\n      \"2025-02-07 11:03:44,810 - INFO - Starting 1 servers.\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'cb4ae00-lcedt': (2020631, [7000, 7001, 7002])}\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"# Returns {'hostname', (server_pid, [http_port, grpc_port, metrics_port])}\\n\",\n    \"server_manager.start_servers(triton_server)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"77847814\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Define client function\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"e278fde0\",\n   \"metadata\": {},\n   \"source\": [\n    \"Get the hostname -> url mapping from the server manager:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"68a9606e\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"host_to_http_url = server_manager.host_to_http_url  # or server_manager.host_to_grpc_url\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"4d70bd6f\",\n   \"metadata\": {},\n   \"source\": [\n    \"Define the Triton inference function, which returns a predict function for batch inference through the server:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 57,\n   \"id\": \"92ba2e26\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def triton_fn(model_name, host_to_url):\\n\",\n    \"    import socket\\n\",\n    \"    from pytriton.client import ModelClient\\n\",\n    \"\\n\",\n    \"    url = host_to_url[socket.gethostname()]\\n\",\n    \"    print(f\\\"Connecting to Triton model {model_name} at {url}.\\\")\\n\",\n    \"\\n\",\n    \"    def infer_batch(inputs):\\n\",\n    \"        with ModelClient(url, model_name, inference_timeout_s=240) as client:\\n\",\n    \"            result_data = client.infer_batch(inputs)\\n\",\n    \"            return result_data[\\\"labels\\\"]\\n\",\n    \"        \\n\",\n    \"    return infer_batch\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 59,\n   \"id\": \"6658d2a1-ef7b-4ca1-9fb6-f2ac9050f3e5\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"predict = predict_batch_udf(partial(triton_fn, model_name=model_name, host_to_url=host_to_http_url),\\n\",\n    \"                            input_tensor_shapes=[[784]],\\n\",\n    \"                            return_type=ArrayType(FloatType()),\\n\",\n    \"                            batch_size=128)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"3842c263\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Run inference\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 58,\n   \"id\": \"43b93753-1d52-4060-9986-f24c30a67528\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"StructType([StructField('data', ArrayType(DoubleType(), True), True)])\"\n      ]\n     },\n     \"execution_count\": 58,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"df = spark.read.parquet(data_path_1)\\n\",\n    \"df.schema\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 60,\n   \"id\": \"8397aa14-82fd-4351-a477-dc8e8b321fa2\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 19:>                                                         (0 + 8) / 8]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 19.8 ms, sys: 2.89 ms, total: 22.7 ms\\n\",\n      \"Wall time: 1.67 s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"preds = df.withColumn(\\\"preds\\\", predict(struct(\\\"data\\\"))).collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 61,\n   \"id\": \"82698bd9-377a-4415-8971-835487f876cc\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 19.8 ms, sys: 5.99 ms, total: 25.7 ms\\n\",\n      \"Wall time: 399 ms\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"preds = df.withColumn(\\\"preds\\\", predict(\\\"data\\\")).collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 62,\n   \"id\": \"419ad7bd-fa28-49d3-b98d-db9fba5aeaef\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 21:====================================>                     (5 + 3) / 8]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 9.07 ms, sys: 1.34 ms, total: 10.4 ms\\n\",\n      \"Wall time: 888 ms\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/html\": [\n       \"<div>\\n\",\n       \"<style scoped>\\n\",\n       \"    .dataframe tbody tr th:only-of-type {\\n\",\n       \"        vertical-align: middle;\\n\",\n       \"    }\\n\",\n       \"\\n\",\n       \"    .dataframe tbody tr th {\\n\",\n       \"        vertical-align: top;\\n\",\n       \"    }\\n\",\n       \"\\n\",\n       \"    .dataframe thead th {\\n\",\n       \"        text-align: right;\\n\",\n       \"    }\\n\",\n       \"</style>\\n\",\n       \"<table border=\\\"1\\\" class=\\\"dataframe\\\">\\n\",\n       \"  <thead>\\n\",\n       \"    <tr style=\\\"text-align: right;\\\">\\n\",\n       \"      <th></th>\\n\",\n       \"      <th>data</th>\\n\",\n       \"      <th>preds</th>\\n\",\n       \"    </tr>\\n\",\n       \"  </thead>\\n\",\n       \"  <tbody>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>0</th>\\n\",\n       \"      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\\n\",\n       \"      <td>[-4.6654444, -2.4893682, -0.5888205, 13.380681...</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>1</th>\\n\",\n       \"      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\\n\",\n       \"      <td>[-2.2732146, -7.5127845, 1.1983705, -3.540661,...</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>2</th>\\n\",\n       \"      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\\n\",\n       \"      <td>[-2.2890894, 0.8308606, 0.31311002, 1.1683631,...</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>3</th>\\n\",\n       \"      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\\n\",\n       \"      <td>[-1.055197, -6.502811, 12.420727, 0.4528031, -...</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>4</th>\\n\",\n       \"      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\\n\",\n       \"      <td>[-3.7887795, 3.9983597, -1.5343359, -0.3698441...</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>5</th>\\n\",\n       \"      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\\n\",\n       \"      <td>[-4.4992743, -1.7618219, 1.1183226, 3.9469318,...</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>6</th>\\n\",\n       \"      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\\n\",\n       \"      <td>[-2.754053, 4.868414, 0.2515293, -0.47300792, ...</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>7</th>\\n\",\n       \"      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\\n\",\n       \"      <td>[-1.888711, 0.02717158, -6.050885, 0.08750934,...</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>8</th>\\n\",\n       \"      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\\n\",\n       \"      <td>[0.9541264, -2.113048, -1.7508973, -5.4303784,...</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>9</th>\\n\",\n       \"      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\\n\",\n       \"      <td>[-1.612412, -0.7655782, -4.473859, 2.0609212, ...</td>\\n\",\n       \"    </tr>\\n\",\n       \"  </tbody>\\n\",\n       \"</table>\\n\",\n       \"</div>\"\n      ],\n      \"text/plain\": [\n       \"                                                data  \\\\\\n\",\n       \"0  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   \\n\",\n       \"1  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   \\n\",\n       \"2  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   \\n\",\n       \"3  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   \\n\",\n       \"4  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   \\n\",\n       \"5  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   \\n\",\n       \"6  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   \\n\",\n       \"7  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   \\n\",\n       \"8  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   \\n\",\n       \"9  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   \\n\",\n       \"\\n\",\n       \"                                               preds  \\n\",\n       \"0  [-4.6654444, -2.4893682, -0.5888205, 13.380681...  \\n\",\n       \"1  [-2.2732146, -7.5127845, 1.1983705, -3.540661,...  \\n\",\n       \"2  [-2.2890894, 0.8308606, 0.31311002, 1.1683631,...  \\n\",\n       \"3  [-1.055197, -6.502811, 12.420727, 0.4528031, -...  \\n\",\n       \"4  [-3.7887795, 3.9983597, -1.5343359, -0.3698441...  \\n\",\n       \"5  [-4.4992743, -1.7618219, 1.1183226, 3.9469318,...  \\n\",\n       \"6  [-2.754053, 4.868414, 0.2515293, -0.47300792, ...  \\n\",\n       \"7  [-1.888711, 0.02717158, -6.050885, 0.08750934,...  \\n\",\n       \"8  [0.9541264, -2.113048, -1.7508973, -5.4303784,...  \\n\",\n       \"9  [-1.612412, -0.7655782, -4.473859, 2.0609212, ...  \"\n      ]\n     },\n     \"execution_count\": 62,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"preds = df.withColumn(\\\"preds\\\", predict(col(\\\"data\\\"))).limit(10).toPandas()\\n\",\n    \"preds\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 63,\n   \"id\": \"79d90a26\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import matplotlib.pyplot as plt\\n\",\n    \"import numpy as np\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 64,\n   \"id\": \"4ca495f5\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"sample = preds.iloc[0]\\n\",\n    \"sample.preds\\n\",\n    \"\\n\",\n    \"prediction = np.argmax(sample.preds)\\n\",\n    \"img = np.array(sample.data).reshape(28,28)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 65,\n   \"id\": \"a5d10903\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"image/png\": \"iVBORw0KGgoAAAANSUhEUgAAAaAAAAGzCAYAAABpdMNsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAkfElEQVR4nO3dfXQUdZ7v8U/nqSEkaR7yLAFCFHRAcAYly/AgSiQEZUCYGUG9F7gziJiggI6KR0Udzsksrg7qIHjcHVhHEGWOyMoiDg9JGBRwwTCIM2QhJ0g4kIBcSYcAIaR/9w+uvbQkQDUdfkl4v86pc+iq37fqm6Lgk+qqrnYZY4wAALjKwmw3AAC4NhFAAAArCCAAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKwggAAAVhBAgAPdunXTpEmT/K8LCwvlcrlUWFgYsm24XC698MILIVsf0FwRQGgxlixZIpfL5Z/atGmjHj16KC8vT5WVlbbbc2TNmjUtJmTefvtt3X777UpKSpLb7VZ6eromT56s/fv3224NLVyE7QYAp1566SWlp6fr9OnT2rx5sxYuXKg1a9Zo9+7dio6Ovqq9DBkyRKdOnVJUVJSjujVr1mjBggUNhtCpU6cUEdF8/mkWFxcrPT1dP/vZz9ShQweVlZXp7bff1urVq/W3v/1NqamptltEC9V8jnLgMuXk5OjWW2+VJP36179Wp06d9Oqrr2rVqlWaMGFCgzU1NTVq165dyHsJCwtTmzZtQrrOUK/vSr355psXzBszZoxuvfVWvfPOO3r66actdIXWgLfg0OLdeeedkqSysjJJ0qRJkxQTE6PS0lKNHDlSsbGxeuCBByRJPp9P8+fPV69evdSmTRslJSVp6tSp+u677wLWaYzR3Llz1blzZ0VHR+uOO+7Q119/fcG2G7sGtG3bNo0cOVIdOnRQu3bt1KdPH7322mv+/hYsWCBJAW8pfq+ha0DFxcXKyclRXFycYmJiNGzYMG3dujVgzPdvUX722WeaNWuWEhIS1K5dO9177706evRowNiqqirt2bNHVVVVl7OLL9CtWzdJ0vHjx4OqByTOgNAKlJaWSpI6derkn3f27FllZ2dr0KBB+pd/+Rf/W3NTp07VkiVLNHnyZD366KMqKyvTH/7wBxUXF+uzzz5TZGSkJOn555/X3LlzNXLkSI0cOVJffvmlhg8frjNnzlyyn3Xr1umee+5RSkqKHnvsMSUnJ+sf//iHVq9erccee0xTp07VoUOHtG7dOv3pT3+65Pq+/vprDR48WHFxcXryyScVGRmpt956S0OHDlVRUZEyMzMDxk+fPl0dOnTQnDlztH//fs2fP195eXl6//33/WNWrlypyZMna/HixQE3VVzMsWPHVF9frwMHDuill16SJA0bNuyyaoEGGaCFWLx4sZFk1q9fb44ePWrKy8vN8uXLTadOnUzbtm3NwYMHjTHGTJw40UgyTz/9dED9X//6VyPJLF26NGD+2rVrA+YfOXLEREVFmbvvvtv4fD7/uGeeecZIMhMnTvTPKygoMJJMQUGBMcaYs2fPmvT0dNO1a1fz3XffBWzn/HXl5uaaxv75STJz5szxvx4zZoyJiooypaWl/nmHDh0ysbGxZsiQIRfsn6ysrIBtzZw504SHh5vjx49fMHbx4sUN9tAQt9ttJBlJplOnTub111+/7FqgIbwFhxYnKytLCQkJSktL0/jx4xUTE6OVK1fquuuuCxg3bdq0gNcrVqyQx+PRXXfdpW+//dY/9evXTzExMSooKJAkrV+/XmfOnNH06dMD3hqbMWPGJXsrLi5WWVmZZsyYofbt2wcsO39dl6u+vl5/+ctfNGbMGHXv3t0/PyUlRffff782b94sr9cbUPPQQw8FbGvw4MGqr6/XN9984583adIkGWMu++xHkj755BOtWbNGr7zyirp06aKamhrHPw9wPt6CQ4uzYMEC9ejRQxEREUpKSlLPnj0VFhb4u1RERIQ6d+4cMG/v3r2qqqpSYmJig+s9cuSIJPn/o77hhhsClickJKhDhw4X7e37twN79+59+T/QRRw9elQnT55Uz549L1h20003yefzqby8XL169fLP79KlS8C473v+4XUup+644w5J524CGT16tHr37q2YmBjl5eVd0Xpx7SKA0OL079/ffxdcY9xu9wWh5PP5lJiYqKVLlzZYk5CQELIebQoPD29wvjEmZNvIyMjQj3/8Yy1dupQAQtAIIFwzMjIytH79eg0cOFBt27ZtdFzXrl0lnTtjOv9tr6NHj17yLCIjI0OStHv3bmVlZTU67nLfjktISFB0dLRKSkouWLZnzx6FhYUpLS3tstYVaqdOnVJtba2VbaN14BoQrhm//OUvVV9fr9/+9rcXLDt79qz/luKsrCxFRkbqjTfeCDhrmD9//iW38ZOf/ETp6emaP3/+Bbcon7+u7z+TdKnbmMPDwzV8+HCtWrUq4MkDlZWVWrZsmQYNGqS4uLhL9vVDl3sb9tmzZxsM3S+++EJfffXVJc9EgYvhDAjXjNtvv11Tp05Vfn6+du7cqeHDhysyMlJ79+7VihUr9Nprr+nnP/+5EhIS9MQTTyg/P1/33HOPRo4cqeLiYn3yySeKj4+/6DbCwsK0cOFCjRo1SrfccosmT56slJQU7dmzR19//bU+/fRTSVK/fv0kSY8++qiys7MVHh6u8ePHN7jOuXPnat26dRo0aJAeeeQRRURE6K233lJtba3mzZsX1L643NuwT5w4obS0NN13333q1auX2rVrp6+++kqLFy+Wx+PRc889F9T2AYkAwjVm0aJF6tevn9566y0988wzioiIULdu3fTggw9q4MCB/nFz585VmzZttGjRIhUUFCgzM1N/+ctfdPfdd19yG9nZ2SooKNCLL76oV155RT6fTxkZGZoyZYp/zNixYzV9+nQtX75c7777rowxjQZQr1699Ne//lWzZ89Wfn6+fD6fMjMz9e67717wGaBQi46O1q9//WsVFBToz3/+s06dOqXU1FRNmDBBzz77rP8DqUAwXCaUVyYBALhMXAMCAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMCKZvc5IJ/Pp0OHDik2NjaopwcDAOwyxqi6ulqpqakXPJPxfM0ugA4dOmTt2VYAgNApLy+/4Kn052t2ARQbGytJGqSRilCk5W4AAE6dVZ02a43///PGNFkALViwQC+//LIqKirUt29fvfHGG+rfv/8l675/2y1CkYpwEUAA0OL8/+frXOoySpPchPD+++9r1qxZmjNnjr788kv17dtX2dnZ/i/8AgCgSQLo1Vdf1ZQpUzR58mT96Ec/0qJFixQdHa0//vGPTbE5AEALFPIAOnPmjHbs2BHwZVxhYWHKysrSli1bLhhfW1srr9cbMAEAWr+QB9C3336r+vp6JSUlBcxPSkpSRUXFBePz8/Pl8Xj8E3fAAcC1wfoHUWfPnq2qqir/VF5ebrslAMBVEPK74OLj4xUeHq7KysqA+ZWVlUpOTr5gvNvtltvtDnUbAIBmLuRnQFFRUerXr582bNjgn+fz+bRhwwYNGDAg1JsDALRQTfI5oFmzZmnixIm69dZb1b9/f82fP181NTWaPHlyU2wOANACNUkA3XfffTp69Kief/55VVRU6JZbbtHatWsvuDEBAHDtchljjO0mzuf1euXxeDRUo3kSAgC0QGdNnQq1SlVVVYqLi2t0nPW74AAA1yYCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWBFhuwHYZ37aN7i6cOe/v0RWeh3XlP7vRMc1vu6nHNdI0p7b/+i4JtzlfD9MPTjAcU3hp7c4run2nzWOayRJW3cFVwc4wBkQAMAKAggAYEXIA+iFF16Qy+UKmG688cZQbwYA0MI1yTWgXr16af369f+zkQguNQEAAjVJMkRERCg5ObkpVg0AaCWa5BrQ3r17lZqaqu7du+uBBx7QgQMHGh1bW1srr9cbMAEAWr+QB1BmZqaWLFmitWvXauHChSorK9PgwYNVXV3d4Pj8/Hx5PB7/lJaWFuqWAADNUMgDKCcnR7/4xS/Up08fZWdna82aNTp+/Lg++OCDBsfPnj1bVVVV/qm8vDzULQEAmqEmvzugffv26tGjh/bt29fgcrfbLbfb3dRtAACamSb/HNCJEydUWlqqlJSUpt4UAKAFCXkAPfHEEyoqKtL+/fv1+eef695771V4eLgmTJgQ6k0BAFqwkL8Fd/DgQU2YMEHHjh1TQkKCBg0apK1btyohISHUmwIAtGAuY4yx3cT5vF6vPB6Phmq0IlyRttuxqubnmY5rKm91flK7dsLLjmskqUtEW8c1/2v/XY5r/tRtneManFN8xhdU3eOP5zmuif5wW1DbQutz1tSpUKtUVVWluLi4RsfxLDgAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKwggAAAVhBAAAArCCAAgBUEEADACgIIAGAFAQQAsIKHkV4lR/J+6rim8OlXHNdEu6Ic1zR339afclzTxhXc71Z1cv7PYcaBexzX/DLxvxzX3B1d5bgmWPvqah3XPDH4l45rzpYfdFyD5o+HkQIAmjUCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsiLDdwLXCF+68pjU+2frlYz9yXLNhxiDHNfVtg/vd6rvrnT+B/br/POy45s2EcY5r7v7zHx3XBGvsf011XNPt+P7QN4JWjTMgAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCh5FeJan/+jfHNR88kui4Jjv6gOOanDlPOK6RpLoYl+Oa6/7joOOaiP07nNc4rjgnOYia+iBqKu/5aRBVV8+uny5xXDMmiAes+qqrHdeg9eAMCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCs4GGkV4mvpsZxzTs90xzXvJ0z1nFNfEGx4xpJ8p0+7bjmbFBbunrCExIc13x3V4bjmscf/sBxDdDacAYEALCCAAIAWOE4gDZt2qRRo0YpNTVVLpdLH330UcByY4yef/55paSkqG3btsrKytLevXtD1S8AoJVwHEA1NTXq27evFixY0ODyefPm6fXXX9eiRYu0bds2tWvXTtnZ2TodxPUCAEDr5fgmhJycHOXk5DS4zBij+fPn69lnn9Xo0aMlSe+8846SkpL00Ucfafz48VfWLQCg1QjpNaCysjJVVFQoKyvLP8/j8SgzM1NbtmxpsKa2tlZerzdgAgC0fiENoIqKCklSUlJSwPykpCT/sh/Kz8+Xx+PxT2lpzm89BgC0PNbvgps9e7aqqqr8U3l5ue2WAABXQUgDKDk5WZJUWVkZML+ystK/7Ifcbrfi4uICJgBA6xfSAEpPT1dycrI2bNjgn+f1erVt2zYNGDAglJsCALRwju+CO3HihPbt2+d/XVZWpp07d6pjx47q0qWLZsyYoblz5+qGG25Qenq6nnvuOaWmpmrMmDGh7BsA0MI5DqDt27frjjvu8L+eNWuWJGnixIlasmSJnnzySdXU1Oihhx7S8ePHNWjQIK1du1Zt2rQJXdcAgBbPZYwxtps4n9frlcfj0VCNVoQr0nY7aKHC23uCqnt8x2bHNUPanAlqW1eDT76g6n73bV/HNduGpTiuqf/2mOMaNH9nTZ0KtUpVVVUXva5v/S44AMC1iQACAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwggACAFhBAAEArCCAAABWEEAAACscfx0D0BKUPdorqLohbTaGuBO7VtXEB1X3ed+oIKp4sjWc4QwIAGAFAQQAsIIAAgBYQQABAKwggAAAVhBAAAArCCAAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKzgYaRAK3ZX28NB1c2d9YDjmrqYoDblWGLxWcc1bT7+ogk6wZXiDAgAYAUBBACwggACAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwggACAFhBAAEArOBhpGiV0jacDKpuxyTnNf3cQW3qqogJC665HY+/EeJOQmfOkR87rtnxMb9rN0f8rQAArCCAAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwggACAFhBAAEArCCAAABWEEAAACsIIACAFTyMFK2S67OdQdW9NOp+xzW1STGOa6of9zqu+eyW5Y5rWqNnE7Y7rrnzwUeD2pbn3a1B1eHycAYEALCCAAIAWOE4gDZt2qRRo0YpNTVVLpdLH330UcDySZMmyeVyBUwjRowIVb8AgFbCcQDV1NSob9++WrBgQaNjRowYocOHD/un995774qaBAC0Po5vQsjJyVFOTs5Fx7jdbiUnJwfdFACg9WuSa0CFhYVKTExUz549NW3aNB07dqzRsbW1tfJ6vQETAKD1C3kAjRgxQu+88442bNigf/7nf1ZRUZFycnJUX1/f4Pj8/Hx5PB7/lJaWFuqWAADNUMg/BzR+/Hj/n2+++Wb16dNHGRkZKiws1LBhwy4YP3v2bM2aNcv/2uv1EkIAcA1o8tuwu3fvrvj4eO3bt6/B5W63W3FxcQETAKD1a/IAOnjwoI4dO6aUlJSm3hQAoAVx/BbciRMnAs5mysrKtHPnTnXs2FEdO3bUiy++qHHjxik5OVmlpaV68skndf311ys7OzukjQMAWjbHAbR9+3bdcccd/tffX7+ZOHGiFi5cqF27dunf//3fdfz4caWmpmr48OH67W9/K7fbHbquAQAtnuMAGjp0qIwxjS7/9NNPr6ghwKb6r0sc10R87Xw7HQpcjmtGRf3Ucc3+P/VwXCNJn2QudFzTOaJtUNtyKtIV7rjmdMfgrjZ4gqrC5eJZcAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALAi5F/JDeAyXOSJ8o2W1NY6run6y68c10jSnW/NdFzz3/csCmpbuHZxBgQAsIIAAgBYQQABAKwggAAAVhBAAAArCCAAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKwggAAAVvAwUqAVc0VGBVfXtj7EnYTOrjPOe0vcXtMEneBKcQYEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFbwMFKgFSt545ag6v572MLQNhJCMx6f7rgm+vNtTdAJrhRnQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQ8jRasUFhsbXF17T4g7adiRu9Ic19w1/TPHNf+RuMBxzTlX53fTD04kOq6J2/KN45qzjitwNXAGBACwggACAFjhKIDy8/N12223KTY2VomJiRozZoxKSkoCxpw+fVq5ubnq1KmTYmJiNG7cOFVWVoa0aQBAy+cogIqKipSbm6utW7dq3bp1qqur0/Dhw1VTU+MfM3PmTH388cdasWKFioqKdOjQIY0dOzbkjQMAWjZHNyGsXbs24PWSJUuUmJioHTt2aMiQIaqqqtK//du/admyZbrzzjslSYsXL9ZNN92krVu36p/+6Z9C1zkAoEW7omtAVVVVkqSOHTtKknbs2KG6ujplZWX5x9x4443q0qWLtmzZ0uA6amtr5fV6AyYAQOsXdAD5fD7NmDFDAwcOVO/evSVJFRUVioqKUvv27QPGJiUlqaKiosH15Ofny+Px+Ke0NOe3pwIAWp6gAyg3N1e7d+/W8uXLr6iB2bNnq6qqyj+Vl5df0foAAC1DUB9EzcvL0+rVq7Vp0yZ17tzZPz85OVlnzpzR8ePHA86CKisrlZyc3OC63G633G53MG0AAFowR2dAxhjl5eVp5cqV2rhxo9LT0wOW9+vXT5GRkdqwYYN/XklJiQ4cOKABAwaEpmMAQKvg6AwoNzdXy5Yt06pVqxQbG+u/ruPxeNS2bVt5PB796le/0qxZs9SxY0fFxcVp+vTpGjBgAHfAAQACOAqghQsXSpKGDh0aMH/x4sWaNGmSJOn3v/+9wsLCNG7cONXW1io7O1tvvvlmSJoFALQeLmOMsd3E+bxerzwej4ZqtCJckbbbuSaE9b0pqLo9uTGOa5LT/q/jmiMlCY5rJt9Z6LhGkp7q9HVQdQhOn88nOa7p8ouvQt8IQuqsqVOhVqmqqkpxcXGNjuNZcAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALAiqG9ERfPl6tfLcU3b3x8Jalv/nfFuUHWO9bk6m2nuak2d45pIV3hQ26qsr3VcM+dQjuOazq8F1x9aB86AAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwggACAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMAKHkbaytR0jXFc8373fw1ya1FB1jU9n3xB1c08NNhxzW8S1zuuyf4813FNbGG045rqbo5LJEnps7cEUVXtuCJMO4PYDloLzoAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoeRtrKRH+4zXHNLw4/HNS2jv64neMaXxDPL61z/nxVvf1//uC8SFLpbacd10z78VTHNek7dzmukTGOS+KdbwW4ajgDAgBYQQABAKwggAAAVhBAAAArCCAAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKwggAAAVhBAAAArXMYE8YTDJuT1euXxeDRUoxXhirTdDgDAobOmToVapaqqKsXFxTU6jjMgAIAVBBAAwApHAZSfn6/bbrtNsbGxSkxM1JgxY1RSUhIwZujQoXK5XAHTww8H930zAIDWy1EAFRUVKTc3V1u3btW6detUV1en4cOHq6amJmDclClTdPjwYf80b968kDYNAGj5HH0j6tq1awNeL1myRImJidqxY4eGDBninx8dHa3k5OTQdAgAaJWu6BpQVVWVJKljx44B85cuXar4+Hj17t1bs2fP1smTJxtdR21trbxeb8AEAGj9HJ0Bnc/n82nGjBkaOHCgevfu7Z9///33q2vXrkpNTdWuXbv01FNPqaSkRB9++GGD68nPz9eLL74YbBsAgBYq6M8BTZs2TZ988ok2b96szp07Nzpu48aNGjZsmPbt26eMjIwLltfW1qq2ttb/2uv1Ki0tjc8BAUALdbmfAwrqDCgvL0+rV6/Wpk2bLho+kpSZmSlJjQaQ2+2W2+0Opg0AQAvmKICMMZo+fbpWrlypwsJCpaenX7Jm586dkqSUlJSgGgQAtE6OAig3N1fLli3TqlWrFBsbq4qKCkmSx+NR27ZtVVpaqmXLlmnkyJHq1KmTdu3apZkzZ2rIkCHq06dPk/wAAICWydE1IJfL1eD8xYsXa9KkSSovL9eDDz6o3bt3q6amRmlpabr33nv17LPPXvR9wPPxLDgAaNma5BrQpbIqLS1NRUVFTlYJALhG8Sw4AIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVEbYb+CFjjCTprOokY7kZAIBjZ1Un6X/+P29Mswug6upqSdJmrbHcCQDgSlRXV8vj8TS63GUuFVFXmc/n06FDhxQbGyuXyxWwzOv1Ki0tTeXl5YqLi7PUoX3sh3PYD+ewH85hP5zTHPaDMUbV1dVKTU1VWFjjV3qa3RlQWFiYOnfufNExcXFx1/QB9j32wznsh3PYD+ewH86xvR8udubzPW5CAABYQQABAKxoUQHkdrs1Z84cud1u261YxX44h/1wDvvhHPbDOS1pPzS7mxAAANeGFnUGBABoPQggAIAVBBAAwAoCCABgBQEEALCixQTQggUL1K1bN7Vp00aZmZn64osvbLd01b3wwgtyuVwB04033mi7rSa3adMmjRo1SqmpqXK5XProo48Clhtj9PzzzyslJUVt27ZVVlaW9u7da6fZJnSp/TBp0qQLjo8RI0bYabaJ5Ofn67bbblNsbKwSExM1ZswYlZSUBIw5ffq0cnNz1alTJ8XExGjcuHGqrKy01HHTuJz9MHTo0AuOh4cffthSxw1rEQH0/vvva9asWZozZ46+/PJL9e3bV9nZ2Tpy5Ijt1q66Xr166fDhw/5p8+bNtltqcjU1Nerbt68WLFjQ4PJ58+bp9ddf16JFi7Rt2za1a9dO2dnZOn369FXutGldaj9I0ogRIwKOj/fee+8qdtj0ioqKlJubq61bt2rdunWqq6vT8OHDVVNT4x8zc+ZMffzxx1qxYoWKiop06NAhjR071mLXoXc5+0GSpkyZEnA8zJs3z1LHjTAtQP/+/U1ubq7/dX19vUlNTTX5+fkWu7r65syZY/r27Wu7DaskmZUrV/pf+3w+k5ycbF5++WX/vOPHjxu3223ee+89Cx1eHT/cD8YYM3HiRDN69Ggr/dhy5MgRI8kUFRUZY8793UdGRpoVK1b4x/zjH/8wksyWLVtstdnkfrgfjDHm9ttvN4899pi9pi5Dsz8DOnPmjHbs2KGsrCz/vLCwMGVlZWnLli0WO7Nj7969Sk1NVffu3fXAAw/owIEDtluyqqysTBUVFQHHh8fjUWZm5jV5fBQWFioxMVE9e/bUtGnTdOzYMdstNamqqipJUseOHSVJO3bsUF1dXcDxcOONN6pLly6t+nj44X743tKlSxUfH6/evXtr9uzZOnnypI32GtXsnob9Q99++63q6+uVlJQUMD8pKUl79uyx1JUdmZmZWrJkiXr27KnDhw/rxRdf1ODBg7V7927Fxsbabs+KiooKSWrw+Ph+2bVixIgRGjt2rNLT01VaWqpnnnlGOTk52rJli8LDw223F3I+n08zZszQwIED1bt3b0nnjoeoqCi1b98+YGxrPh4a2g+SdP/996tr165KTU3Vrl279NRTT6mkpEQffvihxW4DNfsAwv/Iycnx/7lPnz7KzMxU165d9cEHH+hXv/qVxc7QHIwfP97/55tvvll9+vRRRkaGCgsLNWzYMIudNY3c3Fzt3r37mrgOejGN7YeHHnrI/+ebb75ZKSkpGjZsmEpLS5WRkXG122xQs38LLj4+XuHh4RfcxVJZWank5GRLXTUP7du3V48ePbRv3z7brVjz/THA8XGh7t27Kz4+vlUeH3l5eVq9erUKCgoCvj8sOTlZZ86c0fHjxwPGt9bjobH90JDMzExJalbHQ7MPoKioKPXr108bNmzwz/P5fNqwYYMGDBhgsTP7Tpw4odLSUqWkpNhuxZr09HQlJycHHB9er1fbtm275o+PgwcP6tixY63q+DDGKC8vTytXrtTGjRuVnp4esLxfv36KjIwMOB5KSkp04MCBVnU8XGo/NGTnzp2S1LyOB9t3QVyO5cuXG7fbbZYsWWL+/ve/m4ceesi0b9/eVFRU2G7tqnr88cdNYWGhKSsrM5999pnJysoy8fHx5siRI7Zba1LV1dWmuLjYFBcXG0nm1VdfNcXFxeabb74xxhjzu9/9zrRv396sWrXK7Nq1y4wePdqkp6ebU6dOWe48tC62H6qrq80TTzxhtmzZYsrKysz69evNT37yE3PDDTeY06dP2249ZKZNm2Y8Ho8pLCw0hw8f9k8nT570j3n44YdNly5dzMaNG8327dvNgAEDzIABAyx2HXqX2g/79u0zL730ktm+fbspKyszq1atMt27dzdDhgyx3HmgFhFAxhjzxhtvmC5dupioqCjTv39/s3XrVtstXXX33XefSUlJMVFRUea6664z9913n9m3b5/ttppcQUGBkXTBNHHiRGPMuVuxn3vuOZOUlGTcbrcZNmyYKSkpsdt0E7jYfjh58qQZPny4SUhIMJGRkaZr165mypQpre6XtIZ+fklm8eLF/jGnTp0yjzzyiOnQoYOJjo429957rzl8+LC9ppvApfbDgQMHzJAhQ0zHjh2N2+02119/vfnNb35jqqqq7Db+A3wfEADAimZ/DQgA0DoRQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAV/w/hgVLrpVGHsAAAAABJRU5ErkJggg==\",\n      \"text/plain\": [\n       \"<Figure size 640x480 with 1 Axes>\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"plt.figure()\\n\",\n    \"plt.title(\\\"Prediction: {}\\\".format(prediction))\\n\",\n    \"plt.imshow(img)\\n\",\n    \"plt.show()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"6377f41a-5654-410b-8bad-d392e9dce7b8\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"source\": [\n    \"#### Stop Triton Server on each executor\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"d06de00e\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2025-02-04 14:00:18,330 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\\n\",\n      \"2025-02-04 14:00:28,520 - INFO - Sucessfully stopped 1 servers.                 \\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"[True]\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"server_manager.stop_servers()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 67,\n   \"id\": \"f612dc0b-538f-4ecf-81f7-ef6b58c493ab\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"if not on_databricks: # on databricks, spark.stop() puts the cluster in a bad state\\n\",\n    \"    spark.stop()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"490fc849-e47a-48d7-accc-429ff1cced6b\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"spark-dl-tf\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.11.9\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "examples/ML+DL-Examples/Spark-DL/dl_inference/tensorflow/keras_preprocessing_tf.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"7fcc021a\",\n   \"metadata\": {},\n   \"source\": [\n    \"<img src=\\\"http://developer.download.nvidia.com/notebooks/dlsw-notebooks/tensorrt_torchtrt_efficientnet/nvidia_logo.png\\\" width=\\\"90px\\\">\\n\",\n    \"\\n\",\n    \"# Pyspark TensorFlow Inference\\n\",\n    \"\\n\",\n    \"### Classification using Keras Preprocessing Layers\\n\",\n    \"\\n\",\n    \"In this notebook, we demonstrate distributed inference using Keras preprocessing layers to classify structured data.  \\n\",\n    \"From: https://www.tensorflow.org/tutorials/structured_data/preprocessing_layers\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"35203476\",\n   \"metadata\": {},\n   \"source\": [\n    \"Note that cuFFT/cuDNN/cuBLAS registration errors are expected (as of `tf=2.17.0`) and will not affect behavior, as noted in [this issue.](https://github.com/tensorflow/tensorflow/issues/62075)  \"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"id\": \"01162f42-0637-4dfe-8d7d-b577e4ffd017\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2025-02-04 13:59:29.670948: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\\n\",\n      \"2025-02-04 13:59:29.679838: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\\n\",\n      \"2025-02-04 13:59:29.689914: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\\n\",\n      \"2025-02-04 13:59:29.692851: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\\n\",\n      \"2025-02-04 13:59:29.700499: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\\n\",\n      \"To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\\n\",\n      \"2025-02-04 13:59:30.139239: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"import os\\n\",\n    \"import shutil\\n\",\n    \"import numpy as np\\n\",\n    \"import pandas as pd\\n\",\n    \"import tensorflow as tf\\n\",\n    \"\\n\",\n    \"from tensorflow.keras import layers\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"id\": \"0d586fb8\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"os.mkdir('models') if not os.path.exists('models') else None\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"9fa3e1b7-58cd-45f9-9fee-85f25a31c3c6\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2.17.0\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\\n\",\n      \"I0000 00:00:1738706370.524690 3686377 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\\n\",\n      \"I0000 00:00:1738706370.550329 3686377 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\\n\",\n      \"I0000 00:00:1738706370.553239 3686377 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"print(tf.__version__)\\n\",\n    \"\\n\",\n    \"# Enable GPU memory growth\\n\",\n    \"gpus = tf.config.experimental.list_physical_devices('GPU')\\n\",\n    \"if gpus:\\n\",\n    \"    try:\\n\",\n    \"        for gpu in gpus:\\n\",\n    \"            tf.config.experimental.set_memory_growth(gpu, True)\\n\",\n    \"    except RuntimeError as e:\\n\",\n    \"        print(e)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"b2402b9a\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Download dataset\\n\",\n    \"\\n\",\n    \"Download the PetFinder dataset from Kaggle, which where each row describes a pet and the goal is to predict adoption speed.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"id\": \"9326b072-a53c-40c4-a6cb-bd4d3d644d03\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Downloading data from http://storage.googleapis.com/download.tensorflow.org/data/petfinder-mini.zip\\n\",\n      \"\\u001b[1m1668792/1668792\\u001b[0m \\u001b[32m━━━━━━━━━━━━━━━━━━━━\\u001b[0m\\u001b[37m\\u001b[0m \\u001b[1m0s\\u001b[0m 0us/step\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"import pathlib\\n\",\n    \"import os\\n\",\n    \"dataset_url = 'http://storage.googleapis.com/download.tensorflow.org/data/petfinder-mini.zip'\\n\",\n    \"\\n\",\n    \"data_dir = tf.keras.utils.get_file('petfinder_mini.zip', dataset_url, extract=True, cache_dir='.')\\n\",\n    \"data_dir = pathlib.Path(data_dir)\\n\",\n    \"try:\\n\",\n    \"    # pet-finder-mini might be under a parent a directory petfinder_mini_extracted. Check if this is the case:\\n\",\n    \"    dataset = os.path.join(os.path.dirname(data_dir), 'petfinder_mini_extracted/petfinder-mini/petfinder-mini.csv')\\n\",\n    \"    dataframe = pd.read_csv(dataset)\\n\",\n    \"except:\\n\",\n    \"    dataset = os.path.join(os.path.dirname(data_dir), 'petfinder-mini/petfinder-mini.csv')\\n\",\n    \"    dataframe = pd.read_csv(dataset)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"id\": \"e98480ef-d13d-44c0-a227-e9a22f9bf2b0\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/html\": [\n       \"<div>\\n\",\n       \"<style scoped>\\n\",\n       \"    .dataframe tbody tr th:only-of-type {\\n\",\n       \"        vertical-align: middle;\\n\",\n       \"    }\\n\",\n       \"\\n\",\n       \"    .dataframe tbody tr th {\\n\",\n       \"        vertical-align: top;\\n\",\n       \"    }\\n\",\n       \"\\n\",\n       \"    .dataframe thead th {\\n\",\n       \"        text-align: right;\\n\",\n       \"    }\\n\",\n       \"</style>\\n\",\n       \"<table border=\\\"1\\\" class=\\\"dataframe\\\">\\n\",\n       \"  <thead>\\n\",\n       \"    <tr style=\\\"text-align: right;\\\">\\n\",\n       \"      <th></th>\\n\",\n       \"      <th>Type</th>\\n\",\n       \"      <th>Age</th>\\n\",\n       \"      <th>Breed1</th>\\n\",\n       \"      <th>Gender</th>\\n\",\n       \"      <th>Color1</th>\\n\",\n       \"      <th>Color2</th>\\n\",\n       \"      <th>MaturitySize</th>\\n\",\n       \"      <th>FurLength</th>\\n\",\n       \"      <th>Vaccinated</th>\\n\",\n       \"      <th>Sterilized</th>\\n\",\n       \"      <th>Health</th>\\n\",\n       \"      <th>Fee</th>\\n\",\n       \"      <th>Description</th>\\n\",\n       \"      <th>PhotoAmt</th>\\n\",\n       \"      <th>AdoptionSpeed</th>\\n\",\n       \"    </tr>\\n\",\n       \"  </thead>\\n\",\n       \"  <tbody>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>0</th>\\n\",\n       \"      <td>Cat</td>\\n\",\n       \"      <td>3</td>\\n\",\n       \"      <td>Tabby</td>\\n\",\n       \"      <td>Male</td>\\n\",\n       \"      <td>Black</td>\\n\",\n       \"      <td>White</td>\\n\",\n       \"      <td>Small</td>\\n\",\n       \"      <td>Short</td>\\n\",\n       \"      <td>No</td>\\n\",\n       \"      <td>No</td>\\n\",\n       \"      <td>Healthy</td>\\n\",\n       \"      <td>100</td>\\n\",\n       \"      <td>Nibble is a 3+ month old ball of cuteness. He ...</td>\\n\",\n       \"      <td>1</td>\\n\",\n       \"      <td>2</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>1</th>\\n\",\n       \"      <td>Cat</td>\\n\",\n       \"      <td>1</td>\\n\",\n       \"      <td>Domestic Medium Hair</td>\\n\",\n       \"      <td>Male</td>\\n\",\n       \"      <td>Black</td>\\n\",\n       \"      <td>Brown</td>\\n\",\n       \"      <td>Medium</td>\\n\",\n       \"      <td>Medium</td>\\n\",\n       \"      <td>Not Sure</td>\\n\",\n       \"      <td>Not Sure</td>\\n\",\n       \"      <td>Healthy</td>\\n\",\n       \"      <td>0</td>\\n\",\n       \"      <td>I just found it alone yesterday near my apartm...</td>\\n\",\n       \"      <td>2</td>\\n\",\n       \"      <td>0</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>2</th>\\n\",\n       \"      <td>Dog</td>\\n\",\n       \"      <td>1</td>\\n\",\n       \"      <td>Mixed Breed</td>\\n\",\n       \"      <td>Male</td>\\n\",\n       \"      <td>Brown</td>\\n\",\n       \"      <td>White</td>\\n\",\n       \"      <td>Medium</td>\\n\",\n       \"      <td>Medium</td>\\n\",\n       \"      <td>Yes</td>\\n\",\n       \"      <td>No</td>\\n\",\n       \"      <td>Healthy</td>\\n\",\n       \"      <td>0</td>\\n\",\n       \"      <td>Their pregnant mother was dumped by her irresp...</td>\\n\",\n       \"      <td>7</td>\\n\",\n       \"      <td>3</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>3</th>\\n\",\n       \"      <td>Dog</td>\\n\",\n       \"      <td>4</td>\\n\",\n       \"      <td>Mixed Breed</td>\\n\",\n       \"      <td>Female</td>\\n\",\n       \"      <td>Black</td>\\n\",\n       \"      <td>Brown</td>\\n\",\n       \"      <td>Medium</td>\\n\",\n       \"      <td>Short</td>\\n\",\n       \"      <td>Yes</td>\\n\",\n       \"      <td>No</td>\\n\",\n       \"      <td>Healthy</td>\\n\",\n       \"      <td>150</td>\\n\",\n       \"      <td>Good guard dog, very alert, active, obedience ...</td>\\n\",\n       \"      <td>8</td>\\n\",\n       \"      <td>2</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>4</th>\\n\",\n       \"      <td>Dog</td>\\n\",\n       \"      <td>1</td>\\n\",\n       \"      <td>Mixed Breed</td>\\n\",\n       \"      <td>Male</td>\\n\",\n       \"      <td>Black</td>\\n\",\n       \"      <td>No Color</td>\\n\",\n       \"      <td>Medium</td>\\n\",\n       \"      <td>Short</td>\\n\",\n       \"      <td>No</td>\\n\",\n       \"      <td>No</td>\\n\",\n       \"      <td>Healthy</td>\\n\",\n       \"      <td>0</td>\\n\",\n       \"      <td>This handsome yet cute boy is up for adoption....</td>\\n\",\n       \"      <td>3</td>\\n\",\n       \"      <td>2</td>\\n\",\n       \"    </tr>\\n\",\n       \"  </tbody>\\n\",\n       \"</table>\\n\",\n       \"</div>\"\n      ],\n      \"text/plain\": [\n       \"  Type  Age                Breed1  Gender Color1    Color2 MaturitySize  \\\\\\n\",\n       \"0  Cat    3                 Tabby    Male  Black     White        Small   \\n\",\n       \"1  Cat    1  Domestic Medium Hair    Male  Black     Brown       Medium   \\n\",\n       \"2  Dog    1           Mixed Breed    Male  Brown     White       Medium   \\n\",\n       \"3  Dog    4           Mixed Breed  Female  Black     Brown       Medium   \\n\",\n       \"4  Dog    1           Mixed Breed    Male  Black  No Color       Medium   \\n\",\n       \"\\n\",\n       \"  FurLength Vaccinated Sterilized   Health  Fee  \\\\\\n\",\n       \"0     Short         No         No  Healthy  100   \\n\",\n       \"1    Medium   Not Sure   Not Sure  Healthy    0   \\n\",\n       \"2    Medium        Yes         No  Healthy    0   \\n\",\n       \"3     Short        Yes         No  Healthy  150   \\n\",\n       \"4     Short         No         No  Healthy    0   \\n\",\n       \"\\n\",\n       \"                                         Description  PhotoAmt  AdoptionSpeed  \\n\",\n       \"0  Nibble is a 3+ month old ball of cuteness. He ...         1              2  \\n\",\n       \"1  I just found it alone yesterday near my apartm...         2              0  \\n\",\n       \"2  Their pregnant mother was dumped by her irresp...         7              3  \\n\",\n       \"3  Good guard dog, very alert, active, obedience ...         8              2  \\n\",\n       \"4  This handsome yet cute boy is up for adoption....         3              2  \"\n      ]\n     },\n     \"execution_count\": 5,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"dataframe.head()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"27d844f1\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Prepare dataset\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"id\": \"e8efce25-a835-4cbd-b8a2-1418ba2c1d31\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# In the original dataset, `'AdoptionSpeed'` of `4` indicates\\n\",\n    \"# a pet was not adopted.\\n\",\n    \"dataframe['target'] = np.where(dataframe['AdoptionSpeed']==4, 0, 1)\\n\",\n    \"\\n\",\n    \"# Drop unused features.\\n\",\n    \"dataframe = dataframe.drop(columns=['AdoptionSpeed', 'Description'])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"00d403cf-9ae7-4780-9fac-13d920d8b395\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"/home/rishic/anaconda3/envs/spark-dl-tf/lib/python3.11/site-packages/numpy/core/fromnumeric.py:59: FutureWarning: 'DataFrame.swapaxes' is deprecated and will be removed in a future version. Please use 'DataFrame.transpose' instead.\\n\",\n      \"  return bound(*args, **kwds)\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"train, val, test = np.split(dataframe.sample(frac=1), [int(0.8*len(dataframe)), int(0.9*len(dataframe))])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"id\": \"4206a56e-5403-42a9-805e-e037044e7995\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"9229 training examples\\n\",\n      \"1154 validation examples\\n\",\n      \"1154 test examples\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"print(len(train), 'training examples')\\n\",\n    \"print(len(val), 'validation examples')\\n\",\n    \"print(len(test), 'test examples')\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"a7fa64f8\",\n   \"metadata\": {},\n   \"source\": [\n    \"Create an input pipeline which converts each dataset into a tf.data.Dataset with shuffling and batching.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 9,\n   \"id\": \"499ade5f-ac8a-47ca-a021-071239dfe97d\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def df_to_dataset(dataframe, shuffle=True, batch_size=32):\\n\",\n    \"    df = dataframe.copy()\\n\",\n    \"    labels = df.pop('target')\\n\",\n    \"    df = {key: value.to_numpy()[:,tf.newaxis] for key, value in dataframe.items()}\\n\",\n    \"    ds = tf.data.Dataset.from_tensor_slices((dict(df), labels))\\n\",\n    \"    if shuffle:\\n\",\n    \"        ds = ds.shuffle(buffer_size=len(dataframe))\\n\",\n    \"    ds = ds.batch(batch_size)\\n\",\n    \"    ds = ds.prefetch(batch_size)\\n\",\n    \"    return ds\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"96065bed\",\n   \"metadata\": {},\n   \"source\": [\n    \"Check the format of the data returned by the pipeline:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"b9ec57c9-080e-4626-9e03-acf309cf3736\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"I0000 00:00:1738706370.981571 3686377 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\\n\",\n      \"I0000 00:00:1738706370.984478 3686377 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\\n\",\n      \"I0000 00:00:1738706370.987280 3686377 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\\n\",\n      \"I0000 00:00:1738706371.105121 3686377 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\\n\",\n      \"I0000 00:00:1738706371.106231 3686377 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\\n\",\n      \"I0000 00:00:1738706371.107182 3686377 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\\n\",\n      \"2025-02-04 13:59:31.108098: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 40337 MB memory:  -> device: 0, name: NVIDIA RTX A6000, pci bus id: 0000:01:00.0, compute capability: 8.6\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"batch_size = 5\\n\",\n    \"train_ds = df_to_dataset(train, batch_size=batch_size)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"bdc8571c\",\n   \"metadata\": {},\n   \"source\": [\n    \"(Note that OUT_OF_RANGE errors are safe to ignore: https://github.com/tensorflow/tensorflow/issues/62963).\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"dfcbf268-4508-4eb8-abe1-acf1dbb97bd5\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Every feature: ['Type', 'Age', 'Breed1', 'Gender', 'Color1', 'Color2', 'MaturitySize', 'FurLength', 'Vaccinated', 'Sterilized', 'Health', 'Fee', 'PhotoAmt', 'target']\\n\",\n      \"A batch of ages: tf.Tensor(\\n\",\n      \"[[ 4]\\n\",\n      \" [60]\\n\",\n      \" [24]\\n\",\n      \" [ 1]\\n\",\n      \" [ 2]], shape=(5, 1), dtype=int64)\\n\",\n      \"A batch of targets: tf.Tensor([1 1 1 1 1], shape=(5,), dtype=int64)\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2025-02-04 13:59:31.170523: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"[(train_features, label_batch)] = train_ds.take(1)\\n\",\n    \"print('Every feature:', list(train_features.keys()))\\n\",\n    \"print('A batch of ages:', train_features['Age'])\\n\",\n    \"print('A batch of targets:', label_batch )\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"d5a2d10c\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Apply Keras preprocessing layers\\n\",\n    \"\\n\",\n    \"We'll define a normalization layer for numeric features, and a category encoding for categorical features.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 12,\n   \"id\": \"6c09dc4b-3a2a-44f5-b41c-821ec30b87b1\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def get_normalization_layer(name, dataset):\\n\",\n    \"    # Create a Normalization layer for the feature.\\n\",\n    \"    normalizer = layers.Normalization(axis=None)\\n\",\n    \"\\n\",\n    \"    # Prepare a Dataset that only yields the feature.\\n\",\n    \"    feature_ds = dataset.map(lambda x, y: x[name])\\n\",\n    \"\\n\",\n    \"    # Learn the statistics of the data.\\n\",\n    \"    normalizer.adapt(feature_ds)\\n\",\n    \"\\n\",\n    \"    return normalizer\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 13,\n   \"id\": \"59bb91dc-360a-4a89-a9ea-bebc1ddbf1b7\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2025-02-04 13:59:32.726183: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"<tf.Tensor: shape=(5, 1), dtype=float32, numpy=\\n\",\n       \"array([[-0.19333968],\\n\",\n       \"       [-0.19333968],\\n\",\n       \"       [-0.19333968],\\n\",\n       \"       [-0.51676387],\\n\",\n       \"       [ 1.100357  ]], dtype=float32)>\"\n      ]\n     },\n     \"execution_count\": 13,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"photo_count_col = train_features['PhotoAmt']\\n\",\n    \"layer = get_normalization_layer('PhotoAmt', train_ds)\\n\",\n    \"layer(photo_count_col)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 14,\n   \"id\": \"4623b612-e924-472b-9ef4-c7f14f9f53c5\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def get_category_encoding_layer(name, dataset, dtype, max_tokens=None):\\n\",\n    \"    # Create a layer that turns strings into integer indices.\\n\",\n    \"    if dtype == 'string':\\n\",\n    \"        index = layers.StringLookup(max_tokens=max_tokens)\\n\",\n    \"    # Otherwise, create a layer that turns integer values into integer indices.\\n\",\n    \"    else:\\n\",\n    \"        index = layers.IntegerLookup(max_tokens=max_tokens)\\n\",\n    \"\\n\",\n    \"    # Prepare a `tf.data.Dataset` that only yields the feature.\\n\",\n    \"    feature_ds = dataset.map(lambda x, y: x[name])\\n\",\n    \"\\n\",\n    \"    # Learn the set of possible values and assign them a fixed integer index.\\n\",\n    \"    index.adapt(feature_ds)\\n\",\n    \"\\n\",\n    \"    # Encode the integer indices.\\n\",\n    \"    encoder = layers.CategoryEncoding(num_tokens=index.vocabulary_size())\\n\",\n    \"\\n\",\n    \"    # Apply multi-hot encoding to the indices. The lambda function captures the\\n\",\n    \"    # layer, so you can use them, or include them in the Keras Functional model later.\\n\",\n    \"    return lambda feature: encoder(index(feature))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 15,\n   \"id\": \"0a40e9ee-20a5-4a42-8543-c267f99af55e\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"<tf.Tensor: shape=(5, 3), dtype=float32, numpy=\\n\",\n       \"array([[0., 0., 1.],\\n\",\n       \"       [0., 1., 0.],\\n\",\n       \"       [0., 0., 1.],\\n\",\n       \"       [0., 1., 0.],\\n\",\n       \"       [0., 1., 0.]], dtype=float32)>\"\n      ]\n     },\n     \"execution_count\": 15,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"test_type_col = train_features['Type']\\n\",\n    \"test_type_layer = get_category_encoding_layer(name='Type',\\n\",\n    \"                                              dataset=train_ds,\\n\",\n    \"                                              dtype='string')\\n\",\n    \"test_type_layer(test_type_col)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 16,\n   \"id\": \"ff63a5cc-71f4-428e-9299-a8018edc7648\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2025-02-04 13:59:34.294276: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"<tf.Tensor: shape=(5, 5), dtype=float32, numpy=\\n\",\n       \"array([[0., 0., 0., 0., 1.],\\n\",\n       \"       [1., 0., 0., 0., 0.],\\n\",\n       \"       [1., 0., 0., 0., 0.],\\n\",\n       \"       [0., 0., 0., 1., 0.],\\n\",\n       \"       [0., 1., 0., 0., 0.]], dtype=float32)>\"\n      ]\n     },\n     \"execution_count\": 16,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"test_age_col = train_features['Age']\\n\",\n    \"test_age_layer = get_category_encoding_layer(name='Age',\\n\",\n    \"                                             dataset=train_ds,\\n\",\n    \"                                             dtype='int64',\\n\",\n    \"                                             max_tokens=5)\\n\",\n    \"test_age_layer(test_age_col)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"afefbcf2\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Preprocess selected features\\n\",\n    \"\\n\",\n    \"Apply the preprocessing helper class defined earlier. Add all the feature inputs to a list.\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 17,\n   \"id\": \"2b040b0e-d8ca-4cf0-917c-dd9a272e1f0a\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"batch_size = 256\\n\",\n    \"train_ds = df_to_dataset(train, batch_size=batch_size)\\n\",\n    \"val_ds = df_to_dataset(val, shuffle=False, batch_size=batch_size)\\n\",\n    \"test_ds = df_to_dataset(test, shuffle=False, batch_size=batch_size)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 18,\n   \"id\": \"19df498e-4dd1-467a-8741-e1f5e15932a5\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"all_inputs = {}\\n\",\n    \"encoded_features = []\\n\",\n    \"\\n\",\n    \"# Numerical features.\\n\",\n    \"for header in ['PhotoAmt', 'Fee']:\\n\",\n    \"    numeric_col = tf.keras.Input(shape=(1,), name=header)\\n\",\n    \"    normalization_layer = get_normalization_layer(header, train_ds)\\n\",\n    \"    encoded_numeric_col = normalization_layer(numeric_col)\\n\",\n    \"    all_inputs[header] = numeric_col\\n\",\n    \"    encoded_features.append(encoded_numeric_col)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 19,\n   \"id\": \"1d12579f-34fb-40b0-a16a-3e13cfea8178\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"age_col = tf.keras.Input(shape=(1,), name='Age', dtype='int64')\\n\",\n    \"\\n\",\n    \"encoding_layer = get_category_encoding_layer(name='Age',\\n\",\n    \"                                             dataset=train_ds,\\n\",\n    \"                                             dtype='int64',\\n\",\n    \"                                             max_tokens=5)\\n\",\n    \"encoded_age_col = encoding_layer(age_col)\\n\",\n    \"all_inputs['Age'] = age_col\\n\",\n    \"encoded_features.append(encoded_age_col)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 20,\n   \"id\": \"bff286eb-7ad7-4d3a-8fa4-c729692d1425\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2025-02-04 13:59:34.588989: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\\n\",\n      \"2025-02-04 13:59:35.029267: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"categorical_cols = ['Type', 'Color1', 'Color2', 'Gender', 'MaturitySize',\\n\",\n    \"                    'FurLength', 'Vaccinated', 'Sterilized', 'Health', 'Breed1']\\n\",\n    \"\\n\",\n    \"for header in categorical_cols:\\n\",\n    \"    categorical_col = tf.keras.Input(shape=(1,), name=header, dtype='string')\\n\",\n    \"    encoding_layer = get_category_encoding_layer(name=header,\\n\",\n    \"                                                dataset=train_ds,\\n\",\n    \"                                                dtype='string',\\n\",\n    \"                                                max_tokens=5)\\n\",\n    \"    encoded_categorical_col = encoding_layer(categorical_col)\\n\",\n    \"    all_inputs[header] = categorical_col\\n\",\n    \"    encoded_features.append(encoded_categorical_col)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"e0dfac0d\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Create, compile, and train model\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 21,\n   \"id\": \"79247436-32d8-4738-a656-3f288c77001c\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"all_features = tf.keras.layers.concatenate(encoded_features)\\n\",\n    \"x = tf.keras.layers.Dense(32, activation=\\\"relu\\\")(all_features)\\n\",\n    \"x = tf.keras.layers.Dropout(0.5)(x)\\n\",\n    \"output = tf.keras.layers.Dense(1)(x)\\n\",\n    \"\\n\",\n    \"model = tf.keras.Model(all_inputs, output)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 22,\n   \"id\": \"dbc85d3e-6d1e-4167-9516-b1182e880542\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"model.compile(optimizer='adam',\\n\",\n    \"              loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),\\n\",\n    \"              metrics=[\\\"accuracy\\\"],\\n\",\n    \"              run_eagerly=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 23,\n   \"id\": \"bc9836c8-3c1a-41ad-8833-a946bafcfb00\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Epoch 1/10\\n\",\n      \"\\u001b[1m37/37\\u001b[0m \\u001b[32m━━━━━━━━━━━━━━━━━━━━\\u001b[0m\\u001b[37m\\u001b[0m \\u001b[1m1s\\u001b[0m 16ms/step - accuracy: 0.3658 - loss: 0.7746 - val_accuracy: 0.6854 - val_loss: 0.5841\\n\",\n      \"Epoch 2/10\\n\",\n      \"\\u001b[1m37/37\\u001b[0m \\u001b[32m━━━━━━━━━━━━━━━━━━━━\\u001b[0m\\u001b[37m\\u001b[0m \\u001b[1m1s\\u001b[0m 16ms/step - accuracy: 0.6270 - loss: 0.6023 - val_accuracy: 0.7383 - val_loss: 0.5593\\n\",\n      \"Epoch 3/10\\n\",\n      \"\\u001b[1m37/37\\u001b[0m \\u001b[32m━━━━━━━━━━━━━━━━━━━━\\u001b[0m\\u001b[37m\\u001b[0m \\u001b[1m1s\\u001b[0m 16ms/step - accuracy: 0.6650 - loss: 0.5781 - val_accuracy: 0.7392 - val_loss: 0.5442\\n\",\n      \"Epoch 4/10\\n\",\n      \"\\u001b[1m37/37\\u001b[0m \\u001b[32m━━━━━━━━━━━━━━━━━━━━\\u001b[0m\\u001b[37m\\u001b[0m \\u001b[1m1s\\u001b[0m 17ms/step - accuracy: 0.6609 - loss: 0.5744 - val_accuracy: 0.7418 - val_loss: 0.5329\\n\",\n      \"Epoch 5/10\\n\",\n      \"\\u001b[1m37/37\\u001b[0m \\u001b[32m━━━━━━━━━━━━━━━━━━━━\\u001b[0m\\u001b[37m\\u001b[0m \\u001b[1m1s\\u001b[0m 15ms/step - accuracy: 0.6845 - loss: 0.5555 - val_accuracy: 0.7444 - val_loss: 0.5261\\n\",\n      \"Epoch 6/10\\n\",\n      \"\\u001b[1m37/37\\u001b[0m \\u001b[32m━━━━━━━━━━━━━━━━━━━━\\u001b[0m\\u001b[37m\\u001b[0m \\u001b[1m1s\\u001b[0m 15ms/step - accuracy: 0.6910 - loss: 0.5465 - val_accuracy: 0.7513 - val_loss: 0.5198\\n\",\n      \"Epoch 7/10\\n\",\n      \"\\u001b[1m37/37\\u001b[0m \\u001b[32m━━━━━━━━━━━━━━━━━━━━\\u001b[0m\\u001b[37m\\u001b[0m \\u001b[1m1s\\u001b[0m 15ms/step - accuracy: 0.7018 - loss: 0.5475 - val_accuracy: 0.7556 - val_loss: 0.5145\\n\",\n      \"Epoch 8/10\\n\",\n      \"\\u001b[1m37/37\\u001b[0m \\u001b[32m━━━━━━━━━━━━━━━━━━━━\\u001b[0m\\u001b[37m\\u001b[0m \\u001b[1m1s\\u001b[0m 15ms/step - accuracy: 0.7026 - loss: 0.5410 - val_accuracy: 0.7496 - val_loss: 0.5099\\n\",\n      \"Epoch 9/10\\n\",\n      \"\\u001b[1m37/37\\u001b[0m \\u001b[32m━━━━━━━━━━━━━━━━━━━━\\u001b[0m\\u001b[37m\\u001b[0m \\u001b[1m1s\\u001b[0m 15ms/step - accuracy: 0.7145 - loss: 0.5315 - val_accuracy: 0.7530 - val_loss: 0.5066\\n\",\n      \"Epoch 10/10\\n\",\n      \"\\u001b[1m37/37\\u001b[0m \\u001b[32m━━━━━━━━━━━━━━━━━━━━\\u001b[0m\\u001b[37m\\u001b[0m \\u001b[1m1s\\u001b[0m 15ms/step - accuracy: 0.7099 - loss: 0.5316 - val_accuracy: 0.7539 - val_loss: 0.5038\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"<keras.src.callbacks.history.History at 0x7b15e89123d0>\"\n      ]\n     },\n     \"execution_count\": 23,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"model.fit(train_ds, epochs=10, validation_data=val_ds)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"fbccebaa-fc24-4a58-a032-222cef8fdf08\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\u001b[1m5/5\\u001b[0m \\u001b[32m━━━━━━━━━━━━━━━━━━━━\\u001b[0m\\u001b[37m\\u001b[0m \\u001b[1m0s\\u001b[0m 7ms/step - accuracy: 0.7416 - loss: 0.5196 \\n\",\n      \"Accuracy 0.7443674206733704\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"loss, accuracy = model.evaluate(test_ds)\\n\",\n    \"print(\\\"Accuracy\\\", accuracy)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"7534616c-8561-4869-b6e9-7254ebdb2c3f\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Save and reload model\\n\",\n    \"\\n\",\n    \"Demonstrate saving the trained model and reloading it for inference.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 25,\n   \"id\": \"6bf0d024\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"model.save('models/my_pet_classifier.keras')\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 26,\n   \"id\": \"d1a7be62\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"reloaded_model = tf.keras.models.load_model('models/my_pet_classifier.keras')\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"f3d2a2d5-fd4d-4320-bacc-fd4571cec709\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\u001b[1m1/1\\u001b[0m \\u001b[32m━━━━━━━━━━━━━━━━━━━━\\u001b[0m\\u001b[37m\\u001b[0m \\u001b[1m0s\\u001b[0m 27ms/step\\n\",\n      \"This particular pet had a 83.2 percent probability of getting adopted.\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"sample = {\\n\",\n    \"    'Type': 'Cat',\\n\",\n    \"    'Age': 3,\\n\",\n    \"    'Breed1': 'Tabby',\\n\",\n    \"    'Gender': 'Male',\\n\",\n    \"    'Color1': 'Black',\\n\",\n    \"    'Color2': 'White',\\n\",\n    \"    'MaturitySize': 'Small',\\n\",\n    \"    'FurLength': 'Short',\\n\",\n    \"    'Vaccinated': 'No',\\n\",\n    \"    'Sterilized': 'No',\\n\",\n    \"    'Health': 'Healthy',\\n\",\n    \"    'Fee': 100,\\n\",\n    \"    'PhotoAmt': 2,\\n\",\n    \"}\\n\",\n    \"\\n\",\n    \"input_dict = {name: tf.convert_to_tensor([value]) for name, value in sample.items()}\\n\",\n    \"predictions = reloaded_model.predict(input_dict)\\n\",\n    \"prob = tf.nn.sigmoid(predictions[0])\\n\",\n    \"\\n\",\n    \"print(\\n\",\n    \"    \\\"This particular pet had a %.1f percent probability \\\"\\n\",\n    \"    \\\"of getting adopted.\\\" % (100 * prob)\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"f7bbfe69-93ed-4452-8985-c6685e0726c3\",\n   \"metadata\": {},\n   \"source\": [\n    \"## PySpark\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 28,\n   \"id\": \"fc8a0536\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from pyspark.sql.functions import col, struct, pandas_udf\\n\",\n    \"from pyspark.ml.functions import predict_batch_udf\\n\",\n    \"from pyspark.sql.types import *\\n\",\n    \"from pyspark.sql import SparkSession\\n\",\n    \"from pyspark import SparkConf\\n\",\n    \"import json\\n\",\n    \"import pandas as pd\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"bb5aa875\",\n   \"metadata\": {},\n   \"source\": [\n    \"Check the cluster environment to handle any platform-specific Spark configurations.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 29,\n   \"id\": \"7701420e\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"on_databricks = os.environ.get(\\\"DATABRICKS_RUNTIME_VERSION\\\", False)\\n\",\n    \"on_dataproc = os.environ.get(\\\"DATAPROC_IMAGE_VERSION\\\", False)\\n\",\n    \"on_standalone = not (on_databricks or on_dataproc)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"5e231dbd\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Create Spark Session\\n\",\n    \"\\n\",\n    \"For local standalone clusters, we'll connect to the cluster and create the Spark Session.  \\n\",\n    \"For CSP environments, Spark will either be preconfigured (Databricks) or we'll need to create the Spark Session (Dataproc).\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 30,\n   \"id\": \"60dff1da\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"25/02/04 13:59:42 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)\\n\",\n      \"25/02/04 13:59:42 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\\n\",\n      \"Setting default log level to \\\"WARN\\\".\\n\",\n      \"To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\\n\",\n      \"25/02/04 13:59:43 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"conf = SparkConf()\\n\",\n    \"\\n\",\n    \"if 'spark' not in globals():\\n\",\n    \"    if on_standalone:\\n\",\n    \"        import socket\\n\",\n    \"        \\n\",\n    \"        conda_env = os.environ.get(\\\"CONDA_PREFIX\\\")\\n\",\n    \"        hostname = socket.gethostname()\\n\",\n    \"        conf.setMaster(f\\\"spark://{hostname}:7077\\\")\\n\",\n    \"        conf.set(\\\"spark.pyspark.python\\\", f\\\"{conda_env}/bin/python\\\")\\n\",\n    \"        conf.set(\\\"spark.pyspark.driver.python\\\", f\\\"{conda_env}/bin/python\\\")\\n\",\n    \"    elif on_dataproc:\\n\",\n    \"        conf.set(\\\"spark.executorEnv.TF_GPU_ALLOCATOR\\\", \\\"cuda_malloc_async\\\")\\n\",\n    \"\\n\",\n    \"    conf.set(\\\"spark.executor.cores\\\", \\\"8\\\")\\n\",\n    \"    conf.set(\\\"spark.task.resource.gpu.amount\\\", \\\"0.125\\\")\\n\",\n    \"    conf.set(\\\"spark.executor.resource.gpu.amount\\\", \\\"1\\\")\\n\",\n    \"    conf.set(\\\"spark.sql.execution.arrow.pyspark.enabled\\\", \\\"true\\\")\\n\",\n    \"    conf.set(\\\"spark.python.worker.reuse\\\", \\\"true\\\")\\n\",\n    \"\\n\",\n    \"conf.set(\\\"spark.sql.execution.arrow.maxRecordsPerBatch\\\", \\\"1000\\\")\\n\",\n    \"spark = SparkSession.builder.appName(\\\"spark-dl-examples\\\").config(conf=conf).getOrCreate()\\n\",\n    \"sc = spark.sparkContext\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"fa2333d1\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Create PySpark DataFrame\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 31,\n   \"id\": \"3c64fd7b-3d1e-40f8-ab64-b5c13f8bbe77\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"df = spark.createDataFrame(dataframe).repartition(8)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 32,\n   \"id\": \"1be8215b-5068-41b4-849c-1c3ea7bb108a\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"data_path = \\\"spark-dl-datasets/petfinder-mini\\\"\\n\",\n    \"if on_databricks:\\n\",\n    \"    dbutils.fs.mkdirs(\\\"/FileStore/spark-dl-datasets\\\")\\n\",\n    \"    data_path = \\\"dbfs:/FileStore/\\\" + data_path\\n\",\n    \"\\n\",\n    \"df.write.mode(\\\"overwrite\\\").parquet(data_path)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"7cec4e0e\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Load and preprocess DataFrame\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 33,\n   \"id\": \"0892f845\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"df = spark.read.parquet(data_path).cache()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 34,\n   \"id\": \"952645dd\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"['Type', 'Age', 'Breed1', 'Gender', 'Color1', 'Color2', 'MaturitySize', 'FurLength', 'Vaccinated', 'Sterilized', 'Health', 'Fee', 'PhotoAmt', 'target']\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"columns = df.columns\\n\",\n    \"print(columns)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 35,\n   \"id\": \"b9c24c0d\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"['Type', 'Age', 'Breed1', 'Gender', 'Color1', 'Color2', 'MaturitySize', 'FurLength', 'Vaccinated', 'Sterilized', 'Health', 'Fee', 'PhotoAmt']\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"# remove label column\\n\",\n    \"columns.remove(\\\"target\\\")\\n\",\n    \"print(columns)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 36,\n   \"id\": \"d4dbde99-cf65-4c15-a163-754a0201a48d\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+----+---+--------------------+------+------+--------+------------+---------+----------+----------+-------+---+--------+------+\\n\",\n      \"|Type|Age|              Breed1|Gender|Color1|  Color2|MaturitySize|FurLength|Vaccinated|Sterilized| Health|Fee|PhotoAmt|target|\\n\",\n      \"+----+---+--------------------+------+------+--------+------------+---------+----------+----------+-------+---+--------+------+\\n\",\n      \"| Dog|  3|         Mixed Breed|  Male| Black|No Color|       Small|   Medium|  Not Sure|  Not Sure|Healthy|  0|       2|     0|\\n\",\n      \"| Dog|  9|         Mixed Breed|  Male|  Gray|No Color|      Medium|    Short|  Not Sure|        No|Healthy|  0|       4|     1|\\n\",\n      \"| Cat|  4| Domestic Short Hair|  Male| Black|    Gray|      Medium|    Short|  Not Sure|  Not Sure|Healthy|  0|       4|     1|\\n\",\n      \"| Cat|  6| Domestic Short Hair|  Male|Yellow|   White|      Medium|    Short|        No|        No|Healthy|  0|       3|     1|\\n\",\n      \"| Cat|  6|Domestic Medium Hair|  Male|  Gray|No Color|       Small|   Medium|        No|        No|Healthy|  0|       4|     1|\\n\",\n      \"+----+---+--------------------+------+------+--------+------------+---------+----------+----------+-------+---+--------+------+\\n\",\n      \"only showing top 5 rows\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"df.show(5)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"824d7f97\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Inference using Spark DL API\\n\",\n    \"\\n\",\n    \"Distributed inference using the PySpark [predict_batch_udf](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.functions.predict_batch_udf.html#pyspark.ml.functions.predict_batch_udf):\\n\",\n    \"\\n\",\n    \"- predict_batch_fn uses Tensorflow APIs to load the model and return a predict function which operates on numpy arrays \\n\",\n    \"- predict_batch_udf will convert the Spark DataFrame columns into numpy input batches for the predict function\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 37,\n   \"id\": \"d62eb95a-54c6-44d2-9279-38fb65e0e160\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# get absolute path to model\\n\",\n    \"model_path = \\\"{}/models/my_pet_classifier.keras\\\".format(os.getcwd())\\n\",\n    \"\\n\",\n    \"# For cloud environments, copy the model to the distributed file system.\\n\",\n    \"if on_databricks:\\n\",\n    \"    dbutils.fs.mkdirs(\\\"/FileStore/spark-dl-models\\\")\\n\",\n    \"    dbfs_model_path = \\\"/dbfs/FileStore/spark-dl-models/my_pet_classifier.keras\\\"\\n\",\n    \"    shutil.copy(model_path, dbfs_model_path)\\n\",\n    \"    model_path = dbfs_model_path\\n\",\n    \"elif on_dataproc:\\n\",\n    \"    # GCS is mounted at /mnt/gcs by the init script\\n\",\n    \"    models_dir = \\\"/mnt/gcs/spark-dl/models\\\"\\n\",\n    \"    os.mkdir(models_dir) if not os.path.exists(models_dir) else None\\n\",\n    \"    gcs_model_path = models_dir + \\\"/my_pet_classifier.keras\\\"\\n\",\n    \"    shutil.copy(model_path, gcs_model_path)\\n\",\n    \"    model_path = gcs_model_path\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 38,\n   \"id\": \"45665acf-50c8-445b-a985-b3dabd734709\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def predict_batch_fn():\\n\",\n    \"    import tensorflow as tf\\n\",\n    \"    import pandas as pd\\n\",\n    \"    \\n\",\n    \"    # Enable GPU memory growth to avoid CUDA OOM\\n\",\n    \"    gpus = tf.config.experimental.list_physical_devices('GPU')\\n\",\n    \"    if gpus:\\n\",\n    \"        try:\\n\",\n    \"            for gpu in gpus:\\n\",\n    \"                tf.config.experimental.set_memory_growth(gpu, True)\\n\",\n    \"        except RuntimeError as e:\\n\",\n    \"            print(e)\\n\",\n    \"\\n\",\n    \"    model = tf.keras.models.load_model(model_path)\\n\",\n    \"\\n\",\n    \"    def predict(t, a, b, g, c1, c2, m, f, v, s, h, fee, p):\\n\",\n    \"        inputs = {\\n\",\n    \"            \\\"Type\\\": t,\\n\",\n    \"            \\\"Age\\\": a,\\n\",\n    \"            \\\"Breed1\\\": b,\\n\",\n    \"            \\\"Gender\\\": g,\\n\",\n    \"            \\\"Color1\\\": c1,\\n\",\n    \"            \\\"Color2\\\": c2,\\n\",\n    \"            \\\"MaturitySize\\\": m,\\n\",\n    \"            \\\"FurLength\\\": f,\\n\",\n    \"            \\\"Vaccinated\\\": v,\\n\",\n    \"            \\\"Sterilized\\\": s,\\n\",\n    \"            \\\"Health\\\": h,\\n\",\n    \"            \\\"Fee\\\": fee,\\n\",\n    \"            \\\"PhotoAmt\\\": p\\n\",\n    \"        }\\n\",\n    \"        # return model.predict(inputs)\\n\",\n    \"        return pd.Series(np.squeeze(model.predict(inputs)))\\n\",\n    \"\\n\",\n    \"    return predict\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 39,\n   \"id\": \"815e3b5f-7914-4235-85fa-50153dcd3d30\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# need to pass the list of columns into the model_udf\\n\",\n    \"classify = predict_batch_udf(predict_batch_fn,\\n\",\n    \"                             return_type=FloatType(),\\n\",\n    \"                             batch_size=100)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 40,\n   \"id\": \"da03a0c6-2d39-425e-a9fa-57c139cca1ed\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"25/02/04 13:59:47 WARN package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 5:>                                                          (0 + 8) / 8]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 19.8 ms, sys: 9.3 ms, total: 29.1 ms\\n\",\n      \"Wall time: 4.99 s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"preds = df.withColumn(\\\"preds\\\", classify(struct(*columns)))\\n\",\n    \"results = preds.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 41,\n   \"id\": \"03990c76-7198-49a7-bb5d-6870be915fb3\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 6:>                                                          (0 + 8) / 8]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 86.9 ms, sys: 13.7 ms, total: 101 ms\\n\",\n      \"Wall time: 1.56 s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"preds = df.withColumn(\\\"preds\\\", classify(*columns))\\n\",\n    \"results = preds.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 42,\n   \"id\": \"edb93cf3-c248-40c9-b8dc-acc8f51786a9\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 7:>                                                          (0 + 8) / 8]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 16.4 ms, sys: 4.46 ms, total: 20.9 ms\\n\",\n      \"Wall time: 1.52 s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"preds = df.withColumn(\\\"preds\\\", classify(*[col(c) for c in columns]))\\n\",\n    \"results = preds.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 43,\n   \"id\": \"a91f19cb-f7f1-4669-aff1-be594bea5378\",\n   \"metadata\": {\n    \"scrolled\": true\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+----+---+--------------------+------+------+--------+------------+---------+----------+----------+------------+---+--------+------+-----------+\\n\",\n      \"|Type|Age|              Breed1|Gender|Color1|  Color2|MaturitySize|FurLength|Vaccinated|Sterilized|      Health|Fee|PhotoAmt|target|      preds|\\n\",\n      \"+----+---+--------------------+------+------+--------+------------+---------+----------+----------+------------+---+--------+------+-----------+\\n\",\n      \"| Dog|  3|         Mixed Breed|  Male| Black|No Color|       Small|   Medium|  Not Sure|  Not Sure|     Healthy|  0|       2|     0|  0.4963937|\\n\",\n      \"| Dog|  9|         Mixed Breed|  Male|  Gray|No Color|      Medium|    Short|  Not Sure|        No|     Healthy|  0|       4|     1|  0.6780287|\\n\",\n      \"| Cat|  4| Domestic Short Hair|  Male| Black|    Gray|      Medium|    Short|  Not Sure|  Not Sure|     Healthy|  0|       4|     1| 0.58800673|\\n\",\n      \"| Cat|  6| Domestic Short Hair|  Male|Yellow|   White|      Medium|    Short|        No|        No|     Healthy|  0|       3|     1|  0.7378843|\\n\",\n      \"| Cat|  6|Domestic Medium Hair|  Male|  Gray|No Color|       Small|   Medium|        No|        No|     Healthy|  0|       4|     1|  1.2695599|\\n\",\n      \"| Cat|  5|Domestic Medium Hair|Female|  Gray|No Color|      Medium|   Medium|       Yes|  Not Sure|     Healthy|  0|       1|     0|0.060457088|\\n\",\n      \"| Dog| 24|              Beagle|Female| Black|  Golden|      Medium|    Short|  Not Sure|  Not Sure|Minor Injury|  0|       1|     1| 0.28160828|\\n\",\n      \"| Cat| 29|               Tabby|  Male| Brown|  Golden|      Medium|    Short|        No|        No|     Healthy|  0|       1|     0|  0.6928505|\\n\",\n      \"| Dog|  9|         Mixed Breed|Female| Black|   Brown|      Medium|    Short|       Yes|       Yes|     Healthy|  0|       2|     0|-0.10125986|\\n\",\n      \"| Dog|  2|         Mixed Breed|Female| Cream|   White|      Medium|    Short|        No|        No|     Healthy|  0|       1|     0|  1.3703903|\\n\",\n      \"| Dog|  2|         Mixed Breed|  Male| Brown|   White|      Medium|    Short|       Yes|        No|     Healthy|  0|       1|     1|  1.3243997|\\n\",\n      \"| Dog| 60|    Golden Retriever|  Male| Brown|  Yellow|      Medium|   Medium|       Yes|       Yes|     Healthy|  0|       5|     1|  0.9026731|\\n\",\n      \"| Cat|  9|             Siamese|  Male| White|No Color|      Medium|    Short|       Yes|        No|     Healthy|  0|       2|     1|  0.8207382|\\n\",\n      \"| Dog| 19|   Doberman Pinscher|Female| Black|   Brown|       Large|    Short|       Yes|       Yes|     Healthy|500|       2|     1| 0.85343015|\\n\",\n      \"| Cat| 11| Domestic Short Hair|  Male| Cream|No Color|      Medium|    Short|       Yes|       Yes|     Healthy|100|       6|     0| 0.53920615|\\n\",\n      \"| Dog| 18|         Mixed Breed|Female| Brown|   White|       Small|    Short|       Yes|        No|     Healthy|  0|       5|     0|   0.718272|\\n\",\n      \"| Dog|  4|         Mixed Breed|Female| Brown|   White|      Medium|   Medium|  Not Sure|  Not Sure|     Healthy|  0|       3|     0| 0.16185221|\\n\",\n      \"| Dog| 96|    Golden Retriever|  Male|Golden|No Color|       Large|     Long|       Yes|       Yes|     Healthy|  0|       2|     1|  0.8156965|\\n\",\n      \"| Dog| 54|    Golden Retriever|  Male|Golden|No Color|       Large|   Medium|       Yes|        No|     Healthy|350|      20|     1|  3.5315154|\\n\",\n      \"| Cat|  5|Domestic Medium Hair|Female| Brown|   White|      Medium|   Medium|        No|        No|     Healthy|  0|       5|     1|  1.1725564|\\n\",\n      \"+----+---+--------------------+------+------+--------+------------+---------+----------+----------+------------+---+--------+------+-----------+\\n\",\n      \"only showing top 20 rows\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"preds.show()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"0c3e0390\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Using Triton Inference Server\\n\",\n    \"In this section, we demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL.  \\n\",\n    \"We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server.  \\n\",\n    \"\\n\",\n    \"The process looks like this:\\n\",\n    \"- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.\\n\",\n    \"- Define a Triton inference function, which contains a client that binds to the local server on a given node and sends inference requests.\\n\",\n    \"- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.\\n\",\n    \"- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.\\n\",\n    \"\\n\",\n    \"<img src=\\\"../images/spark-server.png\\\" alt=\\\"drawing\\\" width=\\\"700\\\"/>\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 44,\n   \"id\": \"2605d134-ef75-4d94-9b16-2c6d85f29bef\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from functools import partial\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"ea407357\",\n   \"metadata\": {},\n   \"source\": [\n    \"Import the helper class from server_utils.py:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 45,\n   \"id\": \"7e1e716f\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"sc.addPyFile(\\\"server_utils.py\\\")\\n\",\n    \"\\n\",\n    \"from server_utils import TritonServerManager\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"fcd28e7d\",\n   \"metadata\": {},\n   \"source\": [\n    \"Define the Triton Server function:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 46,\n   \"id\": \"4666e618-8038-4dc5-9be7-793aedbf4500\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def triton_server(ports, model_path):\\n\",\n    \"    import time\\n\",\n    \"    import signal\\n\",\n    \"    import numpy as np\\n\",\n    \"    import tensorflow as tf\\n\",\n    \"    from pytriton.decorators import batch\\n\",\n    \"    from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor\\n\",\n    \"    from pytriton.triton import Triton, TritonConfig\\n\",\n    \"    from pyspark import TaskContext\\n\",\n    \"\\n\",\n    \"    print(f\\\"SERVER: Initializing model on worker {TaskContext.get().partitionId()}.\\\")\\n\",\n    \"    # Enable GPU memory growth\\n\",\n    \"    gpus = tf.config.experimental.list_physical_devices('GPU')\\n\",\n    \"    if gpus:\\n\",\n    \"        try:\\n\",\n    \"            for gpu in gpus:\\n\",\n    \"                tf.config.experimental.set_memory_growth(gpu, True)\\n\",\n    \"        except RuntimeError as e:\\n\",\n    \"            print(e)\\n\",\n    \"\\n\",\n    \"    model = tf.keras.models.load_model(model_path)\\n\",\n    \"\\n\",\n    \"    def decode(input_tensor):\\n\",\n    \"        return tf.convert_to_tensor(np.vectorize(lambda x: x.decode('utf-8'))(input_tensor))\\n\",\n    \"\\n\",\n    \"    def identity(input_tensor):\\n\",\n    \"        return tf.convert_to_tensor(input_tensor)\\n\",\n    \"\\n\",\n    \"    input_transforms = {\\n\",\n    \"        \\\"Type\\\": decode,\\n\",\n    \"        \\\"Age\\\": identity,\\n\",\n    \"        \\\"Breed1\\\": decode,\\n\",\n    \"        \\\"Gender\\\": decode,\\n\",\n    \"        \\\"Color1\\\": decode,\\n\",\n    \"        \\\"Color2\\\": decode,\\n\",\n    \"        \\\"MaturitySize\\\": decode,\\n\",\n    \"        \\\"FurLength\\\": decode,\\n\",\n    \"        \\\"Vaccinated\\\": decode,\\n\",\n    \"        \\\"Sterilized\\\": decode,\\n\",\n    \"        \\\"Health\\\": decode,\\n\",\n    \"        \\\"Fee\\\": identity,\\n\",\n    \"        \\\"PhotoAmt\\\": identity\\n\",\n    \"    }\\n\",\n    \"\\n\",\n    \"    @batch\\n\",\n    \"    def _infer_fn(**inputs):\\n\",\n    \"        decoded_inputs = {k: input_transforms[k](v) for k, v in inputs.items()}\\n\",\n    \"        print(f\\\"SERVER: Received batch of size {len(decoded_inputs['Type'])}.\\\")\\n\",\n    \"        return {\\n\",\n    \"            \\\"preds\\\": model.predict(decoded_inputs)\\n\",\n    \"        }\\n\",\n    \"\\n\",\n    \"    workspace_path = f\\\"/tmp/triton_{time.strftime('%m_%d_%M_%S')}\\\"\\n\",\n    \"    triton_conf = TritonConfig(http_port=ports[0], grpc_port=ports[1], metrics_port=ports[2])\\n\",\n    \"    with Triton(config=triton_conf, workspace=workspace_path) as triton:\\n\",\n    \"        triton.bind(\\n\",\n    \"            model_name=\\\"PetClassifier\\\",\\n\",\n    \"            infer_func=_infer_fn,\\n\",\n    \"            inputs=[\\n\",\n    \"                Tensor(name=\\\"Type\\\", dtype=np.bytes_, shape=(-1,)),\\n\",\n    \"                Tensor(name=\\\"Age\\\", dtype=np.int64, shape=(-1,)),\\n\",\n    \"                Tensor(name=\\\"Breed1\\\", dtype=np.bytes_, shape=(-1,)),\\n\",\n    \"                Tensor(name=\\\"Gender\\\", dtype=np.bytes_, shape=(-1,)),\\n\",\n    \"                Tensor(name=\\\"Color1\\\", dtype=np.bytes_, shape=(-1,)),\\n\",\n    \"                Tensor(name=\\\"Color2\\\", dtype=np.bytes_, shape=(-1,)),\\n\",\n    \"                Tensor(name=\\\"MaturitySize\\\", dtype=np.bytes_, shape=(-1,)),\\n\",\n    \"                Tensor(name=\\\"FurLength\\\", dtype=np.bytes_, shape=(-1,)),\\n\",\n    \"                Tensor(name=\\\"Vaccinated\\\", dtype=np.bytes_, shape=(-1,)),\\n\",\n    \"                Tensor(name=\\\"Sterilized\\\", dtype=np.bytes_, shape=(-1,)),\\n\",\n    \"                Tensor(name=\\\"Health\\\", dtype=np.bytes_, shape=(-1,)),\\n\",\n    \"                Tensor(name=\\\"Fee\\\", dtype=np.int64, shape=(-1,)),\\n\",\n    \"                Tensor(name=\\\"PhotoAmt\\\", dtype=np.int64, shape=(-1,)),\\n\",\n    \"            ],\\n\",\n    \"            outputs=[\\n\",\n    \"                Tensor(name=\\\"preds\\\", dtype=np.float32, shape=(-1,)),\\n\",\n    \"            ],\\n\",\n    \"            config=ModelConfig(\\n\",\n    \"                max_batch_size=128,\\n\",\n    \"                batcher=DynamicBatcher(max_queue_delay_microseconds=5000),  # 5ms\\n\",\n    \"            ),\\n\",\n    \"            strict=True,\\n\",\n    \"        )\\n\",\n    \"\\n\",\n    \"        def _stop_triton(signum, frame):\\n\",\n    \"            # The server manager sends SIGTERM to stop the server; this function ensures graceful cleanup.\\n\",\n    \"            print(\\\"SERVER: Received SIGTERM. Stopping Triton server.\\\")\\n\",\n    \"            triton.stop()\\n\",\n    \"\\n\",\n    \"        signal.signal(signal.SIGTERM, _stop_triton)\\n\",\n    \"\\n\",\n    \"        print(\\\"SERVER: Serving inference\\\")\\n\",\n    \"        triton.serve()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"617525a5\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Start Triton servers\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"fc93a43a\",\n   \"metadata\": {},\n   \"source\": [\n    \"The `TritonServerManager` will handle the lifecycle of Triton server instances across the Spark cluster:\\n\",\n    \"- Find available ports for HTTP/gRPC/metrics\\n\",\n    \"- Deploy a server on each node via stage-level scheduling\\n\",\n    \"- Gracefully shutdown servers across nodes\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"c9b98208\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"model_name = \\\"PetClassifier\\\"\\n\",\n    \"server_manager = TritonServerManager(model_name=model_name, model_path=model_path)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"228401f7\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2025-02-07 11:03:44,809 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\\n\",\n      \"2025-02-07 11:03:44,810 - INFO - Starting 1 servers.\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'cb4ae00-lcedt': (2020631, [7000, 7001, 7002])}\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"# Returns {'hostname', (server_pid, [http_port, grpc_port, metrics_port])}\\n\",\n    \"server_manager.start_servers(triton_server)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"cb560288\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Define client function\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"5d28b1ca\",\n   \"metadata\": {},\n   \"source\": [\n    \"Get the hostname -> url mapping from the server manager:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"d1234a02\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"host_to_http_url = server_manager.host_to_http_url  # or server_manager.host_to_grpc_url\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"3c9ef706\",\n   \"metadata\": {},\n   \"source\": [\n    \"Define the Triton inference function, which returns a predict function for batch inference through the server:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 51,\n   \"id\": \"e50b5fc8\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def triton_fn(model_name, host_to_url):\\n\",\n    \"    import socket\\n\",\n    \"    import numpy as np\\n\",\n    \"    from pytriton.client import ModelClient\\n\",\n    \"\\n\",\n    \"    url = host_to_url[socket.gethostname()]\\n\",\n    \"    print(f\\\"CLIENT: Connecting to {model_name} at {url}\\\")\\n\",\n    \"\\n\",\n    \"    def infer_batch(t, a, b, g, c1, c2, m, f, v, s, h, fee, p):\\n\",\n    \"        \\n\",\n    \"        def encode(value):\\n\",\n    \"            return np.vectorize(lambda x: x.encode(\\\"utf-8\\\"))(value).astype(np.bytes_)\\n\",\n    \"\\n\",\n    \"        with ModelClient(url, model_name, inference_timeout_s=240) as client:\\n\",\n    \"            encoded_inputs = {\\n\",\n    \"                \\\"Type\\\": encode(t), \\n\",\n    \"                \\\"Age\\\": a, \\n\",\n    \"                \\\"Breed1\\\": encode(b), \\n\",\n    \"                \\\"Gender\\\": encode(g),\\n\",\n    \"                \\\"Color1\\\": encode(c1),\\n\",\n    \"                \\\"Color2\\\": encode(c2),\\n\",\n    \"                \\\"MaturitySize\\\": encode(m),\\n\",\n    \"                \\\"FurLength\\\": encode(f),\\n\",\n    \"                \\\"Vaccinated\\\": encode(v),\\n\",\n    \"                \\\"Sterilized\\\": encode(s),\\n\",\n    \"                \\\"Health\\\": encode(h),\\n\",\n    \"                \\\"Fee\\\": fee,\\n\",\n    \"                \\\"PhotoAmt\\\": p\\n\",\n    \"            }\\n\",\n    \"            result_data = client.infer_batch(**encoded_inputs)\\n\",\n    \"            return result_data[\\\"preds\\\"]\\n\",\n    \"            \\n\",\n    \"    return infer_batch\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 54,\n   \"id\": \"2ffb020e-dc93-456b-bee6-405611eee1e1\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# need to pass the list of columns into the model_udf\\n\",\n    \"classify = predict_batch_udf(partial(triton_fn, model_name=model_name, host_to_url=host_to_http_url),\\n\",\n    \"                             input_tensor_shapes=[[1]] * len(columns),\\n\",\n    \"                             return_type=FloatType(),\\n\",\n    \"                             batch_size=64)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"2edd887f\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Load and preprocess DataFrame\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 52,\n   \"id\": \"fe8dc3e6-f1b1-4a24-85f4-0a5ecabef4c5\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"df = spark.read.parquet(data_path)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 53,\n   \"id\": \"4cfb3f34-a215-4781-91bf-2bec85e15633\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"columns = df.columns\\n\",\n    \"# remove label column\\n\",\n    \"columns.remove(\\\"target\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"b75e6f20-f06c-4f4c-ada1-c562e078ed4b\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Run inference\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 55,\n   \"id\": \"e6ff0356-becd-421f-aebb-272497d5ad6a\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 12:>                                                         (0 + 8) / 8]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 17.3 ms, sys: 7.75 ms, total: 25.1 ms\\n\",\n      \"Wall time: 6.35 s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"preds = df.withColumn(\\\"preds\\\", classify(struct(*columns)))\\n\",\n    \"results = preds.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 56,\n   \"id\": \"ce18ee7c-5958-4986-b200-6d986fcc6243\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 15.2 ms, sys: 4.2 ms, total: 19.4 ms\\n\",\n      \"Wall time: 5.86 s\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"preds = df.withColumn(\\\"preds\\\", classify(*columns))\\n\",\n    \"results = preds.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 57,\n   \"id\": \"0888ce40-b2c4-4aed-8ccb-6a8bcd00abc8\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 14:===========================================>              (6 + 2) / 8]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 93.4 ms, sys: 3.4 ms, total: 96.8 ms\\n\",\n      \"Wall time: 5.87 s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"preds = df.withColumn(\\\"preds\\\", classify(*[col(c) for c in columns]))\\n\",\n    \"results = preds.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 58,\n   \"id\": \"d45812b5-f584-41a4-a821-2b59e065671c\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+----+---+--------------------+------+------+--------+------------+---------+----------+----------+------------+---+--------+------+-----------+\\n\",\n      \"|Type|Age|              Breed1|Gender|Color1|  Color2|MaturitySize|FurLength|Vaccinated|Sterilized|      Health|Fee|PhotoAmt|target|      preds|\\n\",\n      \"+----+---+--------------------+------+------+--------+------------+---------+----------+----------+------------+---+--------+------+-----------+\\n\",\n      \"| Dog|  3|         Mixed Breed|  Male| Black|No Color|       Small|   Medium|  Not Sure|  Not Sure|     Healthy|  0|       2|     0|  0.4963937|\\n\",\n      \"| Dog|  9|         Mixed Breed|  Male|  Gray|No Color|      Medium|    Short|  Not Sure|        No|     Healthy|  0|       4|     1|  0.6780287|\\n\",\n      \"| Cat|  4| Domestic Short Hair|  Male| Black|    Gray|      Medium|    Short|  Not Sure|  Not Sure|     Healthy|  0|       4|     1| 0.58800673|\\n\",\n      \"| Cat|  6| Domestic Short Hair|  Male|Yellow|   White|      Medium|    Short|        No|        No|     Healthy|  0|       3|     1|  0.7378843|\\n\",\n      \"| Cat|  6|Domestic Medium Hair|  Male|  Gray|No Color|       Small|   Medium|        No|        No|     Healthy|  0|       4|     1|  1.2695599|\\n\",\n      \"| Cat|  5|Domestic Medium Hair|Female|  Gray|No Color|      Medium|   Medium|       Yes|  Not Sure|     Healthy|  0|       1|     0|0.060457088|\\n\",\n      \"| Dog| 24|              Beagle|Female| Black|  Golden|      Medium|    Short|  Not Sure|  Not Sure|Minor Injury|  0|       1|     1| 0.28160828|\\n\",\n      \"| Cat| 29|               Tabby|  Male| Brown|  Golden|      Medium|    Short|        No|        No|     Healthy|  0|       1|     0|  0.6928505|\\n\",\n      \"| Dog|  9|         Mixed Breed|Female| Black|   Brown|      Medium|    Short|       Yes|       Yes|     Healthy|  0|       2|     0|-0.10125986|\\n\",\n      \"| Dog|  2|         Mixed Breed|Female| Cream|   White|      Medium|    Short|        No|        No|     Healthy|  0|       1|     0|  1.3703903|\\n\",\n      \"| Dog|  2|         Mixed Breed|  Male| Brown|   White|      Medium|    Short|       Yes|        No|     Healthy|  0|       1|     1|  1.3243997|\\n\",\n      \"| Dog| 60|    Golden Retriever|  Male| Brown|  Yellow|      Medium|   Medium|       Yes|       Yes|     Healthy|  0|       5|     1|  0.9026731|\\n\",\n      \"| Cat|  9|             Siamese|  Male| White|No Color|      Medium|    Short|       Yes|        No|     Healthy|  0|       2|     1|  0.8207382|\\n\",\n      \"| Dog| 19|   Doberman Pinscher|Female| Black|   Brown|       Large|    Short|       Yes|       Yes|     Healthy|500|       2|     1| 0.85343015|\\n\",\n      \"| Cat| 11| Domestic Short Hair|  Male| Cream|No Color|      Medium|    Short|       Yes|       Yes|     Healthy|100|       6|     0| 0.53920615|\\n\",\n      \"| Dog| 18|         Mixed Breed|Female| Brown|   White|       Small|    Short|       Yes|        No|     Healthy|  0|       5|     0|   0.718272|\\n\",\n      \"| Dog|  4|         Mixed Breed|Female| Brown|   White|      Medium|   Medium|  Not Sure|  Not Sure|     Healthy|  0|       3|     0| 0.16185221|\\n\",\n      \"| Dog| 96|    Golden Retriever|  Male|Golden|No Color|       Large|     Long|       Yes|       Yes|     Healthy|  0|       2|     1|  0.8156965|\\n\",\n      \"| Dog| 54|    Golden Retriever|  Male|Golden|No Color|       Large|   Medium|       Yes|        No|     Healthy|350|      20|     1|  3.5315154|\\n\",\n      \"| Cat|  5|Domestic Medium Hair|Female| Brown|   White|      Medium|   Medium|        No|        No|     Healthy|  0|       5|     1|  1.1725564|\\n\",\n      \"+----+---+--------------------+------+------+--------+------------+---------+----------+----------+------------+---+--------+------+-----------+\\n\",\n      \"only showing top 20 rows\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"preds.show()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"63135aa0-b44c-4dda-8050-8cad320afe88\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"source\": [\n    \"#### Stop Triton Server on each executor\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 59,\n   \"id\": \"6914f44f-677f-4db3-be09-783df8d11b8a\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2025-02-04 14:00:18,330 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\\n\",\n      \"2025-02-04 14:00:28,520 - INFO - Sucessfully stopped 1 servers.                 \\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"[True]\"\n      ]\n     },\n     \"execution_count\": 59,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"server_manager.stop_servers()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 60,\n   \"id\": \"f8c6ee43-8891-4446-986e-1447c5d48bac\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"if not on_databricks: # on databricks, spark.stop() puts the cluster in a bad state\\n\",\n    \"    spark.stop()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"e611126e-d8c3-40ac-bf16-b911f6d7b39f\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"spark-dl-tf\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.11.9\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "examples/ML+DL-Examples/Spark-DL/dl_inference/tensorflow/keras_resnet50_tf.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"8e6810cc-5982-4293-bfbd-c91ef0aca204\",\n   \"metadata\": {},\n   \"source\": [\n    \"<img src=\\\"http://developer.download.nvidia.com/notebooks/dlsw-notebooks/tensorrt_torchtrt_efficientnet/nvidia_logo.png\\\" width=\\\"90px\\\">\\n\",\n    \"\\n\",\n    \"# PySpark Tensorflow Inference\\n\",\n    \"\\n\",\n    \"### Flower Recognition with Keras Resnet50\\n\",\n    \"\\n\",\n    \"In this notebook, we demonstrate distribute inference with Resnet50 on the Databricks flower photos dataset.  \\n\",\n    \"From: https://docs.databricks.com/_static/notebooks/deep-learning/keras-metadata.html\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"858e3a8d\",\n   \"metadata\": {},\n   \"source\": [\n    \"Note that cuFFT/cuDNN/cuBLAS registration errors are expected (as of `tf=2.17.0`) and will not affect behavior, as noted in [this issue.](https://github.com/tensorflow/tensorflow/issues/62075)  \"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"id\": \"cf329ac8-0763-44bc-b0f6-b634b7dc480e\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2025-02-04 14:00:35.457924: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\\n\",\n      \"2025-02-04 14:00:35.465639: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\\n\",\n      \"2025-02-04 14:00:35.473515: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\\n\",\n      \"2025-02-04 14:00:35.475792: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\\n\",\n      \"2025-02-04 14:00:35.482106: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\\n\",\n      \"To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\\n\",\n      \"2025-02-04 14:00:35.843263: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"import os\\n\",\n    \"import shutil\\n\",\n    \"import subprocess\\n\",\n    \"import time\\n\",\n    \"import json\\n\",\n    \"import pandas as pd\\n\",\n    \"from PIL import Image\\n\",\n    \"import numpy as np\\n\",\n    \"import uuid\\n\",\n    \" \\n\",\n    \"import tensorflow as tf\\n\",\n    \"from tensorflow.keras.applications.resnet50 import ResNet50\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"id\": \"532d562d\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"os.mkdir('models') if not os.path.exists('models') else None\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"id\": \"75175140\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2.17.0\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\\n\",\n      \"I0000 00:00:1738706436.174805 3714100 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\\n\",\n      \"I0000 00:00:1738706436.197467 3714100 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\\n\",\n      \"I0000 00:00:1738706436.200398 3714100 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"print(tf.__version__)\\n\",\n    \"\\n\",\n    \"# Enable GPU memory growth\\n\",\n    \"gpus = tf.config.experimental.list_physical_devices('GPU')\\n\",\n    \"if gpus:\\n\",\n    \"    try:\\n\",\n    \"        for gpu in gpus:\\n\",\n    \"            tf.config.experimental.set_memory_growth(gpu, True)\\n\",\n    \"    except RuntimeError as e:\\n\",\n    \"        print(e)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"02fe61b8\",\n   \"metadata\": {},\n   \"source\": [\n    \"## PySpark\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"id\": \"b474339c\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from pyspark.sql.functions import col, struct, pandas_udf, PandasUDFType\\n\",\n    \"from pyspark.ml.functions import predict_batch_udf\\n\",\n    \"from pyspark.sql.types import *\\n\",\n    \"from pyspark.sql import SparkSession\\n\",\n    \"from pyspark import SparkConf\\n\",\n    \"from typing import Iterator, Tuple\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"e182cacb\",\n   \"metadata\": {},\n   \"source\": [\n    \"Check the cluster environment to handle any platform-specific Spark configurations.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"id\": \"564b1d33\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"on_databricks = os.environ.get(\\\"DATABRICKS_RUNTIME_VERSION\\\", False)\\n\",\n    \"on_dataproc = os.environ.get(\\\"DATAPROC_IMAGE_VERSION\\\", False)\\n\",\n    \"on_standalone = not (on_databricks or on_dataproc)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"016cdd0b\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Create Spark Session\\n\",\n    \"\\n\",\n    \"For local standalone clusters, we'll connect to the cluster and create the Spark Session.  \\n\",\n    \"For CSP environments, Spark will either be preconfigured (Databricks) or we'll need to create the Spark Session (Dataproc).\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"id\": \"44d72768\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"25/02/04 14:00:36 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)\\n\",\n      \"25/02/04 14:00:36 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\\n\",\n      \"Setting default log level to \\\"WARN\\\".\\n\",\n      \"To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\\n\",\n      \"25/02/04 14:00:37 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"conf = SparkConf()\\n\",\n    \"\\n\",\n    \"if 'spark' not in globals():\\n\",\n    \"    if on_standalone:\\n\",\n    \"        import socket\\n\",\n    \"        \\n\",\n    \"        conda_env = os.environ.get(\\\"CONDA_PREFIX\\\")\\n\",\n    \"        hostname = socket.gethostname()\\n\",\n    \"        conf.setMaster(f\\\"spark://{hostname}:7077\\\")\\n\",\n    \"        conf.set(\\\"spark.pyspark.python\\\", f\\\"{conda_env}/bin/python\\\")\\n\",\n    \"        conf.set(\\\"spark.pyspark.driver.python\\\", f\\\"{conda_env}/bin/python\\\")\\n\",\n    \"    elif on_dataproc:\\n\",\n    \"        conf.set(\\\"spark.executorEnv.TF_GPU_ALLOCATOR\\\", \\\"cuda_malloc_async\\\")\\n\",\n    \"\\n\",\n    \"    conf.set(\\\"spark.executor.cores\\\", \\\"8\\\")\\n\",\n    \"    conf.set(\\\"spark.task.resource.gpu.amount\\\", \\\"0.125\\\")\\n\",\n    \"    conf.set(\\\"spark.executor.resource.gpu.amount\\\", \\\"1\\\")\\n\",\n    \"    conf.set(\\\"spark.sql.execution.arrow.pyspark.enabled\\\", \\\"true\\\")\\n\",\n    \"    conf.set(\\\"spark.python.worker.reuse\\\", \\\"true\\\")\\n\",\n    \"    conf.set(\\\"spark.driver.memory\\\", \\\"8g\\\")\\n\",\n    \"    conf.set(\\\"spark.executor.memory\\\", \\\"8g\\\")\\n\",\n    \"\\n\",\n    \"conf.set(\\\"spark.sql.execution.arrow.maxRecordsPerBatch\\\", \\\"512\\\")\\n\",\n    \"conf.set(\\\"spark.sql.parquet.columnarReaderBatchSize\\\", \\\"1024\\\")\\n\",\n    \"spark = SparkSession.builder.appName(\\\"spark-dl-examples\\\").config(conf=conf).getOrCreate()\\n\",\n    \"sc = spark.sparkContext\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"61c406fa\",\n   \"metadata\": {},\n   \"source\": [\n    \"Define the input and output directories.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"id\": \"c566dc17\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"os.mkdir(\\\"spark-dl-datasets\\\") if not os.path.exists(\\\"spark-dl-datasets\\\") else None\\n\",\n    \"data_path = \\\"spark-dl-datasets/flowers_{uuid}.parquet\\\".format(uuid=str(uuid.uuid1()))\\n\",\n    \"local_file_path = f\\\"{os.getcwd()}/{data_path}\\\"\\n\",\n    \"output_file_path = \\\"predictions/predictions\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"968d08a7-66b9-444f-b362-d8df692aef1c\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Prepare trained model and data for inference\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"da083168-137f-492c-8769-d8f1e2111756\",\n   \"metadata\": {},\n   \"source\": [\n    \"Load the ResNet-50 Model and broadcast the weights.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"id\": \"2ddc715a-cdbc-4c49-93e9-58c9d88511da\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"I0000 00:00:1738706437.771948 3714100 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\\n\",\n      \"I0000 00:00:1738706437.774792 3714100 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\\n\",\n      \"I0000 00:00:1738706437.777387 3714100 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\\n\",\n      \"I0000 00:00:1738706437.894244 3714100 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\\n\",\n      \"I0000 00:00:1738706437.895287 3714100 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\\n\",\n      \"I0000 00:00:1738706437.896207 3714100 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\\n\",\n      \"2025-02-04 14:00:37.897142: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 40337 MB memory:  -> device: 0, name: NVIDIA RTX A6000, pci bus id: 0000:01:00.0, compute capability: 8.6\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"model = ResNet50()\\n\",\n    \"bc_model_weights = sc.broadcast(model.get_weights())\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"77dddfa3-e8df-4e8e-8251-64457f1ebf80\",\n   \"metadata\": {},\n   \"source\": [\n    \"Load the data and save the datasets to one Parquet file.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 9,\n   \"id\": \"c0738bec-97d4-4946-8c49-5e6d07ff1afc\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Image count: 3670\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"import pathlib\\n\",\n    \"dataset_url = \\\"https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz\\\"\\n\",\n    \"data_dir = tf.keras.utils.get_file(origin=dataset_url,\\n\",\n    \"                                   fname='flower_photos',\\n\",\n    \"                                   untar=True)\\n\",\n    \"data_dir = pathlib.Path(data_dir)\\n\",\n    \"image_count = len(list(data_dir.glob('*/*.jpg')))\\n\",\n    \"print(f\\\"Image count: {image_count}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 10,\n   \"id\": \"d54f470a-d308-4426-8ed0-33f95155bb4f\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import os\\n\",\n    \"files = [os.path.join(dp, f) for dp, dn, filenames in os.walk(data_dir) for f in filenames if os.path.splitext(f)[1] == '.jpg']\\n\",\n    \"files = files[:2048]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 11,\n   \"id\": \"64f94ee0-f1ea-47f6-a77e-be8da5d1b87a\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"image_data = []\\n\",\n    \"for file in files:\\n\",\n    \"    img = Image.open(file)\\n\",\n    \"    img = img.resize([224, 224])\\n\",\n    \"    data = np.asarray(img, dtype=\\\"float32\\\").reshape([224*224*3])\\n\",\n    \"\\n\",\n    \"    image_data.append({\\\"data\\\": data})\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 12,\n   \"id\": \"b4ae1a98\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"pd.DataFrame(image_data, columns=['data']).to_parquet(data_path)\\n\",\n    \"\\n\",\n    \"if on_databricks:\\n\",\n    \"    dbutils.fs.mkdirs(\\\"/FileStore/spark-dl-datasets\\\")\\n\",\n    \"    shutil.copy(local_file_path, \\\"/dbfs/FileStore/{}\\\".format(data_path))\\n\",\n    \"    data_path = \\\"/dbfs/FileStore/{}\\\".format(data_path)\\n\",\n    \"elif on_dataproc:\\n\",\n    \"    data_dir = \\\"/mnt/gcs/spark-dl/spark-dl-datasets\\\"\\n\",\n    \"    os.mkdir(data_dir) if not os.path.exists(data_dir) else None\\n\",\n    \"    shutil.copy(local_file_path, \\\"/mnt/gcs/spark-dl/\\\" + data_path)\\n\",\n    \"    data_path = \\\"file:///mnt/gcs/spark-dl/\\\" + data_path\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"f2414b0f-58f2-4e4a-9d09-8ea95b38d413\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Save Model\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 13,\n   \"id\": \"670328e3-7274-4d78-b315-487750166a3f\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"model_path = 'models/resnet50_model.keras'\\n\",\n    \"model.save(model_path)\\n\",\n    \"\\n\",\n    \"# For cloud environments, copy the model to the distributed file system.\\n\",\n    \"if on_databricks:\\n\",\n    \"    dbutils.fs.mkdirs(\\\"/FileStore/spark-dl-models\\\")\\n\",\n    \"    dbfs_model_path = \\\"/dbfs/FileStore/spark-dl-models/resnet50_model.keras\\\"\\n\",\n    \"    shutil.copy(model_path, dbfs_model_path)\\n\",\n    \"    model_path = dbfs_model_path\\n\",\n    \"elif on_dataproc:\\n\",\n    \"    # GCS is mounted at /mnt/gcs by the init script\\n\",\n    \"    models_dir = \\\"/mnt/gcs/spark-dl/models\\\"\\n\",\n    \"    os.mkdir(models_dir) if not os.path.exists(models_dir) else None\\n\",\n    \"    gcs_model_path = models_dir + \\\"/resnet50_model.keras\\\"\\n\",\n    \"    shutil.copy(model_path, gcs_model_path)\\n\",\n    \"    model_path = gcs_model_path\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"b827ad56-1af0-41b7-be68-94bd203a2a70\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Load the data into Spark DataFrames\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 14,\n   \"id\": \"8ddc22d0-b88a-4906-bd47-bf247e34feeb\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2048\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"df = spark.read.parquet(data_path)\\n\",\n    \"print(df.count())\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"865929b0-b016-4de4-996d-7f16176cf49c\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"source\": [\n    \"### Model inference via Pandas UDF\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"b1f5a747\",\n   \"metadata\": {},\n   \"source\": [\n    \"Define the function to parse the input data.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 15,\n   \"id\": \"a67b3128-13c1-44f1-a0c0-7cf7a836fee3\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def parse_image(image_data):\\n\",\n    \"    image = tf.image.convert_image_dtype(\\n\",\n    \"        image_data, dtype=tf.float32) * (2. / 255) - 1\\n\",\n    \"    image = tf.reshape(image, [224, 224, 3])\\n\",\n    \"    return image\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"024e4ba2\",\n   \"metadata\": {},\n   \"source\": [\n    \"Define the function for model inference.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 16,\n   \"id\": \"7b33185f-6d1e-4ca9-9757-fdc3d736496b\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"@pandas_udf(ArrayType(FloatType()))\\n\",\n    \"def pandas_predict_udf(iter: Iterator[Tuple[pd.Series]]) -> Iterator[pd.Series]:\\n\",\n    \"\\n\",\n    \"    # Enable GPU memory growth to avoid CUDA OOM\\n\",\n    \"    gpus = tf.config.experimental.list_physical_devices('GPU')\\n\",\n    \"    if gpus:\\n\",\n    \"        try:\\n\",\n    \"            for gpu in gpus:\\n\",\n    \"                tf.config.experimental.set_memory_growth(gpu, True)\\n\",\n    \"        except RuntimeError as e:\\n\",\n    \"            print(e)\\n\",\n    \"\\n\",\n    \"    batch_size = 64\\n\",\n    \"    model = ResNet50(weights=None)\\n\",\n    \"    model.set_weights(bc_model_weights.value)\\n\",\n    \"    for image_batch in iter:\\n\",\n    \"        images = np.vstack(image_batch)\\n\",\n    \"        dataset = tf.data.Dataset.from_tensor_slices(images)\\n\",\n    \"        dataset = dataset.map(parse_image, num_parallel_calls=8).prefetch(\\n\",\n    \"            5000).batch(batch_size)\\n\",\n    \"        preds = model.predict(dataset)\\n\",\n    \"        yield pd.Series(list(preds))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"08190547\",\n   \"metadata\": {},\n   \"source\": [\n    \"Run model inference and save the results to Parquet.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 17,\n   \"id\": \"ad8c05da-db38-45ef-81d0-1f862f575ced\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 49.7 ms, sys: 17.6 ms, total: 67.3 ms\\n\",\n      \"Wall time: 15.1 s\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"predictions_1 = df.select(pandas_predict_udf(col(\\\"data\\\")).alias(\\\"prediction\\\"))\\n\",\n    \"results = predictions_1.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 18,\n   \"id\": \"08cb2a10\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 6:============================================>              (3 + 1) / 4]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+----------------------------------------------------------------------------------------------------+\\n\",\n      \"|                                                                                          prediction|\\n\",\n      \"+----------------------------------------------------------------------------------------------------+\\n\",\n      \"|[1.2964063E-4, 2.4653607E-4, 6.7508765E-5, 1.2236452E-4, 5.7346635E-5, 3.9642912E-4, 7.033199E-6,...|\\n\",\n      \"|[4.4486973E-5, 3.5260408E-4, 4.684452E-5, 8.12069E-5, 3.179397E-5, 1.9187202E-4, 7.887208E-6, 1.3...|\\n\",\n      \"|[1.059436E-4, 2.2737762E-4, 3.0225037E-5, 6.550149E-5, 2.3658315E-5, 3.7172026E-4, 3.353684E-6, 2...|\\n\",\n      \"|[2.0393689E-5, 2.2818097E-4, 7.841931E-5, 6.991323E-5, 4.704759E-5, 9.822018E-5, 5.5858673E-6, 2....|\\n\",\n      \"|[1.13108545E-4, 2.3128217E-4, 5.283139E-5, 1.0866656E-4, 4.0229144E-5, 3.7223354E-4, 5.5677583E-6...|\\n\",\n      \"|[9.1271184E-5, 2.0681013E-4, 4.5193243E-5, 7.6812066E-5, 3.2361808E-5, 3.399333E-4, 3.8415465E-6,...|\\n\",\n      \"|[1.0792112E-4, 3.7743401E-4, 7.618583E-5, 1.24259E-4, 4.7426664E-5, 3.3307416E-4, 1.0592865E-5, 9...|\\n\",\n      \"|[2.2220212E-5, 2.7357432E-4, 3.8200575E-5, 6.235621E-5, 1.7954999E-5, 1.7249273E-4, 6.021971E-6, ...|\\n\",\n      \"|[1.1044029E-4, 2.8961376E-4, 4.2384647E-5, 1.0728626E-4, 3.0468744E-5, 4.796082E-4, 6.4537376E-6,...|\\n\",\n      \"|[9.68494E-5, 2.0567125E-4, 7.450887E-5, 1.13256065E-4, 4.609738E-5, 2.8675792E-4, 5.603957E-6, 5....|\\n\",\n      \"|[7.420906E-5, 3.2883475E-4, 1.3444667E-4, 1.7758778E-4, 8.4717096E-5, 2.2534849E-4, 1.3623082E-5,...|\\n\",\n      \"|[8.755989E-5, 2.7312606E-4, 3.59614E-5, 7.7967066E-5, 2.3571063E-5, 3.6875304E-4, 3.5629025E-6, 3...|\\n\",\n      \"|[9.7425895E-5, 2.7611412E-4, 5.74094E-5, 1.1035101E-4, 3.8303257E-5, 3.4981826E-4, 6.167147E-6, 4...|\\n\",\n      \"|[6.92996E-5, 2.5326438E-4, 5.063317E-5, 1.1494952E-4, 3.0212495E-5, 2.7857954E-4, 5.0324948E-6, 5...|\\n\",\n      \"|[4.2184765E-5, 2.4904116E-4, 1.237565E-4, 1.4271903E-4, 7.3208634E-5, 1.6054673E-4, 7.938735E-6, ...|\\n\",\n      \"|[2.719573E-5, 3.8372327E-4, 1.291892E-4, 1.5711001E-4, 7.3108524E-5, 8.553368E-5, 1.2617156E-5, 1...|\\n\",\n      \"|[3.0565643E-5, 3.55542E-4, 1.5949155E-4, 2.1368133E-4, 8.043127E-5, 1.02662845E-4, 1.3859853E-5, ...|\\n\",\n      \"|[3.311506E-5, 2.8069926E-4, 1.7956384E-4, 2.0205336E-4, 1.3665091E-4, 1.0115404E-4, 3.409792E-5, ...|\\n\",\n      \"|[4.573667E-5, 2.888326E-4, 2.3792271E-4, 2.460216E-4, 1.2164583E-4, 1.3814335E-4, 1.6352218E-5, 2...|\\n\",\n      \"|[1.2279079E-4, 2.8073761E-4, 6.365874E-5, 1.0251792E-4, 4.3527238E-5, 3.914249E-4, 8.236801E-6, 6...|\\n\",\n      \"+----------------------------------------------------------------------------------------------------+\\n\",\n      \"only showing top 20 rows\\n\",\n      \"\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"predictions_1.show(truncate=100)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 19,\n   \"id\": \"40799f8e-443e-40ca-919b-391f901cb3f4\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"predictions_1.write.mode(\\\"overwrite\\\").parquet(output_file_path + \\\"_1\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"e7a69aa9\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Inference using Spark DL API\\n\",\n    \"\\n\",\n    \"Distributed inference using the PySpark [predict_batch_udf](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.functions.predict_batch_udf.html#pyspark.ml.functions.predict_batch_udf):\\n\",\n    \"\\n\",\n    \"- predict_batch_fn uses Tensorflow APIs to load the model and return a predict function which operates on numpy arrays \\n\",\n    \"- predict_batch_udf will convert the Spark DataFrame columns into numpy input batches for the predict function\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 20,\n   \"id\": \"dda88b46-6300-4bf7-bc10-7403f4fbbf92\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def predict_batch_fn():\\n\",\n    \"    import tensorflow as tf\\n\",\n    \"    from tensorflow.keras.applications.resnet50 import ResNet50\\n\",\n    \"\\n\",\n    \"    # Enable GPU memory growth\\n\",\n    \"    gpus = tf.config.experimental.list_physical_devices('GPU')\\n\",\n    \"    if gpus:\\n\",\n    \"        try:\\n\",\n    \"            for gpu in gpus:\\n\",\n    \"                tf.config.experimental.set_memory_growth(gpu, True)\\n\",\n    \"        except RuntimeError as e:\\n\",\n    \"            print(e)\\n\",\n    \"\\n\",\n    \"    model = ResNet50()\\n\",\n    \"    def predict(inputs):\\n\",\n    \"        inputs = inputs * (2. / 255) - 1\\n\",\n    \"        return model.predict(inputs)\\n\",\n    \"    return predict\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 21,\n   \"id\": \"cff0e851-563d-40b6-9d05-509c22b3b7f9\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"classify = predict_batch_udf(predict_batch_fn,\\n\",\n    \"                             input_tensor_shapes=[[224, 224, 3]],\\n\",\n    \"                             return_type=ArrayType(FloatType()),\\n\",\n    \"                             batch_size=50)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 22,\n   \"id\": \"aa7c156f-e2b3-4837-9427-ccf3a5720412\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"df = spark.read.parquet(data_path)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 23,\n   \"id\": \"80bc50ad-eaf5-4fce-a354-5e17d65e2da5\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 61.7 ms, sys: 23.1 ms, total: 84.8 ms\\n\",\n      \"Wall time: 16.7 s\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"# first pass caches model/fn\\n\",\n    \"predictions_2 = df.select(classify(struct(\\\"data\\\")).alias(\\\"prediction\\\"))\\n\",\n    \"results = predictions_2.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 24,\n   \"id\": \"41cace80-7a4b-4929-8e63-9c83f9745e02\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 141 ms, sys: 22.2 ms, total: 163 ms\\n\",\n      \"Wall time: 16 s\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"predictions_2 = df.select(classify(\\\"data\\\").alias(\\\"prediction\\\"))\\n\",\n    \"results = predictions_2.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 25,\n   \"id\": \"56a2ec8a-de09-4d7c-9666-1b3c76f10657\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 52.8 ms, sys: 14.3 ms, total: 67.1 ms\\n\",\n      \"Wall time: 15.5 s\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"predictions_2 = df.select(classify(col(\\\"data\\\")).alias(\\\"prediction\\\"))\\n\",\n    \"results = predictions_2.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 26,\n   \"id\": \"2dcf3791\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 13:===========================================>              (3 + 1) / 4]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+----------------------------------------------------------------------------------------------------+\\n\",\n      \"|                                                                                          prediction|\\n\",\n      \"+----------------------------------------------------------------------------------------------------+\\n\",\n      \"|[1.293178E-4, 2.4644283E-4, 6.760039E-5, 1.2260793E-4, 5.7431564E-5, 3.9597694E-4, 7.0522524E-6, ...|\\n\",\n      \"|[4.4487308E-5, 3.5378174E-4, 4.6667028E-5, 8.102564E-5, 3.168566E-5, 1.9189132E-4, 7.903805E-6, 1...|\\n\",\n      \"|[1.0566196E-4, 2.2684377E-4, 3.00564E-5, 6.5251304E-5, 2.3520754E-5, 3.7116173E-4, 3.331476E-6, 2...|\\n\",\n      \"|[2.0337258E-5, 2.2749524E-4, 7.8351426E-5, 6.991163E-5, 4.7081656E-5, 9.8092445E-5, 5.564894E-6, ...|\\n\",\n      \"|[1.12979564E-4, 2.3172122E-4, 5.2946547E-5, 1.0876398E-4, 4.0259067E-5, 3.7143996E-4, 5.5940513E-...|\\n\",\n      \"|[9.093228E-5, 2.0639994E-4, 4.5151268E-5, 7.666316E-5, 3.2264295E-5, 3.387436E-4, 3.832487E-6, 4....|\\n\",\n      \"|[1.0783461E-4, 3.7850672E-4, 7.660902E-5, 1.2446321E-4, 4.7591406E-5, 3.3328883E-4, 1.067249E-5, ...|\\n\",\n      \"|[2.2258617E-5, 2.7345872E-4, 3.814439E-5, 6.229726E-5, 1.79387E-5, 1.7259057E-4, 6.0371217E-6, 1....|\\n\",\n      \"|[1.1067773E-4, 2.8997674E-4, 4.2570035E-5, 1.0747747E-4, 3.0524247E-5, 4.7921995E-4, 6.489833E-6,...|\\n\",\n      \"|[9.676251E-5, 2.0588847E-4, 7.467098E-5, 1.1326933E-4, 4.6123736E-5, 2.8609246E-4, 5.627118E-6, 5...|\\n\",\n      \"|[7.4104944E-5, 3.290917E-4, 1.3448784E-4, 1.7742367E-4, 8.463227E-5, 2.2462371E-4, 1.3614881E-5, ...|\\n\",\n      \"|[8.7211796E-5, 2.7337394E-4, 3.5953894E-5, 7.7924225E-5, 2.3554327E-5, 3.67775E-4, 3.5652213E-6, ...|\\n\",\n      \"|[9.7237185E-5, 2.762026E-4, 5.7450008E-5, 1.1019135E-4, 3.831896E-5, 3.4878452E-4, 6.1574788E-6, ...|\\n\",\n      \"|[6.938849E-5, 2.5376282E-4, 5.0565883E-5, 1.14880335E-4, 3.0061366E-5, 2.7866007E-4, 5.024482E-6,...|\\n\",\n      \"|[4.2096388E-5, 2.4889092E-4, 1.2363133E-4, 1.4304162E-4, 7.337785E-5, 1.6042824E-4, 7.959722E-6, ...|\\n\",\n      \"|[2.730248E-5, 3.851789E-4, 1.293143E-4, 1.5753493E-4, 7.302161E-5, 8.547956E-5, 1.26348905E-5, 1....|\\n\",\n      \"|[3.0354899E-5, 3.5562844E-4, 1.6008675E-4, 2.1440513E-4, 8.062159E-5, 1.02023136E-4, 1.3876455E-5...|\\n\",\n      \"|[3.3083066E-5, 2.8158593E-4, 1.7979987E-4, 2.0232225E-4, 1.3704685E-4, 1.0091762E-4, 3.4243407E-5...|\\n\",\n      \"|[4.5485373E-5, 2.878148E-4, 2.3707838E-4, 2.4493985E-4, 1.21028905E-4, 1.3738636E-4, 1.6280053E-5...|\\n\",\n      \"|[1.22468E-4, 2.809503E-4, 6.3342835E-5, 1.021957E-4, 4.3373006E-5, 3.905496E-4, 8.212427E-6, 6.20...|\\n\",\n      \"+----------------------------------------------------------------------------------------------------+\\n\",\n      \"only showing top 20 rows\\n\",\n      \"\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"predictions_2.show(truncate=100)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 27,\n   \"id\": \"fc511eae\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"predictions_2.write.mode(\\\"overwrite\\\").parquet(output_file_path + \\\"_2\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"878ca7fb\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Using Triton Inference Server\\n\",\n    \"In this section, we demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL.  \\n\",\n    \"We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server.  \\n\",\n    \"\\n\",\n    \"The process looks like this:\\n\",\n    \"- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.\\n\",\n    \"- Define a Triton inference function, which contains a client that binds to the local server on a given node and sends inference requests.\\n\",\n    \"- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.\\n\",\n    \"- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.\\n\",\n    \"\\n\",\n    \"<img src=\\\"../images/spark-server.png\\\" alt=\\\"drawing\\\" width=\\\"700\\\"/>\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 28,\n   \"id\": \"2605d134-ef75-4d94-9b16-2c6d85f29bef\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from functools import partial\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"cdded12d\",\n   \"metadata\": {},\n   \"source\": [\n    \"Import the helper class from server_utils.py:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 29,\n   \"id\": \"a2475d41\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"sc.addPyFile(\\\"server_utils.py\\\")\\n\",\n    \"\\n\",\n    \"from server_utils import TritonServerManager\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"1f6701dc\",\n   \"metadata\": {},\n   \"source\": [\n    \"Define the Triton Server function:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 30,\n   \"id\": \"8c8c0744-0558-4dac-bbfe-8bdde4b2af2d\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def triton_server(ports):\\n\",\n    \"    import time\\n\",\n    \"    import signal\\n\",\n    \"    import numpy as np\\n\",\n    \"    import tensorflow as tf\\n\",\n    \"    from tensorflow.keras.applications import ResNet50\\n\",\n    \"    from pytriton.decorators import batch\\n\",\n    \"    from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor\\n\",\n    \"    from pytriton.triton import Triton, TritonConfig\\n\",\n    \"    from pyspark import TaskContext\\n\",\n    \"\\n\",\n    \"    print(f\\\"SERVER: Initializing ResNet on worker {TaskContext.get().partitionId()}.\\\")\\n\",\n    \"\\n\",\n    \"    # Enable GPU memory growth\\n\",\n    \"    gpus = tf.config.experimental.list_physical_devices('GPU')\\n\",\n    \"    if gpus:\\n\",\n    \"        try:\\n\",\n    \"            for gpu in gpus:\\n\",\n    \"                tf.config.experimental.set_memory_growth(gpu, True)\\n\",\n    \"        except RuntimeError as e:\\n\",\n    \"            print(e)\\n\",\n    \"    \\n\",\n    \"    model = ResNet50()\\n\",\n    \"    normalization_layer = tf.keras.layers.Rescaling(scale=2./255, offset=-1)\\n\",\n    \"\\n\",\n    \"    @batch\\n\",\n    \"    def _infer_fn(**inputs):\\n\",\n    \"        images = inputs[\\\"images\\\"]\\n\",\n    \"        normalized_images = normalization_layer(images)\\n\",\n    \"        return {\\n\",\n    \"            \\\"preds\\\": model.predict(normalized_images),\\n\",\n    \"        }\\n\",\n    \"\\n\",\n    \"    workspace_path = f\\\"/tmp/triton_{time.strftime('%m_%d_%M_%S')}\\\"\\n\",\n    \"    triton_conf = TritonConfig(http_port=ports[0], grpc_port=ports[1], metrics_port=ports[2])\\n\",\n    \"    with Triton(config=triton_conf, workspace=workspace_path) as triton:\\n\",\n    \"        triton.bind(\\n\",\n    \"            model_name=\\\"ResNet50\\\",\\n\",\n    \"            infer_func=_infer_fn,\\n\",\n    \"            inputs=[\\n\",\n    \"                Tensor(name=\\\"images\\\", dtype=np.float32, shape=(224, 224, 3)),\\n\",\n    \"            ],\\n\",\n    \"            outputs=[\\n\",\n    \"                Tensor(name=\\\"preds\\\", dtype=np.float32, shape=(-1,)),\\n\",\n    \"            ],\\n\",\n    \"            config=ModelConfig(\\n\",\n    \"                max_batch_size=100,\\n\",\n    \"                batcher=DynamicBatcher(max_queue_delay_microseconds=5000),  # 5ms\\n\",\n    \"            ),\\n\",\n    \"            strict=True,\\n\",\n    \"        )\\n\",\n    \"\\n\",\n    \"        def _stop_triton(signum, frame):\\n\",\n    \"            # The server manager sends SIGTERM to stop the server; this function ensures graceful cleanup.\\n\",\n    \"            print(\\\"SERVER: Received SIGTERM. Stopping Triton server.\\\")\\n\",\n    \"            triton.stop()\\n\",\n    \"\\n\",\n    \"        signal.signal(signal.SIGTERM, _stop_triton)\\n\",\n    \"\\n\",\n    \"        print(\\\"SERVER: Serving inference\\\")\\n\",\n    \"        triton.serve()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"d74f7037\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Start Triton servers\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"4bf99bde\",\n   \"metadata\": {},\n   \"source\": [\n    \"The `TritonServerManager` will handle the lifecycle of Triton server instances across the Spark cluster:\\n\",\n    \"- Find available ports for HTTP/gRPC/metrics\\n\",\n    \"- Deploy a server on each node via stage-level scheduling\\n\",\n    \"- Gracefully shutdown servers across nodes\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"2309a55c\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"model_name = \\\"ResNet50\\\"\\n\",\n    \"server_manager = TritonServerManager(model_name=model_name)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"205fa1e8\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2025-02-07 11:03:44,809 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\\n\",\n      \"2025-02-07 11:03:44,810 - INFO - Starting 1 servers.\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'cb4ae00-lcedt': (2020631, [7000, 7001, 7002])}\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"# Returns {'hostname', (server_pid, [http_port, grpc_port, metrics_port])}\\n\",\n    \"server_manager.start_servers(triton_server)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"e49ebdbe\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Define client function\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"55c42174\",\n   \"metadata\": {},\n   \"source\": [\n    \"Get the hostname -> url mapping from the server manager:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"9e4ff20e\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"host_to_http_url = server_manager.host_to_http_url  # or server_manager.host_to_grpc_url\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"481dbd42\",\n   \"metadata\": {},\n   \"source\": [\n    \"Define the Triton inference function, which returns a predict function for batch inference through the server:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 35,\n   \"id\": \"a5ab49bb\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def triton_fn(model_name, host_to_url):\\n\",\n    \"    import socket\\n\",\n    \"    import numpy as np\\n\",\n    \"    from pytriton.client import ModelClient\\n\",\n    \"\\n\",\n    \"    url = host_to_url[socket.gethostname()]\\n\",\n    \"    print(f\\\"CLIENT: Connecting to {model_name} at {url}\\\")\\n\",\n    \"\\n\",\n    \"    def infer_batch(inputs):\\n\",\n    \"        with ModelClient(url, model_name, inference_timeout_s=240) as client:\\n\",\n    \"            result_data = client.infer_batch(inputs)\\n\",\n    \"            return result_data[\\\"preds\\\"]\\n\",\n    \"            \\n\",\n    \"    return infer_batch\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 37,\n   \"id\": \"9fabcaeb-5a44-42bb-8097-5dbc2d0cee3e\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"classify = predict_batch_udf(partial(triton_fn, model_name=model_name, host_to_url=host_to_http_url),\\n\",\n    \"                             input_tensor_shapes=[[224, 224, 3]],\\n\",\n    \"                             return_type=ArrayType(FloatType()),\\n\",\n    \"                             batch_size=50)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"fcd2328e\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Load DataFrame\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 36,\n   \"id\": \"bbfc9009\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"df = spark.read.parquet(data_path)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"8c07365c-0a14-49b3-9bd8-cfb35f48b089\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Run inference\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 38,\n   \"id\": \"e595473d-1a5d-46a6-a6ba-89d2ea903de9\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 60.9 ms, sys: 21.3 ms, total: 82.2 ms\\n\",\n      \"Wall time: 18.4 s\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"# first pass caches model/fn\\n\",\n    \"predictions_3 = df.select(classify(struct(\\\"data\\\")).alias(\\\"prediction\\\"))\\n\",\n    \"results = predictions_3.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 39,\n   \"id\": \"5f66d468-e0b1-4589-8606-b3848063a823\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 46.3 ms, sys: 16.1 ms, total: 62.4 ms\\n\",\n      \"Wall time: 12.3 s\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"predictions_3 = df.select(classify(\\\"data\\\").alias(\\\"prediction\\\"))\\n\",\n    \"results = predictions_3.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 40,\n   \"id\": \"632c4c3a-fa52-4c3d-b71e-7526286e353a\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 57.5 ms, sys: 16.4 ms, total: 73.9 ms\\n\",\n      \"Wall time: 12.4 s\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"predictions_3 = df.select(classify(col(\\\"data\\\")).alias(\\\"prediction\\\"))\\n\",\n    \"results = predictions_3.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 41,\n   \"id\": \"49870e39\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 22:===========================================>              (3 + 1) / 4]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+----------------------------------------------------------------------------------------------------+\\n\",\n      \"|                                                                                          prediction|\\n\",\n      \"+----------------------------------------------------------------------------------------------------+\\n\",\n      \"|[1.293178E-4, 2.4644283E-4, 6.760039E-5, 1.2260793E-4, 5.7431564E-5, 3.9597694E-4, 7.0522524E-6, ...|\\n\",\n      \"|[4.4487308E-5, 3.5378174E-4, 4.6667028E-5, 8.102564E-5, 3.168566E-5, 1.9189132E-4, 7.903805E-6, 1...|\\n\",\n      \"|[1.0566196E-4, 2.2684377E-4, 3.00564E-5, 6.5251304E-5, 2.3520754E-5, 3.7116173E-4, 3.331476E-6, 2...|\\n\",\n      \"|[2.0337258E-5, 2.2749524E-4, 7.8351426E-5, 6.991163E-5, 4.7081656E-5, 9.8092445E-5, 5.564894E-6, ...|\\n\",\n      \"|[1.12979564E-4, 2.3172122E-4, 5.2946547E-5, 1.0876398E-4, 4.0259067E-5, 3.7143996E-4, 5.5940513E-...|\\n\",\n      \"|[9.093228E-5, 2.0639994E-4, 4.5151268E-5, 7.666316E-5, 3.2264295E-5, 3.387436E-4, 3.832487E-6, 4....|\\n\",\n      \"|[1.0783461E-4, 3.7850672E-4, 7.660902E-5, 1.2446321E-4, 4.7591406E-5, 3.3328883E-4, 1.067249E-5, ...|\\n\",\n      \"|[2.2258617E-5, 2.7345872E-4, 3.814439E-5, 6.229726E-5, 1.79387E-5, 1.7259057E-4, 6.0371217E-6, 1....|\\n\",\n      \"|[1.1067773E-4, 2.8997674E-4, 4.2570035E-5, 1.0747747E-4, 3.0524247E-5, 4.7921995E-4, 6.489833E-6,...|\\n\",\n      \"|[9.676251E-5, 2.0588847E-4, 7.467098E-5, 1.1326933E-4, 4.6123736E-5, 2.8609246E-4, 5.627118E-6, 5...|\\n\",\n      \"|[7.4104944E-5, 3.290917E-4, 1.3448784E-4, 1.7742367E-4, 8.463227E-5, 2.2462371E-4, 1.3614881E-5, ...|\\n\",\n      \"|[8.7211796E-5, 2.7337394E-4, 3.5953894E-5, 7.7924225E-5, 2.3554327E-5, 3.67775E-4, 3.5652213E-6, ...|\\n\",\n      \"|[9.7237185E-5, 2.762026E-4, 5.7450008E-5, 1.1019135E-4, 3.831896E-5, 3.4878452E-4, 6.1574788E-6, ...|\\n\",\n      \"|[6.938849E-5, 2.5376282E-4, 5.0565883E-5, 1.14880335E-4, 3.0061366E-5, 2.7866007E-4, 5.024482E-6,...|\\n\",\n      \"|[4.2096388E-5, 2.4889092E-4, 1.2363133E-4, 1.4304162E-4, 7.337785E-5, 1.6042824E-4, 7.959722E-6, ...|\\n\",\n      \"|[2.730248E-5, 3.851789E-4, 1.293143E-4, 1.5753493E-4, 7.302161E-5, 8.547956E-5, 1.26348905E-5, 1....|\\n\",\n      \"|[3.0354899E-5, 3.5562844E-4, 1.6008675E-4, 2.1440513E-4, 8.062159E-5, 1.02023136E-4, 1.3876455E-5...|\\n\",\n      \"|[3.3083066E-5, 2.8158593E-4, 1.7979987E-4, 2.0232225E-4, 1.3704685E-4, 1.0091762E-4, 3.4243407E-5...|\\n\",\n      \"|[4.5485373E-5, 2.878148E-4, 2.3707838E-4, 2.4493985E-4, 1.21028905E-4, 1.3738636E-4, 1.6280053E-5...|\\n\",\n      \"|[1.22468E-4, 2.809503E-4, 6.3342835E-5, 1.021957E-4, 4.3373006E-5, 3.905496E-4, 8.212427E-6, 6.20...|\\n\",\n      \"+----------------------------------------------------------------------------------------------------+\\n\",\n      \"only showing top 20 rows\\n\",\n      \"\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"predictions_3.show(truncate=100)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 42,\n   \"id\": \"86cd59f9\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"predictions_3.write.mode(\\\"overwrite\\\").parquet(output_file_path + \\\"_3\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"4dc06b7e-f750-40b5-9208-a035db11d937\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"source\": [\n    \"#### Stop Triton Server on each executor\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 43,\n   \"id\": \"bbfcaa51-3b9f-43ff-a4a8-4b46766115b8\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2025-02-04 14:03:34,747 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\\n\",\n      \"2025-02-04 14:03:39,935 - INFO - Sucessfully stopped 1 servers.                 \\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"[True]\"\n      ]\n     },\n     \"execution_count\": 43,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"server_manager.stop_servers()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 44,\n   \"id\": \"0d88639b-d934-4eb4-ae2f-cc13b9b10456\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"if not on_databricks: # on databricks, spark.stop() puts the cluster in a bad state\\n\",\n    \"    spark.stop()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"df8cc28a-34d7-479c-be7e-9a380d39e25e\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"spark-dl-tf\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.11.9\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "examples/ML+DL-Examples/Spark-DL/dl_inference/tensorflow/text_classification_tf.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"2cd2accf-5877-4136-a243-7a33a13ce2b4\",\n   \"metadata\": {},\n   \"source\": [\n    \"<img src=\\\"http://developer.download.nvidia.com/notebooks/dlsw-notebooks/tensorrt_torchtrt_efficientnet/nvidia_logo.png\\\" width=\\\"90px\\\">\\n\",\n    \"\\n\",\n    \"# Pyspark TensorFlow Inference\\n\",\n    \"\\n\",\n    \"### Text Classification\\n\",\n    \"In this notebook, we demonstrate training a model to perform sentiment analysis, and using the trained model for distributed inference.  \\n\",\n    \"Based on: https://www.tensorflow.org/tutorials/keras/text_classification\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"bc72d0ed\",\n   \"metadata\": {},\n   \"source\": [\n    \"Note that cuFFT/cuDNN/cuBLAS registration errors are expected (as of `tf=2.17.0`) and will not affect behavior, as noted in [this issue.](https://github.com/tensorflow/tensorflow/issues/62075)  \"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"id\": \"76f0f5df-502f-444e-b2ee-1122e1dea870\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2025-02-04 14:05:12.899608: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\\n\",\n      \"2025-02-04 14:05:12.907256: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\\n\",\n      \"2025-02-04 14:05:12.915374: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\\n\",\n      \"2025-02-04 14:05:12.917743: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\\n\",\n      \"2025-02-04 14:05:12.924372: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\\n\",\n      \"To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\\n\",\n      \"2025-02-04 14:05:13.295411: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"import os\\n\",\n    \"import re\\n\",\n    \"import shutil\\n\",\n    \"import string\\n\",\n    \"import matplotlib.pyplot as plt\\n\",\n    \"\\n\",\n    \"import tensorflow as tf\\n\",\n    \"from tensorflow.keras import layers, losses\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"id\": \"a364ad5f-b269-45b5-ab8b-d8f34fb642b7\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2.17.0\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\\n\",\n      \"I0000 00:00:1738706713.692042 3744395 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\\n\",\n      \"I0000 00:00:1738706713.716276 3744395 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\\n\",\n      \"I0000 00:00:1738706713.719037 3744395 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"print(tf.__version__)\\n\",\n    \"\\n\",\n    \"# Enable GPU memory growth\\n\",\n    \"gpus = tf.config.experimental.list_physical_devices('GPU')\\n\",\n    \"if gpus:\\n\",\n    \"    try:\\n\",\n    \"        for gpu in gpus:\\n\",\n    \"            tf.config.experimental.set_memory_growth(gpu, True)\\n\",\n    \"    except RuntimeError as e:\\n\",\n    \"        print(e)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"b64bb471\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Download and explore the dataset\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"id\": \"d229c1b6-3967-46b5-9ea8-68f4b42dd211\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from datasets import load_dataset\\n\",\n    \"dataset = load_dataset(\\\"imdb\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"id\": \"88f9a92e\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Create directories for our data\\n\",\n    \"base_dir = \\\"spark-dl-datasets/imdb\\\"\\n\",\n    \"if os.environ.get(\\\"DATABRICKS_RUNTIME_VERSION\\\", False):\\n\",\n    \"    # For databricks, use the driver disk rather than Workspace (much faster)\\n\",\n    \"    base_dir = \\\"/local_disk0/\\\" + base_dir\\n\",\n    \"\\n\",\n    \"train_dir = base_dir + \\\"/train\\\"\\n\",\n    \"test_dir = base_dir + \\\"/test\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"id\": \"3f984d5a\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Create directories for positive (1) and negative (0) reviews\\n\",\n    \"for split in [\\\"train\\\", \\\"test\\\"]:\\n\",\n    \"    split_dir = os.path.join(base_dir, split)\\n\",\n    \"    pos_dir = split_dir + \\\"/pos\\\"\\n\",\n    \"    neg_dir = split_dir + \\\"/neg\\\"\\n\",\n    \"\\n\",\n    \"    os.makedirs(pos_dir, exist_ok=True)\\n\",\n    \"    os.makedirs(neg_dir, exist_ok=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"id\": \"6cd2328a\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def write_reviews_to_files(dataset_split, split_name):\\n\",\n    \"    for idx, example in enumerate(dataset_split):\\n\",\n    \"        label_dir = \\\"pos\\\" if example[\\\"label\\\"] == 1 else \\\"neg\\\"\\n\",\n    \"        dir_path = os.path.join(base_dir, split_name, label_dir)\\n\",\n    \"\\n\",\n    \"        file_path = dir_path + f\\\"/review_{idx}.txt\\\"\\n\",\n    \"        with open(file_path, \\\"w\\\", encoding=\\\"utf-8\\\") as f:\\n\",\n    \"            f.write(example[\\\"text\\\"])\\n\",\n    \"\\n\",\n    \"# Write train and test sets\\n\",\n    \"write_reviews_to_files(dataset[\\\"train\\\"], \\\"train\\\")\\n\",\n    \"write_reviews_to_files(dataset[\\\"test\\\"], \\\"test\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"b02fde64\",\n   \"metadata\": {},\n   \"source\": [\n    \"There are 25,000 examples in the training folder, of which we will use 80% (or 20,000) for training, and 5,000 for validation.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"id\": \"5c357f22\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Found 25000 files belonging to 2 classes.\\n\",\n      \"Using 20000 files for training.\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"I0000 00:00:1738706719.326625 3744395 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\\n\",\n      \"I0000 00:00:1738706719.329542 3744395 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\\n\",\n      \"I0000 00:00:1738706719.332409 3744395 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\\n\",\n      \"I0000 00:00:1738706719.451656 3744395 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\\n\",\n      \"I0000 00:00:1738706719.452700 3744395 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\\n\",\n      \"I0000 00:00:1738706719.453630 3744395 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\\n\",\n      \"2025-02-04 14:05:19.454569: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 40337 MB memory:  -> device: 0, name: NVIDIA RTX A6000, pci bus id: 0000:01:00.0, compute capability: 8.6\\n\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Found 25000 files belonging to 2 classes.\\n\",\n      \"Using 5000 files for validation.\\n\",\n      \"Found 25000 files belonging to 2 classes.\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"batch_size = 32\\n\",\n    \"seed = 42\\n\",\n    \"\\n\",\n    \"raw_train_ds = tf.keras.utils.text_dataset_from_directory(\\n\",\n    \"    str(train_dir),\\n\",\n    \"    batch_size=batch_size,\\n\",\n    \"    validation_split=0.2,\\n\",\n    \"    subset=\\\"training\\\",\\n\",\n    \"    seed=seed,\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"raw_val_ds = tf.keras.utils.text_dataset_from_directory(\\n\",\n    \"    str(train_dir),\\n\",\n    \"    batch_size=batch_size,\\n\",\n    \"    validation_split=0.2,\\n\",\n    \"    subset=\\\"validation\\\",\\n\",\n    \"    seed=seed,\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"raw_test_ds = tf.keras.utils.text_dataset_from_directory(\\n\",\n    \"    str(test_dir),\\n\",\n    \"    batch_size=batch_size\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"02994994\",\n   \"metadata\": {},\n   \"source\": [\n    \"We can take a look at a sample of the dataset (note that OUT_OF_RANGE errors are safe to ignore: https://github.com/tensorflow/tensorflow/issues/62963):\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"id\": \"1d528a95\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Review b'I was really, really disappointed with this movie. it started really well, and built up some great atmosphere and suspense, but when it finally got round to revealing the \\\"monster\\\"...it turned out to be just some psycho with skin problems......again. Whoop-de-do. Yet another nutjob movie...like we don\\\\'t already have enough of them.<br /><br />To be fair, the \\\"creep\\\" is genuinely unsettling to look at, and the way he moves and the strange sounds he makes are pretty creepy, but I\\\\'m sick of renting film like this only to discover that the monster is human, albeit a twisted, demented, freakish one. When I saw all the tell-tale rats early on I was hoping for some kind of freaky rat-monster hybrid thing...it was such a let down when the Creep was revealed.<br /><br />On top of this, some of the stuff in this movie makes no sense. (Spoiler) <br /><br />Why the hell does the Creep kill the security Guard? Whats the point, apart from sticking a great honking sign up that says \\\"HI I\\\\'m A PSYCHO AND I LIVE DOWN HERE!\\\"? Its stupid, and only seems to happen to prevent Franka Potente\\\\'s character from getting help.<br /><br />what the hells he been eating down there? I got the impression he was effectively walled in, and only the unexpected opening into that tunnel section let him loose...so has he been munching rats all that time, and if so why do they hang around him so much? Why is he so damn hard to kill? He\\\\'s thin, malnourished and not exactly at peak performance...but seems to keep going despite injuries that are equivalent to those that .cripple the non-psycho characters in the film.<br /><br />The DVD commentary says we are intended to empathise with Creep, but I just find him loathsome. Its an effective enough movie, but it wasted so many opportunities that it makes me sick.'\\n\",\n      \"Label 0\\n\",\n      \"Review b\\\"This has the absolute worst performance from Robert Duval who sounds just like William Buckley throughout the entire film. His hammy melodramatic acting takes away from any dramatic interest. I'm not sure if this was deliberate scene stealing or inadvertent but it's the only thing I can recall from a truly forgettable film. This picture should be shown in every amateur acting class of an example of what not to do. Thank God, Duvall went on to bigger and better things and stopped trying to effect a cultured accent. He is a good character actor but that's about it. Klaus is so much better. His performance is muted and noteworthy.\\\"\\n\",\n      \"Label 0\\n\",\n      \"Review b'A long time ago, in a galaxy far, far away.....There was a boy who was only two years old when the original \\\"Star Wars\\\" film was released. He doesn\\\\'t remember first seeing the movie, but he also doesn\\\\'t remember life before it. He does remember the first \\\"Star Wars\\\" themed gift he got...a shoebox full of action figures from the original set. He was too young to fully appreciate how special that gift would be. But years later, he would get what to this day goes down as one of the best gifts he\\\\'s ever received: another box full of action figures, ten of the final twelve he needed to complete his collection. It\\\\'s now legendary in this boy\\\\'s family how the last action figure he needed, Anakin Skywalker, stopped being produced and carried in stores, and how this boy went for about ten years (until he got into college) trying to track one down and finally bought it from someone on his dorm floor for a bag of beer nuggets (don\\\\'t ask...it\\\\'s a Northern Illinois University thing).<br /><br />I can\\\\'t review \\\"Star Wars\\\" as a movie. It represents absolutely everything good, fun and magical about my childhood. There\\\\'s no separating it in my mind from Christmases, birthdays, summers and winters growing up. In the winter, my friends and I would build snow forts and pretend we were on Hoth (I was always Han Solo). My friends\\\\' dad built them a kick-ass tree house, and that served as the Ewok village. They also had a huge pine tree whose bottom branches were high enough to create a sort of cave underneath it, and this made a great spot to pretend we were in Yoda\\\\'s home. I am unabashedly dorky when it comes to \\\"Star Wars\\\" and I think people either just understand that or they don\\\\'t. I don\\\\'t get the appeal of \\\"Lord of the Rings\\\" or \\\"Star Trek\\\" but I understand the rabid flocks of fans that follow them because I am a rabid fan of George Lucas\\\\'s films.<br /><br />I feel no need to defend my opinion of these movies as some of the greatest of all time. Every time I put them in the DVD player, I feel like I\\\\'m eight years old again, when life was simple and the biggest problem I had was figuring out how I was going to track down a figure of Anakin Skywalker.<br /><br />Grade (for the entire trilogy): A+'\\n\",\n      \"Label 1\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2025-02-04 14:05:20.533703: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"for text_batch, label_batch in raw_train_ds.take(1):\\n\",\n    \"    for i in range(3):\\n\",\n    \"        print(\\\"Review\\\", text_batch.numpy()[i])\\n\",\n    \"        print(\\\"Label\\\", label_batch.numpy()[i])\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"4bca98b1\",\n   \"metadata\": {},\n   \"source\": [\n    \"Notice the reviews contain raw text (with punctuation and occasional HTML tags like \\\\<br/>\\\\). We will show how to handle these in the following section.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 9,\n   \"id\": \"f8921ed2\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Label 0 corresponds to neg\\n\",\n      \"Label 1 corresponds to pos\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"print(\\\"Label 0 corresponds to\\\", raw_train_ds.class_names[0])\\n\",\n    \"print(\\\"Label 1 corresponds to\\\", raw_train_ds.class_names[1])\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"f6cf0e47\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Prepare the dataset for training\\n\",\n    \"\\n\",\n    \"Next, we will standardize, tokenize, and vectorize the data using the tf.keras.layers.TextVectorization layer.  \\n\",\n    \"We will write a custom standardization function to remove the HTML.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 10,\n   \"id\": \"cb141709-fcc1-4cee-bc98-9c89aaba8648\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def custom_standardization(input_data):\\n\",\n    \"    lowercase = tf.strings.lower(input_data)\\n\",\n    \"    stripped_html = tf.strings.regex_replace(lowercase, \\\"<br />\\\", \\\" \\\")\\n\",\n    \"    return tf.strings.regex_replace(\\n\",\n    \"        stripped_html, \\\"[%s]\\\" % re.escape(string.punctuation), \\\"\\\"\\n\",\n    \"    )\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"b35e36a2\",\n   \"metadata\": {},\n   \"source\": [\n    \"Next, we will create a TextVectorization layer to standardize, tokenize, and vectorize our data.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 11,\n   \"id\": \"d4e80ea9-536a-4ebc-8b35-1eca73dbba7d\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"max_features = 10000\\n\",\n    \"sequence_length = 250\\n\",\n    \"\\n\",\n    \"vectorize_layer = layers.TextVectorization(\\n\",\n    \"    standardize=custom_standardization,\\n\",\n    \"    max_tokens=max_features,\\n\",\n    \"    output_mode=\\\"int\\\",\\n\",\n    \"    output_sequence_length=sequence_length,\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"879fbc3f\",\n   \"metadata\": {},\n   \"source\": [\n    \"Next, we will call adapt to fit the state of the preprocessing layer to the dataset. This will cause the model to build an index of strings to integers.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 12,\n   \"id\": \"ad1e5d81-7dae-4b08-b520-ca45501b9510\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2025-02-04 14:05:22.003236: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"# Make a text-only dataset (without labels), then call adapt\\n\",\n    \"train_text = raw_train_ds.map(lambda x, y: x)\\n\",\n    \"vectorize_layer.adapt(train_text)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"ad1e5d81-7dae-4b08-b520-ca45501b9510\",\n   \"metadata\": {},\n   \"source\": [\n    \"Let's create a function to see the result of using this layer to preprocess some data.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 13,\n   \"id\": \"80f243f5-edd3-4e1c-bddc-abc1cc6673ef\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def vectorize_text(text, label):\\n\",\n    \"    text = tf.expand_dims(text, -1)\\n\",\n    \"    return vectorize_layer(text), label\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 14,\n   \"id\": \"8f37e95c-515c-4edb-a1ee-fc47be5df4b9\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Review tf.Tensor(b\\\"To describe this film as garbage is unfair. At least rooting through garbage can be an absorbing hobby. This flick was neither absorbing nor entertaining.<br /><br />Kevin Bacon can act superbly given the chance, so no doubt had an IRS bill to settle when he agreed to this dire screenplay. The mad scientist story of 'Hollow Man' has been told before, been told better, and been told without resorting to so many ludicrously expensive special effects.<br /><br />Most of those special effects seem to be built around the transparent anatomical dolls of men, women and dogs you could buy in the early seventies. In the UK they were marketed as 'The Transparent Man (/Woman/Dog)' which is maybe where they got the title for this film.<br /><br />Clever special effects, dire script, non-existent plot.<br /><br />\\\", shape=(), dtype=string)\\n\",\n      \"Label neg\\n\",\n      \"Vectorized review (<tf.Tensor: shape=(1, 250), dtype=int64, numpy=\\n\",\n      \"array([[   6, 1507,   11,   19,   14, 1184,    7, 5230,   30,  217, 5821,\\n\",\n      \"         139, 1184,   68,   26,   33, 6676,    1,   11,  512,   13, 1078,\\n\",\n      \"        6676,  888,  439, 1727, 5292,   68,  503, 3597,  333,    2,  558,\\n\",\n      \"          37,   56,  797,   64,   33, 8270,  978,    6, 3956,   51,   27,\\n\",\n      \"        4531,    6,   11, 3756,  907,    2, 1106, 1660,   63,    5, 3514,\\n\",\n      \"         134,   43,   74,  566,  155,   74,  566,  122,    3,   74,  566,\\n\",\n      \"         204,    1,    6,   37,  106,    1, 3152,  307,  293,   88,    5,\\n\",\n      \"         143,  307,  293,  294,    6,   26, 2250,  183,    2, 7541,    1,\\n\",\n      \"        4379,    5,  352,  362,    3, 2312,   22,   99,  756,    8,    2,\\n\",\n      \"         402, 3887,    8,    2, 2142,   34,   65,    1,   14,    2, 7541,\\n\",\n      \"         134,    1,   61,    7,  271,  111,   34,  182,    2,  409,   15,\\n\",\n      \"          11,   19, 1066,  307,  293, 3756,  223, 2939,  112,    0,    0,\\n\",\n      \"           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,\\n\",\n      \"           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,\\n\",\n      \"           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,\\n\",\n      \"           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,\\n\",\n      \"           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,\\n\",\n      \"           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,\\n\",\n      \"           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,\\n\",\n      \"           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,\\n\",\n      \"           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,\\n\",\n      \"           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,\\n\",\n      \"           0,    0,    0,    0,    0,    0,    0,    0]])>, <tf.Tensor: shape=(), dtype=int32, numpy=0>)\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"# retrieve a batch (of 32 reviews and labels) from the dataset\\n\",\n    \"text_batch, label_batch = next(iter(raw_train_ds))\\n\",\n    \"first_review, first_label = text_batch[0], label_batch[0]\\n\",\n    \"print(\\\"Review\\\", first_review)\\n\",\n    \"print(\\\"Label\\\", raw_train_ds.class_names[first_label])\\n\",\n    \"print(\\\"Vectorized review\\\", vectorize_text(first_review, first_label))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"680f53bb\",\n   \"metadata\": {},\n   \"source\": [\n    \"We can lookup the token (string) that each integer corresponds to by calling .get_vocabulary() on the layer.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 15,\n   \"id\": \"60c9208a-39ac-4e6c-a603-61038cdf3d10\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"1287 --->  nowhere\\n\",\n      \" 313 --->  house\\n\",\n      \"Vocabulary size: 10000\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"print(\\\"1287 ---> \\\",vectorize_layer.get_vocabulary()[1287])\\n\",\n    \"print(\\\" 313 ---> \\\",vectorize_layer.get_vocabulary()[313])\\n\",\n    \"print('Vocabulary size: {}'.format(len(vectorize_layer.get_vocabulary())))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 16,\n   \"id\": \"3cf90d4b-8dae-44b2-b32b-80cb0092c430\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"train_ds = raw_train_ds.map(vectorize_text)\\n\",\n    \"val_ds = raw_val_ds.map(vectorize_text)\\n\",\n    \"test_ds = raw_test_ds.map(vectorize_text)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"b3db3f77\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Configure the dataset for performance\\n\",\n    \"\\n\",\n    \"These are two important methods you should use when loading data to make sure that I/O does not become blocking.\\n\",\n    \"\\n\",\n    \"`.cache()` keeps data in memory after it's loaded off disk. This will ensure the dataset does not become a bottleneck while training your model. If your dataset is too large to fit into memory, you can also use this method to create a performant on-disk cache, which is more efficient to read than many small files.\\n\",\n    \"\\n\",\n    \"`.prefetch()` overlaps data preprocessing and model execution while training.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 17,\n   \"id\": \"115a5aba-8a00-458f-be25-0aae9f55de22\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"AUTOTUNE = tf.data.AUTOTUNE\\n\",\n    \"\\n\",\n    \"train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)\\n\",\n    \"val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)\\n\",\n    \"test_ds = test_ds.cache().prefetch(buffer_size=AUTOTUNE)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"0d6d6692\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Create the model\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 18,\n   \"id\": \"d64f4495-102d-4244-9b42-1ba9976a366e\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"embedding_dim = 16\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 19,\n   \"id\": \"3dc95d22-935f-4091-b0ee-da95174eb9a0\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/html\": [\n       \"<pre style=\\\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\\\"><span style=\\\"font-weight: bold\\\">Model: \\\"sequential\\\"</span>\\n\",\n       \"</pre>\\n\"\n      ],\n      \"text/plain\": [\n       \"\\u001b[1mModel: \\\"sequential\\\"\\u001b[0m\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/html\": [\n       \"<pre style=\\\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\\\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\\n\",\n       \"┃<span style=\\\"font-weight: bold\\\"> Layer (type)                    </span>┃<span style=\\\"font-weight: bold\\\"> Output Shape           </span>┃<span style=\\\"font-weight: bold\\\">       Param # </span>┃\\n\",\n       \"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\\n\",\n       \"│ embedding (<span style=\\\"color: #0087ff; text-decoration-color: #0087ff\\\">Embedding</span>)           │ ?                      │   <span style=\\\"color: #00af00; text-decoration-color: #00af00\\\">0</span> (unbuilt) │\\n\",\n       \"├─────────────────────────────────┼────────────────────────┼───────────────┤\\n\",\n       \"│ dropout (<span style=\\\"color: #0087ff; text-decoration-color: #0087ff\\\">Dropout</span>)               │ ?                      │   <span style=\\\"color: #00af00; text-decoration-color: #00af00\\\">0</span> (unbuilt) │\\n\",\n       \"├─────────────────────────────────┼────────────────────────┼───────────────┤\\n\",\n       \"│ global_average_pooling1d        │ ?                      │   <span style=\\\"color: #00af00; text-decoration-color: #00af00\\\">0</span> (unbuilt) │\\n\",\n       \"│ (<span style=\\\"color: #0087ff; text-decoration-color: #0087ff\\\">GlobalAveragePooling1D</span>)        │                        │               │\\n\",\n       \"├─────────────────────────────────┼────────────────────────┼───────────────┤\\n\",\n       \"│ dropout_1 (<span style=\\\"color: #0087ff; text-decoration-color: #0087ff\\\">Dropout</span>)             │ ?                      │   <span style=\\\"color: #00af00; text-decoration-color: #00af00\\\">0</span> (unbuilt) │\\n\",\n       \"├─────────────────────────────────┼────────────────────────┼───────────────┤\\n\",\n       \"│ dense (<span style=\\\"color: #0087ff; text-decoration-color: #0087ff\\\">Dense</span>)                   │ ?                      │   <span style=\\\"color: #00af00; text-decoration-color: #00af00\\\">0</span> (unbuilt) │\\n\",\n       \"└─────────────────────────────────┴────────────────────────┴───────────────┘\\n\",\n       \"</pre>\\n\"\n      ],\n      \"text/plain\": [\n       \"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\\n\",\n       \"┃\\u001b[1m \\u001b[0m\\u001b[1mLayer (type)                   \\u001b[0m\\u001b[1m \\u001b[0m┃\\u001b[1m \\u001b[0m\\u001b[1mOutput Shape          \\u001b[0m\\u001b[1m \\u001b[0m┃\\u001b[1m \\u001b[0m\\u001b[1m      Param #\\u001b[0m\\u001b[1m \\u001b[0m┃\\n\",\n       \"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\\n\",\n       \"│ embedding (\\u001b[38;5;33mEmbedding\\u001b[0m)           │ ?                      │   \\u001b[38;5;34m0\\u001b[0m (unbuilt) │\\n\",\n       \"├─────────────────────────────────┼────────────────────────┼───────────────┤\\n\",\n       \"│ dropout (\\u001b[38;5;33mDropout\\u001b[0m)               │ ?                      │   \\u001b[38;5;34m0\\u001b[0m (unbuilt) │\\n\",\n       \"├─────────────────────────────────┼────────────────────────┼───────────────┤\\n\",\n       \"│ global_average_pooling1d        │ ?                      │   \\u001b[38;5;34m0\\u001b[0m (unbuilt) │\\n\",\n       \"│ (\\u001b[38;5;33mGlobalAveragePooling1D\\u001b[0m)        │                        │               │\\n\",\n       \"├─────────────────────────────────┼────────────────────────┼───────────────┤\\n\",\n       \"│ dropout_1 (\\u001b[38;5;33mDropout\\u001b[0m)             │ ?                      │   \\u001b[38;5;34m0\\u001b[0m (unbuilt) │\\n\",\n       \"├─────────────────────────────────┼────────────────────────┼───────────────┤\\n\",\n       \"│ dense (\\u001b[38;5;33mDense\\u001b[0m)                   │ ?                      │   \\u001b[38;5;34m0\\u001b[0m (unbuilt) │\\n\",\n       \"└─────────────────────────────────┴────────────────────────┴───────────────┘\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/html\": [\n       \"<pre style=\\\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\\\"><span style=\\\"font-weight: bold\\\"> Total params: </span><span style=\\\"color: #00af00; text-decoration-color: #00af00\\\">0</span> (0.00 B)\\n\",\n       \"</pre>\\n\"\n      ],\n      \"text/plain\": [\n       \"\\u001b[1m Total params: \\u001b[0m\\u001b[38;5;34m0\\u001b[0m (0.00 B)\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/html\": [\n       \"<pre style=\\\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\\\"><span style=\\\"font-weight: bold\\\"> Trainable params: </span><span style=\\\"color: #00af00; text-decoration-color: #00af00\\\">0</span> (0.00 B)\\n\",\n       \"</pre>\\n\"\n      ],\n      \"text/plain\": [\n       \"\\u001b[1m Trainable params: \\u001b[0m\\u001b[38;5;34m0\\u001b[0m (0.00 B)\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/html\": [\n       \"<pre style=\\\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\\\"><span style=\\\"font-weight: bold\\\"> Non-trainable params: </span><span style=\\\"color: #00af00; text-decoration-color: #00af00\\\">0</span> (0.00 B)\\n\",\n       \"</pre>\\n\"\n      ],\n      \"text/plain\": [\n       \"\\u001b[1m Non-trainable params: \\u001b[0m\\u001b[38;5;34m0\\u001b[0m (0.00 B)\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"model = tf.keras.Sequential([\\n\",\n    \"  layers.Embedding(max_features, embedding_dim),\\n\",\n    \"  layers.Dropout(0.2),\\n\",\n    \"  layers.GlobalAveragePooling1D(),\\n\",\n    \"  layers.Dropout(0.2),\\n\",\n    \"  layers.Dense(1, activation='sigmoid')])\\n\",\n    \"\\n\",\n    \"model.summary()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 20,\n   \"id\": \"d9059b93-7666-46db-bf15-517c4c205df9\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"model.compile(loss=losses.BinaryCrossentropy(),\\n\",\n    \"              optimizer='adam',\\n\",\n    \"              metrics=[tf.metrics.BinaryAccuracy(threshold=0.5)])\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"f8b66d33\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Train model\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 21,\n   \"id\": \"b1d5959f-1bd8-48da-9815-8239599519b2\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Epoch 1/10\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\\n\",\n      \"I0000 00:00:1738706722.621647 3744883 service.cc:146] XLA service 0x334cd320 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:\\n\",\n      \"I0000 00:00:1738706722.621667 3744883 service.cc:154]   StreamExecutor device (0): NVIDIA RTX A6000, Compute Capability 8.6\\n\",\n      \"2025-02-04 14:05:22.635317: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.\\n\",\n      \"2025-02-04 14:05:22.689182: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:531] Loaded cuDNN version 8907\\n\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\u001b[1m262/625\\u001b[0m \\u001b[32m━━━━━━━━\\u001b[0m\\u001b[37m━━━━━━━━━━━━\\u001b[0m \\u001b[1m0s\\u001b[0m 578us/step - binary_accuracy: 0.5299 - loss: 0.6904\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"I0000 00:00:1738706723.175401 3744883 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.\\n\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\u001b[1m625/625\\u001b[0m \\u001b[32m━━━━━━━━━━━━━━━━━━━━\\u001b[0m\\u001b[37m\\u001b[0m \\u001b[1m2s\\u001b[0m 2ms/step - binary_accuracy: 0.5692 - loss: 0.6832 - val_binary_accuracy: 0.7020 - val_loss: 0.6195\\n\",\n      \"Epoch 2/10\\n\",\n      \"\\u001b[1m625/625\\u001b[0m \\u001b[32m━━━━━━━━━━━━━━━━━━━━\\u001b[0m\\u001b[37m\\u001b[0m \\u001b[1m0s\\u001b[0m 455us/step - binary_accuracy: 0.7588 - loss: 0.5825 - val_binary_accuracy: 0.7954 - val_loss: 0.5009\\n\",\n      \"Epoch 3/10\\n\",\n      \"\\u001b[1m625/625\\u001b[0m \\u001b[32m━━━━━━━━━━━━━━━━━━━━\\u001b[0m\\u001b[37m\\u001b[0m \\u001b[1m0s\\u001b[0m 536us/step - binary_accuracy: 0.8293 - loss: 0.4681 - val_binary_accuracy: 0.8352 - val_loss: 0.4253\\n\",\n      \"Epoch 4/10\\n\",\n      \"\\u001b[1m625/625\\u001b[0m \\u001b[32m━━━━━━━━━━━━━━━━━━━━\\u001b[0m\\u001b[37m\\u001b[0m \\u001b[1m0s\\u001b[0m 516us/step - binary_accuracy: 0.8523 - loss: 0.3967 - val_binary_accuracy: 0.8516 - val_loss: 0.3802\\n\",\n      \"Epoch 5/10\\n\",\n      \"\\u001b[1m625/625\\u001b[0m \\u001b[32m━━━━━━━━━━━━━━━━━━━━\\u001b[0m\\u001b[37m\\u001b[0m \\u001b[1m0s\\u001b[0m 448us/step - binary_accuracy: 0.8692 - loss: 0.3524 - val_binary_accuracy: 0.8592 - val_loss: 0.3522\\n\",\n      \"Epoch 6/10\\n\",\n      \"\\u001b[1m625/625\\u001b[0m \\u001b[32m━━━━━━━━━━━━━━━━━━━━\\u001b[0m\\u001b[37m\\u001b[0m \\u001b[1m0s\\u001b[0m 530us/step - binary_accuracy: 0.8810 - loss: 0.3199 - val_binary_accuracy: 0.8658 - val_loss: 0.3324\\n\",\n      \"Epoch 7/10\\n\",\n      \"\\u001b[1m625/625\\u001b[0m \\u001b[32m━━━━━━━━━━━━━━━━━━━━\\u001b[0m\\u001b[37m\\u001b[0m \\u001b[1m0s\\u001b[0m 489us/step - binary_accuracy: 0.8919 - loss: 0.2945 - val_binary_accuracy: 0.8666 - val_loss: 0.3188\\n\",\n      \"Epoch 8/10\\n\",\n      \"\\u001b[1m625/625\\u001b[0m \\u001b[32m━━━━━━━━━━━━━━━━━━━━\\u001b[0m\\u001b[37m\\u001b[0m \\u001b[1m0s\\u001b[0m 509us/step - binary_accuracy: 0.8975 - loss: 0.2744 - val_binary_accuracy: 0.8720 - val_loss: 0.3085\\n\",\n      \"Epoch 9/10\\n\",\n      \"\\u001b[1m625/625\\u001b[0m \\u001b[32m━━━━━━━━━━━━━━━━━━━━\\u001b[0m\\u001b[37m\\u001b[0m \\u001b[1m0s\\u001b[0m 389us/step - binary_accuracy: 0.9042 - loss: 0.2565 - val_binary_accuracy: 0.8756 - val_loss: 0.3017\\n\",\n      \"Epoch 10/10\\n\",\n      \"\\u001b[1m625/625\\u001b[0m \\u001b[32m━━━━━━━━━━━━━━━━━━━━\\u001b[0m\\u001b[37m\\u001b[0m \\u001b[1m0s\\u001b[0m 410us/step - binary_accuracy: 0.9121 - loss: 0.2409 - val_binary_accuracy: 0.8750 - val_loss: 0.2972\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"epochs = 10\\n\",\n    \"history = model.fit(\\n\",\n    \"    train_ds,\\n\",\n    \"    validation_data=val_ds,\\n\",\n    \"    epochs=epochs)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"4c8d8f2a\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Evaluate the model\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 22,\n   \"id\": \"656afe07-354f-4ff2-8e3e-d02bad6c5958\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\u001b[1m782/782\\u001b[0m \\u001b[32m━━━━━━━━━━━━━━━━━━━━\\u001b[0m\\u001b[37m\\u001b[0m \\u001b[1m0s\\u001b[0m 573us/step - binary_accuracy: 0.8719 - loss: 0.3147\\n\",\n      \"Loss:  0.3172186613082886\\n\",\n      \"Accuracy:  0.8701599836349487\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"loss, accuracy = model.evaluate(test_ds)\\n\",\n    \"\\n\",\n    \"print(\\\"Loss: \\\", loss)\\n\",\n    \"print(\\\"Accuracy: \\\", accuracy)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"b2a307ce\",\n   \"metadata\": {},\n   \"source\": [\n    \"Create a plot of accuracy and loss over time:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 23,\n   \"id\": \"a01d0f13-d0b8-4d78-9ddc-ede5ed402446\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"dict_keys(['binary_accuracy', 'loss', 'val_binary_accuracy', 'val_loss'])\"\n      ]\n     },\n     \"execution_count\": 23,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"history_dict = history.history\\n\",\n    \"history_dict.keys()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 24,\n   \"id\": \"1f7484c3-3cdf-46d5-b95d-80316f0e6240\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"image/png\": \"iVBORw0KGgoAAAANSUhEUgAAAjcAAAHHCAYAAABDUnkqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABUfUlEQVR4nO3deZyNdf/H8deZGbNhBoNZzDBI9n2LuS2VQmVJNEoMdetO1qQbt52izRYhLbRKaSyVXZRQ3EmppNzZwqAwYx/OXL8/rt8cc8wYs5yZa+ac9/PxOA/nXOc65/ocM3Xevtf3e31shmEYiIiIiLgJL6sLEBEREXElhRsRERFxKwo3IiIi4lYUbkRERMStKNyIiIiIW1G4EREREbeicCMiIiJuReFGRERE3IrCjYiIiLgVhRsRC/Tu3Zvo6OgcvXb8+PHYbDbXFlTAHDhwAJvNxsKFC/P1uJs2bcJms7Fp0ybHtqz+rPKq5ujoaHr37u3S98yKhQsXYrPZOHDgQL4fWyS3FG5E0rDZbFm6pf3yE8mtrVu3Mn78eM6cOWN1KSJuwcfqAkQKknfffdfp8TvvvMO6devSba9evXqujvP666+TkpKSo9eOHj2aESNG5Or4knW5+Vll1datW5kwYQK9e/emRIkSTs/t3bsXLy/9O1QkOxRuRNJ45JFHnB5/8803rFu3Lt326124cIHAwMAsH6dIkSI5qg/Ax8cHHx/9p5tfcvOzcgU/Pz9Ljy9SGOmfAyLZ1Lp1a2rVqsV3331Hy5YtCQwM5D//+Q8Ay5cv59577yUiIgI/Pz8qV67MpEmTsNvtTu9x/TyO1PkaL7/8MvPnz6dy5cr4+fnRuHFjduzY4fTajObc2Gw2BgwYwLJly6hVqxZ+fn7UrFmT1atXp6t/06ZNNGrUCH9/fypXrsxrr72W5Xk8mzdvplu3bpQvXx4/Pz+ioqJ46qmnuHjxYrrPV6xYMY4cOULnzp0pVqwYZcqUYdiwYen+Ls6cOUPv3r0JDg6mRIkSxMXFZen0zH//+19sNhtvv/12uufWrFmDzWbjs88+A+DgwYM8+eSTVK1alYCAAEJCQujWrVuW5pNkNOcmqzX/+OOP9O7dm0qVKuHv709YWBiPPvoof//9t2Of8ePH88wzzwBQsWJFx6nP1NoymnPzxx9/0K1bN0qVKkVgYCC33XYbn3/+udM+qfOHPvroI5577jkiIyPx9/fnzjvvZN++fTf93DcyZ84catasiZ+fHxEREfTv3z/dZ//999954IEHCAsLw9/fn8jISLp3705iYqJjn3Xr1vGPf/yDEiVKUKxYMapWrer470gkt/TPP5Ec+Pvvv2nfvj3du3fnkUceITQ0FDAnYRYrVoyhQ4dSrFgxvvjiC8aOHUtSUhIvvfTSTd/3gw8+4OzZs/zrX//CZrPx4osv0qVLF/7444+bjiB8/fXXxMfH8+STT1K8eHFeeeUVHnjgAQ4dOkRISAgA33//Pe3atSM8PJwJEyZgt9uZOHEiZcqUydLn/vjjj7lw4QL9+vUjJCSE7du3M2vWLP78808+/vhjp33tdjtt27aladOmvPzyy6xfv56pU6dSuXJl+vXrB4BhGHTq1Imvv/6aJ554gurVq7N06VLi4uJuWkujRo2oVKkSH330Ubr9Fy9eTMmSJWnbti0AO3bsYOvWrXTv3p3IyEgOHDjA3Llzad26Nb/88ku2Rt2yU/O6dev4448/6NOnD2FhYfz888/Mnz+fn3/+mW+++QabzUaXLl347bffWLRoEdOnT6d06dIAN/yZHD9+nObNm3PhwgUGDRpESEgIb7/9Nh07dmTJkiXcf//9Tvs///zzeHl5MWzYMBITE3nxxRfp0aMH3377bZY/c6rx48czYcIE2rRpQ79+/di7dy9z585lx44dbNmyhSJFipCcnEzbtm25fPkyAwcOJCwsjCNHjvDZZ59x5swZgoOD+fnnn7nvvvuoU6cOEydOxM/Pj3379rFly5Zs1ySSIUNEbqh///7G9f+ZtGrVygCMefPmpdv/woUL6bb961//MgIDA41Lly45tsXFxRkVKlRwPN6/f78BGCEhIcapU6cc25cvX24AxqeffurYNm7cuHQ1AYavr6+xb98+x7YffvjBAIxZs2Y5tnXo0MEIDAw0jhw54tj2+++/Gz4+PuneMyMZfb4pU6YYNpvNOHjwoNPnA4yJEyc67Vu/fn2jYcOGjsfLli0zAOPFF190bLt69arRokULAzAWLFiQaT0jR440ihQp4vR3dvnyZaNEiRLGo48+mmnd27ZtMwDjnXfecWzbuHGjARgbN250+ixpf1bZqTmj4y5atMgAjK+++sqx7aWXXjIAY//+/en2r1ChghEXF+d4PGTIEAMwNm/e7Nh29uxZo2LFikZ0dLRht9udPkv16tWNy5cvO/adOXOmARi7d+9Od6y0FixY4FTTiRMnDF9fX+Puu+92HMMwDGP27NkGYLz11luGYRjG999/bwDGxx9/fMP3nj59ugEYJ0+ezLQGkZzSaSmRHPDz86NPnz7ptgcEBDjunz17lr/++osWLVpw4cIFfv3115u+b2xsLCVLlnQ8btGiBWCehriZNm3aULlyZcfjOnXqEBQU5Hit3W5n/fr1dO7cmYiICMd+t9xyC+3bt7/p+4Pz5zt//jx//fUXzZs3xzAMvv/++3T7P/HEE06PW7Ro4fRZVq5ciY+Pj2MkB8Db25uBAwdmqZ7Y2FiuXLlCfHy8Y9vatWs5c+YMsbGxGdZ95coV/v77b2655RZKlCjBzp07s3SsnNSc9riXLl3ir7/+4rbbbgPI9nHTHr9Jkyb84x//cGwrVqwYjz/+OAcOHOCXX35x2r9Pnz74+vo6Hmfndyqt9evXk5yczJAhQ5wmOPft25egoCDHabHg4GDAPDV44cKFDN8rddL08uXL83yytngmhRuRHChXrpzTF0aqn3/+mfvvv5/g4GCCgoIoU6aMYzJy2vkGN1K+fHmnx6lB5/Tp09l+berrU1974sQJLl68yC233JJuv4y2ZeTQoUP07t2bUqVKOebRtGrVCkj/+fz9/dOdWklbD5hzYcLDwylWrJjTflWrVs1SPXXr1qVatWosXrzYsW3x4sWULl2aO+64w7Ht4sWLjB07lqioKPz8/ChdujRlypThzJkzWfq5pJWdmk+dOsXgwYMJDQ0lICCAMmXKULFiRSBrvw83On5Gx0pdwXfw4EGn7bn5nbr+uJD+c/r6+lKpUiXH8xUrVmTo0KG88cYblC5dmrZt2/Lqq686fd7Y2FhiYmL45z//SWhoKN27d+ejjz5S0BGX0ZwbkRxI+y/yVGfOnKFVq1YEBQUxceJEKleujL+/Pzt37mT48OFZ+h+3t7d3htsNw8jT12aF3W7nrrvu4tSpUwwfPpxq1apRtGhRjhw5Qu/evdN9vhvV42qxsbE899xz/PXXXxQvXpwVK1bw0EMPOa0oGzhwIAsWLGDIkCE0a9aM4OBgbDYb3bt3z9Mv1AcffJCtW7fyzDPPUK9ePYoVK0ZKSgrt2rXLty/yvP69yMjUqVPp3bs3y5cvZ+3atQwaNIgpU6bwzTffEBkZSUBAAF999RUbN27k888/Z/Xq1SxevJg77riDtWvX5tvvjrgvhRsRF9m0aRN///038fHxtGzZ0rF9//79FlZ1TdmyZfH3989wpUxWVs/s3r2b3377jbfffptevXo5tq9bty7HNVWoUIENGzZw7tw5p5GQvXv3Zvk9YmNjmTBhAp988gmhoaEkJSXRvXt3p32WLFlCXFwcU6dOdWy7dOlSji6al9WaT58+zYYNG5gwYQJjx451bP/999/TvWd2rjhdoUKFDP9+Uk97VqhQIcvvlR2p77t3714qVark2J6cnMz+/ftp06aN0/61a9emdu3ajB49mq1btxITE8O8efN49tlnAfDy8uLOO+/kzjvvZNq0aUyePJlRo0axcePGdO8lkl06LSXiIqn/2kz7L+Lk5GTmzJljVUlOvL29adOmDcuWLePo0aOO7fv27WPVqlVZej04fz7DMJg5c2aOa7rnnnu4evUqc+fOdWyz2+3MmjUry+9RvXp1ateuzeLFi1m8eDHh4eFO4TK19utHKmbNmpVuWbora87o7wtgxowZ6d6zaNGiAFkKW/fccw/bt29n27Ztjm3nz59n/vz5REdHU6NGjax+lGxp06YNvr6+vPLKK06f6c033yQxMZF7770XgKSkJK5ever02tq1a+Pl5cXly5cB83Td9erVqwfg2EckNzRyI+IizZs3p2TJksTFxTFo0CBsNhvvvvtung7/Z9f48eNZu3YtMTEx9OvXD7vdzuzZs6lVqxa7du3K9LXVqlWjcuXKDBs2jCNHjhAUFMQnn3yS7bkbaXXo0IGYmBhGjBjBgQMHqFGjBvHx8dmejxIbG8vYsWPx9/fnscceS3dF3/vuu493332X4OBgatSowbZt21i/fr1jiXxe1BwUFETLli158cUXuXLlCuXKlWPt2rUZjuQ1bNgQgFGjRtG9e3eKFClChw4dHKEnrREjRrBo0SLat2/PoEGDKFWqFG+//Tb79+/nk08+ybOrGZcpU4aRI0cyYcIE2rVrR8eOHdm7dy9z5syhcePGjrllX3zxBQMGDKBbt27ceuutXL16lXfffRdvb28eeOABACZOnMhXX33FvffeS4UKFThx4gRz5swhMjLSaaK0SE4p3Ii4SEhICJ999hlPP/00o0ePpmTJkjzyyCPceeedjuutWK1hw4asWrWKYcOGMWbMGKKiopg4cSJ79uy56WquIkWK8OmnnzrmT/j7+3P//fczYMAA6tatm6N6vLy8WLFiBUOGDOG9997DZrPRsWNHpk6dSv369bP8PrGxsYwePZoLFy44rZJKNXPmTLy9vXn//fe5dOkSMTExrF+/Pkc/l+zU/MEHHzBw4EBeffVVDMPg7rvvZtWqVU6r1QAaN27MpEmTmDdvHqtXryYlJYX9+/dnGG5CQ0PZunUrw4cPZ9asWVy6dIk6derw6aefOkZP8sr48eMpU6YMs2fP5qmnnqJUqVI8/vjjTJ482XEdprp169K2bVs+/fRTjhw5QmBgIHXr1mXVqlWOlWIdO3bkwIEDvPXWW/z111+ULl2aVq1aMWHCBMdqK5HcsBkF6Z+VImKJzp078/PPP2c4H0REpLDRnBsRD3N9q4Tff/+dlStX0rp1a2sKEhFxMY3ciHiY8PBwR7+jgwcPMnfuXC5fvsz3339PlSpVrC5PRCTXNOdGxMO0a9eORYsWkZCQgJ+fH82aNWPy5MkKNiLiNjRyIyIiIm5Fc25ERETErSjciIiIiFvxuDk3KSkpHD16lOLFi2frkuciIiJiHcMwOHv2LBERETe9WKXHhZujR48SFRVldRkiIiKSA4cPHyYyMjLTfTwu3BQvXhww/3KCgoIsrkZERESyIikpiaioKMf3eGY8LtyknooKCgpSuBERESlksjKlRBOKRURExK0o3IiIiIhbUbgRERERt+Jxc25ERMS17HY7V65csboMcQO+vr43XeadFQo3IiKSI4ZhkJCQwJkzZ6wuRdyEl5cXFStWxNfXN1fvo3AjIiI5khpsypYtS2BgoC6MKrmSepHdY8eOUb58+Vz9PinciIhIttntdkewCQkJsboccRNlypTh6NGjXL16lSJFiuT4fTShWEREsi11jk1gYKDFlYg7ST0dZbfbc/U+CjciIpJjOhUlruSq3yedlnIRux02b4ZjxyA8HFq0AG9vq6sSERHxPBq5cYH4eIiOhttvh4cfNv+Mjja3i4iI+4uOjmbGjBlZ3n/Tpk3YbLY8X2m2cOFCSpQokafHKIgUbnIpPh66doU//3TefuSIuV0BR0Qkc3Y7bNoEixaZf+ZyukWmbDZbprfx48fn6H137NjB448/nuX9mzdvzrFjxwgODs7R8SRzOi2VC3Y7DB4MhpH+OcMAmw2GDIFOnXSKSkQkI/Hx5v9H0/4DMTISZs6ELl1cf7xjx4457i9evJixY8eyd+9ex7ZixYo57huGgd1ux8fn5l+VZcqUyVYdvr6+hIWFZes1knUaucmFzZvTj9ikZRhw+LC5n4iIOLNi5DssLMxxCw4OxmazOR7/+uuvFC9enFWrVtGwYUP8/Pz4+uuv+d///kenTp0IDQ2lWLFiNG7cmPXr1zu97/WnpWw2G2+88Qb3338/gYGBVKlShRUrVjiev/60VOrpozVr1lC9enWKFStGu3btnMLY1atXGTRoECVKlCAkJIThw4cTFxdH586ds/V3MHfuXCpXroyvry9Vq1bl3XffdTxnGAbjx4+nfPny+Pn5ERERwaBBgxzPz5kzhypVquDv709oaChdu3bN1rHzi8JNLqT5nXPJfiIinuJmI99gjnzn5SmqGxkxYgTPP/88e/bsoU6dOpw7d4577rmHDRs28P3339OuXTs6dOjAoUOHMn2fCRMm8OCDD/Ljjz9yzz330KNHD06dOnXD/S9cuMDLL7/Mu+++y1dffcWhQ4cYNmyY4/kXXniB999/nwULFrBlyxaSkpJYtmxZtj7b0qVLGTx4ME8//TQ//fQT//rXv+jTpw8bN24E4JNPPmH69Om89tpr/P777yxbtozatWsD8N///pdBgwYxceJE9u7dy+rVq2nZsmW2jp9vDA+TmJhoAEZiYmKu32vjRsMw/zPM/LZxY64PJSJSoFy8eNH45ZdfjIsXL+bo9QXh/58LFiwwgoOD09S00QCMZcuW3fS1NWvWNGbNmuV4XKFCBWP69OmOx4AxevRox+Nz584ZgLFq1SqnY50+fdpRC2Ds27fP8ZpXX33VCA0NdTwODQ01XnrpJcfjq1evGuXLlzc6deqU5c/YvHlzo2/fvk77dOvWzbjnnnsMwzCMqVOnGrfeequRnJyc7r0++eQTIygoyEhKSrrh8XIrs9+r7Hx/a+QmF1q0MM8N32hZvs0GUVHmfiIick1BHvlu1KiR0+Nz584xbNgwqlevTokSJShWrBh79uy56chNnTp1HPeLFi1KUFAQJ06cuOH+gYGBVK5c2fE4PDzcsX9iYiLHjx+nSZMmjue9vb1p2LBhtj7bnj17iImJcdoWExPDnj17AOjWrRsXL16kUqVK9O3bl6VLl3L16lUA7rrrLipUqEClSpXo2bMn77//PhcuXMjW8fOLwk0ueHubk94gfcBJfTxjhiYTi4hcLzzctfu5UtGiRZ0eDxs2jKVLlzJ58mQ2b97Mrl27qF27NsnJyZm+z/XtA2w2GykpKdna38jovF0eioqKYu/evcyZM4eAgACefPJJWrZsyZUrVyhevDg7d+5k0aJFhIeHM3bsWOrWrVsgG6cq3ORSly6wZAmUK+e8PTLS3J4Xs/1FRAq7wjTyvWXLFnr37s39999P7dq1CQsL48CBA/laQ3BwMKGhoezYscOxzW63s3Pnzmy9T/Xq1dmyZYvTti1btlCjRg3H44CAADp06MArr7zCpk2b2LZtG7t37wbAx8eHNm3a8OKLL/Ljjz9y4MABvvjii1x8sryhpeAu0KWLudxbVygWEcma1JHvrl3NIJN2gKKgjXxXqVKF+Ph4OnTogM1mY8yYMZmOwOSVgQMHMmXKFG655RaqVavGrFmzOH36dLZaFjzzzDM8+OCD1K9fnzZt2vDpp58SHx/vWP21cOFC7HY7TZs2JTAwkPfee4+AgAAqVKjAZ599xh9//EHLli0pWbIkK1euJCUlhapVq+bVR84xhRsX8faG1q2trkJEpPBIHfnO6Do3M2YUnJHvadOm8eijj9K8eXNKly7N8OHDSUpKyvc6hg8fTkJCAr169cLb25vHH3+ctm3b4p2NBNi5c2dmzpzJyy+/zODBg6lYsSILFiyg9f9/gZUoUYLnn3+eoUOHYrfbqV27Np9++ikhISGUKFGC+Ph4xo8fz6VLl6hSpQqLFi2iZs2aefSJc85m5PcJPYslJSURHBxMYmIiQUFBVpcjIlIoXbp0if3791OxYkX8/f1z9V7qzZczKSkpVK9enQcffJBJkyZZXY5LZPZ7lZ3vb43ciIiIpTTynTUHDx5k7dq1tGrVisuXLzN79mz279/Pww8/bHVpBY4mFIuIiBQCXl5eLFy4kMaNGxMTE8Pu3btZv3491atXt7q0AkcjNyIiIoVAVFRUupVOkjGN3IiIiIhbUbgRERERt6JwIyIiIm5F4UZERETcisKNiIiIuBWFGxEREXErCjciIiLZ1Lp1a4YMGeJ4HB0dzYwZMzJ9jc1mY9myZbk+tqveJzPjx4+nXr16eXqMvKRwIyIiHqNDhw60a9cuw+c2b96MzWbjxx9/zPb77tixg8cffzy35Tm5UcA4duwY7du3d+mx3I3CjYiIeIzHHnuMdevW8WfaTp3/b8GCBTRq1Ig6depk+33LlClDYGCgK0q8qbCwMPz8/PLlWIWVwo2IiHiM++67jzJlyrBw4UKn7efOnePjjz/mscce4++//+ahhx6iXLlyBAYGUrt2bRYtWpTp+15/Wur333+nZcuW+Pv7U6NGDdatW5fuNcOHD+fWW28lMDCQSpUqMWbMGK5cuQLAwoULmTBhAj/88AM2mw2bzeao+frTUrt37+aOO+4gICCAkJAQHn/8cc6dO+d4vnfv3nTu3JmXX36Z8PBwQkJC6N+/v+NYWZGSksLEiROJjIzEz8+PevXqsXr1asfzycnJDBgwgPDwcPz9/alQoQJTpkwBwDAMxo8fT/ny5fHz8yMiIoJBgwZl+dg5ofYLIiLiEoYBFy5Yc+zAQLDZbr6fj48PvXr1YuHChYwaNQrb/7/o448/xm6389BDD3Hu3DkaNmzI8OHDCQoK4vPPP6dnz55UrlyZJk2a3PQYKSkpdOnShdDQUL799lsSExOd5uekKl68OAsXLiQiIoLdu3fTt29fihcvzr///W9iY2P56aefWL16NevXrwcgODg43XucP3+etm3b0qxZM3bs2MGJEyf45z//yYABA5wC3MaNGwkPD2fjxo3s27eP2NhY6tWrR9++fW/+lwbMnDmTqVOn8tprr1G/fn3eeustOnbsyM8//0yVKlV45ZVXWLFiBR999BHly5fn8OHDHD58GIBPPvmE6dOn8+GHH1KzZk0SEhL44YcfsnTcHDM8TGJiogEYiYmJVpciIlJoXbx40fjll1+MixcvOradO2cYZsTJ/9u5c1mvfc+ePQZgbNy40bGtRYsWxiOPPHLD19x7773G008/7XjcqlUrY/DgwY7HFSpUMKZPn24YhmGsWbPG8PHxMY4cOeJ4ftWqVQZgLF269IbHeOmll4yGDRs6Ho8bN86oW7duuv3Svs/8+fONkiVLGufS/AV8/vnnhpeXl5GQkGAYhmHExcUZFSpUMK5everYp1u3bkZsbOwNa7n+2BEREcZzzz3ntE/jxo2NJ5980jAMwxg4cKBxxx13GCkpKenea+rUqcatt95qJCcn3/B4qTL6vUqVne9vnZYSERGPUq1aNZo3b85bb70FwL59+9i8eTOPPfYYAHa7nUmTJlG7dm1KlSpFsWLFWLNmDYcOHcrS++/Zs4eoqCgiIiIc25o1a5Zuv8WLFxMTE0NYWBjFihVj9OjRWT5G2mPVrVuXokWLOrbFxMSQkpLC3r17Hdtq1qyJt7e343F4eDgnTpzI0jGSkpI4evQoMTExTttjYmLYs2cPYJ762rVrF1WrVmXQoEGsXbvWsV+3bt24ePEilSpVom/fvixdupSrV69m63Nml8KNiIi4RGAgnDtnzS27c3kfe+wxPvnkE86ePcuCBQuoXLkyrVq1AuCll15i5syZDB8+nI0bN7Jr1y7atm1LcnKyy/6utm3bRo8ePbjnnnv47LPP+P777xk1apRLj5FWkSJFnB7bbDZSUlJc9v4NGjRg//79TJo0iYsXL/Lggw/StWtXwOxmvnfvXubMmUNAQABPPvkkLVu2zNacn+zSnBsREXEJmw3SDCAUaA8++CCDBw/mgw8+4J133qFfv36O+TdbtmyhU6dOPPLII4A5h+a3336jRo0aWXrv6tWrc/jwYY4dO0Z4eDgA33zzjdM+W7dupUKFCowaNcqx7eDBg077+Pr6Yrfbb3qshQsXcv78ecfozZYtW/Dy8qJq1apZqvdmgoKCiIiIYMuWLY4AmHqctHOQgoKCiI2NJTY2lq5du9KuXTtOnTpFqVKlCAgIoEOHDnTo0IH+/ftTrVo1du/eTYMGDVxS4/UUbkRExOMUK1aM2NhYRo4cSVJSEr1793Y8V6VKFZYsWcLWrVspWbIk06ZN4/jx41kON23atOHWW28lLi6Ol156iaSkJKcQk3qMQ4cO8eGHH9K4cWM+//xzli5d6rRPdHQ0+/fvZ9euXURGRlK8ePF0S8B79OjBuHHjiIuLY/z48Zw8eZKBAwfSs2dPQkNDc/aXk4FnnnmGcePGUblyZerVq8eCBQvYtWsX77//PgDTpk0jPDyc+vXr4+Xlxccff0xYWBglSpRg4cKF2O12mjZtSmBgIO+99x4BAQFUqFDBZfVdT6elRETEIz322GOcPn2atm3bOs2PGT16NA0aNKBt27a0bt2asLAwOnfunOX39fLyYunSpVy8eJEmTZrwz3/+k+eee85pn44dO/LUU08xYMAA6tWrx9atWxkzZozTPg888ADt2rXj9ttvp0yZMhkuRw8MDGTNmjWcOnWKxo0b07VrV+68805mz56dvb+Mmxg0aBBDhw7l6aefpnbt2qxevZoVK1ZQpUoVwFz59eKLL9KoUSMaN27MgQMHWLlyJV5eXpQoUYLXX3+dmJgY6tSpw/r16/n0008JCQlxaY1p2QzDMPLs3QugpKQkgoODSUxMJCgoyOpyREQKpUuXLrF//34qVqyIv7+/1eWIm8js9yo7398auRERERG3onAjIiIibkXhRkRERNyKwo2IiIi4FYUbERHJMQ9bkyJ5zFW/Two3IiKSbalXvL1gVadMcUupV2hO2yoiJ3QRPxc6fhySkuD/l/2LiLgtb29vSpQo4ehPFBgY6LjCr0hOpKSkcPLkSQIDA/HxyV08UbhxkRUr4OGHoWFD2LTJvAy5iIg7CwsLA8hyA0aRm/Hy8qJ8+fK5DsoKNy5Svz5cvQpffQWrVsE991hdkYhI3rLZbISHh1O2bNk8bYIonsPX1xcvr9zPmFG4cZGoKBg0CF56CUaMgLZtIZenDEVECgVvb+9cz5EQcSVNKHahESOgRAnYvRs++MDqakRERDyTwo0LlSoFI0ea90ePhkuXrK1HRETEEyncuNjAgVCuHBw6BHPnWl2NiIiI51G4cbGAAJgwwbz/7LOQmGhtPSIiIp5G4SYPxMVB9epw6hS8+KLV1YiIiHgWhZs84OMDU6aY96dPh2PHrK1HRETEkyjc5JGOHaF5c7h48dppKhEREcl7Cjd5xGaDF14w77/xBuzda209IiIinkLhJg/94x/QoQPY7TBqlNXViIiIeAbLw82rr75KdHQ0/v7+NG3alO3bt2e6/5kzZ+jfvz/h4eH4+flx6623snLlynyqNvsmTwYvL/jkE/j2W6urERERcX+WhpvFixczdOhQxo0bx86dO6lbty5t27a9YRO25ORk7rrrLg4cOMCSJUvYu3cvr7/+OuXKlcvnyrOuVi1z9RTA8OFgGNbWIyIi4u5shmHd123Tpk1p3Lgxs2fPBsx251FRUQwcOJARI0ak23/evHm89NJL/PrrrxQpUiRHx0xKSiI4OJjExESCgoJyVX9WHT4MVarA5cuwciW0b58vhxUREXEb2fn+tmzkJjk5me+++442bdpcK8bLizZt2rBt27YMX7NixQqaNWtG//79CQ0NpVatWkyePBm73X7D41y+fJmkpCSnW35LbaoJ5uhNJuWKiIhILlkWbv766y/sdjuhoaFO20NDQ0lISMjwNX/88QdLlizBbrezcuVKxowZw9SpU3n22WdveJwpU6YQHBzsuEVFRbn0c2SVmmqKiIjkD8snFGdHSkoKZcuWZf78+TRs2JDY2FhGjRrFvHnzbviakSNHkpiY6LgdPnw4Hyu+plQpM+CAmmqKiIjkJcvCTenSpfH29ub48eNO248fP05YWFiGrwkPD+fWW2/F29vbsa169eokJCSQnJyc4Wv8/PwICgpyulll0CA11RQREclrloUbX19fGjZsyIYNGxzbUlJS2LBhA82aNcvwNTExMezbt4+UlBTHtt9++43w8HB8fX3zvObcSttU87nn1FRTREQkL1h6Wmro0KG8/vrrvP322+zZs4d+/fpx/vx5+vTpA0CvXr0YOXKkY/9+/fpx6tQpBg8ezG+//cbnn3/O5MmT6d+/v1UfIdtSm2r+/Te89JLV1YiIiLgfHysPHhsby8mTJxk7diwJCQnUq1eP1atXOyYZHzp0CC+va/krKiqKNWvW8NRTT1GnTh3KlSvH4MGDGT58uFUfIdtSm2p27gzTpkH//hAebnVVIiIi7sPS69xYwYrr3FzPMMzWDFu3wr/+BZnMhxYREREKyXVuPJnNBs8/b95XU00RERHXUrixSIsW15pqjh5tdTUiIiLuQ+HGQqlNNZcsUVNNERERV1G4sZCaaoqIiLiewo3FJkwAPz/48ktYvdrqakRERAo/hRuLRUXBwIHmfTXVFBERyT2FmwJg5EgIDlZTTREREVdQuCkASpUyAw7AmDFw+bK19YiIiBRmCjcFRGpTzYMH1VRTREQkNxRuCoi0TTWffVZNNUVERHJK4aYAUVNNERGR3FO4KUB8fMwL+4HZVPPYMWvrERERKYwUbgqYTp2gWTO4ePHaaSoRERHJOoWbAsZmgxdeMO+/8Qb89pu19YiIiBQ2CjcFUNqmmqNGWV2NiIhI4aJwU0CpqaaIiEjOKNwUULVqQa9e5n011RQREck6hZsCTE01RUREsk/hpgArX/5aU80RIyAlxdp6RERECgOFmwIutanmjz+qqaaIiEhWKNwUcGmbao4eraaaIiIiN6NwUwioqaaIiEjWKdwUAgEBMH68eV9NNUVERDKncFNI9O4N1aqpqaaIiMjNKNwUEj4+MGWKeX/6dDXVFBERuRGFm0IktanmhQswcWLeHMNuh02bYNEi80+7PW+OIyIiklcUbgqRtE01X3/d9U014+MhOhpuvx0eftj8Mzra3C4iIlJYKNwUMi1awH33ub6pZnw8dO0Kf/7pvP3IEXO7Ao6IiBQWCjeF0JQp5ijOkiWwfXvu389uh8GDM+5flbptyBCdohIRkcJB4aYQqlUL4uLM+65oqrl5c/oRm7QMAw4fNvcTEREp6BRuCqnUppqbNsGaNbl7r6yuvNIKLRERKQwUbgqptE01hw/PXVPN8HDX7iciImIlhZtCzFVNNVu0gMhIcx5PRmw2iIoy9xMRESnoFG4KsVKlYMQI835ummp6e8PMmeb96wNO6uMZM8z9RERECjqFm0Ju0CCIiMh9U80uXczVV+XKOW+PjDS3d+mSuzpFRETyi80wcrvWpnBJSkoiODiYxMREgoKCrC7HJd54A/r2hZAQ+N//zFNVOWW3m6uijh0z59i0aKERGxERsV52vr81cuMG0jbVfPnl3L2Xtze0bg0PPWT+qWAjIiKFjcKNG0jbVHPaNC3ZFhERz6Zw4ybyo6mmiIhIYaBw4ybyuqmmiIhIYaFw40byqqmmiIhIYaJw42Zc3VRTRESksFG4cTOubqopIiJS2CjcuCFXNtUUEREpbBRu3FD58jBggHk/t001RUREChuFGzflqqaaIiIihY3CjZsKCbnWVHPMmJw31RQRESlsFG7cWGpTzQMHYN48q6sRERHJHwo3biww0JxcDDBpEiQmWluPiIhIflC4cXOubKopIiJSGCjcuDkfH5g82byvppoiIuIJFG48QOfOcNttaqopIiKeQeHGA6ippoiIeBKFGw/RsuW1ppqjR1tdjYiISN5RuPEgqU01P/5YTTVFRMR9Kdx4kFq1oFcv876aaoqIiLtSuPEwEyeqqaaIiLg3hRsPk7ap5ogRaqopIiLuR+HGA6U21fzhB1i0yOpqREREXEvhxgOlbao5erSaaoqIiHtRuPFQaqopIiLuSuHGQwUGwvjx5n011RQREXeicOPB+vSBqlXVVFNERNyLwo0H8/ExL+wHZlPNhARr6xEREXEFhRsPp6aaIiLibhRuPFzapprz56uppoiIFH4KN6KmmiIi4lYKRLh59dVXiY6Oxt/fn6ZNm7I9k66OCxcuxGazOd38/f3zsVr3NHnytaaaO3ZYXY2IiEjOWR5uFi9ezNChQxk3bhw7d+6kbt26tG3blhMnTtzwNUFBQRw7dsxxO3jwYD5W7J5q11ZTTRERcQ+Wh5tp06bRt29f+vTpQ40aNZg3bx6BgYG89dZbN3yNzWYjLCzMcQsNDc3Hit1XalPNjRth7VqrqxEREckZS8NNcnIy3333HW3atHFs8/Lyok2bNmzbtu2Grzt37hwVKlQgKiqKTp068fPPP+dHuW4vbVPNp582V1CJiIgUNpaGm7/++gu73Z5u5CU0NJSEG1x0pWrVqrz11lssX76c9957j5SUFJo3b86ff/6Z4f6XL18mKSnJ6SY3NnIklCkDP/8MPXuqa7iIiBQ+lp+Wyq5mzZrRq1cv6tWrR6tWrYiPj6dMmTK89tprGe4/ZcoUgoODHbeoqKh8rrhwCQmB+Hjw9TX/1OopEREpbCwNN6VLl8bb25vjx487bT9+/DhhYWFZeo8iRYpQv3599u3bl+HzI0eOJDEx0XE7fPhwrut2d//4B7zxhnl/yhR4+21r6xEREckOS8ONr68vDRs2ZMOGDY5tKSkpbNiwgWbNmmXpPex2O7t37yY8PDzD5/38/AgKCnK6yc317AmjRpn3+/aFzZutrUdERCSrLD8tNXToUF5//XXefvtt9uzZQ79+/Th//jx9+vQBoFevXowcOdKx/8SJE1m7di1//PEHO3fu5JFHHuHgwYP885//tOojuK2JE6FrV7hyBe6/H/73P6srEhERuTkfqwuIjY3l5MmTjB07loSEBOrVq8fq1asdk4wPHTqEl9e1DHb69Gn69u1LQkICJUuWpGHDhmzdupUaNWpY9RHclpeXeUrq4EHzwn733QfbtkGJElZXJiIicmM2w/Csy7UlJSURHBxMYmKiTlFl0bFj0KQJ/PkntGkDK1dCkSJWVyUiIp4kO9/flp+WkoIvPBw+/RSKFoX162HgQF3BWERECi6FG8mSevVg0SKz/9Rrr8Err1hdkYiISMYUbiTLOnSAl1827w8dCp9/bm09IiIiGVG4kWx56ilzaXhKCnTvDj/+aHVFIiIizhRuJFtsNnj1VbjjDjh3zhzNuUGnDBEREUso3Ei2FSkCS5bArbfCoUPQuTNcvGh1VSIiIiaFG8mRkiXhs8+gVCn49lvo00crqEREpGBQuJEcq1LFbK5ZpAgsXgzjx1tdkYiIiMKN5FKrVubScDDbNbz/vrX1iIiIKNxIrvXpA//+t3n/0Udh61Zr6xEREc+mcCMuMWWKObE4Odn888ABiwsSERGPpXAjLuHlBe+9B/Xrw8mTZpPNpCSrqxIREU+kcCMuU7So2YMqIgJ+/hliY+HqVaurEhERT6NwIy5VrhysWAEBAbB6tdmmQUREJD8p3IjLNWxonqICmDXLvKKxiIhIflG4kTzRpYs5yRhg8GBYs8baekRExHMo3EieGT4cevcGux0efBB++cXqikRExBMo3EiesdnMC/y1bGmunLrvPnMllYiISF5SuJE85esLn3wClSvD/v3mNXAuXbK6KhERcWc5CjeHDx/mzz//dDzevn07Q4YMYf78+S4rTNxH6dJmk83gYPPqxX37qsmmiIjknRyFm4cffpiNGzcCkJCQwF133cX27dsZNWoUEydOdGmB4h6qVYMlS8Db21xJNXmy1RWJiIi7ylG4+emnn2jSpAkAH330EbVq1WLr1q28//77LFy40JX1iRtp0wbmzDHvjx4NH31kbT0iIuKechRurly5gp+fHwDr16+nY8eOAFSrVo1jx465rjpxO48/Dk89Zd6Pi4Pt262tR0RE3E+Owk3NmjWZN28emzdvZt26dbRr1w6Ao0ePEhIS4tICxf289BLce685sbhjRzh0yOqKRETEneQo3Lzwwgu89tprtG7dmoceeoi6desCsGLFCsfpKpEb8faGRYugTh04fhw6dICzZ62uSkRE3IXNMHK2bsVut5OUlETJkiUd2w4cOEBgYCBly5Z1WYGulpSURHBwMImJiQQFBVldjkc7dAiaNLkWcJYuNYOPiIjI9bLz/Z2jkZuLFy9y+fJlR7A5ePAgM2bMYO/evQU62EjBUr48LF8O/v5mN/F//9vqikRExB3kKNx06tSJd955B4AzZ87QtGlTpk6dSufOnZk7d65LCxT31rQpvP22eX/aNNClkkREJLdyFG527txJixYtAFiyZAmhoaEcPHiQd955h1deecWlBYr7e/BBSL08Uv/+sGGDtfWIiEjhlqNwc+HCBYoXLw7A2rVr6dKlC15eXtx2220cPHjQpQWKZxg9Gnr0gKtXoWtX2Ls3d+9nt8OmTebE5U2bzMciIuIZchRubrnlFpYtW8bhw4dZs2YNd999NwAnTpzQJF3JEZsN3ngDmjeHM2fMJpt//52z94qPh+houP12ePhh88/oaHO7iIi4vxyFm7FjxzJs2DCio6Np0qQJzZo1A8xRnPr167u0QPEc/v7miqnoaNi3Dx54AJKTs/ce8fHmyE+a1mcAHDliblfAERFxfzleCp6QkMCxY8eoW7cuXl5mRtq+fTtBQUFUq1bNpUW6kpaCF3w//wzNmpnXvunTB9580xzZuRm73QxG1webVDYbREaa3cm15FxEpHDJ86XgAGFhYdSvX5+jR486OoQ3adKkQAcbKRxq1jT7Tnl5wYIF5hWNs2Lz5hsHGzA7kR8+bO4nIiLuK0fhJiUlhYkTJxIcHEyFChWoUKECJUqUYNKkSaSkpLi6RvFA7drBzJnm/REjzNNVN5PVtmZqfyYi4t58cvKiUaNG8eabb/L8888TExMDwNdff8348eO5dOkSzz33nEuLFM80YIC5amr2bHjkEXPEpUGDG+8fHp61983qfiIiUjjlaM5NREQE8+bNc3QDT7V8+XKefPJJjhw54rICXU1zbgqXq1fNlVNr1kBEhNlFvFy5jPdNnXNz5Ih5Cup6mnMjIlJ45fmcm1OnTmU4t6ZatWqcOnUqJ28pkiEfH1i8GGrUgKNHzS7i589nvK+397VTWddPQE59PGOGgo2IiLvLUbipW7cus2fPTrd99uzZ1KlTJ9dFiaQVHAyffQZlysDOndCzJ9xoaleXLrBkSfrRnchIc3uXLnlfr4iIWCtHp6W+/PJL7r33XsqXL++4xs22bds4fPgwK1eudLRmKIh0Wqrw2rrVvCBfcrI5yXjKlBvva7ebc3SOHTPn2LRooREbEZHCLM9PS7Vq1YrffvuN+++/nzNnznDmzBm6dOnCzz//zLvvvpujokVupnlzeOst8/7zz5vLxG/E2xtat4aHHjL/VLAREfEcOb6IX0Z++OEHGjRogL0AN/LRyE3hN3YsTJoERYrAunXQqpXVFYmISF7Ll4v4iVhl/Hizk/iVK+Ycmn37rK5IREQKEoUbKXS8vGDhQmjSBE6dMpeKnz5tdVUiIlJQKNxIoRQQAMuXQ1SUeaG/bt3MkRwREZFsXaG4y03W0Z45cyY3tYhkS1iYuUQ8JgY2bDCvaDxvXtaabIqIiPvKVrgJDg6+6fO9evXKVUEi2VGnDixaZF7cb/58qFYNnnrK6qpERMRKLl0tVRhotZR7mj4dhg41R22WL4cOHayuSEREXEmrpcTjDBkC//qX2VPqoYfghx+srkhERKyicCNuwWaDWbPgzjvN3lMdOkBCgtVViYiIFRRuxG0UKQIffwxVq8Lhw9CpE1y8aHVVIiKS3xRuxK2ULGmuoCpVCrZvhx494MIFq6sSEZH8pHAjbueWW2DpUnMkZ+lSaNjQ7CYuIiKeQeFG3FLLlrB6NUREwK+/QtOmZrPNAtz2TEREXEThRtzWHXfAjz/CAw/A1aswcqQ54fjQIasrExGRvKRwI24tJMScZPzWW1CsGHz5pXnhvw8+sLoyERHJKwo34vZsNujTB3btgttug8REc6Jxjx6gjiEiIu5H4UY8RuXKsHkzjB8P3t7m6E3duvDVV1ZXJiIirqRwIx7FxwfGjYOvv4ZKlcz5N61bm/NxkpOtrk5ERFxB4UY80m23maepHn3UbNnw/PPQrJm5skpERAo3hRvxWMWLw5tvwiefmBf927kTGjSAefPMwCMiIoWTwo14vC5dYPduuOsus11Dv37QsSOcOGF1ZSIikhMKNyKYF/tbvRqmTwc/P7OFQ+3a8PnnVlcmIiLZpXAj8v+8vGDIENixwww2J07AfffBk0+qP5WISGGicCNyndq1zaabTz1lPp47V/2pREQKE4UbkQz4+8O0abB27bX+VLfdBi+8oP5UIiIFncKNSCbuuutaf6orV2DECPWnEhEp6BRuRG7iRv2pFi2yujIREclIgQg3r776KtHR0fj7+9O0aVO2b9+epdd9+OGH2Gw2OnfunLcFisfLqD/Vww/DI4+oP5WISEFjebhZvHgxQ4cOZdy4cezcuZO6devStm1bTtzkIiMHDhxg2LBhtGjRIp8qFUnfn+r999WfSkSkoLE83EybNo2+ffvSp08fatSowbx58wgMDOStt9664Wvsdjs9evRgwoQJVKpUKR+rFVF/KhGRgs7ScJOcnMx3331HmzZtHNu8vLxo06YN27Ztu+HrJk6cSNmyZXnsscdueozLly+TlJTkdBNxhYz6UzVvrv5UIiJWszTc/PXXX9jtdkJDQ522h4aGkpCQkOFrvv76a958801ef/31LB1jypQpBAcHO25RUVG5rlsk1fX9qb77Tv2pRESsZvlpqew4e/YsPXv25PXXX6d06dJZes3IkSNJTEx03A4fPpzHVYon6tLFXDLepo36U4mIWM3HyoOXLl0ab29vjh8/7rT9+PHjhIWFpdv/f//7HwcOHKBDhw6ObSkpKQD4+Piwd+9eKleu7PQaPz8//Pz88qB6EWflysGaNfDKK+b1cFL7U731Ftx7r9XViYh4DktHbnx9fWnYsCEbNmxwbEtJSWHDhg00a9Ys3f7VqlVj9+7d7Nq1y3Hr2LEjt99+O7t27dIpJ7Fc2v5UtWpd60/Vv7/6U4mI5BdLR24Ahg4dSlxcHI0aNaJJkybMmDGD8+fP06dPHwB69epFuXLlmDJlCv7+/tSqVcvp9SVKlABIt13ESrVrmwHnP/8xO43PmQNffGEuHW/QwOrqRETcm+XhJjY2lpMnTzJ27FgSEhKoV68eq1evdkwyPnToEF5ehWpqkAhwrT9V+/bQu/e1/lSTJsGwYeZ1ckRExPVshuFZazqSkpIIDg4mMTGRoKAgq8sRD/H33/D44xAfbz5u1QreeQfKl7e2LhGRwiI7398aEhHJByEhsGRJ9vtT2e2waZO536ZN6kguIpIVCjci+SS7/ani4yE6Gm6/3dzv9tvNx6mjPyIikjGFG5F8lpX+VPHx0LUr/Pmn82uPHDG3K+CIiNyYwo2IBVL7U23e7Nyf6j//MS8COHhwxlc4Tt02ZIhOUYmI3IjCjYiFmjVz7k81ZQrUq5d+xCYtw4DDh81gJCIi6SnciFjs+v5Uv/2WtdcdO5a3dYmIFFYKNyIFRGp/qoYNs7Z/eHje1iMiUlgp3IgUIOXKwbZtEBx8431sNoiKghYt8q8uEZHCROFGpIApUsS8Hs6NGAbMmKErHIuI3IjCjUgB1KWLOQenXLn0zwUHw+7dZlNOERFJT+0XRAowu91cFfX77/Df/8LKlddWUvn6Qo8e5rLxunWtrVNEJK9l5/tb4UakELlyxbyA3/Tp8O2317bffjs89RTcey+oz6yIuCP1lhJxU0WKQGwsfPONOfE4Ntace7NxI3TsCFWrwuzZcO6c1ZWKiFhH4UakkLrtNvjwQ/jjD/j3v6FECdi3DwYOhMhIeOYZOHjQ6ipFRPKfwo1IIVe+PLzwgjkX59VX4dZbzaacL79stnbo1g22bs24nYOIiDtSuBFxE0WLwpNPwp498Pnn0KYNpKTAkiUQEwNNm8IHH5jzdkRE3JnCjYib8fKCe+6BdevMJeOPPQZ+frBjh7m6Kjra7GH1999WVyoikjcUbkTcWK1a8MYbZqPNSZMgLAyOHjW7j0dFwRNPmCM9IiLuROFGxAOUKQOjR8OBA/DOO1C/Ply8CK+9BjVqQPv2sGaN5uWIiHtQuBHxIH5+0LMnfPcdfPkl3H+/2atq9Wpo1w5q1jQDz4ULVlcqIpJzCjciHshmg5YtzQsC7tsHQ4ZA8eLmKaonnjBPWf3nP3DkiNWViohkn8KNiIerVMm84vGff5p/VqwIp06Zk46jo81JyDt2WF2liEjWKdyICABBQeYIzu+/myM6LVvC1avm8vEmTeAf/zCXlV+9anWlIiKZU7gRESfe3uZcnC+/NOfm9Oxptn3YssW8IOAtt8DUqXDmjNWViohkTOFGRG6oQQNzddXBgzBmDJQubd4fNsxs8TBokDlnR0SkIFG4EZGbCg+HiRPh0CHzujm1asH58zBrltnuoWNHs3mnlpKLSEGgcCMiWRYQYF7x+McfzSsg33uvGWg+/RTuuMO8fs7ChXDpktWViognU7gRkWyz2czeVZ99Br/+ava0CgyEH36APn2gQgUYPx6OH7e6UhHxRAo3IpIrVaua3cj//NPsTh4ZCSdOwIQJZsfyPn3M0CMikl8UbkTEJUqWhH//G/74AxYvhttug+Rk8zRVvXrmaaulS3XKSkTynsKNiLhUkSLw4IOwbZt5697dXF6+cSN06WKuuOraFd57D06ftrpaEXFHNsPwrPUNSUlJBAcHk5iYSFBQkNXliHiEw4fNU1fvvefc0sHbG1q3hk6dzFv58paVKCIFXHa+vxVuRCTfGIZ5YcBly2D5cvjpJ+fnGzQwQ07nzlC7tjlxWUQEFG4ypXAjkv/sdti8GY4dM6+Z06KFOWqzb58ZcpYvh6+/dr5OTsWK14JOTAz4+FhWvogUAAo3mVC4Eclf8fEweLC5mipVZCTMnGnOwUl14oS5tHz5cli71nnicUgIdOhghp277zaXnYuIZ1G4yYTCjUj+iY83Jw9f/3+Z1NNNS5Y4B5xU58+bAWfZMjPwnDp17bmAADPgdOoE990HZcrkWfkiUoAo3GRC4UYkf9jtEB3tPGKTls1mjuDs32+eorqRq1fNU1ap83QOHLj2nJeX2a089fRVpUquq19EChaFm0wo3Ijkj02b4Pbbb77fxo3miqmsMAyz9UNq0Pn+e+fna9e+FnQaNNCEZBF3kp3vb13nRkTyxLFjrt0PzLBSty6MGwc7d5qjODNnmiHK2xt274Znn4VGjcwWEAMHwvr1cOVKjj6CiBRSCjcikifCw127X0YqVIBBg+CLL8wJye+8Y87hCQw0r60zezbcdReULQuPPAIffwxnz+b8eCJSOOi0lIjkidQ5N0eOpJ9QDFmfc5MTFy/Chg3m6asVK+DkyWvP+fqaTT87dYKOHSEszLXHFpG8oTk3mVC4Eck/qaulwDng3Gy1lCvZ7fDNN2bQWbbMvLZO2jpuu82co9Opk9kEVEQKJoWbTCjciOSvjK5zExUFM2bkfbC5nmHAnj3Xgs6OHc7PV6tmBp3OnaFxY3M1logUDAo3mVC4Ecl/N7pCsdWOHDFPWy1bZq7aSjvxODzcPG3VqZPZ0dzPz7IyRQSFm0wp3IhIRhITYdUqM+isXOk88bh4cWjf3hzRad8eSpSwqEgRD6ZwkwmFGxG5mcuXzev0pF5PJ+1ydR8faNXK7HfVpIl501WSRfKewk0mFG5EJDtSUuC//702T2fPnvT7VKx4Leg0bQr166v/lYirKdxkQuFGRHLjt9/MCwNu327eMgo73t7m1ZKbNr0WeqpXLxjzjEQKK4WbTCjciIgrJSaaIzvbt8O335q3hIT0+xUrZl45OXV0p0kTKFdOLSJEskrhJhMKNyKSlwzDXIX17bfXRnd27DA7nV8vPNx5dKdxY9D/lkQypnCTCYUbEclvdrt5+ip1dGf7drMPlt3uvJ/NZl5rJ+3oTu3a5lWVRTydwk0mFG5EpCC4cMFs/pk6uvPtt2Yj0Ov5+ZkdztNOWK5USaezxPMo3GRC4UZECqoTJ8xTWGlPaZ0+nX6/UqWcR3caN9ZydHF/CjeZULgRkcLCMMxeWKlBZ/t2+P578zo816tUKf1y9ICA/K9ZJK8o3GRC4UZECrPkZPjxR+f5O7/+mn4/H5/0y9GrVdNydCm8FG4yoXAjIu7mzJlry9FTQ09Gy9GLF3dejt64sZajS+GhcJMJhRsRyamC2gD0eoZhdmFPO7rz3/9mvBw9ONi8wGCNGs5/VqigruhSsCjcZELhRkRyIj4eBg82Q0OqyEiYORO6dLGurqyy2+GXX5xHd376Kf1y9FQBAeZprOtDT+XKUKRI/tYuAgo3mVK4EZHsio+Hrl3NEZG0Uk/nLFlSOALO9S5fht9/N0PPnj3X/ty715zbk5EiRaBKlfSjPbfeqgnMkrcUbjKhcCMi2WG3Q3S084hNWjabOYKzf3/BPEWVE1evmp/n+tCzZ0/Gp7bA/HuoVCl96KlWTVddFtdQuMmEwo2IZMemTXD77Tffb+NGaN06r6uxVkqKGfKuDz2//JLx9XhSRUZmPK+ndOn8q10Kv+x8f/vkU00iIoXSsWOu3a8w8/KC8uXNW7t217YbhnkBwoxGeo4dMwPRn3/CunXO71emTMahJyJCK7gkdxRuREQyER7u2v3ckc0GoaHm7fpRrtOnzevwXB98DhyAkyfN21dfOb8mKCjj0BMdrRVckjU6LSUikonUOTdHjqSfUAzuOecmP5w/b05c/uUX5+Dzv/9lvoKratX0weeWW7SCyxNozk0mFG5EJLtSV0uBc8Ap7KulCqLUFVypp7XSruDKqO0EmKEyIsK8IGFmt8DA/P0s4loKN5lQuBGRnMjoOjdRUTBjhoJNfrDbM17B9csvN17Bdb0SJZzDTmRk+gBUurROfRVUCjeZULgRkZwqLFco9iSGAUePmqHzyJEb37IagIoUydookL9/3n4uSa/QhZtXX32Vl156iYSEBOrWrcusWbNo0qRJhvvGx8czefJk9u3bx5UrV6hSpQpPP/00PXv2zNKxFG5ERDyLYUBiYubh58gRc8VXVr8RQ0JuHoBCQrTqy5UK1VLwxYsXM3ToUObNm0fTpk2ZMWMGbdu2Ze/evZQtWzbd/qVKlWLUqFFUq1YNX19fPvvsM/r06UPZsmVp27atBZ9AREQKMpvNPCVVogTUrHnj/a5cMUflbhR+UkeHLl2Cv/82bz/+eOP38/NLPwp0/amwiAjw9XX1JxbLR26aNm1K48aNmT17NgApKSlERUUxcOBARowYkaX3aNCgAffeey+TJk266b4auRERkZwyDHN5+81GgU6ezPp7lilzLeyUKQOlSmV+CwryzBGhQjNyk5yczHfffcfIkSMd27y8vGjTpg3btm276esNw+CLL75g7969vPDCC3lZqoiICDbbtZBRu/aN97t8OeNRoLRzg44eNfdLvd7Prl1Zq8Hb++YBKO0tJMT8MzjYcyZLWxpu/vrrL+x2O6GhoU7bQ0ND+fXXX2/4usTERMqVK8fly5fx9vZmzpw53HXXXRnue/nyZS6nWT+YlJTkmuJFRERuwM/PvD5SdPSN9zEM89RW2rDz999w6lT6W+r2ixfNie2pgSg7bDYoWTJrQSjtrUQJ8LF8Ekv2FLJyTcWLF2fXrl2cO3eODRs2MHToUCpVqkTrDBq7TJkyhQkTJuR/kSIiIpmw2cyl56VLQ926WXvNxYvmabGMAlDaEHT97dw5M0ylPs6u4OCbh6C0t9KlzVNsVrF0zk1ycjKBgYEsWbKEzp07O7bHxcVx5swZli9fnqX3+ec//8nhw4dZs2ZNuucyGrmJiorSnBsREfEYycnXQtGNAlBGt8TEnB2vfn3YudO1n6HQzLnx9fWlYcOGbNiwwRFuUlJS2LBhAwMGDMjy+6SkpDgFmLT8/Pzw8/NzRbkiIiKFkq/vtf5f2XHlCpw5c/MQdH1gsrrju+WnpYYOHUpcXByNGjWiSZMmzJgxg/Pnz9OnTx8AevXqRbly5ZgyZQpgnmZq1KgRlStX5vLly6xcuZJ3332XuXPnWvkxREQKDV2MULKqSBHz9FJ2TzFZfQU9y8NNbGwsJ0+eZOzYsSQkJFCvXj1Wr17tmGR86NAhvNJM7z5//jxPPvkkf/75JwEBAVSrVo333nuP2NhYqz6CiEihkVEbichImDlTbSTEdaxeqm75dW7ym65zIyKeKrUB6PX/11cDUCkMsvP97SEr3kVEPJvdbo7YZPTP2dRtQ4aY+4kUdgo3IiIeYPNm51NR1zMMOHzY3E+ksFO4ERHxAMeOuXY/kYJM4UZExAOEh7t2P5GCTOFGRMQDtGhhroq60SoWmw2iosz9RAo7hRsREQ/g7W0u94b0ASf18YwZut6NuAeFGxERD9Gli7ncu1w55+2RkVoGLu7F8ov4iYhI/unSBTp10hWKxb0p3IiIeBhvb2jd2uoqRPKOTkuJiIiIW1G4EREREbeicCMiIiJuRXNuRESk0LLbNTla0lO4ERGRQik+3mwGmrZnVmSkeT0fLWv3bDotJSIihU58PHTtmr4Z6JEj5vb4eGvqkoJB4UZERAoVu90csTGM9M+lbhsyxNxPPJPCjYiIFCqbN6cfsUnLMODwYXM/8UwKNyIiUqgcO+ba/cT9KNyIiEihEh7u2v3E/SjciIhIodKihbkq6vru5qlsNoiKMvcTz6RwIyIihYq3t7ncG9IHnNTHM2boejeeTOFGREQKnS5dYMkSKFfOeXtkpLld17nxbLqIn4iIFEpdukCnTrpCsaSncCMiIoWWtze0bm11FVLQ6LSUiIiIuBWFGxEREXErOi0lIiJiMXU3dy2FGxEREQupu7nr6bSUiIiIRdTdPG8o3IiIiFhA3c3zjsKNiIiIBdTdPO8o3IiIiFhA3c3zjsKNiIiIBdTdPO8o3IiIiFhA3c3zjsKNiIiIBdTdPO8o3IiIiFhE3c3zhi7iJyIiYiF1N3c9hRsRERGLuUt384LSRkLhRkRERHKtILWR0JwbERERyZWC1kZC4UZERERyrCC2kVC4ERERkRwriG0kFG5EREQkxwpiGwmFGxEREcmxgthGQuFGREREcqwgtpFQuBEREZEcK4htJBRuREREJFcKWhsJXcRPREREcq0gtZFQuBERERGXKChtJHRaSkRERNyKwo2IiIi4FYUbERERcSsKNyIiIuJWFG5ERETErSjciIiIiFtRuBERERG3onAjIiIibkXhRkRERNyKx12h2DAMAJKSkiyuRERERLIq9Xs79Xs8Mx4Xbs6ePQtAVFSUxZWIiIhIdp09e5bg4OBM97EZWYlAbiQlJYWjR49SvHhxbNf3ZhfATMdRUVEcPnyYoKAgq8vxePp5FCz6eRQ8+pkULHn18zAMg7NnzxIREYGXV+azajxu5MbLy4vIyEiryygUgoKC9D+KAkQ/j4JFP4+CRz+TgiUvfh43G7FJpQnFIiIi4lYUbkRERMStKNxIOn5+fowbNw4/Pz+rSxH08yho9PMoePQzKVgKws/D4yYUi4iIiHvTyI2IiIi4FYUbERERcSsKNyIiIuJWFG5ERETErSjciMOUKVNo3LgxxYsXp2zZsnTu3Jm9e/daXZYAzz//PDabjSFDhlhdikc7cuQIjzzyCCEhIQQEBFC7dm3++9//Wl2WR7Lb7YwZM4aKFSsSEBBA5cqVmTRpUpb6DknuffXVV3To0IGIiAhsNhvLli1zet4wDMaOHUt4eDgBAQG0adOG33//Pd/qU7gRhy+//JL+/fvzzTffsG7dOq5cucLdd9/N+fPnrS7No+3YsYPXXnuNOnXqWF2KRzt9+jQxMTEUKVKEVatW8csvvzB16lRKlixpdWke6YUXXmDu3LnMnj2bPXv28MILL/Diiy8ya9Ysq0vzCOfPn6du3bq8+uqrGT7/4osv8sorrzBv3jy+/fZbihYtStu2bbl06VK+1Kel4HJDJ0+epGzZsnz55Ze0bNnS6nI80rlz52jQoAFz5szh2WefpV69esyYMcPqsjzSiBEj2LJlC5s3b7a6FAHuu+8+QkNDefPNNx3bHnjgAQICAnjvvfcsrMzz2Gw2li5dSufOnQFz1CYiIoKnn36aYcOGAZCYmEhoaCgLFy6ke/fueV6TRm7khhITEwEoVaqUxZV4rv79+3PvvffSpk0bq0vxeCtWrKBRo0Z069aNsmXLUr9+fV5//XWry/JYzZs3Z8OGDfz2228A/PDDD3z99de0b9/e4spk//79JCQkOP1/Kzg4mKZNm7Jt27Z8qcHjGmdK1qSkpDBkyBBiYmKoVauW1eV4pA8//JCdO3eyY8cOq0sR4I8//mDu3LkMHTqU//znP+zYsYNBgwbh6+tLXFyc1eV5nBEjRpCUlES1atXw9vbGbrfz3HPP0aNHD6tL83gJCQkAhIaGOm0PDQ11PJfXFG4kQ/379+enn37i66+/troUj3T48GEGDx7MunXr8Pf3t7ocwQz8jRo1YvLkyQDUr1+fn376iXnz5incWOCjjz7i/fff54MPPqBmzZrs2rWLIUOGEBERoZ+H6LSUpDdgwAA+++wzNm7cSGRkpNXleKTvvvuOEydO0KBBA3x8fPDx8eHLL7/klVdewcfHB7vdbnWJHic8PJwaNWo4batevTqHDh2yqCLP9swzzzBixAi6d+9O7dq16dmzJ0899RRTpkyxujSPFxYWBsDx48edth8/ftzxXF5TuBEHwzAYMGAAS5cu5YsvvqBixYpWl+Sx7rzzTnbv3s2uXbsct0aNGtGjRw927dqFt7e31SV6nJiYmHSXRvjtt9+oUKGCRRV5tgsXLuDl5fwV5u3tTUpKikUVSaqKFSsSFhbGhg0bHNuSkpL49ttvadasWb7UoNNS4tC/f38++OADli9fTvHixR3nRoODgwkICLC4Os9SvHjxdHOdihYtSkhIiOZAWeSpp56iefPmTJ48mQcffJDt27czf/585s+fb3VpHqlDhw4899xzlC9fnpo1a/L9998zbdo0Hn30UatL8wjnzp1j3759jsf79+9n165dlCpVivLlyzNkyBCeffZZqlSpQsWKFRkzZgwRERGOFVV5zhD5f0CGtwULFlhdmhiG0apVK2Pw4MFWl+HRPv30U6NWrVqGn5+fUa1aNWP+/PlWl+SxkpKSjMGDBxvly5c3/P39jUqVKhmjRo0yLl++bHVpHmHjxo0Zfl/ExcUZhmEYKSkpxpgxY4zQ0FDDz8/PuPPOO429e/fmW326zo2IiIi4Fc25EREREbeicCMiIiJuReFGRERE3IrCjYiIiLgVhRsRERFxKwo3IiIi4lYUbkRERMStKNyIiEey2WwsW7bM6jJEJA8o3IhIvuvduzc2my3drV27dlaXJiJuQL2lRMQS7dq1Y8GCBU7b/Pz8LKpGRNyJRm5ExBJ+fn6EhYU53UqWLAmYp4zmzp1L+/btCQgIoFKlSixZssTp9bt37+aOO+4gICCAkJAQHn/8cc6dO+e0z1tvvUXNmjXx8/MjPDycAQMGOD3/119/cf/99xMYGEiVKlVYsWKF47nTp0/To0cPypQpQ0BAAFWqVEkXxkSkYFK4EZECacyYMTzwwAP88MMP9OjRg+7du7Nnzx4Azp8/T9u2bSlZsiQ7duzg448/Zv369U7hZe7cufTv35/HH3+c3bt3s2LFCm655RanY0yYMIEHH3yQH3/8kXvuuYcePXpw6tQpx/F/+eUXVq1axZ49e5g7dy6lS5fOv78AEcm5fGvRKSLy/+Li4gxvb2+jaNGiTrfnnnvOMAyzQ/0TTzzh9JqmTZsa/fr1MwzDMObPn2+ULFnSOHfunOP5zz//3PDy8jISEhIMwzCMiIgIY9SoUTesATBGjx7teHzu3DkDMFatWmUYhmF06NDB6NOnj2s+sIjkK825ERFL3H777cydO9dpW6lSpRz3mzVr5vRcs2bN2LVrFwB79uyhbt26FC1a1PF8TEwMKSkp7N27F5vNxtGjR7nzzjszraFOnTqO+0WLFiUoKIgTJ04A0K9fPx544AF27tzJ3XffTefOnWnevHmOPquI5C+FGxGxRNGiRdOdJnKVgICALO1XpEgRp8c2m42UlBQA2rdvz8GDB1m5ciXr1q3jzjvvpH///rz88ssur1dEXEtzbkSkQPrmm2/SPa5evToA1atX54cffuD8+fOO57ds2YKXlxdVq1alePHiREdHs2HDhlzVUKZMGeLi4njvvfeYMWMG8+fPz9X7iUj+0MiNiFji8uXLJCQkOG3z8fFxTNr9+OOPadSoEf/4xz94//332b59O2+++SYAPXr0YNy4ccTFxTF+/HhOnjzJwIED6dmzJ6GhoQCMHz+eJ554grJly9K+fXvOnj3Lli1bGDhwYJbqGzt2LA0bNqRmzZpcvnyZzz77zBGuRKRgU7gREUusXr2a8PBwp21Vq1bl119/BcyVTB9++CFPPvkk4eHhLFq0iBo1agAQGBjImjVrGDx4MI0bNyYwMJAHHniAadOmOd4rLi6OS5cuMX36dIYNG0bp0qXp2rVrluvz9fVl5MiRHDhwgICAAFq0aMGHH37ogk8uInnNZhiGYXURIiJp2Ww2li5dSufOna0uRUQKIc25EREREbeicCMiIiJuRXNuRKTA0dlyEckNjdyIiIiIW1G4EREREbeicCMiIiJuReFGRERE3IrCjYiIiLgVhRsRERFxKwo3IiIi4lYUbkRERMStKNyIiIiIW/k/BPLpH1Zbfp0AAAAASUVORK5CYII=\",\n      \"text/plain\": [\n       \"<Figure size 640x480 with 1 Axes>\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"acc = history_dict['binary_accuracy']\\n\",\n    \"val_acc = history_dict['val_binary_accuracy']\\n\",\n    \"loss = history_dict['loss']\\n\",\n    \"val_loss = history_dict['val_loss']\\n\",\n    \"\\n\",\n    \"epochs = range(1, len(acc) + 1)\\n\",\n    \"\\n\",\n    \"# \\\"bo\\\" is for \\\"blue dot\\\"\\n\",\n    \"plt.plot(epochs, loss, 'bo', label='Training loss')\\n\",\n    \"# b is for \\\"solid blue line\\\"\\n\",\n    \"plt.plot(epochs, val_loss, 'b', label='Validation loss')\\n\",\n    \"plt.title('Training and validation loss')\\n\",\n    \"plt.xlabel('Epochs')\\n\",\n    \"plt.ylabel('Loss')\\n\",\n    \"plt.legend()\\n\",\n    \"\\n\",\n    \"plt.show()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 25,\n   \"id\": \"af51178e-fe0b-40ca-9260-2190fb52d960\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"image/png\": \"iVBORw0KGgoAAAANSUhEUgAAAkAAAAHHCAYAAABXx+fLAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABcgUlEQVR4nO3dd1hT598G8DuEvVWQJQKideJWioraSovaUnGideCottaNtmoVR63SWqu4qrV11Um1aP1Vq1WqdY+690RRFBQHCCpKOO8f5000EjCBwEnI/bmuXCZPTk6+J9Dm5pxnyARBEEBERERkQsykLoCIiIiopDEAERERkclhACIiIiKTwwBEREREJocBiIiIiEwOAxARERGZHAYgIiIiMjkMQERERGRyGICIiIjI5DAAEelB79694evrW6jXTpo0CTKZTL8FGZjr169DJpNh2bJlJfq+u3btgkwmw65du1Rt2v6siqtmX19f9O7dW6/7JCLdMQBRqSaTybS6vfoFSVRU+/fvx6RJk/Do0SOpSyGifJhLXQBRcVqxYoXa419//RXbt2/P0169evUivc/PP/+M3NzcQr12/PjxGDNmTJHen7RXlJ+Vtvbv34/Jkyejd+/ecHZ2Vnvu4sWLMDPj355EUmMAolKtR48eao8PHjyI7du352l/3ZMnT2Bra6v1+1hYWBSqPgAwNzeHuTn/UywpRflZ6YOVlZWk728ssrKyYGdnJ3UZVIrxzxAyeS1btkStWrVw9OhRNG/eHLa2tvjqq68AAH/88Qc++OADeHp6wsrKCv7+/pgyZQoUCoXaPl7vV6LsPzJjxgwsWrQI/v7+sLKyQqNGjXDkyBG112rqAySTyTB48GBs3LgRtWrVgpWVFWrWrImtW7fmqX/Xrl1o2LAhrK2t4e/vj59++knrfkV79uxB586dUbFiRVhZWcHb2xsjRozA06dP8xyfvb09kpOTER4eDnt7e7i6umLUqFF5PotHjx6hd+/ecHJygrOzMyIjI7W6FPTff/9BJpNh+fLleZ7btm0bZDIZ/vzzTwDAjRs38Pnnn6Nq1aqwsbFBuXLl0LlzZ1y/fv2N76OpD5C2NZ86dQq9e/dGpUqVYG1tDXd3d/Tt2xf3799XbTNp0iR88cUXAAA/Pz/VZVZlbZr6AF27dg2dO3dG2bJlYWtri7fffhubN29W20bZn+m3337D1KlTUaFCBVhbW6NVq1a4cuXKG49bl8/s0aNHGDFiBHx9fWFlZYUKFSqgV69eSEtLU23z7NkzTJo0CW+99Rasra3h4eGBDh064OrVq2r1vn55WVPfKuXv19WrV9G2bVs4ODige/fuALT/HQWACxcuoEuXLnB1dYWNjQ2qVq2KcePGAQB27twJmUyGDRs25Hnd6tWrIZPJcODAgTd+jlR68M9OIgD3799HmzZt0LVrV/To0QNubm4AgGXLlsHe3h5RUVGwt7fHP//8gwkTJiAjIwPff//9G/e7evVqPH78GJ9++ilkMhmmT5+ODh064Nq1a288E7F3717Ex8fj888/h4ODA+bMmYOOHTsiKSkJ5cqVAwAcP34crVu3hoeHByZPngyFQoGvv/4arq6uWh33unXr8OTJEwwcOBDlypXD4cOHMXfuXNy6dQvr1q1T21ahUCA0NBSBgYGYMWMGduzYgR9++AH+/v4YOHAgAEAQBLRr1w579+7FZ599hurVq2PDhg2IjIx8Yy0NGzZEpUqV8Ntvv+XZPi4uDmXKlEFoaCgA4MiRI9i/fz+6du2KChUq4Pr161iwYAFatmyJc+fO6XT2Tpeat2/fjmvXrqFPnz5wd3fH2bNnsWjRIpw9exYHDx6ETCZDhw4dcOnSJaxZswazZs2Ci4sLAOT7M0lNTUWTJk3w5MkTDB06FOXKlcPy5cvx0UcfYf369Wjfvr3a9t9++y3MzMwwatQopKenY/r06ejevTsOHTpU4HFq+5llZmYiODgY58+fR9++fVG/fn2kpaVh06ZNuHXrFlxcXKBQKPDhhx8iISEBXbt2xbBhw/D48WNs374dZ86cgb+/v9afv1JOTg5CQ0PRrFkzzJgxQ1WPtr+jp06dQnBwMCwsLDBgwAD4+vri6tWr+N///oepU6eiZcuW8Pb2xqpVq/J8pqtWrYK/vz+CgoJ0rpuMmEBkQgYNGiS8/mvfokULAYCwcOHCPNs/efIkT9unn34q2NraCs+ePVO1RUZGCj4+PqrHiYmJAgChXLlywoMHD1Ttf/zxhwBA+N///qdqmzhxYp6aAAiWlpbClStXVG0nT54UAAhz585VtYWFhQm2trZCcnKyqu3y5cuCubl5nn1qoun4YmJiBJlMJty4cUPt+AAIX3/9tdq29erVExo0aKB6vHHjRgGAMH36dFVbTk6OEBwcLAAQli5dWmA9Y8eOFSwsLNQ+s+zsbMHZ2Vno27dvgXUfOHBAACD8+uuvqradO3cKAISdO3eqHcurPytdatb0vmvWrBEACLt371a1ff/99wIAITExMc/2Pj4+QmRkpOrx8OHDBQDCnj17VG2PHz8W/Pz8BF9fX0GhUKgdS/Xq1YXs7GzVtrNnzxYACKdPn87zXq/S9jObMGGCAECIj4/Ps31ubq4gCIKwZMkSAYAwc+bMfLfR9NkLwsv/Nl79XJW/X2PGjNGqbk2/o82bNxccHBzU2l6tRxDE3y8rKyvh0aNHqra7d+8K5ubmwsSJE/O8D5VuvARGBLFfRp8+ffK029jYqO4/fvwYaWlpCA4OxpMnT3DhwoU37jciIgJlypRRPQ4ODgYgXvJ4k5CQELW/pGvXrg1HR0fVaxUKBXbs2IHw8HB4enqqtqtcuTLatGnzxv0D6seXlZWFtLQ0NGnSBIIg4Pjx43m2/+yzz9QeBwcHqx3Lli1bYG5urjojBAByuRxDhgzRqp6IiAi8ePEC8fHxqra///4bjx49QkREhMa6X7x4gfv376Ny5cpwdnbGsWPHtHqvwtT86vs+e/YMaWlpePvttwFA5/d99f0bN26MZs2aqdrs7e0xYMAAXL9+HefOnVPbvk+fPrC0tFQ91vZ3StvP7Pfff0edOnXynCUBoLqs+vvvv8PFxUXjZ1SUKR1e/Rloqju/39F79+5h9+7d6Nu3LypWrJhvPb169UJ2djbWr1+vaouLi0NOTs4b+wVS6cMARATAy8tL7UtF6ezZs2jfvj2cnJzg6OgIV1dX1f8o09PT37jf1/9nrAxDDx8+1Pm1ytcrX3v37l08ffoUlStXzrOdpjZNkpKS0Lt3b5QtW1bVr6dFixYA8h6ftbV1nss4r9YDiP1MPDw8YG9vr7Zd1apVtaqnTp06qFatGuLi4lRtcXFxcHFxwbvvvqtqe/r0KSZMmABvb29YWVnBxcUFrq6uePTokVY/l1fpUvODBw8wbNgwuLm5wcbGBq6urvDz8wOg3e9Dfu+v6b2UIxNv3Lih1l7Y3yltP7OrV6+iVq1aBe7r6tWrqFq1ql4775ubm6NChQp52rX5HVWGvzfVXa1aNTRq1AirVq1Sta1atQpvv/221v/NUOnBPkBEUP8rU+nRo0do0aIFHB0d8fXXX8Pf3x/W1tY4duwYRo8erdVQarlcrrFdEIRifa02FAoF3nvvPTx48ACjR49GtWrVYGdnh+TkZPTu3TvP8eVXj75FRERg6tSpSEtLg4ODAzZt2oRu3bqpfdkOGTIES5cuxfDhwxEUFAQnJyfIZDJ07dq1WIe4d+nSBfv378cXX3yBunXrwt7eHrm5uWjdunWxD61XKuzvRUl/ZvmdCXq907ySlZVVnukBdP0d1UavXr0wbNgw3Lp1C9nZ2Th48CDmzZun837I+DEAEeVj165duH//PuLj49G8eXNVe2JiooRVvVS+fHlYW1trHAGkzaig06dP49KlS1i+fDl69eqlat++fXuha/Lx8UFCQgIyMzPVzqhcvHhR631ERERg8uTJ+P333+Hm5oaMjAx07dpVbZv169cjMjISP/zwg6rt2bNnhZp4UNuaHz58iISEBEyePBkTJkxQtV++fDnPPnW5DOTj46Px81FeYvXx8dF6XwXR9jPz9/fHmTNnCtyXv78/Dh06hBcvXuTbmV95Zur1/b9+Rqsg2v6OVqpUCQDeWDcAdO3aFVFRUVizZg2ePn0KCwsLtcurZDp4CYwoH8q/tF/9y/r58+f48ccfpSpJjVwuR0hICDZu3Ijbt2+r2q9cuYK//vpLq9cD6scnCAJmz55d6Jratm2LnJwcLFiwQNWmUCgwd+5crfdRvXp1BAQEIC4uDnFxcfDw8FALoMraXz/jMXfu3HzPLuijZk2fFwDExsbm2ady/hptAlnbtm1x+PBhtSHYWVlZWLRoEXx9fVGjRg1tD6VA2n5mHTt2xMmTJzUOF1e+vmPHjkhLS9N45kS5jY+PD+RyOXbv3q32vC7//Wj7O+rq6ormzZtjyZIlSEpK0liPkouLC9q0aYOVK1di1apVaN26tWqkHpkWngEiykeTJk1QpkwZREZGYujQoZDJZFixYoXeLkHpw6RJk/D333+jadOmGDhwIBQKBebNm4datWrhxIkTBb62WrVq8Pf3x6hRo5CcnAxHR0f8/vvvWvVPyk9YWBiaNm2KMWPG4Pr166hRowbi4+N17h8TERGBCRMmwNraGv369ctzaeTDDz/EihUr4OTkhBo1auDAgQPYsWOHanqA4qjZ0dERzZs3x/Tp0/HixQt4eXnh77//1nhGsEGDBgCAcePGoWvXrrCwsEBYWJjGif3GjBmDNWvWoE2bNhg6dCjKli2L5cuXIzExEb///rveZo3W9jP74osvsH79enTu3Bl9+/ZFgwYN8ODBA2zatAkLFy5EnTp10KtXL/z666+IiorC4cOHERwcjKysLOzYsQOff/452rVrBycnJ3Tu3Blz586FTCaDv78//vzzT9y9e1frmnX5HZ0zZw6aNWuG+vXrY8CAAfDz88P169exefPmPP8t9OrVC506dQIATJkyRfcPk0qHEh93RiSh/IbB16xZU+P2+/btE95++23BxsZG8PT0FL788kth27ZtbxxarRzq+/333+fZJwC1Ibf5DYMfNGhQnte+PoRaEAQhISFBqFevnmBpaSn4+/sLv/zyizBy5EjB2to6n0/hpXPnzgkhISGCvb294OLiIvTv31813P71Ycp2dnZ5Xq+p9vv37ws9e/YUHB0dBScnJ6Fnz57C8ePHtRoGr3T58mUBgABA2Lt3b57nHz58KPTp00dwcXER7O3thdDQUOHChQt5Ph9thsHrUvOtW7eE9u3bC87OzoKTk5PQuXNn4fbt23l+poIgCFOmTBG8vLwEMzMztSHxmn6GV69eFTp16iQ4OzsL1tbWQuPGjYU///xTbRvlsaxbt06tXdOwck20/cyUn8fgwYMFLy8vwdLSUqhQoYIQGRkppKWlqbZ58uSJMG7cOMHPz0+wsLAQ3N3dhU6dOglXr15VbXPv3j2hY8eOgq2trVCmTBnh008/Fc6cOaP175cgaP87KgiCcObMGdXPx9raWqhataoQHR2dZ5/Z2dlCmTJlBCcnJ+Hp06cFfm5UeskEwYD+nCUivQgPD8fZs2c19k8hMnU5OTnw9PREWFgYFi9eLHU5JBH2ASIycq8vCXD58mVs2bIFLVu2lKYgIgO3ceNG3Lt3T61jNZkengEiMnIeHh6q9alu3LiBBQsWIDs7G8ePH0eVKlWkLo/IYBw6dAinTp3ClClT4OLiUujJK6l0YCdoIiPXunVrrFmzBikpKbCyskJQUBCmTZvG8EP0mgULFmDlypWoW7eu2mKsZJp4BoiIiIhMDvsAERERkclhACIiIiKTwz5AGuTm5uL27dtwcHAo0srGREREVHIEQcDjx4/h6en5xklEGYA0uH37Nry9vaUug4iIiArh5s2bqFChQoHbMABp4ODgAED8AB0dHSWuhoiIiLSRkZEBb29v1fd4QRiANFBe9nJ0dGQAIiIiMjLadF9hJ2giIiIyOQxAREREZHIYgIiIiMjkMAARERGRyWEAIiIiIpPDAEREREQmhwGIiIiITA4DEBEREZkcBiAiIiIyOZwJmoiIiEqEQgHs2QPcuQN4eADBwYBcLk0tDEBERERU7OLjgWHDgFu3XrZVqADMng106FDy9fASGBERERWr+HigUyf18AMAyclie3x8ydfEAERERETFRqEQz/wIQt7nlG3Dh4vblSQGICIiIio2e/bkPfPzKkEAbt4UtytJDEBERERUbO7c0e92+sIARERERMXGw0O/2+kLR4EREREZOEMaPq6r4GBxtFdysuZ+QDKZ+HxwcMnWxTNAREREBiw+HvD1Bd55B/j4Y/FfX19pRk4VhlwuDnUHxLDzKuXj2NiSD3QMQERERAbKEIePF0aHDsD69YCXl3p7hQpiuxTzAMkEQdMJKdOWkZEBJycnpKenw9HRUepyiIjIBCkU4pme/EZQKS8dJSYaz+Ww4r6Up8v3N/sAERERGSBdho+3bFliZRWJXG44tfISGBERkQEy1OHjpQUDEBERkQEy1OHjpYXkAWj+/Pnw9fWFtbU1AgMDcfjw4Xy3ffHiBb7++mv4+/vD2toaderUwdatW4u0TyIiIkOkHD7++sgpJZkM8PYu+eHjpYWkASguLg5RUVGYOHEijh07hjp16iA0NBR3797VuP348ePx008/Ye7cuTh37hw+++wztG/fHsePHy/0PomIqPRSKIBdu4A1a8R/S3q9qaIw1OHjpYYgocaNGwuDBg1SPVYoFIKnp6cQExOjcXsPDw9h3rx5am0dOnQQunfvXuh9apKeni4AENLT07V+DRERGZbffxeEChUEQewuLN4qVBDbjYmm4/D2Nr7jKAm6fH9Ldgbo+fPnOHr0KEJCQlRtZmZmCAkJwYEDBzS+Jjs7G9bW1mptNjY22Lt3b6H3SUREpU9pmT8HEOfIuX4d2LkTWL1a/DcxUZq5c0oTyYbBp6WlQaFQwM3NTa3dzc0NFy5c0Pia0NBQzJw5E82bN4e/vz8SEhIQHx8Pxf+f0yzMPgExWGVnZ6seZ2RkFPawiIhIYgoFMGyY5mUXBEG8fDR8ONCunfFcPjKk4eOlheSdoHUxe/ZsVKlSBdWqVYOlpSUGDx6MPn36wMysaIcRExMDJycn1c3b21tPFRMRUUnTZf4cMl2SBSAXFxfI5XKkpqaqtaempsLd3V3ja1xdXbFx40ZkZWXhxo0buHDhAuzt7VGpUqVC7xMAxo4di/T0dNXt5s2bRTw6IiKSCufPIW1IFoAsLS3RoEEDJCQkqNpyc3ORkJCAoKCgAl9rbW0NLy8v5OTk4Pfff0e7du2KtE8rKys4Ojqq3YiIyDhx/hzShqRLYURFRSEyMhINGzZE48aNERsbi6ysLPTp0wcA0KtXL3h5eSEmJgYAcOjQISQnJ6Nu3bpITk7GpEmTkJubiy+//FLrfRIRUemmnD8nOVlzPyDlGlqcP8e0SRqAIiIicO/ePUyYMAEpKSmoW7cutm7dqurEnJSUpNa/59mzZxg/fjyuXbsGe3t7tG3bFitWrICzs7PW+yQiojcr7kUri5Ny/pxOncSw82oI4vw5pMTV4DXgavBEZMri48VRVK92JK5QQQwVxjT0WtNxeHuL4ceYjoO0p8v3NwOQBgxARGSqlPPnvP7NoDxzsn69cYUHYz6TRbpjACoiBiAiMkUKBeDrm/8QcmXfmcREhggyTLp8fxvVPEBERFR8OH8OmRIGICIiAsD5c8i0MAAREREAzp9DpoUBiIiIALycP0fZ4fl1Mpk4iorz51BpwABEREQAXs6fA+QNQZw/h0obBiAiIlLp0EEc6u7lpd5eoYLxDYEnKoikM0ETEZHh6dABaNeO8+dQ6cYAREREecjlQMuWUldBVHx4CYyIiIhMDs8AERHpEZdeIDIODEBERHpSWhYRJTIFvARGRKQHykVEX19KIjlZbI+Pl6YuItKMAYiIqIgUCvHMj6alpZVtw4eL2xGRYWAAIiIqIi4iSmR8GICIiIqIi4gSGR8GICKiIuIiokTGhwGIiKiIuIgokfFhACIiKiIuIkpkfBiAiIj0gIuIEhkXToRIRKQnXESU9CUnB0hLA+7efXm7dw94+BAwMwPMzcWbXK7+75vaivr8621mZvlf+jV0DEBERHrERURJk9xcMby8GmZeDTevtz14IHXF2itsEOvdGxg8WMK6pXtrIiIi4yQIwOPH2oWZu3fFszm6ToRpZga4uADly4s3V1egTBnxvRUK8SxRTs7L+5ratLmvzbaaJvlUUm6Tna3b8YWG6ra9vjEAEZFB4CKiJLWnT7ULM8rHun7hA2KAUYYZZbB5NeC8+rhMGcP5byA3V/xvVJ8hq1IlaY+JAYiIJMdFRKm4PH0K3LghzsRdUJi5exfIzNR9//b22oUZV1fxbI6lpf6PsSSYmYk3CwupK9EfBiAikpRyEdHXT7ErFxHlCCoqyIsXQFIScP06kJgo3l69n5Ki2/4sLbULM8p/bW2L46ioJMgEoaAre6YpIyMDTk5OSE9Ph6Ojo9TlEJVaCgXg65v/OloymXgmKDHRcC4FUMlSKMQwrAw1rwedW7fEyzMFcXAAKlYE3N3zDzPKm4OD8Y5qIt2+v3kGiIgko8siohxZVToJApCaqvnszfXr4tmdFy8K3oe1NeDnJ4bpV/9V3i9blqGG8mIAIiLJcBHR0k8QxCHdmsKN8t9nzwreh4WFeAbn9XCjfOzmxoBDumMAIiLJcBHR0iEjI/8+ONevi8PFC2JmJl7qfP3MjfK+pycvgZL+MQARkWSUi4gmJ2ueZ0TZB4iLiErryRNxJFV+AUebSfvc3TWHG19fcaFYYx0dRcaLAYiIJKNcRLRTJzHsvBqCSuMiogqF2J/l1VtOTt42Kdtff+7RI7GPzpuUK6c53Pj5AT4+gI1NMX+4RDpiACIiSSkXEdU0D1BsrOEPgX/8GLh4EbhwQf2Wmpo3VBjzmFsHh7x9b1697+AgdYVEumEAIiLJGfoiooIghrPXQ87Fi+Llu6IwNxc7+Spvrz8uqfb8nrO3FwNOmTLsaEylCwMQERkEQ1hE9Nkz4PJlzUEnKyv/17m7A9Wqqd+8vMR+LQUFEXNzhgoiqTAAEZFJEQRx+YPXQ86FC2KH3vwuU5mbA1WqiOGmatWXQadqVcDZuSSPgIj0gQGIiEqlFy+Aa9c09895+DD/1zk7A9Wr5z2j4+dXutZBIjJ1DEBEZNQePdIccq5cETseayKTiYHm9bM51aqJSyPwshRR6ccAREQGLzdXXBJBU9+cgha7tLXNeyanWjWgcmUOyyYydQxARGQwnjwBLl3KG3QuXQKePs3/dV5eec/kKDsim5mVXP1EZDwYgIioxAmC2OH41Cn12+XL+XdCtrR82Qn51dtbbwFvWPSZiCgPBiAiI6dQGO78OYC4TtTp0+pB5/Tp/NeHKltWcydkX19xJBYRkT7wfydERiw+XvMMyrNnl/wMygqF2PH49bM6169r3t7SEqhRA6hdW/3m5laiZRORiWIAIjJS8fHiGlqvXzJKThbb168vvhCUlpb3rM6ZM+JEgppUqJA36Lz1FoeVE5F0ZIJgzKvTFI+MjAw4OTkhPT0djuxcQAZIoRAvCb165udVylXUExOLdjns+XNxpNXrZ3Vu39a8va0tUKuWetAJCBAvaxERFTddvr95BojICO3Zk3/4AcSzQjdvittps7yEIIjDyV8POufPixMKalKpUt6zOpUqGVb/IyKi/DAAERmhO3cKv93Tp8DZs+odkk+dEi9raeLk9PJMjjLo1KrF1b+JyLgxABEZIQ8P7baTyYBNm/IONc/NzbutmZk4l87rZ3W8vTkzMhGVPpJPETZ//nz4+vrC2toagYGBOHz4cIHbx8bGomrVqrCxsYG3tzdGjBiBZ6/0vJw0aRJkMpnarVq1asV9GEQlKjhY7ONTUDCRyYBu3YB27YDoaGDdOrE/T24u4OICtGoFjBgBLF0KHD0KZGYC584Ba9cCX30FfPghULEiww8RlU6SngGKi4tDVFQUFi5ciMDAQMTGxiI0NBQXL15E+fLl82y/evVqjBkzBkuWLEGTJk1w6dIl9O7dGzKZDDNnzlRtV7NmTezYsUP12JyTh1ApI5cDs2YBnTvnv40giKOs8htqzmBDRKZM0mQwc+ZM9O/fH3369AEALFy4EJs3b8aSJUswZsyYPNvv378fTZs2xccffwwA8PX1Rbdu3XDo0CG17czNzeHu7l78B0AkgeRkYNkyYMkSzc87OAD9+om3qlU51JyISBPJAtDz589x9OhRjB07VtVmZmaGkJAQHDhwQONrmjRpgpUrV+Lw4cNo3Lgxrl27hi1btqBnz55q212+fBmenp6wtrZGUFAQYmJiULFixXxryc7ORnZ2tupxRkZGEY+OSL+ePwf+/BNYvBjYuvVlHx4HB6BrV6B+fXE5CE9Pw5sJmojIEEkWgNLS0qBQKOD22rSvbm5uuHDhgsbXfPzxx0hLS0OzZs0gCAJycnLw2Wef4auvvlJtExgYiGXLlqFq1aq4c+cOJk+ejODgYJw5cwYO+QxbiYmJweTJk/V3cER6cuGCGHqWLwfu3XvZHhwsnuHp1Amws5OuPiIiYyV5J2hd7Nq1C9OmTcOPP/6IY8eOIT4+Hps3b8aUKVNU27Rp0wadO3dG7dq1ERoaii1btuDRo0f47bff8t3v2LFjkZ6errrdvHmzJA6HSKPMTPHyVtOm4ppYM2aI4cfNDRg9WuzIvHs3EBnJ8ENEVFiSnQFycXGBXC5HamqqWntqamq+/Xeio6PRs2dPfPLJJwCAgIAAZGVlYcCAARg3bhzMzPLmOWdnZ7z11lu4cuVKvrVYWVnBysqqCEdDVDSCABw6JJ7tWbtWDEGAeCmrbVvxbE/btuzPQ0SkL5KdAbK0tESDBg2QkJCgasvNzUVCQgKCgoI0vubJkyd5Qo78/zs75LeiR2ZmJq5evQoPbSdOISpB9+4BM2eKEwsGBQG//CKGn8qVgZgYIClJnMenXTuGHyIifZJ0FFhUVBQiIyPRsGFDNG7cGLGxscjKylKNCuvVqxe8vLwQExMDAAgLC8PMmTNRr149BAYG4sqVK4iOjkZYWJgqCI0aNQphYWHw8fHB7du3MXHiRMjlcnTr1k2y4yR6lUIBbN8unu3544+XS03Y2Ih9evr1A5o35zB1IqLiJGkAioiIwL179zBhwgSkpKSgbt262Lp1q6pjdFJSktoZn/Hjx0Mmk2H8+PFITk6Gq6srwsLCMHXqVNU2t27dQrdu3XD//n24urqiWbNmOHjwIFxdXUv8+Ihedf262Ldn2TJxnS6lBg2ATz4RJy10cpKqOiIi08LV4DXgavCkL8+eARs3imd7EhLEvj4AUKYM0KOHeLanTh1JSyQiKjW4GjyRxE6eFEPPypXAw4cv21u1Es/2hIcD1taSlUdEZPIYgIj0JD0dWLNGDD7//feyvUIFoE8f8ebnJ119RET0EgMQUREIgjgnz+LFwPr1wNOnYruFhThyq18/4L33ODMzEZGhYQAiKoQ7d8TZmZcsAS5fftleo4YYenr2BNjvnojIcDEAEWkpJwfYskWcq2fLFnE4OwDY2wMREWLfnsBADl8nIjIGDEBEb3DpknimZ/lyICXlZXuTJuLZni5dxBBERETGgwGISIOsLLFPz+LFwJ49L9tdXcU1uPr2FdfpIiIi48QARPT/BEEcvbV4sTiaKyNDbDczA1q3Fs/2fPghYGkpbZ1ERFR0DEBkshQK8ezOpUvAqVPiaK7Tp18+7+cnhp7ISHEoOxERlR4MQGSS4uOBIUOA27fV2y0sgM6dxeDTsqV49oeIiEofBiAyOfHxQMeOmp978UJ87t13S7YmIiIqWfz7lkyKQgEMGpT/8zIZMHz4yyHuRERUOjEAkUlZvVp9KPvrBEFcqf3VkV9ERFT6MACRyTh1Chg8WLtt79wp3lqIiEhaDEBkEg4eBFq0eDm0/U08PIq3HiIikhYDEJV6O3YAISHAo0dAUBDg6Zn/chUyGeDtDQQHl2iJRERUwhiAqFT74w/ggw/EmZ3ffx/Yvh2YO1d87vUQpHwcG8vV24mISjsGICq1Vq4Uh7Q/fw506ABs2gTY2Yn3168HvLzUt69QQWzv0EGaeomIqORwHiAqlX788eVw98hIcQV381d+2zt0ANq1E0d73bkj9vkJDuaZHyIiU8EARKVOTAzw1Vfi/SFDxEtammZ0lsvF2Z6JiMj08BIYlRqCAIwZ8zL8REcDs2dzOQsiIsqLZ4CoVMjNFS95LVwoPp4xAxg5UtqaiIjIcDEAkdF78QLo3Vuc5VkmA376CejfX+qqiIjIkDEAkVF79gyIiBBHeJmbAytWAF27Sl0VEREZOgYgMlqPH4sjuXbuBKytxSHsH3wgdVVERGQMGIDIKD14ALRtCxw6BNjbA3/+KS51QUREpA0GIDI6KSnirM6nTwNlywJbtwKNGkldFRERGRMGIDIqN26I63pduSJOXvj330CtWlJXRURExoYBiIzGhQvAe+8Bt24Bvr7iIqf+/lJXRURExohTxJFROH4caN5cDD/VqwN79zL8EBFR4TEAkcHbtw945x3g3j2gQQNg9+68C5kSERHpggGIDNrff4sdntPTxcVKExIAFxepqyIiImPHAEQG6/ffgQ8/BJ48AVq3Fkd7OTlJXRUREZUGDEBkkJYvB7p0EZe56NwZ+OMPwNZW6qqIiKi0YAAigzN3rri2V24u0K8fsGYNYGkpdVVERFSaMACRwRAE4JtvgKFDxccjRgA//wzI5dLWRUREpQ8DEBkEQQC++AKIjhYfT5oE/PCDuLo7ERGRvnEiRJKcQgEMHCie7QGAWbOA4cMlLYmIiEo5BiCS1PPnQK9eQFwcYGYmhqC+faWuioiISjsGIJLM06dAp07Ali2AhQWwapU44ouIiKi4MQCRJDIygI8+Av79F7CxAeLjxbl+iIiISgIDEJW4+/fFsPPff4CDA7B5szjLMxERUUlhAKISdfu2uKL7uXPikhZbt4rrexEREZUkBiAqMdeuASEhQGIi4OkJ7NghruxORERU0jgPEJWIc+fEy1yJiUClSsDevQw/REQkHQYgKnZHjwLNm4uXv2rWBPbsAfz8pK6KiIhMGQMQFavdu4F33hE7PjdqJI768vSUuioiIjJ1DEBUbLZsAUJDgcePgRYtgIQEoFw5qasiIiIygAA0f/58+Pr6wtraGoGBgTh8+HCB28fGxqJq1aqwsbGBt7c3RowYgWfPnhVpn6R/v/0GtGsHPHsGfPAB8Ndf4pB3IiIiQyBpAIqLi0NUVBQmTpyIY8eOoU6dOggNDcXdu3c1br969WqMGTMGEydOxPnz57F48WLExcXhq6++KvQ+Sf8WLwa6dQNycoCuXYENG8TJDomIiAyFTBAEQao3DwwMRKNGjTBv3jwAQG5uLry9vTFkyBCMGTMmz/aDBw/G+fPnkZCQoGobOXIkDh06hL179xZqn5pkZGTAyckJ6enpcHR0LOphmpRZs4CoKPF+//7AggWAXC5tTUREZBp0+f6W7AzQ8+fPcfToUYSEhLwsxswMISEhOHDggMbXNGnSBEePHlVd0rp27Rq2bNmCtm3bFnqfAJCdnY2MjAy1G+lGEICJE1+Gn1GjgJ9+YvghIiLDJNlEiGlpaVAoFHBzc1Nrd3Nzw4ULFzS+5uOPP0ZaWhqaNWsGQRCQk5ODzz77THUJrDD7BICYmBhMnjy5iEdkunJzxeAze7b4+JtvgK++AmQyaesiIiLKj+SdoHWxa9cuTJs2DT/++COOHTuG+Ph4bN68GVOmTCnSfseOHYv09HTV7ebNm3qquPRTKIBPPnkZfubMAcaNY/ghIiLDJtkZIBcXF8jlcqSmpqq1p6amwt3dXeNroqOj0bNnT3zyyScAgICAAGRlZWHAgAEYN25cofYJAFZWVrCysiriEZme7GygRw9g/XrAzAxYsgSIjJS6KiIiojeT7AyQpaUlGjRooNahOTc3FwkJCQgKCtL4midPnsDMTL1k+f93MhEEoVD7pMJ58kQc5r5+PWBpCaxbx/BDRETGQ9LFUKOiohAZGYmGDRuicePGiI2NRVZWFvr06QMA6NWrF7y8vBATEwMACAsLw8yZM1GvXj0EBgbiypUriI6ORlhYmCoIvWmfVHRZWUDr1uJ6Xra2wMaN4grvRERExkLSABQREYF79+5hwoQJSElJQd26dbF161ZVJ+akpCS1Mz7jx4+HTCbD+PHjkZycDFdXV4SFhWHq1Kla75OKbtYsMfw4OQGbNwNNm0pdERERkW4knQfIUHEeoPw9ewb4+AB37wIrVwLdu0tdERERkahY5wHy9fXF119/jaSkpEIXSMZr5Uox/Hh7A126SF0NERFR4egcgIYPH474+HhUqlQJ7733HtauXYvs7OziqI0MTG4u8MMP4v3hwwELC0nLISIiKrRCBaATJ07g8OHDqF69OoYMGQIPDw8MHjwYx44dK44ayUBs2QJcuAA4Oopz/xARERmrQg+Dr1+/PubMmYPbt29j4sSJ+OWXX9CoUSPUrVsXS5YsAbsWlT4zZoj/fvqpGIKIiIiMVaFHgb148QIbNmzA0qVLsX37drz99tvo168fbt26ha+++go7duzA6tWr9VkrSejIEeDffwFzc2DoUKmrISIiKhqdA9CxY8ewdOlSrFmzBmZmZujVqxdmzZqFatWqqbZp3749GjVqpNdCSVrKvj/dugEVKkhbCxERUVHpHIAaNWqE9957DwsWLEB4eDgsNPSE9fPzQ9euXfVSIEnv+nVxpmcAGDlS0lKIiIj0QucAdO3aNfj4+BS4jZ2dHZYuXVroosiwxMaKI8Deew+oU0fqaoiIiIpO507Qd+/exaFDh/K0Hzp0CP/9959eiiLD8fAh8Msv4v1Ro6SthYiISF90DkCDBg3CzZs387QnJydj0KBBeimKDMdPP4lrfwUEcL0vIiIqPXQOQOfOnUP9+vXztNerVw/nzp3TS1FkGLKzgTlzxPujRgEymbT1EBER6YvOAcjKygqpqal52u/cuQNzc0nXViU9W7MGuHMH8PQE2KediIhKE50D0Pvvv4+xY8ciPT1d1fbo0SN89dVXeI/XSEoNQXg58eGwYYClpbT1EBER6ZPOp2xmzJiB5s2bw8fHB/Xq1QMAnDhxAm5ublixYoXeCyRpbNsGnD0L2NsDAwZIXQ0REZF+6RyAvLy8cOrUKaxatQonT56EjY0N+vTpg27dummcE4iMk/LsT//+gLOzpKUQERHpnUzgol15ZGRkwMnJCenp6XA0wUWvjh8H6tcH5HLg6lXgDdM+ERERGQRdvr8L3Wv53LlzSEpKwvPnz9XaP/roo8LukgyEctmLLl00hx+FAtizR+wg7eEBBAeLYYmIiMhYFGom6Pbt2+P06dOQyWSqVd9l/z9GWqFQ6LdCKlE3bwJr14r3NS17ER8vdoq+detlW4UKwOzZQIcOJVMjERFRUek8CmzYsGHw8/PD3bt3YWtri7Nnz2L37t1o2LAhdu3aVQwlUkmaPVs8w/POO0CDBurPxccDnTqphx8ASE4W2+PjS65OIiKiotA5AB04cABff/01XFxcYGZmBjMzMzRr1gwxMTEYOnRocdRIJSQ9HVi0SLz/+rIXCoV45kdTjzFl2/Dh4nZERESGTucApFAo4ODgAABwcXHB7du3AQA+Pj64ePGifqujEvXzz8Djx0CNGkDr1urP7dmT98zPqwRBvHy2Z0/x1khERKQPOvcBqlWrFk6ePAk/Pz8EBgZi+vTpsLS0xKJFi1CpUqXiqJFKwPPn4qrvgNj3x+y1aHznjnb70XY7IiIiKekcgMaPH4+srCwAwNdff40PP/wQwcHBKFeuHOLi4vReIJWM334T+/K4uQHdu+d93sNDu/1oux0REZGU9DIP0IMHD1CmTBnVSDBjZ2rzAAkCUK8ecPIkMHUq8NVXebdRKABfXzEkafqNkcnE0WCJiRwST0RE0tDl+1unPkAvXryAubk5zpw5o9ZetmzZUhN+TFFCghh+bG2Bzz7TvI1cLo4QA/KuCq98HBvL8ENERMZBpwBkYWGBihUrcq6fUka57EW/fkDZsvlv16EDsH494OWl3l6hgtjOeYCIiMhY6HwJbPHixYiPj8eKFStQtqBvSyNmSpfATp0C6tQROz1fvgxo04+dM0ETEZEhKtalMObNm4crV67A09MTPj4+sLOzU3v+2LFjuu6SJDRzpvhvx47ahR9ADDstWxZbSURERMVO5wAUHh5eDGWQFJKTgdWrxfualr0gIiIqrXQOQBMnTiyOOkgCc+cCL16Il7ACA6WuhoiIqOToPBM0lQ6PHwMLF4r3X1/2goiIqLTT+QyQmZlZgUPeOULMOCxeLK799dZbwIcfSl0NERFRydI5AG3YsEHt8YsXL3D8+HEsX74ckydP1lthVHxycoBZs8T7mpa9ICIiKu30MhM0AKxevRpxcXH4448/9LE7SZX2YfBr1wLdugGursCNG4CNjdQVERERFV2xzQRdkLfffhsJCQn62h0VE0EAvv9evD94MMMPERGZJr0EoKdPn2LOnDnwen2KYDI4//4LHDsGWFsDn38udTVERETS0LkP0OuLngqCgMePH8PW1hYrV67Ua3Gkf8plL/r0AVxcpK2FiIhIKjoHoFmzZqkFIDMzM7i6uiIwMBBlypTRa3GkX+fOAZs3i4uXjhghdTVERETS0TkA9e7duxjKoJKgXPYiPByoUkXSUoiIiCSlcx+gpUuXYt26dXna161bh+XLl+ulKNK/lBRgxQrxPic+JCIiU6dzAIqJiYGLhs4j5cuXx7Rp0/RSFOnfvHnA8+dAUBDQpInU1RAREUlL5wCUlJQEPz+/PO0+Pj5ISkrSS1GkX1lZwI8/ivd59oeIiKgQAah8+fI4depUnvaTJ0+iXLlyeimK9GvpUuDhQ8DfH2jXTupqiIiIpKdzAOrWrRuGDh2KnTt3QqFQQKFQ4J9//sGwYcPQtWvX4qiRikCheNn5OSoKkMulrYeIiMgQ6DwKbMqUKbh+/TpatWoFc3Px5bm5uejVqxf7ABmgDRuAxESgXDmAA/iIiIhEOgcgS0tLxMXF4ZtvvsGJEydgY2ODgIAA+Pj4FEd9VASvLnvx+eeAra209RARERkKnQOQUpUqVVCFk8kYtH37gMOHASsrYNAgqashIiIyHDr3AerYsSO+++67PO3Tp09H586d9VIU6Ydy2YtevQA3N2lrISIiMiQ6B6Ddu3ejbdu2edrbtGmD3bt3F6qI+fPnw9fXF9bW1ggMDMThw4fz3bZly5aQyWR5bh988IFqm969e+d5vnXr1oWqzVhdvAhs2iTej4qSthYiIiJDo/MlsMzMTFhaWuZpt7CwQEZGhs4FxMXFISoqCgsXLkRgYCBiY2MRGhqKixcvonz58nm2j4+Px/Pnz1WP79+/jzp16uQ5+9S6dWssXbpU9djKykrn2ozZrFliH6CwMKBaNamrISIiMiw6nwEKCAhAXFxcnva1a9eiRo0aOhcwc+ZM9O/fH3369EGNGjWwcOFC2NraYsmSJRq3L1u2LNzd3VW37du3w9bWNk8AsrKyUtvOlBZqvXsXUK5KwokPiYiI8tL5DFB0dDQ6dOiAq1ev4t133wUAJCQkYPXq1Vi/fr1O+3r+/DmOHj2KsWPHqtrMzMwQEhKCAwcOaLWPxYsXo2vXrrCzs1Nr37VrF8qXL48yZcrg3XffxTfffJPvRI3Z2dnIzs5WPS7MmSxD8uOPwLNnQKNGQHCw1NUQEREZHp3PAIWFhWHjxo24cuUKPv/8c4wcORLJycn4559/ULlyZZ32lZaWBoVCAbfXeui6ubkhJSXlja8/fPgwzpw5g08++UStvXXr1vj111+RkJCA7777Dv/++y/atGkDhUKhcT8xMTFwcnJS3by9vXU6DkPy5Akwf754f9QoQCaTth4iIiJDVKhh8B988IGq03FGRgbWrFmDUaNG4ejRo/mGjOKwePFiBAQEoHHjxmrtr85IHRAQgNq1a8Pf3x+7du1Cq1at8uxn7NixiHqlp3BGRobRhqBffwXS0gBfX6BDB6mrISIiMkw6nwFS2r17NyIjI+Hp6YkffvgB7777Lg4ePKjTPlxcXCCXy5GamqrWnpqaCnd39wJfm5WVhbVr16Jfv35vfJ9KlSrBxcUFV65c0fi8lZUVHB0d1W7G6NVlL0aMAMwLPcsTERFR6aZTAEpJScG3336LKlWqoHPnznB0dER2djY2btyIb7/9Fo0aNdLpzS0tLdGgQQMkJCSo2nJzc5GQkICgoKACX7tu3TpkZ2ejR48eb3yfW7du4f79+/Dw8NCpPmPzv/8Bly8DZcoAfftKXQ0REZHh0joAhYWFoWrVqjh16hRiY2Nx+/ZtzJ07t8gFREVF4eeff8by5ctx/vx5DBw4EFlZWejTpw8AoFevXmqdpJUWL16M8PDwPB2bMzMz8cUXX+DgwYO4fv06EhIS0K5dO1SuXBmhoaFFrteQKSc+HDgQsLeXthYiIiJDpvVFkr/++gtDhw7FwIED9boERkREBO7du4cJEyYgJSUFdevWxdatW1Udo5OSkmBmpp7TLl68iL179+Lvv//Osz+5XI5Tp05h+fLlePToETw9PfH+++9jypQppXouoAMHxKUvLC2BwYOlroaIiMiwyQRBELTZ8ODBg1i8eDHi4uJQvXp19OzZE127doWHhwdOnjxZqDmADFVGRgacnJyQnp5uNP2BOnUCfv9dvPS1eLHU1RAREZU8Xb6/tb4E9vbbb+Pnn3/GnTt38Omnn2Lt2rXw9PREbm4utm/fjsePHxe5cCqcq1eB+HjxPpe9ICIiejOdR4HZ2dmhb9++2Lt3L06fPo2RI0fi22+/Rfny5fHRRx8VR430BsplL9q2BWrWlLoaIiIiw1foYfAAULVqVUyfPh23bt3CmjVr9FUT6eD+fUC5agiXvSAiItJOkQKQklwuR3h4ODYplx+nErNgAfD0KVC/PtCypdTVEBERGQe9BCCSxrNngHImAi57QUREpD0GICO2cqW48nvFiuIoMCIiItIOA5CRys0FfvhBvD98OGBhIWk5RERERoUByEht2QJcuAA4OQGffCJ1NURERMaFAchIKZe9+PRTwMFB2lqIiIiMDQOQETpyBPj3X3G196FDpa6GiIjI+DAAGSFl35+PPwa8vKSthYiIyBgxABmZ69eBdevE+yNHSloKERGR0WIAMjKxseIIsPffB2rXlroaIiIi48QAZEQePgR++UW8z2UviIiICo8ByIj89BOQlSWe+QkJkboaIiIi48UAZCSys4E5c8T7XPaCiIioaBiAjMSaNcCdO+Kor4gIqashIiIybgxARkAQXk58OGwYYGkpbT1ERETGjgHICGzbBpw9K874PGCA1NUQEREZPwYgI6A8+9O/v7j2FxERERUNA5CBO34cSEgA5HLx8hcREREVHQOQgVMuexERAVSsKG0tREREpQUDkAG7eRNYu1a8z2UviIiI9IcByIDNng0oFMC77wL160tdDRERUenBAGSg0tOBRYvE+1z2goiISL8YgAzUzz8Djx8DNWoArVtLXQ0REVHpwgBkgJ4/F1d9B7jsBRERUXFgADJAv/0GJCcD7u7Axx9LXQ0REVHpwwBkYF5d9mLoUMDKStp6iIiISiMGIAOTkACcPAnY2QGffip1NURERKUTA5CBUZ796dcPKFtW2lqIiIhKKwYgA3LqlLjwqZkZMHy41NUQERGVXgxABkS57EWnToCfn7S1EBERlWYMQAbi1i1g9WrxPic+JCIiKl4MQAZi7lwgJwdo3hxo1EjqaoiIiEo3BiADkJEBLFwo3ufZHyIiouLHAGQAFi8WQ1DVqsAHH0hdDRERUenHACSxFy9eLnsxcqQ4AoyIiIiKF79uJbZ+PZCUBJQvD/TsKXU1REREpoEBSEKvLnsxeDBgbS1tPURERKaCAUhCu3YBx44BNjbAwIFSV0NERGQ6GIAkpDz706cP4OIibS1ERESmhAFIImfPAlu2ADIZMGKE1NUQERGZFgYgicycKf7bvj1QubK0tRAREZkaBiAJ3LkDrFwp3ufEh0RERCWPAUgC8+YBz58DTZoAQUFSV0NERGR6GIBKWGYmsGCBeJ9nf4iIiKTBAFTCli4FHj4U+/189JHU1RAREZkmgwhA8+fPh6+vL6ytrREYGIjDhw/nu23Lli0hk8ny3D54ZREtQRAwYcIEeHh4wMbGBiEhIbh8+XJJHEqBcnKAWbPE+1FRgFwubT1ERESmSvIAFBcXh6ioKEycOBHHjh1DnTp1EBoairt372rcPj4+Hnfu3FHdzpw5A7lcjs6dO6u2mT59OubMmYOFCxfi0KFDsLOzQ2hoKJ49e1ZSh6XRhg1AYiJQrhwQGSlpKURERCZN8gA0c+ZM9O/fH3369EGNGjWwcOFC2NraYsmSJRq3L1u2LNzd3VW37du3w9bWVhWABEFAbGwsxo8fj3bt2qF27dr49ddfcfv2bWzcuLEEjyyv48fFeX8GDQJsbSUthYiIyKRJGoCeP3+Oo0ePIiQkRNVmZmaGkJAQHDhwQKt9LF68GF27doWdnR0AIDExESkpKWr7dHJyQmBgYL77zM7ORkZGhtqtOEybBpw/DwwdWiy7JyIiIi1JGoDS0tKgUCjg5uam1u7m5oaUlJQ3vv7w4cM4c+YMPvnkE1Wb8nW67DMmJgZOTk6qm7e3t66HorWqVcVLYERERCQdyS+BFcXixYsREBCAxo0bF2k/Y8eORXp6uup28+ZNPVVIREREhkjSAOTi4gK5XI7U1FS19tTUVLi7uxf42qysLKxduxb9+vVTa1e+Tpd9WllZwdHRUe1GREREpZekAcjS0hINGjRAQkKCqi03NxcJCQkIesMUyevWrUN2djZ69Oih1u7n5wd3d3e1fWZkZODQoUNv3CcRERGZBnOpC4iKikJkZCQaNmyIxo0bIzY2FllZWejTpw8AoFevXvDy8kJMTIza6xYvXozw8HCUe61DjUwmw/Dhw/HNN9+gSpUq8PPzQ3R0NDw9PREeHl5Sh0VEREQGTPIAFBERgXv37mHChAlISUlB3bp1sXXrVlUn5qSkJJiZqZ+ounjxIvbu3Yu///5b4z6//PJLZGVlYcCAAXj06BGaNWuGrVu3wtrautiPh4iIiAyfTBAEQeoiDE1GRgacnJyQnp7O/kBERERGQpfvb6MeBUZERERUGAxAREREZHIYgIiIiMjkMAARERGRyWEAIiIiIpPDAEREREQmhwGIiIiITA4DEBEREZkcBiAiIiIyOQxAREREZHIYgIiIiMjkMAARERGRyWEAIiIiIpPDAEREREQmhwGIiIiITA4DEBEREZkcBiAiIiIyOQxAREREZHIYgIiIiMjkMAARERGRyWEAIiIiIpPDAEREREQmhwGIiIiITA4DEBEREZkcBiAiIiIyOQxAREREZHIYgIiIiMjkMAARERGRyWEAIiIiIpPDAEREREQmhwGIiIiITA4DEBEREZkcBiAiIiIyOQxAREREZHIYgIiIiMjkMAARERGRyWEAIiIiIpPDAEREREQmhwGIiIiITA4DEBEREZkcBiAiIiIyOQxAREREZHLMpS6AiIhMi0KhwIsXL6Qug4yQhYUF5HK5XvbFAERERCVCEASkpKTg0aNHUpdCRszZ2Rnu7u6QyWRF2g8DEBERlQhl+ClfvjxsbW2L/AVGpkUQBDx58gR3794FAHh4eBRpfwxARERU7BQKhSr8lCtXTupyyEjZ2NgAAO7evYvy5csX6XIYO0ETEVGxU/b5sbW1lbgSMnbK36Gi9iNjACIiohLDy15UVPr6HZI8AM2fPx++vr6wtrZGYGAgDh8+XOD2jx49wqBBg+Dh4QErKyu89dZb2LJli+r5SZMmQSaTqd2qVatW3IdBRESkFV9fX8TGxmq9/a5duyCTydh5XM8k7QMUFxeHqKgoLFy4EIGBgYiNjUVoaCguXryI8uXL59n++fPneO+991C+fHmsX78eXl5euHHjBpydndW2q1mzJnbs2KF6bG7Ork5ERKWBQgHs2QPcuQN4eADBwYCeRkXn8aYzDRMnTsSkSZN03u+RI0dgZ2en9fZNmjTBnTt34OTkpPN7Uf4kTQYzZ85E//790adPHwDAwoULsXnzZixZsgRjxozJs/2SJUvw4MED7N+/HxYWFgDEJP06c3NzuLu7F2vtRERUsuLjgWHDgFu3XrZVqADMng106KD/97tz547qflxcHCZMmICLFy+q2uzt7VX3BUGAQqHQ6g9uV1dXneqwtLTkd1oxkOwS2PPnz3H06FGEhIS8LMbMDCEhIThw4IDG12zatAlBQUEYNGgQ3NzcUKtWLUybNg0KhUJtu8uXL8PT0xOVKlVC9+7dkZSUVKzHQkRExSs+HujUST38AEBystgeH6//93R3d1fdnJycIJPJVI8vXLgABwcH/PXXX2jQoAGsrKywd+9eXL16Fe3atYObmxvs7e3RqFEjtSsSQN5LYDKZDL/88gvat28PW1tbVKlSBZs2bVI9//olsGXLlsHZ2Rnbtm1D9erVYW9vj9atW6sFtpycHAwdOhTOzs4oV64cRo8ejcjISISHh+d7vPfv30e3bt3g5eUFW1tbBAQEYM2aNWrb5ObmYvr06ahcuTKsrKxQsWJFTJ06VfX8rVu30K1bN5QtWxZ2dnZo2LAhDh06VIhPv/hJFoDS0tKgUCjg5uam1u7m5oaUlBSNr7l27RrWr18PhUKBLVu2IDo6Gj/88AO++eYb1TaBgYFYtmwZtm7digULFiAxMRHBwcF4/PhxvrVkZ2cjIyND7UZERIZBoRDP/AhC3ueUbcOHi9uVtDFjxuDbb7/F+fPnUbt2bWRmZqJt27ZISEjA8ePH0bp1a4SFhb3xD/HJkyejS5cuOHXqFNq2bYvu3bvjwYMH+W7/5MkTzJgxAytWrMDu3buRlJSEUaNGqZ7/7rvvsGrVKixduhT79u1DRkYGNm7cWGANz549Q4MGDbB582acOXMGAwYMQM+ePdX65o4dOxbffvstoqOjce7cOaxevVr1PZ6ZmYkWLVogOTkZmzZtwsmTJ/Hll18iNzdXi09SAoJEkpOTBQDC/v371dq/+OILoXHjxhpfU6VKFcHb21vIyclRtf3www+Cu7t7vu/z8OFDwdHRUfjll1/y3WbixIkCgDy39PR0HY+KiIg0efr0qXDu3Dnh6dOnOr92505BEKNOwbedO/VetsrSpUsFJyenV2raKQAQNm7c+MbX1qxZU5g7d67qsY+PjzBr1izVYwDC+PHjVY8zMzMFAMJff/2l9l4PHz5U1QJAuHLliuo18+fPF9zc3FSP3dzchO+//171OCcnR6hYsaLQrl07bQ9ZEARB+OCDD4SRI0cKgiAIGRkZgpWVlfDzzz9r3Pann34SHBwchPv37+v0Hroq6HcpPT1d6+9vyc4Aubi4QC6XIzU1Va09NTU132udHh4eeOutt9QmPqpevTpSUlLw/Plzja9xdnbGW2+9hStXruRby9ixY5Genq663bx5sxBHRERExeGVKzt62U6fGjZsqPY4MzMTo0aNQvXq1eHs7Ax7e3ucP3/+jWeAateurbpvZ2cHR0dH1YzHmtja2sLf31/12MPDQ7V9eno6UlNT0bhxY9XzcrkcDRo0KLAGhUKBKVOmICAgAGXLloW9vT22bdumqv38+fPIzs5Gq1atNL7+xIkTqFevHsqWLVvg+xgKyQKQpaUlGjRogISEBFVbbm4uEhISEBQUpPE1TZs2xZUrV9ROp126dAkeHh6wtLTU+JrMzExcvXq1wCmzrays4OjoqHYjIiLDoO2KB0VcGaFQXh/NNWrUKGzYsAHTpk3Dnj17cOLECQQEBOT7R7qScmCPkkwmK/DSkabtBU3XCHXw/fffY/bs2Rg9ejR27tyJEydOIDQ0VFW7chbm/LzpeUMj6TxAUVFR+Pnnn7F8+XKcP38eAwcORFZWlmpUWK9evTB27FjV9gMHDsSDBw8wbNgwXLp0CZs3b8a0adMwaNAg1TajRo3Cv//+i+vXr2P//v1o37495HI5unXrVuLHR0RERRccLI72ym9UukwGeHuL20lt37596N27N9q3b4+AgAC4u7vj+vXrJVqDk5MT3NzccOTIEVWbQqHAsWPHCnzdvn370K5dO/To0QN16tRBpUqVcOnSJdXzVapUgY2NjdqJi1fVrl0bJ06cKLDvkiGRdBh8REQE7t27hwkTJiAlJQV169bF1q1bVR2qkpKSYGb2MqN5e3tj27ZtGDFiBGrXrg0vLy8MGzYMo0ePVm2j7IF+//59uLq6olmzZjh48KDOww6JiMgwyOXiUPdOncSw8+qJDmUoio0tvvmAdFGlShXEx8cjLCwMMpkM0dHRknQCHjJkCGJiYlC5cmVUq1YNc+fOxcOHDwuc26hKlSpYv3499u/fjzJlymDmzJlITU1FjRo1AADW1tYYPXo0vvzyS1haWqJp06a4d+8ezp49i379+qFbt26YNm0awsPDERMTAw8PDxw/fhyenp75XtmRkuQzBA4ePBiDBw/W+NyuXbvytAUFBeHgwYP57m/t2rX6Ko2IiAxEhw7A+vWa5wGKjS2eeYAKY+bMmejbty+aNGkCFxcXjB49WpKRxaNHj0ZKSgp69eoFuVyOAQMGIDQ0tMDFQ8ePH49r164hNDQUtra2GDBgAMLDw5Genq7aJjo6Gubm5pgwYQJu374NDw8PfPbZZwDEri1///03Ro4cibZt2yInJwc1atTA/Pnzi/14C0MmFPWiYSmUkZEBJycnpKensz8QEZEePHv2DImJifDz84O1tXWh91OSM0GXJrm5uahevTq6dOmCKVOmSF1OkRT0u6TL97fkZ4CIiIi0JZcDLVtKXYXhu3HjBv7++2+0aNEC2dnZmDdvHhITE/Hxxx9LXZrBkHwxVCIiItIvMzMzLFu2DI0aNULTpk1x+vRp7NixA9WrV5e6NIPBM0BERESljLe3N/bt2yd1GQaNZ4CIiIjI5DAAERERkclhACIiIiKTwwBEREREJocBiIiIiEwOAxARERGZHAYgIiKiYtSyZUsMHz5c9djX1xexsbEFvkYmk2Hjxo1Ffm997ac0YgAiIiLSICwsDK1bt9b43J49eyCTyXDq1Cmd93vkyBEMGDCgqOWpmTRpEurWrZun/c6dO2jTpo1e36u0YAAiIiLSoF+/fti+fTtuvbr66v9bunQpGjZsiNq1a+u8X1dXV9ja2uqjxDdyd3eHlZVVibyXsWEAIiIi0uDDDz+Eq6srli1bptaemZmJdevWoV+/frh//z66desGLy8v2NraIiAgAGvWrClwv69fArt8+TKaN28Oa2tr1KhRA9u3b8/zmtGjR+Ott96Cra0tKlWqhOjoaLx48QIAsGzZMkyePBknT56ETCaDTCZT1fz6JbDTp0/j3XffhY2NDcqVK4cBAwYgMzNT9Xzv3r0RHh6OGTNmwMPDA+XKlcOgQYNU76XJ1atX0a5dO7i5ucHe3h6NGjXCjh071LbJzs7G6NGj4e3tDSsrK1SuXBmLFy9WPX/27Fl8+OGHcHR0hIODA4KDg3H16tUCP8ei4lIYREQkCUEAnjwp+fe1tQVksjdvZ25ujl69emHZsmUYN24cZP//onXr1kGhUKBbt27IzMxEgwYNMHr0aDg6OmLz5s3o2bMn/P390bhx4ze+R25uLjp06AA3NzccOnQI6enpav2FlBwcHLBs2TJ4enri9OnT6N+/PxwcHPDll18iIiICZ86cwdatW1XBw8nJKc8+srKyEBoaiqCgIBw5cgR3797FJ598gsGDB6uFvJ07d8LDwwM7d+7ElStXEBERgbp166J///4ajyEzMxNt27bF1KlTYWVlhV9//RVhYWG4ePEiKlasCADo1asXDhw4gDlz5qBOnTpITExEWloaACA5ORnNmzdHy5Yt8c8//8DR0RH79u1DTk7OGz+/IhEoj/T0dAGAkJ6ertf95uQIws6dgrB6tfhvTo5ed09EZLCePn0qnDt3Tnj69KmqLTNTEMQYVLK3zEzt6z5//rwAQNi5c6eqLTg4WOjRo0e+r/nggw+EkSNHqh63aNFCGDZsmOqxj4+PMGvWLEEQBGHbtm2Cubm5kJycrHr+r7/+EgAIGzZsyPc9vv/+e6FBgwaqxxMnThTq1KmTZ7tX97No0SKhTJkyQuYrH8DmzZsFMzMzISUlRRAEQYiMjBR8fHyEnFe+oDp37ixERETkW4smNWvWFObOnSsIgiBcvHhRACBs375d47Zjx44V/Pz8hOfPn2u1b02/S0q6fH/zElgJiY8HfH2Bd94BPv5Y/NfXV2wnIiLDVK1aNTRp0gRLliwBAFy5cgV79uxBv379AAAKhQJTpkxBQEAAypYtC3t7e2zbtg1JSUla7f/8+fPw9vaGp6enqi0oKCjPdnFxcWjatCnc3d1hb2+P8ePHa/0er75XnTp1YGdnp2pr2rQpcnNzcfHiRVVbzZo1IZfLVY89PDxw9+7dfPebmZmJUaNGoXr16nB2doa9vT3Onz+vqu/EiROQy+Vo0aKFxtefOHECwcHBsLCw0Ol4ioqXwEpAfDzQqZP4t8erkpPF9vXrgQ4dpKmNiEgqtrbAK91PSvR9ddGvXz8MGTIE8+fPx9KlS+Hv76/6Mv/+++8xe/ZsxMbGIiAgAHZ2dhg+fDieP3+ut3oPHDiA7t27Y/LkyQgNDYWTkxPWrl2LH374QW/v8arXg4hMJkNubm6+248aNQrbt2/HjBkzULlyZdjY2KBTp06qz8DGxqbA93vT88WFAaiYKRTAsGF5ww8gtslkwPDhQLt2wCuBm4io1JPJgFdORhisLl26YNiwYVi9ejV+/fVXDBw4UNUfaN++fWjXrh169OgBQOzTc+nSJdSoUUOrfVevXh03b97EnTt34OHhAQA4ePCg2jb79++Hj48Pxo0bp2q7ceOG2jaWlpZQKBRvfK9ly5YhKytLdRZo3759MDMzQ9WqVbWqV5N9+/ahd+/eaN++PQDxjND169dVzwcEBCA3Nxf//vsvQkJC8ry+du3aWL58OV68eFGiZ4F4CayY7dkDaBhBqSIIwM2b4nZERGR47O3tERERgbFjx+LOnTvo3bu36rkqVapg+/bt2L9/P86fP49PP/0UqampWu87JCQEb731FiIjI3Hy5Ens2bNHLego3yMpKQlr167F1atXMWfOHGzYsEFtG19fXyQmJuLEiRNIS0tDdnZ2nvfq3r07rK2tERkZiTNnzmDnzp0YMmQIevbsCTc3N90+lNfqi4+Px4kTJ3Dy5El8/PHHameMfH19ERkZib59+2Ljxo1ITEzErl278NtvvwEABg8ejIyMDHTt2hX//fcfLl++jBUrVqhdlisODEDF7M4d/W5HREQlr1+/fnj48CFCQ0PV+uuMHz8e9evXR2hoKFq2bAl3d3eEh4drvV8zMzNs2LABT58+RePGjfHJJ59g6tSpatt89NFHGDFiBAYPHoy6deti//79iI6OVtumY8eOaN26Nd555x24urpqHIpva2uLbdu24cGDB2jUqBE6deqEVq1aYd68ebp9GK+ZOXMmypQpgyZNmiAsLAyhoaGoX7++2jYLFixAp06d8Pnnn6NatWro378/srKyAADlypXDP//8g8zMTLRo0QINGjTAzz//XOxng2SCoOnijGnLyMiAk5MT0tPT4ejoWKR97doldnh+k507gZYti/RWREQG69mzZ0hMTISfnx+sra2lLoeMWEG/S7p8f/MMUDELDgYqVMh/zgmZDPD2FrcjIiKiksEAVMzkcmD2bPH+6yFI+Tg2lh2giYiIShIDUAno0EEc6u7lpd5eoQKHwBMREUmBw+BLSIcO4lD3PXvEDs8eHuJlL575ISIiKnkMQCVILmdHZyIiIkPAS2BERFRiOPCYikpfv0MMQEREVOyUc7o8kWL5dypVlL9DRZ0niJfAiIio2Mnlcjg7O6sW1bS1tVUtJ0GkDUEQ8OTJE9y9exfOzs5qC7YWBgMQERGVCHd3dwAocGVxojdxdnZW/S4VBQMQERGVCJlMBg8PD5QvXx4vXryQuhwyQhYWFkU+86PEAERERCVKLpfr7UuMqLDYCZqIiIhMDgMQERERmRwGICIiIjI57AOkgXKSpYyMDIkrISIiIm0pv7e1mSyRAUiDx48fAwC8vb0lroSIiIh09fjxYzg5ORW4jUzgvOR55Obm4vbt23BwcOBEXfnIyMiAt7c3bt68CUdHR6nLMXn8eRgW/jwMC38ehqU4fx6CIODx48fw9PSEmVnBvXx4BkgDMzMzVKhQQeoyjIKjoyP/h2JA+PMwLPx5GBb+PAxLcf083nTmR4mdoImIiMjkMAARERGRyWEAokKxsrLCxIkTYWVlJXUpBP48DA1/HoaFPw/DYig/D3aCJiIiIpPDM0BERERkchiAiIiIyOQwABEREZHJYQAiIiIik8MARFqLiYlBo0aN4ODggPLlyyM8PBwXL16Uuiz6f99++y1kMhmGDx8udSkmLTk5GT169EC5cuVgY2ODgIAA/Pfff1KXZZIUCgWio6Ph5+cHGxsb+Pv7Y8qUKVqtE0VFt3v3boSFhcHT0xMymQwbN25Ue14QBEyYMAEeHh6wsbFBSEgILl++XGL1MQCR1v79918MGjQIBw8exPbt2/HixQu8//77yMrKkro0k3fkyBH89NNPqF27ttSlmLSHDx+iadOmsLCwwF9//YVz587hhx9+QJkyZaQuzSR99913WLBgAebNm4fz58/ju+++w/Tp0zF37lypSzMJWVlZqFOnDubPn6/x+enTp2POnDlYuHAhDh06BDs7O4SGhuLZs2clUh+HwVOh3bt3D+XLl8e///6L5s2bS12OycrMzET9+vXx448/4ptvvkHdunURGxsrdVkmacyYMdi3bx/27NkjdSkE4MMPP4SbmxsWL16sauvYsSNsbGywcuVKCSszPTKZDBs2bEB4eDgA8eyPp6cnRo4ciVGjRgEA0tPT4ebmhmXLlqFr167FXhPPAFGhpaenAwDKli0rcSWmbdCgQfjggw8QEhIidSkmb9OmTWjYsCE6d+6M8uXLo169evj555+lLstkNWnSBAkJCbh06RIA4OTJk9i7dy/atGkjcWWUmJiIlJQUtf9vOTk5ITAwEAcOHCiRGrgYKhVKbm4uhg8fjqZNm6JWrVpSl2Oy1q5di2PHjuHIkSNSl0IArl27hgULFiAqKgpfffUVjhw5gqFDh8LS0hKRkZFSl2dyxowZg4yMDFSrVg1yuRwKhQJTp05F9+7dpS7N5KWkpAAA3Nzc1Nrd3NxUzxU3BiAqlEGDBuHMmTPYu3ev1KWYrJs3b2LYsGHYvn07rK2tpS6HIP5h0LBhQ0ybNg0AUK9ePZw5cwYLFy5kAJLAb7/9hlWrVmH16tWoWbMmTpw4geHDh8PT05M/D+IlMNLd4MGD8eeff2Lnzp2oUKGC1OWYrKNHj+Lu3buoX78+zM3NYW5ujn///Rdz5syBubk5FAqF1CWaHA8PD9SoUUOtrXr16khKSpKoItP2xRdfYMyYMejatSsCAgLQs2dPjBgxAjExMVKXZvLc3d0BAKmpqWrtqampqueKGwMQaU0QBAwePBgbNmzAP//8Az8/P6lLMmmtWrXC6dOnceLECdWtYcOG6N69O06cOAG5XC51iSanadOmeaaGuHTpEnx8fCSqyLQ9efIEZmbqX3NyuRy5ubkSVURKfn5+cHd3R0JCgqotIyMDhw4dQlBQUInUwEtgpLVBgwZh9erV+OOPP+Dg4KC6Tuvk5AQbGxuJqzM9Dg4Oefpf2dnZoVy5cuyXJZERI0agSZMmmDZtGrp06YLDhw9j0aJFWLRokdSlmaSwsDBMnToVFStWRM2aNXH8+HHMnDkTffv2lbo0k5CZmYkrV66oHicmJuLEiRMoW7YsKlasiOHDh+Obb75BlSpV4Ofnh+joaHh6eqpGihU7gUhLADTeli5dKnVp9P9atGghDBs2TOoyTNr//vc/oVatWoKVlZVQrVo1YdGiRVKXZLIyMjKEYcOGCRUrVhSsra2FSpUqCePGjROys7OlLs0k7Ny5U+N3RmRkpCAIgpCbmytER0cLbm5ugpWVldCqVSvh4sWLJVYf5wEiIiIik8M+QERERGRyGICIiIjI5DAAERERkclhACIiIiKTwwBEREREJocBiIiIiEwOAxARERGZHAYgIqJ8yGQybNy4UeoyiKgYMAARkUHq3bs3ZDJZnlvr1q2lLo2ISgGuBUZEBqt169ZYunSpWpuVlZVE1RBRacIzQERksKysrODu7q52K1OmDADx8tSCBQvQpk0b2NjYoFKlSli/fr3a60+fPo13330XNjY2KFeuHAYMGIDMzEy1bZYsWYKaNWvCysoKHh4eGDx4sNrzaWlpaN++PWxtbVGlShVs2rRJ9dzDhw/RvXt3uLq6wsbGBlWqVMkT2IjIMDEAEZHRio6ORseOHXHy5El0794dXbt2xfnz5wEAWVlZCA0NRZkyZXDkyBGsW7cOO3bsUAs4CxYswKBBgzBgwACcPn0amzZtQuXKldXeY/LkyejSpQtOnTqFtm3bonv37njw4IHq/c+dO4e//voL58+fx4IFC+Di4lJyHwARFV6JLbtKRKSDyMhIQS6XC3Z2dmq3qVOnCoIgCACEzz77TO01gYGBwsCBAwVBEIRFixYJZcqUETIzM1XPb968WTAzMxNSUlIEQRAET09PYdy4cfnWAEAYP3686nFmZqYAQPjrr78EQRCEsLAwoU+fPvo5YCIqUewDREQG65133sGCBQvU2sqWLau6HxQUpPZcUFAQTpw4AQA4f/486tSpAzs7O9XzTZs2RW5uLi5evAiZTIbbt2+jVatWBdZQu3Zt1X07Ozs4Ojri7t27AICBAweiY8eOOHbsGN5//32Eh4ejSZMmhTpWIipZDEBEZLDs7OzyXJLSFxsbG622s7CwUHssk8mQm5sLAGjTpg1u3LiBLVu2YPv27WjVqhUGDRqEGTNm6L1eItIv9gEiIqN18ODBPI+rV68OAKhevTpOnjyJrKws1fP79u2DmZkZqlatCgcHB/j6+iIhIaFINbi6uiIyMhIrV65EbGwsFi1aVKT9EVHJ4BkgIjJY2dnZSElJUWszNzdXdTRet24dGjZsiGbNmmHVqlU4fPgwFi9eDADo3r07Jk6ciMjISEyaNAn37t3DkCFD0LNnT7i5uQEAJk2ahM8++wzly5dHmzZt8PjxY+zbtw9DhgzRqr4JEyagQYMGqFmzJrKzs/Hnn3+qAhgRGTYGICIyWFu3boWHh4daW9WqVXHhwgUA4gittWvX4vPPP4eHhwfWrFmDGjVqAABsbW2xbds2DBs2DI0aNYKtrS06duyImTNnqvYVGRmJZ8+eYdasWRg1ahRcXFzQqVMnreuztLTE2LFjcf36ddjY2CA4OBhr167Vw5ETUXGTCYIgSF0EEZGuZDIZNmzYgPDwcKlLISIjxD5AREREZHIYgIiIiMjksA8QERklXr0noqLgGSAiIiIyOQxAREREZHIYgIiIiMjkMAARERGRyWEAIiIiIpPDAEREREQmhwGIiIiITA4DEBEREZkcBiAiIiIyOf8Hw4fJoTNf+rYAAAAASUVORK5CYII=\",\n      \"text/plain\": [\n       \"<Figure size 640x480 with 1 Axes>\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"plt.plot(epochs, acc, 'bo', label='Training acc')\\n\",\n    \"plt.plot(epochs, val_acc, 'b', label='Validation acc')\\n\",\n    \"plt.title('Training and validation accuracy')\\n\",\n    \"plt.xlabel('Epochs')\\n\",\n    \"plt.ylabel('Accuracy')\\n\",\n    \"plt.legend(loc='lower right')\\n\",\n    \"\\n\",\n    \"plt.show()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"7865d6f2\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Export the model\\n\",\n    \"\\n\",\n    \"We can export the model including the TextVectorization layer inside the model to conduct inference on raw text.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 26,\n   \"id\": \"93b0a42c-437e-41bb-99e7-d58cb8036a3a\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\u001b[1m782/782\\u001b[0m \\u001b[32m━━━━━━━━━━━━━━━━━━━━\\u001b[0m\\u001b[37m\\u001b[0m \\u001b[1m2s\\u001b[0m 2ms/step - accuracy: 0.4935 - binary_accuracy: 0.0000e+00 - loss: 0.0000e+00\\n\",\n      \"{'accuracy': 0.5, 'binary_accuracy': 0.0, 'loss': 0.0}\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"export_model = tf.keras.Sequential([\\n\",\n    \"  vectorize_layer,\\n\",\n    \"  model,\\n\",\n    \"  layers.Activation('sigmoid')\\n\",\n    \"])\\n\",\n    \"\\n\",\n    \"export_model.compile(\\n\",\n    \"    loss=losses.BinaryCrossentropy(from_logits=False), optimizer=\\\"adam\\\", metrics=['accuracy']\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"# Test it with `raw_test_ds`, which yields raw strings\\n\",\n    \"metrics = export_model.evaluate(raw_test_ds, return_dict=True)\\n\",\n    \"print(metrics)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"d0795584\",\n   \"metadata\": {},\n   \"source\": [\n    \"Conduct inference on new data:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 27,\n   \"id\": \"8939539b-a600-48b1-a55e-3f1087f4a855\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\u001b[1m1/1\\u001b[0m \\u001b[32m━━━━━━━━━━━━━━━━━━━━\\u001b[0m\\u001b[37m\\u001b[0m \\u001b[1m0s\\u001b[0m 61ms/step\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"array([[0.67346764],\\n\",\n       \"       [0.634105  ],\\n\",\n       \"       [0.61044645]], dtype=float32)\"\n      ]\n     },\n     \"execution_count\": 27,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"examples = tf.constant([\\n\",\n    \"  \\\"The movie was great!\\\",\\n\",\n    \"  \\\"The movie was okay.\\\",\\n\",\n    \"  \\\"The movie was terrible...\\\"\\n\",\n    \"])\\n\",\n    \"\\n\",\n    \"export_model.predict(examples)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"f6b40a59-8d3b-44ec-a4f7-92c5742a0c1c\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Save Model\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 28,\n   \"id\": \"3e520822\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"os.mkdir('models') if not os.path.exists('models') else None\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 29,\n   \"id\": \"7f22cc32-2708-4808-8e76-99024da87a21\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"export_model.save('models/text_model.keras')\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"e0461f74-fdd0-4f30-9f44-0be7ad00d9b0\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Load model\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 30,\n   \"id\": \"c9cf2c7f-5e86-4ff8-984e-dd0ed7a3ece9\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/html\": [\n       \"<pre style=\\\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\\\"><span style=\\\"font-weight: bold\\\">Model: \\\"sequential_1\\\"</span>\\n\",\n       \"</pre>\\n\"\n      ],\n      \"text/plain\": [\n       \"\\u001b[1mModel: \\\"sequential_1\\\"\\u001b[0m\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/html\": [\n       \"<pre style=\\\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\\\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\\n\",\n       \"┃<span style=\\\"font-weight: bold\\\"> Layer (type)                    </span>┃<span style=\\\"font-weight: bold\\\"> Output Shape           </span>┃<span style=\\\"font-weight: bold\\\">       Param # </span>┃\\n\",\n       \"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\\n\",\n       \"│ text_vectorization              │ (<span style=\\\"color: #00d7ff; text-decoration-color: #00d7ff\\\">None</span>, <span style=\\\"color: #00af00; text-decoration-color: #00af00\\\">250</span>)            │             <span style=\\\"color: #00af00; text-decoration-color: #00af00\\\">0</span> │\\n\",\n       \"│ (<span style=\\\"color: #0087ff; text-decoration-color: #0087ff\\\">TextVectorization</span>)             │                        │               │\\n\",\n       \"├─────────────────────────────────┼────────────────────────┼───────────────┤\\n\",\n       \"│ sequential (<span style=\\\"color: #0087ff; text-decoration-color: #0087ff\\\">Sequential</span>)         │ (<span style=\\\"color: #00d7ff; text-decoration-color: #00d7ff\\\">None</span>, <span style=\\\"color: #00af00; text-decoration-color: #00af00\\\">1</span>)              │       <span style=\\\"color: #00af00; text-decoration-color: #00af00\\\">160,017</span> │\\n\",\n       \"├─────────────────────────────────┼────────────────────────┼───────────────┤\\n\",\n       \"│ activation (<span style=\\\"color: #0087ff; text-decoration-color: #0087ff\\\">Activation</span>)         │ (<span style=\\\"color: #00d7ff; text-decoration-color: #00d7ff\\\">None</span>, <span style=\\\"color: #00af00; text-decoration-color: #00af00\\\">1</span>)              │             <span style=\\\"color: #00af00; text-decoration-color: #00af00\\\">0</span> │\\n\",\n       \"└─────────────────────────────────┴────────────────────────┴───────────────┘\\n\",\n       \"</pre>\\n\"\n      ],\n      \"text/plain\": [\n       \"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\\n\",\n       \"┃\\u001b[1m \\u001b[0m\\u001b[1mLayer (type)                   \\u001b[0m\\u001b[1m \\u001b[0m┃\\u001b[1m \\u001b[0m\\u001b[1mOutput Shape          \\u001b[0m\\u001b[1m \\u001b[0m┃\\u001b[1m \\u001b[0m\\u001b[1m      Param #\\u001b[0m\\u001b[1m \\u001b[0m┃\\n\",\n       \"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\\n\",\n       \"│ text_vectorization              │ (\\u001b[38;5;45mNone\\u001b[0m, \\u001b[38;5;34m250\\u001b[0m)            │             \\u001b[38;5;34m0\\u001b[0m │\\n\",\n       \"│ (\\u001b[38;5;33mTextVectorization\\u001b[0m)             │                        │               │\\n\",\n       \"├─────────────────────────────────┼────────────────────────┼───────────────┤\\n\",\n       \"│ sequential (\\u001b[38;5;33mSequential\\u001b[0m)         │ (\\u001b[38;5;45mNone\\u001b[0m, \\u001b[38;5;34m1\\u001b[0m)              │       \\u001b[38;5;34m160,017\\u001b[0m │\\n\",\n       \"├─────────────────────────────────┼────────────────────────┼───────────────┤\\n\",\n       \"│ activation (\\u001b[38;5;33mActivation\\u001b[0m)         │ (\\u001b[38;5;45mNone\\u001b[0m, \\u001b[38;5;34m1\\u001b[0m)              │             \\u001b[38;5;34m0\\u001b[0m │\\n\",\n       \"└─────────────────────────────────┴────────────────────────┴───────────────┘\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/html\": [\n       \"<pre style=\\\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\\\"><span style=\\\"font-weight: bold\\\"> Total params: </span><span style=\\\"color: #00af00; text-decoration-color: #00af00\\\">160,017</span> (625.07 KB)\\n\",\n       \"</pre>\\n\"\n      ],\n      \"text/plain\": [\n       \"\\u001b[1m Total params: \\u001b[0m\\u001b[38;5;34m160,017\\u001b[0m (625.07 KB)\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/html\": [\n       \"<pre style=\\\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\\\"><span style=\\\"font-weight: bold\\\"> Trainable params: </span><span style=\\\"color: #00af00; text-decoration-color: #00af00\\\">160,017</span> (625.07 KB)\\n\",\n       \"</pre>\\n\"\n      ],\n      \"text/plain\": [\n       \"\\u001b[1m Trainable params: \\u001b[0m\\u001b[38;5;34m160,017\\u001b[0m (625.07 KB)\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/html\": [\n       \"<pre style=\\\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\\\"><span style=\\\"font-weight: bold\\\"> Non-trainable params: </span><span style=\\\"color: #00af00; text-decoration-color: #00af00\\\">0</span> (0.00 B)\\n\",\n       \"</pre>\\n\"\n      ],\n      \"text/plain\": [\n       \"\\u001b[1m Non-trainable params: \\u001b[0m\\u001b[38;5;34m0\\u001b[0m (0.00 B)\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"# register callables as custom objects before loading\\n\",\n    \"custom_objects = {\\\"vectorize_layer\\\": vectorize_layer, \\\"custom_standardization\\\": custom_standardization}\\n\",\n    \"with tf.keras.utils.custom_object_scope(custom_objects):\\n\",\n    \"    new_model = tf.keras.models.load_model('models/text_model.keras', compile=False)\\n\",\n    \"\\n\",\n    \"new_model.summary()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"242a4f7e-fa45-4d21-b103-fe3718bc0f10\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Predict\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 31,\n   \"id\": \"531680b2-42ef-4205-9a38-6995aee9f340\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\u001b[1m1/1\\u001b[0m \\u001b[32m━━━━━━━━━━━━━━━━━━━━\\u001b[0m\\u001b[37m\\u001b[0m \\u001b[1m0s\\u001b[0m 59ms/step\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"array([[0.67346764],\\n\",\n       \"       [0.634105  ],\\n\",\n       \"       [0.61044645]], dtype=float32)\"\n      ]\n     },\n     \"execution_count\": 31,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"new_model.predict(examples)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"a82ae387-1587-4175-b4b2-66586e4668f7\",\n   \"metadata\": {},\n   \"source\": [\n    \"## PySpark\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 32,\n   \"id\": \"d6d515c2-ce53-4af5-a936-ae91fdecea99\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from pyspark.ml.functions import predict_batch_udf\\n\",\n    \"from pyspark.sql.functions import struct, col, array, pandas_udf\\n\",\n    \"from pyspark.sql.types import ArrayType, FloatType, DoubleType\\n\",\n    \"from pyspark.sql import SparkSession\\n\",\n    \"from pyspark import SparkConf\\n\",\n    \"import pandas as pd\\n\",\n    \"import json\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"39c35256\",\n   \"metadata\": {},\n   \"source\": [\n    \"Check the cluster environment to handle any platform-specific Spark configurations.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 33,\n   \"id\": \"31de0c5f\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"on_databricks = os.environ.get(\\\"DATABRICKS_RUNTIME_VERSION\\\", False)\\n\",\n    \"on_dataproc = os.environ.get(\\\"DATAPROC_IMAGE_VERSION\\\", False)\\n\",\n    \"on_standalone = not (on_databricks or on_dataproc)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"55ad7f00\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Create Spark Session\\n\",\n    \"\\n\",\n    \"For local standalone clusters, we'll connect to the cluster and create the Spark Session.  \\n\",\n    \"For CSP environments, Spark will either be preconfigured (Databricks) or we'll need to create the Spark Session (Dataproc).\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 34,\n   \"id\": \"6b653c43\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"25/02/04 14:05:31 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)\\n\",\n      \"25/02/04 14:05:31 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\\n\",\n      \"Setting default log level to \\\"WARN\\\".\\n\",\n      \"To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\\n\",\n      \"25/02/04 14:05:31 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"conf = SparkConf()\\n\",\n    \"\\n\",\n    \"if 'spark' not in globals():\\n\",\n    \"    if on_standalone:\\n\",\n    \"        import socket\\n\",\n    \"        \\n\",\n    \"        conda_env = os.environ.get(\\\"CONDA_PREFIX\\\")\\n\",\n    \"        hostname = socket.gethostname()\\n\",\n    \"        conf.setMaster(f\\\"spark://{hostname}:7077\\\")\\n\",\n    \"        conf.set(\\\"spark.pyspark.python\\\", f\\\"{conda_env}/bin/python\\\")\\n\",\n    \"        conf.set(\\\"spark.pyspark.driver.python\\\", f\\\"{conda_env}/bin/python\\\")\\n\",\n    \"    elif on_dataproc:\\n\",\n    \"        conf.set(\\\"spark.executorEnv.TF_GPU_ALLOCATOR\\\", \\\"cuda_malloc_async\\\")\\n\",\n    \"\\n\",\n    \"    conf.set(\\\"spark.executor.cores\\\", \\\"8\\\")\\n\",\n    \"    conf.set(\\\"spark.task.resource.gpu.amount\\\", \\\"0.125\\\")\\n\",\n    \"    conf.set(\\\"spark.executor.resource.gpu.amount\\\", \\\"1\\\")\\n\",\n    \"    conf.set(\\\"spark.sql.execution.arrow.pyspark.enabled\\\", \\\"true\\\")\\n\",\n    \"    conf.set(\\\"spark.python.worker.reuse\\\", \\\"true\\\")\\n\",\n    \"\\n\",\n    \"conf.set(\\\"spark.sql.execution.arrow.maxRecordsPerBatch\\\", \\\"1000\\\")\\n\",\n    \"spark = SparkSession.builder.appName(\\\"spark-dl-examples\\\").config(conf=conf).getOrCreate()\\n\",\n    \"sc = spark.sparkContext\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"53b39d27\",\n   \"metadata\": {},\n   \"source\": [\n    \"Load the IMDB dataset. We'll perform inference on the first sentence of each sample.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 35,\n   \"id\": \"ef3309eb\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from datasets import load_dataset\\n\",\n    \"\\n\",\n    \"dataset = load_dataset(\\\"imdb\\\", split=\\\"test\\\")\\n\",\n    \"dataset = dataset.to_pandas().drop(columns=\\\"label\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"3a7672d1\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Create PySpark DataFrame\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 36,\n   \"id\": \"bb05466f\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"StructType([StructField('text', StringType(), True)])\"\n      ]\n     },\n     \"execution_count\": 36,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"df = spark.createDataFrame(dataset).repartition(8)\\n\",\n    \"df.schema\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 37,\n   \"id\": \"3f0a594b\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"25/02/04 14:05:36 WARN TaskSetManager: Stage 0 contains a task of very large size (4021 KiB). The maximum recommended task size is 1000 KiB.\\n\",\n      \"                                                                                \\r\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"[Row(text=\\\"Anyone remember the first CKY, CKY2K etc..? Back when it was about making crazy cool stuff, rather than watching Bam Margera act like a douchebag, spoiled 5 year old, super/rock-star wannabe.<br /><br />The show used to be awesome, however, Bam's fame and wealth has led him to believe, that we now enjoy him acting childish and idiotic, more than actual cool stuff, that used to be in ex. CKY2K.<br /><br />The acts are so repetitive, there's like nothing new, except annoying stupidity and rehearsed comments... The only things we see is Bam Margera, so busy showing us how much he doesn't care, how much money he got or whatsoever.<br /><br />I really got nothing much left to say except, give us back CKY2K, cause Bam suck..<br /><br />I enjoy watching Steve-o, Knoxville etc. a thousand times more.\\\")]\"\n      ]\n     },\n     \"execution_count\": 37,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"df.take(1)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 38,\n   \"id\": \"9d9db063\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"25/02/04 14:05:37 WARN TaskSetManager: Stage 3 contains a task of very large size (4021 KiB). The maximum recommended task size is 1000 KiB.\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"data_path = \\\"spark-dl-datasets/imdb_test\\\"\\n\",\n    \"if on_databricks:\\n\",\n    \"    data_path = \\\"dbfs:/FileStore/\\\" + data_path\\n\",\n    \"\\n\",\n    \"df.write.mode(\\\"overwrite\\\").parquet(data_path)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"2f78a16a\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Load and Preprocess PySpark DataFrame\\n\",\n    \"\\n\",\n    \"Define our preprocess function. We'll take the first sentence of each sample as our input for sentiment analysis.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 39,\n   \"id\": \"1c081557\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"@pandas_udf(\\\"string\\\")\\n\",\n    \"def preprocess(text: pd.Series) -> pd.Series:\\n\",\n    \"    return pd.Series([s.split(\\\".\\\")[0] for s in text])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 40,\n   \"id\": \"60af570a\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Limit to N rows, since this can be slow\\n\",\n    \"df = spark.read.parquet(data_path).limit(512).repartition(8)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 41,\n   \"id\": \"a690f6df\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"input_df = df.select(preprocess(col(\\\"text\\\")).alias(\\\"lines\\\")).cache()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"01166d97\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Inference using Spark DL API\\n\",\n    \"\\n\",\n    \"Distributed inference using the PySpark [predict_batch_udf](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.functions.predict_batch_udf.html#pyspark.ml.functions.predict_batch_udf):\\n\",\n    \"\\n\",\n    \"- predict_batch_fn uses Tensorflow APIs to load the model and return a predict function which operates on numpy arrays \\n\",\n    \"- predict_batch_udf will convert the Spark DataFrame columns into numpy input batches for the predict function\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 42,\n   \"id\": \"7b7a8395-e2ae-4c3c-bf57-763dfde600ad\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"text_model_path = \\\"{}/models/text_model.keras\\\".format(os.getcwd())\\n\",\n    \"\\n\",\n    \"# For cloud environments, copy the model to the distributed file system.\\n\",\n    \"if on_databricks:\\n\",\n    \"    dbutils.fs.mkdirs(\\\"/FileStore/spark-dl-models\\\")\\n\",\n    \"    dbfs_model_path = \\\"/dbfs/FileStore/spark-dl-models/text_model.keras\\\"\\n\",\n    \"    shutil.copy(text_model_path, dbfs_model_path)\\n\",\n    \"    text_model_path = dbfs_model_path\\n\",\n    \"elif on_dataproc:\\n\",\n    \"    # GCS is mounted at /mnt/gcs by the init script\\n\",\n    \"    models_dir = \\\"/mnt/gcs/spark-dl/models\\\"\\n\",\n    \"    os.mkdir(models_dir) if not os.path.exists(models_dir) else None\\n\",\n    \"    gcs_model_path = models_dir + \\\"/text_model.keras\\\"\\n\",\n    \"    shutil.copy(text_model_path, gcs_model_path)\\n\",\n    \"    text_model_path = gcs_model_path\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 43,\n   \"id\": \"8c0524cf-3a75-4fb8-8025-f0654acce13e\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def predict_batch_fn():\\n\",\n    \"    # since this function runs on the executor, any required imports should be added inside the function.\\n\",\n    \"    import re\\n\",\n    \"    import string\\n\",\n    \"    import tensorflow as tf\\n\",\n    \"    from tensorflow.keras import layers\\n\",\n    \"\\n\",\n    \"    # Enable GPU memory growth to avoid CUDA OOM\\n\",\n    \"    gpus = tf.config.experimental.list_physical_devices('GPU')\\n\",\n    \"    if gpus:\\n\",\n    \"        try:\\n\",\n    \"            for gpu in gpus:\\n\",\n    \"                tf.config.experimental.set_memory_growth(gpu, True)\\n\",\n    \"        except RuntimeError as e:\\n\",\n    \"            print(e)\\n\",\n    \"\\n\",\n    \"    def custom_standardization(input_data):\\n\",\n    \"        lowercase = tf.strings.lower(input_data)\\n\",\n    \"        stripped_html = tf.strings.regex_replace(lowercase, \\\"<br />\\\", \\\" \\\")\\n\",\n    \"        return tf.strings.regex_replace(\\n\",\n    \"            stripped_html, \\\"[%s]\\\" % re.escape(string.punctuation), \\\"\\\"\\n\",\n    \"        )\\n\",\n    \"\\n\",\n    \"    max_features = 10000\\n\",\n    \"    sequence_length = 250\\n\",\n    \"\\n\",\n    \"    vectorize_layer = layers.TextVectorization(\\n\",\n    \"        standardize=custom_standardization,\\n\",\n    \"        max_tokens=max_features,\\n\",\n    \"        output_mode=\\\"int\\\",\\n\",\n    \"        output_sequence_length=sequence_length,\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    custom_objects = {\\\"vectorize_layer\\\": vectorize_layer,\\n\",\n    \"                      \\\"custom_standardization\\\": custom_standardization}\\n\",\n    \"    with tf.keras.utils.custom_object_scope(custom_objects):\\n\",\n    \"        model = tf.keras.models.load_model(text_model_path)\\n\",\n    \"\\n\",\n    \"    def predict(inputs):\\n\",\n    \"        return model.predict(inputs)\\n\",\n    \"\\n\",\n    \"    return predict\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 44,\n   \"id\": \"0d603644-d938-4c87-aa8a-2512251638d5\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"classify = predict_batch_udf(predict_batch_fn,\\n\",\n    \"                             return_type=FloatType(),\\n\",\n    \"                             batch_size=256)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 45,\n   \"id\": \"0b480622-8dc1-4879-933e-c43112768630\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 9:>                                                          (0 + 8) / 8]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 6.81 ms, sys: 3.75 ms, total: 10.6 ms\\n\",\n      \"Wall time: 4.62 s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"predictions = input_df.withColumn(\\\"preds\\\", classify(struct(\\\"lines\\\")))\\n\",\n    \"results = predictions.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 46,\n   \"id\": \"31b0a262-387e-4a5e-a60e-b9b8ee456199\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 4.58 ms, sys: 0 ns, total: 4.58 ms\\n\",\n      \"Wall time: 142 ms\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"predictions = input_df.withColumn(\\\"preds\\\", classify(\\\"lines\\\"))\\n\",\n    \"results = predictions.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 47,\n   \"id\": \"7ef9e431-59f5-4b29-9f79-ae16a9cfb0b9\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 903 μs, sys: 4.09 ms, total: 5 ms\\n\",\n      \"Wall time: 222 ms\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"predictions = input_df.withColumn(\\\"preds\\\", classify(col(\\\"lines\\\")))\\n\",\n    \"results = predictions.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 48,\n   \"id\": \"9a325ee2-3268-414a-bb75-a5fcf794f512\",\n   \"metadata\": {\n    \"scrolled\": true\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+--------------------------------------------------------------------------------+----------+\\n\",\n      \"|                                                                           lines|     preds|\\n\",\n      \"+--------------------------------------------------------------------------------+----------+\\n\",\n      \"|The only reason I'm even giving this movie a 4 is because it was made in to a...|  0.571606|\\n\",\n      \"|Awkward disaster mishmash has a team of scavengers coming across the overturn...| 0.6264358|\\n\",\n      \"|Here is a fantastic concept for a film - a series of meteors crash into a sma...| 0.6764294|\\n\",\n      \"|              I walked out of the cinema having suffered this film after 30 mins| 0.6258814|\\n\",\n      \"|A wildly uneven film where the major problem is the uneasy mix of comedy and ...|0.63658905|\\n\",\n      \"|Leonard Rossiter and Frances de la Tour carry this film, not without a strugg...|  0.633625|\\n\",\n      \"|                                                                     A good cast|0.65998995|\\n\",\n      \"|Yet again, I appear to be the only person on planet Earth who is capable of c...| 0.6435825|\\n\",\n      \"|As a serious horror fan, I get that certain marketing ploys are used to sell ...| 0.6453945|\\n\",\n      \"|Upon writing this review I have difficulty trying to think of what to write a...|0.61587423|\\n\",\n      \"|                                                                    Simply awful|  0.594154|\\n\",\n      \"|I am a fan of Ed Harris' work and I really had high expectations about this film| 0.6366444|\\n\",\n      \"|                                                                            Well|0.65976477|\\n\",\n      \"|                                                This is a new approach to comedy| 0.6555772|\\n\",\n      \"|     It's been mentioned by others the inane dialogue in this series and I agree| 0.6534178|\\n\",\n      \"|One of the most boring movies I've ever had to sit through, it's completely f...| 0.5919746|\\n\",\n      \"|This movie was playing on Lifetime Movie Network last month and I decided to ...| 0.6527056|\\n\",\n      \"|                                       1983's \\\"Frightmare\\\" is an odd little film|0.64622015|\\n\",\n      \"|                                                           'Felony' is a B-movie|0.64882356|\\n\",\n      \"|                                          This movie defines the word \\\"confused\\\"|0.63689107|\\n\",\n      \"+--------------------------------------------------------------------------------+----------+\\n\",\n      \"only showing top 20 rows\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"predictions.show(truncate=80)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"ad9b07e6\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Using Triton Inference Server\\n\",\n    \"In this section, we demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL.  \\n\",\n    \"We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server.  \\n\",\n    \"\\n\",\n    \"The process looks like this:\\n\",\n    \"- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.\\n\",\n    \"- Define a Triton inference function, which contains a client that binds to the local server on a given node and sends inference requests.\\n\",\n    \"- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.\\n\",\n    \"- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.\\n\",\n    \"\\n\",\n    \"<img src=\\\"../images/spark-server.png\\\" alt=\\\"drawing\\\" width=\\\"700\\\"/>\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"889a1623\",\n   \"metadata\": {},\n   \"source\": [\n    \"First we'll cleanup the vocabulary layer of the model to remove non-ASCII characters. This ensures the inputs can be properly serialized and sent to Triton.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 49,\n   \"id\": \"f4f14c8f\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import unicodedata\\n\",\n    \"\\n\",\n    \"def normalize_vocabulary(vocab):\\n\",\n    \"    # Normalize each word in the vocabulary to remove non-ASCII characters\\n\",\n    \"    normalized_vocab = [\\n\",\n    \"        unicodedata.normalize('NFKD', word).encode('ascii', 'ignore').decode('utf-8')\\n\",\n    \"        for word in vocab\\n\",\n    \"    ]\\n\",\n    \"    normalized_vocab = filter(lambda x: x != '', normalized_vocab)\\n\",\n    \"    normalized_vocab = list(set(normalized_vocab)) \\n\",\n    \"\\n\",\n    \"\\n\",\n    \"    return normalized_vocab\\n\",\n    \"\\n\",\n    \"vocab = vectorize_layer.get_vocabulary()\\n\",\n    \"normalized_vocab = normalize_vocabulary(vocab)\\n\",\n    \"\\n\",\n    \"# Reassign the cleaned vocabulary to the TextVectorization layer\\n\",\n    \"vectorize_layer.set_vocabulary(normalized_vocab)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 50,\n   \"id\": \"9614a192\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Save the model with the cleaned vocabulary\\n\",\n    \"triton_model_path = '{}/models/text_model_cleaned.keras'.format(os.getcwd())\\n\",\n    \"export_model.save(triton_model_path)\\n\",\n    \"\\n\",\n    \"# For cloud environments, copy the model to the distributed file system.\\n\",\n    \"if on_databricks:\\n\",\n    \"    dbutils.fs.mkdirs(\\\"/FileStore/spark-dl-models\\\")\\n\",\n    \"    dbfs_model_path = \\\"/dbfs/FileStore/spark-dl-models/text_model_cleaned.keras\\\"\\n\",\n    \"    shutil.copy(triton_model_path, dbfs_model_path)\\n\",\n    \"    triton_model_path = dbfs_model_path\\n\",\n    \"elif on_dataproc:\\n\",\n    \"    # GCS is mounted at /mnt/gcs by the init script\\n\",\n    \"    models_dir = \\\"/mnt/gcs/spark-dl/models\\\"\\n\",\n    \"    os.mkdir(models_dir) if not os.path.exists(models_dir) else None\\n\",\n    \"    gcs_model_path = models_dir + \\\"/text_model_cleaned.keras\\\"\\n\",\n    \"    shutil.copy(triton_model_path, gcs_model_path)\\n\",\n    \"    triton_model_path = gcs_model_path\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 51,\n   \"id\": \"32d0142a\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from functools import partial\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"edddffb9\",\n   \"metadata\": {},\n   \"source\": [\n    \"Import the helper class from server_utils.py:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 52,\n   \"id\": \"444bad3f\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"sc.addPyFile(\\\"server_utils.py\\\")\\n\",\n    \"\\n\",\n    \"from server_utils import TritonServerManager\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"f0923a56\",\n   \"metadata\": {},\n   \"source\": [\n    \"Define the Triton Server function:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 53,\n   \"id\": \"a4d37d33\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def triton_server(ports, model_path):\\n\",\n    \"    import time\\n\",\n    \"    import signal\\n\",\n    \"    import numpy as np\\n\",\n    \"    import tensorflow as tf\\n\",\n    \"    from pytriton.decorators import batch\\n\",\n    \"    from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor\\n\",\n    \"    from pytriton.triton import Triton, TritonConfig\\n\",\n    \"    from pyspark import TaskContext\\n\",\n    \"    from tensorflow.keras import layers \\n\",\n    \"\\n\",\n    \"    \\n\",\n    \"    print(f\\\"SERVER: Initializing model on worker {TaskContext.get().partitionId()}.\\\")\\n\",\n    \"    # Enable GPU memory growth\\n\",\n    \"    gpus = tf.config.experimental.list_physical_devices('GPU')\\n\",\n    \"    if gpus:\\n\",\n    \"        try:\\n\",\n    \"            for gpu in gpus:\\n\",\n    \"                tf.config.experimental.set_memory_growth(gpu, True)\\n\",\n    \"        except RuntimeError as e:\\n\",\n    \"            print(e)\\n\",\n    \"\\n\",\n    \"    def custom_standardization(input_data):\\n\",\n    \"        lowercase = tf.strings.lower(input_data)\\n\",\n    \"        stripped_html = tf.strings.regex_replace(lowercase, \\\"<br />\\\", \\\" \\\")\\n\",\n    \"        return tf.strings.regex_replace(\\n\",\n    \"            stripped_html, \\\"[%s]\\\" % re.escape(string.punctuation), \\\"\\\"\\n\",\n    \"        )\\n\",\n    \"\\n\",\n    \"    max_features = 10000\\n\",\n    \"    sequence_length = 250\\n\",\n    \"\\n\",\n    \"    vectorize_layer = layers.TextVectorization(\\n\",\n    \"        standardize=custom_standardization,\\n\",\n    \"        max_tokens=max_features,\\n\",\n    \"        output_mode=\\\"int\\\",\\n\",\n    \"        output_sequence_length=sequence_length,\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    custom_objects = {\\\"vectorize_layer\\\": vectorize_layer,\\n\",\n    \"                \\\"custom_standardization\\\": custom_standardization}\\n\",\n    \"\\n\",\n    \"    with tf.keras.utils.custom_object_scope(custom_objects):\\n\",\n    \"        model = tf.keras.models.load_model(model_path)\\n\",\n    \"\\n\",\n    \"    @batch\\n\",\n    \"    def _infer_fn(**inputs):\\n\",\n    \"        sentences = inputs[\\\"text\\\"]\\n\",\n    \"        print(f\\\"SERVER: Received batch of size {len(sentences)}.\\\")\\n\",\n    \"        decoded_sentences = tf.convert_to_tensor(np.vectorize(lambda x: x.decode('utf-8'))(sentences))\\n\",\n    \"        return {\\n\",\n    \"            \\\"preds\\\": model.predict(decoded_sentences)\\n\",\n    \"        }\\n\",\n    \"    \\n\",\n    \"    workspace_path = f\\\"/tmp/triton_{time.strftime('%m_%d_%M_%S')}\\\"\\n\",\n    \"    triton_conf = TritonConfig(http_port=ports[0], grpc_port=ports[1], metrics_port=ports[2])\\n\",\n    \"    with Triton(config=triton_conf, workspace=workspace_path) as triton:\\n\",\n    \"        triton.bind(\\n\",\n    \"            model_name=\\\"TextModel\\\",\\n\",\n    \"            infer_func=_infer_fn,\\n\",\n    \"            inputs=[\\n\",\n    \"                Tensor(name=\\\"text\\\", dtype=np.bytes_, shape=(-1,)),\\n\",\n    \"            ],\\n\",\n    \"            outputs=[\\n\",\n    \"                Tensor(name=\\\"preds\\\", dtype=np.float32, shape=(-1,)),\\n\",\n    \"            ],\\n\",\n    \"            config=ModelConfig(\\n\",\n    \"                max_batch_size=128,\\n\",\n    \"                batcher=DynamicBatcher(max_queue_delay_microseconds=5000),  # 5ms\\n\",\n    \"            ),\\n\",\n    \"            strict=True,\\n\",\n    \"        )\\n\",\n    \"\\n\",\n    \"        def _stop_triton(signum, frame):\\n\",\n    \"            # The server manager sends SIGTERM to stop the server; this function ensures graceful cleanup.\\n\",\n    \"            print(\\\"SERVER: Received SIGTERM. Stopping Triton server.\\\")\\n\",\n    \"            triton.stop()\\n\",\n    \"\\n\",\n    \"        signal.signal(signal.SIGTERM, _stop_triton)\\n\",\n    \"\\n\",\n    \"        print(\\\"SERVER: Serving inference\\\")\\n\",\n    \"        triton.serve()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"d340e231\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Start Triton servers\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"fcdb7c5a\",\n   \"metadata\": {},\n   \"source\": [\n    \"The `TritonServerManager` will handle the lifecycle of Triton server instances across the Spark cluster:\\n\",\n    \"- Find available ports for HTTP/gRPC/metrics\\n\",\n    \"- Deploy a server on each node via stage-level scheduling\\n\",\n    \"- Gracefully shutdown servers across nodes\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 55,\n   \"id\": \"4d5dc419\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"model_name = \\\"TextModel\\\"\\n\",\n    \"server_manager = TritonServerManager(model_name=model_name, model_path=triton_model_path)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"20198644\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2025-02-07 11:03:44,809 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\\n\",\n      \"2025-02-07 11:03:44,810 - INFO - Starting 1 servers.\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'cb4ae00-lcedt': (2020631, [7000, 7001, 7002])}\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"# Returns {'hostname', (server_pid, [http_port, grpc_port, metrics_port])}\\n\",\n    \"server_manager.start_servers(triton_server)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"e1477f4b\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Define client function\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"798c2815\",\n   \"metadata\": {},\n   \"source\": [\n    \"Get the hostname -> url mapping from the server manager:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"813d42cf\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"host_to_http_url = server_manager.host_to_http_url  # or server_manager.host_to_grpc_url\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"f16617e3\",\n   \"metadata\": {},\n   \"source\": [\n    \"Define the Triton inference function, which returns a predict function for batch inference through the server:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 58,\n   \"id\": \"0ad47438\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def triton_fn(model_name, host_to_url):\\n\",\n    \"    import socket\\n\",\n    \"    import numpy as np\\n\",\n    \"    from pytriton.client import ModelClient\\n\",\n    \"\\n\",\n    \"    url = host_to_url[socket.gethostname()]\\n\",\n    \"    print(f\\\"CLIENT: Connecting to {model_name} at {url}\\\")\\n\",\n    \"\\n\",\n    \"    def infer_batch(inputs):\\n\",\n    \"        with ModelClient(url, model_name, inference_timeout_s=240) as client:\\n\",\n    \"            encoded_inputs = np.vectorize(lambda x: x.encode(\\\"utf-8\\\"))(inputs).astype(np.bytes_)\\n\",\n    \"            encoded_inputs = np.expand_dims(encoded_inputs, axis=1)\\n\",\n    \"            result_data = client.infer_batch(encoded_inputs)\\n\",\n    \"            \\n\",\n    \"            return result_data[\\\"preds\\\"]\\n\",\n    \"            \\n\",\n    \"    return infer_batch\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 61,\n   \"id\": \"8e06d33f-5cef-4a48-afc3-5d468f8ec2b4\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"classify = predict_batch_udf(partial(triton_fn, model_name=model_name, host_to_url=host_to_http_url),\\n\",\n    \"                             return_type=FloatType(),\\n\",\n    \"                             batch_size=64)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"91974885\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Load and preprocess DataFrame\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 59,\n   \"id\": \"41106a02-236e-4cb3-ac51-76aa64b663c2\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"df = spark.read.parquet(data_path).limit(512).repartition(8)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 60,\n   \"id\": \"e851870b\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"25/02/04 14:05:48 WARN CacheManager: Asked to cache already cached data.\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"input_df = df.select(preprocess(col(\\\"text\\\")).alias(\\\"lines\\\")).cache()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 62,\n   \"id\": \"d89e74ad-e551-4bfa-ad08-98725878630a\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 24:==============>                                           (2 + 6) / 8]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 2.92 ms, sys: 4.06 ms, total: 6.97 ms\\n\",\n      \"Wall time: 1.03 s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"predictions = input_df.withColumn(\\\"preds\\\", classify(struct(\\\"lines\\\")))\\n\",\n    \"results = predictions.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 63,\n   \"id\": \"b4fa7fc9-341c-49a6-9af2-e316f2355d67\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 1.39 ms, sys: 2.15 ms, total: 3.53 ms\\n\",\n      \"Wall time: 237 ms\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"predictions = input_df.withColumn(\\\"preds\\\", classify(\\\"lines\\\"))\\n\",\n    \"results = predictions.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 64,\n   \"id\": \"564f999b\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 862 μs, sys: 2.77 ms, total: 3.63 ms\\n\",\n      \"Wall time: 225 ms\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"predictions = input_df.withColumn(\\\"preds\\\", classify(col(\\\"lines\\\")))\\n\",\n    \"results = predictions.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 65,\n   \"id\": \"9222e8a9\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+--------------------------------------------------------------------------------+----------+\\n\",\n      \"|                                                                           lines|     preds|\\n\",\n      \"+--------------------------------------------------------------------------------+----------+\\n\",\n      \"|The only reason I'm even giving this movie a 4 is because it was made in to a...|0.67212176|\\n\",\n      \"|Awkward disaster mishmash has a team of scavengers coming across the overturn...|0.63807774|\\n\",\n      \"|Here is a fantastic concept for a film - a series of meteors crash into a sma...|0.65471745|\\n\",\n      \"|              I walked out of the cinema having suffered this film after 30 mins| 0.6527998|\\n\",\n      \"|A wildly uneven film where the major problem is the uneasy mix of comedy and ...| 0.6405446|\\n\",\n      \"|Leonard Rossiter and Frances de la Tour carry this film, not without a strugg...|0.63534474|\\n\",\n      \"|                                                                     A good cast|0.64761806|\\n\",\n      \"|Yet again, I appear to be the only person on planet Earth who is capable of c...|0.66956663|\\n\",\n      \"|As a serious horror fan, I get that certain marketing ploys are used to sell ...|0.62346375|\\n\",\n      \"|Upon writing this review I have difficulty trying to think of what to write a...|  0.681598|\\n\",\n      \"|                                                                    Simply awful| 0.6537583|\\n\",\n      \"|I am a fan of Ed Harris' work and I really had high expectations about this film| 0.6382922|\\n\",\n      \"|                                                                            Well|0.65424603|\\n\",\n      \"|                                                This is a new approach to comedy| 0.6628315|\\n\",\n      \"|     It's been mentioned by others the inane dialogue in this series and I agree|0.63345987|\\n\",\n      \"|One of the most boring movies I've ever had to sit through, it's completely f...| 0.6459369|\\n\",\n      \"|This movie was playing on Lifetime Movie Network last month and I decided to ...|0.65335083|\\n\",\n      \"|                                       1983's \\\"Frightmare\\\" is an odd little film|0.65602964|\\n\",\n      \"|                                                           'Felony' is a B-movie| 0.6583404|\\n\",\n      \"|                                          This movie defines the word \\\"confused\\\"| 0.6217103|\\n\",\n      \"+--------------------------------------------------------------------------------+----------+\\n\",\n      \"only showing top 20 rows\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"predictions.show(truncate=80)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"d45e8981-ca44-429b-9b37-e04035c3a86b\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"source\": [\n    \"#### Stop Triton Server on each executor\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 66,\n   \"id\": \"a71ac9b6-47a2-4306-bc40-9ce7b4e968ec\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2025-02-04 14:05:50,166 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\\n\",\n      \"2025-02-04 14:06:00,351 - INFO - Sucessfully stopped 1 servers.                 \\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"[True]\"\n      ]\n     },\n     \"execution_count\": 66,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"server_manager.stop_servers()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 67,\n   \"id\": \"54a90574-7cbb-487b-b7a8-dcda0e6e301f\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"if not on_databricks: # on databricks, spark.stop() puts the cluster in a bad state\\n\",\n    \"    spark.stop()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"88e3bfea-a825-46eb-b8c2-921a932c0089\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"spark-dl-tf\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.11.9\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "examples/ML+DL-Examples/Spark-DL/dl_inference/tf_requirements.txt",
    "content": "# Copyright (c) 2024, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n-r requirements.txt\ntensorflow[and-cuda]\ntf-keras"
  },
  {
    "path": "examples/ML+DL-Examples/Spark-DL/dl_inference/torch_requirements.txt",
    "content": "# Copyright (c) 2025, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n-r requirements.txt\ntorch<=2.5.1\ntorchvision\ntorch-tensorrt\ntensorrt --extra-index-url https://download.pytorch.org/whl/cu121\nsentence_transformers\nsentencepiece\nnvidia-modelopt[all] --extra-index-url https://pypi.nvidia.com"
  },
  {
    "path": "examples/ML+DL-Examples/Spark-DL/dl_inference/vllm/qwen-2.5-14b-tensor-parallel_vllm.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"<img src=\\\"https://developer.download.nvidia.com/notebooks/dlsw-notebooks/tensorrt_torchtrt_efficientnet/nvidia_logo.png\\\" width=\\\"90px\\\">\\n\",\n    \"\\n\",\n    \"# PySpark LLM Inference: Qwen-2.5-14b Data Structuring\\n\",\n    \"\\n\",\n    \"In this notebook, we demonstrate distributed batch inference with [Qwen-2.5](https://huggingface.co/Qwen/Qwen2.5-14B-Instruct), using open weights on Huggingface.\\n\",\n    \"\\n\",\n    \"The Qwen-2.5-14b-instruct is an instruction-fine-tuned version of the Qwen-2.5-14b base model. We'll show how to use the model to prepare unstructured text data into a structured schema for downstream tasks.\\n\",\n    \"\\n\",\n    \"**Note:** This example demonstrates **tensor parallelism**, which requires multiple GPUs per node. For standalone users, make sure to use a Spark worker with 2 GPUs. If you follow the Databricks or Dataproc instructions, make sure to include the `tp` argument to the cluster startup scripts.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Check the cluster environment to handle any platform-specific configurations.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"import os\\n\",\n    \"\\n\",\n    \"on_databricks = os.environ.get(\\\"DATABRICKS_RUNTIME_VERSION\\\", False)\\n\",\n    \"on_dataproc = os.environ.get(\\\"DATAPROC_IMAGE_VERSION\\\", False)\\n\",\n    \"on_standalone = not (on_databricks or on_dataproc)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# For cloud environments, load the model to the distributed file system.\\n\",\n    \"if on_databricks:\\n\",\n    \"    models_dir = \\\"/dbfs/FileStore/spark-dl-models\\\"\\n\",\n    \"    dbutils.fs.mkdirs(\\\"/FileStore/spark-dl-models\\\")\\n\",\n    \"    model_path = f\\\"{models_dir}/qwen2.5-14b\\\"\\n\",\n    \"elif on_dataproc:\\n\",\n    \"    models_dir = \\\"/mnt/gcs/spark-dl-models\\\"\\n\",\n    \"    os.mkdir(models_dir) if not os.path.exists(models_dir) else None\\n\",\n    \"    model_path = f\\\"{models_dir}/qwen2.5-14b\\\"\\n\",\n    \"else:\\n\",\n    \"    model_path = os.path.abspath(\\\"qwen2.5-14b\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Download the model from huggingface hub.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"f75ef1f2071f413da5ae502589293c62\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"Fetching 18 files:   0%|          | 0/18 [00:00<?, ?it/s]\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"from huggingface_hub import snapshot_download\\n\",\n    \"\\n\",\n    \"model_path = snapshot_download(\\n\",\n    \"    repo_id=\\\"Qwen/Qwen2.5-14B-Instruct\\\",\\n\",\n    \"    local_dir=model_path\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## PySpark\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"import pandas as pd\\n\",\n    \"from pyspark.sql.types import *\\n\",\n    \"from pyspark import SparkConf\\n\",\n    \"from pyspark.sql import SparkSession\\n\",\n    \"from pyspark.sql.functions import pandas_udf, col, struct, length, lit, concat\\n\",\n    \"from pyspark.ml.functions import predict_batch_udf\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"import os\\n\",\n    \"import datasets\\n\",\n    \"from datasets import load_dataset\\n\",\n    \"datasets.disable_progress_bars()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Create Spark Session\\n\",\n    \"\\n\",\n    \"For local standalone clusters, we'll connect to the cluster and create the Spark Session.  \\n\",\n    \"For CSP environments, Spark will either be preconfigured (Databricks) or we'll need to create the Spark Session (Dataproc).\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Setting default log level to \\\"WARN\\\".\\n\",\n      \"To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\\n\",\n      \"25/03/20 00:55:18 INFO SparkEnv: Registering MapOutputTracker\\n\",\n      \"25/03/20 00:55:18 INFO SparkEnv: Registering BlockManagerMaster\\n\",\n      \"25/03/20 00:55:18 INFO SparkEnv: Registering BlockManagerMasterHeartbeat\\n\",\n      \"25/03/20 00:55:18 INFO SparkEnv: Registering OutputCommitCoordinator\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"conf = SparkConf()\\n\",\n    \"\\n\",\n    \"if 'spark' not in globals():\\n\",\n    \"    if on_standalone:\\n\",\n    \"        import socket\\n\",\n    \"        conda_env = os.environ.get(\\\"CONDA_PREFIX\\\")\\n\",\n    \"        hostname = socket.gethostname()\\n\",\n    \"        conf.setMaster(f\\\"spark://{hostname}:7077\\\")\\n\",\n    \"        conf.set(\\\"spark.pyspark.python\\\", f\\\"{conda_env}/bin/python\\\")\\n\",\n    \"        conf.set(\\\"spark.pyspark.driver.python\\\", f\\\"{conda_env}/bin/python\\\")\\n\",\n    \"        \\n\",\n    \"    # For standalone users: adjust executor.cores and task.resource.gpu.amount based on available cores\\n\",\n    \"    conf.set(\\\"spark.executor.cores\\\", \\\"24\\\")  \\n\",\n    \"    conf.set(\\\"spark.task.maxFailures\\\", \\\"1\\\")\\n\",\n    \"    conf.set(\\\"spark.task.resource.gpu.amount\\\", \\\"0.083333\\\")\\n\",\n    \"    conf.set(\\\"spark.executor.resource.gpu.amount\\\", \\\"2\\\")  # 2 GPUs per executor\\n\",\n    \"    conf.set(\\\"spark.sql.execution.arrow.pyspark.enabled\\\", \\\"true\\\")\\n\",\n    \"    conf.set(\\\"spark.python.worker.reuse\\\", \\\"true\\\")\\n\",\n    \"\\n\",\n    \"spark = SparkSession.builder.appName(\\\"spark-dl-examples\\\").config(conf=conf).getOrCreate()\\n\",\n    \"sc = spark.sparkContext\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Load and Preprocess DataFrame\\n\",\n    \"\\n\",\n    \"Load the first 500 samples of the [Amazon Video Game Product Reviews dataset](https://huggingface.co/datasets/logankells/amazon_product_reviews_video_games) from Huggingface and store in a Spark Dataframe.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"adead9f390a042a286de64a03a88fca8\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"README.md:   0%|          | 0.00/6.00 [00:00<?, ?B/s]\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Repo card metadata block was not found. Setting CardData to empty.\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"product_reviews_ds = load_dataset(\\\"LoganKells/amazon_product_reviews_video_games\\\", split=\\\"train\\\", streaming=True)\\n\",\n    \"product_reviews_pds = pd.Series([sample[\\\"reviewText\\\"] for sample in product_reviews_ds.take(500)])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"df = spark.createDataFrame(product_reviews_pds, schema=StringType())\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 9,\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+----------------------------------------------------------------------------------------------------+\\n\",\n      \"|                                                                                               value|\\n\",\n      \"+----------------------------------------------------------------------------------------------------+\\n\",\n      \"|Installing the game was a struggle (because of games for windows live bugs).Some championship rac...|\\n\",\n      \"|If you like rally cars get this game you will have fun.It is more oriented to &#34;European marke...|\\n\",\n      \"|1st shipment received a book instead of the game.2nd shipment got a FAKE one. Game arrived with a...|\\n\",\n      \"|I had Dirt 2 on Xbox 360 and it was an okay game. I started playing games on my laptop and bought...|\\n\",\n      \"|Overall this is a well done racing game, with very good graphics for its time period. My family h...|\\n\",\n      \"+----------------------------------------------------------------------------------------------------+\\n\",\n      \"only showing top 5 rows\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"df.show(5, truncate=100)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Format each sample into the Qwen chat template, including a system prompt to guide generation.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 10,\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"system_prompt = \\\"\\\"\\\"You are a specialized review analysis AI that categorizes product reviews into precise sentiment categories.\\n\",\n    \"IMPORTANT: Your response must contain ONLY valid JSON and nothing else - no explanations, no additional text.\\n\",\n    \"For each review, analyze and output EXACTLY this JSON structure:\\n\",\n    \"{\\n\",\n    \"  \\\"primary_sentiment\\\": [EXACTLY ONE OF: \\\"positive\\\", \\\"negative\\\", \\\"neutral\\\", \\\"mixed\\\"],\\n\",\n    \"  \\\"sentiment_score\\\": [integer between 1-10, where 1 is extremely negative and 10 is extremely positive],\\n\",\n    \"  \\\"purchase_intention\\\": [EXACTLY ONE OF: \\\"will repurchase\\\", \\\"might repurchase\\\", \\\"will not repurchase\\\", \\\"recommends alternatives\\\", \\\"uncertain\\\"]\\n\",\n    \"}\\n\",\n    \"\\n\",\n    \"Do not include any text before or after the JSON. The response should start with '{' and end with '}' with no trailing characters, comments, or explanations.\\n\",\n    \"\\\"\\\"\\\"\\n\",\n    \"\\n\",\n    \"df = df.select(\\n\",\n    \"    concat(\\n\",\n    \"        lit(\\\"<|im_start|>system\\\\n\\\"),\\n\",\n    \"        lit(system_prompt),\\n\",\n    \"        lit(\\\"<|im_end|>\\\\n<|im_start|>user\\\\n\\\"),\\n\",\n    \"        lit(\\\"Analyze this review: \\\"),\\n\",\n    \"        col(\\\"value\\\"),\\n\",\n    \"        lit(\\\"<|im_end|>\\\\n<|im_start|>assistant\\\\n\\\")\\n\",\n    \"    ).alias(\\\"prompt\\\")\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 11,\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"<|im_start|>system\\n\",\n      \"You are a specialized review analysis AI that categorizes product reviews into precise sentiment categories.\\n\",\n      \"IMPORTANT: Your response must contain ONLY valid JSON and nothing else - no explanations, no additional text.\\n\",\n      \"For each review, analyze and output EXACTLY this JSON structure:\\n\",\n      \"{\\n\",\n      \"  \\\"primary_sentiment\\\": [EXACTLY ONE OF: \\\"positive\\\", \\\"negative\\\", \\\"neutral\\\", \\\"mixed\\\"],\\n\",\n      \"  \\\"sentiment_score\\\": [integer between 1-10, where 1 is extremely negative and 10 is extremely positive],\\n\",\n      \"  \\\"purchase_intention\\\": [EXACTLY ONE OF: \\\"will repurchase\\\", \\\"might repurchase\\\", \\\"will not repurchase\\\", \\\"recommends alternatives\\\", \\\"uncertain\\\"]\\n\",\n      \"}\\n\",\n      \"\\n\",\n      \"Do not include any text before or after the JSON. The response should start with '{' and end with '}' with no trailing characters, comments, or explanations.\\n\",\n      \"<|im_end|>\\n\",\n      \"<|im_start|>user\\n\",\n      \"Analyze this review: Installing the game was a struggle (because of games for windows live bugs).Some championship races and cars can only be \\\"unlocked\\\" by buying them as an addon to the game. I paid nearly 30 dollars when the game was new. I don't like the idea that I have to keep paying to keep playing.I noticed no improvement in the physics or graphics compared to Dirt 2.I tossed it in the garbage and vowed never to buy another codemasters game. I'm really tired of arcade style rally/racing games anyway.I'll continue to get my fix from Richard Burns Rally, and you should to. :)http://www.amazon.com/Richard-Burns-Rally-PC/dp/B000C97156/ref=sr_1_1?ie=UTF8&qid;=1341886844&sr;=8-1&keywords;=richard+burns+rallyThank you for reading my review! If you enjoyed it, be sure to rate it as helpful.<|im_end|>\\n\",\n      \"<|im_start|>assistant\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"print(df.take(1)[0].prompt)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 12,\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"data_path = \\\"spark-dl-datasets/amazon_video_game_reviews\\\"\\n\",\n    \"if on_databricks:\\n\",\n    \"    dbutils.fs.mkdirs(\\\"/FileStore/spark-dl-datasets\\\")\\n\",\n    \"    data_path = \\\"dbfs:/FileStore/\\\" + data_path\\n\",\n    \"\\n\",\n    \"df.write.mode(\\\"overwrite\\\").parquet(data_path)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Using vLLM Server\\n\",\n    \"In this section, we demonstrate integration with [vLLM Serving](https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html), an open-source server with an OpenAI-compatible completions endpoint for LLMs.  \\n\",\n    \"\\n\",\n    \"The process looks like this:\\n\",\n    \"- Distribute a server startup task across the Spark cluster, instructing each node to launch a vLLM server process.\\n\",\n    \"- Define a vLLM inference function, which sends inference request to the local server on a given node.\\n\",\n    \"- Wrap the vLLM inference function in a predict_batch_udf to launch parallel inference requests using Spark.\\n\",\n    \"- Finally, distribute a shutdown signal to terminate the vLLM server processes on each node.\\n\",\n    \"\\n\",\n    \"<img src=\\\"../images/spark-server-mg.png\\\" alt=\\\"drawing\\\" width=\\\"700\\\"/>\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 13,\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"from functools import partial\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Import the helper class from server_utils.py:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"sc.addPyFile(\\\"server_utils.py\\\")\\n\",\n    \"\\n\",\n    \"from server_utils import VLLMServerManager\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"There are currently some hanging issues with vLLM's `torch.compile` on Databricks, which we are working to resolve. For now we will enforce eager mode on Databricks, which disables compilation at some performance cost.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"enforce_eager = True if on_databricks else False\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Start vLLM servers\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"The `VLLMServerManager` will handle the lifecycle of vLLM server instances across the Spark cluster:\\n\",\n    \"- Find available ports for HTTP\\n\",\n    \"- Deploy a server on each node via stage-level scheduling\\n\",\n    \"- Gracefully shutdown servers across nodes\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"model_name = \\\"qwen-2.5-14b\\\"\\n\",\n    \"server_manager = VLLMServerManager(model_name=model_name, model_path=model_path)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"You can pass any of the supported [vLLM serve CLI arguments](https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#vllm-serve) as key-word arguments when starting the servers. Note that this can take some time, as it includes loading the model from disk, Torch compilation, and capturing CUDA graphs.\\n\",\n    \"\\n\",\n    \"Here, we set `tensor_parallel_size` to the number of GPUs per node:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 22,\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2025-03-20 01:04:42,978 - INFO - Requesting stage-level resources: (cores=13, gpu=2.0)\\n\",\n      \"2025-03-20 01:04:42,979 - INFO - Starting 2 VLLM servers.\\n\",\n      \"                                                                                \\r\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'spark-dl-inference-vllm-tp-w-0': (35438, [7000]),\\n\",\n       \" 'spark-dl-inference-vllm-tp-w-1': (35288, [7000])}\"\n      ]\n     },\n     \"execution_count\": 22,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"tensor_parallel_size = int(spark.conf.get(\\\"spark.executor.resource.gpu.amount\\\"))\\n\",\n    \"server_manager.start_servers(tensor_parallel_size=tensor_parallel_size,\\n\",\n    \"                             gpu_memory_utilization=0.95,\\n\",\n    \"                             max_model_len=6600,\\n\",\n    \"                             task=\\\"generate\\\",\\n\",\n    \"                             enforce_eager=enforce_eager,\\n\",\n    \"                             wait_retries=100)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Define client function\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Get the hostname -> url mapping from the server manager:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 23,\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"host_to_http_url = server_manager.host_to_http_url\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 51,\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"def vllm_fn(model_name, host_to_url):\\n\",\n    \"    import socket\\n\",\n    \"    import json\\n\",\n    \"    import requests\\n\",\n    \"\\n\",\n    \"    url = host_to_url[socket.gethostname()]\\n\",\n    \"    \\n\",\n    \"    def predict(inputs):\\n\",\n    \"        print(inputs)\\n\",\n    \"        response = requests.post(\\n\",\n    \"            \\\"http://localhost:7000/v1/completions\\\",\\n\",\n    \"            json={\\n\",\n    \"                \\\"model\\\": model_name,\\n\",\n    \"                \\\"prompt\\\": inputs.tolist(),\\n\",\n    \"                \\\"max_tokens\\\": 50,\\n\",\n    \"                \\\"temperature\\\": 0.7,\\n\",\n    \"                \\\"top_p\\\": 0.8,\\n\",\n    \"                \\\"repetition_penalty\\\": 1.05,\\n\",\n    \"            }\\n\",\n    \"        )\\n\",\n    \"        result_dicts = [json.loads(o[\\\"text\\\"]) for o in response.json()[\\\"choices\\\"]]\\n\",\n    \"        return result_dicts\\n\",\n    \"    \\n\",\n    \"    return predict\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 55,\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"generate = predict_batch_udf(partial(vllm_fn, model_name=model_name, host_to_url=host_to_http_url),\\n\",\n    \"                             return_type=StructType([\\n\",\n    \"                                 StructField(\\\"primary_sentiment\\\", StringType()),\\n\",\n    \"                                 StructField(\\\"sentiment_score\\\", IntegerType()),\\n\",\n    \"                                 StructField(\\\"purchase_intention\\\", StringType())\\n\",\n    \"                             ]),\\n\",\n    \"                             batch_size=32)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Load DataFrame\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 56,\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"df = spark.read.parquet(data_path).repartition(16)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Run Inference\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 57,\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 31:=================================================>      (14 + 2) / 16]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 29.6 ms, sys: 6.89 ms, total: 36.5 ms\\n\",\n      \"Wall time: 33 s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"# first pass caches model/fn\\n\",\n    \"preds = df.withColumn(\\\"outputs\\\", generate(col(\\\"prompt\\\"))).select(\\\"prompt\\\", \\\"outputs.*\\\")\\n\",\n    \"results = preds.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 58,\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 34:=================================================>      (14 + 2) / 16]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 25.6 ms, sys: 6.73 ms, total: 32.3 ms\\n\",\n      \"Wall time: 32 s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"preds = df.withColumn(\\\"outputs\\\", generate(col(\\\"prompt\\\"))).select(\\\"prompt\\\", \\\"outputs.*\\\")\\n\",\n    \"results = preds.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 59,\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 37:>                                                         (0 + 1) / 1]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+--------------------------------------------------+-----------------+---------------+-------------------+\\n\",\n      \"|                                            prompt|primary_sentiment|sentiment_score| purchase_intention|\\n\",\n      \"+--------------------------------------------------+-----------------+---------------+-------------------+\\n\",\n      \"|<|im_start|>system\\\\nYou are a specialized revie...|         positive|              9|    will repurchase|\\n\",\n      \"|<|im_start|>system\\\\nYou are a specialized revie...|         positive|              9|    will repurchase|\\n\",\n      \"|<|im_start|>system\\\\nYou are a specialized revie...|         positive|              8|    will repurchase|\\n\",\n      \"|<|im_start|>system\\\\nYou are a specialized revie...|         negative|              4|will not repurchase|\\n\",\n      \"|<|im_start|>system\\\\nYou are a specialized revie...|            mixed|              6|   might repurchase|\\n\",\n      \"+--------------------------------------------------+-----------------+---------------+-------------------+\\n\",\n      \"only showing top 5 rows\\n\",\n      \"\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"preds.show(5, truncate=50)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 60,\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Review: <|im_start|>system\\n\",\n      \"You are a specialized review analysis AI that categorizes product reviews into precise sentiment categories.\\n\",\n      \"IMPORTANT: Your response must contain ONLY valid JSON and nothing else - no explanations, no additional text.\\n\",\n      \"For each review, analyze and output EXACTLY this JSON structure:\\n\",\n      \"{\\n\",\n      \"  \\\"primary_sentiment\\\": [EXACTLY ONE OF: \\\"positive\\\", \\\"negative\\\", \\\"neutral\\\", \\\"mixed\\\"],\\n\",\n      \"  \\\"sentiment_score\\\": [integer between 1-10, where 1 is extremely negative and 10 is extremely positive],\\n\",\n      \"  \\\"purchase_intention\\\": [EXACTLY ONE OF: \\\"will repurchase\\\", \\\"might repurchase\\\", \\\"will not repurchase\\\", \\\"recommends alternatives\\\", \\\"uncertain\\\"]\\n\",\n      \"}\\n\",\n      \"\\n\",\n      \"Do not include any text before or after the JSON. The response should start with '{' and end with '}' with no trailing characters, comments, or explanations.\\n\",\n      \"<|im_end|>\\n\",\n      \"<|im_start|>user\\n\",\n      \"Analyze this review: I have never played anything like this since. Everything from Sly  Racoon, to Ratchet and Clank, owe it to this.Wicked witch Gruntilda takes Banjo's sister to hey layer, miles away in a realistic 3D cartoon world.Banjo is a bear with Kazooie a bird in his backpack that can help him jump and fly and basically you learn to do lots of things with it. You solve puzzles via action and collect tolkens across lovely maps. Mumbo Jumbo transforms Banjo into some other creatures along the way. You can fly. It was amazing. A full adventure all the way to end. We played it for months and I have NEVER played anything like it again. The makers of Donkey Kong released it at the best time. It is now up to the future generations to make adventure concepts better than this one. This is one of the best N64 games ever.<|im_end|>\\n\",\n      \"<|im_start|>assistant\\n\",\n      \"\\n\",\n      \"Sentiment: positive, Score: 9, Status: will repurchase\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"sample = results[0]\\n\",\n    \"print(\\\"Review:\\\", sample[\\\"prompt\\\"])\\n\",\n    \"print(f\\\"Sentiment: {sample['primary_sentiment']}, Score: {sample['sentiment_score']}, Status: {sample['purchase_intention']}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Shut down server on each executor\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 61,\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2025-03-20 01:19:32,218 - INFO - Requesting stage-level resources: (cores=13, gpu=2.0)\\n\",\n      \"2025-03-20 01:19:33,872 - INFO - Successfully stopped 2 VLLM servers.           \\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"[True, True]\"\n      ]\n     },\n     \"execution_count\": 61,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"server_manager.stop_servers()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 62,\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"if not on_databricks: # on databricks, spark.stop() puts the cluster in a bad state\\n\",\n    \"    spark.stop()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"spark-dl-vllm\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.11.11\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 4\n}\n"
  },
  {
    "path": "examples/ML+DL-Examples/Spark-DL/dl_inference/vllm/qwen-2.5-7b_vllm.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"<img src=\\\"http://developer.download.nvidia.com/notebooks/dlsw-notebooks/tensorrt_torchtrt_efficientnet/nvidia_logo.png\\\" width=\\\"90px\\\">\\n\",\n    \"\\n\",\n    \"# PySpark LLM Inference: Qwen-2.5 Text Summarization\\n\",\n    \"\\n\",\n    \"In this notebook, we demonstrate distributed batch inference with [Qwen-2.5](https://huggingface.co/Qwen/Qwen2.5-7B-Instruct), using open weights on Huggingface.\\n\",\n    \"\\n\",\n    \"The Qwen-2.5-7b-instruct is an instruction-fine-tuned version of the Qwen-2.5-7b base model. We'll show how to use the model to perform text summarization.\\n\",\n    \"\\n\",\n    \"**Note:** Running this model on GPU with 16-bit precision requires **~16GB** of GPU RAM. Make sure your instances have sufficient GPU capacity.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Manually enable Huggingface tokenizer parallelism to avoid disabling with PySpark parallelism.\\n\",\n    \"# See (https://github.com/huggingface/transformers/issues/5486) for more info. \\n\",\n    \"import os\\n\",\n    \"os.environ[\\\"TOKENIZERS_PARALLELISM\\\"] = \\\"true\\\"\\n\",\n    \"\\n\",\n    \"# vLLM does CUDA init at import time. Forking will try to re-initialize CUDA if vLLM was imported before and throw an error.\\n\",\n    \"os.environ[\\\"VLLM_WORKER_MULTIPROC_METHOD\\\"] = \\\"spawn\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Check the cluster environment to handle any platform-specific configurations.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"on_databricks = os.environ.get(\\\"DATABRICKS_RUNTIME_VERSION\\\", False)\\n\",\n    \"on_dataproc = os.environ.get(\\\"DATAPROC_IMAGE_VERSION\\\", False)\\n\",\n    \"on_standalone = not (on_databricks or on_dataproc)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# For cloud environments, load the model to the distributed file system.\\n\",\n    \"if on_databricks:\\n\",\n    \"    models_dir = \\\"/dbfs/FileStore/spark-dl-models\\\"\\n\",\n    \"    dbutils.fs.mkdirs(\\\"/FileStore/spark-dl-models\\\")\\n\",\n    \"    model_path = f\\\"{models_dir}/qwen-2.5-7b\\\"\\n\",\n    \"elif on_dataproc:\\n\",\n    \"    models_dir = \\\"/mnt/gcs/spark-dl-models\\\"\\n\",\n    \"    os.mkdir(models_dir) if not os.path.exists(models_dir) else None\\n\",\n    \"    model_path = f\\\"{models_dir}/qwen-2.5-7b\\\"\\n\",\n    \"else:\\n\",\n    \"    model_path = os.path.abspath(\\\"qwen-2.5-7b\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Download the model from huggingface hub.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"eeb0daf2bd7948bebd94ce2a9a5a01b8\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"Fetching 14 files:   0%|          | 0/14 [00:00<?, ?it/s]\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"from huggingface_hub import snapshot_download\\n\",\n    \"\\n\",\n    \"model_path = snapshot_download(\\n\",\n    \"    repo_id=\\\"Qwen/Qwen2.5-7B-Instruct\\\",\\n\",\n    \"    local_dir=model_path\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"There are currently some hanging issues with vLLM's `torch.compile` on Databricks, which we are working to resolve. For now we will enforce eager mode on Databricks, which disables compilation at some performance cost.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"enforce_eager = True if on_databricks else False\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Warmup: Running locally\\n\",\n    \"\\n\",\n    \"**Note**: If the driver node does not have sufficient GPU capacity, proceed to the PySpark section.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from vllm import LLM, SamplingParams\\n\",\n    \"from transformers import AutoTokenizer\\n\",\n    \"\\n\",\n    \"sampling_params = SamplingParams(temperature=0.7, top_p=0.8, repetition_penalty=1.05, max_tokens=128)\\n\",\n    \"tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side=\\\"left\\\")\\n\",\n    \"llm = LLM(model=model_path, gpu_memory_utilization=0.95, max_model_len=6600, enforce_eager=enforce_eager)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"system_prompt = {\\n\",\n    \"    \\\"role\\\": \\\"system\\\",\\n\",\n    \"    \\\"content\\\": \\\"You are a knowledgeable AI assistant that provides accurate answers to questions.\\\"\\n\",\n    \"}\\n\",\n    \"\\n\",\n    \"queries = [\\n\",\n    \"    \\\"What does CUDA stand for?\\\",\\n\",\n    \"    \\\"In one sentence, what's the difference between a CPU and a GPU?\\\",\\n\",\n    \"    \\\"What's the hottest planet in the solar system?\\\"\\n\",\n    \"]\\n\",\n    \"\\n\",\n    \"prompts = [\\n\",\n    \"    [\\n\",\n    \"        system_prompt,\\n\",\n    \"        {\\\"role\\\": \\\"user\\\", \\\"content\\\": query}\\n\",\n    \"    ] for query in queries\\n\",\n    \"]\\n\",\n    \"\\n\",\n    \"text = tokenizer.apply_chat_template(\\n\",\n    \"    prompts,\\n\",\n    \"    tokenize=False,\\n\",\n    \"    add_generation_prompt=True,\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Processed prompts: 100%|██████████| 3/3 [00:01<00:00,  1.76it/s, est. speed input: 63.83 toks/s, output: 100.14 toks/s]\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"outputs = llm.generate(text, sampling_params=sampling_params)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 9,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Q: What does CUDA stand for?\\n\",\n      \"A: CUDA stands for Compute Unified Device Architecture. It is a parallel computing platform and application programming interface (API) model created by NVIDIA. CUDA allows software developers to use a CUDA-enabled graphics processing unit (GPU) for general purpose processing.\\n\",\n      \"\\n\",\n      \"Q: In one sentence, what's the difference between a CPU and a GPU?\\n\",\n      \"A: A CPU (Central Processing Unit) is designed for general-purpose processing and managing the overall operations of a computer, while a GPU (Graphics Processing Unit) is specialized for parallel processing tasks, particularly those related to rendering graphics and accelerating machine learning tasks.\\n\",\n      \"\\n\",\n      \"Q: What's the hottest planet in the solar system?\\n\",\n      \"A: The hottest planet in the solar system is Venus. Despite Mercury being closer to the Sun, Venus has a thick atmosphere that traps heat in a runaway version of the greenhouse effect, creating a much hotter surface temperature than Mercury. The average surface temperature on Venus is about 462°C (864°F), making it the hottest planet in our solar system.\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"for q, o in zip(queries, outputs):\\n\",\n    \"    print(f\\\"Q: {q}\\\")\\n\",\n    \"    print(f\\\"A: {o.outputs[0].text}\\\\n\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Unload the model to free up the GPU for the PySpark section.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import contextlib\\n\",\n    \"import gc\\n\",\n    \"import torch\\n\",\n    \"from vllm.distributed import destroy_model_parallel, destroy_distributed_environment\\n\",\n    \"\\n\",\n    \"def cleanup():\\n\",\n    \"    destroy_model_parallel()\\n\",\n    \"    destroy_distributed_environment()\\n\",\n    \"    with contextlib.suppress(AssertionError):\\n\",\n    \"        torch.distributed.destroy_process_group()\\n\",\n    \"    gc.collect()\\n\",\n    \"    torch.cuda.empty_cache()\\n\",\n    \"\\n\",\n    \"del llm\\n\",\n    \"cleanup()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## PySpark\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 11,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import pandas as pd\\n\",\n    \"from pyspark.sql.types import *\\n\",\n    \"from pyspark import SparkConf\\n\",\n    \"from pyspark.sql import SparkSession\\n\",\n    \"from pyspark.sql.functions import pandas_udf, col, struct, length, lit, concat\\n\",\n    \"from pyspark.ml.functions import predict_batch_udf\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 12,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[2025-03-24 11:37:46] INFO config.py:54: PyTorch version 2.6.0 available.\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"import os\\n\",\n    \"import datasets\\n\",\n    \"from datasets import load_dataset\\n\",\n    \"datasets.disable_progress_bars()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Create Spark Session\\n\",\n    \"\\n\",\n    \"For local standalone clusters, we'll connect to the cluster and create the Spark Session.  \\n\",\n    \"For CSP environments, Spark will either be preconfigured (Databricks) or we'll need to create the Spark Session (Dataproc).\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 13,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"25/03/24 11:37:47 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)\\n\",\n      \"25/03/24 11:37:47 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\\n\",\n      \"Setting default log level to \\\"WARN\\\".\\n\",\n      \"To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\\n\",\n      \"25/03/24 11:37:47 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"conf = SparkConf()\\n\",\n    \"\\n\",\n    \"if 'spark' not in globals():\\n\",\n    \"    if on_standalone:\\n\",\n    \"        import socket\\n\",\n    \"        conda_env = os.environ.get(\\\"CONDA_PREFIX\\\")\\n\",\n    \"        hostname = socket.gethostname()\\n\",\n    \"        conf.setMaster(f\\\"spark://{hostname}:7077\\\")\\n\",\n    \"        conf.set(\\\"spark.pyspark.python\\\", f\\\"{conda_env}/bin/python\\\")\\n\",\n    \"        conf.set(\\\"spark.pyspark.driver.python\\\", f\\\"{conda_env}/bin/python\\\")\\n\",\n    \"\\n\",\n    \"    conf.set(\\\"spark.executor.cores\\\", \\\"8\\\")\\n\",\n    \"    conf.set(\\\"spark.task.maxFailures\\\", \\\"1\\\")\\n\",\n    \"    conf.set(\\\"spark.task.resource.gpu.amount\\\", \\\"0.125\\\")\\n\",\n    \"    conf.set(\\\"spark.executor.resource.gpu.amount\\\", \\\"1\\\")\\n\",\n    \"    conf.set(\\\"spark.sql.execution.arrow.pyspark.enabled\\\", \\\"true\\\")\\n\",\n    \"    conf.set(\\\"spark.python.worker.reuse\\\", \\\"true\\\")\\n\",\n    \"\\n\",\n    \"spark = SparkSession.builder.appName(\\\"spark-dl-examples\\\").config(conf=conf).getOrCreate()\\n\",\n    \"sc = spark.sparkContext\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Load and Preprocess DataFrame\\n\",\n    \"\\n\",\n    \"Load the first 500 samples of the [ML ArXiv dataset](https://huggingface.co/datasets/CShorten/ML-ArXiv-Papers) from Huggingface and store in a Spark Dataframe.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 14,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"ml_arxiv_dataset = load_dataset(\\\"CShorten/ML-ArXiv-Papers\\\", split=\\\"train\\\", streaming=True)\\n\",\n    \"ml_arxiv_pds = pd.Series([sample[\\\"abstract\\\"] for sample in ml_arxiv_dataset.take(500)])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 15,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"df = spark.createDataFrame(ml_arxiv_pds, schema=StringType())\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 16,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+----------------------------------------------------------------------------------------------------+\\n\",\n      \"|                                                                                               value|\\n\",\n      \"+----------------------------------------------------------------------------------------------------+\\n\",\n      \"|  The problem of statistical learning is to construct a predictor of a random\\\\nvariable $Y$ as a ...|\\n\",\n      \"|  In a sensor network, in practice, the communication among sensors is subject\\\\nto:(1) errors or ...|\\n\",\n      \"|  The on-line shortest path problem is considered under various models of\\\\npartial monitoring. Gi...|\\n\",\n      \"|  Ordinal regression is an important type of learning, which has properties of\\\\nboth classificati...|\\n\",\n      \"|  This paper uncovers and explores the close relationship between Monte Carlo\\\\nOptimization of a ...|\\n\",\n      \"+----------------------------------------------------------------------------------------------------+\\n\",\n      \"only showing top 5 rows\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"df.show(5, truncate=100)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Format each sample into the chat template, including a system prompt to guide generation.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 17,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"system_prompt = '''You are a knowledgeable AI assistant. Your job is to create a 1 sentence summary \\n\",\n    \"of a research abstract that captures the main objective, methodology, and key findings, using clear \\n\",\n    \"language while preserving technical accuracy and quantitative results.'''\\n\",\n    \"\\n\",\n    \"df = df.select(\\n\",\n    \"    concat(\\n\",\n    \"        lit(\\\"<|im_start|>system\\\\n\\\"),\\n\",\n    \"        lit(system_prompt),\\n\",\n    \"        lit(\\\"<|im_end|>\\\\n<|im_start|>user\\\\n\\\"),\\n\",\n    \"        col(\\\"value\\\"),\\n\",\n    \"        lit(\\\"<|im_end|>\\\\n<|im_start|>assistant\\\\n\\\")\\n\",\n    \"    ).alias(\\\"prompt\\\")\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 18,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"<|im_start|>system\\n\",\n      \"You are a knowledgeable AI assistant. Your job is to create a 1 sentence summary \\n\",\n      \"of a research abstract that captures the main objective, methodology, and key findings, using clear \\n\",\n      \"language while preserving technical accuracy and quantitative results.<|im_end|>\\n\",\n      \"<|im_start|>user\\n\",\n      \"  The problem of statistical learning is to construct a predictor of a random\\n\",\n      \"variable $Y$ as a function of a related random variable $X$ on the basis of an\\n\",\n      \"i.i.d. training sample from the joint distribution of $(X,Y)$. Allowable\\n\",\n      \"predictors are drawn from some specified class, and the goal is to approach\\n\",\n      \"asymptotically the performance (expected loss) of the best predictor in the\\n\",\n      \"class. We consider the setting in which one has perfect observation of the\\n\",\n      \"$X$-part of the sample, while the $Y$-part has to be communicated at some\\n\",\n      \"finite bit rate. The encoding of the $Y$-values is allowed to depend on the\\n\",\n      \"$X$-values. Under suitable regularity conditions on the admissible predictors,\\n\",\n      \"the underlying family of probability distributions and the loss function, we\\n\",\n      \"give an information-theoretic characterization of achievable predictor\\n\",\n      \"performance in terms of conditional distortion-rate functions. The ideas are\\n\",\n      \"illustrated on the example of nonparametric regression in Gaussian noise.\\n\",\n      \"<|im_end|>\\n\",\n      \"<|im_start|>assistant\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"print(df.take(1)[0].prompt)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 19,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"data_path = \\\"spark-dl-datasets/arxiv_abstracts\\\"\\n\",\n    \"if on_databricks:\\n\",\n    \"    dbutils.fs.mkdirs(\\\"/FileStore/spark-dl-datasets\\\")\\n\",\n    \"    data_path = \\\"dbfs:/FileStore/\\\" + data_path\\n\",\n    \"\\n\",\n    \"df.write.mode(\\\"overwrite\\\").parquet(data_path)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Using vLLM Server\\n\",\n    \"In this section, we demonstrate integration with [vLLM Serving](https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html), an open-source server with an OpenAI-compatible completions endpoint for LLMs.  \\n\",\n    \"\\n\",\n    \"The process looks like this:\\n\",\n    \"- Distribute a server startup task across the Spark cluster, instructing each node to launch a vLLM server process.\\n\",\n    \"- Define a vLLM inference function, which sends inference request to the local server on a given node.\\n\",\n    \"- Wrap the vLLM inference function in a predict_batch_udf to launch parallel inference requests using Spark.\\n\",\n    \"- Finally, distribute a shutdown signal to terminate the vLLM server processes on each node.\\n\",\n    \"\\n\",\n    \"<img src=\\\"../images/spark-server.png\\\" alt=\\\"drawing\\\" width=\\\"700\\\"/>\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 20,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from functools import partial\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Import the helper class from server_utils.py:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 21,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"sc.addPyFile(\\\"server_utils.py\\\")\\n\",\n    \"\\n\",\n    \"from server_utils import VLLMServerManager\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Start vLLM servers\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"The `VLLMServerManager` will handle the lifecycle of vLLM server instances across the Spark cluster:\\n\",\n    \"- Find available ports for HTTP\\n\",\n    \"- Deploy a server on each node via stage-level scheduling\\n\",\n    \"- Gracefully shutdown servers across nodes\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 22,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"model_name = \\\"qwen-2.5-7b\\\"\\n\",\n    \"server_manager = VLLMServerManager(model_name=model_name, model_path=model_path)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"You can pass any of the supported [vLLM serve CLI arguments](https://docs.vllm.ai/en/stable/serving/openai_compatible_server.html#vllm-serve) as key-word arguments when starting the servers. Note that this can take some time, as it includes loading the model from disk, Torch compilation, and capturing CUDA graphs.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 23,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[2025-03-24 11:37:57] INFO server_utils.py:359: Requesting stage-level resources: (cores=5, gpu=1.0)\\n\",\n      \"[2025-03-24 11:37:57] INFO server_utils.py:390: Starting 1 VLLM servers.\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'cb4ae00-lcedt': (4022579, [7000])}\"\n      ]\n     },\n     \"execution_count\": 23,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"server_manager.start_servers(gpu_memory_utilization=0.95,\\n\",\n    \"                             max_model_len=6600,\\n\",\n    \"                             task=\\\"generate\\\",\\n\",\n    \"                             enforce_eager=enforce_eager,\\n\",\n    \"                             wait_retries=60)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Define client function\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Get the hostname -> url mapping from the server manager:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 24,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"host_to_http_url = server_manager.host_to_http_url\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Define the vLLM inference function, which returns a predict function for batch inference through the server:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 25,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def vllm_fn(model_name, host_to_url):\\n\",\n    \"    import socket\\n\",\n    \"    import numpy as np\\n\",\n    \"    import requests\\n\",\n    \"\\n\",\n    \"    url = host_to_url[socket.gethostname()]\\n\",\n    \"    \\n\",\n    \"    def predict(inputs):\\n\",\n    \"        response = requests.post(\\n\",\n    \"            f\\\"{url}/v1/completions\\\",\\n\",\n    \"            json={\\n\",\n    \"                \\\"model\\\": model_name,\\n\",\n    \"                \\\"prompt\\\": inputs.tolist(),\\n\",\n    \"                \\\"max_tokens\\\": 128,\\n\",\n    \"                \\\"temperature\\\": 0.7,\\n\",\n    \"                \\\"top_p\\\": 0.8,\\n\",\n    \"                \\\"repetition_penalty\\\": 1.05,\\n\",\n    \"            }\\n\",\n    \"        )\\n\",\n    \"        return np.array([r[\\\"text\\\"] for r in response.json()[\\\"choices\\\"]])\\n\",\n    \"    \\n\",\n    \"    return predict\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 26,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"generate = predict_batch_udf(partial(vllm_fn, model_name=model_name, host_to_url=host_to_http_url),\\n\",\n    \"                             return_type=StringType(),\\n\",\n    \"                             batch_size=32)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Load DataFrame\\n\",\n    \"\\n\",\n    \"We'll parallelize over a small set of prompts for demonstration.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 27,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"df = spark.read.parquet(data_path).limit(256).repartition(8)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Run Inference\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 28,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 11:==================================================>       (7 + 1) / 8]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 7.53 ms, sys: 2.19 ms, total: 9.72 ms\\n\",\n      \"Wall time: 13.9 s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"# first pass caches model/fn and does JIT compilation\\n\",\n    \"preds = df.withColumn(\\\"outputs\\\", generate(col(\\\"prompt\\\")))\\n\",\n    \"results = preds.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 29,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 17:===========================================>              (6 + 2) / 8]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 10.7 ms, sys: 3.65 ms, total: 14.3 ms\\n\",\n      \"Wall time: 6.26 s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"preds = df.withColumn(\\\"outputs\\\", generate(col(\\\"prompt\\\")))\\n\",\n    \"results = preds.collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Sample output:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 30,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Q: <|im_start|>system\\n\",\n      \"You are a knowledgeable AI assistant. Your job is to create a 1 sentence summary \\n\",\n      \"of a research abstract that captures the main objective, methodology, and key findings, using clear \\n\",\n      \"language while preserving technical accuracy and quantitative results.<|im_end|>\\n\",\n      \"<|im_start|>user\\n\",\n      \"  Images can be segmented by first using a classifier to predict an affinity\\n\",\n      \"graph that reflects the degree to which image pixels must be grouped together\\n\",\n      \"and then partitioning the graph to yield a segmentation. Machine learning has\\n\",\n      \"been applied to the affinity classifier to produce affinity graphs that are\\n\",\n      \"good in the sense of minimizing edge misclassification rates. However, this\\n\",\n      \"error measure is only indirectly related to the quality of segmentations\\n\",\n      \"produced by ultimately partitioning the affinity graph. We present the first\\n\",\n      \"machine learning algorithm for training a classifier to produce affinity graphs\\n\",\n      \"that are good in the sense of producing segmentations that directly minimize\\n\",\n      \"the Rand index, a well known segmentation performance measure. The Rand index\\n\",\n      \"measures segmentation performance by quantifying the classification of the\\n\",\n      \"connectivity of image pixel pairs after segmentation. By using the simple graph\\n\",\n      \"partitioning algorithm of finding the connected components of the thresholded\\n\",\n      \"affinity graph, we are able to train an affinity classifier to directly\\n\",\n      \"minimize the Rand index of segmentations resulting from the graph partitioning.\\n\",\n      \"Our learning algorithm corresponds to the learning of maximin affinities\\n\",\n      \"between image pixel pairs, which are predictive of the pixel-pair connectivity.\\n\",\n      \"<|im_end|>\\n\",\n      \"<|im_start|>assistant\\n\",\n      \" \\n\",\n      \"\\n\",\n      \"A: The research presents a machine learning algorithm that trains an affinity classifier to directly minimize the Rand index of image segmentations by producing affinity graphs optimized for pixel-pair connectivity, using a simple graph partitioning method. \\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"print(f\\\"Q: {results[0].prompt} \\\\n\\\")\\n\",\n    \"print(f\\\"A: {results[0].outputs} \\\\n\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Shut down server on each executor\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 31,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[2025-03-24 11:38:49] INFO server_utils.py:359: Requesting stage-level resources: (cores=5, gpu=1.0)\\n\",\n      \"[2025-03-24 11:38:50] INFO server_utils.py:447: Successfully stopped 1 VLLM servers.\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"[True]\"\n      ]\n     },\n     \"execution_count\": 31,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"server_manager.stop_servers()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 32,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"if not on_databricks: # on databricks, spark.stop() puts the cluster in a bad state\\n\",\n    \"    spark.stop()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"spark-dl-vllm\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.11.11\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "examples/ML+DL-Examples/Spark-DL/dl_inference/vllm_requirements.txt",
    "content": "# Copyright (c) 2025, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\npyspark>=3.4.0\ndatasets\nvllm\nipywidgets\njupyterlab\n"
  },
  {
    "path": "examples/ML+DL-Examples/Spark-Rapids-ML/pca/README.md",
    "content": "# Spark-Rapids-ML PCA example\n\nThis is an example of the GPU accelerated PCA algorithm from the [Spark-Rapids-ML](https://github.com/NVIDIA/spark-rapids-ml) library, which provides PySpark ML compatible algorithms powered by RAPIDS cuML. \nThe notebook uses PCA to reduce a random dataset with 2048 feature dimensions to 3 dimensions. We train both the GPU and CPU algorithms for comparison. \n\n## Build\n\nPlease refer to the Spark-Rapids-ML [README](https://github.com/NVIDIA/spark-rapids-ml/blob/HEAD/python) to setup the RAPIDS conda environment and install Spark-Rapids-ML dependencies. \n\n## Download RAPIDS Jar from Maven Central\n\nDownload the [Spark-Rapids plugin](https://nvidia.github.io/spark-rapids/docs/download.html#download-rapids-accelerator-for-apache-spark-v24081).  \nFor Spark-RAPIDS-ML version 26.02.0, download the RAPIDS jar from Maven Central: [rapids-4-spark_2.12-26.02.0.jar](https://repo1.maven.org/maven2/com/nvidia/rapids-4-spark_2.12/26.02.0/rapids-4-spark_2.12-26.02.0.jar). \n\n## Running the Notebooks\n\nOnce you have built your environment, please follow these instructions to run the notebooks. Make sure `jupyterlab` is installed in the environment.\n\n**Note**: for demonstration purposes, these examples just use a local Spark Standalone cluster with a single executor, but you should be able to run them on any distributed Spark cluster.\n```\n# setup environment variables\nexport SPARK_HOME=/path/to/spark\nexport RAPIDS_JAR=/path/to/rapids.jar\n\n# launches the standalone cluster and jupyter with pyspark\n./start-spark-rapids.sh\n\n# BROWSE to localhost:8888 to view/run notebooks\n\n# stop spark standalone cluster\n${SPARK_HOME}/sbin/stop-worker.sh; ${SPARK_HOME}/sbin/stop-master.sh\n```\n"
  },
  {
    "path": "examples/ML+DL-Examples/Spark-Rapids-ML/pca/notebooks/pca.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Principal Component Analysis (PCA)\\n\",\n    \"\\n\",\n    \"In this notebook, we will demonstrate the end-to-end workflow of Spark RAPIDS accelerated PCA.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import numpy as np\\n\",\n    \"import pandas as pd\\n\",\n    \"import time\\n\",\n    \"import os\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"No active Spark session found, initializing manually.\\n\",\n      \"File already exists. Skipping download.\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"24/10/04 18:04:27 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)\\n\",\n      \"24/10/04 18:04:27 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\\n\",\n      \"24/10/04 18:04:27 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\\n\",\n      \"Setting default log level to \\\"WARN\\\".\\n\",\n      \"To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\\n\",\n      \"24/10/04 18:04:27 WARN RapidsPluginUtils: RAPIDS Accelerator 25.02.1 using cudf 25.02.1, private revision 9fac64da220ddd6bf5626bd7bd1dd74c08603eac\\n\",\n      \"24/10/04 18:04:27 WARN RapidsPluginUtils: RAPIDS Accelerator is enabled, to disable GPU support set `spark.rapids.sql.enabled` to false.\\n\",\n      \"24/10/04 18:04:31 WARN GpuDeviceManager: RMM pool is disabled since spark.rapids.memory.gpu.pooling.enabled is set to false; however, this configuration is deprecated and the behavior may change in a future release.\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"from pyspark.sql import SparkSession\\n\",\n    \"from pyspark import SparkConf\\n\",\n    \"\\n\",\n    \"def get_rapids_jar():\\n\",\n    \"    import os\\n\",\n    \"    import requests\\n\",\n    \"\\n\",\n    \"    SPARK_RAPIDS_VERSION = \\\"26.02.0\\\"\\n\",\n    \"    rapids_jar = f\\\"rapids-4-spark_2.12-{SPARK_RAPIDS_VERSION}.jar\\\"\\n\",\n    \"    if not os.path.exists(rapids_jar):\\n\",\n    \"        print(\\\"Downloading spark rapids jar\\\")\\n\",\n    \"        url = f\\\"https://repo1.maven.org/maven2/com/nvidia/rapids-4-spark_2.12/{SPARK_RAPIDS_VERSION}/{rapids_jar}\\\"\\n\",\n    \"        response = requests.get(url)\\n\",\n    \"        if response.status_code == 200:\\n\",\n    \"            with open(rapids_jar, \\\"wb\\\") as f:\\n\",\n    \"                f.write(response.content)\\n\",\n    \"            print(f\\\"File '{rapids_jar}' downloaded and saved successfully.\\\")\\n\",\n    \"        else:\\n\",\n    \"            print(f\\\"Failed to download the file. Status code: {response.status_code}\\\")\\n\",\n    \"    else:\\n\",\n    \"        print(\\\"File already exists. Skipping download.\\\")\\n\",\n    \"    return rapids_jar\\n\",\n    \"\\n\",\n    \"def initialize_spark(rapids_jar: str):\\n\",\n    \"    '''\\n\",\n    \"    If no active Spark session is found, initialize and configure a new one. \\n\",\n    \"    '''\\n\",\n    \"    import socket\\n\",\n    \"    hostname = socket.gethostname()\\n\",\n    \"\\n\",\n    \"    conf = SparkConf()\\n\",\n    \"    conf.setMaster(f\\\"spark://{hostname}:7077\\\") # Assuming master is on host and default port. \\n\",\n    \"    conf.set(\\\"spark.task.maxFailures\\\", \\\"1\\\")\\n\",\n    \"    conf.set(\\\"spark.driver.memory\\\", \\\"10g\\\")\\n\",\n    \"    conf.set(\\\"spark.executor.memory\\\", \\\"8g\\\")\\n\",\n    \"    conf.set(\\\"spark.rpc.message.maxSize\\\", \\\"1024\\\")\\n\",\n    \"    conf.set(\\\"spark.sql.pyspark.jvmStacktrace.enabled\\\", \\\"true\\\")\\n\",\n    \"    conf.set(\\\"spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled\\\", \\\"false\\\")\\n\",\n    \"    conf.set(\\\"spark.sql.execution.arrow.pyspark.enabled\\\", \\\"true\\\")\\n\",\n    \"    conf.set(\\\"spark.python.worker.reuse\\\", \\\"true\\\")\\n\",\n    \"    conf.set(\\\"spark.rapids.ml.uvm.enabled\\\", \\\"true\\\")\\n\",\n    \"    conf.set(\\\"spark.jars\\\", rapids_jar)\\n\",\n    \"    conf.set(\\\"spark.executorEnv.PYTHONPATH\\\", rapids_jar)\\n\",\n    \"    conf.set(\\\"spark.rapids.memory.gpu.minAllocFraction\\\", \\\"0.0001\\\")\\n\",\n    \"    conf.set(\\\"spark.plugins\\\", \\\"com.nvidia.spark.SQLPlugin\\\")\\n\",\n    \"    conf.set(\\\"spark.locality.wait\\\", \\\"0s\\\")\\n\",\n    \"    conf.set(\\\"spark.sql.cache.serializer\\\", \\\"com.nvidia.spark.ParquetCachedBatchSerializer\\\")\\n\",\n    \"    conf.set(\\\"spark.rapids.memory.gpu.pooling.enabled\\\", \\\"false\\\")\\n\",\n    \"    conf.set(\\\"spark.sql.execution.sortBeforeRepartition\\\", \\\"false\\\")\\n\",\n    \"    conf.set(\\\"spark.rapids.sql.format.parquet.reader.type\\\", \\\"MULTITHREADED\\\")\\n\",\n    \"    conf.set(\\\"spark.rapids.sql.format.parquet.multiThreadedRead.maxNumFilesParallel\\\", \\\"20\\\")\\n\",\n    \"    conf.set(\\\"spark.rapids.sql.multiThreadedRead.numThreads\\\", \\\"20\\\")\\n\",\n    \"    conf.set(\\\"spark.rapids.sql.python.gpu.enabled\\\", \\\"true\\\")\\n\",\n    \"    conf.set(\\\"spark.rapids.memory.pinnedPool.size\\\", \\\"2G\\\")\\n\",\n    \"    conf.set(\\\"spark.python.daemon.module\\\", \\\"rapids.daemon\\\")\\n\",\n    \"    conf.set(\\\"spark.rapids.sql.batchSizeBytes\\\", \\\"512m\\\")\\n\",\n    \"    conf.set(\\\"spark.sql.adaptive.enabled\\\", \\\"false\\\")\\n\",\n    \"    conf.set(\\\"spark.sql.files.maxPartitionBytes\\\", \\\"512m\\\")\\n\",\n    \"    conf.set(\\\"spark.rapids.sql.concurrentGpuTasks\\\", \\\"1\\\")\\n\",\n    \"    conf.set(\\\"spark.sql.execution.arrow.maxRecordsPerBatch\\\", \\\"20000\\\")\\n\",\n    \"    conf.set(\\\"spark.rapids.sql.explain\\\", \\\"NONE\\\")\\n\",\n    \"    \\n\",\n    \"    spark = SparkSession.builder.appName(\\\"spark-rapids-ml-pca\\\").config(conf=conf).getOrCreate()\\n\",\n    \"    return spark\\n\",\n    \"\\n\",\n    \"# Check if Spark session is already active, if not, initialize it\\n\",\n    \"if 'spark' not in globals():\\n\",\n    \"    print(\\\"No active Spark session found, initializing manually.\\\")\\n\",\n    \"    rapids_jar = os.environ.get('RAPIDS_JAR')\\n\",\n    \"    if rapids_jar is None:\\n\",\n    \"        rapids_jar = get_rapids_jar()\\n\",\n    \"    spark = initialize_spark(rapids_jar)\\n\",\n    \"else:\\n\",\n    \"    print(\\\"Using existing Spark session.\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Generate synthetic dataset\\n\",\n    \"\\n\",\n    \"Here we generate a 100,000 x 2048 random dataset.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"24/10/04 18:04:45 WARN TaskSetManager: Stage 0 contains a task of very large size (160085 KiB). The maximum recommended task size is 1000 KiB.\\n\",\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"rows = 100000\\n\",\n    \"dim = 2048\\n\",\n    \"dtype = 'float32'\\n\",\n    \"np.random.seed(42)\\n\",\n    \"\\n\",\n    \"data = np.random.rand(rows, dim).astype(dtype)\\n\",\n    \"pd_data = pd.DataFrame({\\\"features\\\": list(data)})\\n\",\n    \"prepare_df = spark.createDataFrame(pd_data)\\n\",\n    \"prepare_df.write.mode(\\\"overwrite\\\").parquet(\\\"PCA_data.parquet\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Spark-RAPIDS-ML accepts ArrayType input\\n\",\n    \"\\n\",\n    \"Note that in the original Spark-ML PCA, we must `Vectorize` the input column:\\n\",\n    \"\\n\",\n    \"```python\\n\",\n    \"from pyspark.ml.linalg import Vectors\\n\",\n    \"data = [(Vectors.sparse(5, [(1, 1.0), (3, 7.0)]),),\\n\",\n    \"    (Vectors.dense([2.0, 0.0, 3.0, 4.0, 5.0]),),\\n\",\n    \"    (Vectors.dense([4.0, 0.0, 0.0, 6.0, 7.0]),)]\\n\",\n    \"df = spark.createDataFrame(data,[\\\"features\\\"])\\n\",\n    \"df.show()\\n\",\n    \"```\\n\",\n    \"\\n\",\n    \"...whereas the Spark-RAPIDS-ML version does not require extra Vectorization, and can accept an ArrayType column as the input column:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"root\\n\",\n      \" |-- features: array (nullable = true)\\n\",\n      \" |    |-- element: float (containsNull = true)\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"data_df = spark.read.parquet(\\\"PCA_data.parquet\\\")\\n\",\n    \"data_df.printSchema()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Using Spark-RAPIDS-ML PCA (GPU)\\n\",\n    \"\\n\",\n    \"Compared to the Spark-ML PCA training API:\\n\",\n    \"\\n\",\n    \"```python\\n\",\n    \"from pyspark.ml.feature import PCA\\n\",\n    \"pca = PCA(k=3, inputCol=\\\"features\\\")\\n\",\n    \"pca.setOutputCol(\\\"pca_features\\\")\\n\",\n    \"```\\n\",\n    \"\\n\",\n    \"We use a customized class which requires **no code change** from the user to enjoy GPU acceleration:\\n\",\n    \"\\n\",\n    \"```python\\n\",\n    \"from spark_rapids_ml.feature import PCA\\n\",\n    \"pca = PCA(k=3, inputCol=\\\"features\\\")\\n\",\n    \"pca.setOutputCol(\\\"pca_features\\\")\\n\",\n    \"```\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"PCA_570681141389\"\n      ]\n     },\n     \"execution_count\": 5,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"from spark_rapids_ml.feature import PCA\\n\",\n    \"\\n\",\n    \"gpu_pca = PCA(k=2, inputCol=\\\"features\\\")\\n\",\n    \"gpu_pca.setOutputCol(\\\"pca_features\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"The PCA estimator object can be persisted and reloaded.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"estimator_path = \\\"/tmp/pca_estimator\\\"\\n\",\n    \"gpu_pca.write().overwrite().save(estimator_path)\\n\",\n    \"gpu_pca_loaded = PCA.load(estimator_path)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Fit\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"24/10/04 18:04:58 WARN MultiFileReaderThreadPool: Configuring the file reader thread pool with a max of 32 threads instead of spark.rapids.sql.multiThreadedRead.numThreads = 20\\n\",\n      \"2024-10-04 18:04:58,487 - spark_rapids_ml.feature.PCA - INFO - CUDA managed memory enabled.\\n\",\n      \"2024-10-04 18:04:58,570 - spark_rapids_ml.feature.PCA - INFO - Training spark-rapids-ml with 1 worker(s) ...\\n\",\n      \"INFO: Process 2762394 found CUDA visible device(s): 0\\n\",\n      \"2024-10-04 18:05:01,613 - spark_rapids_ml.feature.PCA - INFO - Loading data into python worker memory\\n\",\n      \"2024-10-04 18:05:02,551 - spark_rapids_ml.feature.PCA - INFO - Initializing cuml context\\n\",\n      \"2024-10-04 18:05:03,795 - spark_rapids_ml.feature.PCA - INFO - Invoking cuml fit\\n\",\n      \"2024-10-04 18:05:05,326 - spark_rapids_ml.feature.PCA - INFO - Cuml fit complete\\n\",\n      \"2024-10-04 18:05:06,858 - spark_rapids_ml.feature.PCA - INFO - Finished training\\n\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"GPU PCA fit took: 8.90433144569397 sec\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"start_time = time.time()\\n\",\n    \"gpu_pca_model = gpu_pca_loaded.fit(data_df)\\n\",\n    \"gpu_fit_time = time.time() - start_time\\n\",\n    \"print(f\\\"GPU PCA fit took: {gpu_fit_time} sec\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Transform\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 12,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+---------------------------+\\n\",\n      \"|pca_features               |\\n\",\n      \"+---------------------------+\\n\",\n      \"|[0.062363233, 0.4037608]   |\\n\",\n      \"|[0.49734917, 0.703541]     |\\n\",\n      \"|[0.0035427138, 0.29358602] |\\n\",\n      \"|[-0.06798951, 0.37400067]  |\\n\",\n      \"|[0.10075127, 0.34651726]   |\\n\",\n      \"|[-0.22320557, 0.6660976]   |\\n\",\n      \"|[0.49608234, 0.6761328]    |\\n\",\n      \"|[0.25515205, 0.20352581]   |\\n\",\n      \"|[-0.5102935, 0.319284]     |\\n\",\n      \"|[-0.5109488, 0.2756377]    |\\n\",\n      \"|[0.411546, -0.17954555]    |\\n\",\n      \"|[0.21616393, -0.46268395]  |\\n\",\n      \"|[-0.0924304, 0.65660465]   |\\n\",\n      \"|[0.12355948, 0.9478601]    |\\n\",\n      \"|[0.49234354, 0.63746333]   |\\n\",\n      \"|[-0.86077166, 0.0037032962]|\\n\",\n      \"|[-0.013956882, 0.663955]   |\\n\",\n      \"|[-0.30510652, 0.02372247]  |\\n\",\n      \"|[-0.05999008, 0.28261736]  |\\n\",\n      \"|[0.36605445, 0.9674797]    |\\n\",\n      \"+---------------------------+\\n\",\n      \"only showing top 20 rows\\n\",\n      \"\\n\",\n      \"GPU PCA transform took: 0.43911027908325195 sec\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"start_time = time.time()\\n\",\n    \"embeddings = gpu_pca_model.transform(data_df).select(\\\"pca_features\\\").show(truncate=False)\\n\",\n    \"gpu_transform_time = time.time() - start_time\\n\",\n    \"print(f\\\"GPU PCA transform took: {gpu_transform_time} sec\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Using Spark-ML PCA (CPU)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 13,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"PCA_58add243f20d\"\n      ]\n     },\n     \"execution_count\": 13,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"from pyspark.ml.feature import PCA\\n\",\n    \"\\n\",\n    \"cpu_pca = PCA(k=2, inputCol=\\\"features\\\")\\n\",\n    \"cpu_pca.setOutputCol(\\\"pca_features\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 14,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"root\\n\",\n      \" |-- features: vector (nullable = true)\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"from pyspark.ml.functions import array_to_vector\\n\",\n    \"\\n\",\n    \"vector_df = data_df.select(array_to_vector(\\\"features\\\").alias(\\\"features\\\"))\\n\",\n    \"vector_df.printSchema()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Fit\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 15,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"24/10/04 17:07:07 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.lapack.JNILAPACK\\n\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU PCA fit took: 63.37388610839844 sec\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"start_time = time.time()\\n\",\n    \"cpu_pca_model = cpu_pca.fit(vector_df)\\n\",\n    \"pca_fit_time = time.time() - start_time\\n\",\n    \"print(f\\\"CPU PCA fit took: {pca_fit_time} sec\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Transform\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 16,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+-------------------------------------------+\\n\",\n      \"|pca_features                               |\\n\",\n      \"+-------------------------------------------+\\n\",\n      \"|[0.24926765828229927,0.3425432972889563]   |\\n\",\n      \"|[-0.5175207040808384,0.48893065865444574]  |\\n\",\n      \"|[-0.2505049373829902,0.381272141155778]    |\\n\",\n      \"|[-0.39046980420292005,0.4870705091697811]  |\\n\",\n      \"|[-0.4024088726395023,0.707133448810984]    |\\n\",\n      \"|[-0.3061227832285992,0.5363554872099332]   |\\n\",\n      \"|[-0.6065136982526093,0.5205197626985932]   |\\n\",\n      \"|[-0.21870566838630084,0.6516598402789231]  |\\n\",\n      \"|[0.1910036552854184,0.6336513389989592]    |\\n\",\n      \"|[0.6139537641786907,0.6055187085018856]    |\\n\",\n      \"|[-0.026502904776425647,-0.0366087508156753]|\\n\",\n      \"|[-0.2989311781309336,-0.05136110567458389] |\\n\",\n      \"|[-0.5474468086054212,-0.18779964958125014] |\\n\",\n      \"|[-0.6644746232216499,0.10351178251944647]  |\\n\",\n      \"|[-0.12685301272617464,0.47394431583661295] |\\n\",\n      \"|[-0.4355221246718862,-0.00346289187881239] |\\n\",\n      \"|[0.6222719258951077,0.5488293416698503]    |\\n\",\n      \"|[0.04966907735703511,0.7138677407505005]   |\\n\",\n      \"|[0.6260486995906139,0.3553228450428632]    |\\n\",\n      \"|[0.16396683091519929,0.7382693234881972]   |\\n\",\n      \"+-------------------------------------------+\\n\",\n      \"only showing top 20 rows\\n\",\n      \"\\n\",\n      \"CPU PCA transform took: 0.19607114791870117 sec\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"start_time = time.time()\\n\",\n    \"embeddings = cpu_pca_model.transform(vector_df).select(\\\"pca_features\\\").show(truncate=False)\\n\",\n    \"pca_transform_time = time.time() - start_time\\n\",\n    \"print(f\\\"CPU PCA transform took: {pca_transform_time} sec\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Summary\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 27,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU runtime: (64.02s + 0.20s)\\n\",\n      \"GPU runtime: (8.76s + 0.42s)\\n\",\n      \"End-to-end speedup: CPU / GPU = 7.00x\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"speedup = (pca_fit_time + pca_transform_time) / (gpu_fit_time + gpu_transform_time)\\n\",\n    \"print(f\\\"CPU runtime: ({pca_fit_time:.2f}s + {pca_transform_time:.2f}s)\\\")\\n\",\n    \"print(f\\\"GPU runtime: ({gpu_fit_time:.2f}s + {gpu_transform_time:.2f}s)\\\")\\n\",\n    \"print(f\\\"End-to-end speedup: CPU / GPU = {speedup:.2f}x\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"rapids-25.02\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.11.10\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 4\n}\n"
  },
  {
    "path": "examples/ML+DL-Examples/Spark-Rapids-ML/pca/start-spark-rapids.sh",
    "content": "#!/bin/bash\n#\n# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n\n# Check if SPARK_HOME is set\nif [ -z \"$SPARK_HOME\" ]; then\n  echo \"Please set the SPARK_HOME environment variable before running this script.\"\n  exit 1\nfi\n\n# Check if RAPIDS_JAR is set\nif [ -z \"$RAPIDS_JAR\" ]; then\n  echo \"Please set the RAPIDS_JAR environment variable before running this script.\"\n  exit 1\nfi\n\n# Configuration\nMASTER_HOSTNAME=$(hostname)\nMASTER=spark://${MASTER_HOSTNAME}:7077\nCORES_PER_WORKER=8\nMEMORY_PER_WORKER=16G\n\n# Environment variables\nexport SPARK_HOME=${SPARK_HOME}\nexport MASTER=${MASTER}\nexport SPARK_WORKER_INSTANCES=1\nexport CORES_PER_WORKER=${CORES_PER_WORKER}\nexport PYSPARK_DRIVER_PYTHON=jupyter\nexport PYSPARK_DRIVER_PYTHON_OPTS='lab'\n\n# Start standalone cluster\necho \"Starting Spark standalone cluster...\"\n${SPARK_HOME}/sbin/start-master.sh\n${SPARK_HOME}/sbin/start-worker.sh -c ${CORES_PER_WORKER} -m ${MEMORY_PER_WORKER} ${MASTER}\n\n# Start Jupyter with PySpark\necho \"Launching PySpark with Jupyter...\"\n${SPARK_HOME}/bin/pyspark --master ${MASTER} \\\n--driver-memory 10G \\\n--executor-memory 8G \\\n--conf spark.task.maxFailures=1 \\\n--conf spark.rpc.message.maxSize=1024 \\\n--conf spark.sql.pyspark.jvmStacktrace.enabled=true \\\n--conf spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled=false \\\n--conf spark.sql.execution.arrow.pyspark.enabled=true \\\n--conf spark.python.worker.reuse=true \\\n--conf spark.rapids.ml.uvm.enabled=true \\\n--conf spark.jars=${RAPIDS_JAR} \\\n--conf spark.executorEnv.PYTHONPATH=${RAPIDS_JAR} \\\n--conf spark.rapids.memory.gpu.minAllocFraction=0.0001 \\\n--conf spark.plugins=com.nvidia.spark.SQLPlugin \\\n--conf spark.locality.wait=0s \\\n--conf spark.sql.cache.serializer=com.nvidia.spark.ParquetCachedBatchSerializer \\\n--conf spark.rapids.memory.gpu.pooling.enabled=false \\\n--conf spark.sql.execution.sortBeforeRepartition=false \\\n--conf spark.rapids.sql.format.parquet.reader.type=MULTITHREADED \\\n--conf spark.rapids.sql.format.parquet.multiThreadedRead.maxNumFilesParallel=20 \\\n--conf spark.rapids.sql.multiThreadedRead.numThreads=20 \\\n--conf spark.rapids.sql.python.gpu.enabled=true \\\n--conf spark.rapids.memory.pinnedPool.size=2G \\\n--conf spark.python.daemon.module=rapids.daemon \\\n--conf spark.rapids.sql.batchSizeBytes=512m \\\n--conf spark.sql.adaptive.enabled=false \\\n--conf spark.sql.files.maxPartitionBytes=512m \\\n--conf spark.rapids.sql.concurrentGpuTasks=1 \\\n--conf spark.sql.execution.arrow.maxRecordsPerBatch=20000 \\\n--conf spark.rapids.sql.explain=NONE"
  },
  {
    "path": "examples/SQL+DF-Examples/customer-churn/README.md",
    "content": "# Customer Churn\n\nThis demo is derived from [data-science-blueprints](https://github.com/NVIDIA/data-science-blueprints) repository.\nThe repository shows a realistic ETL workflow based on synthetic normalized data.  It consists of two pieces:\n\n1.  _an augmentation notebook_, which synthesizes normalized (long-form) data from a wide-form input file,\n    optionally augmenting it by duplicating records, and\n2. _an ETL notebook_, which performs joins and aggregations in order to generate wide-form data from the synthetic long-form data.\n\nTo learn more about the customer churn use case, you can read our [ebook](https://www.nvidia.com/en-us/ai-data-science/resources/churn-prediction-blueprint/). \n"
  },
  {
    "path": "examples/SQL+DF-Examples/customer-churn/notebooks/python/README.md",
    "content": "# telco-churn-augmentation\n\nThis demo shows a realistic ETL workflow based on synthetic normalized data.  It consists of two pieces:\n\n1.  _an [augmentation notebook](augment.ipynb)_, which synthesizes normalized (long-form) data from a wide-form input file, \n    optionally augmenting it by duplicating records, and\n2. _an [ETL notebook](etl.ipynb)_, which performs joins and aggregations in order to generate\n   wide-form data from the synthetic long-form data.\n\nFrom a performance evaluation perspective, the latter is the interesting workload; \nthe former is just a data generator for the latter.\n"
  },
  {
    "path": "examples/SQL+DF-Examples/customer-churn/notebooks/python/augment.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Customer churn augment\\n\",\n    \"\\n\",\n    \"This notebook is derived from [customer churn augment notebook](https://github.com/NVIDIA/data-science-blueprints/blob/main/churn/augment.ipynb), please refer to this [git repo](https://github.com/NVIDIA/data-science-blueprints/tree/main/churn) for more detail information.\\n\",\n    \"\\n\",\n    \" \"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 20,\n   \"metadata\": {\n    \"tags\": [\n     \"parameters\"\n    ]\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# notebook parameters\\n\",\n    \"\\n\",\n    \"import os\\n\",\n    \"\\n\",\n    \"spark_master = os.getenv(\\\"SPARK_MASTER_URL\\\", \\\"spark://ip:port\\\")\\n\",\n    \"app_name = \\\"augment\\\"\\n\",\n    \"dataRoot = os.getenv(\\\"DATA_ROOT\\\", \\\"data\\\")\\n\",\n    \"input_file = os.path.join(dataRoot, \\\"WA_Fn-UseC_-Telco-Customer-Churn-.csv\\\")\\n\",\n    \"output_mode = \\\"overwrite\\\"\\n\",\n    \"output_kind = \\\"parquet\\\"\\n\",\n    \"driver_memory = '12g'\\n\",\n    \"executor_memory = '8g'\\n\",\n    \"\\n\",\n    \"dup_times = 100\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import churn.augment\\n\",\n    \"\\n\",\n    \"churn.augment.register_options(\\n\",\n    \"    spark_master = spark_master,\\n\",\n    \"    app_name = app_name,\\n\",\n    \"    input_file = input_file,\\n\",\n    \"    output_mode = output_mode,\\n\",\n    \"    output_kind = output_kind,\\n\",\n    \"    driver_memory = driver_memory,\\n\",\n    \"    executor_memory = executor_memory,\\n\",\n    \"    dup_times = dup_times,\\n\",\n    \"    use_decimal = True\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Sanity-checking\\n\",\n    \"\\n\",\n    \"We're going to make sure we're running with a compatible JVM first — if we run on macOS, we might get one that doesn't work with Scala.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from os import getenv\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"'/data/usr/lib/jvm/java-8-openjdk-amd64'\"\n      ]\n     },\n     \"execution_count\": 4,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"getenv(\\\"JAVA_HOME\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Spark setup\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import pyspark\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/html\": [\n       \"\\n\",\n       \"            <div>\\n\",\n       \"                <p><b>SparkSession - hive</b></p>\\n\",\n       \"                \\n\",\n       \"        <div>\\n\",\n       \"            <p><b>SparkContext</b></p>\\n\",\n       \"\\n\",\n       \"            <p><a href=\\\"http://10.19.183.210:4040\\\">Spark UI</a></p>\\n\",\n       \"\\n\",\n       \"            <dl>\\n\",\n       \"              <dt>Version</dt>\\n\",\n       \"                <dd><code>v3.2.0</code></dd>\\n\",\n       \"              <dt>Master</dt>\\n\",\n       \"                <dd><code>spark://yuanli-System-Product-Name:7077</code></dd>\\n\",\n       \"              <dt>AppName</dt>\\n\",\n       \"                <dd><code>PySparkShell</code></dd>\\n\",\n       \"            </dl>\\n\",\n       \"        </div>\\n\",\n       \"        \\n\",\n       \"            </div>\\n\",\n       \"        \"\n      ],\n      \"text/plain\": [\n       \"<pyspark.sql.session.SparkSession at 0x7f2751631520>\"\n      ]\n     },\n     \"execution_count\": 6,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"session = pyspark.sql.SparkSession.builder \\\\\\n\",\n    \"    .master(spark_master) \\\\\\n\",\n    \"    .appName(app_name) \\\\\\n\",\n    \"    .config(\\\"spark.driver.memory\\\", driver_memory) \\\\\\n\",\n    \"    .config(\\\"spark.executor.memory\\\", executor_memory) \\\\\\n\",\n    \"    .getOrCreate()\\n\",\n    \"session\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Schema definition\\n\",\n    \"\\n\",\n    \"Most of the fields are strings representing booleans or categoricals, but a few (`tenure`, `MonthlyCharges`, and `TotalCharges`) are numeric.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"read 7043 records from source dataset (7032 non-null records)\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"from churn.augment import load_supplied_data\\n\",\n    \"\\n\",\n    \"df = load_supplied_data(session, input_file)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Splitting the data frame\\n\",\n    \"\\n\",\n    \"The training data schema looks like this:\\n\",\n    \"\\n\",\n    \"- customerID\\n\",\n    \"- gender\\n\",\n    \"- SeniorCitizen\\n\",\n    \"- Partner\\n\",\n    \"- Dependents\\n\",\n    \"- tenure\\n\",\n    \"- PhoneService\\n\",\n    \"- MultipleLines\\n\",\n    \"- InternetService\\n\",\n    \"- OnlineSecurity\\n\",\n    \"- OnlineBackup\\n\",\n    \"- DeviceProtection\\n\",\n    \"- TechSupport\\n\",\n    \"- StreamingTV\\n\",\n    \"- StreamingMovies\\n\",\n    \"- Contract\\n\",\n    \"- PaperlessBilling\\n\",\n    \"- PaymentMethod\\n\",\n    \"- MonthlyCharges\\n\",\n    \"- TotalCharges\\n\",\n    \"- Churn\\n\",\n    \"\\n\",\n    \"We want to divide the data frame into several frames that we can join together in an ETL job.\\n\",\n    \"\\n\",\n    \"Those frames will look like this:\\n\",\n    \"\\n\",\n    \"- **Customer metadata**\\n\",\n    \"  - customerID\\n\",\n    \"  - gender\\n\",\n    \"  - date of birth (we'll derive age and senior citizen status from this)\\n\",\n    \"  - Partner\\n\",\n    \"  - Dependents\\n\",\n    \"  - (nominal) MonthlyCharges\\n\",\n    \"- **Billing events**\\n\",\n    \"  - customerID\\n\",\n    \"  - date (we'll derive tenure from the number/duration of billing events)\\n\",\n    \"  - kind (one of \\\"AccountCreation\\\", \\\"Charge\\\", or \\\"AccountTermination\\\")\\n\",\n    \"  - value (either a positive nonzero amount or 0.00; we'll derive TotalCharges from the sum of amounts and Churn from the existence of an AccountTermination event)\\n\",\n    \"- **Customer phone features**\\n\",\n    \"  - customerID\\n\",\n    \"  - feature (one of \\\"PhoneService\\\" or \\\"MultipleLines\\\")\\n\",\n    \"- **Customer internet features**\\n\",\n    \"  - customerID\\n\",\n    \"  - feature (one of \\\"InternetService\\\", \\\"OnlineSecurity\\\", \\\"OnlineBackup\\\", \\\"DeviceProtection\\\", \\\"TechSupport\\\", \\\"StreamingTV\\\", \\\"StreamingMovies\\\")\\n\",\n    \"  - value (one of \\\"Fiber\\\", \\\"DSL\\\", \\\"Yes\\\", \\\"No\\\")\\n\",\n    \"- **Customer account features**\\n\",\n    \"  - customerID\\n\",\n    \"  - feature (one of \\\"Contract\\\", \\\"PaperlessBilling\\\", \\\"PaymentMethod\\\")\\n\",\n    \"  - value (one of \\\"Month-to-month\\\", \\\"One year\\\", \\\"Two year\\\", \\\"No\\\", \\\"Yes\\\", \\\"Credit card (automatic)\\\", \\\"Mailed check\\\", \\\"Bank transfer (automatic)\\\", \\\"Electronic check\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"root\\n\",\n      \" |-- customerID: string (nullable = true)\\n\",\n      \" |-- gender: string (nullable = true)\\n\",\n      \" |-- SeniorCitizen: string (nullable = true)\\n\",\n      \" |-- Partner: string (nullable = true)\\n\",\n      \" |-- Dependents: string (nullable = true)\\n\",\n      \" |-- tenure: double (nullable = true)\\n\",\n      \" |-- PhoneService: string (nullable = true)\\n\",\n      \" |-- MultipleLines: string (nullable = true)\\n\",\n      \" |-- InternetService: string (nullable = true)\\n\",\n      \" |-- OnlineSecurity: string (nullable = true)\\n\",\n      \" |-- OnlineBackup: string (nullable = true)\\n\",\n      \" |-- DeviceProtection: string (nullable = true)\\n\",\n      \" |-- TechSupport: string (nullable = true)\\n\",\n      \" |-- StreamingTV: string (nullable = true)\\n\",\n      \" |-- StreamingMovies: string (nullable = true)\\n\",\n      \" |-- Contract: string (nullable = true)\\n\",\n      \" |-- PaperlessBilling: string (nullable = true)\\n\",\n      \" |-- PaymentMethod: string (nullable = true)\\n\",\n      \" |-- MonthlyCharges: double (nullable = true)\\n\",\n      \" |-- TotalCharges: double (nullable = true)\\n\",\n      \" |-- Churn: string (nullable = true)\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"df.printSchema()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"We'll start by generating a series of monthly charges, then a series of account creation events, and finally a series of churn events. `billingEvents` is the data frame containing all of these events:  account activation, account termination, and individual payment events.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 9,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"/home/yuanli/work/spark-3.2.0-bin-hadoop3.2/python/pyspark/sql/functions.py:1353: FutureWarning: Deprecated in 3.2, use shiftright instead.\\n\",\n      \"  warnings.warn(\\\"Deprecated in 3.2, use shiftright instead.\\\", FutureWarning)\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"from churn.augment import billing_events\\n\",\n    \"billingEvents = billing_events(df)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Our next step is to generate customer metadata, which includes the following fields:\\n\",\n    \"\\n\",\n    \"  - gender\\n\",\n    \"  - date of birth (we'll derive age and senior citizen status from this)\\n\",\n    \"  - Partner\\n\",\n    \"  - Dependents\\n\",\n    \"  \\n\",\n    \"We'll calculate date of birth by using the hash of the customer ID as a pseudorandom number and then assuming that ages are uniformly distributed between 18-65 and exponentially distributed over 65.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 10,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2022-04-05 09:36:31,848 WARN conf.HiveConf: HiveConf of name hive.stats.jdbc.timeout does not exist\\n\",\n      \"2022-04-05 09:36:31,849 WARN conf.HiveConf: HiveConf of name hive.stats.retries.wait does not exist\\n\",\n      \"2022-04-05 09:36:33,683 WARN metastore.ObjectStore: Version information not found in metastore. hive.metastore.schema.verification is not enabled so recording the schema version 2.3.0\\n\",\n      \"2022-04-05 09:36:33,683 WARN metastore.ObjectStore: setMetaStoreSchemaVersion called but recording version is disabled: version = 2.3.0, comment = Set by MetaStore yuanli@127.0.1.1\\n\",\n      \"2022-04-05 09:36:33,811 WARN metastore.ObjectStore: Failed to get database global_temp, returning NoSuchObjectException\\n\",\n      \"2022-04-05 09:36:33,892 WARN rapids.GpuOverrides: \\n\",\n      \"! <LocalTableScanExec> cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.execution.LocalTableScanExec\\n\",\n      \"  @Expression <AttributeReference> name#326 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> database#327 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> description#328 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> tableType#329 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> isTemporary#330 could run on GPU\\n\",\n      \"\\n\",\n      \"2022-04-05 09:36:33,960 WARN rapids.GpuOverrides: \\n\",\n      \"        ! <RDDScanExec> cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.execution.RDDScanExec\\n\",\n      \"          @Expression <AttributeReference> u_value#337 could run on GPU\\n\",\n      \"\\n\",\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"from churn.augment import customer_meta\\n\",\n    \"customerMeta = customer_meta(df)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Now we can generate customer phone features, which include:\\n\",\n    \"\\n\",\n    \"  - customerID\\n\",\n    \"  - feature (one of \\\"PhoneService\\\" or \\\"MultipleLines\\\")\\n\",\n    \"  - value (always \\\"Yes\\\"; there are no records for \\\"No\\\" or \\\"No Phone Service\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 11,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from churn.augment import phone_features\\n\",\n    \"customerPhoneFeatures = phone_features(df)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Customer internet features include:\\n\",\n    \"  - customerID\\n\",\n    \"  - feature (one of \\\"InternetService\\\", \\\"OnlineSecurity\\\", \\\"OnlineBackup\\\", \\\"DeviceProtection\\\", \\\"TechSupport\\\", \\\"StreamingTV\\\", \\\"StreamingMovies\\\")\\n\",\n    \"  - value (one of \\\"Fiber\\\", \\\"DSL\\\", \\\"Yes\\\" -- no records for \\\"No\\\" or \\\"No internet service\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 12,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from churn.augment import internet_features\\n\",\n    \"customerInternetFeatures = internet_features(df)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Customer account features include:\\n\",\n    \"\\n\",\n    \"  - customerID\\n\",\n    \"  - feature (one of \\\"Contract\\\", \\\"PaperlessBilling\\\", \\\"PaymentMethod\\\")\\n\",\n    \"  - value (one of \\\"Month-to-month\\\", \\\"One year\\\", \\\"Two year\\\", \\\"Yes\\\", \\\"Credit card (automatic)\\\", \\\"Mailed check\\\", \\\"Bank transfer (automatic)\\\", \\\"Electronic check\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 13,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from churn.augment import account_features\\n\",\n    \"customerAccountFeatures = account_features(df)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Write outputs\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 14,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2022-04-05 09:36:36,792 WARN rapids.GpuOverrides: \\n\",\n      \"! <LocalTableScanExec> cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.execution.LocalTableScanExec\\n\",\n      \"  @Expression <AttributeReference> name#798 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> database#799 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> description#800 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> tableType#801 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> isTemporary#802 could run on GPU\\n\",\n      \"\\n\",\n      \"2022-04-05 09:36:37,142 WARN rapids.GpuOverrides: \\n\",\n      \"!Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"  @Expression <Alias> format_string(%s-%s, customerID#0, u_value#337) AS customerID#816 could run on GPU\\n\",\n      \"    ! <FormatString> format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\\n\",\n      \"      @Expression <Literal> %s-%s could run on GPU\\n\",\n      \"      @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"      @Expression <AttributeReference> u_value#337 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> kind#133 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> value#136 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> date#156 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> month#315 could run on GPU\\n\",\n      \"        !Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"          @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"          @Expression <Alias> Charge AS kind#133 could run on GPU\\n\",\n      \"            @Expression <Literal> Charge could run on GPU\\n\",\n      \"          @Expression <AttributeReference> value#136 could run on GPU\\n\",\n      \"          @Expression <Alias> add_months(2022-04-05, cast(-(cast(_we0#149 as bigint) + last_month#135L) as int)) AS date#156 could run on GPU\\n\",\n      \"            ! <AddMonths> add_months(2022-04-05, cast(-(cast(_we0#149 as bigint) + last_month#135L) as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.AddMonths\\n\",\n      \"              @Expression <Literal> 2022-04-05 could run on GPU\\n\",\n      \"              @Expression <Cast> cast(-(cast(_we0#149 as bigint) + last_month#135L) as int) could run on GPU\\n\",\n      \"                @Expression <UnaryMinus> -(cast(_we0#149 as bigint) + last_month#135L) could run on GPU\\n\",\n      \"                  @Expression <Add> (cast(_we0#149 as bigint) + last_month#135L) could run on GPU\\n\",\n      \"                    @Expression <Cast> cast(_we0#149 as bigint) could run on GPU\\n\",\n      \"                      @Expression <AttributeReference> _we0#149 could run on GPU\\n\",\n      \"                    @Expression <AttributeReference> last_month#135L could run on GPU\\n\",\n      \"              !Exec <ShuffleExchangeExec> cannot run on GPU because Columnar exchange without columnar children is inefficient\\n\",\n      \"                @Partitioning <HashPartitioning> could run on GPU\\n\",\n      \"                  @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"                !Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"                  @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"                  @Expression <AttributeReference> value#136 could run on GPU\\n\",\n      \"                  @Expression <Alias> CASE WHEN (Churn#20 = Yes) THEN -((((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) % 6)) ELSE 0 END AS last_month#135L could run on GPU\\n\",\n      \"                    @Expression <CaseWhen> CASE WHEN (Churn#20 = Yes) THEN -((((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) % 6)) ELSE 0 END could run on GPU\\n\",\n      \"                      @Expression <EqualTo> (Churn#20 = Yes) could run on GPU\\n\",\n      \"                        @Expression <AttributeReference> Churn#20 could run on GPU\\n\",\n      \"                        @Expression <Literal> Yes could run on GPU\\n\",\n      \"                      @Expression <UnaryMinus> -((((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) % 6)) could run on GPU\\n\",\n      \"                        @Expression <Add> ((((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) % 6)) could run on GPU\\n\",\n      \"                          @Expression <Add> (((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10)) could run on GPU\\n\",\n      \"                            @Expression <Add> ((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) could run on GPU\\n\",\n      \"                              @Expression <Add> (((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) could run on GPU\\n\",\n      \"                                @Expression <Remainder> ((abs(xxhash64(customerID#0, 42), false) & 255) % 36) could run on GPU\\n\",\n      \"                                  @Expression <BitwiseAnd> (abs(xxhash64(customerID#0, 42), false) & 255) could run on GPU\\n\",\n      \"                                    @Expression <Abs> abs(xxhash64(customerID#0, 42), false) could run on GPU\\n\",\n      \"                                      ! <XxHash64> xxhash64(customerID#0, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                        @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"                                    @Expression <Literal> 255 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 36 could run on GPU\\n\",\n      \"                                @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24) could run on GPU\\n\",\n      \"                                  @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) could run on GPU\\n\",\n      \"                                    @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#0, 42), false), 8) could run on GPU\\n\",\n      \"                                      @Expression <Abs> abs(xxhash64(customerID#0, 42), false) could run on GPU\\n\",\n      \"                                        ! <XxHash64> xxhash64(customerID#0, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                          @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"                                      @Expression <Literal> 8 could run on GPU\\n\",\n      \"                                    @Expression <Literal> 255 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 24 could run on GPU\\n\",\n      \"                              @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14) could run on GPU\\n\",\n      \"                                @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) could run on GPU\\n\",\n      \"                                  @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#0, 42), false), 16) could run on GPU\\n\",\n      \"                                    @Expression <Abs> abs(xxhash64(customerID#0, 42), false) could run on GPU\\n\",\n      \"                                      ! <XxHash64> xxhash64(customerID#0, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                        @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"                                    @Expression <Literal> 16 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 255 could run on GPU\\n\",\n      \"                                @Expression <Literal> 14 could run on GPU\\n\",\n      \"                            @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10) could run on GPU\\n\",\n      \"                              @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) could run on GPU\\n\",\n      \"                                @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#0, 42), false), 24) could run on GPU\\n\",\n      \"                                  @Expression <Abs> abs(xxhash64(customerID#0, 42), false) could run on GPU\\n\",\n      \"                                    ! <XxHash64> xxhash64(customerID#0, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                      @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 24 could run on GPU\\n\",\n      \"                                @Expression <Literal> 255 could run on GPU\\n\",\n      \"                              @Expression <Literal> 10 could run on GPU\\n\",\n      \"                          @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) % 6) could run on GPU\\n\",\n      \"                            @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) could run on GPU\\n\",\n      \"                              @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#0, 42), false), 32) could run on GPU\\n\",\n      \"                                @Expression <Abs> abs(xxhash64(customerID#0, 42), false) could run on GPU\\n\",\n      \"                                  ! <XxHash64> xxhash64(customerID#0, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                    @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"                                @Expression <Literal> 32 could run on GPU\\n\",\n      \"                              @Expression <Literal> 255 could run on GPU\\n\",\n      \"                            @Expression <Literal> 6 could run on GPU\\n\",\n      \"                      @Expression <Literal> 0 could run on GPU\\n\",\n      \"                  !Exec <GenerateExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"                    @Expression <Explode> explode(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int))) could run on GPU\\n\",\n      \"                      ! <ArrayRepeat> array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.ArrayRepeat\\n\",\n      \"                        @Expression <Cast> cast((TotalCharges#19 / tenure#5) as decimal(8,2)) could run on GPU\\n\",\n      \"                          @Expression <Divide> (TotalCharges#19 / tenure#5) could run on GPU\\n\",\n      \"                            @Expression <AttributeReference> TotalCharges#19 could run on GPU\\n\",\n      \"                            @Expression <AttributeReference> tenure#5 could run on GPU\\n\",\n      \"                        !Expression <Cast> cast(tenure#5 as int) cannot run on GPU because Casting from floating point types to integral types on the GPU supports a slightly different range of values when using Spark 3.1.0 or later. Refer to the CAST documentation for more details.. To enable this operation on the GPU, set spark.rapids.sql.castFloatToIntegralTypes.enabled to true.\\n\",\n      \"                          @Expression <AttributeReference> tenure#5 could run on GPU\\n\",\n      \"                    @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"                    @Expression <AttributeReference> Churn#20 could run on GPU\\n\",\n      \"                      !Exec <FilterExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"                        @Expression <And> ((atleastnnonnulls(21, customerID#0, gender#1, SeniorCitizen#2, Partner#3, Dependents#4, tenure#5, PhoneService#6, MultipleLines#7, InternetService#8, OnlineSecurity#9, OnlineBackup#10, DeviceProtection#11, TechSupport#12, StreamingTV#13, StreamingMovies#14, Contract#15, PaperlessBilling#16, PaymentMethod#17, MonthlyCharges#18, TotalCharges#19, Churn#20) AND (size(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)), true) > 0)) AND isnotnull(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)))) could run on GPU\\n\",\n      \"                          @Expression <And> (atleastnnonnulls(21, customerID#0, gender#1, SeniorCitizen#2, Partner#3, Dependents#4, tenure#5, PhoneService#6, MultipleLines#7, InternetService#8, OnlineSecurity#9, OnlineBackup#10, DeviceProtection#11, TechSupport#12, StreamingTV#13, StreamingMovies#14, Contract#15, PaperlessBilling#16, PaymentMethod#17, MonthlyCharges#18, TotalCharges#19, Churn#20) AND (size(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)), true) > 0)) could run on GPU\\n\",\n      \"                            @Expression <AtLeastNNonNulls> atleastnnonnulls(21, customerID#0, gender#1, SeniorCitizen#2, Partner#3, Dependents#4, tenure#5, PhoneService#6, MultipleLines#7, InternetService#8, OnlineSecurity#9, OnlineBackup#10, DeviceProtection#11, TechSupport#12, StreamingTV#13, StreamingMovies#14, Contract#15, PaperlessBilling#16, PaymentMethod#17, MonthlyCharges#18, TotalCharges#19, Churn#20) could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> gender#1 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> SeniorCitizen#2 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> Partner#3 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> Dependents#4 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> tenure#5 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> PhoneService#6 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> MultipleLines#7 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> InternetService#8 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> OnlineSecurity#9 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> OnlineBackup#10 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> DeviceProtection#11 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> TechSupport#12 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> StreamingTV#13 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> StreamingMovies#14 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> Contract#15 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> PaperlessBilling#16 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> PaymentMethod#17 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> MonthlyCharges#18 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> TotalCharges#19 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> Churn#20 could run on GPU\\n\",\n      \"                            @Expression <GreaterThan> (size(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)), true) > 0) could run on GPU\\n\",\n      \"                              @Expression <Size> size(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)), true) could run on GPU\\n\",\n      \"                                ! <ArrayRepeat> array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.ArrayRepeat\\n\",\n      \"                                  @Expression <Cast> cast((TotalCharges#19 / tenure#5) as decimal(8,2)) could run on GPU\\n\",\n      \"                                    @Expression <Divide> (TotalCharges#19 / tenure#5) could run on GPU\\n\",\n      \"                                      @Expression <AttributeReference> TotalCharges#19 could run on GPU\\n\",\n      \"                                      @Expression <AttributeReference> tenure#5 could run on GPU\\n\",\n      \"                                  !Expression <Cast> cast(tenure#5 as int) cannot run on GPU because Casting from floating point types to integral types on the GPU supports a slightly different range of values when using Spark 3.1.0 or later. Refer to the CAST documentation for more details.. To enable this operation on the GPU, set spark.rapids.sql.castFloatToIntegralTypes.enabled to true.\\n\",\n      \"                                    @Expression <AttributeReference> tenure#5 could run on GPU\\n\",\n      \"                              @Expression <Literal> 0 could run on GPU\\n\",\n      \"                          @Expression <IsNotNull> isnotnull(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int))) could run on GPU\\n\",\n      \"                            ! <ArrayRepeat> array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.ArrayRepeat\\n\",\n      \"                              @Expression <Cast> cast((TotalCharges#19 / tenure#5) as decimal(8,2)) could run on GPU\\n\",\n      \"                                @Expression <Divide> (TotalCharges#19 / tenure#5) could run on GPU\\n\",\n      \"                                  @Expression <AttributeReference> TotalCharges#19 could run on GPU\\n\",\n      \"                                  @Expression <AttributeReference> tenure#5 could run on GPU\\n\",\n      \"                              !Expression <Cast> cast(tenure#5 as int) cannot run on GPU because Casting from floating point types to integral types on the GPU supports a slightly different range of values when using Spark 3.1.0 or later. Refer to the CAST documentation for more details.. To enable this operation on the GPU, set spark.rapids.sql.castFloatToIntegralTypes.enabled to true.\\n\",\n      \"                                @Expression <AttributeReference> tenure#5 could run on GPU\\n\",\n      \"        !Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"          @Expression <AttributeReference> customerID#265 could run on GPU\\n\",\n      \"          @Expression <Alias> AccountCreation AS kind#191 could run on GPU\\n\",\n      \"            @Expression <Literal> AccountCreation could run on GPU\\n\",\n      \"          @Expression <Alias> 0.00 AS value#192 could run on GPU\\n\",\n      \"            @Expression <Literal> 0.00 could run on GPU\\n\",\n      \"          @Expression <Alias> add_months(2022-04-05, cast(((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) as int)) AS date#200 could run on GPU\\n\",\n      \"            ! <AddMonths> add_months(2022-04-05, cast(((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.AddMonths\\n\",\n      \"              @Expression <Literal> 2022-04-05 could run on GPU\\n\",\n      \"              !Expression <Cast> cast(((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) as int) cannot run on GPU because Casting from floating point types to integral types on the GPU supports a slightly different range of values when using Spark 3.1.0 or later. Refer to the CAST documentation for more details.. To enable this operation on the GPU, set spark.rapids.sql.castFloatToIntegralTypes.enabled to true.\\n\",\n      \"                @Expression <Add> ((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) could run on GPU\\n\",\n      \"                  @Expression <Subtract> (-tenure#270 - 1.0) could run on GPU\\n\",\n      \"                    @Expression <UnaryMinus> -tenure#270 could run on GPU\\n\",\n      \"                      @Expression <AttributeReference> tenure#270 could run on GPU\\n\",\n      \"                    @Expression <Literal> 1.0 could run on GPU\\n\",\n      \"                  @Expression <CaseWhen> CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END could run on GPU\\n\",\n      \"                    @Expression <EqualTo> (Churn#285 = Yes) could run on GPU\\n\",\n      \"                      @Expression <AttributeReference> Churn#285 could run on GPU\\n\",\n      \"                      @Expression <Literal> Yes could run on GPU\\n\",\n      \"                    @Expression <Cast> cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) could run on GPU\\n\",\n      \"                      @Expression <UnaryMinus> -((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) could run on GPU\\n\",\n      \"                        @Expression <Add> ((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) could run on GPU\\n\",\n      \"                          @Expression <Add> (((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) could run on GPU\\n\",\n      \"                            @Expression <Add> ((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) could run on GPU\\n\",\n      \"                              @Expression <Add> (((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) could run on GPU\\n\",\n      \"                                @Expression <Remainder> ((abs(xxhash64(customerID#265, 42), false) & 255) % 36) could run on GPU\\n\",\n      \"                                  @Expression <BitwiseAnd> (abs(xxhash64(customerID#265, 42), false) & 255) could run on GPU\\n\",\n      \"                                    @Expression <Abs> abs(xxhash64(customerID#265, 42), false) could run on GPU\\n\",\n      \"                                      ! <XxHash64> xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                        @Expression <AttributeReference> customerID#265 could run on GPU\\n\",\n      \"                                    @Expression <Literal> 255 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 36 could run on GPU\\n\",\n      \"                                @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24) could run on GPU\\n\",\n      \"                                  @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) could run on GPU\\n\",\n      \"                                    @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#265, 42), false), 8) could run on GPU\\n\",\n      \"                                      @Expression <Abs> abs(xxhash64(customerID#265, 42), false) could run on GPU\\n\",\n      \"                                        ! <XxHash64> xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                          @Expression <AttributeReference> customerID#265 could run on GPU\\n\",\n      \"                                      @Expression <Literal> 8 could run on GPU\\n\",\n      \"                                    @Expression <Literal> 255 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 24 could run on GPU\\n\",\n      \"                              @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14) could run on GPU\\n\",\n      \"                                @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) could run on GPU\\n\",\n      \"                                  @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#265, 42), false), 16) could run on GPU\\n\",\n      \"                                    @Expression <Abs> abs(xxhash64(customerID#265, 42), false) could run on GPU\\n\",\n      \"                                      ! <XxHash64> xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                        @Expression <AttributeReference> customerID#265 could run on GPU\\n\",\n      \"                                    @Expression <Literal> 16 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 255 could run on GPU\\n\",\n      \"                                @Expression <Literal> 14 could run on GPU\\n\",\n      \"                            @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10) could run on GPU\\n\",\n      \"                              @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) could run on GPU\\n\",\n      \"                                @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#265, 42), false), 24) could run on GPU\\n\",\n      \"                                  @Expression <Abs> abs(xxhash64(customerID#265, 42), false) could run on GPU\\n\",\n      \"                                    ! <XxHash64> xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                      @Expression <AttributeReference> customerID#265 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 24 could run on GPU\\n\",\n      \"                                @Expression <Literal> 255 could run on GPU\\n\",\n      \"                              @Expression <Literal> 10 could run on GPU\\n\",\n      \"                          @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6) could run on GPU\\n\",\n      \"                            @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) could run on GPU\\n\",\n      \"                              @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#265, 42), false), 32) could run on GPU\\n\",\n      \"                                @Expression <Abs> abs(xxhash64(customerID#265, 42), false) could run on GPU\\n\",\n      \"                                  ! <XxHash64> xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                    @Expression <AttributeReference> customerID#265 could run on GPU\\n\",\n      \"                                @Expression <Literal> 32 could run on GPU\\n\",\n      \"                              @Expression <Literal> 255 could run on GPU\\n\",\n      \"                            @Expression <Literal> 6 could run on GPU\\n\",\n      \"                    @Expression <Literal> 0.0 could run on GPU\\n\",\n      \"        !Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"          @Expression <AttributeReference> customerID#290 could run on GPU\\n\",\n      \"          @Expression <Alias> AccountTermination AS kind#258 could run on GPU\\n\",\n      \"            @Expression <Literal> AccountTermination could run on GPU\\n\",\n      \"          @Expression <Alias> 0.00 AS value#259 could run on GPU\\n\",\n      \"            @Expression <Literal> 0.00 could run on GPU\\n\",\n      \"          @Expression <Alias> CASE WHEN (Churn#310 = Yes) THEN add_months(2022-04-05, cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int)) ELSE 2022-04-05 END AS date#260 could run on GPU\\n\",\n      \"            @Expression <CaseWhen> CASE WHEN (Churn#310 = Yes) THEN add_months(2022-04-05, cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int)) ELSE 2022-04-05 END could run on GPU\\n\",\n      \"              @Expression <EqualTo> (Churn#310 = Yes) could run on GPU\\n\",\n      \"                @Expression <AttributeReference> Churn#310 could run on GPU\\n\",\n      \"                @Expression <Literal> Yes could run on GPU\\n\",\n      \"              ! <AddMonths> add_months(2022-04-05, cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.AddMonths\\n\",\n      \"                @Expression <Literal> 2022-04-05 could run on GPU\\n\",\n      \"                @Expression <Cast> cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int) could run on GPU\\n\",\n      \"                  @Expression <UnaryMinus> -((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) could run on GPU\\n\",\n      \"                    @Expression <Add> ((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) could run on GPU\\n\",\n      \"                      @Expression <Add> (((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) could run on GPU\\n\",\n      \"                        @Expression <Add> ((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) could run on GPU\\n\",\n      \"                          @Expression <Add> (((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) could run on GPU\\n\",\n      \"                            @Expression <Remainder> ((abs(xxhash64(customerID#290, 42), false) & 255) % 36) could run on GPU\\n\",\n      \"                              @Expression <BitwiseAnd> (abs(xxhash64(customerID#290, 42), false) & 255) could run on GPU\\n\",\n      \"                                @Expression <Abs> abs(xxhash64(customerID#290, 42), false) could run on GPU\\n\",\n      \"                                  ! <XxHash64> xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                    @Expression <AttributeReference> customerID#290 could run on GPU\\n\",\n      \"                                @Expression <Literal> 255 could run on GPU\\n\",\n      \"                              @Expression <Literal> 36 could run on GPU\\n\",\n      \"                            @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24) could run on GPU\\n\",\n      \"                              @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) could run on GPU\\n\",\n      \"                                @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#290, 42), false), 8) could run on GPU\\n\",\n      \"                                  @Expression <Abs> abs(xxhash64(customerID#290, 42), false) could run on GPU\\n\",\n      \"                                    ! <XxHash64> xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                      @Expression <AttributeReference> customerID#290 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 8 could run on GPU\\n\",\n      \"                                @Expression <Literal> 255 could run on GPU\\n\",\n      \"                              @Expression <Literal> 24 could run on GPU\\n\",\n      \"                          @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14) could run on GPU\\n\",\n      \"                            @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) could run on GPU\\n\",\n      \"                              @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#290, 42), false), 16) could run on GPU\\n\",\n      \"                                @Expression <Abs> abs(xxhash64(customerID#290, 42), false) could run on GPU\\n\",\n      \"                                  ! <XxHash64> xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                    @Expression <AttributeReference> customerID#290 could run on GPU\\n\",\n      \"                                @Expression <Literal> 16 could run on GPU\\n\",\n      \"                              @Expression <Literal> 255 could run on GPU\\n\",\n      \"                            @Expression <Literal> 14 could run on GPU\\n\",\n      \"                        @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10) could run on GPU\\n\",\n      \"                          @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) could run on GPU\\n\",\n      \"                            @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#290, 42), false), 24) could run on GPU\\n\",\n      \"                              @Expression <Abs> abs(xxhash64(customerID#290, 42), false) could run on GPU\\n\",\n      \"                                ! <XxHash64> xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                  @Expression <AttributeReference> customerID#290 could run on GPU\\n\",\n      \"                              @Expression <Literal> 24 could run on GPU\\n\",\n      \"                            @Expression <Literal> 255 could run on GPU\\n\",\n      \"                          @Expression <Literal> 10 could run on GPU\\n\",\n      \"                      @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6) could run on GPU\\n\",\n      \"                        @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) could run on GPU\\n\",\n      \"                          @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#290, 42), false), 32) could run on GPU\\n\",\n      \"                            @Expression <Abs> abs(xxhash64(customerID#290, 42), false) could run on GPU\\n\",\n      \"                              ! <XxHash64> xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                @Expression <AttributeReference> customerID#290 could run on GPU\\n\",\n      \"                            @Expression <Literal> 32 could run on GPU\\n\",\n      \"                          @Expression <Literal> 255 could run on GPU\\n\",\n      \"                        @Expression <Literal> 6 could run on GPU\\n\",\n      \"              @Expression <Literal> 2022-04-05 could run on GPU\\n\",\n      \"\\n\",\n      \"2022-04-05 09:36:37,176 WARN rapids.GpuOverrides: \\n\",\n      \"!Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"  @Expression <Alias> format_string(%s-%s, customerID#0, u_value#337) AS customerID#816 could run on GPU\\n\",\n      \"    ! <FormatString> format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\\n\",\n      \"      @Expression <Literal> %s-%s could run on GPU\\n\",\n      \"      @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"      @Expression <AttributeReference> u_value#337 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> kind#133 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> value#136 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> date#156 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> month#315 could run on GPU\\n\",\n      \"        !Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"          @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"          @Expression <Alias> Charge AS kind#133 could run on GPU\\n\",\n      \"            @Expression <Literal> Charge could run on GPU\\n\",\n      \"          @Expression <AttributeReference> value#136 could run on GPU\\n\",\n      \"          @Expression <Alias> add_months(2022-04-05, cast(-(cast(_we0#149 as bigint) + last_month#135L) as int)) AS date#156 could run on GPU\\n\",\n      \"            ! <AddMonths> add_months(2022-04-05, cast(-(cast(_we0#149 as bigint) + last_month#135L) as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.AddMonths\\n\",\n      \"              @Expression <Literal> 2022-04-05 could run on GPU\\n\",\n      \"              @Expression <Cast> cast(-(cast(_we0#149 as bigint) + last_month#135L) as int) could run on GPU\\n\",\n      \"                @Expression <UnaryMinus> -(cast(_we0#149 as bigint) + last_month#135L) could run on GPU\\n\",\n      \"                  @Expression <Add> (cast(_we0#149 as bigint) + last_month#135L) could run on GPU\\n\",\n      \"                    @Expression <Cast> cast(_we0#149 as bigint) could run on GPU\\n\",\n      \"                      @Expression <AttributeReference> _we0#149 could run on GPU\\n\",\n      \"                    @Expression <AttributeReference> last_month#135L could run on GPU\\n\",\n      \"              !Exec <ShuffleExchangeExec> cannot run on GPU because Columnar exchange without columnar children is inefficient\\n\",\n      \"                @Partitioning <HashPartitioning> could run on GPU\\n\",\n      \"                  @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"                !Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"                  @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"                  @Expression <AttributeReference> value#136 could run on GPU\\n\",\n      \"                  @Expression <Alias> CASE WHEN (Churn#20 = Yes) THEN -((((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) % 6)) ELSE 0 END AS last_month#135L could run on GPU\\n\",\n      \"                    @Expression <CaseWhen> CASE WHEN (Churn#20 = Yes) THEN -((((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) % 6)) ELSE 0 END could run on GPU\\n\",\n      \"                      @Expression <EqualTo> (Churn#20 = Yes) could run on GPU\\n\",\n      \"                        @Expression <AttributeReference> Churn#20 could run on GPU\\n\",\n      \"                        @Expression <Literal> Yes could run on GPU\\n\",\n      \"                      @Expression <UnaryMinus> -((((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) % 6)) could run on GPU\\n\",\n      \"                        @Expression <Add> ((((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) % 6)) could run on GPU\\n\",\n      \"                          @Expression <Add> (((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10)) could run on GPU\\n\",\n      \"                            @Expression <Add> ((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) could run on GPU\\n\",\n      \"                              @Expression <Add> (((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) could run on GPU\\n\",\n      \"                                @Expression <Remainder> ((abs(xxhash64(customerID#0, 42), false) & 255) % 36) could run on GPU\\n\",\n      \"                                  @Expression <BitwiseAnd> (abs(xxhash64(customerID#0, 42), false) & 255) could run on GPU\\n\",\n      \"                                    @Expression <Abs> abs(xxhash64(customerID#0, 42), false) could run on GPU\\n\",\n      \"                                      ! <XxHash64> xxhash64(customerID#0, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                        @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"                                    @Expression <Literal> 255 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 36 could run on GPU\\n\",\n      \"                                @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24) could run on GPU\\n\",\n      \"                                  @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) could run on GPU\\n\",\n      \"                                    @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#0, 42), false), 8) could run on GPU\\n\",\n      \"                                      @Expression <Abs> abs(xxhash64(customerID#0, 42), false) could run on GPU\\n\",\n      \"                                        ! <XxHash64> xxhash64(customerID#0, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                          @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"                                      @Expression <Literal> 8 could run on GPU\\n\",\n      \"                                    @Expression <Literal> 255 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 24 could run on GPU\\n\",\n      \"                              @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14) could run on GPU\\n\",\n      \"                                @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) could run on GPU\\n\",\n      \"                                  @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#0, 42), false), 16) could run on GPU\\n\",\n      \"                                    @Expression <Abs> abs(xxhash64(customerID#0, 42), false) could run on GPU\\n\",\n      \"                                      ! <XxHash64> xxhash64(customerID#0, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                        @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"                                    @Expression <Literal> 16 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 255 could run on GPU\\n\",\n      \"                                @Expression <Literal> 14 could run on GPU\\n\",\n      \"                            @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10) could run on GPU\\n\",\n      \"                              @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) could run on GPU\\n\",\n      \"                                @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#0, 42), false), 24) could run on GPU\\n\",\n      \"                                  @Expression <Abs> abs(xxhash64(customerID#0, 42), false) could run on GPU\\n\",\n      \"                                    ! <XxHash64> xxhash64(customerID#0, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                      @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 24 could run on GPU\\n\",\n      \"                                @Expression <Literal> 255 could run on GPU\\n\",\n      \"                              @Expression <Literal> 10 could run on GPU\\n\",\n      \"                          @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) % 6) could run on GPU\\n\",\n      \"                            @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) could run on GPU\\n\",\n      \"                              @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#0, 42), false), 32) could run on GPU\\n\",\n      \"                                @Expression <Abs> abs(xxhash64(customerID#0, 42), false) could run on GPU\\n\",\n      \"                                  ! <XxHash64> xxhash64(customerID#0, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                    @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"                                @Expression <Literal> 32 could run on GPU\\n\",\n      \"                              @Expression <Literal> 255 could run on GPU\\n\",\n      \"                            @Expression <Literal> 6 could run on GPU\\n\",\n      \"                      @Expression <Literal> 0 could run on GPU\\n\",\n      \"                  !Exec <GenerateExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"                    @Expression <Explode> explode(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int))) could run on GPU\\n\",\n      \"                      ! <ArrayRepeat> array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.ArrayRepeat\\n\",\n      \"                        @Expression <Cast> cast((TotalCharges#19 / tenure#5) as decimal(8,2)) could run on GPU\\n\",\n      \"                          @Expression <Divide> (TotalCharges#19 / tenure#5) could run on GPU\\n\",\n      \"                            @Expression <AttributeReference> TotalCharges#19 could run on GPU\\n\",\n      \"                            @Expression <AttributeReference> tenure#5 could run on GPU\\n\",\n      \"                        !Expression <Cast> cast(tenure#5 as int) cannot run on GPU because Casting from floating point types to integral types on the GPU supports a slightly different range of values when using Spark 3.1.0 or later. Refer to the CAST documentation for more details.. To enable this operation on the GPU, set spark.rapids.sql.castFloatToIntegralTypes.enabled to true.\\n\",\n      \"                          @Expression <AttributeReference> tenure#5 could run on GPU\\n\",\n      \"                    @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"                    @Expression <AttributeReference> Churn#20 could run on GPU\\n\",\n      \"                      !Exec <FilterExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"                        @Expression <And> ((atleastnnonnulls(21, customerID#0, gender#1, SeniorCitizen#2, Partner#3, Dependents#4, tenure#5, PhoneService#6, MultipleLines#7, InternetService#8, OnlineSecurity#9, OnlineBackup#10, DeviceProtection#11, TechSupport#12, StreamingTV#13, StreamingMovies#14, Contract#15, PaperlessBilling#16, PaymentMethod#17, MonthlyCharges#18, TotalCharges#19, Churn#20) AND (size(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)), true) > 0)) AND isnotnull(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)))) could run on GPU\\n\",\n      \"                          @Expression <And> (atleastnnonnulls(21, customerID#0, gender#1, SeniorCitizen#2, Partner#3, Dependents#4, tenure#5, PhoneService#6, MultipleLines#7, InternetService#8, OnlineSecurity#9, OnlineBackup#10, DeviceProtection#11, TechSupport#12, StreamingTV#13, StreamingMovies#14, Contract#15, PaperlessBilling#16, PaymentMethod#17, MonthlyCharges#18, TotalCharges#19, Churn#20) AND (size(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)), true) > 0)) could run on GPU\\n\",\n      \"                            @Expression <AtLeastNNonNulls> atleastnnonnulls(21, customerID#0, gender#1, SeniorCitizen#2, Partner#3, Dependents#4, tenure#5, PhoneService#6, MultipleLines#7, InternetService#8, OnlineSecurity#9, OnlineBackup#10, DeviceProtection#11, TechSupport#12, StreamingTV#13, StreamingMovies#14, Contract#15, PaperlessBilling#16, PaymentMethod#17, MonthlyCharges#18, TotalCharges#19, Churn#20) could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> gender#1 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> SeniorCitizen#2 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> Partner#3 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> Dependents#4 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> tenure#5 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> PhoneService#6 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> MultipleLines#7 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> InternetService#8 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> OnlineSecurity#9 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> OnlineBackup#10 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> DeviceProtection#11 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> TechSupport#12 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> StreamingTV#13 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> StreamingMovies#14 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> Contract#15 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> PaperlessBilling#16 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> PaymentMethod#17 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> MonthlyCharges#18 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> TotalCharges#19 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> Churn#20 could run on GPU\\n\",\n      \"                            @Expression <GreaterThan> (size(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)), true) > 0) could run on GPU\\n\",\n      \"                              @Expression <Size> size(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)), true) could run on GPU\\n\",\n      \"                                ! <ArrayRepeat> array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.ArrayRepeat\\n\",\n      \"                                  @Expression <Cast> cast((TotalCharges#19 / tenure#5) as decimal(8,2)) could run on GPU\\n\",\n      \"                                    @Expression <Divide> (TotalCharges#19 / tenure#5) could run on GPU\\n\",\n      \"                                      @Expression <AttributeReference> TotalCharges#19 could run on GPU\\n\",\n      \"                                      @Expression <AttributeReference> tenure#5 could run on GPU\\n\",\n      \"                                  !Expression <Cast> cast(tenure#5 as int) cannot run on GPU because Casting from floating point types to integral types on the GPU supports a slightly different range of values when using Spark 3.1.0 or later. Refer to the CAST documentation for more details.. To enable this operation on the GPU, set spark.rapids.sql.castFloatToIntegralTypes.enabled to true.\\n\",\n      \"                                    @Expression <AttributeReference> tenure#5 could run on GPU\\n\",\n      \"                              @Expression <Literal> 0 could run on GPU\\n\",\n      \"                          @Expression <IsNotNull> isnotnull(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int))) could run on GPU\\n\",\n      \"                            ! <ArrayRepeat> array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.ArrayRepeat\\n\",\n      \"                              @Expression <Cast> cast((TotalCharges#19 / tenure#5) as decimal(8,2)) could run on GPU\\n\",\n      \"                                @Expression <Divide> (TotalCharges#19 / tenure#5) could run on GPU\\n\",\n      \"                                  @Expression <AttributeReference> TotalCharges#19 could run on GPU\\n\",\n      \"                                  @Expression <AttributeReference> tenure#5 could run on GPU\\n\",\n      \"                              !Expression <Cast> cast(tenure#5 as int) cannot run on GPU because Casting from floating point types to integral types on the GPU supports a slightly different range of values when using Spark 3.1.0 or later. Refer to the CAST documentation for more details.. To enable this operation on the GPU, set spark.rapids.sql.castFloatToIntegralTypes.enabled to true.\\n\",\n      \"                                @Expression <AttributeReference> tenure#5 could run on GPU\\n\",\n      \"        !Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"          @Expression <AttributeReference> customerID#265 could run on GPU\\n\",\n      \"          @Expression <Alias> AccountCreation AS kind#191 could run on GPU\\n\",\n      \"            @Expression <Literal> AccountCreation could run on GPU\\n\",\n      \"          @Expression <Alias> 0.00 AS value#192 could run on GPU\\n\",\n      \"            @Expression <Literal> 0.00 could run on GPU\\n\",\n      \"          @Expression <Alias> add_months(2022-04-05, cast(((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) as int)) AS date#200 could run on GPU\\n\",\n      \"            ! <AddMonths> add_months(2022-04-05, cast(((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.AddMonths\\n\",\n      \"              @Expression <Literal> 2022-04-05 could run on GPU\\n\",\n      \"              !Expression <Cast> cast(((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) as int) cannot run on GPU because Casting from floating point types to integral types on the GPU supports a slightly different range of values when using Spark 3.1.0 or later. Refer to the CAST documentation for more details.. To enable this operation on the GPU, set spark.rapids.sql.castFloatToIntegralTypes.enabled to true.\\n\",\n      \"                @Expression <Add> ((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) could run on GPU\\n\",\n      \"                  @Expression <Subtract> (-tenure#270 - 1.0) could run on GPU\\n\",\n      \"                    @Expression <UnaryMinus> -tenure#270 could run on GPU\\n\",\n      \"                      @Expression <AttributeReference> tenure#270 could run on GPU\\n\",\n      \"                    @Expression <Literal> 1.0 could run on GPU\\n\",\n      \"                  @Expression <CaseWhen> CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END could run on GPU\\n\",\n      \"                    @Expression <EqualTo> (Churn#285 = Yes) could run on GPU\\n\",\n      \"                      @Expression <AttributeReference> Churn#285 could run on GPU\\n\",\n      \"                      @Expression <Literal> Yes could run on GPU\\n\",\n      \"                    @Expression <Cast> cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) could run on GPU\\n\",\n      \"                      @Expression <UnaryMinus> -((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) could run on GPU\\n\",\n      \"                        @Expression <Add> ((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) could run on GPU\\n\",\n      \"                          @Expression <Add> (((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) could run on GPU\\n\",\n      \"                            @Expression <Add> ((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) could run on GPU\\n\",\n      \"                              @Expression <Add> (((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) could run on GPU\\n\",\n      \"                                @Expression <Remainder> ((abs(xxhash64(customerID#265, 42), false) & 255) % 36) could run on GPU\\n\",\n      \"                                  @Expression <BitwiseAnd> (abs(xxhash64(customerID#265, 42), false) & 255) could run on GPU\\n\",\n      \"                                    @Expression <Abs> abs(xxhash64(customerID#265, 42), false) could run on GPU\\n\",\n      \"                                      ! <XxHash64> xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                        @Expression <AttributeReference> customerID#265 could run on GPU\\n\",\n      \"                                    @Expression <Literal> 255 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 36 could run on GPU\\n\",\n      \"                                @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24) could run on GPU\\n\",\n      \"                                  @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) could run on GPU\\n\",\n      \"                                    @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#265, 42), false), 8) could run on GPU\\n\",\n      \"                                      @Expression <Abs> abs(xxhash64(customerID#265, 42), false) could run on GPU\\n\",\n      \"                                        ! <XxHash64> xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                          @Expression <AttributeReference> customerID#265 could run on GPU\\n\",\n      \"                                      @Expression <Literal> 8 could run on GPU\\n\",\n      \"                                    @Expression <Literal> 255 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 24 could run on GPU\\n\",\n      \"                              @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14) could run on GPU\\n\",\n      \"                                @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) could run on GPU\\n\",\n      \"                                  @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#265, 42), false), 16) could run on GPU\\n\",\n      \"                                    @Expression <Abs> abs(xxhash64(customerID#265, 42), false) could run on GPU\\n\",\n      \"                                      ! <XxHash64> xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                        @Expression <AttributeReference> customerID#265 could run on GPU\\n\",\n      \"                                    @Expression <Literal> 16 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 255 could run on GPU\\n\",\n      \"                                @Expression <Literal> 14 could run on GPU\\n\",\n      \"                            @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10) could run on GPU\\n\",\n      \"                              @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) could run on GPU\\n\",\n      \"                                @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#265, 42), false), 24) could run on GPU\\n\",\n      \"                                  @Expression <Abs> abs(xxhash64(customerID#265, 42), false) could run on GPU\\n\",\n      \"                                    ! <XxHash64> xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                      @Expression <AttributeReference> customerID#265 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 24 could run on GPU\\n\",\n      \"                                @Expression <Literal> 255 could run on GPU\\n\",\n      \"                              @Expression <Literal> 10 could run on GPU\\n\",\n      \"                          @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6) could run on GPU\\n\",\n      \"                            @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) could run on GPU\\n\",\n      \"                              @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#265, 42), false), 32) could run on GPU\\n\",\n      \"                                @Expression <Abs> abs(xxhash64(customerID#265, 42), false) could run on GPU\\n\",\n      \"                                  ! <XxHash64> xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                    @Expression <AttributeReference> customerID#265 could run on GPU\\n\",\n      \"                                @Expression <Literal> 32 could run on GPU\\n\",\n      \"                              @Expression <Literal> 255 could run on GPU\\n\",\n      \"                            @Expression <Literal> 6 could run on GPU\\n\",\n      \"                    @Expression <Literal> 0.0 could run on GPU\\n\",\n      \"        !Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"          @Expression <AttributeReference> customerID#290 could run on GPU\\n\",\n      \"          @Expression <Alias> AccountTermination AS kind#258 could run on GPU\\n\",\n      \"            @Expression <Literal> AccountTermination could run on GPU\\n\",\n      \"          @Expression <Alias> 0.00 AS value#259 could run on GPU\\n\",\n      \"            @Expression <Literal> 0.00 could run on GPU\\n\",\n      \"          @Expression <Alias> CASE WHEN (Churn#310 = Yes) THEN add_months(2022-04-05, cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int)) ELSE 2022-04-05 END AS date#260 could run on GPU\\n\",\n      \"            @Expression <CaseWhen> CASE WHEN (Churn#310 = Yes) THEN add_months(2022-04-05, cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int)) ELSE 2022-04-05 END could run on GPU\\n\",\n      \"              @Expression <EqualTo> (Churn#310 = Yes) could run on GPU\\n\",\n      \"                @Expression <AttributeReference> Churn#310 could run on GPU\\n\",\n      \"                @Expression <Literal> Yes could run on GPU\\n\",\n      \"              ! <AddMonths> add_months(2022-04-05, cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.AddMonths\\n\",\n      \"                @Expression <Literal> 2022-04-05 could run on GPU\\n\",\n      \"                @Expression <Cast> cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int) could run on GPU\\n\",\n      \"                  @Expression <UnaryMinus> -((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) could run on GPU\\n\",\n      \"                    @Expression <Add> ((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) could run on GPU\\n\",\n      \"                      @Expression <Add> (((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) could run on GPU\\n\",\n      \"                        @Expression <Add> ((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) could run on GPU\\n\",\n      \"                          @Expression <Add> (((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) could run on GPU\\n\",\n      \"                            @Expression <Remainder> ((abs(xxhash64(customerID#290, 42), false) & 255) % 36) could run on GPU\\n\",\n      \"                              @Expression <BitwiseAnd> (abs(xxhash64(customerID#290, 42), false) & 255) could run on GPU\\n\",\n      \"                                @Expression <Abs> abs(xxhash64(customerID#290, 42), false) could run on GPU\\n\",\n      \"                                  ! <XxHash64> xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                    @Expression <AttributeReference> customerID#290 could run on GPU\\n\",\n      \"                                @Expression <Literal> 255 could run on GPU\\n\",\n      \"                              @Expression <Literal> 36 could run on GPU\\n\",\n      \"                            @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24) could run on GPU\\n\",\n      \"                              @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) could run on GPU\\n\",\n      \"                                @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#290, 42), false), 8) could run on GPU\\n\",\n      \"                                  @Expression <Abs> abs(xxhash64(customerID#290, 42), false) could run on GPU\\n\",\n      \"                                    ! <XxHash64> xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                      @Expression <AttributeReference> customerID#290 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 8 could run on GPU\\n\",\n      \"                                @Expression <Literal> 255 could run on GPU\\n\",\n      \"                              @Expression <Literal> 24 could run on GPU\\n\",\n      \"                          @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14) could run on GPU\\n\",\n      \"                            @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) could run on GPU\\n\",\n      \"                              @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#290, 42), false), 16) could run on GPU\\n\",\n      \"                                @Expression <Abs> abs(xxhash64(customerID#290, 42), false) could run on GPU\\n\",\n      \"                                  ! <XxHash64> xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                    @Expression <AttributeReference> customerID#290 could run on GPU\\n\",\n      \"                                @Expression <Literal> 16 could run on GPU\\n\",\n      \"                              @Expression <Literal> 255 could run on GPU\\n\",\n      \"                            @Expression <Literal> 14 could run on GPU\\n\",\n      \"                        @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10) could run on GPU\\n\",\n      \"                          @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) could run on GPU\\n\",\n      \"                            @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#290, 42), false), 24) could run on GPU\\n\",\n      \"                              @Expression <Abs> abs(xxhash64(customerID#290, 42), false) could run on GPU\\n\",\n      \"                                ! <XxHash64> xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                  @Expression <AttributeReference> customerID#290 could run on GPU\\n\",\n      \"                              @Expression <Literal> 24 could run on GPU\\n\",\n      \"                            @Expression <Literal> 255 could run on GPU\\n\",\n      \"                          @Expression <Literal> 10 could run on GPU\\n\",\n      \"                      @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6) could run on GPU\\n\",\n      \"                        @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) could run on GPU\\n\",\n      \"                          @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#290, 42), false), 32) could run on GPU\\n\",\n      \"                            @Expression <Abs> abs(xxhash64(customerID#290, 42), false) could run on GPU\\n\",\n      \"                              ! <XxHash64> xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                @Expression <AttributeReference> customerID#290 could run on GPU\\n\",\n      \"                            @Expression <Literal> 32 could run on GPU\\n\",\n      \"                          @Expression <Literal> 255 could run on GPU\\n\",\n      \"                        @Expression <Literal> 6 could run on GPU\\n\",\n      \"              @Expression <Literal> 2022-04-05 could run on GPU\\n\",\n      \"\\n\",\n      \"2022-04-05 09:36:37,199 WARN rapids.GpuOverrides: \\n\",\n      \"!Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"  @Expression <Alias> format_string(%s-%s, customerID#0, u_value#337) AS customerID#816 could run on GPU\\n\",\n      \"    ! <FormatString> format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\\n\",\n      \"      @Expression <Literal> %s-%s could run on GPU\\n\",\n      \"      @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"      @Expression <AttributeReference> u_value#337 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> kind#133 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> value#136 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> date#156 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> month#315 could run on GPU\\n\",\n      \"        !Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"          @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"          @Expression <Alias> Charge AS kind#133 could run on GPU\\n\",\n      \"            @Expression <Literal> Charge could run on GPU\\n\",\n      \"          @Expression <AttributeReference> value#136 could run on GPU\\n\",\n      \"          @Expression <Alias> add_months(2022-04-05, cast(-(cast(_we0#149 as bigint) + last_month#135L) as int)) AS date#156 could run on GPU\\n\",\n      \"            ! <AddMonths> add_months(2022-04-05, cast(-(cast(_we0#149 as bigint) + last_month#135L) as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.AddMonths\\n\",\n      \"              @Expression <Literal> 2022-04-05 could run on GPU\\n\",\n      \"              @Expression <Cast> cast(-(cast(_we0#149 as bigint) + last_month#135L) as int) could run on GPU\\n\",\n      \"                @Expression <UnaryMinus> -(cast(_we0#149 as bigint) + last_month#135L) could run on GPU\\n\",\n      \"                  @Expression <Add> (cast(_we0#149 as bigint) + last_month#135L) could run on GPU\\n\",\n      \"                    @Expression <Cast> cast(_we0#149 as bigint) could run on GPU\\n\",\n      \"                      @Expression <AttributeReference> _we0#149 could run on GPU\\n\",\n      \"                    @Expression <AttributeReference> last_month#135L could run on GPU\\n\",\n      \"              !Exec <ShuffleExchangeExec> cannot run on GPU because Columnar exchange without columnar children is inefficient\\n\",\n      \"                @Partitioning <HashPartitioning> could run on GPU\\n\",\n      \"                  @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"                !Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"                  @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"                  @Expression <AttributeReference> value#136 could run on GPU\\n\",\n      \"                  @Expression <Alias> CASE WHEN (Churn#20 = Yes) THEN -((((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) % 6)) ELSE 0 END AS last_month#135L could run on GPU\\n\",\n      \"                    @Expression <CaseWhen> CASE WHEN (Churn#20 = Yes) THEN -((((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) % 6)) ELSE 0 END could run on GPU\\n\",\n      \"                      @Expression <EqualTo> (Churn#20 = Yes) could run on GPU\\n\",\n      \"                        @Expression <AttributeReference> Churn#20 could run on GPU\\n\",\n      \"                        @Expression <Literal> Yes could run on GPU\\n\",\n      \"                      @Expression <UnaryMinus> -((((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) % 6)) could run on GPU\\n\",\n      \"                        @Expression <Add> ((((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) % 6)) could run on GPU\\n\",\n      \"                          @Expression <Add> (((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10)) could run on GPU\\n\",\n      \"                            @Expression <Add> ((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) could run on GPU\\n\",\n      \"                              @Expression <Add> (((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) could run on GPU\\n\",\n      \"                                @Expression <Remainder> ((abs(xxhash64(customerID#0, 42), false) & 255) % 36) could run on GPU\\n\",\n      \"                                  @Expression <BitwiseAnd> (abs(xxhash64(customerID#0, 42), false) & 255) could run on GPU\\n\",\n      \"                                    @Expression <Abs> abs(xxhash64(customerID#0, 42), false) could run on GPU\\n\",\n      \"                                      ! <XxHash64> xxhash64(customerID#0, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                        @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"                                    @Expression <Literal> 255 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 36 could run on GPU\\n\",\n      \"                                @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24) could run on GPU\\n\",\n      \"                                  @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) could run on GPU\\n\",\n      \"                                    @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#0, 42), false), 8) could run on GPU\\n\",\n      \"                                      @Expression <Abs> abs(xxhash64(customerID#0, 42), false) could run on GPU\\n\",\n      \"                                        ! <XxHash64> xxhash64(customerID#0, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                          @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"                                      @Expression <Literal> 8 could run on GPU\\n\",\n      \"                                    @Expression <Literal> 255 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 24 could run on GPU\\n\",\n      \"                              @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14) could run on GPU\\n\",\n      \"                                @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) could run on GPU\\n\",\n      \"                                  @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#0, 42), false), 16) could run on GPU\\n\",\n      \"                                    @Expression <Abs> abs(xxhash64(customerID#0, 42), false) could run on GPU\\n\",\n      \"                                      ! <XxHash64> xxhash64(customerID#0, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                        @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"                                    @Expression <Literal> 16 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 255 could run on GPU\\n\",\n      \"                                @Expression <Literal> 14 could run on GPU\\n\",\n      \"                            @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10) could run on GPU\\n\",\n      \"                              @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) could run on GPU\\n\",\n      \"                                @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#0, 42), false), 24) could run on GPU\\n\",\n      \"                                  @Expression <Abs> abs(xxhash64(customerID#0, 42), false) could run on GPU\\n\",\n      \"                                    ! <XxHash64> xxhash64(customerID#0, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                      @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 24 could run on GPU\\n\",\n      \"                                @Expression <Literal> 255 could run on GPU\\n\",\n      \"                              @Expression <Literal> 10 could run on GPU\\n\",\n      \"                          @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) % 6) could run on GPU\\n\",\n      \"                            @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) could run on GPU\\n\",\n      \"                              @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#0, 42), false), 32) could run on GPU\\n\",\n      \"                                @Expression <Abs> abs(xxhash64(customerID#0, 42), false) could run on GPU\\n\",\n      \"                                  ! <XxHash64> xxhash64(customerID#0, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                    @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"                                @Expression <Literal> 32 could run on GPU\\n\",\n      \"                              @Expression <Literal> 255 could run on GPU\\n\",\n      \"                            @Expression <Literal> 6 could run on GPU\\n\",\n      \"                      @Expression <Literal> 0 could run on GPU\\n\",\n      \"                  !Exec <GenerateExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"                    @Expression <Explode> explode(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int))) could run on GPU\\n\",\n      \"                      ! <ArrayRepeat> array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.ArrayRepeat\\n\",\n      \"                        @Expression <Cast> cast((TotalCharges#19 / tenure#5) as decimal(8,2)) could run on GPU\\n\",\n      \"                          @Expression <Divide> (TotalCharges#19 / tenure#5) could run on GPU\\n\",\n      \"                            @Expression <AttributeReference> TotalCharges#19 could run on GPU\\n\",\n      \"                            @Expression <AttributeReference> tenure#5 could run on GPU\\n\",\n      \"                        !Expression <Cast> cast(tenure#5 as int) cannot run on GPU because Casting from floating point types to integral types on the GPU supports a slightly different range of values when using Spark 3.1.0 or later. Refer to the CAST documentation for more details.. To enable this operation on the GPU, set spark.rapids.sql.castFloatToIntegralTypes.enabled to true.\\n\",\n      \"                          @Expression <AttributeReference> tenure#5 could run on GPU\\n\",\n      \"                    @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"                    @Expression <AttributeReference> Churn#20 could run on GPU\\n\",\n      \"                      !Exec <FilterExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"                        @Expression <And> ((atleastnnonnulls(21, customerID#0, gender#1, SeniorCitizen#2, Partner#3, Dependents#4, tenure#5, PhoneService#6, MultipleLines#7, InternetService#8, OnlineSecurity#9, OnlineBackup#10, DeviceProtection#11, TechSupport#12, StreamingTV#13, StreamingMovies#14, Contract#15, PaperlessBilling#16, PaymentMethod#17, MonthlyCharges#18, TotalCharges#19, Churn#20) AND (size(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)), true) > 0)) AND isnotnull(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)))) could run on GPU\\n\",\n      \"                          @Expression <And> (atleastnnonnulls(21, customerID#0, gender#1, SeniorCitizen#2, Partner#3, Dependents#4, tenure#5, PhoneService#6, MultipleLines#7, InternetService#8, OnlineSecurity#9, OnlineBackup#10, DeviceProtection#11, TechSupport#12, StreamingTV#13, StreamingMovies#14, Contract#15, PaperlessBilling#16, PaymentMethod#17, MonthlyCharges#18, TotalCharges#19, Churn#20) AND (size(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)), true) > 0)) could run on GPU\\n\",\n      \"                            @Expression <AtLeastNNonNulls> atleastnnonnulls(21, customerID#0, gender#1, SeniorCitizen#2, Partner#3, Dependents#4, tenure#5, PhoneService#6, MultipleLines#7, InternetService#8, OnlineSecurity#9, OnlineBackup#10, DeviceProtection#11, TechSupport#12, StreamingTV#13, StreamingMovies#14, Contract#15, PaperlessBilling#16, PaymentMethod#17, MonthlyCharges#18, TotalCharges#19, Churn#20) could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> gender#1 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> SeniorCitizen#2 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> Partner#3 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> Dependents#4 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> tenure#5 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> PhoneService#6 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> MultipleLines#7 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> InternetService#8 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> OnlineSecurity#9 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> OnlineBackup#10 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> DeviceProtection#11 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> TechSupport#12 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> StreamingTV#13 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> StreamingMovies#14 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> Contract#15 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> PaperlessBilling#16 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> PaymentMethod#17 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> MonthlyCharges#18 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> TotalCharges#19 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> Churn#20 could run on GPU\\n\",\n      \"                            @Expression <GreaterThan> (size(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)), true) > 0) could run on GPU\\n\",\n      \"                              @Expression <Size> size(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)), true) could run on GPU\\n\",\n      \"                                ! <ArrayRepeat> array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.ArrayRepeat\\n\",\n      \"                                  @Expression <Cast> cast((TotalCharges#19 / tenure#5) as decimal(8,2)) could run on GPU\\n\",\n      \"                                    @Expression <Divide> (TotalCharges#19 / tenure#5) could run on GPU\\n\",\n      \"                                      @Expression <AttributeReference> TotalCharges#19 could run on GPU\\n\",\n      \"                                      @Expression <AttributeReference> tenure#5 could run on GPU\\n\",\n      \"                                  !Expression <Cast> cast(tenure#5 as int) cannot run on GPU because Casting from floating point types to integral types on the GPU supports a slightly different range of values when using Spark 3.1.0 or later. Refer to the CAST documentation for more details.. To enable this operation on the GPU, set spark.rapids.sql.castFloatToIntegralTypes.enabled to true.\\n\",\n      \"                                    @Expression <AttributeReference> tenure#5 could run on GPU\\n\",\n      \"                              @Expression <Literal> 0 could run on GPU\\n\",\n      \"                          @Expression <IsNotNull> isnotnull(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int))) could run on GPU\\n\",\n      \"                            ! <ArrayRepeat> array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.ArrayRepeat\\n\",\n      \"                              @Expression <Cast> cast((TotalCharges#19 / tenure#5) as decimal(8,2)) could run on GPU\\n\",\n      \"                                @Expression <Divide> (TotalCharges#19 / tenure#5) could run on GPU\\n\",\n      \"                                  @Expression <AttributeReference> TotalCharges#19 could run on GPU\\n\",\n      \"                                  @Expression <AttributeReference> tenure#5 could run on GPU\\n\",\n      \"                              !Expression <Cast> cast(tenure#5 as int) cannot run on GPU because Casting from floating point types to integral types on the GPU supports a slightly different range of values when using Spark 3.1.0 or later. Refer to the CAST documentation for more details.. To enable this operation on the GPU, set spark.rapids.sql.castFloatToIntegralTypes.enabled to true.\\n\",\n      \"                                @Expression <AttributeReference> tenure#5 could run on GPU\\n\",\n      \"        !Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"          @Expression <AttributeReference> customerID#265 could run on GPU\\n\",\n      \"          @Expression <Alias> AccountCreation AS kind#191 could run on GPU\\n\",\n      \"            @Expression <Literal> AccountCreation could run on GPU\\n\",\n      \"          @Expression <Alias> 0.00 AS value#192 could run on GPU\\n\",\n      \"            @Expression <Literal> 0.00 could run on GPU\\n\",\n      \"          @Expression <Alias> add_months(2022-04-05, cast(((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) as int)) AS date#200 could run on GPU\\n\",\n      \"            ! <AddMonths> add_months(2022-04-05, cast(((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.AddMonths\\n\",\n      \"              @Expression <Literal> 2022-04-05 could run on GPU\\n\",\n      \"              !Expression <Cast> cast(((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) as int) cannot run on GPU because Casting from floating point types to integral types on the GPU supports a slightly different range of values when using Spark 3.1.0 or later. Refer to the CAST documentation for more details.. To enable this operation on the GPU, set spark.rapids.sql.castFloatToIntegralTypes.enabled to true.\\n\",\n      \"                @Expression <Add> ((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) could run on GPU\\n\",\n      \"                  @Expression <Subtract> (-tenure#270 - 1.0) could run on GPU\\n\",\n      \"                    @Expression <UnaryMinus> -tenure#270 could run on GPU\\n\",\n      \"                      @Expression <AttributeReference> tenure#270 could run on GPU\\n\",\n      \"                    @Expression <Literal> 1.0 could run on GPU\\n\",\n      \"                  @Expression <CaseWhen> CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END could run on GPU\\n\",\n      \"                    @Expression <EqualTo> (Churn#285 = Yes) could run on GPU\\n\",\n      \"                      @Expression <AttributeReference> Churn#285 could run on GPU\\n\",\n      \"                      @Expression <Literal> Yes could run on GPU\\n\",\n      \"                    @Expression <Cast> cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) could run on GPU\\n\",\n      \"                      @Expression <UnaryMinus> -((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) could run on GPU\\n\",\n      \"                        @Expression <Add> ((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) could run on GPU\\n\",\n      \"                          @Expression <Add> (((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) could run on GPU\\n\",\n      \"                            @Expression <Add> ((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) could run on GPU\\n\",\n      \"                              @Expression <Add> (((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) could run on GPU\\n\",\n      \"                                @Expression <Remainder> ((abs(xxhash64(customerID#265, 42), false) & 255) % 36) could run on GPU\\n\",\n      \"                                  @Expression <BitwiseAnd> (abs(xxhash64(customerID#265, 42), false) & 255) could run on GPU\\n\",\n      \"                                    @Expression <Abs> abs(xxhash64(customerID#265, 42), false) could run on GPU\\n\",\n      \"                                      ! <XxHash64> xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                        @Expression <AttributeReference> customerID#265 could run on GPU\\n\",\n      \"                                    @Expression <Literal> 255 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 36 could run on GPU\\n\",\n      \"                                @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24) could run on GPU\\n\",\n      \"                                  @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) could run on GPU\\n\",\n      \"                                    @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#265, 42), false), 8) could run on GPU\\n\",\n      \"                                      @Expression <Abs> abs(xxhash64(customerID#265, 42), false) could run on GPU\\n\",\n      \"                                        ! <XxHash64> xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                          @Expression <AttributeReference> customerID#265 could run on GPU\\n\",\n      \"                                      @Expression <Literal> 8 could run on GPU\\n\",\n      \"                                    @Expression <Literal> 255 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 24 could run on GPU\\n\",\n      \"                              @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14) could run on GPU\\n\",\n      \"                                @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) could run on GPU\\n\",\n      \"                                  @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#265, 42), false), 16) could run on GPU\\n\",\n      \"                                    @Expression <Abs> abs(xxhash64(customerID#265, 42), false) could run on GPU\\n\",\n      \"                                      ! <XxHash64> xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                        @Expression <AttributeReference> customerID#265 could run on GPU\\n\",\n      \"                                    @Expression <Literal> 16 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 255 could run on GPU\\n\",\n      \"                                @Expression <Literal> 14 could run on GPU\\n\",\n      \"                            @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10) could run on GPU\\n\",\n      \"                              @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) could run on GPU\\n\",\n      \"                                @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#265, 42), false), 24) could run on GPU\\n\",\n      \"                                  @Expression <Abs> abs(xxhash64(customerID#265, 42), false) could run on GPU\\n\",\n      \"                                    ! <XxHash64> xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                      @Expression <AttributeReference> customerID#265 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 24 could run on GPU\\n\",\n      \"                                @Expression <Literal> 255 could run on GPU\\n\",\n      \"                              @Expression <Literal> 10 could run on GPU\\n\",\n      \"                          @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6) could run on GPU\\n\",\n      \"                            @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) could run on GPU\\n\",\n      \"                              @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#265, 42), false), 32) could run on GPU\\n\",\n      \"                                @Expression <Abs> abs(xxhash64(customerID#265, 42), false) could run on GPU\\n\",\n      \"                                  ! <XxHash64> xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                    @Expression <AttributeReference> customerID#265 could run on GPU\\n\",\n      \"                                @Expression <Literal> 32 could run on GPU\\n\",\n      \"                              @Expression <Literal> 255 could run on GPU\\n\",\n      \"                            @Expression <Literal> 6 could run on GPU\\n\",\n      \"                    @Expression <Literal> 0.0 could run on GPU\\n\",\n      \"        !Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"          @Expression <AttributeReference> customerID#290 could run on GPU\\n\",\n      \"          @Expression <Alias> AccountTermination AS kind#258 could run on GPU\\n\",\n      \"            @Expression <Literal> AccountTermination could run on GPU\\n\",\n      \"          @Expression <Alias> 0.00 AS value#259 could run on GPU\\n\",\n      \"            @Expression <Literal> 0.00 could run on GPU\\n\",\n      \"          @Expression <Alias> CASE WHEN (Churn#310 = Yes) THEN add_months(2022-04-05, cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int)) ELSE 2022-04-05 END AS date#260 could run on GPU\\n\",\n      \"            @Expression <CaseWhen> CASE WHEN (Churn#310 = Yes) THEN add_months(2022-04-05, cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int)) ELSE 2022-04-05 END could run on GPU\\n\",\n      \"              @Expression <EqualTo> (Churn#310 = Yes) could run on GPU\\n\",\n      \"                @Expression <AttributeReference> Churn#310 could run on GPU\\n\",\n      \"                @Expression <Literal> Yes could run on GPU\\n\",\n      \"              ! <AddMonths> add_months(2022-04-05, cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.AddMonths\\n\",\n      \"                @Expression <Literal> 2022-04-05 could run on GPU\\n\",\n      \"                @Expression <Cast> cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int) could run on GPU\\n\",\n      \"                  @Expression <UnaryMinus> -((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) could run on GPU\\n\",\n      \"                    @Expression <Add> ((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) could run on GPU\\n\",\n      \"                      @Expression <Add> (((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) could run on GPU\\n\",\n      \"                        @Expression <Add> ((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) could run on GPU\\n\",\n      \"                          @Expression <Add> (((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) could run on GPU\\n\",\n      \"                            @Expression <Remainder> ((abs(xxhash64(customerID#290, 42), false) & 255) % 36) could run on GPU\\n\",\n      \"                              @Expression <BitwiseAnd> (abs(xxhash64(customerID#290, 42), false) & 255) could run on GPU\\n\",\n      \"                                @Expression <Abs> abs(xxhash64(customerID#290, 42), false) could run on GPU\\n\",\n      \"                                  ! <XxHash64> xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                    @Expression <AttributeReference> customerID#290 could run on GPU\\n\",\n      \"                                @Expression <Literal> 255 could run on GPU\\n\",\n      \"                              @Expression <Literal> 36 could run on GPU\\n\",\n      \"                            @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24) could run on GPU\\n\",\n      \"                              @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) could run on GPU\\n\",\n      \"                                @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#290, 42), false), 8) could run on GPU\\n\",\n      \"                                  @Expression <Abs> abs(xxhash64(customerID#290, 42), false) could run on GPU\\n\",\n      \"                                    ! <XxHash64> xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                      @Expression <AttributeReference> customerID#290 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 8 could run on GPU\\n\",\n      \"                                @Expression <Literal> 255 could run on GPU\\n\",\n      \"                              @Expression <Literal> 24 could run on GPU\\n\",\n      \"                          @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14) could run on GPU\\n\",\n      \"                            @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) could run on GPU\\n\",\n      \"                              @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#290, 42), false), 16) could run on GPU\\n\",\n      \"                                @Expression <Abs> abs(xxhash64(customerID#290, 42), false) could run on GPU\\n\",\n      \"                                  ! <XxHash64> xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                    @Expression <AttributeReference> customerID#290 could run on GPU\\n\",\n      \"                                @Expression <Literal> 16 could run on GPU\\n\",\n      \"                              @Expression <Literal> 255 could run on GPU\\n\",\n      \"                            @Expression <Literal> 14 could run on GPU\\n\",\n      \"                        @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10) could run on GPU\\n\",\n      \"                          @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) could run on GPU\\n\",\n      \"                            @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#290, 42), false), 24) could run on GPU\\n\",\n      \"                              @Expression <Abs> abs(xxhash64(customerID#290, 42), false) could run on GPU\\n\",\n      \"                                ! <XxHash64> xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                  @Expression <AttributeReference> customerID#290 could run on GPU\\n\",\n      \"                              @Expression <Literal> 24 could run on GPU\\n\",\n      \"                            @Expression <Literal> 255 could run on GPU\\n\",\n      \"                          @Expression <Literal> 10 could run on GPU\\n\",\n      \"                      @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6) could run on GPU\\n\",\n      \"                        @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) could run on GPU\\n\",\n      \"                          @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#290, 42), false), 32) could run on GPU\\n\",\n      \"                            @Expression <Abs> abs(xxhash64(customerID#290, 42), false) could run on GPU\\n\",\n      \"                              ! <XxHash64> xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                @Expression <AttributeReference> customerID#290 could run on GPU\\n\",\n      \"                            @Expression <Literal> 32 could run on GPU\\n\",\n      \"                          @Expression <Literal> 255 could run on GPU\\n\",\n      \"                        @Expression <Literal> 6 could run on GPU\\n\",\n      \"              @Expression <Literal> 2022-04-05 could run on GPU\\n\",\n      \"\\n\",\n      \"2022-04-05 09:36:37,210 WARN rapids.GpuOverrides: \\n\",\n      \"!Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"  @Expression <Alias> format_string(%s-%s, customerID#0, u_value#337) AS customerID#816 could run on GPU\\n\",\n      \"    ! <FormatString> format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\\n\",\n      \"      @Expression <Literal> %s-%s could run on GPU\\n\",\n      \"      @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"      @Expression <AttributeReference> u_value#337 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> kind#133 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> value#136 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> date#156 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> month#315 could run on GPU\\n\",\n      \"        !Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"          @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"          @Expression <Alias> Charge AS kind#133 could run on GPU\\n\",\n      \"            @Expression <Literal> Charge could run on GPU\\n\",\n      \"          @Expression <AttributeReference> value#136 could run on GPU\\n\",\n      \"          @Expression <Alias> add_months(2022-04-05, cast(-(cast(_we0#149 as bigint) + last_month#135L) as int)) AS date#156 could run on GPU\\n\",\n      \"            ! <AddMonths> add_months(2022-04-05, cast(-(cast(_we0#149 as bigint) + last_month#135L) as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.AddMonths\\n\",\n      \"              @Expression <Literal> 2022-04-05 could run on GPU\\n\",\n      \"              @Expression <Cast> cast(-(cast(_we0#149 as bigint) + last_month#135L) as int) could run on GPU\\n\",\n      \"                @Expression <UnaryMinus> -(cast(_we0#149 as bigint) + last_month#135L) could run on GPU\\n\",\n      \"                  @Expression <Add> (cast(_we0#149 as bigint) + last_month#135L) could run on GPU\\n\",\n      \"                    @Expression <Cast> cast(_we0#149 as bigint) could run on GPU\\n\",\n      \"                      @Expression <AttributeReference> _we0#149 could run on GPU\\n\",\n      \"                    @Expression <AttributeReference> last_month#135L could run on GPU\\n\",\n      \"              !Exec <ShuffleExchangeExec> cannot run on GPU because Columnar exchange without columnar children is inefficient\\n\",\n      \"                @Partitioning <HashPartitioning> could run on GPU\\n\",\n      \"                  @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"                !Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"                  @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"                  @Expression <AttributeReference> value#136 could run on GPU\\n\",\n      \"                  @Expression <Alias> CASE WHEN (Churn#20 = Yes) THEN -((((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) % 6)) ELSE 0 END AS last_month#135L could run on GPU\\n\",\n      \"                    @Expression <CaseWhen> CASE WHEN (Churn#20 = Yes) THEN -((((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) % 6)) ELSE 0 END could run on GPU\\n\",\n      \"                      @Expression <EqualTo> (Churn#20 = Yes) could run on GPU\\n\",\n      \"                        @Expression <AttributeReference> Churn#20 could run on GPU\\n\",\n      \"                        @Expression <Literal> Yes could run on GPU\\n\",\n      \"                      @Expression <UnaryMinus> -((((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) % 6)) could run on GPU\\n\",\n      \"                        @Expression <Add> ((((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) % 6)) could run on GPU\\n\",\n      \"                          @Expression <Add> (((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10)) could run on GPU\\n\",\n      \"                            @Expression <Add> ((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) could run on GPU\\n\",\n      \"                              @Expression <Add> (((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) could run on GPU\\n\",\n      \"                                @Expression <Remainder> ((abs(xxhash64(customerID#0, 42), false) & 255) % 36) could run on GPU\\n\",\n      \"                                  @Expression <BitwiseAnd> (abs(xxhash64(customerID#0, 42), false) & 255) could run on GPU\\n\",\n      \"                                    @Expression <Abs> abs(xxhash64(customerID#0, 42), false) could run on GPU\\n\",\n      \"                                      ! <XxHash64> xxhash64(customerID#0, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                        @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"                                    @Expression <Literal> 255 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 36 could run on GPU\\n\",\n      \"                                @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24) could run on GPU\\n\",\n      \"                                  @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) could run on GPU\\n\",\n      \"                                    @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#0, 42), false), 8) could run on GPU\\n\",\n      \"                                      @Expression <Abs> abs(xxhash64(customerID#0, 42), false) could run on GPU\\n\",\n      \"                                        ! <XxHash64> xxhash64(customerID#0, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                          @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"                                      @Expression <Literal> 8 could run on GPU\\n\",\n      \"                                    @Expression <Literal> 255 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 24 could run on GPU\\n\",\n      \"                              @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14) could run on GPU\\n\",\n      \"                                @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) could run on GPU\\n\",\n      \"                                  @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#0, 42), false), 16) could run on GPU\\n\",\n      \"                                    @Expression <Abs> abs(xxhash64(customerID#0, 42), false) could run on GPU\\n\",\n      \"                                      ! <XxHash64> xxhash64(customerID#0, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                        @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"                                    @Expression <Literal> 16 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 255 could run on GPU\\n\",\n      \"                                @Expression <Literal> 14 could run on GPU\\n\",\n      \"                            @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10) could run on GPU\\n\",\n      \"                              @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) could run on GPU\\n\",\n      \"                                @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#0, 42), false), 24) could run on GPU\\n\",\n      \"                                  @Expression <Abs> abs(xxhash64(customerID#0, 42), false) could run on GPU\\n\",\n      \"                                    ! <XxHash64> xxhash64(customerID#0, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                      @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 24 could run on GPU\\n\",\n      \"                                @Expression <Literal> 255 could run on GPU\\n\",\n      \"                              @Expression <Literal> 10 could run on GPU\\n\",\n      \"                          @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) % 6) could run on GPU\\n\",\n      \"                            @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) could run on GPU\\n\",\n      \"                              @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#0, 42), false), 32) could run on GPU\\n\",\n      \"                                @Expression <Abs> abs(xxhash64(customerID#0, 42), false) could run on GPU\\n\",\n      \"                                  ! <XxHash64> xxhash64(customerID#0, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                    @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"                                @Expression <Literal> 32 could run on GPU\\n\",\n      \"                              @Expression <Literal> 255 could run on GPU\\n\",\n      \"                            @Expression <Literal> 6 could run on GPU\\n\",\n      \"                      @Expression <Literal> 0 could run on GPU\\n\",\n      \"                  !Exec <GenerateExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"                    @Expression <Explode> explode(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int))) could run on GPU\\n\",\n      \"                      ! <ArrayRepeat> array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.ArrayRepeat\\n\",\n      \"                        @Expression <Cast> cast((TotalCharges#19 / tenure#5) as decimal(8,2)) could run on GPU\\n\",\n      \"                          @Expression <Divide> (TotalCharges#19 / tenure#5) could run on GPU\\n\",\n      \"                            @Expression <AttributeReference> TotalCharges#19 could run on GPU\\n\",\n      \"                            @Expression <AttributeReference> tenure#5 could run on GPU\\n\",\n      \"                        !Expression <Cast> cast(tenure#5 as int) cannot run on GPU because Casting from floating point types to integral types on the GPU supports a slightly different range of values when using Spark 3.1.0 or later. Refer to the CAST documentation for more details.. To enable this operation on the GPU, set spark.rapids.sql.castFloatToIntegralTypes.enabled to true.\\n\",\n      \"                          @Expression <AttributeReference> tenure#5 could run on GPU\\n\",\n      \"                    @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"                    @Expression <AttributeReference> Churn#20 could run on GPU\\n\",\n      \"                      !Exec <FilterExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"                        @Expression <And> ((atleastnnonnulls(21, customerID#0, gender#1, SeniorCitizen#2, Partner#3, Dependents#4, tenure#5, PhoneService#6, MultipleLines#7, InternetService#8, OnlineSecurity#9, OnlineBackup#10, DeviceProtection#11, TechSupport#12, StreamingTV#13, StreamingMovies#14, Contract#15, PaperlessBilling#16, PaymentMethod#17, MonthlyCharges#18, TotalCharges#19, Churn#20) AND (size(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)), true) > 0)) AND isnotnull(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)))) could run on GPU\\n\",\n      \"                          @Expression <And> (atleastnnonnulls(21, customerID#0, gender#1, SeniorCitizen#2, Partner#3, Dependents#4, tenure#5, PhoneService#6, MultipleLines#7, InternetService#8, OnlineSecurity#9, OnlineBackup#10, DeviceProtection#11, TechSupport#12, StreamingTV#13, StreamingMovies#14, Contract#15, PaperlessBilling#16, PaymentMethod#17, MonthlyCharges#18, TotalCharges#19, Churn#20) AND (size(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)), true) > 0)) could run on GPU\\n\",\n      \"                            @Expression <AtLeastNNonNulls> atleastnnonnulls(21, customerID#0, gender#1, SeniorCitizen#2, Partner#3, Dependents#4, tenure#5, PhoneService#6, MultipleLines#7, InternetService#8, OnlineSecurity#9, OnlineBackup#10, DeviceProtection#11, TechSupport#12, StreamingTV#13, StreamingMovies#14, Contract#15, PaperlessBilling#16, PaymentMethod#17, MonthlyCharges#18, TotalCharges#19, Churn#20) could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> gender#1 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> SeniorCitizen#2 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> Partner#3 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> Dependents#4 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> tenure#5 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> PhoneService#6 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> MultipleLines#7 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> InternetService#8 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> OnlineSecurity#9 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> OnlineBackup#10 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> DeviceProtection#11 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> TechSupport#12 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> StreamingTV#13 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> StreamingMovies#14 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> Contract#15 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> PaperlessBilling#16 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> PaymentMethod#17 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> MonthlyCharges#18 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> TotalCharges#19 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> Churn#20 could run on GPU\\n\",\n      \"                            @Expression <GreaterThan> (size(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)), true) > 0) could run on GPU\\n\",\n      \"                              @Expression <Size> size(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)), true) could run on GPU\\n\",\n      \"                                ! <ArrayRepeat> array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.ArrayRepeat\\n\",\n      \"                                  @Expression <Cast> cast((TotalCharges#19 / tenure#5) as decimal(8,2)) could run on GPU\\n\",\n      \"                                    @Expression <Divide> (TotalCharges#19 / tenure#5) could run on GPU\\n\",\n      \"                                      @Expression <AttributeReference> TotalCharges#19 could run on GPU\\n\",\n      \"                                      @Expression <AttributeReference> tenure#5 could run on GPU\\n\",\n      \"                                  !Expression <Cast> cast(tenure#5 as int) cannot run on GPU because Casting from floating point types to integral types on the GPU supports a slightly different range of values when using Spark 3.1.0 or later. Refer to the CAST documentation for more details.. To enable this operation on the GPU, set spark.rapids.sql.castFloatToIntegralTypes.enabled to true.\\n\",\n      \"                                    @Expression <AttributeReference> tenure#5 could run on GPU\\n\",\n      \"                              @Expression <Literal> 0 could run on GPU\\n\",\n      \"                          @Expression <IsNotNull> isnotnull(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int))) could run on GPU\\n\",\n      \"                            ! <ArrayRepeat> array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.ArrayRepeat\\n\",\n      \"                              @Expression <Cast> cast((TotalCharges#19 / tenure#5) as decimal(8,2)) could run on GPU\\n\",\n      \"                                @Expression <Divide> (TotalCharges#19 / tenure#5) could run on GPU\\n\",\n      \"                                  @Expression <AttributeReference> TotalCharges#19 could run on GPU\\n\",\n      \"                                  @Expression <AttributeReference> tenure#5 could run on GPU\\n\",\n      \"                              !Expression <Cast> cast(tenure#5 as int) cannot run on GPU because Casting from floating point types to integral types on the GPU supports a slightly different range of values when using Spark 3.1.0 or later. Refer to the CAST documentation for more details.. To enable this operation on the GPU, set spark.rapids.sql.castFloatToIntegralTypes.enabled to true.\\n\",\n      \"                                @Expression <AttributeReference> tenure#5 could run on GPU\\n\",\n      \"        !Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"          @Expression <AttributeReference> customerID#265 could run on GPU\\n\",\n      \"          @Expression <Alias> AccountCreation AS kind#191 could run on GPU\\n\",\n      \"            @Expression <Literal> AccountCreation could run on GPU\\n\",\n      \"          @Expression <Alias> 0.00 AS value#192 could run on GPU\\n\",\n      \"            @Expression <Literal> 0.00 could run on GPU\\n\",\n      \"          @Expression <Alias> add_months(2022-04-05, cast(((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) as int)) AS date#200 could run on GPU\\n\",\n      \"            ! <AddMonths> add_months(2022-04-05, cast(((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.AddMonths\\n\",\n      \"              @Expression <Literal> 2022-04-05 could run on GPU\\n\",\n      \"              !Expression <Cast> cast(((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) as int) cannot run on GPU because Casting from floating point types to integral types on the GPU supports a slightly different range of values when using Spark 3.1.0 or later. Refer to the CAST documentation for more details.. To enable this operation on the GPU, set spark.rapids.sql.castFloatToIntegralTypes.enabled to true.\\n\",\n      \"                @Expression <Add> ((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) could run on GPU\\n\",\n      \"                  @Expression <Subtract> (-tenure#270 - 1.0) could run on GPU\\n\",\n      \"                    @Expression <UnaryMinus> -tenure#270 could run on GPU\\n\",\n      \"                      @Expression <AttributeReference> tenure#270 could run on GPU\\n\",\n      \"                    @Expression <Literal> 1.0 could run on GPU\\n\",\n      \"                  @Expression <CaseWhen> CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END could run on GPU\\n\",\n      \"                    @Expression <EqualTo> (Churn#285 = Yes) could run on GPU\\n\",\n      \"                      @Expression <AttributeReference> Churn#285 could run on GPU\\n\",\n      \"                      @Expression <Literal> Yes could run on GPU\\n\",\n      \"                    @Expression <Cast> cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) could run on GPU\\n\",\n      \"                      @Expression <UnaryMinus> -((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) could run on GPU\\n\",\n      \"                        @Expression <Add> ((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) could run on GPU\\n\",\n      \"                          @Expression <Add> (((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) could run on GPU\\n\",\n      \"                            @Expression <Add> ((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) could run on GPU\\n\",\n      \"                              @Expression <Add> (((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) could run on GPU\\n\",\n      \"                                @Expression <Remainder> ((abs(xxhash64(customerID#265, 42), false) & 255) % 36) could run on GPU\\n\",\n      \"                                  @Expression <BitwiseAnd> (abs(xxhash64(customerID#265, 42), false) & 255) could run on GPU\\n\",\n      \"                                    @Expression <Abs> abs(xxhash64(customerID#265, 42), false) could run on GPU\\n\",\n      \"                                      ! <XxHash64> xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                        @Expression <AttributeReference> customerID#265 could run on GPU\\n\",\n      \"                                    @Expression <Literal> 255 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 36 could run on GPU\\n\",\n      \"                                @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24) could run on GPU\\n\",\n      \"                                  @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) could run on GPU\\n\",\n      \"                                    @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#265, 42), false), 8) could run on GPU\\n\",\n      \"                                      @Expression <Abs> abs(xxhash64(customerID#265, 42), false) could run on GPU\\n\",\n      \"                                        ! <XxHash64> xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                          @Expression <AttributeReference> customerID#265 could run on GPU\\n\",\n      \"                                      @Expression <Literal> 8 could run on GPU\\n\",\n      \"                                    @Expression <Literal> 255 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 24 could run on GPU\\n\",\n      \"                              @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14) could run on GPU\\n\",\n      \"                                @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) could run on GPU\\n\",\n      \"                                  @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#265, 42), false), 16) could run on GPU\\n\",\n      \"                                    @Expression <Abs> abs(xxhash64(customerID#265, 42), false) could run on GPU\\n\",\n      \"                                      ! <XxHash64> xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                        @Expression <AttributeReference> customerID#265 could run on GPU\\n\",\n      \"                                    @Expression <Literal> 16 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 255 could run on GPU\\n\",\n      \"                                @Expression <Literal> 14 could run on GPU\\n\",\n      \"                            @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10) could run on GPU\\n\",\n      \"                              @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) could run on GPU\\n\",\n      \"                                @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#265, 42), false), 24) could run on GPU\\n\",\n      \"                                  @Expression <Abs> abs(xxhash64(customerID#265, 42), false) could run on GPU\\n\",\n      \"                                    ! <XxHash64> xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                      @Expression <AttributeReference> customerID#265 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 24 could run on GPU\\n\",\n      \"                                @Expression <Literal> 255 could run on GPU\\n\",\n      \"                              @Expression <Literal> 10 could run on GPU\\n\",\n      \"                          @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6) could run on GPU\\n\",\n      \"                            @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) could run on GPU\\n\",\n      \"                              @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#265, 42), false), 32) could run on GPU\\n\",\n      \"                                @Expression <Abs> abs(xxhash64(customerID#265, 42), false) could run on GPU\\n\",\n      \"                                  ! <XxHash64> xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                    @Expression <AttributeReference> customerID#265 could run on GPU\\n\",\n      \"                                @Expression <Literal> 32 could run on GPU\\n\",\n      \"                              @Expression <Literal> 255 could run on GPU\\n\",\n      \"                            @Expression <Literal> 6 could run on GPU\\n\",\n      \"                    @Expression <Literal> 0.0 could run on GPU\\n\",\n      \"        !Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"          @Expression <AttributeReference> customerID#290 could run on GPU\\n\",\n      \"          @Expression <Alias> AccountTermination AS kind#258 could run on GPU\\n\",\n      \"            @Expression <Literal> AccountTermination could run on GPU\\n\",\n      \"          @Expression <Alias> 0.00 AS value#259 could run on GPU\\n\",\n      \"            @Expression <Literal> 0.00 could run on GPU\\n\",\n      \"          @Expression <Alias> CASE WHEN (Churn#310 = Yes) THEN add_months(2022-04-05, cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int)) ELSE 2022-04-05 END AS date#260 could run on GPU\\n\",\n      \"            @Expression <CaseWhen> CASE WHEN (Churn#310 = Yes) THEN add_months(2022-04-05, cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int)) ELSE 2022-04-05 END could run on GPU\\n\",\n      \"              @Expression <EqualTo> (Churn#310 = Yes) could run on GPU\\n\",\n      \"                @Expression <AttributeReference> Churn#310 could run on GPU\\n\",\n      \"                @Expression <Literal> Yes could run on GPU\\n\",\n      \"              ! <AddMonths> add_months(2022-04-05, cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.AddMonths\\n\",\n      \"                @Expression <Literal> 2022-04-05 could run on GPU\\n\",\n      \"                @Expression <Cast> cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int) could run on GPU\\n\",\n      \"                  @Expression <UnaryMinus> -((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) could run on GPU\\n\",\n      \"                    @Expression <Add> ((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) could run on GPU\\n\",\n      \"                      @Expression <Add> (((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) could run on GPU\\n\",\n      \"                        @Expression <Add> ((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) could run on GPU\\n\",\n      \"                          @Expression <Add> (((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) could run on GPU\\n\",\n      \"                            @Expression <Remainder> ((abs(xxhash64(customerID#290, 42), false) & 255) % 36) could run on GPU\\n\",\n      \"                              @Expression <BitwiseAnd> (abs(xxhash64(customerID#290, 42), false) & 255) could run on GPU\\n\",\n      \"                                @Expression <Abs> abs(xxhash64(customerID#290, 42), false) could run on GPU\\n\",\n      \"                                  ! <XxHash64> xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                    @Expression <AttributeReference> customerID#290 could run on GPU\\n\",\n      \"                                @Expression <Literal> 255 could run on GPU\\n\",\n      \"                              @Expression <Literal> 36 could run on GPU\\n\",\n      \"                            @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24) could run on GPU\\n\",\n      \"                              @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) could run on GPU\\n\",\n      \"                                @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#290, 42), false), 8) could run on GPU\\n\",\n      \"                                  @Expression <Abs> abs(xxhash64(customerID#290, 42), false) could run on GPU\\n\",\n      \"                                    ! <XxHash64> xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                      @Expression <AttributeReference> customerID#290 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 8 could run on GPU\\n\",\n      \"                                @Expression <Literal> 255 could run on GPU\\n\",\n      \"                              @Expression <Literal> 24 could run on GPU\\n\",\n      \"                          @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14) could run on GPU\\n\",\n      \"                            @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) could run on GPU\\n\",\n      \"                              @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#290, 42), false), 16) could run on GPU\\n\",\n      \"                                @Expression <Abs> abs(xxhash64(customerID#290, 42), false) could run on GPU\\n\",\n      \"                                  ! <XxHash64> xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                    @Expression <AttributeReference> customerID#290 could run on GPU\\n\",\n      \"                                @Expression <Literal> 16 could run on GPU\\n\",\n      \"                              @Expression <Literal> 255 could run on GPU\\n\",\n      \"                            @Expression <Literal> 14 could run on GPU\\n\",\n      \"                        @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10) could run on GPU\\n\",\n      \"                          @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) could run on GPU\\n\",\n      \"                            @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#290, 42), false), 24) could run on GPU\\n\",\n      \"                              @Expression <Abs> abs(xxhash64(customerID#290, 42), false) could run on GPU\\n\",\n      \"                                ! <XxHash64> xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                  @Expression <AttributeReference> customerID#290 could run on GPU\\n\",\n      \"                              @Expression <Literal> 24 could run on GPU\\n\",\n      \"                            @Expression <Literal> 255 could run on GPU\\n\",\n      \"                          @Expression <Literal> 10 could run on GPU\\n\",\n      \"                      @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6) could run on GPU\\n\",\n      \"                        @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) could run on GPU\\n\",\n      \"                          @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#290, 42), false), 32) could run on GPU\\n\",\n      \"                            @Expression <Abs> abs(xxhash64(customerID#290, 42), false) could run on GPU\\n\",\n      \"                              ! <XxHash64> xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                @Expression <AttributeReference> customerID#290 could run on GPU\\n\",\n      \"                            @Expression <Literal> 32 could run on GPU\\n\",\n      \"                          @Expression <Literal> 255 could run on GPU\\n\",\n      \"                        @Expression <Literal> 6 could run on GPU\\n\",\n      \"              @Expression <Literal> 2022-04-05 could run on GPU\\n\",\n      \"\\n\",\n      \"2022-04-05 09:36:37,305 WARN rapids.GpuOverrides: \\n\",\n      \"!Exec <ShuffleExchangeExec> cannot run on GPU because Columnar exchange without columnar children is inefficient\\n\",\n      \"  @Partitioning <HashPartitioning> could run on GPU\\n\",\n      \"    @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"  !Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"    @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> value#136 could run on GPU\\n\",\n      \"    @Expression <Alias> CASE WHEN (Churn#20 = Yes) THEN -((((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) % 6)) ELSE 0 END AS last_month#135L could run on GPU\\n\",\n      \"      @Expression <CaseWhen> CASE WHEN (Churn#20 = Yes) THEN -((((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) % 6)) ELSE 0 END could run on GPU\\n\",\n      \"        @Expression <EqualTo> (Churn#20 = Yes) could run on GPU\\n\",\n      \"          @Expression <AttributeReference> Churn#20 could run on GPU\\n\",\n      \"          @Expression <Literal> Yes could run on GPU\\n\",\n      \"        @Expression <UnaryMinus> -((((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) % 6)) could run on GPU\\n\",\n      \"          @Expression <Add> ((((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) % 6)) could run on GPU\\n\",\n      \"            @Expression <Add> (((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10)) could run on GPU\\n\",\n      \"              @Expression <Add> ((((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14)) could run on GPU\\n\",\n      \"                @Expression <Add> (((abs(xxhash64(customerID#0, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24)) could run on GPU\\n\",\n      \"                  @Expression <Remainder> ((abs(xxhash64(customerID#0, 42), false) & 255) % 36) could run on GPU\\n\",\n      \"                    @Expression <BitwiseAnd> (abs(xxhash64(customerID#0, 42), false) & 255) could run on GPU\\n\",\n      \"                      @Expression <Abs> abs(xxhash64(customerID#0, 42), false) could run on GPU\\n\",\n      \"                        ! <XxHash64> xxhash64(customerID#0, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                          @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"                      @Expression <Literal> 255 could run on GPU\\n\",\n      \"                    @Expression <Literal> 36 could run on GPU\\n\",\n      \"                  @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) % 24) could run on GPU\\n\",\n      \"                    @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#0, 42), false), 8) & 255) could run on GPU\\n\",\n      \"                      @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#0, 42), false), 8) could run on GPU\\n\",\n      \"                        @Expression <Abs> abs(xxhash64(customerID#0, 42), false) could run on GPU\\n\",\n      \"                          ! <XxHash64> xxhash64(customerID#0, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                            @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"                        @Expression <Literal> 8 could run on GPU\\n\",\n      \"                      @Expression <Literal> 255 could run on GPU\\n\",\n      \"                    @Expression <Literal> 24 could run on GPU\\n\",\n      \"                @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) % 14) could run on GPU\\n\",\n      \"                  @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#0, 42), false), 16) & 255) could run on GPU\\n\",\n      \"                    @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#0, 42), false), 16) could run on GPU\\n\",\n      \"                      @Expression <Abs> abs(xxhash64(customerID#0, 42), false) could run on GPU\\n\",\n      \"                        ! <XxHash64> xxhash64(customerID#0, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                          @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"                      @Expression <Literal> 16 could run on GPU\\n\",\n      \"                    @Expression <Literal> 255 could run on GPU\\n\",\n      \"                  @Expression <Literal> 14 could run on GPU\\n\",\n      \"              @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) % 10) could run on GPU\\n\",\n      \"                @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#0, 42), false), 24) & 255) could run on GPU\\n\",\n      \"                  @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#0, 42), false), 24) could run on GPU\\n\",\n      \"                    @Expression <Abs> abs(xxhash64(customerID#0, 42), false) could run on GPU\\n\",\n      \"                      ! <XxHash64> xxhash64(customerID#0, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                        @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"                    @Expression <Literal> 24 could run on GPU\\n\",\n      \"                  @Expression <Literal> 255 could run on GPU\\n\",\n      \"                @Expression <Literal> 10 could run on GPU\\n\",\n      \"            @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) % 6) could run on GPU\\n\",\n      \"              @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#0, 42), false), 32) & 255) could run on GPU\\n\",\n      \"                @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#0, 42), false), 32) could run on GPU\\n\",\n      \"                  @Expression <Abs> abs(xxhash64(customerID#0, 42), false) could run on GPU\\n\",\n      \"                    ! <XxHash64> xxhash64(customerID#0, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                      @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"                  @Expression <Literal> 32 could run on GPU\\n\",\n      \"                @Expression <Literal> 255 could run on GPU\\n\",\n      \"              @Expression <Literal> 6 could run on GPU\\n\",\n      \"        @Expression <Literal> 0 could run on GPU\\n\",\n      \"    !Exec <GenerateExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"      @Expression <Explode> explode(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int))) could run on GPU\\n\",\n      \"        ! <ArrayRepeat> array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.ArrayRepeat\\n\",\n      \"          @Expression <Cast> cast((TotalCharges#19 / tenure#5) as decimal(8,2)) could run on GPU\\n\",\n      \"            @Expression <Divide> (TotalCharges#19 / tenure#5) could run on GPU\\n\",\n      \"              @Expression <AttributeReference> TotalCharges#19 could run on GPU\\n\",\n      \"              @Expression <AttributeReference> tenure#5 could run on GPU\\n\",\n      \"          !Expression <Cast> cast(tenure#5 as int) cannot run on GPU because Casting from floating point types to integral types on the GPU supports a slightly different range of values when using Spark 3.1.0 or later. Refer to the CAST documentation for more details.. To enable this operation on the GPU, set spark.rapids.sql.castFloatToIntegralTypes.enabled to true.\\n\",\n      \"            @Expression <AttributeReference> tenure#5 could run on GPU\\n\",\n      \"      @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"      @Expression <AttributeReference> Churn#20 could run on GPU\\n\",\n      \"        !Exec <FilterExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"          @Expression <And> ((atleastnnonnulls(21, customerID#0, gender#1, SeniorCitizen#2, Partner#3, Dependents#4, tenure#5, PhoneService#6, MultipleLines#7, InternetService#8, OnlineSecurity#9, OnlineBackup#10, DeviceProtection#11, TechSupport#12, StreamingTV#13, StreamingMovies#14, Contract#15, PaperlessBilling#16, PaymentMethod#17, MonthlyCharges#18, TotalCharges#19, Churn#20) AND (size(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)), true) > 0)) AND isnotnull(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)))) could run on GPU\\n\",\n      \"            @Expression <And> (atleastnnonnulls(21, customerID#0, gender#1, SeniorCitizen#2, Partner#3, Dependents#4, tenure#5, PhoneService#6, MultipleLines#7, InternetService#8, OnlineSecurity#9, OnlineBackup#10, DeviceProtection#11, TechSupport#12, StreamingTV#13, StreamingMovies#14, Contract#15, PaperlessBilling#16, PaymentMethod#17, MonthlyCharges#18, TotalCharges#19, Churn#20) AND (size(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)), true) > 0)) could run on GPU\\n\",\n      \"              @Expression <AtLeastNNonNulls> atleastnnonnulls(21, customerID#0, gender#1, SeniorCitizen#2, Partner#3, Dependents#4, tenure#5, PhoneService#6, MultipleLines#7, InternetService#8, OnlineSecurity#9, OnlineBackup#10, DeviceProtection#11, TechSupport#12, StreamingTV#13, StreamingMovies#14, Contract#15, PaperlessBilling#16, PaymentMethod#17, MonthlyCharges#18, TotalCharges#19, Churn#20) could run on GPU\\n\",\n      \"                @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"                @Expression <AttributeReference> gender#1 could run on GPU\\n\",\n      \"                @Expression <AttributeReference> SeniorCitizen#2 could run on GPU\\n\",\n      \"                @Expression <AttributeReference> Partner#3 could run on GPU\\n\",\n      \"                @Expression <AttributeReference> Dependents#4 could run on GPU\\n\",\n      \"                @Expression <AttributeReference> tenure#5 could run on GPU\\n\",\n      \"                @Expression <AttributeReference> PhoneService#6 could run on GPU\\n\",\n      \"                @Expression <AttributeReference> MultipleLines#7 could run on GPU\\n\",\n      \"                @Expression <AttributeReference> InternetService#8 could run on GPU\\n\",\n      \"                @Expression <AttributeReference> OnlineSecurity#9 could run on GPU\\n\",\n      \"                @Expression <AttributeReference> OnlineBackup#10 could run on GPU\\n\",\n      \"                @Expression <AttributeReference> DeviceProtection#11 could run on GPU\\n\",\n      \"                @Expression <AttributeReference> TechSupport#12 could run on GPU\\n\",\n      \"                @Expression <AttributeReference> StreamingTV#13 could run on GPU\\n\",\n      \"                @Expression <AttributeReference> StreamingMovies#14 could run on GPU\\n\",\n      \"                @Expression <AttributeReference> Contract#15 could run on GPU\\n\",\n      \"                @Expression <AttributeReference> PaperlessBilling#16 could run on GPU\\n\",\n      \"                @Expression <AttributeReference> PaymentMethod#17 could run on GPU\\n\",\n      \"                @Expression <AttributeReference> MonthlyCharges#18 could run on GPU\\n\",\n      \"                @Expression <AttributeReference> TotalCharges#19 could run on GPU\\n\",\n      \"                @Expression <AttributeReference> Churn#20 could run on GPU\\n\",\n      \"              @Expression <GreaterThan> (size(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)), true) > 0) could run on GPU\\n\",\n      \"                @Expression <Size> size(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)), true) could run on GPU\\n\",\n      \"                  ! <ArrayRepeat> array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.ArrayRepeat\\n\",\n      \"                    @Expression <Cast> cast((TotalCharges#19 / tenure#5) as decimal(8,2)) could run on GPU\\n\",\n      \"                      @Expression <Divide> (TotalCharges#19 / tenure#5) could run on GPU\\n\",\n      \"                        @Expression <AttributeReference> TotalCharges#19 could run on GPU\\n\",\n      \"                        @Expression <AttributeReference> tenure#5 could run on GPU\\n\",\n      \"                    !Expression <Cast> cast(tenure#5 as int) cannot run on GPU because Casting from floating point types to integral types on the GPU supports a slightly different range of values when using Spark 3.1.0 or later. Refer to the CAST documentation for more details.. To enable this operation on the GPU, set spark.rapids.sql.castFloatToIntegralTypes.enabled to true.\\n\",\n      \"                      @Expression <AttributeReference> tenure#5 could run on GPU\\n\",\n      \"                @Expression <Literal> 0 could run on GPU\\n\",\n      \"            @Expression <IsNotNull> isnotnull(array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int))) could run on GPU\\n\",\n      \"              ! <ArrayRepeat> array_repeat(cast((TotalCharges#19 / tenure#5) as decimal(8,2)), cast(tenure#5 as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.ArrayRepeat\\n\",\n      \"                @Expression <Cast> cast((TotalCharges#19 / tenure#5) as decimal(8,2)) could run on GPU\\n\",\n      \"                  @Expression <Divide> (TotalCharges#19 / tenure#5) could run on GPU\\n\",\n      \"                    @Expression <AttributeReference> TotalCharges#19 could run on GPU\\n\",\n      \"                    @Expression <AttributeReference> tenure#5 could run on GPU\\n\",\n      \"                !Expression <Cast> cast(tenure#5 as int) cannot run on GPU because Casting from floating point types to integral types on the GPU supports a slightly different range of values when using Spark 3.1.0 or later. Refer to the CAST documentation for more details.. To enable this operation on the GPU, set spark.rapids.sql.castFloatToIntegralTypes.enabled to true.\\n\",\n      \"                  @Expression <AttributeReference> tenure#5 could run on GPU\\n\",\n      \"\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2022-04-05 09:36:37,476 WARN rapids.GpuOverrides: \\n\",\n      \"!Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"  @Expression <Alias> format_string(%s-%s, customerID#0, u_value#337) AS customerID#816 could run on GPU\\n\",\n      \"    ! <FormatString> format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\\n\",\n      \"      @Expression <Literal> %s-%s could run on GPU\\n\",\n      \"      @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"      @Expression <AttributeReference> u_value#337 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> kind#133 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> value#136 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> date#156 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> month#315 could run on GPU\\n\",\n      \"        !Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"          @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"          @Expression <Alias> Charge AS kind#133 could run on GPU\\n\",\n      \"            @Expression <Literal> Charge could run on GPU\\n\",\n      \"          @Expression <AttributeReference> value#136 could run on GPU\\n\",\n      \"          @Expression <Alias> add_months(2022-04-05, cast(-(cast(_we0#149 as bigint) + last_month#135L) as int)) AS date#156 could run on GPU\\n\",\n      \"            ! <AddMonths> add_months(2022-04-05, cast(-(cast(_we0#149 as bigint) + last_month#135L) as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.AddMonths\\n\",\n      \"              @Expression <Literal> 2022-04-05 could run on GPU\\n\",\n      \"              @Expression <Cast> cast(-(cast(_we0#149 as bigint) + last_month#135L) as int) could run on GPU\\n\",\n      \"                @Expression <UnaryMinus> -(cast(_we0#149 as bigint) + last_month#135L) could run on GPU\\n\",\n      \"                  @Expression <Add> (cast(_we0#149 as bigint) + last_month#135L) could run on GPU\\n\",\n      \"                    @Expression <Cast> cast(_we0#149 as bigint) could run on GPU\\n\",\n      \"                      @Expression <AttributeReference> _we0#149 could run on GPU\\n\",\n      \"                    @Expression <AttributeReference> last_month#135L could run on GPU\\n\",\n      \"        !Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"          @Expression <AttributeReference> customerID#265 could run on GPU\\n\",\n      \"          @Expression <Alias> AccountCreation AS kind#191 could run on GPU\\n\",\n      \"            @Expression <Literal> AccountCreation could run on GPU\\n\",\n      \"          @Expression <Alias> 0.00 AS value#192 could run on GPU\\n\",\n      \"            @Expression <Literal> 0.00 could run on GPU\\n\",\n      \"          @Expression <Alias> add_months(2022-04-05, cast(((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) as int)) AS date#200 could run on GPU\\n\",\n      \"            ! <AddMonths> add_months(2022-04-05, cast(((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.AddMonths\\n\",\n      \"              @Expression <Literal> 2022-04-05 could run on GPU\\n\",\n      \"              !Expression <Cast> cast(((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) as int) cannot run on GPU because Casting from floating point types to integral types on the GPU supports a slightly different range of values when using Spark 3.1.0 or later. Refer to the CAST documentation for more details.. To enable this operation on the GPU, set spark.rapids.sql.castFloatToIntegralTypes.enabled to true.\\n\",\n      \"                @Expression <Add> ((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) could run on GPU\\n\",\n      \"                  @Expression <Subtract> (-tenure#270 - 1.0) could run on GPU\\n\",\n      \"                    @Expression <UnaryMinus> -tenure#270 could run on GPU\\n\",\n      \"                      @Expression <AttributeReference> tenure#270 could run on GPU\\n\",\n      \"                    @Expression <Literal> 1.0 could run on GPU\\n\",\n      \"                  @Expression <CaseWhen> CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END could run on GPU\\n\",\n      \"                    @Expression <EqualTo> (Churn#285 = Yes) could run on GPU\\n\",\n      \"                      @Expression <AttributeReference> Churn#285 could run on GPU\\n\",\n      \"                      @Expression <Literal> Yes could run on GPU\\n\",\n      \"                    @Expression <Cast> cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) could run on GPU\\n\",\n      \"                      @Expression <UnaryMinus> -((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) could run on GPU\\n\",\n      \"                        @Expression <Add> ((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) could run on GPU\\n\",\n      \"                          @Expression <Add> (((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) could run on GPU\\n\",\n      \"                            @Expression <Add> ((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) could run on GPU\\n\",\n      \"                              @Expression <Add> (((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) could run on GPU\\n\",\n      \"                                @Expression <Remainder> ((abs(xxhash64(customerID#265, 42), false) & 255) % 36) could run on GPU\\n\",\n      \"                                  @Expression <BitwiseAnd> (abs(xxhash64(customerID#265, 42), false) & 255) could run on GPU\\n\",\n      \"                                    @Expression <Abs> abs(xxhash64(customerID#265, 42), false) could run on GPU\\n\",\n      \"                                      ! <XxHash64> xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                        @Expression <AttributeReference> customerID#265 could run on GPU\\n\",\n      \"                                    @Expression <Literal> 255 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 36 could run on GPU\\n\",\n      \"                                @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24) could run on GPU\\n\",\n      \"                                  @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) could run on GPU\\n\",\n      \"                                    @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#265, 42), false), 8) could run on GPU\\n\",\n      \"                                      @Expression <Abs> abs(xxhash64(customerID#265, 42), false) could run on GPU\\n\",\n      \"                                        ! <XxHash64> xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                          @Expression <AttributeReference> customerID#265 could run on GPU\\n\",\n      \"                                      @Expression <Literal> 8 could run on GPU\\n\",\n      \"                                    @Expression <Literal> 255 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 24 could run on GPU\\n\",\n      \"                              @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14) could run on GPU\\n\",\n      \"                                @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) could run on GPU\\n\",\n      \"                                  @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#265, 42), false), 16) could run on GPU\\n\",\n      \"                                    @Expression <Abs> abs(xxhash64(customerID#265, 42), false) could run on GPU\\n\",\n      \"                                      ! <XxHash64> xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                        @Expression <AttributeReference> customerID#265 could run on GPU\\n\",\n      \"                                    @Expression <Literal> 16 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 255 could run on GPU\\n\",\n      \"                                @Expression <Literal> 14 could run on GPU\\n\",\n      \"                            @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10) could run on GPU\\n\",\n      \"                              @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) could run on GPU\\n\",\n      \"                                @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#265, 42), false), 24) could run on GPU\\n\",\n      \"                                  @Expression <Abs> abs(xxhash64(customerID#265, 42), false) could run on GPU\\n\",\n      \"                                    ! <XxHash64> xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                      @Expression <AttributeReference> customerID#265 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 24 could run on GPU\\n\",\n      \"                                @Expression <Literal> 255 could run on GPU\\n\",\n      \"                              @Expression <Literal> 10 could run on GPU\\n\",\n      \"                          @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6) could run on GPU\\n\",\n      \"                            @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) could run on GPU\\n\",\n      \"                              @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#265, 42), false), 32) could run on GPU\\n\",\n      \"                                @Expression <Abs> abs(xxhash64(customerID#265, 42), false) could run on GPU\\n\",\n      \"                                  ! <XxHash64> xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                    @Expression <AttributeReference> customerID#265 could run on GPU\\n\",\n      \"                                @Expression <Literal> 32 could run on GPU\\n\",\n      \"                              @Expression <Literal> 255 could run on GPU\\n\",\n      \"                            @Expression <Literal> 6 could run on GPU\\n\",\n      \"                    @Expression <Literal> 0.0 could run on GPU\\n\",\n      \"        !Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"          @Expression <AttributeReference> customerID#290 could run on GPU\\n\",\n      \"          @Expression <Alias> AccountTermination AS kind#258 could run on GPU\\n\",\n      \"            @Expression <Literal> AccountTermination could run on GPU\\n\",\n      \"          @Expression <Alias> 0.00 AS value#259 could run on GPU\\n\",\n      \"            @Expression <Literal> 0.00 could run on GPU\\n\",\n      \"          @Expression <Alias> CASE WHEN (Churn#310 = Yes) THEN add_months(2022-04-05, cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int)) ELSE 2022-04-05 END AS date#260 could run on GPU\\n\",\n      \"            @Expression <CaseWhen> CASE WHEN (Churn#310 = Yes) THEN add_months(2022-04-05, cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int)) ELSE 2022-04-05 END could run on GPU\\n\",\n      \"              @Expression <EqualTo> (Churn#310 = Yes) could run on GPU\\n\",\n      \"                @Expression <AttributeReference> Churn#310 could run on GPU\\n\",\n      \"                @Expression <Literal> Yes could run on GPU\\n\",\n      \"              ! <AddMonths> add_months(2022-04-05, cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.AddMonths\\n\",\n      \"                @Expression <Literal> 2022-04-05 could run on GPU\\n\",\n      \"                @Expression <Cast> cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int) could run on GPU\\n\",\n      \"                  @Expression <UnaryMinus> -((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) could run on GPU\\n\",\n      \"                    @Expression <Add> ((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) could run on GPU\\n\",\n      \"                      @Expression <Add> (((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) could run on GPU\\n\",\n      \"                        @Expression <Add> ((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) could run on GPU\\n\",\n      \"                          @Expression <Add> (((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) could run on GPU\\n\",\n      \"                            @Expression <Remainder> ((abs(xxhash64(customerID#290, 42), false) & 255) % 36) could run on GPU\\n\",\n      \"                              @Expression <BitwiseAnd> (abs(xxhash64(customerID#290, 42), false) & 255) could run on GPU\\n\",\n      \"                                @Expression <Abs> abs(xxhash64(customerID#290, 42), false) could run on GPU\\n\",\n      \"                                  ! <XxHash64> xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                    @Expression <AttributeReference> customerID#290 could run on GPU\\n\",\n      \"                                @Expression <Literal> 255 could run on GPU\\n\",\n      \"                              @Expression <Literal> 36 could run on GPU\\n\",\n      \"                            @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24) could run on GPU\\n\",\n      \"                              @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) could run on GPU\\n\",\n      \"                                @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#290, 42), false), 8) could run on GPU\\n\",\n      \"                                  @Expression <Abs> abs(xxhash64(customerID#290, 42), false) could run on GPU\\n\",\n      \"                                    ! <XxHash64> xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                      @Expression <AttributeReference> customerID#290 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 8 could run on GPU\\n\",\n      \"                                @Expression <Literal> 255 could run on GPU\\n\",\n      \"                              @Expression <Literal> 24 could run on GPU\\n\",\n      \"                          @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14) could run on GPU\\n\",\n      \"                            @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) could run on GPU\\n\",\n      \"                              @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#290, 42), false), 16) could run on GPU\\n\",\n      \"                                @Expression <Abs> abs(xxhash64(customerID#290, 42), false) could run on GPU\\n\",\n      \"                                  ! <XxHash64> xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                    @Expression <AttributeReference> customerID#290 could run on GPU\\n\",\n      \"                                @Expression <Literal> 16 could run on GPU\\n\",\n      \"                              @Expression <Literal> 255 could run on GPU\\n\",\n      \"                            @Expression <Literal> 14 could run on GPU\\n\",\n      \"                        @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10) could run on GPU\\n\",\n      \"                          @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) could run on GPU\\n\",\n      \"                            @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#290, 42), false), 24) could run on GPU\\n\",\n      \"                              @Expression <Abs> abs(xxhash64(customerID#290, 42), false) could run on GPU\\n\",\n      \"                                ! <XxHash64> xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                  @Expression <AttributeReference> customerID#290 could run on GPU\\n\",\n      \"                              @Expression <Literal> 24 could run on GPU\\n\",\n      \"                            @Expression <Literal> 255 could run on GPU\\n\",\n      \"                          @Expression <Literal> 10 could run on GPU\\n\",\n      \"                      @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6) could run on GPU\\n\",\n      \"                        @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) could run on GPU\\n\",\n      \"                          @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#290, 42), false), 32) could run on GPU\\n\",\n      \"                            @Expression <Abs> abs(xxhash64(customerID#290, 42), false) could run on GPU\\n\",\n      \"                              ! <XxHash64> xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                @Expression <AttributeReference> customerID#290 could run on GPU\\n\",\n      \"                            @Expression <Literal> 32 could run on GPU\\n\",\n      \"                          @Expression <Literal> 255 could run on GPU\\n\",\n      \"                        @Expression <Literal> 6 could run on GPU\\n\",\n      \"              @Expression <Literal> 2022-04-05 could run on GPU\\n\",\n      \"\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2022-04-05 09:36:37,897 WARN rapids.GpuOverrides: \\n\",\n      \"!Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"  @Expression <Alias> format_string(%s-%s, customerID#0, u_value#337) AS customerID#816 could run on GPU\\n\",\n      \"    ! <FormatString> format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\\n\",\n      \"      @Expression <Literal> %s-%s could run on GPU\\n\",\n      \"      @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"      @Expression <AttributeReference> u_value#337 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> kind#133 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> value#136 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> date#156 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> month#315 could run on GPU\\n\",\n      \"        !Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"          @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"          @Expression <Alias> Charge AS kind#133 could run on GPU\\n\",\n      \"            @Expression <Literal> Charge could run on GPU\\n\",\n      \"          @Expression <AttributeReference> value#136 could run on GPU\\n\",\n      \"          @Expression <Alias> add_months(2022-04-05, cast(-(cast(_we0#149 as bigint) + last_month#135L) as int)) AS date#156 could run on GPU\\n\",\n      \"            ! <AddMonths> add_months(2022-04-05, cast(-(cast(_we0#149 as bigint) + last_month#135L) as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.AddMonths\\n\",\n      \"              @Expression <Literal> 2022-04-05 could run on GPU\\n\",\n      \"              @Expression <Cast> cast(-(cast(_we0#149 as bigint) + last_month#135L) as int) could run on GPU\\n\",\n      \"                @Expression <UnaryMinus> -(cast(_we0#149 as bigint) + last_month#135L) could run on GPU\\n\",\n      \"                  @Expression <Add> (cast(_we0#149 as bigint) + last_month#135L) could run on GPU\\n\",\n      \"                    @Expression <Cast> cast(_we0#149 as bigint) could run on GPU\\n\",\n      \"                      @Expression <AttributeReference> _we0#149 could run on GPU\\n\",\n      \"                    @Expression <AttributeReference> last_month#135L could run on GPU\\n\",\n      \"        !Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"          @Expression <AttributeReference> customerID#265 could run on GPU\\n\",\n      \"          @Expression <Alias> AccountCreation AS kind#191 could run on GPU\\n\",\n      \"            @Expression <Literal> AccountCreation could run on GPU\\n\",\n      \"          @Expression <Alias> 0.00 AS value#192 could run on GPU\\n\",\n      \"            @Expression <Literal> 0.00 could run on GPU\\n\",\n      \"          @Expression <Alias> add_months(2022-04-05, cast(((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) as int)) AS date#200 could run on GPU\\n\",\n      \"            ! <AddMonths> add_months(2022-04-05, cast(((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.AddMonths\\n\",\n      \"              @Expression <Literal> 2022-04-05 could run on GPU\\n\",\n      \"              !Expression <Cast> cast(((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) as int) cannot run on GPU because Casting from floating point types to integral types on the GPU supports a slightly different range of values when using Spark 3.1.0 or later. Refer to the CAST documentation for more details.. To enable this operation on the GPU, set spark.rapids.sql.castFloatToIntegralTypes.enabled to true.\\n\",\n      \"                @Expression <Add> ((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) could run on GPU\\n\",\n      \"                  @Expression <Subtract> (-tenure#270 - 1.0) could run on GPU\\n\",\n      \"                    @Expression <UnaryMinus> -tenure#270 could run on GPU\\n\",\n      \"                      @Expression <AttributeReference> tenure#270 could run on GPU\\n\",\n      \"                    @Expression <Literal> 1.0 could run on GPU\\n\",\n      \"                  @Expression <CaseWhen> CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END could run on GPU\\n\",\n      \"                    @Expression <EqualTo> (Churn#285 = Yes) could run on GPU\\n\",\n      \"                      @Expression <AttributeReference> Churn#285 could run on GPU\\n\",\n      \"                      @Expression <Literal> Yes could run on GPU\\n\",\n      \"                    @Expression <Cast> cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) could run on GPU\\n\",\n      \"                      @Expression <UnaryMinus> -((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) could run on GPU\\n\",\n      \"                        @Expression <Add> ((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) could run on GPU\\n\",\n      \"                          @Expression <Add> (((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) could run on GPU\\n\",\n      \"                            @Expression <Add> ((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) could run on GPU\\n\",\n      \"                              @Expression <Add> (((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) could run on GPU\\n\",\n      \"                                @Expression <Remainder> ((abs(xxhash64(customerID#265, 42), false) & 255) % 36) could run on GPU\\n\",\n      \"                                  @Expression <BitwiseAnd> (abs(xxhash64(customerID#265, 42), false) & 255) could run on GPU\\n\",\n      \"                                    @Expression <Abs> abs(xxhash64(customerID#265, 42), false) could run on GPU\\n\",\n      \"                                      ! <XxHash64> xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                        @Expression <AttributeReference> customerID#265 could run on GPU\\n\",\n      \"                                    @Expression <Literal> 255 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 36 could run on GPU\\n\",\n      \"                                @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24) could run on GPU\\n\",\n      \"                                  @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) could run on GPU\\n\",\n      \"                                    @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#265, 42), false), 8) could run on GPU\\n\",\n      \"                                      @Expression <Abs> abs(xxhash64(customerID#265, 42), false) could run on GPU\\n\",\n      \"                                        ! <XxHash64> xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                          @Expression <AttributeReference> customerID#265 could run on GPU\\n\",\n      \"                                      @Expression <Literal> 8 could run on GPU\\n\",\n      \"                                    @Expression <Literal> 255 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 24 could run on GPU\\n\",\n      \"                              @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14) could run on GPU\\n\",\n      \"                                @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) could run on GPU\\n\",\n      \"                                  @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#265, 42), false), 16) could run on GPU\\n\",\n      \"                                    @Expression <Abs> abs(xxhash64(customerID#265, 42), false) could run on GPU\\n\",\n      \"                                      ! <XxHash64> xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                        @Expression <AttributeReference> customerID#265 could run on GPU\\n\",\n      \"                                    @Expression <Literal> 16 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 255 could run on GPU\\n\",\n      \"                                @Expression <Literal> 14 could run on GPU\\n\",\n      \"                            @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10) could run on GPU\\n\",\n      \"                              @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) could run on GPU\\n\",\n      \"                                @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#265, 42), false), 24) could run on GPU\\n\",\n      \"                                  @Expression <Abs> abs(xxhash64(customerID#265, 42), false) could run on GPU\\n\",\n      \"                                    ! <XxHash64> xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                      @Expression <AttributeReference> customerID#265 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 24 could run on GPU\\n\",\n      \"                                @Expression <Literal> 255 could run on GPU\\n\",\n      \"                              @Expression <Literal> 10 could run on GPU\\n\",\n      \"                          @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6) could run on GPU\\n\",\n      \"                            @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) could run on GPU\\n\",\n      \"                              @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#265, 42), false), 32) could run on GPU\\n\",\n      \"                                @Expression <Abs> abs(xxhash64(customerID#265, 42), false) could run on GPU\\n\",\n      \"                                  ! <XxHash64> xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                    @Expression <AttributeReference> customerID#265 could run on GPU\\n\",\n      \"                                @Expression <Literal> 32 could run on GPU\\n\",\n      \"                              @Expression <Literal> 255 could run on GPU\\n\",\n      \"                            @Expression <Literal> 6 could run on GPU\\n\",\n      \"                    @Expression <Literal> 0.0 could run on GPU\\n\",\n      \"        !Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"          @Expression <AttributeReference> customerID#290 could run on GPU\\n\",\n      \"          @Expression <Alias> AccountTermination AS kind#258 could run on GPU\\n\",\n      \"            @Expression <Literal> AccountTermination could run on GPU\\n\",\n      \"          @Expression <Alias> 0.00 AS value#259 could run on GPU\\n\",\n      \"            @Expression <Literal> 0.00 could run on GPU\\n\",\n      \"          @Expression <Alias> CASE WHEN (Churn#310 = Yes) THEN add_months(2022-04-05, cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int)) ELSE 2022-04-05 END AS date#260 could run on GPU\\n\",\n      \"            @Expression <CaseWhen> CASE WHEN (Churn#310 = Yes) THEN add_months(2022-04-05, cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int)) ELSE 2022-04-05 END could run on GPU\\n\",\n      \"              @Expression <EqualTo> (Churn#310 = Yes) could run on GPU\\n\",\n      \"                @Expression <AttributeReference> Churn#310 could run on GPU\\n\",\n      \"                @Expression <Literal> Yes could run on GPU\\n\",\n      \"              ! <AddMonths> add_months(2022-04-05, cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.AddMonths\\n\",\n      \"                @Expression <Literal> 2022-04-05 could run on GPU\\n\",\n      \"                @Expression <Cast> cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int) could run on GPU\\n\",\n      \"                  @Expression <UnaryMinus> -((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) could run on GPU\\n\",\n      \"                    @Expression <Add> ((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) could run on GPU\\n\",\n      \"                      @Expression <Add> (((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) could run on GPU\\n\",\n      \"                        @Expression <Add> ((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) could run on GPU\\n\",\n      \"                          @Expression <Add> (((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) could run on GPU\\n\",\n      \"                            @Expression <Remainder> ((abs(xxhash64(customerID#290, 42), false) & 255) % 36) could run on GPU\\n\",\n      \"                              @Expression <BitwiseAnd> (abs(xxhash64(customerID#290, 42), false) & 255) could run on GPU\\n\",\n      \"                                @Expression <Abs> abs(xxhash64(customerID#290, 42), false) could run on GPU\\n\",\n      \"                                  ! <XxHash64> xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                    @Expression <AttributeReference> customerID#290 could run on GPU\\n\",\n      \"                                @Expression <Literal> 255 could run on GPU\\n\",\n      \"                              @Expression <Literal> 36 could run on GPU\\n\",\n      \"                            @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24) could run on GPU\\n\",\n      \"                              @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) could run on GPU\\n\",\n      \"                                @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#290, 42), false), 8) could run on GPU\\n\",\n      \"                                  @Expression <Abs> abs(xxhash64(customerID#290, 42), false) could run on GPU\\n\",\n      \"                                    ! <XxHash64> xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                      @Expression <AttributeReference> customerID#290 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 8 could run on GPU\\n\",\n      \"                                @Expression <Literal> 255 could run on GPU\\n\",\n      \"                              @Expression <Literal> 24 could run on GPU\\n\",\n      \"                          @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14) could run on GPU\\n\",\n      \"                            @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) could run on GPU\\n\",\n      \"                              @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#290, 42), false), 16) could run on GPU\\n\",\n      \"                                @Expression <Abs> abs(xxhash64(customerID#290, 42), false) could run on GPU\\n\",\n      \"                                  ! <XxHash64> xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                    @Expression <AttributeReference> customerID#290 could run on GPU\\n\",\n      \"                                @Expression <Literal> 16 could run on GPU\\n\",\n      \"                              @Expression <Literal> 255 could run on GPU\\n\",\n      \"                            @Expression <Literal> 14 could run on GPU\\n\",\n      \"                        @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10) could run on GPU\\n\",\n      \"                          @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) could run on GPU\\n\",\n      \"                            @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#290, 42), false), 24) could run on GPU\\n\",\n      \"                              @Expression <Abs> abs(xxhash64(customerID#290, 42), false) could run on GPU\\n\",\n      \"                                ! <XxHash64> xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                  @Expression <AttributeReference> customerID#290 could run on GPU\\n\",\n      \"                              @Expression <Literal> 24 could run on GPU\\n\",\n      \"                            @Expression <Literal> 255 could run on GPU\\n\",\n      \"                          @Expression <Literal> 10 could run on GPU\\n\",\n      \"                      @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6) could run on GPU\\n\",\n      \"                        @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) could run on GPU\\n\",\n      \"                          @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#290, 42), false), 32) could run on GPU\\n\",\n      \"                            @Expression <Abs> abs(xxhash64(customerID#290, 42), false) could run on GPU\\n\",\n      \"                              ! <XxHash64> xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                @Expression <AttributeReference> customerID#290 could run on GPU\\n\",\n      \"                            @Expression <Literal> 32 could run on GPU\\n\",\n      \"                          @Expression <Literal> 255 could run on GPU\\n\",\n      \"                        @Expression <Literal> 6 could run on GPU\\n\",\n      \"              @Expression <Literal> 2022-04-05 could run on GPU\\n\",\n      \"\\n\",\n      \"2022-04-05 09:36:37,903 WARN rapids.GpuOverrides: \\n\",\n      \"!Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"  @Expression <Alias> format_string(%s-%s, customerID#0, u_value#337) AS customerID#816 could run on GPU\\n\",\n      \"    ! <FormatString> format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\\n\",\n      \"      @Expression <Literal> %s-%s could run on GPU\\n\",\n      \"      @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"      @Expression <AttributeReference> u_value#337 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> kind#133 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> value#136 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> date#156 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> month#315 could run on GPU\\n\",\n      \"        !Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"          @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"          @Expression <Alias> Charge AS kind#133 could run on GPU\\n\",\n      \"            @Expression <Literal> Charge could run on GPU\\n\",\n      \"          @Expression <AttributeReference> value#136 could run on GPU\\n\",\n      \"          @Expression <Alias> add_months(2022-04-05, cast(-(cast(_we0#149 as bigint) + last_month#135L) as int)) AS date#156 could run on GPU\\n\",\n      \"            ! <AddMonths> add_months(2022-04-05, cast(-(cast(_we0#149 as bigint) + last_month#135L) as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.AddMonths\\n\",\n      \"              @Expression <Literal> 2022-04-05 could run on GPU\\n\",\n      \"              @Expression <Cast> cast(-(cast(_we0#149 as bigint) + last_month#135L) as int) could run on GPU\\n\",\n      \"                @Expression <UnaryMinus> -(cast(_we0#149 as bigint) + last_month#135L) could run on GPU\\n\",\n      \"                  @Expression <Add> (cast(_we0#149 as bigint) + last_month#135L) could run on GPU\\n\",\n      \"                    @Expression <Cast> cast(_we0#149 as bigint) could run on GPU\\n\",\n      \"                      @Expression <AttributeReference> _we0#149 could run on GPU\\n\",\n      \"                    @Expression <AttributeReference> last_month#135L could run on GPU\\n\",\n      \"        !Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"          @Expression <AttributeReference> customerID#265 could run on GPU\\n\",\n      \"          @Expression <Alias> AccountCreation AS kind#191 could run on GPU\\n\",\n      \"            @Expression <Literal> AccountCreation could run on GPU\\n\",\n      \"          @Expression <Alias> 0.00 AS value#192 could run on GPU\\n\",\n      \"            @Expression <Literal> 0.00 could run on GPU\\n\",\n      \"          @Expression <Alias> add_months(2022-04-05, cast(((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) as int)) AS date#200 could run on GPU\\n\",\n      \"            ! <AddMonths> add_months(2022-04-05, cast(((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.AddMonths\\n\",\n      \"              @Expression <Literal> 2022-04-05 could run on GPU\\n\",\n      \"              !Expression <Cast> cast(((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) as int) cannot run on GPU because Casting from floating point types to integral types on the GPU supports a slightly different range of values when using Spark 3.1.0 or later. Refer to the CAST documentation for more details.. To enable this operation on the GPU, set spark.rapids.sql.castFloatToIntegralTypes.enabled to true.\\n\",\n      \"                @Expression <Add> ((-tenure#270 - 1.0) + CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END) could run on GPU\\n\",\n      \"                  @Expression <Subtract> (-tenure#270 - 1.0) could run on GPU\\n\",\n      \"                    @Expression <UnaryMinus> -tenure#270 could run on GPU\\n\",\n      \"                      @Expression <AttributeReference> tenure#270 could run on GPU\\n\",\n      \"                    @Expression <Literal> 1.0 could run on GPU\\n\",\n      \"                  @Expression <CaseWhen> CASE WHEN (Churn#285 = Yes) THEN cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) ELSE 0.0 END could run on GPU\\n\",\n      \"                    @Expression <EqualTo> (Churn#285 = Yes) could run on GPU\\n\",\n      \"                      @Expression <AttributeReference> Churn#285 could run on GPU\\n\",\n      \"                      @Expression <Literal> Yes could run on GPU\\n\",\n      \"                    @Expression <Cast> cast(-((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) as double) could run on GPU\\n\",\n      \"                      @Expression <UnaryMinus> -((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) could run on GPU\\n\",\n      \"                        @Expression <Add> ((((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6)) could run on GPU\\n\",\n      \"                          @Expression <Add> (((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10)) could run on GPU\\n\",\n      \"                            @Expression <Add> ((((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14)) could run on GPU\\n\",\n      \"                              @Expression <Add> (((abs(xxhash64(customerID#265, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24)) could run on GPU\\n\",\n      \"                                @Expression <Remainder> ((abs(xxhash64(customerID#265, 42), false) & 255) % 36) could run on GPU\\n\",\n      \"                                  @Expression <BitwiseAnd> (abs(xxhash64(customerID#265, 42), false) & 255) could run on GPU\\n\",\n      \"                                    @Expression <Abs> abs(xxhash64(customerID#265, 42), false) could run on GPU\\n\",\n      \"                                      ! <XxHash64> xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                        @Expression <AttributeReference> customerID#265 could run on GPU\\n\",\n      \"                                    @Expression <Literal> 255 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 36 could run on GPU\\n\",\n      \"                                @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) % 24) could run on GPU\\n\",\n      \"                                  @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#265, 42), false), 8) & 255) could run on GPU\\n\",\n      \"                                    @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#265, 42), false), 8) could run on GPU\\n\",\n      \"                                      @Expression <Abs> abs(xxhash64(customerID#265, 42), false) could run on GPU\\n\",\n      \"                                        ! <XxHash64> xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                          @Expression <AttributeReference> customerID#265 could run on GPU\\n\",\n      \"                                      @Expression <Literal> 8 could run on GPU\\n\",\n      \"                                    @Expression <Literal> 255 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 24 could run on GPU\\n\",\n      \"                              @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) % 14) could run on GPU\\n\",\n      \"                                @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#265, 42), false), 16) & 255) could run on GPU\\n\",\n      \"                                  @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#265, 42), false), 16) could run on GPU\\n\",\n      \"                                    @Expression <Abs> abs(xxhash64(customerID#265, 42), false) could run on GPU\\n\",\n      \"                                      ! <XxHash64> xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                        @Expression <AttributeReference> customerID#265 could run on GPU\\n\",\n      \"                                    @Expression <Literal> 16 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 255 could run on GPU\\n\",\n      \"                                @Expression <Literal> 14 could run on GPU\\n\",\n      \"                            @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) % 10) could run on GPU\\n\",\n      \"                              @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#265, 42), false), 24) & 255) could run on GPU\\n\",\n      \"                                @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#265, 42), false), 24) could run on GPU\\n\",\n      \"                                  @Expression <Abs> abs(xxhash64(customerID#265, 42), false) could run on GPU\\n\",\n      \"                                    ! <XxHash64> xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                      @Expression <AttributeReference> customerID#265 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 24 could run on GPU\\n\",\n      \"                                @Expression <Literal> 255 could run on GPU\\n\",\n      \"                              @Expression <Literal> 10 could run on GPU\\n\",\n      \"                          @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) % 6) could run on GPU\\n\",\n      \"                            @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#265, 42), false), 32) & 255) could run on GPU\\n\",\n      \"                              @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#265, 42), false), 32) could run on GPU\\n\",\n      \"                                @Expression <Abs> abs(xxhash64(customerID#265, 42), false) could run on GPU\\n\",\n      \"                                  ! <XxHash64> xxhash64(customerID#265, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                    @Expression <AttributeReference> customerID#265 could run on GPU\\n\",\n      \"                                @Expression <Literal> 32 could run on GPU\\n\",\n      \"                              @Expression <Literal> 255 could run on GPU\\n\",\n      \"                            @Expression <Literal> 6 could run on GPU\\n\",\n      \"                    @Expression <Literal> 0.0 could run on GPU\\n\",\n      \"        !Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"          @Expression <AttributeReference> customerID#290 could run on GPU\\n\",\n      \"          @Expression <Alias> AccountTermination AS kind#258 could run on GPU\\n\",\n      \"            @Expression <Literal> AccountTermination could run on GPU\\n\",\n      \"          @Expression <Alias> 0.00 AS value#259 could run on GPU\\n\",\n      \"            @Expression <Literal> 0.00 could run on GPU\\n\",\n      \"          @Expression <Alias> CASE WHEN (Churn#310 = Yes) THEN add_months(2022-04-05, cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int)) ELSE 2022-04-05 END AS date#260 could run on GPU\\n\",\n      \"            @Expression <CaseWhen> CASE WHEN (Churn#310 = Yes) THEN add_months(2022-04-05, cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int)) ELSE 2022-04-05 END could run on GPU\\n\",\n      \"              @Expression <EqualTo> (Churn#310 = Yes) could run on GPU\\n\",\n      \"                @Expression <AttributeReference> Churn#310 could run on GPU\\n\",\n      \"                @Expression <Literal> Yes could run on GPU\\n\",\n      \"              ! <AddMonths> add_months(2022-04-05, cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.AddMonths\\n\",\n      \"                @Expression <Literal> 2022-04-05 could run on GPU\\n\",\n      \"                @Expression <Cast> cast(-((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) as int) could run on GPU\\n\",\n      \"                  @Expression <UnaryMinus> -((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) could run on GPU\\n\",\n      \"                    @Expression <Add> ((((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6)) could run on GPU\\n\",\n      \"                      @Expression <Add> (((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10)) could run on GPU\\n\",\n      \"                        @Expression <Add> ((((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14)) could run on GPU\\n\",\n      \"                          @Expression <Add> (((abs(xxhash64(customerID#290, 42), false) & 255) % 36) + ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24)) could run on GPU\\n\",\n      \"                            @Expression <Remainder> ((abs(xxhash64(customerID#290, 42), false) & 255) % 36) could run on GPU\\n\",\n      \"                              @Expression <BitwiseAnd> (abs(xxhash64(customerID#290, 42), false) & 255) could run on GPU\\n\",\n      \"                                @Expression <Abs> abs(xxhash64(customerID#290, 42), false) could run on GPU\\n\",\n      \"                                  ! <XxHash64> xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                    @Expression <AttributeReference> customerID#290 could run on GPU\\n\",\n      \"                                @Expression <Literal> 255 could run on GPU\\n\",\n      \"                              @Expression <Literal> 36 could run on GPU\\n\",\n      \"                            @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) % 24) could run on GPU\\n\",\n      \"                              @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#290, 42), false), 8) & 255) could run on GPU\\n\",\n      \"                                @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#290, 42), false), 8) could run on GPU\\n\",\n      \"                                  @Expression <Abs> abs(xxhash64(customerID#290, 42), false) could run on GPU\\n\",\n      \"                                    ! <XxHash64> xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                      @Expression <AttributeReference> customerID#290 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 8 could run on GPU\\n\",\n      \"                                @Expression <Literal> 255 could run on GPU\\n\",\n      \"                              @Expression <Literal> 24 could run on GPU\\n\",\n      \"                          @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) % 14) could run on GPU\\n\",\n      \"                            @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#290, 42), false), 16) & 255) could run on GPU\\n\",\n      \"                              @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#290, 42), false), 16) could run on GPU\\n\",\n      \"                                @Expression <Abs> abs(xxhash64(customerID#290, 42), false) could run on GPU\\n\",\n      \"                                  ! <XxHash64> xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                    @Expression <AttributeReference> customerID#290 could run on GPU\\n\",\n      \"                                @Expression <Literal> 16 could run on GPU\\n\",\n      \"                              @Expression <Literal> 255 could run on GPU\\n\",\n      \"                            @Expression <Literal> 14 could run on GPU\\n\",\n      \"                        @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) % 10) could run on GPU\\n\",\n      \"                          @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#290, 42), false), 24) & 255) could run on GPU\\n\",\n      \"                            @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#290, 42), false), 24) could run on GPU\\n\",\n      \"                              @Expression <Abs> abs(xxhash64(customerID#290, 42), false) could run on GPU\\n\",\n      \"                                ! <XxHash64> xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                  @Expression <AttributeReference> customerID#290 could run on GPU\\n\",\n      \"                              @Expression <Literal> 24 could run on GPU\\n\",\n      \"                            @Expression <Literal> 255 could run on GPU\\n\",\n      \"                          @Expression <Literal> 10 could run on GPU\\n\",\n      \"                      @Expression <Remainder> ((shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) % 6) could run on GPU\\n\",\n      \"                        @Expression <BitwiseAnd> (shiftright(abs(xxhash64(customerID#290, 42), false), 32) & 255) could run on GPU\\n\",\n      \"                          @Expression <ShiftRight> shiftright(abs(xxhash64(customerID#290, 42), false), 32) could run on GPU\\n\",\n      \"                            @Expression <Abs> abs(xxhash64(customerID#290, 42), false) could run on GPU\\n\",\n      \"                              ! <XxHash64> xxhash64(customerID#290, 42) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.XxHash64\\n\",\n      \"                                @Expression <AttributeReference> customerID#290 could run on GPU\\n\",\n      \"                            @Expression <Literal> 32 could run on GPU\\n\",\n      \"                          @Expression <Literal> 255 could run on GPU\\n\",\n      \"                        @Expression <Literal> 6 could run on GPU\\n\",\n      \"              @Expression <Literal> 2022-04-05 could run on GPU\\n\",\n      \"\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2022-04-05 09:40:21,129 WARN rapids.GpuOverrides:                               \\n\",\n      \"  !Exec <ShuffleExchangeExec> cannot run on GPU because Columnar exchange without columnar children is inefficient\\n\",\n      \"    @Partitioning <RangePartitioning> could run on GPU\\n\",\n      \"      @Expression <SortOrder> customerID#395 ASC NULLS FIRST could run on GPU\\n\",\n      \"        @Expression <AttributeReference> customerID#395 could run on GPU\\n\",\n      \"    !Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"      @Expression <Alias> format_string(%s-%s, customerID#0, u_value#337) AS customerID#395 could run on GPU\\n\",\n      \"        ! <FormatString> format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\\n\",\n      \"          @Expression <Literal> %s-%s could run on GPU\\n\",\n      \"          @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"          @Expression <AttributeReference> u_value#337 could run on GPU\\n\",\n      \"      @Expression <Alias> date_sub(2022-04-05, cast(FLOOR(CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END) as int)) AS dateOfBirth#460 could run on GPU\\n\",\n      \"        @Expression <DateSub> date_sub(2022-04-05, cast(FLOOR(CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END) as int)) could run on GPU\\n\",\n      \"          @Expression <Literal> 2022-04-05 could run on GPU\\n\",\n      \"          @Expression <Cast> cast(FLOOR(CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END) as int) could run on GPU\\n\",\n      \"            @Expression <Floor> FLOOR(CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END) could run on GPU\\n\",\n      \"              @Expression <CaseWhen> CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END could run on GPU\\n\",\n      \"                @Expression <EqualTo> (cast(SeniorCitizen#2 as int) = 0) could run on GPU\\n\",\n      \"                  @Expression <Cast> cast(SeniorCitizen#2 as int) could run on GPU\\n\",\n      \"                    @Expression <AttributeReference> SeniorCitizen#2 could run on GPU\\n\",\n      \"                  @Expression <Literal> 0 could run on GPU\\n\",\n      \"                @Expression <Add> (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) could run on GPU\\n\",\n      \"                  @Expression <Multiply> ((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) could run on GPU\\n\",\n      \"                    @Expression <Divide> (cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) could run on GPU\\n\",\n      \"                      @Expression <Cast> cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) could run on GPU\\n\",\n      \"                        @Expression <Remainder> (abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) could run on GPU\\n\",\n      \"                          @Expression <Abs> abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) could run on GPU\\n\",\n      \"                            @Expression <Murmur3Hash> hash(format_string(%s-%s, customerID#0, u_value#337), 42) could run on GPU\\n\",\n      \"                              ! <FormatString> format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\\n\",\n      \"                                @Expression <Literal> %s-%s could run on GPU\\n\",\n      \"                                @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"                                @Expression <AttributeReference> u_value#337 could run on GPU\\n\",\n      \"                          @Expression <Literal> 4096 could run on GPU\\n\",\n      \"                      @Expression <Literal> 4096.0 could run on GPU\\n\",\n      \"                    @Expression <Literal> 16801.5 could run on GPU\\n\",\n      \"                  @Expression <Literal> 6574.5 could run on GPU\\n\",\n      \"                @Expression <Add> (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) could run on GPU\\n\",\n      \"                  @Expression <Multiply> ((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) could run on GPU\\n\",\n      \"                    @Expression <Multiply> (-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) could run on GPU\\n\",\n      \"                      @Expression <UnaryMinus> -LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) could run on GPU\\n\",\n      \"                        @Expression <Log1p> LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) could run on GPU\\n\",\n      \"                          @Expression <UnaryMinus> -(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) could run on GPU\\n\",\n      \"                            @Expression <Divide> (cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) could run on GPU\\n\",\n      \"                              @Expression <Cast> cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) could run on GPU\\n\",\n      \"                                @Expression <Remainder> (abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) could run on GPU\\n\",\n      \"                                  @Expression <Abs> abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) could run on GPU\\n\",\n      \"                                    @Expression <Murmur3Hash> hash(format_string(%s-%s, customerID#0, u_value#337), 42) could run on GPU\\n\",\n      \"                                      ! <FormatString> format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\\n\",\n      \"                                        @Expression <Literal> %s-%s could run on GPU\\n\",\n      \"                                        @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"                                        @Expression <AttributeReference> u_value#337 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 4096 could run on GPU\\n\",\n      \"                              @Expression <Literal> 4096.0 could run on GPU\\n\",\n      \"                      @Expression <Literal> 6.3 could run on GPU\\n\",\n      \"                    @Expression <Literal> 365.25 could run on GPU\\n\",\n      \"                  @Expression <Literal> 23741.25 could run on GPU\\n\",\n      \"      @Expression <AttributeReference> gender#1 could run on GPU\\n\",\n      \"      @Expression <AttributeReference> SeniorCitizen#2 could run on GPU\\n\",\n      \"      @Expression <AttributeReference> Partner#3 could run on GPU\\n\",\n      \"      @Expression <AttributeReference> Dependents#4 could run on GPU\\n\",\n      \"      @Expression <Alias> cast(MonthlyCharges#18 as decimal(8,2)) AS MonthlyCharges#441 could run on GPU\\n\",\n      \"        @Expression <Cast> cast(MonthlyCharges#18 as decimal(8,2)) could run on GPU\\n\",\n      \"          @Expression <AttributeReference> MonthlyCharges#18 could run on GPU\\n\",\n      \"      @Expression <Alias> 2022-04-05 09:36:19.001066 AS now#439 could run on GPU\\n\",\n      \"        @Expression <Literal> 2022-04-05 09:36:19.001066 could run on GPU\\n\",\n      \"\\n\",\n      \"2022-04-05 09:40:21,133 WARN rapids.GpuOverrides: \\n\",\n      \"  !Exec <ShuffleExchangeExec> cannot run on GPU because Columnar exchange without columnar children is inefficient\\n\",\n      \"    @Partitioning <RangePartitioning> could run on GPU\\n\",\n      \"      @Expression <SortOrder> customerID#395 ASC NULLS FIRST could run on GPU\\n\",\n      \"        @Expression <AttributeReference> customerID#395 could run on GPU\\n\",\n      \"    !Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"      @Expression <Alias> format_string(%s-%s, customerID#0, u_value#337) AS customerID#395 could run on GPU\\n\",\n      \"        ! <FormatString> format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\\n\",\n      \"          @Expression <Literal> %s-%s could run on GPU\\n\",\n      \"          @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"          @Expression <AttributeReference> u_value#337 could run on GPU\\n\",\n      \"      @Expression <Alias> date_sub(2022-04-05, cast(FLOOR(CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END) as int)) AS dateOfBirth#460 could run on GPU\\n\",\n      \"        @Expression <DateSub> date_sub(2022-04-05, cast(FLOOR(CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END) as int)) could run on GPU\\n\",\n      \"          @Expression <Literal> 2022-04-05 could run on GPU\\n\",\n      \"          @Expression <Cast> cast(FLOOR(CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END) as int) could run on GPU\\n\",\n      \"            @Expression <Floor> FLOOR(CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END) could run on GPU\\n\",\n      \"              @Expression <CaseWhen> CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END could run on GPU\\n\",\n      \"                @Expression <EqualTo> (cast(SeniorCitizen#2 as int) = 0) could run on GPU\\n\",\n      \"                  @Expression <Cast> cast(SeniorCitizen#2 as int) could run on GPU\\n\",\n      \"                    @Expression <AttributeReference> SeniorCitizen#2 could run on GPU\\n\",\n      \"                  @Expression <Literal> 0 could run on GPU\\n\",\n      \"                @Expression <Add> (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) could run on GPU\\n\",\n      \"                  @Expression <Multiply> ((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) could run on GPU\\n\",\n      \"                    @Expression <Divide> (cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) could run on GPU\\n\",\n      \"                      @Expression <Cast> cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) could run on GPU\\n\",\n      \"                        @Expression <Remainder> (abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) could run on GPU\\n\",\n      \"                          @Expression <Abs> abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) could run on GPU\\n\",\n      \"                            @Expression <Murmur3Hash> hash(format_string(%s-%s, customerID#0, u_value#337), 42) could run on GPU\\n\",\n      \"                              ! <FormatString> format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\\n\",\n      \"                                @Expression <Literal> %s-%s could run on GPU\\n\",\n      \"                                @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"                                @Expression <AttributeReference> u_value#337 could run on GPU\\n\",\n      \"                          @Expression <Literal> 4096 could run on GPU\\n\",\n      \"                      @Expression <Literal> 4096.0 could run on GPU\\n\",\n      \"                    @Expression <Literal> 16801.5 could run on GPU\\n\",\n      \"                  @Expression <Literal> 6574.5 could run on GPU\\n\",\n      \"                @Expression <Add> (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) could run on GPU\\n\",\n      \"                  @Expression <Multiply> ((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) could run on GPU\\n\",\n      \"                    @Expression <Multiply> (-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) could run on GPU\\n\",\n      \"                      @Expression <UnaryMinus> -LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) could run on GPU\\n\",\n      \"                        @Expression <Log1p> LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) could run on GPU\\n\",\n      \"                          @Expression <UnaryMinus> -(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) could run on GPU\\n\",\n      \"                            @Expression <Divide> (cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) could run on GPU\\n\",\n      \"                              @Expression <Cast> cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) could run on GPU\\n\",\n      \"                                @Expression <Remainder> (abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) could run on GPU\\n\",\n      \"                                  @Expression <Abs> abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) could run on GPU\\n\",\n      \"                                    @Expression <Murmur3Hash> hash(format_string(%s-%s, customerID#0, u_value#337), 42) could run on GPU\\n\",\n      \"                                      ! <FormatString> format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\\n\",\n      \"                                        @Expression <Literal> %s-%s could run on GPU\\n\",\n      \"                                        @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"                                        @Expression <AttributeReference> u_value#337 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 4096 could run on GPU\\n\",\n      \"                              @Expression <Literal> 4096.0 could run on GPU\\n\",\n      \"                      @Expression <Literal> 6.3 could run on GPU\\n\",\n      \"                    @Expression <Literal> 365.25 could run on GPU\\n\",\n      \"                  @Expression <Literal> 23741.25 could run on GPU\\n\",\n      \"      @Expression <AttributeReference> gender#1 could run on GPU\\n\",\n      \"      @Expression <AttributeReference> SeniorCitizen#2 could run on GPU\\n\",\n      \"      @Expression <AttributeReference> Partner#3 could run on GPU\\n\",\n      \"      @Expression <AttributeReference> Dependents#4 could run on GPU\\n\",\n      \"      @Expression <Alias> cast(MonthlyCharges#18 as decimal(8,2)) AS MonthlyCharges#441 could run on GPU\\n\",\n      \"        @Expression <Cast> cast(MonthlyCharges#18 as decimal(8,2)) could run on GPU\\n\",\n      \"          @Expression <AttributeReference> MonthlyCharges#18 could run on GPU\\n\",\n      \"      @Expression <Alias> 2022-04-05 09:36:19.001066 AS now#439 could run on GPU\\n\",\n      \"        @Expression <Literal> 2022-04-05 09:36:19.001066 could run on GPU\\n\",\n      \"\\n\",\n      \"2022-04-05 09:40:21,138 WARN rapids.GpuOverrides: \\n\",\n      \"  !Exec <ShuffleExchangeExec> cannot run on GPU because Columnar exchange without columnar children is inefficient\\n\",\n      \"    @Partitioning <RangePartitioning> could run on GPU\\n\",\n      \"      @Expression <SortOrder> customerID#395 ASC NULLS FIRST could run on GPU\\n\",\n      \"        @Expression <AttributeReference> customerID#395 could run on GPU\\n\",\n      \"    !Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"      @Expression <Alias> format_string(%s-%s, customerID#0, u_value#337) AS customerID#395 could run on GPU\\n\",\n      \"        ! <FormatString> format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\\n\",\n      \"          @Expression <Literal> %s-%s could run on GPU\\n\",\n      \"          @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"          @Expression <AttributeReference> u_value#337 could run on GPU\\n\",\n      \"      @Expression <Alias> date_sub(2022-04-05, cast(FLOOR(CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END) as int)) AS dateOfBirth#460 could run on GPU\\n\",\n      \"        @Expression <DateSub> date_sub(2022-04-05, cast(FLOOR(CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END) as int)) could run on GPU\\n\",\n      \"          @Expression <Literal> 2022-04-05 could run on GPU\\n\",\n      \"          @Expression <Cast> cast(FLOOR(CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END) as int) could run on GPU\\n\",\n      \"            @Expression <Floor> FLOOR(CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END) could run on GPU\\n\",\n      \"              @Expression <CaseWhen> CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END could run on GPU\\n\",\n      \"                @Expression <EqualTo> (cast(SeniorCitizen#2 as int) = 0) could run on GPU\\n\",\n      \"                  @Expression <Cast> cast(SeniorCitizen#2 as int) could run on GPU\\n\",\n      \"                    @Expression <AttributeReference> SeniorCitizen#2 could run on GPU\\n\",\n      \"                  @Expression <Literal> 0 could run on GPU\\n\",\n      \"                @Expression <Add> (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) could run on GPU\\n\",\n      \"                  @Expression <Multiply> ((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) could run on GPU\\n\",\n      \"                    @Expression <Divide> (cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) could run on GPU\\n\",\n      \"                      @Expression <Cast> cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) could run on GPU\\n\",\n      \"                        @Expression <Remainder> (abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) could run on GPU\\n\",\n      \"                          @Expression <Abs> abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) could run on GPU\\n\",\n      \"                            @Expression <Murmur3Hash> hash(format_string(%s-%s, customerID#0, u_value#337), 42) could run on GPU\\n\",\n      \"                              ! <FormatString> format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\\n\",\n      \"                                @Expression <Literal> %s-%s could run on GPU\\n\",\n      \"                                @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"                                @Expression <AttributeReference> u_value#337 could run on GPU\\n\",\n      \"                          @Expression <Literal> 4096 could run on GPU\\n\",\n      \"                      @Expression <Literal> 4096.0 could run on GPU\\n\",\n      \"                    @Expression <Literal> 16801.5 could run on GPU\\n\",\n      \"                  @Expression <Literal> 6574.5 could run on GPU\\n\",\n      \"                @Expression <Add> (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) could run on GPU\\n\",\n      \"                  @Expression <Multiply> ((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) could run on GPU\\n\",\n      \"                    @Expression <Multiply> (-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) could run on GPU\\n\",\n      \"                      @Expression <UnaryMinus> -LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) could run on GPU\\n\",\n      \"                        @Expression <Log1p> LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) could run on GPU\\n\",\n      \"                          @Expression <UnaryMinus> -(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) could run on GPU\\n\",\n      \"                            @Expression <Divide> (cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) could run on GPU\\n\",\n      \"                              @Expression <Cast> cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) could run on GPU\\n\",\n      \"                                @Expression <Remainder> (abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) could run on GPU\\n\",\n      \"                                  @Expression <Abs> abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) could run on GPU\\n\",\n      \"                                    @Expression <Murmur3Hash> hash(format_string(%s-%s, customerID#0, u_value#337), 42) could run on GPU\\n\",\n      \"                                      ! <FormatString> format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\\n\",\n      \"                                        @Expression <Literal> %s-%s could run on GPU\\n\",\n      \"                                        @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"                                        @Expression <AttributeReference> u_value#337 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 4096 could run on GPU\\n\",\n      \"                              @Expression <Literal> 4096.0 could run on GPU\\n\",\n      \"                      @Expression <Literal> 6.3 could run on GPU\\n\",\n      \"                    @Expression <Literal> 365.25 could run on GPU\\n\",\n      \"                  @Expression <Literal> 23741.25 could run on GPU\\n\",\n      \"      @Expression <AttributeReference> gender#1 could run on GPU\\n\",\n      \"      @Expression <AttributeReference> SeniorCitizen#2 could run on GPU\\n\",\n      \"      @Expression <AttributeReference> Partner#3 could run on GPU\\n\",\n      \"      @Expression <AttributeReference> Dependents#4 could run on GPU\\n\",\n      \"      @Expression <Alias> cast(MonthlyCharges#18 as decimal(8,2)) AS MonthlyCharges#441 could run on GPU\\n\",\n      \"        @Expression <Cast> cast(MonthlyCharges#18 as decimal(8,2)) could run on GPU\\n\",\n      \"          @Expression <AttributeReference> MonthlyCharges#18 could run on GPU\\n\",\n      \"      @Expression <Alias> 2022-04-05 09:36:19.001066 AS now#439 could run on GPU\\n\",\n      \"        @Expression <Literal> 2022-04-05 09:36:19.001066 could run on GPU\\n\",\n      \"\\n\",\n      \"2022-04-05 09:40:21,144 WARN rapids.GpuOverrides: \\n\",\n      \"  !Exec <ShuffleExchangeExec> cannot run on GPU because Columnar exchange without columnar children is inefficient\\n\",\n      \"    @Partitioning <RangePartitioning> could run on GPU\\n\",\n      \"      @Expression <SortOrder> customerID#395 ASC NULLS FIRST could run on GPU\\n\",\n      \"        @Expression <AttributeReference> customerID#395 could run on GPU\\n\",\n      \"    !Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"      @Expression <Alias> format_string(%s-%s, customerID#0, u_value#337) AS customerID#395 could run on GPU\\n\",\n      \"        ! <FormatString> format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\\n\",\n      \"          @Expression <Literal> %s-%s could run on GPU\\n\",\n      \"          @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"          @Expression <AttributeReference> u_value#337 could run on GPU\\n\",\n      \"      @Expression <Alias> date_sub(2022-04-05, cast(FLOOR(CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END) as int)) AS dateOfBirth#460 could run on GPU\\n\",\n      \"        @Expression <DateSub> date_sub(2022-04-05, cast(FLOOR(CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END) as int)) could run on GPU\\n\",\n      \"          @Expression <Literal> 2022-04-05 could run on GPU\\n\",\n      \"          @Expression <Cast> cast(FLOOR(CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END) as int) could run on GPU\\n\",\n      \"            @Expression <Floor> FLOOR(CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END) could run on GPU\\n\",\n      \"              @Expression <CaseWhen> CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END could run on GPU\\n\",\n      \"                @Expression <EqualTo> (cast(SeniorCitizen#2 as int) = 0) could run on GPU\\n\",\n      \"                  @Expression <Cast> cast(SeniorCitizen#2 as int) could run on GPU\\n\",\n      \"                    @Expression <AttributeReference> SeniorCitizen#2 could run on GPU\\n\",\n      \"                  @Expression <Literal> 0 could run on GPU\\n\",\n      \"                @Expression <Add> (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) could run on GPU\\n\",\n      \"                  @Expression <Multiply> ((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) could run on GPU\\n\",\n      \"                    @Expression <Divide> (cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) could run on GPU\\n\",\n      \"                      @Expression <Cast> cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) could run on GPU\\n\",\n      \"                        @Expression <Remainder> (abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) could run on GPU\\n\",\n      \"                          @Expression <Abs> abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) could run on GPU\\n\",\n      \"                            @Expression <Murmur3Hash> hash(format_string(%s-%s, customerID#0, u_value#337), 42) could run on GPU\\n\",\n      \"                              ! <FormatString> format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\\n\",\n      \"                                @Expression <Literal> %s-%s could run on GPU\\n\",\n      \"                                @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"                                @Expression <AttributeReference> u_value#337 could run on GPU\\n\",\n      \"                          @Expression <Literal> 4096 could run on GPU\\n\",\n      \"                      @Expression <Literal> 4096.0 could run on GPU\\n\",\n      \"                    @Expression <Literal> 16801.5 could run on GPU\\n\",\n      \"                  @Expression <Literal> 6574.5 could run on GPU\\n\",\n      \"                @Expression <Add> (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) could run on GPU\\n\",\n      \"                  @Expression <Multiply> ((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) could run on GPU\\n\",\n      \"                    @Expression <Multiply> (-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) could run on GPU\\n\",\n      \"                      @Expression <UnaryMinus> -LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) could run on GPU\\n\",\n      \"                        @Expression <Log1p> LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) could run on GPU\\n\",\n      \"                          @Expression <UnaryMinus> -(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) could run on GPU\\n\",\n      \"                            @Expression <Divide> (cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) could run on GPU\\n\",\n      \"                              @Expression <Cast> cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) could run on GPU\\n\",\n      \"                                @Expression <Remainder> (abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) could run on GPU\\n\",\n      \"                                  @Expression <Abs> abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) could run on GPU\\n\",\n      \"                                    @Expression <Murmur3Hash> hash(format_string(%s-%s, customerID#0, u_value#337), 42) could run on GPU\\n\",\n      \"                                      ! <FormatString> format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\\n\",\n      \"                                        @Expression <Literal> %s-%s could run on GPU\\n\",\n      \"                                        @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"                                        @Expression <AttributeReference> u_value#337 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 4096 could run on GPU\\n\",\n      \"                              @Expression <Literal> 4096.0 could run on GPU\\n\",\n      \"                      @Expression <Literal> 6.3 could run on GPU\\n\",\n      \"                    @Expression <Literal> 365.25 could run on GPU\\n\",\n      \"                  @Expression <Literal> 23741.25 could run on GPU\\n\",\n      \"      @Expression <AttributeReference> gender#1 could run on GPU\\n\",\n      \"      @Expression <AttributeReference> SeniorCitizen#2 could run on GPU\\n\",\n      \"      @Expression <AttributeReference> Partner#3 could run on GPU\\n\",\n      \"      @Expression <AttributeReference> Dependents#4 could run on GPU\\n\",\n      \"      @Expression <Alias> cast(MonthlyCharges#18 as decimal(8,2)) AS MonthlyCharges#441 could run on GPU\\n\",\n      \"        @Expression <Cast> cast(MonthlyCharges#18 as decimal(8,2)) could run on GPU\\n\",\n      \"          @Expression <AttributeReference> MonthlyCharges#18 could run on GPU\\n\",\n      \"      @Expression <Alias> 2022-04-05 09:36:19.001066 AS now#439 could run on GPU\\n\",\n      \"        @Expression <Literal> 2022-04-05 09:36:19.001066 could run on GPU\\n\",\n      \"\\n\",\n      \"2022-04-05 09:40:21,206 WARN rapids.GpuOverrides: \\n\",\n      \"  !Exec <ShuffleExchangeExec> cannot run on GPU because Columnar exchange without columnar children is inefficient\\n\",\n      \"    @Partitioning <RangePartitioning> could run on GPU\\n\",\n      \"      @Expression <SortOrder> customerID#395 ASC NULLS FIRST could run on GPU\\n\",\n      \"        @Expression <AttributeReference> customerID#395 could run on GPU\\n\",\n      \"    !Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"      @Expression <Alias> format_string(%s-%s, customerID#0, u_value#337) AS customerID#395 could run on GPU\\n\",\n      \"        ! <FormatString> format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\\n\",\n      \"          @Expression <Literal> %s-%s could run on GPU\\n\",\n      \"          @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"          @Expression <AttributeReference> u_value#337 could run on GPU\\n\",\n      \"      @Expression <Alias> date_sub(2022-04-05, cast(FLOOR(CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END) as int)) AS dateOfBirth#460 could run on GPU\\n\",\n      \"        @Expression <DateSub> date_sub(2022-04-05, cast(FLOOR(CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END) as int)) could run on GPU\\n\",\n      \"          @Expression <Literal> 2022-04-05 could run on GPU\\n\",\n      \"          @Expression <Cast> cast(FLOOR(CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END) as int) could run on GPU\\n\",\n      \"            @Expression <Floor> FLOOR(CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END) could run on GPU\\n\",\n      \"              @Expression <CaseWhen> CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END could run on GPU\\n\",\n      \"                @Expression <EqualTo> (cast(SeniorCitizen#2 as int) = 0) could run on GPU\\n\",\n      \"                  @Expression <Cast> cast(SeniorCitizen#2 as int) could run on GPU\\n\",\n      \"                    @Expression <AttributeReference> SeniorCitizen#2 could run on GPU\\n\",\n      \"                  @Expression <Literal> 0 could run on GPU\\n\",\n      \"                @Expression <Add> (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) could run on GPU\\n\",\n      \"                  @Expression <Multiply> ((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) could run on GPU\\n\",\n      \"                    @Expression <Divide> (cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) could run on GPU\\n\",\n      \"                      @Expression <Cast> cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) could run on GPU\\n\",\n      \"                        @Expression <Remainder> (abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) could run on GPU\\n\",\n      \"                          @Expression <Abs> abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) could run on GPU\\n\",\n      \"                            @Expression <Murmur3Hash> hash(format_string(%s-%s, customerID#0, u_value#337), 42) could run on GPU\\n\",\n      \"                              ! <FormatString> format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\\n\",\n      \"                                @Expression <Literal> %s-%s could run on GPU\\n\",\n      \"                                @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"                                @Expression <AttributeReference> u_value#337 could run on GPU\\n\",\n      \"                          @Expression <Literal> 4096 could run on GPU\\n\",\n      \"                      @Expression <Literal> 4096.0 could run on GPU\\n\",\n      \"                    @Expression <Literal> 16801.5 could run on GPU\\n\",\n      \"                  @Expression <Literal> 6574.5 could run on GPU\\n\",\n      \"                @Expression <Add> (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) could run on GPU\\n\",\n      \"                  @Expression <Multiply> ((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) could run on GPU\\n\",\n      \"                    @Expression <Multiply> (-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) could run on GPU\\n\",\n      \"                      @Expression <UnaryMinus> -LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) could run on GPU\\n\",\n      \"                        @Expression <Log1p> LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) could run on GPU\\n\",\n      \"                          @Expression <UnaryMinus> -(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) could run on GPU\\n\",\n      \"                            @Expression <Divide> (cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) could run on GPU\\n\",\n      \"                              @Expression <Cast> cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) could run on GPU\\n\",\n      \"                                @Expression <Remainder> (abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) could run on GPU\\n\",\n      \"                                  @Expression <Abs> abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) could run on GPU\\n\",\n      \"                                    @Expression <Murmur3Hash> hash(format_string(%s-%s, customerID#0, u_value#337), 42) could run on GPU\\n\",\n      \"                                      ! <FormatString> format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\\n\",\n      \"                                        @Expression <Literal> %s-%s could run on GPU\\n\",\n      \"                                        @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"                                        @Expression <AttributeReference> u_value#337 could run on GPU\\n\",\n      \"                                  @Expression <Literal> 4096 could run on GPU\\n\",\n      \"                              @Expression <Literal> 4096.0 could run on GPU\\n\",\n      \"                      @Expression <Literal> 6.3 could run on GPU\\n\",\n      \"                    @Expression <Literal> 365.25 could run on GPU\\n\",\n      \"                  @Expression <Literal> 23741.25 could run on GPU\\n\",\n      \"      @Expression <AttributeReference> gender#1 could run on GPU\\n\",\n      \"      @Expression <AttributeReference> SeniorCitizen#2 could run on GPU\\n\",\n      \"      @Expression <AttributeReference> Partner#3 could run on GPU\\n\",\n      \"      @Expression <AttributeReference> Dependents#4 could run on GPU\\n\",\n      \"      @Expression <Alias> cast(MonthlyCharges#18 as decimal(8,2)) AS MonthlyCharges#441 could run on GPU\\n\",\n      \"        @Expression <Cast> cast(MonthlyCharges#18 as decimal(8,2)) could run on GPU\\n\",\n      \"          @Expression <AttributeReference> MonthlyCharges#18 could run on GPU\\n\",\n      \"      @Expression <Alias> 2022-04-05 09:36:19.001066 AS now#439 could run on GPU\\n\",\n      \"        @Expression <Literal> 2022-04-05 09:36:19.001066 could run on GPU\\n\",\n      \"\\n\",\n      \"2022-04-05 09:40:21,209 WARN rapids.GpuOverrides: \\n\",\n      \"!Exec <ShuffleExchangeExec> cannot run on GPU because Columnar exchange without columnar children is inefficient\\n\",\n      \"  @Partitioning <RangePartitioning> could run on GPU\\n\",\n      \"    @Expression <SortOrder> customerID#395 ASC NULLS FIRST could run on GPU\\n\",\n      \"      @Expression <AttributeReference> customerID#395 could run on GPU\\n\",\n      \"  !Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"    @Expression <Alias> format_string(%s-%s, customerID#0, u_value#337) AS customerID#395 could run on GPU\\n\",\n      \"      ! <FormatString> format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\\n\",\n      \"        @Expression <Literal> %s-%s could run on GPU\\n\",\n      \"        @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"        @Expression <AttributeReference> u_value#337 could run on GPU\\n\",\n      \"    @Expression <Alias> date_sub(2022-04-05, cast(FLOOR(CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END) as int)) AS dateOfBirth#460 could run on GPU\\n\",\n      \"      @Expression <DateSub> date_sub(2022-04-05, cast(FLOOR(CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END) as int)) could run on GPU\\n\",\n      \"        @Expression <Literal> 2022-04-05 could run on GPU\\n\",\n      \"        @Expression <Cast> cast(FLOOR(CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END) as int) could run on GPU\\n\",\n      \"          @Expression <Floor> FLOOR(CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END) could run on GPU\\n\",\n      \"            @Expression <CaseWhen> CASE WHEN (cast(SeniorCitizen#2 as int) = 0) THEN (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) ELSE (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) END could run on GPU\\n\",\n      \"              @Expression <EqualTo> (cast(SeniorCitizen#2 as int) = 0) could run on GPU\\n\",\n      \"                @Expression <Cast> cast(SeniorCitizen#2 as int) could run on GPU\\n\",\n      \"                  @Expression <AttributeReference> SeniorCitizen#2 could run on GPU\\n\",\n      \"                @Expression <Literal> 0 could run on GPU\\n\",\n      \"              @Expression <Add> (((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) + 6574.5) could run on GPU\\n\",\n      \"                @Expression <Multiply> ((cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) * 16801.5) could run on GPU\\n\",\n      \"                  @Expression <Divide> (cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) could run on GPU\\n\",\n      \"                    @Expression <Cast> cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) could run on GPU\\n\",\n      \"                      @Expression <Remainder> (abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) could run on GPU\\n\",\n      \"                        @Expression <Abs> abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) could run on GPU\\n\",\n      \"                          @Expression <Murmur3Hash> hash(format_string(%s-%s, customerID#0, u_value#337), 42) could run on GPU\\n\",\n      \"                            ! <FormatString> format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\\n\",\n      \"                              @Expression <Literal> %s-%s could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"                              @Expression <AttributeReference> u_value#337 could run on GPU\\n\",\n      \"                        @Expression <Literal> 4096 could run on GPU\\n\",\n      \"                    @Expression <Literal> 4096.0 could run on GPU\\n\",\n      \"                  @Expression <Literal> 16801.5 could run on GPU\\n\",\n      \"                @Expression <Literal> 6574.5 could run on GPU\\n\",\n      \"              @Expression <Add> (((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) + 23741.25) could run on GPU\\n\",\n      \"                @Expression <Multiply> ((-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) * 365.25) could run on GPU\\n\",\n      \"                  @Expression <Multiply> (-LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) * 6.3) could run on GPU\\n\",\n      \"                    @Expression <UnaryMinus> -LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) could run on GPU\\n\",\n      \"                      @Expression <Log1p> LOG1P(-(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0)) could run on GPU\\n\",\n      \"                        @Expression <UnaryMinus> -(cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) could run on GPU\\n\",\n      \"                          @Expression <Divide> (cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) / 4096.0) could run on GPU\\n\",\n      \"                            @Expression <Cast> cast((abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) as double) could run on GPU\\n\",\n      \"                              @Expression <Remainder> (abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) % 4096) could run on GPU\\n\",\n      \"                                @Expression <Abs> abs(hash(format_string(%s-%s, customerID#0, u_value#337), 42), false) could run on GPU\\n\",\n      \"                                  @Expression <Murmur3Hash> hash(format_string(%s-%s, customerID#0, u_value#337), 42) could run on GPU\\n\",\n      \"                                    ! <FormatString> format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\\n\",\n      \"                                      @Expression <Literal> %s-%s could run on GPU\\n\",\n      \"                                      @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"                                      @Expression <AttributeReference> u_value#337 could run on GPU\\n\",\n      \"                                @Expression <Literal> 4096 could run on GPU\\n\",\n      \"                            @Expression <Literal> 4096.0 could run on GPU\\n\",\n      \"                    @Expression <Literal> 6.3 could run on GPU\\n\",\n      \"                  @Expression <Literal> 365.25 could run on GPU\\n\",\n      \"                @Expression <Literal> 23741.25 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> gender#1 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> SeniorCitizen#2 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> Partner#3 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> Dependents#4 could run on GPU\\n\",\n      \"    @Expression <Alias> cast(MonthlyCharges#18 as decimal(8,2)) AS MonthlyCharges#441 could run on GPU\\n\",\n      \"      @Expression <Cast> cast(MonthlyCharges#18 as decimal(8,2)) could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MonthlyCharges#18 could run on GPU\\n\",\n      \"    @Expression <Alias> 2022-04-05 09:36:19.001066 AS now#439 could run on GPU\\n\",\n      \"      @Expression <Literal> 2022-04-05 09:36:19.001066 could run on GPU\\n\",\n      \"\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2022-04-05 09:40:23,697 WARN rapids.GpuOverrides:                   (0 + 1) / 1]\\n\",\n      \"  !Exec <AQEShuffleReadExec> cannot run on GPU because Unable to replace CustomShuffleReader due to child not being columnar\\n\",\n      \"\\n\",\n      \"2022-04-05 09:40:24,451 WARN rapids.GpuOverrides:                               \\n\",\n      \"! <LocalTableScanExec> cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.execution.LocalTableScanExec\\n\",\n      \"  @Expression <AttributeReference> name#894 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> database#895 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> description#896 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> tableType#897 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> isTemporary#898 could run on GPU\\n\",\n      \"\\n\",\n      \"2022-04-05 09:40:24,499 WARN rapids.GpuOverrides: \\n\",\n      \"!Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"  @Expression <Alias> format_string(%s-%s, customerID#0, u_value#337) AS customerID#910 could run on GPU\\n\",\n      \"    ! <FormatString> format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\\n\",\n      \"      @Expression <Literal> %s-%s could run on GPU\\n\",\n      \"      @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"      @Expression <AttributeReference> u_value#337 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature#479 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> value#480 could run on GPU\\n\",\n      \"\\n\",\n      \"2022-04-05 09:40:24,502 WARN rapids.GpuOverrides: \\n\",\n      \"!Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"  @Expression <Alias> format_string(%s-%s, customerID#0, u_value#337) AS customerID#910 could run on GPU\\n\",\n      \"    ! <FormatString> format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\\n\",\n      \"      @Expression <Literal> %s-%s could run on GPU\\n\",\n      \"      @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"      @Expression <AttributeReference> u_value#337 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature#479 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> value#480 could run on GPU\\n\",\n      \"\\n\",\n      \"2022-04-05 09:40:24,504 WARN rapids.GpuOverrides: \\n\",\n      \"!Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"  @Expression <Alias> format_string(%s-%s, customerID#0, u_value#337) AS customerID#910 could run on GPU\\n\",\n      \"    ! <FormatString> format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\\n\",\n      \"      @Expression <Literal> %s-%s could run on GPU\\n\",\n      \"      @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"      @Expression <AttributeReference> u_value#337 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature#479 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> value#480 could run on GPU\\n\",\n      \"\\n\",\n      \"2022-04-05 09:40:24,507 WARN rapids.GpuOverrides: \\n\",\n      \"!Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"  @Expression <Alias> format_string(%s-%s, customerID#0, u_value#337) AS customerID#910 could run on GPU\\n\",\n      \"    ! <FormatString> format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\\n\",\n      \"      @Expression <Literal> %s-%s could run on GPU\\n\",\n      \"      @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"      @Expression <AttributeReference> u_value#337 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature#479 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> value#480 could run on GPU\\n\",\n      \"\\n\",\n      \"2022-04-05 09:40:24,555 WARN rapids.GpuOverrides: \\n\",\n      \"!Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"  @Expression <Alias> format_string(%s-%s, customerID#0, u_value#337) AS customerID#910 could run on GPU\\n\",\n      \"    ! <FormatString> format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\\n\",\n      \"      @Expression <Literal> %s-%s could run on GPU\\n\",\n      \"      @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"      @Expression <AttributeReference> u_value#337 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature#479 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> value#480 could run on GPU\\n\",\n      \"\\n\",\n      \"2022-04-05 09:40:24,557 WARN rapids.GpuOverrides: \\n\",\n      \"!Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"  @Expression <Alias> format_string(%s-%s, customerID#0, u_value#337) AS customerID#910 could run on GPU\\n\",\n      \"    ! <FormatString> format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\\n\",\n      \"      @Expression <Literal> %s-%s could run on GPU\\n\",\n      \"      @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"      @Expression <AttributeReference> u_value#337 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature#479 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> value#480 could run on GPU\\n\",\n      \"\\n\",\n      \"2022-04-05 09:40:25,815 WARN rapids.GpuOverrides:                               \\n\",\n      \"! <LocalTableScanExec> cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.execution.LocalTableScanExec\\n\",\n      \"  @Expression <AttributeReference> name#946 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> database#947 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> description#948 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> tableType#949 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> isTemporary#950 could run on GPU\\n\",\n      \"\\n\",\n      \"2022-04-05 09:40:25,888 WARN rapids.GpuOverrides: \\n\",\n      \"!Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"  @Expression <Alias> format_string(%s-%s, customerID#0, u_value#337) AS customerID#962 could run on GPU\\n\",\n      \"    ! <FormatString> format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\\n\",\n      \"      @Expression <Literal> %s-%s could run on GPU\\n\",\n      \"      @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"      @Expression <AttributeReference> u_value#337 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature#513 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> value#514 could run on GPU\\n\",\n      \"\\n\",\n      \"2022-04-05 09:40:25,894 WARN rapids.GpuOverrides: \\n\",\n      \"!Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"  @Expression <Alias> format_string(%s-%s, customerID#0, u_value#337) AS customerID#962 could run on GPU\\n\",\n      \"    ! <FormatString> format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\\n\",\n      \"      @Expression <Literal> %s-%s could run on GPU\\n\",\n      \"      @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"      @Expression <AttributeReference> u_value#337 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature#513 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> value#514 could run on GPU\\n\",\n      \"\\n\",\n      \"2022-04-05 09:40:25,901 WARN rapids.GpuOverrides: \\n\",\n      \"!Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"  @Expression <Alias> format_string(%s-%s, customerID#0, u_value#337) AS customerID#962 could run on GPU\\n\",\n      \"    ! <FormatString> format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\\n\",\n      \"      @Expression <Literal> %s-%s could run on GPU\\n\",\n      \"      @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"      @Expression <AttributeReference> u_value#337 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature#513 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> value#514 could run on GPU\\n\",\n      \"\\n\",\n      \"2022-04-05 09:40:25,907 WARN rapids.GpuOverrides: \\n\",\n      \"!Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"  @Expression <Alias> format_string(%s-%s, customerID#0, u_value#337) AS customerID#962 could run on GPU\\n\",\n      \"    ! <FormatString> format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\\n\",\n      \"      @Expression <Literal> %s-%s could run on GPU\\n\",\n      \"      @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"      @Expression <AttributeReference> u_value#337 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature#513 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> value#514 could run on GPU\\n\",\n      \"\\n\",\n      \"2022-04-05 09:40:25,962 WARN rapids.GpuOverrides: \\n\",\n      \"!Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"  @Expression <Alias> format_string(%s-%s, customerID#0, u_value#337) AS customerID#962 could run on GPU\\n\",\n      \"    ! <FormatString> format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\\n\",\n      \"      @Expression <Literal> %s-%s could run on GPU\\n\",\n      \"      @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"      @Expression <AttributeReference> u_value#337 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature#513 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> value#514 could run on GPU\\n\",\n      \"\\n\",\n      \"2022-04-05 09:40:25,967 WARN rapids.GpuOverrides: \\n\",\n      \"!Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"  @Expression <Alias> format_string(%s-%s, customerID#0, u_value#337) AS customerID#962 could run on GPU\\n\",\n      \"    ! <FormatString> format_string(%s-%s, customerID#0, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\\n\",\n      \"      @Expression <Literal> %s-%s could run on GPU\\n\",\n      \"      @Expression <AttributeReference> customerID#0 could run on GPU\\n\",\n      \"      @Expression <AttributeReference> u_value#337 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature#513 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> value#514 could run on GPU\\n\",\n      \"\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2022-04-05 09:40:28,911 WARN rapids.GpuOverrides:                               \\n\",\n      \"! <LocalTableScanExec> cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.execution.LocalTableScanExec\\n\",\n      \"  @Expression <AttributeReference> name#998 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> database#999 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> description#1000 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> tableType#1001 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> isTemporary#1002 could run on GPU\\n\",\n      \"\\n\",\n      \"2022-04-05 09:40:28,964 WARN rapids.GpuOverrides: \\n\",\n      \"!Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"  @Expression <Alias> format_string(%s-%s, customerID#721, u_value#337) AS customerID#1014 could run on GPU\\n\",\n      \"    ! <FormatString> format_string(%s-%s, customerID#721, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\\n\",\n      \"      @Expression <Literal> %s-%s could run on GPU\\n\",\n      \"      @Expression <AttributeReference> customerID#721 could run on GPU\\n\",\n      \"      @Expression <AttributeReference> u_value#337 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature#722 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> value#723 could run on GPU\\n\",\n      \"      ! <RDDScanExec> cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.execution.RDDScanExec\\n\",\n      \"        @Expression <AttributeReference> customerID#721 could run on GPU\\n\",\n      \"        @Expression <AttributeReference> feature#722 could run on GPU\\n\",\n      \"        @Expression <AttributeReference> value#723 could run on GPU\\n\",\n      \"\\n\",\n      \"2022-04-05 09:40:28,967 WARN rapids.GpuOverrides: \\n\",\n      \"!Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"  @Expression <Alias> format_string(%s-%s, customerID#721, u_value#337) AS customerID#1014 could run on GPU\\n\",\n      \"    ! <FormatString> format_string(%s-%s, customerID#721, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\\n\",\n      \"      @Expression <Literal> %s-%s could run on GPU\\n\",\n      \"      @Expression <AttributeReference> customerID#721 could run on GPU\\n\",\n      \"      @Expression <AttributeReference> u_value#337 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature#722 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> value#723 could run on GPU\\n\",\n      \"      ! <RDDScanExec> cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.execution.RDDScanExec\\n\",\n      \"        @Expression <AttributeReference> customerID#721 could run on GPU\\n\",\n      \"        @Expression <AttributeReference> feature#722 could run on GPU\\n\",\n      \"        @Expression <AttributeReference> value#723 could run on GPU\\n\",\n      \"\\n\",\n      \"2022-04-05 09:40:28,970 WARN rapids.GpuOverrides: \\n\",\n      \"!Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"  @Expression <Alias> format_string(%s-%s, customerID#721, u_value#337) AS customerID#1014 could run on GPU\\n\",\n      \"    ! <FormatString> format_string(%s-%s, customerID#721, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\\n\",\n      \"      @Expression <Literal> %s-%s could run on GPU\\n\",\n      \"      @Expression <AttributeReference> customerID#721 could run on GPU\\n\",\n      \"      @Expression <AttributeReference> u_value#337 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature#722 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> value#723 could run on GPU\\n\",\n      \"      ! <RDDScanExec> cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.execution.RDDScanExec\\n\",\n      \"        @Expression <AttributeReference> customerID#721 could run on GPU\\n\",\n      \"        @Expression <AttributeReference> feature#722 could run on GPU\\n\",\n      \"        @Expression <AttributeReference> value#723 could run on GPU\\n\",\n      \"\\n\",\n      \"2022-04-05 09:40:28,973 WARN rapids.GpuOverrides: \\n\",\n      \"!Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"  @Expression <Alias> format_string(%s-%s, customerID#721, u_value#337) AS customerID#1014 could run on GPU\\n\",\n      \"    ! <FormatString> format_string(%s-%s, customerID#721, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\\n\",\n      \"      @Expression <Literal> %s-%s could run on GPU\\n\",\n      \"      @Expression <AttributeReference> customerID#721 could run on GPU\\n\",\n      \"      @Expression <AttributeReference> u_value#337 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature#722 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> value#723 could run on GPU\\n\",\n      \"      ! <RDDScanExec> cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.execution.RDDScanExec\\n\",\n      \"        @Expression <AttributeReference> customerID#721 could run on GPU\\n\",\n      \"        @Expression <AttributeReference> feature#722 could run on GPU\\n\",\n      \"        @Expression <AttributeReference> value#723 could run on GPU\\n\",\n      \"\\n\",\n      \"2022-04-05 09:40:29,023 WARN rapids.GpuOverrides: \\n\",\n      \"!Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"  @Expression <Alias> format_string(%s-%s, customerID#721, u_value#337) AS customerID#1014 could run on GPU\\n\",\n      \"    ! <FormatString> format_string(%s-%s, customerID#721, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\\n\",\n      \"      @Expression <Literal> %s-%s could run on GPU\\n\",\n      \"      @Expression <AttributeReference> customerID#721 could run on GPU\\n\",\n      \"      @Expression <AttributeReference> u_value#337 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature#722 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> value#723 could run on GPU\\n\",\n      \"      ! <RDDScanExec> cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.execution.RDDScanExec\\n\",\n      \"        @Expression <AttributeReference> customerID#721 could run on GPU\\n\",\n      \"        @Expression <AttributeReference> feature#722 could run on GPU\\n\",\n      \"        @Expression <AttributeReference> value#723 could run on GPU\\n\",\n      \"\\n\",\n      \"2022-04-05 09:40:29,026 WARN rapids.GpuOverrides: \\n\",\n      \"!Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"  @Expression <Alias> format_string(%s-%s, customerID#721, u_value#337) AS customerID#1014 could run on GPU\\n\",\n      \"    ! <FormatString> format_string(%s-%s, customerID#721, u_value#337) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.FormatString\\n\",\n      \"      @Expression <Literal> %s-%s could run on GPU\\n\",\n      \"      @Expression <AttributeReference> customerID#721 could run on GPU\\n\",\n      \"      @Expression <AttributeReference> u_value#337 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature#722 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> value#723 could run on GPU\\n\",\n      \"      ! <RDDScanExec> cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.execution.RDDScanExec\\n\",\n      \"        @Expression <AttributeReference> customerID#721 could run on GPU\\n\",\n      \"        @Expression <AttributeReference> feature#722 could run on GPU\\n\",\n      \"        @Expression <AttributeReference> value#723 could run on GPU\\n\",\n      \"\\n\",\n      \"[Stage 41:==================================================>     (10 + 1) / 11]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 214 ms, sys: 34 ms, total: 248 ms\\n\",\n      \"Wall time: 3min 54s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\r\",\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"\\n\",\n    \"from churn.augment import write_df\\n\",\n    \"\\n\",\n    \"write_df(billingEvents, \\\"billing_events\\\", partition_by=\\\"month\\\")\\n\",\n    \"write_df(customerMeta, \\\"customer_meta\\\", skip_replication=True)\\n\",\n    \"write_df(customerPhoneFeatures, \\\"customer_phone_features\\\")\\n\",\n    \"write_df(customerInternetFeatures.orderBy(\\\"customerID\\\"), \\\"customer_internet_features\\\")\\n\",\n    \"write_df(customerAccountFeatures, \\\"customer_account_features\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 15,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"billing_events 703200\\n\",\n      \"customer_meta 703200\\n\",\n      \"customer_phone_features 635200\\n\",\n      \"customer_internet_features 551200\\n\",\n      \"customer_account_features 703200\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"for f in [\\\"billing_events\\\", \\\"customer_meta\\\", \\\"customer_phone_features\\\", \\\"customer_internet_features\\\", \\\"customer_account_features\\\"]:\\n\",\n    \"    output_df = session.read.parquet(\\\"%s.parquet\\\" % f)\\n\",\n    \"    print(f, output_df.select(\\\"customerID\\\").distinct().count())\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 16,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"import pyspark.sql.functions as F\\n\",\n    \"from functools import reduce\\n\",\n    \"\\n\",\n    \"output_dfs = []\\n\",\n    \"\\n\",\n    \"for f in [\\\"billing_events\\\", \\\"customer_meta\\\", \\\"customer_phone_features\\\", \\\"customer_internet_features\\\", \\\"customer_account_features\\\"]:\\n\",\n    \"    output_dfs.append(\\n\",\n    \"        session.read.parquet(\\\"%s.parquet\\\" % f).select(\\n\",\n    \"            F.lit(f).alias(\\\"table\\\"),\\n\",\n    \"            \\\"customerID\\\"\\n\",\n    \"        )\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"all_customers = reduce(lambda l, r: l.unionAll(r), output_dfs)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 17,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2022-04-05 09:41:25,790 WARN rapids.GpuOverrides: \\n\",\n      \"!Exec <CollectLimitExec> cannot run on GPU because the Exec CollectLimitExec has been disabled, and is disabled by default because Collect Limit replacement can be slower on the GPU, if huge number of rows in a batch it could help by limiting the number of rows transferred from GPU to CPU. Set spark.rapids.sql.exec.CollectLimitExec to true if you wish to enable it\\n\",\n      \"  @Partitioning <SinglePartition$> could run on GPU\\n\",\n      \"      !Exec <HashAggregateExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"        @Expression <AttributeReference> table#1189 could run on GPU\\n\",\n      \"        @Expression <AggregateExpression> approx_count_distinct(customerID#1179, 0.05, 0, 0) could run on GPU\\n\",\n      \"          ! <HyperLogLogPlusPlus> approx_count_distinct(customerID#1179, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\\n\",\n      \"            @Expression <AttributeReference> customerID#1179 could run on GPU\\n\",\n      \"        @Expression <AttributeReference> approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> table#1189 could run on GPU\\n\",\n      \"        @Expression <Alias> cast(approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L as string) AS approx_unique_customers#2118 could run on GPU\\n\",\n      \"          @Expression <Cast> cast(approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L as string) could run on GPU\\n\",\n      \"            @Expression <AttributeReference> approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L could run on GPU\\n\",\n      \"        !Exec <ShuffleExchangeExec> cannot run on GPU because Columnar exchange without columnar children is inefficient\\n\",\n      \"          @Partitioning <HashPartitioning> could run on GPU\\n\",\n      \"            @Expression <AttributeReference> table#1189 could run on GPU\\n\",\n      \"          !Exec <HashAggregateExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"            @Expression <AttributeReference> table#1189 could run on GPU\\n\",\n      \"            @Expression <AggregateExpression> partial_approx_count_distinct(customerID#1179, 0.05, 0, 0) could run on GPU\\n\",\n      \"              ! <HyperLogLogPlusPlus> approx_count_distinct(customerID#1179, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\\n\",\n      \"                @Expression <AttributeReference> customerID#1179 could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[0]#1354L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[1]#1355L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[2]#1356L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[3]#1357L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[4]#1358L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[5]#1359L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[6]#1360L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[7]#1361L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[8]#1362L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[9]#1363L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[10]#1364L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[11]#1365L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[12]#1366L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[13]#1367L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[14]#1368L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[15]#1369L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[16]#1370L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[17]#1371L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[18]#1372L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[19]#1373L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[20]#1374L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[21]#1375L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[22]#1376L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[23]#1377L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[24]#1378L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[25]#1379L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[26]#1380L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[27]#1381L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[28]#1382L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[29]#1383L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[30]#1384L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[31]#1385L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[32]#1386L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[33]#1387L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[34]#1388L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[35]#1389L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[36]#1390L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[37]#1391L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[38]#1392L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[39]#1393L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[40]#1394L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[41]#1395L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[42]#1396L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[43]#1397L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[44]#1398L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[45]#1399L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[46]#1400L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[47]#1401L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[48]#1402L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[49]#1403L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[50]#1404L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[51]#1405L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> table#1189 could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[0]#1406L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[1]#1407L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[2]#1408L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[3]#1409L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[4]#1410L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[5]#1411L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[6]#1412L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[7]#1413L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[8]#1414L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[9]#1415L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[10]#1416L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[11]#1417L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[12]#1418L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[13]#1419L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[14]#1420L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[15]#1421L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[16]#1422L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[17]#1423L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[18]#1424L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[19]#1425L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[20]#1426L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[21]#1427L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[22]#1428L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[23]#1429L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[24]#1430L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[25]#1431L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[26]#1432L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[27]#1433L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[28]#1434L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[29]#1435L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[30]#1436L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[31]#1437L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[32]#1438L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[33]#1439L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[34]#1440L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[35]#1441L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[36]#1442L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[37]#1443L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[38]#1444L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[39]#1445L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[40]#1446L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[41]#1447L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[42]#1448L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[43]#1449L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[44]#1450L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[45]#1451L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[46]#1452L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[47]#1453L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[48]#1454L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[49]#1455L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[50]#1456L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[51]#1457L could run on GPU\\n\",\n      \"      !Exec <HashAggregateExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"        @Expression <AttributeReference> 0#2539 could run on GPU\\n\",\n      \"        @Expression <AggregateExpression> approx_count_distinct(customerID#1883, 0.05, 0, 0) could run on GPU\\n\",\n      \"          ! <HyperLogLogPlusPlus> approx_count_distinct(customerID#1883, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\\n\",\n      \"            @Expression <AttributeReference> customerID#1883 could run on GPU\\n\",\n      \"        @Expression <AttributeReference> approx_count_distinct(customerID#1883, 0.05, 0, 0)#1671L could run on GPU\\n\",\n      \"        @Expression <Alias> all AS table#2537 could run on GPU\\n\",\n      \"          @Expression <Literal> all could run on GPU\\n\",\n      \"        @Expression <Alias> cast(approx_count_distinct(customerID#1883, 0.05, 0, 0)#1671L as string) AS approx_unique_customers#2538 could run on GPU\\n\",\n      \"          @Expression <Cast> cast(approx_count_distinct(customerID#1883, 0.05, 0, 0)#1671L as string) could run on GPU\\n\",\n      \"            @Expression <AttributeReference> approx_count_distinct(customerID#1883, 0.05, 0, 0)#1671L could run on GPU\\n\",\n      \"        !Exec <ShuffleExchangeExec> cannot run on GPU because Columnar exchange without columnar children is inefficient\\n\",\n      \"          @Partitioning <HashPartitioning> could run on GPU\\n\",\n      \"            @Expression <AttributeReference> 0#2539 could run on GPU\\n\",\n      \"          !Exec <HashAggregateExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"            @Expression <Alias> 0 AS 0#2539 could run on GPU\\n\",\n      \"              @Expression <Literal> 0 could run on GPU\\n\",\n      \"            @Expression <AggregateExpression> partial_approx_count_distinct(customerID#1883, 0.05, 0, 0) could run on GPU\\n\",\n      \"              ! <HyperLogLogPlusPlus> approx_count_distinct(customerID#1883, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\\n\",\n      \"                @Expression <AttributeReference> customerID#1883 could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[0]#1905L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[1]#1906L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[2]#1907L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[3]#1908L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[4]#1909L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[5]#1910L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[6]#1911L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[7]#1912L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[8]#1913L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[9]#1914L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[10]#1915L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[11]#1916L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[12]#1917L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[13]#1918L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[14]#1919L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[15]#1920L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[16]#1921L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[17]#1922L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[18]#1923L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[19]#1924L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[20]#1925L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[21]#1926L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[22]#1927L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[23]#1928L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[24]#1929L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[25]#1930L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[26]#1931L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[27]#1932L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[28]#1933L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[29]#1934L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[30]#1935L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[31]#1936L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[32]#1937L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[33]#1938L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[34]#1939L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[35]#1940L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[36]#1941L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[37]#1942L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[38]#1943L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[39]#1944L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[40]#1945L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[41]#1946L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[42]#1947L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[43]#1948L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[44]#1949L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[45]#1950L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[46]#1951L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[47]#1952L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[48]#1953L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[49]#1954L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[50]#1955L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[51]#1956L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> 0#2539 could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[0]#1957L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[1]#1958L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[2]#1959L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[3]#1960L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[4]#1961L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[5]#1962L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[6]#1963L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[7]#1964L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[8]#1965L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[9]#1966L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[10]#1967L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[11]#1968L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[12]#1969L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[13]#1970L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[14]#1971L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[15]#1972L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[16]#1973L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[17]#1974L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[18]#1975L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[19]#1976L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[20]#1977L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[21]#1978L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[22]#1979L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[23]#1980L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[24]#1981L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[25]#1982L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[26]#1983L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[27]#1984L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[28]#1985L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[29]#1986L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[30]#1987L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[31]#1988L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[32]#1989L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[33]#1990L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[34]#1991L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[35]#1992L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[36]#1993L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[37]#1994L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[38]#1995L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[39]#1996L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[40]#1997L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[41]#1998L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[42]#1999L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[43]#2000L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[44]#2001L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[45]#2002L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[46]#2003L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[47]#2004L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[48]#2005L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[49]#2006L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[50]#2007L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[51]#2008L could run on GPU\\n\",\n      \"\\n\",\n      \"2022-04-05 09:41:25,794 WARN rapids.GpuOverrides: \\n\",\n      \"!Exec <CollectLimitExec> cannot run on GPU because the Exec CollectLimitExec has been disabled, and is disabled by default because Collect Limit replacement can be slower on the GPU, if huge number of rows in a batch it could help by limiting the number of rows transferred from GPU to CPU. Set spark.rapids.sql.exec.CollectLimitExec to true if you wish to enable it\\n\",\n      \"  @Partitioning <SinglePartition$> could run on GPU\\n\",\n      \"      !Exec <HashAggregateExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"        @Expression <AttributeReference> table#1189 could run on GPU\\n\",\n      \"        @Expression <AggregateExpression> approx_count_distinct(customerID#1179, 0.05, 0, 0) could run on GPU\\n\",\n      \"          ! <HyperLogLogPlusPlus> approx_count_distinct(customerID#1179, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\\n\",\n      \"            @Expression <AttributeReference> customerID#1179 could run on GPU\\n\",\n      \"        @Expression <AttributeReference> approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> table#1189 could run on GPU\\n\",\n      \"        @Expression <Alias> cast(approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L as string) AS approx_unique_customers#2118 could run on GPU\\n\",\n      \"          @Expression <Cast> cast(approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L as string) could run on GPU\\n\",\n      \"            @Expression <AttributeReference> approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L could run on GPU\\n\",\n      \"        !Exec <ShuffleExchangeExec> cannot run on GPU because Columnar exchange without columnar children is inefficient\\n\",\n      \"          @Partitioning <HashPartitioning> could run on GPU\\n\",\n      \"            @Expression <AttributeReference> table#1189 could run on GPU\\n\",\n      \"          !Exec <HashAggregateExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"            @Expression <AttributeReference> table#1189 could run on GPU\\n\",\n      \"            @Expression <AggregateExpression> partial_approx_count_distinct(customerID#1179, 0.05, 0, 0) could run on GPU\\n\",\n      \"              ! <HyperLogLogPlusPlus> approx_count_distinct(customerID#1179, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\\n\",\n      \"                @Expression <AttributeReference> customerID#1179 could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[0]#1354L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[1]#1355L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[2]#1356L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[3]#1357L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[4]#1358L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[5]#1359L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[6]#1360L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[7]#1361L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[8]#1362L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[9]#1363L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[10]#1364L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[11]#1365L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[12]#1366L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[13]#1367L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[14]#1368L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[15]#1369L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[16]#1370L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[17]#1371L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[18]#1372L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[19]#1373L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[20]#1374L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[21]#1375L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[22]#1376L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[23]#1377L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[24]#1378L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[25]#1379L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[26]#1380L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[27]#1381L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[28]#1382L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[29]#1383L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[30]#1384L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[31]#1385L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[32]#1386L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[33]#1387L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[34]#1388L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[35]#1389L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[36]#1390L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[37]#1391L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[38]#1392L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[39]#1393L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[40]#1394L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[41]#1395L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[42]#1396L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[43]#1397L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[44]#1398L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[45]#1399L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[46]#1400L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[47]#1401L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[48]#1402L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[49]#1403L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[50]#1404L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[51]#1405L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> table#1189 could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[0]#1406L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[1]#1407L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[2]#1408L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[3]#1409L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[4]#1410L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[5]#1411L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[6]#1412L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[7]#1413L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[8]#1414L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[9]#1415L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[10]#1416L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[11]#1417L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[12]#1418L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[13]#1419L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[14]#1420L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[15]#1421L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[16]#1422L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[17]#1423L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[18]#1424L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[19]#1425L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[20]#1426L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[21]#1427L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[22]#1428L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[23]#1429L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[24]#1430L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[25]#1431L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[26]#1432L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[27]#1433L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[28]#1434L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[29]#1435L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[30]#1436L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[31]#1437L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[32]#1438L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[33]#1439L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[34]#1440L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[35]#1441L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[36]#1442L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[37]#1443L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[38]#1444L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[39]#1445L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[40]#1446L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[41]#1447L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[42]#1448L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[43]#1449L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[44]#1450L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[45]#1451L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[46]#1452L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[47]#1453L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[48]#1454L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[49]#1455L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[50]#1456L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[51]#1457L could run on GPU\\n\",\n      \"      !Exec <HashAggregateExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"        @Expression <AttributeReference> 0#2539 could run on GPU\\n\",\n      \"        @Expression <AggregateExpression> approx_count_distinct(customerID#1883, 0.05, 0, 0) could run on GPU\\n\",\n      \"          ! <HyperLogLogPlusPlus> approx_count_distinct(customerID#1883, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\\n\",\n      \"            @Expression <AttributeReference> customerID#1883 could run on GPU\\n\",\n      \"        @Expression <AttributeReference> approx_count_distinct(customerID#1883, 0.05, 0, 0)#1671L could run on GPU\\n\",\n      \"        @Expression <Alias> all AS table#2537 could run on GPU\\n\",\n      \"          @Expression <Literal> all could run on GPU\\n\",\n      \"        @Expression <Alias> cast(approx_count_distinct(customerID#1883, 0.05, 0, 0)#1671L as string) AS approx_unique_customers#2538 could run on GPU\\n\",\n      \"          @Expression <Cast> cast(approx_count_distinct(customerID#1883, 0.05, 0, 0)#1671L as string) could run on GPU\\n\",\n      \"            @Expression <AttributeReference> approx_count_distinct(customerID#1883, 0.05, 0, 0)#1671L could run on GPU\\n\",\n      \"        !Exec <ShuffleExchangeExec> cannot run on GPU because Columnar exchange without columnar children is inefficient\\n\",\n      \"          @Partitioning <HashPartitioning> could run on GPU\\n\",\n      \"            @Expression <AttributeReference> 0#2539 could run on GPU\\n\",\n      \"          !Exec <HashAggregateExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"            @Expression <Alias> 0 AS 0#2539 could run on GPU\\n\",\n      \"              @Expression <Literal> 0 could run on GPU\\n\",\n      \"            @Expression <AggregateExpression> partial_approx_count_distinct(customerID#1883, 0.05, 0, 0) could run on GPU\\n\",\n      \"              ! <HyperLogLogPlusPlus> approx_count_distinct(customerID#1883, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\\n\",\n      \"                @Expression <AttributeReference> customerID#1883 could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[0]#1905L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[1]#1906L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[2]#1907L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[3]#1908L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[4]#1909L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[5]#1910L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[6]#1911L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[7]#1912L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[8]#1913L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[9]#1914L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[10]#1915L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[11]#1916L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[12]#1917L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[13]#1918L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[14]#1919L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[15]#1920L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[16]#1921L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[17]#1922L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[18]#1923L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[19]#1924L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[20]#1925L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[21]#1926L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[22]#1927L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[23]#1928L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[24]#1929L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[25]#1930L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[26]#1931L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[27]#1932L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[28]#1933L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[29]#1934L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[30]#1935L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[31]#1936L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[32]#1937L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[33]#1938L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[34]#1939L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[35]#1940L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[36]#1941L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[37]#1942L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[38]#1943L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[39]#1944L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[40]#1945L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[41]#1946L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[42]#1947L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[43]#1948L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[44]#1949L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[45]#1950L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[46]#1951L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[47]#1952L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[48]#1953L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[49]#1954L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[50]#1955L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[51]#1956L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> 0#2539 could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[0]#1957L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[1]#1958L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[2]#1959L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[3]#1960L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[4]#1961L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[5]#1962L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[6]#1963L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[7]#1964L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[8]#1965L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[9]#1966L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[10]#1967L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[11]#1968L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[12]#1969L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[13]#1970L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[14]#1971L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[15]#1972L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[16]#1973L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[17]#1974L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[18]#1975L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[19]#1976L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[20]#1977L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[21]#1978L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[22]#1979L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[23]#1980L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[24]#1981L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[25]#1982L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[26]#1983L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[27]#1984L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[28]#1985L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[29]#1986L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[30]#1987L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[31]#1988L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[32]#1989L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[33]#1990L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[34]#1991L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[35]#1992L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[36]#1993L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[37]#1994L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[38]#1995L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[39]#1996L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[40]#1997L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[41]#1998L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[42]#1999L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[43]#2000L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[44]#2001L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[45]#2002L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[46]#2003L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[47]#2004L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[48]#2005L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[49]#2006L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[50]#2007L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[51]#2008L could run on GPU\\n\",\n      \"\\n\",\n      \"2022-04-05 09:41:25,797 WARN rapids.GpuOverrides: \\n\",\n      \"!Exec <CollectLimitExec> cannot run on GPU because the Exec CollectLimitExec has been disabled, and is disabled by default because Collect Limit replacement can be slower on the GPU, if huge number of rows in a batch it could help by limiting the number of rows transferred from GPU to CPU. Set spark.rapids.sql.exec.CollectLimitExec to true if you wish to enable it\\n\",\n      \"  @Partitioning <SinglePartition$> could run on GPU\\n\",\n      \"      !Exec <HashAggregateExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"        @Expression <AttributeReference> table#1189 could run on GPU\\n\",\n      \"        @Expression <AggregateExpression> approx_count_distinct(customerID#1179, 0.05, 0, 0) could run on GPU\\n\",\n      \"          ! <HyperLogLogPlusPlus> approx_count_distinct(customerID#1179, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\\n\",\n      \"            @Expression <AttributeReference> customerID#1179 could run on GPU\\n\",\n      \"        @Expression <AttributeReference> approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> table#1189 could run on GPU\\n\",\n      \"        @Expression <Alias> cast(approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L as string) AS approx_unique_customers#2118 could run on GPU\\n\",\n      \"          @Expression <Cast> cast(approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L as string) could run on GPU\\n\",\n      \"            @Expression <AttributeReference> approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L could run on GPU\\n\",\n      \"        !Exec <ShuffleExchangeExec> cannot run on GPU because Columnar exchange without columnar children is inefficient\\n\",\n      \"          @Partitioning <HashPartitioning> could run on GPU\\n\",\n      \"            @Expression <AttributeReference> table#1189 could run on GPU\\n\",\n      \"          !Exec <HashAggregateExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"            @Expression <AttributeReference> table#1189 could run on GPU\\n\",\n      \"            @Expression <AggregateExpression> partial_approx_count_distinct(customerID#1179, 0.05, 0, 0) could run on GPU\\n\",\n      \"              ! <HyperLogLogPlusPlus> approx_count_distinct(customerID#1179, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\\n\",\n      \"                @Expression <AttributeReference> customerID#1179 could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[0]#1354L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[1]#1355L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[2]#1356L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[3]#1357L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[4]#1358L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[5]#1359L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[6]#1360L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[7]#1361L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[8]#1362L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[9]#1363L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[10]#1364L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[11]#1365L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[12]#1366L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[13]#1367L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[14]#1368L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[15]#1369L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[16]#1370L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[17]#1371L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[18]#1372L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[19]#1373L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[20]#1374L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[21]#1375L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[22]#1376L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[23]#1377L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[24]#1378L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[25]#1379L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[26]#1380L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[27]#1381L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[28]#1382L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[29]#1383L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[30]#1384L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[31]#1385L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[32]#1386L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[33]#1387L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[34]#1388L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[35]#1389L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[36]#1390L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[37]#1391L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[38]#1392L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[39]#1393L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[40]#1394L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[41]#1395L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[42]#1396L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[43]#1397L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[44]#1398L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[45]#1399L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[46]#1400L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[47]#1401L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[48]#1402L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[49]#1403L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[50]#1404L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[51]#1405L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> table#1189 could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[0]#1406L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[1]#1407L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[2]#1408L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[3]#1409L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[4]#1410L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[5]#1411L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[6]#1412L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[7]#1413L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[8]#1414L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[9]#1415L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[10]#1416L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[11]#1417L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[12]#1418L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[13]#1419L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[14]#1420L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[15]#1421L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[16]#1422L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[17]#1423L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[18]#1424L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[19]#1425L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[20]#1426L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[21]#1427L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[22]#1428L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[23]#1429L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[24]#1430L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[25]#1431L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[26]#1432L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[27]#1433L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[28]#1434L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[29]#1435L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[30]#1436L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[31]#1437L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[32]#1438L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[33]#1439L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[34]#1440L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[35]#1441L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[36]#1442L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[37]#1443L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[38]#1444L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[39]#1445L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[40]#1446L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[41]#1447L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[42]#1448L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[43]#1449L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[44]#1450L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[45]#1451L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[46]#1452L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[47]#1453L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[48]#1454L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[49]#1455L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[50]#1456L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[51]#1457L could run on GPU\\n\",\n      \"      !Exec <HashAggregateExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"        @Expression <AttributeReference> 0#2539 could run on GPU\\n\",\n      \"        @Expression <AggregateExpression> approx_count_distinct(customerID#1883, 0.05, 0, 0) could run on GPU\\n\",\n      \"          ! <HyperLogLogPlusPlus> approx_count_distinct(customerID#1883, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\\n\",\n      \"            @Expression <AttributeReference> customerID#1883 could run on GPU\\n\",\n      \"        @Expression <AttributeReference> approx_count_distinct(customerID#1883, 0.05, 0, 0)#1671L could run on GPU\\n\",\n      \"        @Expression <Alias> all AS table#2537 could run on GPU\\n\",\n      \"          @Expression <Literal> all could run on GPU\\n\",\n      \"        @Expression <Alias> cast(approx_count_distinct(customerID#1883, 0.05, 0, 0)#1671L as string) AS approx_unique_customers#2538 could run on GPU\\n\",\n      \"          @Expression <Cast> cast(approx_count_distinct(customerID#1883, 0.05, 0, 0)#1671L as string) could run on GPU\\n\",\n      \"            @Expression <AttributeReference> approx_count_distinct(customerID#1883, 0.05, 0, 0)#1671L could run on GPU\\n\",\n      \"        !Exec <ShuffleExchangeExec> cannot run on GPU because Columnar exchange without columnar children is inefficient\\n\",\n      \"          @Partitioning <HashPartitioning> could run on GPU\\n\",\n      \"            @Expression <AttributeReference> 0#2539 could run on GPU\\n\",\n      \"          !Exec <HashAggregateExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"            @Expression <Alias> 0 AS 0#2539 could run on GPU\\n\",\n      \"              @Expression <Literal> 0 could run on GPU\\n\",\n      \"            @Expression <AggregateExpression> partial_approx_count_distinct(customerID#1883, 0.05, 0, 0) could run on GPU\\n\",\n      \"              ! <HyperLogLogPlusPlus> approx_count_distinct(customerID#1883, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\\n\",\n      \"                @Expression <AttributeReference> customerID#1883 could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[0]#1905L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[1]#1906L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[2]#1907L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[3]#1908L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[4]#1909L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[5]#1910L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[6]#1911L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[7]#1912L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[8]#1913L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[9]#1914L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[10]#1915L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[11]#1916L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[12]#1917L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[13]#1918L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[14]#1919L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[15]#1920L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[16]#1921L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[17]#1922L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[18]#1923L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[19]#1924L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[20]#1925L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[21]#1926L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[22]#1927L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[23]#1928L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[24]#1929L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[25]#1930L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[26]#1931L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[27]#1932L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[28]#1933L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[29]#1934L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[30]#1935L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[31]#1936L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[32]#1937L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[33]#1938L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[34]#1939L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[35]#1940L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[36]#1941L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[37]#1942L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[38]#1943L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[39]#1944L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[40]#1945L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[41]#1946L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[42]#1947L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[43]#1948L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[44]#1949L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[45]#1950L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[46]#1951L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[47]#1952L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[48]#1953L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[49]#1954L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[50]#1955L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[51]#1956L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> 0#2539 could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[0]#1957L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[1]#1958L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[2]#1959L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[3]#1960L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[4]#1961L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[5]#1962L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[6]#1963L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[7]#1964L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[8]#1965L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[9]#1966L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[10]#1967L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[11]#1968L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[12]#1969L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[13]#1970L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[14]#1971L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[15]#1972L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[16]#1973L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[17]#1974L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[18]#1975L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[19]#1976L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[20]#1977L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[21]#1978L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[22]#1979L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[23]#1980L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[24]#1981L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[25]#1982L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[26]#1983L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[27]#1984L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[28]#1985L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[29]#1986L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[30]#1987L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[31]#1988L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[32]#1989L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[33]#1990L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[34]#1991L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[35]#1992L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[36]#1993L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[37]#1994L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[38]#1995L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[39]#1996L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[40]#1997L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[41]#1998L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[42]#1999L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[43]#2000L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[44]#2001L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[45]#2002L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[46]#2003L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[47]#2004L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[48]#2005L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[49]#2006L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[50]#2007L could run on GPU\\n\",\n      \"            @Expression <AttributeReference> MS[51]#2008L could run on GPU\\n\",\n      \"\\n\",\n      \"2022-04-05 09:41:25,801 WARN util.package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.\\n\",\n      \"2022-04-05 09:41:25,806 WARN rapids.GpuOverrides: \\n\",\n      \"!Exec <ShuffleExchangeExec> cannot run on GPU because Columnar exchange without columnar children is inefficient\\n\",\n      \"  @Partitioning <HashPartitioning> could run on GPU\\n\",\n      \"    @Expression <AttributeReference> table#1189 could run on GPU\\n\",\n      \"  !Exec <HashAggregateExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"    @Expression <AttributeReference> table#1189 could run on GPU\\n\",\n      \"    @Expression <AggregateExpression> partial_approx_count_distinct(customerID#1179, 0.05, 0, 0) could run on GPU\\n\",\n      \"      ! <HyperLogLogPlusPlus> approx_count_distinct(customerID#1179, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\\n\",\n      \"        @Expression <AttributeReference> customerID#1179 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[0]#1354L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[1]#1355L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[2]#1356L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[3]#1357L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[4]#1358L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[5]#1359L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[6]#1360L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[7]#1361L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[8]#1362L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[9]#1363L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[10]#1364L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[11]#1365L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[12]#1366L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[13]#1367L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[14]#1368L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[15]#1369L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[16]#1370L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[17]#1371L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[18]#1372L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[19]#1373L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[20]#1374L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[21]#1375L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[22]#1376L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[23]#1377L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[24]#1378L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[25]#1379L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[26]#1380L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[27]#1381L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[28]#1382L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[29]#1383L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[30]#1384L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[31]#1385L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[32]#1386L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[33]#1387L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[34]#1388L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[35]#1389L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[36]#1390L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[37]#1391L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[38]#1392L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[39]#1393L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[40]#1394L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[41]#1395L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[42]#1396L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[43]#1397L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[44]#1398L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[45]#1399L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[46]#1400L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[47]#1401L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[48]#1402L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[49]#1403L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[50]#1404L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[51]#1405L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> table#1189 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[0]#1406L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[1]#1407L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[2]#1408L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[3]#1409L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[4]#1410L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[5]#1411L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[6]#1412L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[7]#1413L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[8]#1414L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[9]#1415L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[10]#1416L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[11]#1417L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[12]#1418L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[13]#1419L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[14]#1420L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[15]#1421L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[16]#1422L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[17]#1423L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[18]#1424L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[19]#1425L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[20]#1426L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[21]#1427L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[22]#1428L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[23]#1429L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[24]#1430L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[25]#1431L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[26]#1432L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[27]#1433L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[28]#1434L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[29]#1435L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[30]#1436L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[31]#1437L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[32]#1438L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[33]#1439L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[34]#1440L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[35]#1441L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[36]#1442L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[37]#1443L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[38]#1444L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[39]#1445L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[40]#1446L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[41]#1447L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[42]#1448L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[43]#1449L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[44]#1450L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[45]#1451L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[46]#1452L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[47]#1453L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[48]#1454L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[49]#1455L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[50]#1456L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[51]#1457L could run on GPU\\n\",\n      \"\\n\",\n      \"2022-04-05 09:41:25,810 WARN rapids.GpuOverrides: \\n\",\n      \"!Exec <ShuffleExchangeExec> cannot run on GPU because Columnar exchange without columnar children is inefficient\\n\",\n      \"  @Partitioning <HashPartitioning> could run on GPU\\n\",\n      \"    @Expression <AttributeReference> 0#2539 could run on GPU\\n\",\n      \"  !Exec <HashAggregateExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"    @Expression <Alias> 0 AS 0#2539 could run on GPU\\n\",\n      \"      @Expression <Literal> 0 could run on GPU\\n\",\n      \"    @Expression <AggregateExpression> partial_approx_count_distinct(customerID#1883, 0.05, 0, 0) could run on GPU\\n\",\n      \"      ! <HyperLogLogPlusPlus> approx_count_distinct(customerID#1883, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\\n\",\n      \"        @Expression <AttributeReference> customerID#1883 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[0]#1905L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[1]#1906L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[2]#1907L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[3]#1908L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[4]#1909L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[5]#1910L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[6]#1911L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[7]#1912L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[8]#1913L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[9]#1914L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[10]#1915L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[11]#1916L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[12]#1917L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[13]#1918L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[14]#1919L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[15]#1920L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[16]#1921L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[17]#1922L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[18]#1923L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[19]#1924L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[20]#1925L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[21]#1926L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[22]#1927L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[23]#1928L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[24]#1929L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[25]#1930L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[26]#1931L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[27]#1932L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[28]#1933L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[29]#1934L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[30]#1935L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[31]#1936L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[32]#1937L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[33]#1938L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[34]#1939L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[35]#1940L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[36]#1941L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[37]#1942L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[38]#1943L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[39]#1944L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[40]#1945L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[41]#1946L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[42]#1947L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[43]#1948L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[44]#1949L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[45]#1950L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[46]#1951L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[47]#1952L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[48]#1953L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[49]#1954L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[50]#1955L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[51]#1956L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> 0#2539 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[0]#1957L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[1]#1958L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[2]#1959L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[3]#1960L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[4]#1961L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[5]#1962L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[6]#1963L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[7]#1964L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[8]#1965L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[9]#1966L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[10]#1967L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[11]#1968L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[12]#1969L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[13]#1970L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[14]#1971L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[15]#1972L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[16]#1973L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[17]#1974L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[18]#1975L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[19]#1976L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[20]#1977L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[21]#1978L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[22]#1979L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[23]#1980L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[24]#1981L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[25]#1982L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[26]#1983L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[27]#1984L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[28]#1985L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[29]#1986L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[30]#1987L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[31]#1988L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[32]#1989L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[33]#1990L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[34]#1991L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[35]#1992L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[36]#1993L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[37]#1994L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[38]#1995L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[39]#1996L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[40]#1997L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[41]#1998L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[42]#1999L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[43]#2000L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[44]#2001L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[45]#2002L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[46]#2003L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[47]#2004L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[48]#2005L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[49]#2006L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[50]#2007L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[51]#2008L could run on GPU\\n\",\n      \"\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2022-04-05 09:42:07,736 WARN rapids.GpuOverrides: >               (0 + 0) / 815]\\n\",\n      \"!Exec <CollectLimitExec> cannot run on GPU because the Exec CollectLimitExec has been disabled, and is disabled by default because Collect Limit replacement can be slower on the GPU, if huge number of rows in a batch it could help by limiting the number of rows transferred from GPU to CPU. Set spark.rapids.sql.exec.CollectLimitExec to true if you wish to enable it\\n\",\n      \"  @Partitioning <SinglePartition$> could run on GPU\\n\",\n      \"      !Exec <HashAggregateExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"        @Expression <AttributeReference> table#1189 could run on GPU\\n\",\n      \"        @Expression <AggregateExpression> approx_count_distinct(customerID#1179, 0.05, 0, 0) could run on GPU\\n\",\n      \"          ! <HyperLogLogPlusPlus> approx_count_distinct(customerID#1179, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\\n\",\n      \"            @Expression <AttributeReference> customerID#1179 could run on GPU\\n\",\n      \"        @Expression <AttributeReference> approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> table#1189 could run on GPU\\n\",\n      \"        @Expression <Alias> cast(approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L as string) AS approx_unique_customers#2118 could run on GPU\\n\",\n      \"          @Expression <Cast> cast(approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L as string) could run on GPU\\n\",\n      \"            @Expression <AttributeReference> approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L could run on GPU\\n\",\n      \"      !Exec <HashAggregateExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"        @Expression <AttributeReference> 0#2539 could run on GPU\\n\",\n      \"        @Expression <AggregateExpression> approx_count_distinct(customerID#1883, 0.05, 0, 0) could run on GPU\\n\",\n      \"          ! <HyperLogLogPlusPlus> approx_count_distinct(customerID#1883, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\\n\",\n      \"            @Expression <AttributeReference> customerID#1883 could run on GPU\\n\",\n      \"        @Expression <AttributeReference> approx_count_distinct(customerID#1883, 0.05, 0, 0)#1671L could run on GPU\\n\",\n      \"        @Expression <Alias> all AS table#2537 could run on GPU\\n\",\n      \"          @Expression <Literal> all could run on GPU\\n\",\n      \"        @Expression <Alias> cast(approx_count_distinct(customerID#1883, 0.05, 0, 0)#1671L as string) AS approx_unique_customers#2538 could run on GPU\\n\",\n      \"          @Expression <Cast> cast(approx_count_distinct(customerID#1883, 0.05, 0, 0)#1671L as string) could run on GPU\\n\",\n      \"            @Expression <AttributeReference> approx_count_distinct(customerID#1883, 0.05, 0, 0)#1671L could run on GPU\\n\",\n      \"\\n\",\n      \"2022-04-05 09:42:46,961 WARN rapids.GpuOverrides: =============>(812 + 1) / 815]\\n\",\n      \"!Exec <CollectLimitExec> cannot run on GPU because the Exec CollectLimitExec has been disabled, and is disabled by default because Collect Limit replacement can be slower on the GPU, if huge number of rows in a batch it could help by limiting the number of rows transferred from GPU to CPU. Set spark.rapids.sql.exec.CollectLimitExec to true if you wish to enable it\\n\",\n      \"  @Partitioning <SinglePartition$> could run on GPU\\n\",\n      \"      !Exec <HashAggregateExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"        @Expression <AttributeReference> table#1189 could run on GPU\\n\",\n      \"        @Expression <AggregateExpression> approx_count_distinct(customerID#1179, 0.05, 0, 0) could run on GPU\\n\",\n      \"          ! <HyperLogLogPlusPlus> approx_count_distinct(customerID#1179, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\\n\",\n      \"            @Expression <AttributeReference> customerID#1179 could run on GPU\\n\",\n      \"        @Expression <AttributeReference> approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> table#1189 could run on GPU\\n\",\n      \"        @Expression <Alias> cast(approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L as string) AS approx_unique_customers#2118 could run on GPU\\n\",\n      \"          @Expression <Cast> cast(approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L as string) could run on GPU\\n\",\n      \"            @Expression <AttributeReference> approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L could run on GPU\\n\",\n      \"      !Exec <HashAggregateExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"        @Expression <AttributeReference> 0#2539 could run on GPU\\n\",\n      \"        @Expression <AggregateExpression> approx_count_distinct(customerID#1883, 0.05, 0, 0) could run on GPU\\n\",\n      \"          ! <HyperLogLogPlusPlus> approx_count_distinct(customerID#1883, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\\n\",\n      \"            @Expression <AttributeReference> customerID#1883 could run on GPU\\n\",\n      \"        @Expression <AttributeReference> approx_count_distinct(customerID#1883, 0.05, 0, 0)#1671L could run on GPU\\n\",\n      \"        @Expression <Alias> all AS table#2537 could run on GPU\\n\",\n      \"          @Expression <Literal> all could run on GPU\\n\",\n      \"        @Expression <Alias> cast(approx_count_distinct(customerID#1883, 0.05, 0, 0)#1671L as string) AS approx_unique_customers#2538 could run on GPU\\n\",\n      \"          @Expression <Cast> cast(approx_count_distinct(customerID#1883, 0.05, 0, 0)#1671L as string) could run on GPU\\n\",\n      \"            @Expression <AttributeReference> approx_count_distinct(customerID#1883, 0.05, 0, 0)#1671L could run on GPU\\n\",\n      \"\\n\",\n      \"2022-04-05 09:42:46,964 WARN rapids.GpuOverrides: \\n\",\n      \"!Exec <CollectLimitExec> cannot run on GPU because the Exec CollectLimitExec has been disabled, and is disabled by default because Collect Limit replacement can be slower on the GPU, if huge number of rows in a batch it could help by limiting the number of rows transferred from GPU to CPU. Set spark.rapids.sql.exec.CollectLimitExec to true if you wish to enable it\\n\",\n      \"  @Partitioning <SinglePartition$> could run on GPU\\n\",\n      \"      !Exec <HashAggregateExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"        @Expression <AttributeReference> table#1189 could run on GPU\\n\",\n      \"        @Expression <AggregateExpression> approx_count_distinct(customerID#1179, 0.05, 0, 0) could run on GPU\\n\",\n      \"          ! <HyperLogLogPlusPlus> approx_count_distinct(customerID#1179, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\\n\",\n      \"            @Expression <AttributeReference> customerID#1179 could run on GPU\\n\",\n      \"        @Expression <AttributeReference> approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> table#1189 could run on GPU\\n\",\n      \"        @Expression <Alias> cast(approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L as string) AS approx_unique_customers#2118 could run on GPU\\n\",\n      \"          @Expression <Cast> cast(approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L as string) could run on GPU\\n\",\n      \"            @Expression <AttributeReference> approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L could run on GPU\\n\",\n      \"        !Exec <AQEShuffleReadExec> cannot run on GPU because Unable to replace CustomShuffleReader due to child not being columnar\\n\",\n      \"      !Exec <HashAggregateExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"        @Expression <AttributeReference> 0#2539 could run on GPU\\n\",\n      \"        @Expression <AggregateExpression> approx_count_distinct(customerID#1883, 0.05, 0, 0) could run on GPU\\n\",\n      \"          ! <HyperLogLogPlusPlus> approx_count_distinct(customerID#1883, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\\n\",\n      \"            @Expression <AttributeReference> customerID#1883 could run on GPU\\n\",\n      \"        @Expression <AttributeReference> approx_count_distinct(customerID#1883, 0.05, 0, 0)#1671L could run on GPU\\n\",\n      \"        @Expression <Alias> all AS table#2537 could run on GPU\\n\",\n      \"          @Expression <Literal> all could run on GPU\\n\",\n      \"        @Expression <Alias> cast(approx_count_distinct(customerID#1883, 0.05, 0, 0)#1671L as string) AS approx_unique_customers#2538 could run on GPU\\n\",\n      \"          @Expression <Cast> cast(approx_count_distinct(customerID#1883, 0.05, 0, 0)#1671L as string) could run on GPU\\n\",\n      \"            @Expression <AttributeReference> approx_count_distinct(customerID#1883, 0.05, 0, 0)#1671L could run on GPU\\n\",\n      \"        !Exec <AQEShuffleReadExec> cannot run on GPU because Unable to replace CustomShuffleReader due to child not being columnar\\n\",\n      \"\\n\",\n      \"                                                                                \\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+--------------------+-----------------------+\\n\",\n      \"|               table|approx_unique_customers|\\n\",\n      \"+--------------------+-----------------------+\\n\",\n      \"|      billing_events|                 699470|\\n\",\n      \"|       customer_meta|                 699470|\\n\",\n      \"|customer_phone_fe...|                 631148|\\n\",\n      \"|customer_internet...|                 521053|\\n\",\n      \"|customer_account_...|                 699470|\\n\",\n      \"|                 all|                 699470|\\n\",\n      \"+--------------------+-----------------------+\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"\\n\",\n    \"each_table = all_customers.groupBy(\\\"table\\\").agg(F.approx_count_distinct(\\\"customerID\\\").alias(\\\"approx_unique_customers\\\"))\\n\",\n    \"overall = all_customers.groupBy(F.lit(\\\"all\\\").alias(\\\"table\\\")).agg(F.approx_count_distinct(\\\"customerID\\\").alias(\\\"approx_unique_customers\\\"))\\n\",\n    \"\\n\",\n    \"each_table.union(overall).show()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 18,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2022-04-05 09:42:47,133 WARN rapids.GpuOverrides: \\n\",\n      \"  !Exec <HashAggregateExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"    @Expression <AttributeReference> table#1189 could run on GPU\\n\",\n      \"    @Expression <AggregateExpression> approx_count_distinct(customerID#1179, 0.05, 0, 0) could run on GPU\\n\",\n      \"      ! <HyperLogLogPlusPlus> approx_count_distinct(customerID#1179, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\\n\",\n      \"        @Expression <AttributeReference> customerID#1179 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> table#1189 could run on GPU\\n\",\n      \"    @Expression <Alias> approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L AS approx_unique_customers#1353L could run on GPU\\n\",\n      \"      @Expression <AttributeReference> approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L could run on GPU\\n\",\n      \"    !Exec <ShuffleExchangeExec> cannot run on GPU because Columnar exchange without columnar children is inefficient\\n\",\n      \"      @Partitioning <HashPartitioning> could run on GPU\\n\",\n      \"        @Expression <AttributeReference> table#1189 could run on GPU\\n\",\n      \"      !Exec <HashAggregateExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"        @Expression <AttributeReference> table#1189 could run on GPU\\n\",\n      \"        @Expression <AggregateExpression> partial_approx_count_distinct(customerID#1179, 0.05, 0, 0) could run on GPU\\n\",\n      \"          ! <HyperLogLogPlusPlus> approx_count_distinct(customerID#1179, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\\n\",\n      \"            @Expression <AttributeReference> customerID#1179 could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[0]#1354L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[1]#1355L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[2]#1356L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[3]#1357L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[4]#1358L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[5]#1359L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[6]#1360L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[7]#1361L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[8]#1362L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[9]#1363L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[10]#1364L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[11]#1365L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[12]#1366L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[13]#1367L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[14]#1368L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[15]#1369L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[16]#1370L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[17]#1371L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[18]#1372L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[19]#1373L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[20]#1374L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[21]#1375L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[22]#1376L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[23]#1377L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[24]#1378L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[25]#1379L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[26]#1380L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[27]#1381L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[28]#1382L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[29]#1383L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[30]#1384L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[31]#1385L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[32]#1386L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[33]#1387L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[34]#1388L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[35]#1389L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[36]#1390L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[37]#1391L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[38]#1392L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[39]#1393L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[40]#1394L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[41]#1395L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[42]#1396L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[43]#1397L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[44]#1398L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[45]#1399L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[46]#1400L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[47]#1401L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[48]#1402L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[49]#1403L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[50]#1404L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[51]#1405L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> table#1189 could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[0]#1406L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[1]#1407L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[2]#1408L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[3]#1409L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[4]#1410L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[5]#1411L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[6]#1412L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[7]#1413L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[8]#1414L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[9]#1415L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[10]#1416L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[11]#1417L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[12]#1418L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[13]#1419L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[14]#1420L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[15]#1421L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[16]#1422L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[17]#1423L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[18]#1424L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[19]#1425L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[20]#1426L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[21]#1427L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[22]#1428L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[23]#1429L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[24]#1430L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[25]#1431L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[26]#1432L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[27]#1433L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[28]#1434L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[29]#1435L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[30]#1436L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[31]#1437L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[32]#1438L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[33]#1439L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[34]#1440L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[35]#1441L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[36]#1442L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[37]#1443L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[38]#1444L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[39]#1445L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[40]#1446L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[41]#1447L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[42]#1448L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[43]#1449L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[44]#1450L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[45]#1451L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[46]#1452L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[47]#1453L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[48]#1454L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[49]#1455L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[50]#1456L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[51]#1457L could run on GPU\\n\",\n      \"  !Exec <HashAggregateExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"    @Expression <AttributeReference> 0#4023 could run on GPU\\n\",\n      \"    @Expression <AggregateExpression> approx_count_distinct(customerID#3375, 0.05, 0, 0) could run on GPU\\n\",\n      \"      ! <HyperLogLogPlusPlus> approx_count_distinct(customerID#3375, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\\n\",\n      \"        @Expression <AttributeReference> customerID#3375 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> approx_count_distinct(customerID#3375, 0.05, 0, 0)#1671L could run on GPU\\n\",\n      \"    @Expression <Alias> all AS table#1564 could run on GPU\\n\",\n      \"      @Expression <Literal> all could run on GPU\\n\",\n      \"    @Expression <Alias> approx_count_distinct(customerID#3375, 0.05, 0, 0)#1671L AS approx_unique_customers#1672L could run on GPU\\n\",\n      \"      @Expression <AttributeReference> approx_count_distinct(customerID#3375, 0.05, 0, 0)#1671L could run on GPU\\n\",\n      \"    !Exec <ShuffleExchangeExec> cannot run on GPU because Columnar exchange without columnar children is inefficient\\n\",\n      \"      @Partitioning <HashPartitioning> could run on GPU\\n\",\n      \"        @Expression <AttributeReference> 0#4023 could run on GPU\\n\",\n      \"      !Exec <HashAggregateExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"        @Expression <Alias> 0 AS 0#4023 could run on GPU\\n\",\n      \"          @Expression <Literal> 0 could run on GPU\\n\",\n      \"        @Expression <AggregateExpression> partial_approx_count_distinct(customerID#3375, 0.05, 0, 0) could run on GPU\\n\",\n      \"          ! <HyperLogLogPlusPlus> approx_count_distinct(customerID#3375, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\\n\",\n      \"            @Expression <AttributeReference> customerID#3375 could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[0]#3397L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[1]#3398L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[2]#3399L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[3]#3400L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[4]#3401L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[5]#3402L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[6]#3403L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[7]#3404L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[8]#3405L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[9]#3406L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[10]#3407L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[11]#3408L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[12]#3409L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[13]#3410L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[14]#3411L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[15]#3412L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[16]#3413L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[17]#3414L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[18]#3415L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[19]#3416L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[20]#3417L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[21]#3418L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[22]#3419L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[23]#3420L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[24]#3421L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[25]#3422L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[26]#3423L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[27]#3424L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[28]#3425L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[29]#3426L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[30]#3427L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[31]#3428L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[32]#3429L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[33]#3430L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[34]#3431L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[35]#3432L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[36]#3433L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[37]#3434L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[38]#3435L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[39]#3436L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[40]#3437L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[41]#3438L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[42]#3439L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[43]#3440L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[44]#3441L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[45]#3442L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[46]#3443L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[47]#3444L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[48]#3445L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[49]#3446L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[50]#3447L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[51]#3448L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> 0#4023 could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[0]#3449L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[1]#3450L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[2]#3451L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[3]#3452L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[4]#3453L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[5]#3454L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[6]#3455L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[7]#3456L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[8]#3457L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[9]#3458L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[10]#3459L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[11]#3460L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[12]#3461L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[13]#3462L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[14]#3463L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[15]#3464L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[16]#3465L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[17]#3466L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[18]#3467L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[19]#3468L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[20]#3469L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[21]#3470L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[22]#3471L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[23]#3472L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[24]#3473L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[25]#3474L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[26]#3475L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[27]#3476L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[28]#3477L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[29]#3478L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[30]#3479L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[31]#3480L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[32]#3481L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[33]#3482L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[34]#3483L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[35]#3484L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[36]#3485L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[37]#3486L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[38]#3487L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[39]#3488L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[40]#3489L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[41]#3490L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[42]#3491L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[43]#3492L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[44]#3493L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[45]#3494L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[46]#3495L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[47]#3496L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[48]#3497L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[49]#3498L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[50]#3499L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[51]#3500L could run on GPU\\n\",\n      \"\\n\",\n      \"2022-04-05 09:42:47,136 WARN rapids.GpuOverrides: \\n\",\n      \"  !Exec <HashAggregateExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"    @Expression <AttributeReference> table#1189 could run on GPU\\n\",\n      \"    @Expression <AggregateExpression> approx_count_distinct(customerID#1179, 0.05, 0, 0) could run on GPU\\n\",\n      \"      ! <HyperLogLogPlusPlus> approx_count_distinct(customerID#1179, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\\n\",\n      \"        @Expression <AttributeReference> customerID#1179 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> table#1189 could run on GPU\\n\",\n      \"    @Expression <Alias> approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L AS approx_unique_customers#1353L could run on GPU\\n\",\n      \"      @Expression <AttributeReference> approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L could run on GPU\\n\",\n      \"    !Exec <ShuffleExchangeExec> cannot run on GPU because Columnar exchange without columnar children is inefficient\\n\",\n      \"      @Partitioning <HashPartitioning> could run on GPU\\n\",\n      \"        @Expression <AttributeReference> table#1189 could run on GPU\\n\",\n      \"      !Exec <HashAggregateExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"        @Expression <AttributeReference> table#1189 could run on GPU\\n\",\n      \"        @Expression <AggregateExpression> partial_approx_count_distinct(customerID#1179, 0.05, 0, 0) could run on GPU\\n\",\n      \"          ! <HyperLogLogPlusPlus> approx_count_distinct(customerID#1179, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\\n\",\n      \"            @Expression <AttributeReference> customerID#1179 could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[0]#1354L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[1]#1355L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[2]#1356L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[3]#1357L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[4]#1358L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[5]#1359L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[6]#1360L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[7]#1361L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[8]#1362L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[9]#1363L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[10]#1364L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[11]#1365L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[12]#1366L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[13]#1367L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[14]#1368L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[15]#1369L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[16]#1370L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[17]#1371L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[18]#1372L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[19]#1373L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[20]#1374L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[21]#1375L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[22]#1376L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[23]#1377L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[24]#1378L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[25]#1379L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[26]#1380L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[27]#1381L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[28]#1382L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[29]#1383L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[30]#1384L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[31]#1385L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[32]#1386L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[33]#1387L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[34]#1388L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[35]#1389L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[36]#1390L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[37]#1391L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[38]#1392L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[39]#1393L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[40]#1394L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[41]#1395L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[42]#1396L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[43]#1397L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[44]#1398L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[45]#1399L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[46]#1400L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[47]#1401L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[48]#1402L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[49]#1403L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[50]#1404L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[51]#1405L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> table#1189 could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[0]#1406L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[1]#1407L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[2]#1408L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[3]#1409L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[4]#1410L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[5]#1411L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[6]#1412L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[7]#1413L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[8]#1414L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[9]#1415L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[10]#1416L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[11]#1417L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[12]#1418L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[13]#1419L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[14]#1420L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[15]#1421L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[16]#1422L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[17]#1423L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[18]#1424L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[19]#1425L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[20]#1426L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[21]#1427L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[22]#1428L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[23]#1429L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[24]#1430L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[25]#1431L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[26]#1432L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[27]#1433L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[28]#1434L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[29]#1435L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[30]#1436L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[31]#1437L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[32]#1438L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[33]#1439L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[34]#1440L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[35]#1441L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[36]#1442L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[37]#1443L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[38]#1444L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[39]#1445L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[40]#1446L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[41]#1447L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[42]#1448L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[43]#1449L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[44]#1450L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[45]#1451L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[46]#1452L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[47]#1453L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[48]#1454L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[49]#1455L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[50]#1456L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[51]#1457L could run on GPU\\n\",\n      \"  !Exec <HashAggregateExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"    @Expression <AttributeReference> 0#4023 could run on GPU\\n\",\n      \"    @Expression <AggregateExpression> approx_count_distinct(customerID#3375, 0.05, 0, 0) could run on GPU\\n\",\n      \"      ! <HyperLogLogPlusPlus> approx_count_distinct(customerID#3375, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\\n\",\n      \"        @Expression <AttributeReference> customerID#3375 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> approx_count_distinct(customerID#3375, 0.05, 0, 0)#1671L could run on GPU\\n\",\n      \"    @Expression <Alias> all AS table#1564 could run on GPU\\n\",\n      \"      @Expression <Literal> all could run on GPU\\n\",\n      \"    @Expression <Alias> approx_count_distinct(customerID#3375, 0.05, 0, 0)#1671L AS approx_unique_customers#1672L could run on GPU\\n\",\n      \"      @Expression <AttributeReference> approx_count_distinct(customerID#3375, 0.05, 0, 0)#1671L could run on GPU\\n\",\n      \"    !Exec <ShuffleExchangeExec> cannot run on GPU because Columnar exchange without columnar children is inefficient\\n\",\n      \"      @Partitioning <HashPartitioning> could run on GPU\\n\",\n      \"        @Expression <AttributeReference> 0#4023 could run on GPU\\n\",\n      \"      !Exec <HashAggregateExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"        @Expression <Alias> 0 AS 0#4023 could run on GPU\\n\",\n      \"          @Expression <Literal> 0 could run on GPU\\n\",\n      \"        @Expression <AggregateExpression> partial_approx_count_distinct(customerID#3375, 0.05, 0, 0) could run on GPU\\n\",\n      \"          ! <HyperLogLogPlusPlus> approx_count_distinct(customerID#3375, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\\n\",\n      \"            @Expression <AttributeReference> customerID#3375 could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[0]#3397L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[1]#3398L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[2]#3399L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[3]#3400L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[4]#3401L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[5]#3402L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[6]#3403L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[7]#3404L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[8]#3405L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[9]#3406L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[10]#3407L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[11]#3408L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[12]#3409L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[13]#3410L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[14]#3411L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[15]#3412L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[16]#3413L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[17]#3414L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[18]#3415L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[19]#3416L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[20]#3417L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[21]#3418L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[22]#3419L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[23]#3420L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[24]#3421L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[25]#3422L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[26]#3423L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[27]#3424L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[28]#3425L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[29]#3426L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[30]#3427L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[31]#3428L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[32]#3429L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[33]#3430L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[34]#3431L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[35]#3432L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[36]#3433L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[37]#3434L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[38]#3435L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[39]#3436L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[40]#3437L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[41]#3438L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[42]#3439L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[43]#3440L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[44]#3441L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[45]#3442L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[46]#3443L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[47]#3444L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[48]#3445L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[49]#3446L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[50]#3447L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[51]#3448L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> 0#4023 could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[0]#3449L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[1]#3450L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[2]#3451L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[3]#3452L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[4]#3453L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[5]#3454L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[6]#3455L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[7]#3456L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[8]#3457L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[9]#3458L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[10]#3459L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[11]#3460L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[12]#3461L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[13]#3462L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[14]#3463L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[15]#3464L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[16]#3465L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[17]#3466L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[18]#3467L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[19]#3468L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[20]#3469L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[21]#3470L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[22]#3471L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[23]#3472L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[24]#3473L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[25]#3474L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[26]#3475L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[27]#3476L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[28]#3477L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[29]#3478L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[30]#3479L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[31]#3480L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[32]#3481L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[33]#3482L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[34]#3483L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[35]#3484L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[36]#3485L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[37]#3486L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[38]#3487L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[39]#3488L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[40]#3489L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[41]#3490L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[42]#3491L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[43]#3492L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[44]#3493L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[45]#3494L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[46]#3495L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[47]#3496L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[48]#3497L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[49]#3498L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[50]#3499L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[51]#3500L could run on GPU\\n\",\n      \"\\n\",\n      \"2022-04-05 09:42:47,139 WARN rapids.GpuOverrides: \\n\",\n      \"  !Exec <HashAggregateExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"    @Expression <AttributeReference> table#1189 could run on GPU\\n\",\n      \"    @Expression <AggregateExpression> approx_count_distinct(customerID#1179, 0.05, 0, 0) could run on GPU\\n\",\n      \"      ! <HyperLogLogPlusPlus> approx_count_distinct(customerID#1179, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\\n\",\n      \"        @Expression <AttributeReference> customerID#1179 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> table#1189 could run on GPU\\n\",\n      \"    @Expression <Alias> approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L AS approx_unique_customers#1353L could run on GPU\\n\",\n      \"      @Expression <AttributeReference> approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L could run on GPU\\n\",\n      \"    !Exec <ShuffleExchangeExec> cannot run on GPU because Columnar exchange without columnar children is inefficient\\n\",\n      \"      @Partitioning <HashPartitioning> could run on GPU\\n\",\n      \"        @Expression <AttributeReference> table#1189 could run on GPU\\n\",\n      \"      !Exec <HashAggregateExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"        @Expression <AttributeReference> table#1189 could run on GPU\\n\",\n      \"        @Expression <AggregateExpression> partial_approx_count_distinct(customerID#1179, 0.05, 0, 0) could run on GPU\\n\",\n      \"          ! <HyperLogLogPlusPlus> approx_count_distinct(customerID#1179, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\\n\",\n      \"            @Expression <AttributeReference> customerID#1179 could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[0]#1354L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[1]#1355L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[2]#1356L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[3]#1357L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[4]#1358L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[5]#1359L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[6]#1360L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[7]#1361L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[8]#1362L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[9]#1363L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[10]#1364L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[11]#1365L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[12]#1366L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[13]#1367L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[14]#1368L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[15]#1369L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[16]#1370L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[17]#1371L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[18]#1372L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[19]#1373L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[20]#1374L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[21]#1375L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[22]#1376L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[23]#1377L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[24]#1378L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[25]#1379L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[26]#1380L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[27]#1381L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[28]#1382L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[29]#1383L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[30]#1384L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[31]#1385L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[32]#1386L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[33]#1387L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[34]#1388L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[35]#1389L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[36]#1390L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[37]#1391L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[38]#1392L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[39]#1393L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[40]#1394L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[41]#1395L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[42]#1396L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[43]#1397L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[44]#1398L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[45]#1399L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[46]#1400L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[47]#1401L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[48]#1402L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[49]#1403L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[50]#1404L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[51]#1405L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> table#1189 could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[0]#1406L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[1]#1407L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[2]#1408L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[3]#1409L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[4]#1410L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[5]#1411L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[6]#1412L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[7]#1413L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[8]#1414L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[9]#1415L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[10]#1416L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[11]#1417L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[12]#1418L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[13]#1419L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[14]#1420L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[15]#1421L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[16]#1422L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[17]#1423L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[18]#1424L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[19]#1425L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[20]#1426L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[21]#1427L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[22]#1428L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[23]#1429L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[24]#1430L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[25]#1431L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[26]#1432L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[27]#1433L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[28]#1434L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[29]#1435L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[30]#1436L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[31]#1437L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[32]#1438L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[33]#1439L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[34]#1440L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[35]#1441L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[36]#1442L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[37]#1443L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[38]#1444L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[39]#1445L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[40]#1446L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[41]#1447L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[42]#1448L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[43]#1449L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[44]#1450L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[45]#1451L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[46]#1452L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[47]#1453L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[48]#1454L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[49]#1455L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[50]#1456L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[51]#1457L could run on GPU\\n\",\n      \"  !Exec <HashAggregateExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"    @Expression <AttributeReference> 0#4023 could run on GPU\\n\",\n      \"    @Expression <AggregateExpression> approx_count_distinct(customerID#3375, 0.05, 0, 0) could run on GPU\\n\",\n      \"      ! <HyperLogLogPlusPlus> approx_count_distinct(customerID#3375, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\\n\",\n      \"        @Expression <AttributeReference> customerID#3375 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> approx_count_distinct(customerID#3375, 0.05, 0, 0)#1671L could run on GPU\\n\",\n      \"    @Expression <Alias> all AS table#1564 could run on GPU\\n\",\n      \"      @Expression <Literal> all could run on GPU\\n\",\n      \"    @Expression <Alias> approx_count_distinct(customerID#3375, 0.05, 0, 0)#1671L AS approx_unique_customers#1672L could run on GPU\\n\",\n      \"      @Expression <AttributeReference> approx_count_distinct(customerID#3375, 0.05, 0, 0)#1671L could run on GPU\\n\",\n      \"    !Exec <ShuffleExchangeExec> cannot run on GPU because Columnar exchange without columnar children is inefficient\\n\",\n      \"      @Partitioning <HashPartitioning> could run on GPU\\n\",\n      \"        @Expression <AttributeReference> 0#4023 could run on GPU\\n\",\n      \"      !Exec <HashAggregateExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"        @Expression <Alias> 0 AS 0#4023 could run on GPU\\n\",\n      \"          @Expression <Literal> 0 could run on GPU\\n\",\n      \"        @Expression <AggregateExpression> partial_approx_count_distinct(customerID#3375, 0.05, 0, 0) could run on GPU\\n\",\n      \"          ! <HyperLogLogPlusPlus> approx_count_distinct(customerID#3375, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\\n\",\n      \"            @Expression <AttributeReference> customerID#3375 could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[0]#3397L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[1]#3398L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[2]#3399L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[3]#3400L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[4]#3401L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[5]#3402L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[6]#3403L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[7]#3404L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[8]#3405L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[9]#3406L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[10]#3407L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[11]#3408L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[12]#3409L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[13]#3410L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[14]#3411L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[15]#3412L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[16]#3413L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[17]#3414L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[18]#3415L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[19]#3416L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[20]#3417L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[21]#3418L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[22]#3419L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[23]#3420L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[24]#3421L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[25]#3422L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[26]#3423L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[27]#3424L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[28]#3425L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[29]#3426L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[30]#3427L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[31]#3428L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[32]#3429L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[33]#3430L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[34]#3431L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[35]#3432L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[36]#3433L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[37]#3434L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[38]#3435L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[39]#3436L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[40]#3437L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[41]#3438L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[42]#3439L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[43]#3440L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[44]#3441L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[45]#3442L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[46]#3443L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[47]#3444L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[48]#3445L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[49]#3446L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[50]#3447L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[51]#3448L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> 0#4023 could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[0]#3449L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[1]#3450L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[2]#3451L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[3]#3452L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[4]#3453L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[5]#3454L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[6]#3455L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[7]#3456L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[8]#3457L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[9]#3458L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[10]#3459L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[11]#3460L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[12]#3461L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[13]#3462L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[14]#3463L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[15]#3464L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[16]#3465L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[17]#3466L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[18]#3467L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[19]#3468L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[20]#3469L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[21]#3470L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[22]#3471L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[23]#3472L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[24]#3473L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[25]#3474L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[26]#3475L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[27]#3476L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[28]#3477L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[29]#3478L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[30]#3479L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[31]#3480L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[32]#3481L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[33]#3482L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[34]#3483L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[35]#3484L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[36]#3485L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[37]#3486L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[38]#3487L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[39]#3488L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[40]#3489L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[41]#3490L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[42]#3491L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[43]#3492L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[44]#3493L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[45]#3494L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[46]#3495L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[47]#3496L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[48]#3497L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[49]#3498L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[50]#3499L could run on GPU\\n\",\n      \"        @Expression <AttributeReference> MS[51]#3500L could run on GPU\\n\",\n      \"\\n\",\n      \"2022-04-05 09:42:47,147 WARN rapids.GpuOverrides: \\n\",\n      \"!Exec <ShuffleExchangeExec> cannot run on GPU because Columnar exchange without columnar children is inefficient\\n\",\n      \"  @Partitioning <HashPartitioning> could run on GPU\\n\",\n      \"    @Expression <AttributeReference> table#1189 could run on GPU\\n\",\n      \"  !Exec <HashAggregateExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"    @Expression <AttributeReference> table#1189 could run on GPU\\n\",\n      \"    @Expression <AggregateExpression> partial_approx_count_distinct(customerID#1179, 0.05, 0, 0) could run on GPU\\n\",\n      \"      ! <HyperLogLogPlusPlus> approx_count_distinct(customerID#1179, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\\n\",\n      \"        @Expression <AttributeReference> customerID#1179 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[0]#1354L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[1]#1355L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[2]#1356L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[3]#1357L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[4]#1358L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[5]#1359L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[6]#1360L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[7]#1361L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[8]#1362L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[9]#1363L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[10]#1364L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[11]#1365L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[12]#1366L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[13]#1367L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[14]#1368L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[15]#1369L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[16]#1370L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[17]#1371L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[18]#1372L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[19]#1373L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[20]#1374L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[21]#1375L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[22]#1376L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[23]#1377L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[24]#1378L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[25]#1379L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[26]#1380L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[27]#1381L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[28]#1382L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[29]#1383L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[30]#1384L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[31]#1385L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[32]#1386L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[33]#1387L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[34]#1388L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[35]#1389L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[36]#1390L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[37]#1391L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[38]#1392L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[39]#1393L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[40]#1394L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[41]#1395L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[42]#1396L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[43]#1397L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[44]#1398L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[45]#1399L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[46]#1400L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[47]#1401L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[48]#1402L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[49]#1403L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[50]#1404L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[51]#1405L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> table#1189 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[0]#1406L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[1]#1407L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[2]#1408L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[3]#1409L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[4]#1410L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[5]#1411L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[6]#1412L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[7]#1413L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[8]#1414L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[9]#1415L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[10]#1416L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[11]#1417L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[12]#1418L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[13]#1419L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[14]#1420L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[15]#1421L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[16]#1422L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[17]#1423L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[18]#1424L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[19]#1425L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[20]#1426L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[21]#1427L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[22]#1428L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[23]#1429L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[24]#1430L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[25]#1431L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[26]#1432L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[27]#1433L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[28]#1434L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[29]#1435L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[30]#1436L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[31]#1437L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[32]#1438L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[33]#1439L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[34]#1440L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[35]#1441L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[36]#1442L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[37]#1443L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[38]#1444L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[39]#1445L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[40]#1446L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[41]#1447L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[42]#1448L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[43]#1449L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[44]#1450L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[45]#1451L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[46]#1452L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[47]#1453L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[48]#1454L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[49]#1455L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[50]#1456L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[51]#1457L could run on GPU\\n\",\n      \"\\n\",\n      \"2022-04-05 09:42:47,151 WARN rapids.GpuOverrides: \\n\",\n      \"!Exec <ShuffleExchangeExec> cannot run on GPU because Columnar exchange without columnar children is inefficient\\n\",\n      \"  @Partitioning <HashPartitioning> could run on GPU\\n\",\n      \"    @Expression <AttributeReference> 0#4023 could run on GPU\\n\",\n      \"  !Exec <HashAggregateExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"    @Expression <Alias> 0 AS 0#4023 could run on GPU\\n\",\n      \"      @Expression <Literal> 0 could run on GPU\\n\",\n      \"    @Expression <AggregateExpression> partial_approx_count_distinct(customerID#3375, 0.05, 0, 0) could run on GPU\\n\",\n      \"      ! <HyperLogLogPlusPlus> approx_count_distinct(customerID#3375, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\\n\",\n      \"        @Expression <AttributeReference> customerID#3375 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[0]#3397L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[1]#3398L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[2]#3399L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[3]#3400L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[4]#3401L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[5]#3402L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[6]#3403L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[7]#3404L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[8]#3405L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[9]#3406L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[10]#3407L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[11]#3408L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[12]#3409L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[13]#3410L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[14]#3411L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[15]#3412L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[16]#3413L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[17]#3414L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[18]#3415L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[19]#3416L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[20]#3417L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[21]#3418L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[22]#3419L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[23]#3420L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[24]#3421L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[25]#3422L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[26]#3423L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[27]#3424L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[28]#3425L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[29]#3426L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[30]#3427L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[31]#3428L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[32]#3429L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[33]#3430L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[34]#3431L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[35]#3432L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[36]#3433L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[37]#3434L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[38]#3435L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[39]#3436L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[40]#3437L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[41]#3438L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[42]#3439L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[43]#3440L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[44]#3441L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[45]#3442L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[46]#3443L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[47]#3444L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[48]#3445L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[49]#3446L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[50]#3447L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[51]#3448L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> 0#4023 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[0]#3449L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[1]#3450L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[2]#3451L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[3]#3452L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[4]#3453L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[5]#3454L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[6]#3455L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[7]#3456L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[8]#3457L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[9]#3458L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[10]#3459L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[11]#3460L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[12]#3461L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[13]#3462L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[14]#3463L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[15]#3464L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[16]#3465L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[17]#3466L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[18]#3467L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[19]#3468L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[20]#3469L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[21]#3470L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[22]#3471L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[23]#3472L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[24]#3473L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[25]#3474L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[26]#3475L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[27]#3476L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[28]#3477L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[29]#3478L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[30]#3479L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[31]#3480L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[32]#3481L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[33]#3482L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[34]#3483L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[35]#3484L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[36]#3485L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[37]#3486L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[38]#3487L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[39]#3488L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[40]#3489L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[41]#3490L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[42]#3491L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[43]#3492L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[44]#3493L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[45]#3494L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[46]#3495L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[47]#3496L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[48]#3497L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[49]#3498L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[50]#3499L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> MS[51]#3500L could run on GPU\\n\",\n      \"\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2022-04-05 09:43:28,385 WARN rapids.GpuOverrides: >               (0 + 0) / 815]\\n\",\n      \"  !Exec <HashAggregateExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"    @Expression <AttributeReference> table#1189 could run on GPU\\n\",\n      \"    @Expression <AggregateExpression> approx_count_distinct(customerID#1179, 0.05, 0, 0) could run on GPU\\n\",\n      \"      ! <HyperLogLogPlusPlus> approx_count_distinct(customerID#1179, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\\n\",\n      \"        @Expression <AttributeReference> customerID#1179 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> table#1189 could run on GPU\\n\",\n      \"    @Expression <Alias> approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L AS approx_unique_customers#1353L could run on GPU\\n\",\n      \"      @Expression <AttributeReference> approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L could run on GPU\\n\",\n      \"  !Exec <HashAggregateExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"    @Expression <AttributeReference> 0#4023 could run on GPU\\n\",\n      \"    @Expression <AggregateExpression> approx_count_distinct(customerID#3375, 0.05, 0, 0) could run on GPU\\n\",\n      \"      ! <HyperLogLogPlusPlus> approx_count_distinct(customerID#3375, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\\n\",\n      \"        @Expression <AttributeReference> customerID#3375 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> approx_count_distinct(customerID#3375, 0.05, 0, 0)#1671L could run on GPU\\n\",\n      \"    @Expression <Alias> all AS table#1564 could run on GPU\\n\",\n      \"      @Expression <Literal> all could run on GPU\\n\",\n      \"    @Expression <Alias> approx_count_distinct(customerID#3375, 0.05, 0, 0)#1671L AS approx_unique_customers#1672L could run on GPU\\n\",\n      \"      @Expression <AttributeReference> approx_count_distinct(customerID#3375, 0.05, 0, 0)#1671L could run on GPU\\n\",\n      \"\\n\",\n      \"2022-04-05 09:44:07,480 WARN rapids.GpuOverrides: =============>(812 + 1) / 815]\\n\",\n      \"  !Exec <HashAggregateExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"    @Expression <AttributeReference> table#1189 could run on GPU\\n\",\n      \"    @Expression <AggregateExpression> approx_count_distinct(customerID#1179, 0.05, 0, 0) could run on GPU\\n\",\n      \"      ! <HyperLogLogPlusPlus> approx_count_distinct(customerID#1179, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\\n\",\n      \"        @Expression <AttributeReference> customerID#1179 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> table#1189 could run on GPU\\n\",\n      \"    @Expression <Alias> approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L AS approx_unique_customers#1353L could run on GPU\\n\",\n      \"      @Expression <AttributeReference> approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L could run on GPU\\n\",\n      \"  !Exec <HashAggregateExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"    @Expression <AttributeReference> 0#4023 could run on GPU\\n\",\n      \"    @Expression <AggregateExpression> approx_count_distinct(customerID#3375, 0.05, 0, 0) could run on GPU\\n\",\n      \"      ! <HyperLogLogPlusPlus> approx_count_distinct(customerID#3375, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\\n\",\n      \"        @Expression <AttributeReference> customerID#3375 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> approx_count_distinct(customerID#3375, 0.05, 0, 0)#1671L could run on GPU\\n\",\n      \"    @Expression <Alias> all AS table#1564 could run on GPU\\n\",\n      \"      @Expression <Literal> all could run on GPU\\n\",\n      \"    @Expression <Alias> approx_count_distinct(customerID#3375, 0.05, 0, 0)#1671L AS approx_unique_customers#1672L could run on GPU\\n\",\n      \"      @Expression <AttributeReference> approx_count_distinct(customerID#3375, 0.05, 0, 0)#1671L could run on GPU\\n\",\n      \"\\n\",\n      \"2022-04-05 09:44:07,482 WARN rapids.GpuOverrides: \\n\",\n      \"  !Exec <HashAggregateExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"    @Expression <AttributeReference> table#1189 could run on GPU\\n\",\n      \"    @Expression <AggregateExpression> approx_count_distinct(customerID#1179, 0.05, 0, 0) could run on GPU\\n\",\n      \"      ! <HyperLogLogPlusPlus> approx_count_distinct(customerID#1179, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\\n\",\n      \"        @Expression <AttributeReference> customerID#1179 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L could run on GPU\\n\",\n      \"    @Expression <AttributeReference> table#1189 could run on GPU\\n\",\n      \"    @Expression <Alias> approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L AS approx_unique_customers#1353L could run on GPU\\n\",\n      \"      @Expression <AttributeReference> approx_count_distinct(customerID#1179, 0.05, 0, 0)#1352L could run on GPU\\n\",\n      \"    !Exec <AQEShuffleReadExec> cannot run on GPU because Unable to replace CustomShuffleReader due to child not being columnar\\n\",\n      \"  !Exec <HashAggregateExec> cannot run on GPU because not all expressions can be replaced\\n\",\n      \"    @Expression <AttributeReference> 0#4023 could run on GPU\\n\",\n      \"    @Expression <AggregateExpression> approx_count_distinct(customerID#3375, 0.05, 0, 0) could run on GPU\\n\",\n      \"      ! <HyperLogLogPlusPlus> approx_count_distinct(customerID#3375, 0.05, 0, 0) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus\\n\",\n      \"        @Expression <AttributeReference> customerID#3375 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> approx_count_distinct(customerID#3375, 0.05, 0, 0)#1671L could run on GPU\\n\",\n      \"    @Expression <Alias> all AS table#1564 could run on GPU\\n\",\n      \"      @Expression <Literal> all could run on GPU\\n\",\n      \"    @Expression <Alias> approx_count_distinct(customerID#3375, 0.05, 0, 0)#1671L AS approx_unique_customers#1672L could run on GPU\\n\",\n      \"      @Expression <AttributeReference> approx_count_distinct(customerID#3375, 0.05, 0, 0)#1671L could run on GPU\\n\",\n      \"    !Exec <AQEShuffleReadExec> cannot run on GPU because Unable to replace CustomShuffleReader due to child not being columnar\\n\",\n      \"\\n\",\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"rows = each_table.union(overall).collect()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 19,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'billing_events': 699470,\\n\",\n       \" 'customer_meta': 699470,\\n\",\n       \" 'customer_phone_features': 631148,\\n\",\n       \" 'customer_internet_features': 521053,\\n\",\n       \" 'customer_account_features': 699470,\\n\",\n       \" 'all': 699470}\"\n      ]\n     },\n     \"execution_count\": 19,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"dict([(row[0], row[1]) for row in rows])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3 (ipykernel)\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.10\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 4\n}\n"
  },
  {
    "path": "examples/SQL+DF-Examples/customer-churn/notebooks/python/churn/augment.py",
    "content": "# Copyright (c) 2022, NVIDIA Corporation.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport datetime\nimport os\n\nimport pyspark\nfrom pyspark.sql.types import StructType, StructField, StringType, DoubleType, DecimalType\nimport pyspark.sql.functions as F\nfrom collections import defaultdict\n\noptions = defaultdict(lambda: None)\n\nnow = datetime.datetime.now(datetime.timezone.utc)\n\nAUGMENT_VERSION = \"0.7\"\nAUGMENT_CUSTOMER_TAG = \"0007\"\n\nsession = None\ncurrencyType = None\n\ndef get_currency_type():\n    global options\n    global currencyType\n\n    if currencyType is not None:\n        return currencyType\n    \n    if \"use_decimal\" in options and options[\"use_decimal\"]:\n        if \"decimal_precision\" in options :\n            assert options[\"decimal_precision\"] > 5, \"Decimal precision is too small; was %d but should be at least 6\" % options[\"decimal_precision\"]\n            currencyType = DecimalType(options[\"decimal_precision\"], 2)\n        else:\n            # \"999,999.99 should be enough for anyone\"\n            currencyType = DecimalType(8, 2)\n    else:\n        currencyType = DoubleType()\n    \n    return currencyType\n\ndef _register_session(s):\n    global session\n    session = s\n\ndef _get_uniques(ct):\n    global session\n    table_names = set([table.name for table in session.catalog.listTables()])\n\n    if (\"uniques_%d\" % ct) in table_names:\n        return session.table(\"uniques_%d\" % ct)\n    else:\n        def str_part(seed=0x5CA1AB1E):\n            \"generate the string part of a unique ID\"\n            import random\n\n            r = random.Random(seed)\n            from base64 import b64encode\n\n            while True:\n                yield \"%s-%s\" % (b64encode(r.getrandbits(72).to_bytes(9, \"big\"), b\"@_\").decode(\n                    \"utf-8\"\n                ), AUGMENT_CUSTOMER_TAG)\n        \n        sp = str_part()\n        \n        uniques = (\n            session.createDataFrame(\n                schema=StructType([StructField(\"u_value\", StringType())]),\n                data=[dict(u_value=next(sp)) for _ in range(min(int(ct * 1.02), ct + 2))],\n            )\n            .distinct()\n            .orderBy(\"u_value\")\n            .limit(ct)\n        ).cache()\n\n        uc = uniques.count()\n        assert (uc == ct), \"due to prng collision we had %d instead of %d replicas\" % (uc, ct)\n\n        uniques.createOrReplaceTempView(\"uniques_%d\" % ct)\n\n        return uniques\n        \n\ndef register_options(**kwargs):\n    global options\n    for k, v in kwargs.items():\n        options[k] = v\n\ndef load_supplied_data(session, input_file):\n    _register_session(session)\n\n    fields = [\n        \"customerID\",\n        \"gender\",\n        \"SeniorCitizen\",\n        \"Partner\",\n        \"Dependents\",\n        \"tenure\",\n        \"PhoneService\",\n        \"MultipleLines\",\n        \"InternetService\",\n        \"OnlineSecurity\",\n        \"OnlineBackup\",\n        \"DeviceProtection\",\n        \"TechSupport\",\n        \"StreamingTV\",\n        \"StreamingMovies\",\n        \"Contract\",\n        \"PaperlessBilling\",\n        \"PaymentMethod\",\n        \"MonthlyCharges\",\n        \"TotalCharges\",\n        \"Churn\",\n    ]\n    double_fields = set([\"tenure\", \"MonthlyCharges\", \"TotalCharges\"])\n\n    schema = pyspark.sql.types.StructType(\n        [\n            pyspark.sql.types.StructField(\n                f, DoubleType() if f in double_fields else StringType()\n            )\n            for f in fields\n        ]\n    )\n\n    df = session.read.csv(input_file, header=True, schema=schema)\n    \n    source_count = df.count()\n    df = df.dropna()\n    nn_count = df.count()\n\n    if source_count == nn_count:    \n        print(\"read %d records from source dataset with no nulls -- is this what you expect?\" % source_count)\n    else:\n        print(\"read %d records from source dataset (%d non-null records)\" % (source_count, nn_count))\n    \n    return df\n\ndef replicate_df(df, duplicates):\n\n    if duplicates > 1:\n        uniques = _get_uniques(duplicates)\n\n        df = (\n            df.crossJoin(uniques.distinct())\n            .withColumn(\"customerID\", F.format_string(\"%s-%s\", \"customerID\", \"u_value\"))\n            .drop(\"u_value\")\n        )\n\n    return df\n\ndef examine_categoricals(df, columns=None):\n    \"\"\" Returns (to driver memory) a list of tuples consisting of every unique value \n        for each column in `columns` or for every categorical column in the source \n        data if no columns are specified \"\"\"\n    default_columns = [\n        \"SeniorCitizen\",\n        \"Partner\",\n        \"Dependents\",\n        \"PhoneService\",\n        \"MultipleLines\",\n        \"InternetService\",\n        \"OnlineSecurity\",\n        \"OnlineBackup\",\n        \"DeviceProtection\",\n        \"TechSupport\",\n        \"StreamingTV\",\n        \"StreamingMovies\",\n        \"Contract\",\n        \"PaperlessBilling\",\n        \"PaymentMethod\",\n    ]\n\n    columns = columns or default_columns\n\n    return [(c, [row[0] for row in df.select(c).distinct().rdd.collect()]) for c in columns]\n\ndef billing_events(df):\n    import datetime\n\n    MAX_MONTH = 72\n\n    def get_last_month(col):\n        h = F.abs(F.xxhash64(col))\n        h1 = (h.bitwiseAND(0xff)) % (MAX_MONTH // 2)\n        h2 = (F.shiftRight(h, 8).bitwiseAND(0xff)) % (MAX_MONTH // 3)\n        h3 = (F.shiftRight(h, 16).bitwiseAND(0xff)) % (MAX_MONTH // 5)\n        h4 = (F.shiftRight(h, 24).bitwiseAND(0xff)) % (MAX_MONTH // 7)\n        h5 = (F.shiftRight(h, 32).bitwiseAND(0xff)) % (MAX_MONTH // 11)\n        return -(h1 + h2 + h3 + h4 + h5)\n\n    w = pyspark.sql.Window.orderBy(F.lit(\"\")).partitionBy(df.customerID)\n\n    charges = (\n        df.select(\n            df.customerID,\n            F.lit(\"Charge\").alias(\"kind\"),\n            F.explode(\n                F.array_repeat((df.TotalCharges / df.tenure).cast(get_currency_type()), df.tenure.cast(\"int\"))\n            ).alias(\"value\"),\n            F.when(df.Churn == \"Yes\", get_last_month(df.customerID)).otherwise(0).alias(\"last_month\")\n        )\n        .withColumn(\"now\", F.lit(now).cast(\"date\"))\n        .withColumn(\"month_number\", -(F.row_number().over(w) + F.col(\"last_month\")))\n        .withColumn(\"date\", F.expr(\"add_months(now, month_number)\"))\n        .drop(\"now\", \"month_number\", \"last_month\")\n    )\n\n    serviceStarts = (\n        df.withColumn(\"last_month\", F.when(df.Churn == \"Yes\", get_last_month(df.customerID)).otherwise(0)).select(\n            df.customerID,\n            F.lit(\"AccountCreation\").alias(\"kind\"),\n            F.lit(0.0).cast(get_currency_type()).alias(\"value\"),\n            F.lit(now).alias(\"now\"),\n            (-df.tenure - 1 + F.col(\"last_month\")).alias(\"month_number\"),\n        )\n        .withColumn(\"date\", F.expr(\"add_months(now, month_number)\"))\n        .drop(\"now\", \"month_number\")\n    )\n\n    serviceTerminations = df.withColumn(\"last_month\", F.when(df.Churn == \"Yes\", get_last_month(df.customerID)).otherwise(0)).where(\n        df.Churn == \"Yes\"\n    ).withColumn(\"now\", F.lit(now)).select(\n        df.customerID,\n        F.lit(\"AccountTermination\").alias(\"kind\"),\n        F.lit(0.0).cast(get_currency_type()).alias(\"value\"),\n        F.expr(\"add_months(now, last_month)\").alias(\"date\")\n    )\n\n    billingEvents = charges.union(serviceStarts).union(serviceTerminations).orderBy(\"date\").withColumn(\"month\", F.substring(\"date\", 0, 7))\n    return billingEvents\n\ndef resolve_path(name):\n    output_prefix = options[\"output_prefix\"] or \"\"\n    output_mode = options[\"output_mode\"] or \"overwrite\"\n    output_kind = options[\"output_kind\"] or \"parquet\"\n    name = \"%s.%s\" % (name, output_kind)\n    if output_prefix != \"\":\n        name = \"%s%s\" % (output_prefix, name)\n    \n    return name\n\ndef write_df(df, name, skip_replication=False, partition_by=None):\n    dup_times = options[\"dup_times\"] or 1\n    output_prefix = options[\"output_prefix\"] or \"\"\n    output_mode = options[\"output_mode\"] or \"overwrite\"\n    output_kind = options[\"output_kind\"] or \"parquet\"\n\n    if not skip_replication:\n        df = replicate_df(df, dup_times)\n    write = df.write\n    if partition_by is not None:\n        if type(partition_by) == str:\n            partition_by = [partition_by]\n        write = write.partitionBy(*partition_by)\n    name = \"%s.%s\" % (name, output_kind)\n    if output_prefix != \"\":\n        name = \"%s%s\" % (output_prefix, name)\n    kwargs = {}\n    if output_kind == \"csv\":\n        kwargs[\"header\"] = True\n    getattr(write.mode(output_mode), output_kind)(name, **kwargs)\n\ndef customer_meta(df):\n    SENIOR_CUTOFF = 65\n    ADULT_CUTOFF = 18\n    DAYS_IN_YEAR = 365.25\n    EXPONENTIAL_DIST_SCALE = 6.3\n\n    augmented_original = replicate_df(df, options[\"dup_times\"] or 1)\n\n    customerMetaRaw = augmented_original.select(\n        \"customerID\",\n        F.lit(now).alias(\"now\"),\n        (F.abs(F.hash(augmented_original.customerID)) % 4096 / 4096).alias(\"choice\"),\n        \"SeniorCitizen\",\n        \"gender\",\n        \"Partner\",\n        \"Dependents\",\n        F.col(\"MonthlyCharges\").cast(get_currency_type()).alias(\"MonthlyCharges\"),\n    )\n\n    customerMetaRaw = customerMetaRaw.withColumn(\n        \"ageInDays\",\n        F.floor(\n            F.when(\n                customerMetaRaw.SeniorCitizen == 0,\n                (\n                    customerMetaRaw.choice\n                    * ((SENIOR_CUTOFF - ADULT_CUTOFF - 1) * DAYS_IN_YEAR)\n                )\n                + (ADULT_CUTOFF * DAYS_IN_YEAR),\n            ).otherwise(\n                (SENIOR_CUTOFF * DAYS_IN_YEAR)\n                + (\n                    DAYS_IN_YEAR\n                    * (-F.log1p(-customerMetaRaw.choice) * EXPONENTIAL_DIST_SCALE)\n                )\n            )\n        ).cast(\"int\"),\n    )\n\n    customerMetaRaw = customerMetaRaw.withColumn(\n        \"dateOfBirth\", F.expr(\"date_sub(now, ageInDays)\")\n    )\n\n    return customerMetaRaw.select(\n        \"customerID\",\n        \"dateOfBirth\",\n        \"gender\",\n        \"SeniorCitizen\",\n        \"Partner\",\n        \"Dependents\",\n        \"MonthlyCharges\",\n        \"now\",\n    ).orderBy(\"customerID\")\n\n\ndef phone_features(df):\n    phoneService = df.select(\n        \"customerID\", F.lit(\"PhoneService\").alias(\"feature\"), F.lit(\"Yes\").alias(\"value\")\n    ).where(df.PhoneService == \"Yes\")\n\n    multipleLines = df.select(\n        \"customerID\", F.lit(\"MultipleLines\").alias(\"feature\"), F.lit(\"Yes\").alias(\"value\")\n    ).where(df.MultipleLines == \"Yes\")\n\n    return phoneService.union(multipleLines).orderBy(\"customerID\")\n\ndef internet_features(df):\n    internet_service = df.select(\n        \"customerID\",\n        F.lit(\"InternetService\").alias(\"feature\"),\n        df.InternetService.alias(\"value\"),\n    ).where(df.InternetService != \"No\")\n\n    customerInternetFeatures = internet_service\n\n    for feature in [\n        \"InternetService\",\n        \"OnlineSecurity\",\n        \"OnlineBackup\",\n        \"DeviceProtection\",\n        \"TechSupport\",\n        \"StreamingTV\",\n        \"StreamingMovies\",\n    ]:\n        tmpdf = df.select(\n            \"customerID\",\n            F.lit(feature).alias(\"feature\"),\n            df[feature].alias(\"value\"),\n        ).where(df[feature] == \"Yes\")\n\n        customerInternetFeatures = customerInternetFeatures.union(tmpdf)\n\n    return customerInternetFeatures\n\n\ndef account_features(df):\n    session = df.sql_ctx.sparkSession\n    accountSchema = pyspark.sql.types.StructType(\n        [\n            pyspark.sql.types.StructField(f, StringType())\n            for f in [\"customerID\", \"feature\", \"value\"]\n        ]\n    )\n\n    customerAccountFeatures = session.createDataFrame(schema=accountSchema, data=[])\n\n    for feature in [\"Contract\", \"PaperlessBilling\", \"PaymentMethod\"]:\n        tmpdf = df.select(\n            \"customerID\",\n            F.lit(feature).alias(\"feature\"),\n            df[feature].alias(\"value\"),\n        ).where(df[feature] != \"No\")\n\n        customerAccountFeatures = customerAccountFeatures.union(tmpdf)\n    \n    return customerAccountFeatures\n\n\ndef debug_augmentation(df):\n    return (\n        df.select(\"customerID\")\n        .distinct()\n        .select(\n            \"customerID\",\n            F.substring(\"customerID\", 0, 10).alias(\"originalID\"),\n            F.element_at(F.split(\"customerID\", \"-\", -1), 3).alias(\"suffix\"),\n        )\n    )"
  },
  {
    "path": "examples/SQL+DF-Examples/customer-churn/notebooks/python/churn/eda.py",
    "content": "# Copyright (c) 2022, NVIDIA Corporation.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom pyspark.sql import types as T\nfrom pyspark.sql import functions as F\n\neda_options = { 'use_array_ops' : False }\n\ndef isnumeric(data_type):\n    numeric_types = [T.ByteType, T.ShortType, T.IntegerType, T.LongType, T.FloatType, T.DoubleType, T.DecimalType]\n    return any([isinstance(data_type, t) for t in numeric_types])\n\n\ndef percent_true(df, cols):\n    denominator = df.count()\n    return {col : df.where(F.col(col) == True).count() / denominator for col in cols}\n\n\ndef cardinalities(df, cols):\n    from functools import reduce\n    \n    counts = df.agg(\n        F.struct(*[F.countDistinct(F.col(c)).alias(c) for c in cols] + [F.count(F.col(cols[0])).alias('total')]).alias(\"results\")\n    ).select(\"results\").collect()[0][0].asDict()\n    counts.update({'total' : df.count()})\n    return counts\n\n\ndef likely_unique(counts):\n    total = counts[\"total\"]\n    return [k for (k, v) in counts.items() if k != \"total\" and abs(total - v) < total * 0.15]\n\n\ndef likely_categoricals(counts):\n    total = counts[\"total\"]\n    return [k for (k, v) in counts.items() if v < total * 0.15 or v < 128]\n\ndef unique_values(df, cols):\n    if eda_options['use_array_ops']:\n        return unique_values_array(df, cols)\n    else:   \n        return unique_values_driver(df, cols)\n\ndef unique_values_array(df, cols):\n    from functools import reduce\n \n    counts = df.groupBy(\n        F.lit(True).alias(\"drop_me\")\n    ).agg(\n        *[F.array_sort(F.collect_set(F.col(c))).alias(c) for c in cols]\n    ).drop(\"drop_me\").cache()\n    \n    result = reduce(lambda l, r: l.unionAll(r), [counts.select(F.lit(c).alias(\"field\"), F.col(c).alias(\"unique_vals\")) for c in counts.columns]).collect()\n    \n    return dict([(r[0],r[1]) for r in result])\n\n\ndef unique_values_driver(df, cols):\n    return { col : [v[0] for v in df.select(F.col(col).alias('value')).distinct().orderBy(F.col('value')).collect()] for col in cols}\n\ndef approx_ecdf(df, cols):\n    from functools import reduce\n    \n    quantiles = [0.0, 0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99, 1.0]\n\n    qs = df.approxQuantile(cols, quantiles, 0.01)\n    \n    result = dict(zip(cols, qs))\n    return {c: dict(zip(quantiles, vs)) for (c, vs) in result.items()}\n\n\ndef gen_summary(df, output_prefix=\"\"):\n    summary = {}\n    \n    string_cols = []\n    boolean_cols = []\n    numeric_cols = []\n    other_cols = []\n\n    for field in df.schema.fields:\n        if isinstance(field.dataType, T.StringType):\n            string_cols.append(field.name)\n        elif isinstance(field.dataType, T.BooleanType):\n            boolean_cols.append(field.name)\n        elif isnumeric(field.dataType):\n            numeric_cols.append(field.name)\n        else:\n            other_cols.append(field.name)\n    \n    counts = cardinalities(df, string_cols)\n    uniques = likely_unique(counts)\n    categoricals = unique_values(df, likely_categoricals(counts))\n\n    for span in [2,3,4,6,12]:\n        thecube = df.cube(\"Churn\", F.ceil(df.tenure / span).alias(\"%d_month_spans\" % span), \"gender\", \"Partner\", \"SeniorCitizen\", \"Contract\", \"PaperlessBilling\", \"PaymentMethod\", F.ceil(F.log2(F.col(\"MonthlyCharges\"))*10).alias(\"log_charges\")).count()\n        therollup = df.rollup(\"Churn\", F.ceil(df.tenure / span).alias(\"%d_month_spans\" % span), \"SeniorCitizen\", \"Contract\", \"PaperlessBilling\", \"PaymentMethod\", F.ceil(F.log2(F.col(\"MonthlyCharges\"))*10).alias(\"log_charges\")).agg(F.sum(F.col(\"TotalCharges\")).alias(\"sum_charges\"))\n        thecube.write.mode(\"overwrite\").parquet(\"%scube-%d.parquet\" % (output_prefix, span))\n        therollup.write.mode(\"overwrite\").parquet(\"%srollup-%d.parquet\" % (output_prefix, span))\n\n    encoding_struct = {\n        \"categorical\" : categoricals,\n        \"numeric\" : numeric_cols + boolean_cols,\n        \"unique\": uniques\n    }\n    \n    summary[\"schema\"] = df.schema.jsonValue()\n    summary[\"ecdfs\"] = approx_ecdf(df, numeric_cols)\n    summary[\"true_percentage\"] = percent_true(df, boolean_cols)\n    summary[\"encoding\"] = encoding_struct\n    summary[\"distinct_customers\"] = df.select(df.customerID).distinct().count()\n    \n    return summary\n\ndef losses_by_month(be):\n    customer_lifetime_values = be.groupBy(\"customerID\").sum(\"value\").alias(\"value\")\n    return be.where(be.kind == \"AccountTermination\").join(customer_lifetime_values, \"customerID\").groupBy(\"month\").sum(\"value\").alias(\"value\").sort(\"month\").toPandas().to_json()\n\ndef output_reports(df, be=None, report_prefix=\"\"):\n    import json\n\n    summary = gen_summary(df, report_prefix)\n\n    if be is not None:\n        summary[\"losses_by_month\"] = losses_by_month(be)\n\n    with open(\"%ssummary.json\" % report_prefix, \"w\") as sf:\n        json.dump(summary, sf)\n    \n    with open(\"%sencodings.json\" % report_prefix, \"w\") as ef:\n        json.dump(summary[\"encoding\"], ef)\n        \n"
  },
  {
    "path": "examples/SQL+DF-Examples/customer-churn/notebooks/python/churn/etl.py",
    "content": "#!/usr/bin/env python\n# coding: utf-8\n\n# Copyright (c) 2022, NVIDIA Corporation.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport pyspark\nimport pyspark.sql\nimport pyspark.sql.functions as F\n\nfrom collections import defaultdict\n\noptions = defaultdict(lambda: None)\nsession = None\n\nETL_VERSION = '0.7'\n\ndef register_options(**kwargs):\n    global options\n    for k, v in kwargs.items():\n        options[k] = v\n\ndef _register_session(s):\n    global session\n    session = s\n\ndef _register_views(lvars, *names):\n    for n in names:\n        if n in lvars:\n            lvars[n].createOrReplaceTempView(n)\n\ndef withsession(df_arg=0):\n    def decorate(fn):\n        def wrapped(*args, **kwargs):\n            _register_session(args[df_arg].sql_ctx.sparkSession)\n            fn(*args, **kwargs)\n        return wrapped\n    return decorate\n\ndef read_df(session, fn):\n    kwargs = {}\n    _register_session(session)\n    input_kind = options[\"input_kind\"]\n\n    if input_kind == \"csv\":\n        kwargs[\"header\"] = True\n    return getattr(session.read, input_kind)(\"%s.%s\" % (fn, input_kind), **kwargs)\n\n\ndef find_customers(billing_events_df):\n    customers = billing_events_df.select(\"customerID\").distinct()\n    if 'cache_customers' in options:\n        customers.cache()\n    customers.createOrReplaceTempView(\"customers\")\n    return customers\n\ndef customers():\n    global session\n    return session.table(\"customers\")\n\ndef join_billing_data(billing_events_df):\n    _register_session(billing_events_df.sql_ctx.sparkSession)\n\n    billing_events = billing_events_df.withColumn(\"value\", billing_events_df.value)\n\n    customers = find_customers(billing_events)\n\n    counts_and_charges = billing_events.groupBy(\"customerID\", \"kind\").agg(\n        F.count(billing_events.value).alias(\"event_counts\"),\n        F.sum(billing_events.value).alias(\"total_charges\"),\n    )\n\n    counts_and_charges.createOrReplaceTempView(\"counts_and_charges\")\n    \n    terminations = billing_events.where(F.col(\"kind\") == \"AccountTermination\").select(\n        F.col(\"customerID\").alias(\"Churn\")\n    )\n\n    churned = customers.join(\n        terminations, customers.customerID == terminations.Churn, how=\"leftouter\"\n    ).select(\n        \"customerID\", F.when(F.col(\"Churn\").isNull(), F.lit(False)).otherwise(F.lit(True)).alias(\"Churn\")\n    )\n\n    customer_charges = customers.join(\n        counts_and_charges.where(F.col(\"kind\") == \"Charge\"), \"customerID\", how=\"leftouter\"\n    ).select(\n        \"customerID\",\n        F.col(\"event_counts\").alias(\"tenure\"),\n        F.col(\"total_charges\").alias(\"TotalCharges\"),\n    ).fillna({'tenure': 0, 'TotalCharges': 0.0})\n\n    _register_views(locals(), \"counts_and_charges\", \"terminations\", \"churned\", \"customer_charges\")\n \n    # counts_and_charges.createOrReplaceTempView(\"counts_and_charges\")\n    # terminations.createOrReplaceTempView(\"terminations\")\n    # churned.createOrReplaceTempView(\"churned\")\n    # customer_charges.createOrReplaceTempView(\"customer_charges\")\n    \n    customer_billing = churned.join(customer_charges, \"customerID\")\n    _register_views(locals(), \"counts_and_charges\", \"terminations\", \"churned\", \"customer_charges\", \"customer_billing\")\n \n    return customer_billing\n\n\ndef join_phone_features(phone_features_df):\n    phone_features = phone_features_df\n\n    phone_service = phone_features.where(F.col(\"feature\") == \"PhoneService\").select(\n        \"customerID\", F.lit(\"Yes\").alias(\"PhoneService\")\n    )\n\n    multiple_lines = phone_features.where(F.col(\"feature\") == \"MultipleLines\").select(\n        \"customerID\", F.lit(\"Yes\").alias(\"MultipleLines\")\n    )\n\n    customer_phone_features = (\n        customers().join(phone_service, \"customerID\", how=\"leftouter\")\n        .join(multiple_lines, \"customerID\", how=\"leftouter\")\n        .select(\n            \"customerID\",\n            F.when(F.col(\"PhoneService\").isNull(), \"No\")\n            .otherwise(\"Yes\")\n            .alias(\"PhoneService\"),\n            \"MultipleLines\",\n        )\n        .select(\n            \"customerID\",\n            \"PhoneService\",\n            F.when(F.col(\"PhoneService\") == \"No\", \"No phone service\")\n            .otherwise(F.when(F.col(\"MultipleLines\").isNull(), \"No\").otherwise(\"Yes\"))\n            .alias(\"MultipleLines\"),\n        )\n    )\n\n    _register_views(locals(), \"phone_service\", \"multiple_lines\", \"customer_phone_features\")\n \n    return customer_phone_features\n\n\ndef untidy_feature(df, feature):\n    \"\"\" 'untidies' a feature by turning it into a column \"\"\"\n    return df.where(F.col(\"feature\") == feature).select(\n        \"customerID\", F.col(\"value\").alias(feature)\n    )\n\ndef chained_join(column, base_df, dfs, how=\"leftouter\"):\n    \"\"\" repeatedly joins a sequence of data frames on the same column \"\"\"\n    acc = base_df\n    for df in dfs:\n        acc = acc.join(df, column, how=how)\n\n    return acc\n\ndef resolve_nullable_column(df, col, null_val=\"No\"):\n    return F.when(df[col].isNull(), null_val).otherwise(df[col]).alias(col)\n\n\ndef resolve_dependent_column(\n    df,\n    col,\n    parent_col=\"InternetService\",\n    null_val=\"No\",\n    null_parent_val=\"No internet service\",\n):\n    return (\n        F.when((df[parent_col] == \"No\") | (df[parent_col].isNull()), null_parent_val)\n        .otherwise(F.when(df[col].isNull(), null_val).otherwise(df[col]))\n        .alias(col)\n    )\n\n\ndef join_internet_features(internet_features_df):\n\n    internet_features = internet_features_df\n\n    internet_service = untidy_feature(internet_features, \"InternetService\")\n    online_security = untidy_feature(internet_features, \"OnlineSecurity\")\n    online_backup = untidy_feature(internet_features, \"OnlineBackup\")\n    device_protection = untidy_feature(internet_features, \"DeviceProtection\")\n    tech_support = untidy_feature(internet_features, \"TechSupport\")\n    streaming_tv = untidy_feature(internet_features, \"StreamingTV\")\n    streaming_movies = untidy_feature(internet_features, \"StreamingMovies\")\n\n    customer_internet_features = chained_join(\n        \"customerID\",\n        customers(),\n        [\n            internet_service,\n            online_security,\n            online_backup,\n            device_protection,\n            tech_support,\n            streaming_tv,\n            streaming_movies,\n        ],\n    )\n    \n    customer_internet_features = customer_internet_features.select(\n        \"customerID\",\n        resolve_nullable_column(customer_internet_features, \"InternetService\"),\n        resolve_dependent_column(\n            customer_internet_features, \"OnlineSecurity\", \"InternetService\"\n        ),\n        resolve_dependent_column(\n            customer_internet_features, \"OnlineBackup\", \"InternetService\"\n        ),\n        resolve_dependent_column(\n            customer_internet_features, \"DeviceProtection\", \"InternetService\"\n        ),\n        resolve_dependent_column(\n            customer_internet_features, \"TechSupport\", \"InternetService\"\n        ),\n        resolve_dependent_column(\n            customer_internet_features, \"StreamingTV\", \"InternetService\"\n        ),\n        resolve_dependent_column(\n            customer_internet_features, \"StreamingMovies\", \"InternetService\"\n        ),\n    )\n\n    _register_views(locals(), \n        \"internet_service\",\n        \"online_security\",\n        \"online_backup\",\n        \"device_protection\",\n        \"tech_support\",\n        \"streaming_tv\",\n        \"streaming_movies\",\n        \"customer_internet_features\" \n    )\n\n    return customer_internet_features\n\n\ndef join_account_features(account_features_df):\n    account_features = account_features_df\n    contracts = untidy_feature(account_features, \"Contract\")\n\n    paperless = untidy_feature(account_features, \"PaperlessBilling\")\n\n    payment = untidy_feature(account_features, \"PaymentMethod\")\n\n    customer_account_features = chained_join(\n        \"customerID\", customers(), [contracts, paperless, payment]\n    )\n\n    customer_account_features = customer_account_features.select(\n        \"customerID\",\n        \"Contract\",\n        resolve_nullable_column(customer_account_features, \"PaperlessBilling\"),\n        \"PaymentMethod\",\n    )\n\n    _register_views(locals(), \"contracts\", \"paperless\", \"payment\", \"customer_account_features\")\n    \n    return customer_account_features\n\n\ndef process_account_meta(account_meta_df, usecal=None):\n    def is_senior_citizen(nowcol, dobcol):\n        if options['use_calendar_arithmetic']:\n            return F.when(\n                F.col(\"now\") >= F.add_months(\n                    F.col(\"dateOfBirth\"), 65 * 12\n                ), F.lit(True)\n            ).otherwise(F.lit(False))\n        else:\n            return (F.year(F.col(nowcol)) > (F.year(F.col(dobcol)) + 65)) | \\\n                (F.year(F.col(nowcol)) == (F.year(F.col(dobcol)) + 65)) & \\\n                (\n                    (F.month(F.col(nowcol)) < F.month(F.col(dobcol))) | \\\n                    (\n                        (F.month(F.col(nowcol)) == F.month(F.col(dobcol))) & \\\n                        (F.dayofmonth(F.col(nowcol)) <= F.dayofmonth(F.col(nowcol)))\n                    )\n                )\n\n    customer_account_meta = account_meta_df.select(\n        \"customerID\",\n        is_senior_citizen(\"now\", \"dateOfBirth\").alias(\"SeniorCitizen\"),\n        \"Partner\",\n        \"Dependents\",\n        \"gender\",\n        \"MonthlyCharges\",\n    )\n    \n    _register_views(locals(), \"customer_account_meta\")\n    return customer_account_meta\n\ndef forcefloat(c):\n    return F.col(c).cast(\"float\").alias(c)\n\n\ndef join_wide_table(customer_billing, customer_phone_features, customer_internet_features, customer_account_features, customer_account_meta):\n\n    wide_data = chained_join(\n        \"customerID\",\n        customers(),\n        [\n            customer_billing,\n            customer_phone_features,\n            customer_internet_features,\n            customer_account_features,\n            customer_account_meta,\n        ],\n    ).select(\n        \"customerID\",\n        \"gender\",\n        \"SeniorCitizen\",\n        \"Partner\",\n        \"Dependents\",\n        \"tenure\",\n        \"PhoneService\",\n        \"MultipleLines\",\n        \"InternetService\",\n        \"OnlineSecurity\",\n        \"OnlineBackup\",\n        \"DeviceProtection\",\n        \"TechSupport\",\n        \"StreamingTV\",\n        \"StreamingMovies\",\n        \"Contract\",\n        \"PaperlessBilling\",\n        \"PaymentMethod\",\n        \"MonthlyCharges\",\n        \"TotalCharges\",\n        \"Churn\",\n    )\n\n    return wide_data\n\n    # In[ ]:\n\ndef cast_and_coalesce_wide_data(wd):\n    if options[\"coalesce_output\"] > 0:\n        wd = wd.coalesce(options[\"coalesce_output\"])\n    \n    return wd.select(\n        \"customerID\",\n        \"gender\",\n        \"SeniorCitizen\",\n        \"Partner\",\n        \"Dependents\",\n        \"tenure\",\n        \"PhoneService\",\n        \"MultipleLines\",\n        \"InternetService\",\n        \"OnlineSecurity\",\n        \"OnlineBackup\",\n        \"DeviceProtection\",\n        \"TechSupport\",\n        \"StreamingTV\",\n        \"StreamingMovies\",\n        \"Contract\",\n        \"PaperlessBilling\",\n        \"PaymentMethod\",\n        forcefloat(\"MonthlyCharges\"),\n        forcefloat(\"TotalCharges\"),\n        \"Churn\",\n    )\n\ndef write_df(df, name):\n    output_kind = options[\"output_kind\"]\n    output_mode = options[\"output_mode\"]\n    output_prefix = options[\"output_prefix\"]\n    \n    name = \"%s.%s\" % (name, output_kind)\n    if output_prefix != \"\":\n        name = \"%s%s\" % (output_prefix, name)\n    kwargs = {}\n    if output_kind == \"csv\":\n        kwargs[\"header\"] = True\n    getattr(df.write.mode(output_mode), output_kind)(name, **kwargs)\n\n"
  },
  {
    "path": "examples/SQL+DF-Examples/customer-churn/notebooks/python/etl.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Transforming and joining raw data\\n\",\n    \"\\n\",\n    \"The \\\"raw\\\" data is divided among the following tables:\\n\",\n    \"\\n\",\n    \"- **Customer metadata**\\n\",\n    \"  - customerID\\n\",\n    \"  - gender\\n\",\n    \"  - date of birth (we'll derive age and senior citizen status from this)\\n\",\n    \"  - Partner\\n\",\n    \"  - Dependents\\n\",\n    \"  - (nominal) MonthlyCharges\\n\",\n    \"- **Billing events**\\n\",\n    \"  - customerID\\n\",\n    \"  - date (we'll derive tenure from the number/duration of billing events)\\n\",\n    \"  - kind (one of \\\"AccountCreation\\\", \\\"Charge\\\", or \\\"AccountTermination\\\")\\n\",\n    \"  - value (either a positive nonzero amount or 0.00; we'll derive TotalCharges from the sum of amounts and Churn from the existence of an AccountTermination event)\\n\",\n    \"- **Customer phone features**\\n\",\n    \"  - customerID\\n\",\n    \"  - feature (one of \\\"PhoneService\\\" or \\\"MultipleLines\\\")\\n\",\n    \"- **Customer internet features**\\n\",\n    \"  - customerID\\n\",\n    \"  - feature (one of \\\"InternetService\\\", \\\"OnlineSecurity\\\", \\\"OnlineBackup\\\", \\\"DeviceProtection\\\", \\\"TechSupport\\\", \\\"StreamingTV\\\", \\\"StreamingMovies\\\")\\n\",\n    \"  - value (one of \\\"Fiber\\\", \\\"DSL\\\", \\\"Yes\\\", \\\"No\\\")\\n\",\n    \"- **Customer account features**\\n\",\n    \"  - customerID\\n\",\n    \"  - feature (one of \\\"Contract\\\", \\\"PaperlessBilling\\\", \\\"PaymentMethod\\\")\\n\",\n    \"  - value (one of \\\"Month-to-month\\\", \\\"One year\\\", \\\"Two year\\\", \\\"No\\\", \\\"Yes\\\", \\\"Credit card (automatic)\\\", \\\"Mailed check\\\", \\\"Bank transfer (automatic)\\\", \\\"Electronic check\\\")\\n\",\n    \"\\n\",\n    \"We want to join these together to reconstitute a training data set with this schema:\\n\",\n    \"\\n\",\n    \"- customerID\\n\",\n    \"- gender\\n\",\n    \"- SeniorCitizen\\n\",\n    \"- Partner\\n\",\n    \"- Dependents\\n\",\n    \"- tenure\\n\",\n    \"- PhoneService\\n\",\n    \"- MultipleLines\\n\",\n    \"- InternetService\\n\",\n    \"- OnlineSecurity\\n\",\n    \"- OnlineBackup\\n\",\n    \"- DeviceProtection\\n\",\n    \"- TechSupport\\n\",\n    \"- StreamingTV\\n\",\n    \"- StreamingMovies\\n\",\n    \"- Contract\\n\",\n    \"- PaperlessBilling\\n\",\n    \"- PaymentMethod\\n\",\n    \"- MonthlyCharges\\n\",\n    \"- TotalCharges\\n\",\n    \"- Churn\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {\n    \"tags\": [\n     \"parameters\"\n    ]\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# notebook parameters\\n\",\n    \"\\n\",\n    \"import os\\n\",\n    \"\\n\",\n    \"spark_master = os.getenv(\\\"SPARK_MASTER_URL\\\", \\\"/spark://ip:port\\\")\\n\",\n    \"app_name = \\\"churn-etl\\\"\\n\",\n    \"input_files = dict(\\n\",\n    \"    billing=\\\"billing_events\\\", \\n\",\n    \"    account_features=\\\"customer_account_features\\\", \\n\",\n    \"    internet_features=\\\"customer_internet_features\\\", \\n\",\n    \"    meta=\\\"customer_meta\\\", \\n\",\n    \"    phone_features=\\\"customer_phone_features\\\"\\n\",\n    \")\\n\",\n    \"output_file = \\\"churn-etl\\\"\\n\",\n    \"output_mode = \\\"overwrite\\\"\\n\",\n    \"output_kind = \\\"parquet\\\"\\n\",\n    \"input_kind = \\\"parquet\\\"\\n\",\n    \"driver_memory = '8g'\\n\",\n    \"executor_memory = '8g'\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/html\": [\n       \"\\n\",\n       \"            <div>\\n\",\n       \"                <p><b>SparkSession - hive</b></p>\\n\",\n       \"                \\n\",\n       \"        <div>\\n\",\n       \"            <p><b>SparkContext</b></p>\\n\",\n       \"\\n\",\n       \"            <p><a href=\\\"http://10.19.183.210:4040\\\">Spark UI</a></p>\\n\",\n       \"\\n\",\n       \"            <dl>\\n\",\n       \"              <dt>Version</dt>\\n\",\n       \"                <dd><code>v3.2.0</code></dd>\\n\",\n       \"              <dt>Master</dt>\\n\",\n       \"                <dd><code>spark://yuanli-System-Product-Name:7077</code></dd>\\n\",\n       \"              <dt>AppName</dt>\\n\",\n       \"                <dd><code>PySparkShell</code></dd>\\n\",\n       \"            </dl>\\n\",\n       \"        </div>\\n\",\n       \"        \\n\",\n       \"            </div>\\n\",\n       \"        \"\n      ],\n      \"text/plain\": [\n       \"<pyspark.sql.session.SparkSession at 0x7efe7c3e15b0>\"\n      ]\n     },\n     \"execution_count\": 2,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"import pyspark\\n\",\n    \"\\n\",\n    \"session = pyspark.sql.SparkSession.builder \\\\\\n\",\n    \"    .master(spark_master) \\\\\\n\",\n    \"    .appName(app_name) \\\\\\n\",\n    \"    .config(\\\"spark.eventLog.enabled\\\", True) \\\\\\n\",\n    \"    .config(\\\"spark.eventLog.dir\\\", \\\".\\\") \\\\\\n\",\n    \"    .config(\\\"spark.driver.memory\\\", driver_memory) \\\\\\n\",\n    \"    .config(\\\"spark.executor.memory\\\", executor_memory) \\\\\\n\",\n    \"    .getOrCreate()\\n\",\n    \"session\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import churn.etl\\n\",\n    \"\\n\",\n    \"churn.etl.register_options(\\n\",\n    \"    spark_master = spark_master,\\n\",\n    \"    app_name = app_name,\\n\",\n    \"    input_files = input_files,\\n\",\n    \"    output_mode = output_mode,\\n\",\n    \"    output_kind = output_kind,\\n\",\n    \"    input_kind = input_kind,\\n\",\n    \"    driver_memory = driver_memory,\\n\",\n    \"    executor_memory = executor_memory\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Reconstructing billing events and charges\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"                                                                                \\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"root\\n\",\n      \" |-- customerID: string (nullable = true)\\n\",\n      \" |-- kind: string (nullable = true)\\n\",\n      \" |-- value: decimal(8,2) (nullable = true)\\n\",\n      \" |-- date: date (nullable = true)\\n\",\n      \" |-- month: string (nullable = true)\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"from churn.etl import read_df\\n\",\n    \"billing_events = read_df(session, input_files[\\\"billing\\\"])\\n\",\n    \"billing_events.printSchema()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from churn.etl import join_billing_data\\n\",\n    \"customer_billing = join_billing_data(billing_events)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"DataFrame[customerID: string, Churn: boolean, tenure: bigint, TotalCharges: decimal(18,2)]\"\n      ]\n     },\n     \"execution_count\": 6,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"customer_billing\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"When we aggregated billing data, we also captured a unique list of customers in a temporary view.  For convenience, we can access it as follows:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from churn.etl import customers as get_customers\\n\",\n    \"customers = get_customers()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Reconstructing phone features\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"root\\n\",\n      \" |-- customerID: string (nullable = true)\\n\",\n      \" |-- feature: string (nullable = true)\\n\",\n      \" |-- value: string (nullable = true)\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"phone_features = read_df(session, input_files[\\\"phone_features\\\"])\\n\",\n    \"phone_features.printSchema()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 9,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from churn.etl import join_phone_features\\n\",\n    \"customer_phone_features = join_phone_features(phone_features)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Reconstructing internet features\\n\",\n    \"\\n\",\n    \"Whereas phone features only include whether or not there are multiple lines, there are several internet-specific features in accounts:\\n\",\n    \"\\n\",\n    \"- `InternetService` (one of `Fiber optic` or `DSL` in the \\\"raw\\\" data; its absence translates to `No` in the processed data)\\n\",\n    \"- `OnlineSecurity` (`Yes` in the \\\"raw\\\" data if present; one of `No`, `Yes`, or `No internet service` in the processed data)\\n\",\n    \"- `OnlineBackup` (`Yes` in the \\\"raw\\\" data if present; one of `No`, `Yes`, or `No internet service` in the processed data)\\n\",\n    \"- `DeviceProtection` (`Yes` in the \\\"raw\\\" data if present; one of `No`, `Yes`, or `No internet service` in the processed data)\\n\",\n    \"- `TechSupport` (`Yes` in the \\\"raw\\\" data if present; one of `No`, `Yes`, or `No internet service` in the processed data)\\n\",\n    \"- `StreamingTV` (`Yes` in the \\\"raw\\\" data if present; one of `No`, `Yes`, or `No internet service` in the processed data)\\n\",\n    \"- `StreamingMovies` (`Yes` in the \\\"raw\\\" data if present; one of `No`, `Yes`, or `No internet service` in the processed data)\\n\",\n    \"\\n\",\n    \"This will lead to some slightly more interesting joins!\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 10,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"root\\n\",\n      \" |-- customerID: string (nullable = true)\\n\",\n      \" |-- feature: string (nullable = true)\\n\",\n      \" |-- value: string (nullable = true)\\n\",\n      \"\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2022-04-05 09:59:39,224 WARN rapids.GpuOverrides: \\n\",\n      \"!Exec <CollectLimitExec> cannot run on GPU because the Exec CollectLimitExec has been disabled, and is disabled by default because Collect Limit replacement can be slower on the GPU, if huge number of rows in a batch it could help by limiting the number of rows transferred from GPU to CPU. Set spark.rapids.sql.exec.CollectLimitExec to true if you wish to enable it\\n\",\n      \"  @Partitioning <SinglePartition$> could run on GPU\\n\",\n      \"\\n\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+--------------------+---------------+-----+\\n\",\n      \"|          customerID|        feature|value|\\n\",\n      \"+--------------------+---------------+-----+\\n\",\n      \"|7590-VHVEG-Mg8VG5...|InternetService|  DSL|\\n\",\n      \"|7590-VHVEG-5xLi5Z...|InternetService|  DSL|\\n\",\n      \"|7590-VHVEG-ZePlJi...|InternetService|  DSL|\\n\",\n      \"|7590-VHVEG-x9IoNd...|InternetService|  DSL|\\n\",\n      \"|7590-VHVEG-Z9yCIk...|InternetService|  DSL|\\n\",\n      \"|7590-VHVEG-K8kBya...|InternetService|  DSL|\\n\",\n      \"|7590-VHVEG-4ZjnIU...|InternetService|  DSL|\\n\",\n      \"|7590-VHVEG-0stTDJ...|InternetService|  DSL|\\n\",\n      \"|7590-VHVEG-lqhKlh...|InternetService|  DSL|\\n\",\n      \"|7590-VHVEG-4Y_zUA...|InternetService|  DSL|\\n\",\n      \"|7590-VHVEG-34V86Q...|InternetService|  DSL|\\n\",\n      \"|7590-VHVEG-GCNzU2...|InternetService|  DSL|\\n\",\n      \"|7590-VHVEG-i0AFUE...|InternetService|  DSL|\\n\",\n      \"|7590-VHVEG-F1ALBc...|InternetService|  DSL|\\n\",\n      \"|7590-VHVEG-aEfHl7...|InternetService|  DSL|\\n\",\n      \"|7590-VHVEG-eiqTDe...|InternetService|  DSL|\\n\",\n      \"|7590-VHVEG-3K15yQ...|InternetService|  DSL|\\n\",\n      \"|7590-VHVEG-iMYyeZ...|InternetService|  DSL|\\n\",\n      \"|7590-VHVEG-rReekB...|InternetService|  DSL|\\n\",\n      \"|7590-VHVEG-2l92Zs...|InternetService|  DSL|\\n\",\n      \"+--------------------+---------------+-----+\\n\",\n      \"only showing top 20 rows\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"internet_features = read_df(session, input_files[\\\"internet_features\\\"])\\n\",\n    \"internet_features.printSchema()\\n\",\n    \"internet_features.show()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 11,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from churn.etl import join_internet_features\\n\",\n    \"customer_internet_features = join_internet_features(internet_features)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Reconstructing account features\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 12,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"root\\n\",\n      \" |-- customerID: string (nullable = true)\\n\",\n      \" |-- feature: string (nullable = true)\\n\",\n      \" |-- value: string (nullable = true)\\n\",\n      \"\\n\",\n      \"+--------------------+-------------+----------------+\\n\",\n      \"|          customerID|      feature|           value|\\n\",\n      \"+--------------------+-------------+----------------+\\n\",\n      \"|7590-VHVEG-Mg8VG5...|PaymentMethod|Electronic check|\\n\",\n      \"|7590-VHVEG-5xLi5Z...|PaymentMethod|Electronic check|\\n\",\n      \"|7590-VHVEG-ZePlJi...|PaymentMethod|Electronic check|\\n\",\n      \"|7590-VHVEG-x9IoNd...|PaymentMethod|Electronic check|\\n\",\n      \"|7590-VHVEG-Z9yCIk...|PaymentMethod|Electronic check|\\n\",\n      \"|7590-VHVEG-K8kBya...|PaymentMethod|Electronic check|\\n\",\n      \"|7590-VHVEG-4ZjnIU...|PaymentMethod|Electronic check|\\n\",\n      \"|7590-VHVEG-0stTDJ...|PaymentMethod|Electronic check|\\n\",\n      \"|7590-VHVEG-lqhKlh...|PaymentMethod|Electronic check|\\n\",\n      \"|7590-VHVEG-4Y_zUA...|PaymentMethod|Electronic check|\\n\",\n      \"|7590-VHVEG-34V86Q...|PaymentMethod|Electronic check|\\n\",\n      \"|7590-VHVEG-GCNzU2...|PaymentMethod|Electronic check|\\n\",\n      \"|7590-VHVEG-i0AFUE...|PaymentMethod|Electronic check|\\n\",\n      \"|7590-VHVEG-F1ALBc...|PaymentMethod|Electronic check|\\n\",\n      \"|7590-VHVEG-aEfHl7...|PaymentMethod|Electronic check|\\n\",\n      \"|7590-VHVEG-3K15yQ...|PaymentMethod|Electronic check|\\n\",\n      \"|7590-VHVEG-eiqTDe...|PaymentMethod|Electronic check|\\n\",\n      \"|7590-VHVEG-iMYyeZ...|PaymentMethod|Electronic check|\\n\",\n      \"|7590-VHVEG-rReekB...|PaymentMethod|Electronic check|\\n\",\n      \"|7590-VHVEG-2l92Zs...|PaymentMethod|Electronic check|\\n\",\n      \"+--------------------+-------------+----------------+\\n\",\n      \"only showing top 20 rows\\n\",\n      \"\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2022-04-05 09:59:42,068 WARN rapids.GpuOverrides: \\n\",\n      \"!Exec <CollectLimitExec> cannot run on GPU because the Exec CollectLimitExec has been disabled, and is disabled by default because Collect Limit replacement can be slower on the GPU, if huge number of rows in a batch it could help by limiting the number of rows transferred from GPU to CPU. Set spark.rapids.sql.exec.CollectLimitExec to true if you wish to enable it\\n\",\n      \"  @Partitioning <SinglePartition$> could run on GPU\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"account_features = read_df(session, input_files[\\\"account_features\\\"])\\n\",\n    \"account_features.printSchema()\\n\",\n    \"account_features.show()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 13,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from churn.etl import join_account_features\\n\",\n    \"customer_account_features = join_account_features(account_features)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Account metadata\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 14,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"root\\n\",\n      \" |-- customerID: string (nullable = true)\\n\",\n      \" |-- dateOfBirth: date (nullable = true)\\n\",\n      \" |-- gender: string (nullable = true)\\n\",\n      \" |-- SeniorCitizen: string (nullable = true)\\n\",\n      \" |-- Partner: string (nullable = true)\\n\",\n      \" |-- Dependents: string (nullable = true)\\n\",\n      \" |-- MonthlyCharges: decimal(8,2) (nullable = true)\\n\",\n      \" |-- now: timestamp (nullable = true)\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"account_meta = read_df(session, input_files[\\\"meta\\\"])\\n\",\n    \"account_meta.printSchema()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 15,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from churn.etl import process_account_meta\\n\",\n    \"customer_account_meta = process_account_meta(account_meta)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Putting it all together\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 16,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from churn.etl import chained_join\\n\",\n    \"from churn.etl import forcefloat\\n\",\n    \"\\n\",\n    \"wide_data = chained_join(\\n\",\n    \"    \\\"customerID\\\",\\n\",\n    \"    customers,\\n\",\n    \"    [\\n\",\n    \"        customer_billing,\\n\",\n    \"        customer_phone_features,\\n\",\n    \"        customer_internet_features,\\n\",\n    \"        customer_account_features,\\n\",\n    \"        customer_account_meta\\n\",\n    \"    ]\\n\",\n    \").select(\\n\",\n    \"    \\\"customerID\\\", \\n\",\n    \"    \\\"gender\\\", \\n\",\n    \"    \\\"SeniorCitizen\\\", \\n\",\n    \"    \\\"Partner\\\", \\n\",\n    \"    \\\"Dependents\\\", \\n\",\n    \"    \\\"tenure\\\", \\n\",\n    \"    \\\"PhoneService\\\", \\n\",\n    \"    \\\"MultipleLines\\\", \\n\",\n    \"    \\\"InternetService\\\", \\n\",\n    \"    \\\"OnlineSecurity\\\", \\n\",\n    \"    \\\"OnlineBackup\\\", \\n\",\n    \"    \\\"DeviceProtection\\\", \\n\",\n    \"    \\\"TechSupport\\\", \\n\",\n    \"    \\\"StreamingTV\\\", \\n\",\n    \"    \\\"StreamingMovies\\\", \\n\",\n    \"    \\\"Contract\\\", \\n\",\n    \"    \\\"PaperlessBilling\\\", \\n\",\n    \"    \\\"PaymentMethod\\\", \\n\",\n    \"    forcefloat(\\\"MonthlyCharges\\\"),\\n\",\n    \"    forcefloat(\\\"TotalCharges\\\"), \\n\",\n    \"    \\\"Churn\\\"\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 17,\n   \"metadata\": {\n    \"scrolled\": false\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"== Physical Plan ==\\n\",\n      \"AdaptiveSparkPlan isFinalPlan=false\\n\",\n      \"+- Project [customerID#0, gender#265, SeniorCitizen#279, Partner#267, Dependents#268, tenure#61L, PhoneService#97, MultipleLines#101, InternetService#199, OnlineSecurity#200, OnlineBackup#201, DeviceProtection#202, TechSupport#203, StreamingTV#204, StreamingMovies#205, Contract#233, PaperlessBilling#258, PaymentMethod#239, cast(MonthlyCharges#269 as float) AS MonthlyCharges#366, cast(TotalCharges#62 as float) AS TotalCharges#367, Churn#41]\\n\",\n      \"   +- BroadcastHashJoin [customerID#0], [customerID#263], LeftOuter, BuildRight, false\\n\",\n      \"      :- Project [customerID#0, Churn#41, tenure#61L, TotalCharges#62, PhoneService#97, MultipleLines#101, InternetService#199, OnlineSecurity#200, OnlineBackup#201, DeviceProtection#202, TechSupport#203, StreamingTV#204, StreamingMovies#205, Contract#233, PaperlessBilling#258, PaymentMethod#239]\\n\",\n      \"      :  +- SortMergeJoin [customerID#0], [customerID#324], LeftOuter\\n\",\n      \"      :     :- Project [customerID#0, Churn#41, tenure#61L, TotalCharges#62, PhoneService#97, MultipleLines#101, InternetService#199, OnlineSecurity#200, OnlineBackup#201, DeviceProtection#202, TechSupport#203, StreamingTV#204, StreamingMovies#205]\\n\",\n      \"      :     :  +- SortMergeJoin [customerID#0], [customerID#306], LeftOuter\\n\",\n      \"      :     :     :- Project [customerID#0, Churn#41, tenure#61L, TotalCharges#62, PhoneService#97, MultipleLines#101]\\n\",\n      \"      :     :     :  +- SortMergeJoin [customerID#0], [customerID#295], LeftOuter\\n\",\n      \"      :     :     :     :- Project [customerID#0, Churn#41, tenure#61L, TotalCharges#62]\\n\",\n      \"      :     :     :     :  +- SortMergeJoin [customerID#0], [customerID#286], LeftOuter\\n\",\n      \"      :     :     :     :     :- Sort [customerID#0 ASC NULLS FIRST], false, 0\\n\",\n      \"      :     :     :     :     :  +- HashAggregate(keys=[customerID#0], functions=[])\\n\",\n      \"      :     :     :     :     :     +- Exchange hashpartitioning(customerID#0, 200), ENSURE_REQUIREMENTS, [id=#550]\\n\",\n      \"      :     :     :     :     :        +- HashAggregate(keys=[customerID#0], functions=[])\\n\",\n      \"      :     :     :     :     :           +- Project [customerID#0]\\n\",\n      \"      :     :     :     :     :              +- FileScan parquet [customerID#0,month#4] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/data/home/yuanli/work/customer-churn/data-science-blueprints/chu..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct<customerID:string>\\n\",\n      \"      :     :     :     :     +- Project [customerID#286, Churn#41, tenure#61L, TotalCharges#62]\\n\",\n      \"      :     :     :     :        +- SortMergeJoin [customerID#286], [customerID#66], Inner\\n\",\n      \"      :     :     :     :           :- Project [customerID#286, isnotnull(Churn#30) AS Churn#41]\\n\",\n      \"      :     :     :     :           :  +- SortMergeJoin [customerID#286], [Churn#30], LeftOuter\\n\",\n      \"      :     :     :     :           :     :- Sort [customerID#286 ASC NULLS FIRST], false, 0\\n\",\n      \"      :     :     :     :           :     :  +- HashAggregate(keys=[customerID#286], functions=[])\\n\",\n      \"      :     :     :     :           :     :     +- Exchange hashpartitioning(customerID#286, 200), ENSURE_REQUIREMENTS, [id=#552]\\n\",\n      \"      :     :     :     :           :     :        +- HashAggregate(keys=[customerID#286], functions=[])\\n\",\n      \"      :     :     :     :           :     :           +- Project [customerID#286]\\n\",\n      \"      :     :     :     :           :     :              +- Filter isnotnull(customerID#286)\\n\",\n      \"      :     :     :     :           :     :                 +- FileScan parquet [customerID#286,month#290] Batched: true, DataFilters: [isnotnull(customerID#286)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/data/home/yuanli/work/customer-churn/data-science-blueprints/chu..., PartitionFilters: [], PushedFilters: [IsNotNull(customerID)], ReadSchema: struct<customerID:string>\\n\",\n      \"      :     :     :     :           :     +- Sort [Churn#30 ASC NULLS FIRST], false, 0\\n\",\n      \"      :     :     :     :           :        +- Exchange hashpartitioning(Churn#30, 200), ENSURE_REQUIREMENTS, [id=#556]\\n\",\n      \"      :     :     :     :           :           +- Project [customerID#32 AS Churn#30]\\n\",\n      \"      :     :     :     :           :              +- Filter ((isnotnull(kind#33) AND (kind#33 = AccountTermination)) AND isnotnull(customerID#32))\\n\",\n      \"      :     :     :     :           :                 +- FileScan parquet [customerID#32,kind#33,month#36] Batched: true, DataFilters: [isnotnull(kind#33), (kind#33 = AccountTermination), isnotnull(customerID#32)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/data/home/yuanli/work/customer-churn/data-science-blueprints/chu..., PartitionFilters: [], PushedFilters: [IsNotNull(kind), EqualTo(kind,AccountTermination), IsNotNull(customerID)], ReadSchema: struct<customerID:string,kind:string>\\n\",\n      \"      :     :     :     :           +- Project [customerID#66, coalesce(event_counts#23L, 0) AS tenure#61L, coalesce(total_charges#25, 0.00) AS TotalCharges#62]\\n\",\n      \"      :     :     :     :              +- SortMergeJoin [customerID#66], [customerID#44], LeftOuter\\n\",\n      \"      :     :     :     :                 :- Sort [customerID#66 ASC NULLS FIRST], false, 0\\n\",\n      \"      :     :     :     :                 :  +- HashAggregate(keys=[customerID#66], functions=[])\\n\",\n      \"      :     :     :     :                 :     +- Exchange hashpartitioning(customerID#66, 200), ENSURE_REQUIREMENTS, [id=#561]\\n\",\n      \"      :     :     :     :                 :        +- HashAggregate(keys=[customerID#66], functions=[])\\n\",\n      \"      :     :     :     :                 :           +- Project [customerID#66]\\n\",\n      \"      :     :     :     :                 :              +- Filter isnotnull(customerID#66)\\n\",\n      \"      :     :     :     :                 :                 +- FileScan parquet [customerID#66,month#70] Batched: true, DataFilters: [isnotnull(customerID#66)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/data/home/yuanli/work/customer-churn/data-science-blueprints/chu..., PartitionFilters: [], PushedFilters: [IsNotNull(customerID)], ReadSchema: struct<customerID:string>\\n\",\n      \"      :     :     :     :                 +- Sort [customerID#44 ASC NULLS FIRST], false, 0\\n\",\n      \"      :     :     :     :                    +- Exchange hashpartitioning(customerID#44, 200), ENSURE_REQUIREMENTS, [id=#567]\\n\",\n      \"      :     :     :     :                       +- HashAggregate(keys=[customerID#44, kind#45], functions=[count(value#46), sum(UnscaledValue(value#46))])\\n\",\n      \"      :     :     :     :                          +- Exchange hashpartitioning(customerID#44, kind#45, 200), ENSURE_REQUIREMENTS, [id=#563]\\n\",\n      \"      :     :     :     :                             +- HashAggregate(keys=[customerID#44, kind#45], functions=[partial_count(value#46), partial_sum(UnscaledValue(value#46))])\\n\",\n      \"      :     :     :     :                                +- Project [customerID#44, kind#45, value#46]\\n\",\n      \"      :     :     :     :                                   +- Filter ((isnotnull(kind#45) AND (kind#45 = Charge)) AND isnotnull(customerID#44))\\n\",\n      \"      :     :     :     :                                      +- FileScan parquet [customerID#44,kind#45,value#46,month#48] Batched: true, DataFilters: [isnotnull(kind#45), (kind#45 = Charge), isnotnull(customerID#44)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/data/home/yuanli/work/customer-churn/data-science-blueprints/chu..., PartitionFilters: [], PushedFilters: [IsNotNull(kind), EqualTo(kind,Charge), IsNotNull(customerID)], ReadSchema: struct<customerID:string,kind:string,value:decimal(8,2)>\\n\",\n      \"      :     :     :     +- Sort [customerID#295 ASC NULLS FIRST], false, 0\\n\",\n      \"      :     :     :        +- Project [customerID#295, CASE WHEN isnull(PhoneService#82) THEN No ELSE Yes END AS PhoneService#97, CASE WHEN isnull(PhoneService#82) THEN No phone service ELSE CASE WHEN isnull(MultipleLines#85) THEN No ELSE Yes END END AS MultipleLines#101]\\n\",\n      \"      :     :     :           +- BroadcastHashJoin [customerID#295], [customerID#91], LeftOuter, BuildRight, false\\n\",\n      \"      :     :     :              :- Project [customerID#295, PhoneService#82]\\n\",\n      \"      :     :     :              :  +- BroadcastHashJoin [customerID#295], [customerID#76], LeftOuter, BuildRight, false\\n\",\n      \"      :     :     :              :     :- HashAggregate(keys=[customerID#295], functions=[])\\n\",\n      \"      :     :     :              :     :  +- Exchange hashpartitioning(customerID#295, 200), ENSURE_REQUIREMENTS, [id=#580]\\n\",\n      \"      :     :     :              :     :     +- HashAggregate(keys=[customerID#295], functions=[])\\n\",\n      \"      :     :     :              :     :        +- Project [customerID#295]\\n\",\n      \"      :     :     :              :     :           +- Filter isnotnull(customerID#295)\\n\",\n      \"      :     :     :              :     :              +- FileScan parquet [customerID#295,month#299] Batched: true, DataFilters: [isnotnull(customerID#295)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/data/home/yuanli/work/customer-churn/data-science-blueprints/chu..., PartitionFilters: [], PushedFilters: [IsNotNull(customerID)], ReadSchema: struct<customerID:string>\\n\",\n      \"      :     :     :              :     +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, string, true]),false), [id=#583]\\n\",\n      \"      :     :     :              :        +- Project [customerID#76, Yes AS PhoneService#82]\\n\",\n      \"      :     :     :              :           +- Filter ((isnotnull(feature#77) AND (feature#77 = PhoneService)) AND isnotnull(customerID#76))\\n\",\n      \"      :     :     :              :              +- FileScan parquet [customerID#76,feature#77] Batched: true, DataFilters: [isnotnull(feature#77), (feature#77 = PhoneService), isnotnull(customerID#76)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/data/home/yuanli/work/customer-churn/data-science-blueprints/chu..., PartitionFilters: [], PushedFilters: [IsNotNull(feature), EqualTo(feature,PhoneService), IsNotNull(customerID)], ReadSchema: struct<customerID:string,feature:string>\\n\",\n      \"      :     :     :              +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, string, true]),false), [id=#587]\\n\",\n      \"      :     :     :                 +- Project [customerID#91, Yes AS MultipleLines#85]\\n\",\n      \"      :     :     :                    +- Filter ((isnotnull(feature#92) AND (feature#92 = MultipleLines)) AND isnotnull(customerID#91))\\n\",\n      \"      :     :     :                       +- FileScan parquet [customerID#91,feature#92] Batched: true, DataFilters: [isnotnull(feature#92), (feature#92 = MultipleLines), isnotnull(customerID#91)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/data/home/yuanli/work/customer-churn/data-science-blueprints/chu..., PartitionFilters: [], PushedFilters: [IsNotNull(feature), EqualTo(feature,MultipleLines), IsNotNull(customerID)], ReadSchema: struct<customerID:string,feature:string>\\n\",\n      \"      :     :     +- Sort [customerID#306 ASC NULLS FIRST], false, 0\\n\",\n      \"      :     :        +- Project [customerID#306, CASE WHEN isnull(InternetService#124) THEN No ELSE InternetService#124 END AS InternetService#199, CASE WHEN ((InternetService#124 = No) OR isnull(InternetService#124)) THEN No internet service ELSE CASE WHEN isnull(OnlineSecurity#127) THEN No ELSE OnlineSecurity#127 END END AS OnlineSecurity#200, CASE WHEN ((InternetService#124 = No) OR isnull(InternetService#124)) THEN No internet service ELSE CASE WHEN isnull(OnlineBackup#130) THEN No ELSE OnlineBackup#130 END END AS OnlineBackup#201, CASE WHEN ((InternetService#124 = No) OR isnull(InternetService#124)) THEN No internet service ELSE CASE WHEN isnull(DeviceProtection#133) THEN No ELSE DeviceProtection#133 END END AS DeviceProtection#202, CASE WHEN ((InternetService#124 = No) OR isnull(InternetService#124)) THEN No internet service ELSE CASE WHEN isnull(TechSupport#136) THEN No ELSE TechSupport#136 END END AS TechSupport#203, CASE WHEN ((InternetService#124 = No) OR isnull(InternetService#124)) THEN No internet service ELSE CASE WHEN isnull(StreamingTV#139) THEN No ELSE StreamingTV#139 END END AS StreamingTV#204, CASE WHEN ((InternetService#124 = No) OR isnull(InternetService#124)) THEN No internet service ELSE CASE WHEN isnull(StreamingMovies#142) THEN No ELSE StreamingMovies#142 END END AS StreamingMovies#205]\\n\",\n      \"      :     :           +- BroadcastHashJoin [customerID#306], [customerID#188], LeftOuter, BuildRight, false\\n\",\n      \"      :     :              :- Project [customerID#306, InternetService#124, OnlineSecurity#127, OnlineBackup#130, DeviceProtection#133, TechSupport#136, StreamingTV#139]\\n\",\n      \"      :     :              :  +- BroadcastHashJoin [customerID#306], [customerID#178], LeftOuter, BuildRight, false\\n\",\n      \"      :     :              :     :- Project [customerID#306, InternetService#124, OnlineSecurity#127, OnlineBackup#130, DeviceProtection#133, TechSupport#136]\\n\",\n      \"      :     :              :     :  +- BroadcastHashJoin [customerID#306], [customerID#169], LeftOuter, BuildRight, false\\n\",\n      \"      :     :              :     :     :- Project [customerID#306, InternetService#124, OnlineSecurity#127, OnlineBackup#130, DeviceProtection#133]\\n\",\n      \"      :     :              :     :     :  +- BroadcastHashJoin [customerID#306], [customerID#161], LeftOuter, BuildRight, false\\n\",\n      \"      :     :              :     :     :     :- Project [customerID#306, InternetService#124, OnlineSecurity#127, OnlineBackup#130]\\n\",\n      \"      :     :              :     :     :     :  +- BroadcastHashJoin [customerID#306], [customerID#154], LeftOuter, BuildRight, false\\n\",\n      \"      :     :              :     :     :     :     :- Project [customerID#306, InternetService#124, OnlineSecurity#127]\\n\",\n      \"      :     :              :     :     :     :     :  +- BroadcastHashJoin [customerID#306], [customerID#148], LeftOuter, BuildRight, false\\n\",\n      \"      :     :              :     :     :     :     :     :- Project [customerID#306, InternetService#124]\\n\",\n      \"      :     :              :     :     :     :     :     :  +- BroadcastHashJoin [customerID#306], [customerID#105], LeftOuter, BuildRight, false\\n\",\n      \"      :     :              :     :     :     :     :     :     :- HashAggregate(keys=[customerID#306], functions=[])\\n\",\n      \"      :     :              :     :     :     :     :     :     :  +- Exchange hashpartitioning(customerID#306, 200), ENSURE_REQUIREMENTS, [id=#595]\\n\",\n      \"      :     :              :     :     :     :     :     :     :     +- HashAggregate(keys=[customerID#306], functions=[])\\n\",\n      \"      :     :              :     :     :     :     :     :     :        +- Project [customerID#306]\\n\",\n      \"      :     :              :     :     :     :     :     :     :           +- Filter isnotnull(customerID#306)\\n\",\n      \"      :     :              :     :     :     :     :     :     :              +- FileScan parquet [customerID#306,month#310] Batched: true, DataFilters: [isnotnull(customerID#306)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/data/home/yuanli/work/customer-churn/data-science-blueprints/chu..., PartitionFilters: [], PushedFilters: [IsNotNull(customerID)], ReadSchema: struct<customerID:string>\\n\",\n      \"      :     :              :     :     :     :     :     :     +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, string, true]),false), [id=#598]\\n\",\n      \"      :     :              :     :     :     :     :     :        +- Project [customerID#105, value#107 AS InternetService#124]\\n\",\n      \"      :     :              :     :     :     :     :     :           +- Filter ((isnotnull(feature#106) AND (feature#106 = InternetService)) AND isnotnull(customerID#105))\\n\",\n      \"      :     :              :     :     :     :     :     :              +- FileScan parquet [customerID#105,feature#106,value#107] Batched: true, DataFilters: [isnotnull(feature#106), (feature#106 = InternetService), isnotnull(customerID#105)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/data/home/yuanli/work/customer-churn/data-science-blueprints/chu..., PartitionFilters: [], PushedFilters: [IsNotNull(feature), EqualTo(feature,InternetService), IsNotNull(customerID)], ReadSchema: struct<customerID:string,feature:string,value:string>\\n\",\n      \"      :     :              :     :     :     :     :     +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, string, true]),false), [id=#602]\\n\",\n      \"      :     :              :     :     :     :     :        +- Project [customerID#148, value#150 AS OnlineSecurity#127]\\n\",\n      \"      :     :              :     :     :     :     :           +- Filter ((isnotnull(feature#149) AND (feature#149 = OnlineSecurity)) AND isnotnull(customerID#148))\\n\",\n      \"      :     :              :     :     :     :     :              +- FileScan parquet [customerID#148,feature#149,value#150] Batched: true, DataFilters: [isnotnull(feature#149), (feature#149 = OnlineSecurity), isnotnull(customerID#148)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/data/home/yuanli/work/customer-churn/data-science-blueprints/chu..., PartitionFilters: [], PushedFilters: [IsNotNull(feature), EqualTo(feature,OnlineSecurity), IsNotNull(customerID)], ReadSchema: struct<customerID:string,feature:string,value:string>\\n\",\n      \"      :     :              :     :     :     :     +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, string, true]),false), [id=#606]\\n\",\n      \"      :     :              :     :     :     :        +- Project [customerID#154, value#156 AS OnlineBackup#130]\\n\",\n      \"      :     :              :     :     :     :           +- Filter ((isnotnull(feature#155) AND (feature#155 = OnlineBackup)) AND isnotnull(customerID#154))\\n\",\n      \"      :     :              :     :     :     :              +- FileScan parquet [customerID#154,feature#155,value#156] Batched: true, DataFilters: [isnotnull(feature#155), (feature#155 = OnlineBackup), isnotnull(customerID#154)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/data/home/yuanli/work/customer-churn/data-science-blueprints/chu..., PartitionFilters: [], PushedFilters: [IsNotNull(feature), EqualTo(feature,OnlineBackup), IsNotNull(customerID)], ReadSchema: struct<customerID:string,feature:string,value:string>\\n\",\n      \"      :     :              :     :     :     +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, string, true]),false), [id=#610]\\n\",\n      \"      :     :              :     :     :        +- Project [customerID#161, value#163 AS DeviceProtection#133]\\n\",\n      \"      :     :              :     :     :           +- Filter ((isnotnull(feature#162) AND (feature#162 = DeviceProtection)) AND isnotnull(customerID#161))\\n\",\n      \"      :     :              :     :     :              +- FileScan parquet [customerID#161,feature#162,value#163] Batched: true, DataFilters: [isnotnull(feature#162), (feature#162 = DeviceProtection), isnotnull(customerID#161)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/data/home/yuanli/work/customer-churn/data-science-blueprints/chu..., PartitionFilters: [], PushedFilters: [IsNotNull(feature), EqualTo(feature,DeviceProtection), IsNotNull(customerID)], ReadSchema: struct<customerID:string,feature:string,value:string>\\n\",\n      \"      :     :              :     :     +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, string, true]),false), [id=#614]\\n\",\n      \"      :     :              :     :        +- Project [customerID#169, value#171 AS TechSupport#136]\\n\",\n      \"      :     :              :     :           +- Filter ((isnotnull(feature#170) AND (feature#170 = TechSupport)) AND isnotnull(customerID#169))\\n\",\n      \"      :     :              :     :              +- FileScan parquet [customerID#169,feature#170,value#171] Batched: true, DataFilters: [isnotnull(feature#170), (feature#170 = TechSupport), isnotnull(customerID#169)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/data/home/yuanli/work/customer-churn/data-science-blueprints/chu..., PartitionFilters: [], PushedFilters: [IsNotNull(feature), EqualTo(feature,TechSupport), IsNotNull(customerID)], ReadSchema: struct<customerID:string,feature:string,value:string>\\n\",\n      \"      :     :              :     +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, string, true]),false), [id=#618]\\n\",\n      \"      :     :              :        +- Project [customerID#178, value#180 AS StreamingTV#139]\\n\",\n      \"      :     :              :           +- Filter ((isnotnull(feature#179) AND (feature#179 = StreamingTV)) AND isnotnull(customerID#178))\\n\",\n      \"      :     :              :              +- FileScan parquet [customerID#178,feature#179,value#180] Batched: true, DataFilters: [isnotnull(feature#179), (feature#179 = StreamingTV), isnotnull(customerID#178)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/data/home/yuanli/work/customer-churn/data-science-blueprints/chu..., PartitionFilters: [], PushedFilters: [IsNotNull(feature), EqualTo(feature,StreamingTV), IsNotNull(customerID)], ReadSchema: struct<customerID:string,feature:string,value:string>\\n\",\n      \"      :     :              +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, string, true]),false), [id=#622]\\n\",\n      \"      :     :                 +- Project [customerID#188, value#190 AS StreamingMovies#142]\\n\",\n      \"      :     :                    +- Filter ((isnotnull(feature#189) AND (feature#189 = StreamingMovies)) AND isnotnull(customerID#188))\\n\",\n      \"      :     :                       +- FileScan parquet [customerID#188,feature#189,value#190] Batched: true, DataFilters: [isnotnull(feature#189), (feature#189 = StreamingMovies), isnotnull(customerID#188)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/data/home/yuanli/work/customer-churn/data-science-blueprints/chu..., PartitionFilters: [], PushedFilters: [IsNotNull(feature), EqualTo(feature,StreamingMovies), IsNotNull(customerID)], ReadSchema: struct<customerID:string,feature:string,value:string>\\n\",\n      \"      :     +- Sort [customerID#324 ASC NULLS FIRST], false, 0\\n\",\n      \"      :        +- Project [customerID#324, Contract#233, CASE WHEN isnull(PaperlessBilling#236) THEN No ELSE PaperlessBilling#236 END AS PaperlessBilling#258, PaymentMethod#239]\\n\",\n      \"      :           +- BroadcastHashJoin [customerID#324], [customerID#251], LeftOuter, BuildRight, false\\n\",\n      \"      :              :- Project [customerID#324, Contract#233, PaperlessBilling#236]\\n\",\n      \"      :              :  +- BroadcastHashJoin [customerID#324], [customerID#245], LeftOuter, BuildRight, false\\n\",\n      \"      :              :     :- Project [customerID#324, Contract#233]\\n\",\n      \"      :              :     :  +- BroadcastHashJoin [customerID#324], [customerID#214], LeftOuter, BuildRight, false\\n\",\n      \"      :              :     :     :- HashAggregate(keys=[customerID#324], functions=[])\\n\",\n      \"      :              :     :     :  +- Exchange hashpartitioning(customerID#324, 200), ENSURE_REQUIREMENTS, [id=#630]\\n\",\n      \"      :              :     :     :     +- HashAggregate(keys=[customerID#324], functions=[])\\n\",\n      \"      :              :     :     :        +- Project [customerID#324]\\n\",\n      \"      :              :     :     :           +- Filter isnotnull(customerID#324)\\n\",\n      \"      :              :     :     :              +- FileScan parquet [customerID#324,month#328] Batched: true, DataFilters: [isnotnull(customerID#324)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/data/home/yuanli/work/customer-churn/data-science-blueprints/chu..., PartitionFilters: [], PushedFilters: [IsNotNull(customerID)], ReadSchema: struct<customerID:string>\\n\",\n      \"      :              :     :     +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, string, true]),false), [id=#633]\\n\",\n      \"      :              :     :        +- Project [customerID#214, value#216 AS Contract#233]\\n\",\n      \"      :              :     :           +- Filter ((isnotnull(feature#215) AND (feature#215 = Contract)) AND isnotnull(customerID#214))\\n\",\n      \"      :              :     :              +- FileScan parquet [customerID#214,feature#215,value#216] Batched: true, DataFilters: [isnotnull(feature#215), (feature#215 = Contract), isnotnull(customerID#214)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/data/home/yuanli/work/customer-churn/data-science-blueprints/chu..., PartitionFilters: [], PushedFilters: [IsNotNull(feature), EqualTo(feature,Contract), IsNotNull(customerID)], ReadSchema: struct<customerID:string,feature:string,value:string>\\n\",\n      \"      :              :     +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, string, true]),false), [id=#637]\\n\",\n      \"      :              :        +- Project [customerID#245, value#247 AS PaperlessBilling#236]\\n\",\n      \"      :              :           +- Filter ((isnotnull(feature#246) AND (feature#246 = PaperlessBilling)) AND isnotnull(customerID#245))\\n\",\n      \"      :              :              +- FileScan parquet [customerID#245,feature#246,value#247] Batched: true, DataFilters: [isnotnull(feature#246), (feature#246 = PaperlessBilling), isnotnull(customerID#245)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/data/home/yuanli/work/customer-churn/data-science-blueprints/chu..., PartitionFilters: [], PushedFilters: [IsNotNull(feature), EqualTo(feature,PaperlessBilling), IsNotNull(customerID)], ReadSchema: struct<customerID:string,feature:string,value:string>\\n\",\n      \"      :              +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, string, true]),false), [id=#641]\\n\",\n      \"      :                 +- Project [customerID#251, value#253 AS PaymentMethod#239]\\n\",\n      \"      :                    +- Filter ((isnotnull(feature#252) AND (feature#252 = PaymentMethod)) AND isnotnull(customerID#251))\\n\",\n      \"      :                       +- FileScan parquet [customerID#251,feature#252,value#253] Batched: true, DataFilters: [isnotnull(feature#252), (feature#252 = PaymentMethod), isnotnull(customerID#251)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/data/home/yuanli/work/customer-churn/data-science-blueprints/chu..., PartitionFilters: [], PushedFilters: [IsNotNull(feature), EqualTo(feature,PaymentMethod), IsNotNull(customerID)], ReadSchema: struct<customerID:string,feature:string,value:string>\\n\",\n      \"      +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, string, true]),false), [id=#650]\\n\",\n      \"         +- Project [customerID#263, ((year(cast(now#270 as date)) > (year(dateOfBirth#264) + 65)) OR ((year(cast(now#270 as date)) = (year(dateOfBirth#264) + 65)) AND ((month(cast(now#270 as date)) < month(dateOfBirth#264)) OR ((month(cast(now#270 as date)) = month(dateOfBirth#264)) AND (dayofmonth(cast(now#270 as date)) <= dayofmonth(cast(now#270 as date))))))) AS SeniorCitizen#279, Partner#267, Dependents#268, gender#265, MonthlyCharges#269]\\n\",\n      \"            +- Filter isnotnull(customerID#263)\\n\",\n      \"               +- FileScan parquet [customerID#263,dateOfBirth#264,gender#265,Partner#267,Dependents#268,MonthlyCharges#269,now#270] Batched: true, DataFilters: [isnotnull(customerID#263)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/data/home/yuanli/work/customer-churn/data-science-blueprints/chu..., PartitionFilters: [], PushedFilters: [IsNotNull(customerID)], ReadSchema: struct<customerID:string,dateOfBirth:date,gender:string,Partner:string,Dependents:string,MonthlyC...\\n\",\n      \"\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"wide_data.explain()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 18,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[Stage 33:======================================================> (28 + 1) / 29]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"CPU times: user 1.15 s, sys: 188 ms, total: 1.34 s\\n\",\n      \"Wall time: 2min 58s\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\r\",\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%%time\\n\",\n    \"from churn.etl import write_df\\n\",\n    \"write_df(wide_data, output_file)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Inspecting individual tables\\n\",\n    \"\\n\",\n    \"If we need to inspect individual components of our processing, we can.  Each constituent of these joins is registered as a temporary view.  For example, we loaded `customers` earlier using a method from `churn.etl`, but it is also available as a table:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 19,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"customers = session.table(\\\"customers\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 20,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2022-04-05 10:02:56,112 WARN rapids.GpuOverrides: \\n\",\n      \"!Exec <CollectLimitExec> cannot run on GPU because the Exec CollectLimitExec has been disabled, and is disabled by default because Collect Limit replacement can be slower on the GPU, if huge number of rows in a batch it could help by limiting the number of rows transferred from GPU to CPU. Set spark.rapids.sql.exec.CollectLimitExec to true if you wish to enable it\\n\",\n      \"  @Partitioning <SinglePartition$> could run on GPU\\n\",\n      \"\\n\",\n      \"2022-04-05 10:02:56,113 WARN rapids.GpuOverrides: \\n\",\n      \"!Exec <CollectLimitExec> cannot run on GPU because the Exec CollectLimitExec has been disabled, and is disabled by default because Collect Limit replacement can be slower on the GPU, if huge number of rows in a batch it could help by limiting the number of rows transferred from GPU to CPU. Set spark.rapids.sql.exec.CollectLimitExec to true if you wish to enable it\\n\",\n      \"  @Partitioning <SinglePartition$> could run on GPU\\n\",\n      \"\\n\",\n      \"2022-04-05 10:02:56,114 WARN rapids.GpuOverrides: \\n\",\n      \"!Exec <CollectLimitExec> cannot run on GPU because the Exec CollectLimitExec has been disabled, and is disabled by default because Collect Limit replacement can be slower on the GPU, if huge number of rows in a batch it could help by limiting the number of rows transferred from GPU to CPU. Set spark.rapids.sql.exec.CollectLimitExec to true if you wish to enable it\\n\",\n      \"  @Partitioning <SinglePartition$> could run on GPU\\n\",\n      \"\\n\",\n      \"2022-04-05 10:03:35,633 WARN rapids.GpuOverrides: =============>(790 + 1) / 795]\\n\",\n      \"!Exec <CollectLimitExec> cannot run on GPU because the Exec CollectLimitExec has been disabled, and is disabled by default because Collect Limit replacement can be slower on the GPU, if huge number of rows in a batch it could help by limiting the number of rows transferred from GPU to CPU. Set spark.rapids.sql.exec.CollectLimitExec to true if you wish to enable it\\n\",\n      \"  @Partitioning <SinglePartition$> could run on GPU\\n\",\n      \"\\n\",\n      \"2022-04-05 10:03:35,634 WARN rapids.GpuOverrides: \\n\",\n      \"!Exec <CollectLimitExec> cannot run on GPU because the Exec CollectLimitExec has been disabled, and is disabled by default because Collect Limit replacement can be slower on the GPU, if huge number of rows in a batch it could help by limiting the number of rows transferred from GPU to CPU. Set spark.rapids.sql.exec.CollectLimitExec to true if you wish to enable it\\n\",\n      \"  @Partitioning <SinglePartition$> could run on GPU\\n\",\n      \"\\n\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+--------------------+\\n\",\n      \"|          customerID|\\n\",\n      \"+--------------------+\\n\",\n      \"|9102-OXKFY-fCyaBG...|\\n\",\n      \"|5478-JJVZK-pUoKfE...|\\n\",\n      \"|1843-TLSGD-PjdWrt...|\\n\",\n      \"|2027-FECZV-5HMwOd...|\\n\",\n      \"|3793-MMFUH-FBa4QK...|\\n\",\n      \"|5360-XGYAZ-F1ALBc...|\\n\",\n      \"|1843-TLSGD-L@JxWt...|\\n\",\n      \"|5872-OEQNH-5NXyac...|\\n\",\n      \"|6773-LQTVT-XB@vuC...|\\n\",\n      \"|3301-VKTGC-PjdWrt...|\\n\",\n      \"|9251-AWQGT-fCyaBG...|\\n\",\n      \"|9830-ECLEN-lqhKlh...|\\n\",\n      \"|7969-FFOWG-fPARzA...|\\n\",\n      \"|9451-WLYRI-0stTDJ...|\\n\",\n      \"|4293-ETKAP-dkh3P1...|\\n\",\n      \"|6281-FKEWS-0V3zMQ...|\\n\",\n      \"|8220-OCUFY-PjdWrt...|\\n\",\n      \"|0578-SKVMF-GSLp0h...|\\n\",\n      \"|2165-VOEGB-K8kBya...|\\n\",\n      \"|6754-WKSHP-rt81Nn...|\\n\",\n      \"+--------------------+\\n\",\n      \"only showing top 20 rows\\n\",\n      \"\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\r\",\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"customers.show()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"We can see which tables are available by querying the session catalog:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 21,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2022-04-05 10:03:38,813 WARN conf.HiveConf: HiveConf of name hive.stats.jdbc.timeout does not exist\\n\",\n      \"2022-04-05 10:03:38,814 WARN conf.HiveConf: HiveConf of name hive.stats.retries.wait does not exist\\n\",\n      \"2022-04-05 10:03:40,550 WARN metastore.ObjectStore: Version information not found in metastore. hive.metastore.schema.verification is not enabled so recording the schema version 2.3.0\\n\",\n      \"2022-04-05 10:03:40,550 WARN metastore.ObjectStore: setMetaStoreSchemaVersion called but recording version is disabled: version = 2.3.0, comment = Set by MetaStore yuanli@127.0.1.1\\n\",\n      \"2022-04-05 10:03:40,703 WARN metastore.ObjectStore: Failed to get database global_temp, returning NoSuchObjectException\\n\",\n      \"2022-04-05 10:03:40,833 WARN rapids.GpuOverrides: \\n\",\n      \"! <LocalTableScanExec> cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.execution.LocalTableScanExec\\n\",\n      \"  @Expression <AttributeReference> name#507 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> database#508 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> description#509 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> tableType#510 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> isTemporary#511 could run on GPU\\n\",\n      \"\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"['churned',\\n\",\n       \" 'contracts',\\n\",\n       \" 'counts_and_charges',\\n\",\n       \" 'customer_account_features',\\n\",\n       \" 'customer_account_meta',\\n\",\n       \" 'customer_billing',\\n\",\n       \" 'customer_charges',\\n\",\n       \" 'customer_internet_features',\\n\",\n       \" 'customer_phone_features',\\n\",\n       \" 'customers',\\n\",\n       \" 'device_protection',\\n\",\n       \" 'internet_service',\\n\",\n       \" 'multiple_lines',\\n\",\n       \" 'online_backup',\\n\",\n       \" 'online_security',\\n\",\n       \" 'paperless',\\n\",\n       \" 'payment',\\n\",\n       \" 'phone_service',\\n\",\n       \" 'streaming_movies',\\n\",\n       \" 'streaming_tv',\\n\",\n       \" 'tech_support',\\n\",\n       \" 'terminations']\"\n      ]\n     },\n     \"execution_count\": 21,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"tables = session.catalog.listTables()\\n\",\n    \"[t.name for t in tables]\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Finishing up\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 22,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"session.stop()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3 (ipykernel)\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.10\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 4\n}\n"
  },
  {
    "path": "examples/SQL+DF-Examples/demo/Spark_get_json_object.ipynb",
    "content": "{\n  \"cells\": [\n    {\n      \"cell_type\": \"markdown\",\n      \"id\": \"Td_alkbOv3Aj\",\n      \"metadata\": {\n        \"id\": \"Td_alkbOv3Aj\"\n      },\n      \"source\": [\n        \"# Spark RAPIDS get_json_object acceleration\\n\",\n        \"\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"id\": \"c6ed860b\",\n      \"metadata\": {\n        \"id\": \"c6ed860b\"\n      },\n      \"source\": [\n        \"<a target=\\\"_blank\\\" href=\\\"https://colab.research.google.com/github/rapidsai-community/showcase/blob/main/getting_started_tutorials/10min_to_cudf_colab.ipynb\\\">\\n\",\n        \"  <img src=\\\"https://colab.research.google.com/assets/colab-badge.svg\\\" alt=\\\"Open In Colab\\\"/>\\n\",\n        \"</a>\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"id\": \"AhUsdz6jLdMi\",\n      \"metadata\": {\n        \"id\": \"AhUsdz6jLdMi\"\n      },\n      \"source\": [\n        \"\\n\",\n        \"Before getting started - be sure to change your runtime to use a GPU Hardware accelerator! Use the Runtime -> \\\"Change runtime type\\\" menu option to add a GPU.\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"id\": \"ZfNDlz0SM0DB\",\n      \"metadata\": {\n        \"id\": \"ZfNDlz0SM0DB\"\n      },\n      \"source\": [\n        \"# Let's get started using the RAPIDS Accelerator for Apache Spark\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"id\": \"PzW61-K04A1E\",\n      \"metadata\": {\n        \"id\": \"PzW61-K04A1E\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"!nvidia-smi\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"!cat /proc/cpuinfo\"\n      ],\n      \"metadata\": {\n        \"id\": \"OIEun51OCyC4\"\n      },\n      \"id\": \"OIEun51OCyC4\",\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"spark_version='3.5.0'\\n\",\n        \"rapids_version='24.12.0'\"\n      ],\n      \"metadata\": {\n        \"id\": \"NEGt46X7nEqf\"\n      },\n      \"id\": \"NEGt46X7nEqf\",\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"%pip install --quiet \\\\\\n\",\n        \"  pyspark=={spark_version}\"\n      ],\n      \"metadata\": {\n        \"id\": \"g9XK28gcnHiG\"\n      },\n      \"id\": \"g9XK28gcnHiG\",\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"from importlib.resources import files\\n\",\n        \"from pyspark.sql import SparkSession\\n\",\n        \"import glob\\n\",\n        \"import os\\n\",\n        \"import re\\n\",\n        \"import time\\n\",\n        \"import statistics\"\n      ],\n      \"metadata\": {\n        \"id\": \"gr2msGD1nLh-\"\n      },\n      \"id\": \"gr2msGD1nLh-\",\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"pyspark_files = files('pyspark')\\n\",\n        \"spark_sql_jar_path, *_ = glob.glob(f\\\"{pyspark_files}/*/spark-sql_*jar\\\")\\n\",\n        \"spark_sql_jar = os.path.basename(spark_sql_jar_path)\\n\",\n        \"scala_version = re.search(r'^spark-sql_(\\\\d+.\\\\d+)-.*\\\\.jar$', spark_sql_jar).group(1)\"\n      ],\n      \"metadata\": {\n        \"id\": \"0uXK6z8KoFUt\"\n      },\n      \"id\": \"0uXK6z8KoFUt\",\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"spark = (\\n\",\n        \"    SparkSession.builder\\n\",\n        \"      .appName('JSON PySpark RAPIDS=ON/OFF')\\n\",\n        \"      .config('spark.driver.memory', '5g')\\n\",\n        \"      .config('spark.plugins', 'com.nvidia.spark.SQLPlugin')\\n\",\n        \"      .config('spark.jars.packages', f\\\"com.nvidia:rapids-4-spark_{scala_version}:{rapids_version}\\\")\\n\",\n        \"      .getOrCreate()\\n\",\n        \")\\n\",\n        \"spark\"\n      ],\n      \"metadata\": {\n        \"id\": \"ayT5VJQvnQv4\"\n      },\n      \"id\": \"ayT5VJQvnQv4\",\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"location = \\\"./TMP_DATA\\\"\\n\",\n        \"iters = 3\"\n      ],\n      \"metadata\": {\n        \"id\": \"3VsYyTATpNG1\"\n      },\n      \"id\": \"3VsYyTATpNG1\",\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"def mk_json_column(i):\\n\",\n        \"    return \\\"\\\"\\\" '\\\"', CAST(rand(\\\"\\\"\\\" + str(i) + \\\"\\\"\\\") * 10000 AS LONG), '\\\":\\\"\\\"\\\" + str(i) + \\\"\\\"\\\"'\\\"\\\"\\\"\\n\",\n        \"\\n\",\n        \"# generate json lines with very sparse keys\\n\",\n        \"spark.range(1000000).selectExpr(\\\"\\\"\\\"concat('{', \\\"\\\"\\\" + (\\\"\\\"\\\", ',' ,\\\"\\\"\\\".join([mk_json_column(i) for i in range(100)])) + \\\"\\\"\\\"'}') as json\\\"\\\"\\\").write.mode(\\\"overwrite\\\").parquet(location)\"\n      ],\n      \"metadata\": {\n        \"id\": \"diUi3mxWh91X\"\n      },\n      \"id\": \"diUi3mxWh91X\",\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"# Test pulling out a few keys using the GPU\\n\",\n        \"spark.conf.set(\\\"spark.rapids.sql.enabled\\\",True)\\n\",\n        \"gpu_times = []\\n\",\n        \"for i in range(iters):\\n\",\n        \"    start = time.time()\\n\",\n        \"    df = spark.read.parquet(location).selectExpr(\\\"count(get_json_object(json,'$.0')) as zero\\\", \\\"count(get_json_object(json,'$.10')) as ten\\\", \\\"count(get_json_object(json,'$.100')) as hundred\\\", \\\"count(get_json_object(json,'$.1000')) as thousand\\\", \\\"count(get_json_object(json,'$.1001')) as thousandAndOne\\\", \\\"avg(octet_length(json)) as len\\\")\\n\",\n        \"    if i == 0:\\n\",\n        \"      df.show()\\n\",\n        \"    else:\\n\",\n        \"      df.collect()\\n\",\n        \"    end = time.time()\\n\",\n        \"    gpu_times.append(end - start)\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"print(f\\\"Median execution time of {iters} runs for GPU get_json_object: {statistics.median(gpu_times):.3f}\\\")\"\n      ],\n      \"metadata\": {\n        \"id\": \"iXaXVgBNt4pK\"\n      },\n      \"id\": \"iXaXVgBNt4pK\",\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"#  Run the same test using the CPU. Note that this is a exceptional result\\n\",\n        \"#  because Colab provides very little CPU (2 cores) to go with the GPU (T4)\\n\",\n        \"#  on a 16 core AMD CPU that is not overcommited and with an NVMe to load the\\n\",\n        \"#  data, and an A6000 GPU, the GPU takes about 0.662 seconds to complete and\\n\",\n        \"#  the CPU taks about 2.986 seconds, or about a 4.5x speedup, compared to this\\n\",\n        \"#  notebook's ~30x speedup.\\n\",\n        \"spark.conf.set(\\\"spark.rapids.sql.enabled\\\",False)\\n\",\n        \"cpu_times = []\\n\",\n        \"for i in range(iters):\\n\",\n        \"    start = time.time()\\n\",\n        \"    df = spark.read.parquet(location).selectExpr(\\\"count(get_json_object(json,'$.0')) as zero\\\", \\\"count(get_json_object(json,'$.10')) as ten\\\", \\\"count(get_json_object(json,'$.100')) as hundred\\\", \\\"count(get_json_object(json,'$.1000')) as thousand\\\", \\\"count(get_json_object(json,'$.1001')) as thousandAndOne\\\", \\\"avg(octet_length(json)) as len\\\")\\n\",\n        \"    if i == 0:\\n\",\n        \"      df.show()\\n\",\n        \"    else:\\n\",\n        \"      df.collect()\\n\",\n        \"    end = time.time()\\n\",\n        \"    cpu_times.append(end - start)\\n\",\n        \"\\n\",\n        \"print(f\\\"Median execution time of {iters} runs for CPU get_json_object: {statistics.median(cpu_times):.3f}\\\")\"\n      ],\n      \"metadata\": {\n        \"id\": \"lUmVe12Wic5X\"\n      },\n      \"id\": \"lUmVe12Wic5X\",\n      \"execution_count\": null,\n      \"outputs\": []\n    }\n  ],\n  \"metadata\": {\n    \"accelerator\": \"GPU\",\n    \"colab\": {\n      \"provenance\": []\n    },\n    \"gpuClass\": \"standard\",\n    \"kernelspec\": {\n      \"display_name\": \"Python 3.9.12 ('base')\",\n      \"language\": \"python\",\n      \"name\": \"python3\"\n    },\n    \"language_info\": {\n      \"codemirror_mode\": {\n        \"name\": \"ipython\",\n        \"version\": 3\n      },\n      \"file_extension\": \".py\",\n      \"mimetype\": \"text/x-python\",\n      \"name\": \"python\",\n      \"nbconvert_exporter\": \"python\",\n      \"pygments_lexer\": \"ipython3\",\n      \"version\": \"3.9.12\"\n    },\n    \"vscode\": {\n      \"interpreter\": {\n        \"hash\": \"5327a248d9883bedf47bfd9e608af95bf318797e621edcc550c6b5b3fdc820cc\"\n      }\n    }\n  },\n  \"nbformat\": 4,\n  \"nbformat_minor\": 5\n}"
  },
  {
    "path": "examples/SQL+DF-Examples/demo/Spark_parquet_microkernels.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"raw\",\n   \"id\": \"Td_alkbOv3Aj\",\n   \"metadata\": {\n    \"id\": \"Td_alkbOv3Aj\"\n   },\n   \"source\": [\n    \"{\\n\",\n    \"  \\\"cells\\\": [\\n\",\n    \"    {\\n\",\n    \"      \\\"cell_type\\\": \\\"markdown\\\",\\n\",\n    \"      \\\"id\\\": \\\"Td_alkbOv3Aj\\\",\\n\",\n    \"      \\\"metadata\\\": {\\n\",\n    \"        \\\"id\\\": \\\"Td_alkbOv3Aj\\\"\\n\",\n    \"      },\\n\",\n    \"      \\\"source\\\": [\\n\",\n    \"        \\\"# Spark RAPIDS Parquet acceleration\\\\n\\\",\\n\",\n    \"        \\\"\\\\n\\\"\\n\",\n    \"      ]\\n\",\n    \"    },\\n\",\n    \"    {\\n\",\n    \"      \\\"cell_type\\\": \\\"markdown\\\",\\n\",\n    \"      \\\"id\\\": \\\"c6ed860b\\\",\\n\",\n    \"      \\\"metadata\\\": {\\n\",\n    \"        \\\"id\\\": \\\"c6ed860b\\\"\\n\",\n    \"      },\\n\",\n    \"      \\\"source\\\": [\\n\",\n    \"        \\\"<a target=\\\\\\\"_blank\\\\\\\" href=\\\\\\\"https://colab.research.google.com/github/NVIDIA/spark-rapids-examples/blob/main/examples/SQL%2BDF-Examples/demo/Spark_parquet_microkernels.ipynb\\\\\\\">\\\\n\\\",\\n\",\n    \"        \\\"  <img src=\\\\\\\"https://colab.research.google.com/assets/colab-badge.svg\\\\\\\" alt=\\\\\\\"Open In Colab\\\\\\\"/>\\\\n\\\",\\n\",\n    \"        \\\"</a>\\\\n\\\"\\n\",\n    \"      ]\\n\",\n    \"    },\\n\",\n    \"    {\\n\",\n    \"      \\\"cell_type\\\": \\\"markdown\\\",\\n\",\n    \"      \\\"id\\\": \\\"AhUsdz6jLdMi\\\",\\n\",\n    \"      \\\"metadata\\\": {\\n\",\n    \"        \\\"id\\\": \\\"AhUsdz6jLdMi\\\"\\n\",\n    \"      },\\n\",\n    \"      \\\"source\\\": [\\n\",\n    \"        \\\"\\\\n\\\",\\n\",\n    \"        \\\"Before getting started - be sure to change your runtime to use a GPU Hardware accelerator! Use the Runtime -> \\\\\\\"Change runtime type\\\\\\\" menu option to add a GPU.\\\"\\n\",\n    \"      ]\\n\",\n    \"    },\\n\",\n    \"    {\\n\",\n    \"      \\\"cell_type\\\": \\\"markdown\\\",\\n\",\n    \"      \\\"id\\\": \\\"ZfNDlz0SM0DB\\\",\\n\",\n    \"      \\\"metadata\\\": {\\n\",\n    \"        \\\"id\\\": \\\"ZfNDlz0SM0DB\\\"\\n\",\n    \"      },\\n\",\n    \"      \\\"source\\\": [\\n\",\n    \"        \\\"# Let's get started using the RAPIDS Accelerator for Apache Spark\\\"\\n\",\n    \"      ]\\n\",\n    \"    },\\n\",\n    \"    {\\n\",\n    \"      \\\"cell_type\\\": \\\"code\\\",\\n\",\n    \"      \\\"execution_count\\\": null,\\n\",\n    \"      \\\"id\\\": \\\"PzW61-K04A1E\\\",\\n\",\n    \"      \\\"metadata\\\": {\\n\",\n    \"        \\\"id\\\": \\\"PzW61-K04A1E\\\"\\n\",\n    \"      },\\n\",\n    \"      \\\"outputs\\\": [],\\n\",\n    \"      \\\"source\\\": [\\n\",\n    \"        \\\"!nvidia-smi\\\"\\n\",\n    \"      ]\\n\",\n    \"    },\\n\",\n    \"    {\\n\",\n    \"      \\\"cell_type\\\": \\\"code\\\",\\n\",\n    \"      \\\"source\\\": [\\n\",\n    \"        \\\"!cat /proc/cpuinfo\\\"\\n\",\n    \"      ],\\n\",\n    \"      \\\"metadata\\\": {\\n\",\n    \"        \\\"id\\\": \\\"OIEun51OCyC4\\\"\\n\",\n    \"      },\\n\",\n    \"      \\\"id\\\": \\\"OIEun51OCyC4\\\",\\n\",\n    \"      \\\"execution_count\\\": null,\\n\",\n    \"      \\\"outputs\\\": []\\n\",\n    \"    },\\n\",\n    \"    {\\n\",\n    \"      \\\"cell_type\\\": \\\"code\\\",\\n\",\n    \"      \\\"source\\\": [\\n\",\n    \"        \\\"spark_version='3.5.0'\\\\n\\\",\\n\",\n    \"        \\\"rapids_version='24.12.0'\\\"\\n\",\n    \"      ],\\n\",\n    \"      \\\"metadata\\\": {\\n\",\n    \"        \\\"id\\\": \\\"NEGt46X7nEqf\\\"\\n\",\n    \"      },\\n\",\n    \"      \\\"id\\\": \\\"NEGt46X7nEqf\\\",\\n\",\n    \"      \\\"execution_count\\\": null,\\n\",\n    \"      \\\"outputs\\\": []\\n\",\n    \"    },\\n\",\n    \"    {\\n\",\n    \"      \\\"cell_type\\\": \\\"code\\\",\\n\",\n    \"      \\\"source\\\": [\\n\",\n    \"        \\\"%pip install --quiet \\\\\\\\\\\\n\\\",\\n\",\n    \"        \\\"  pyspark=={spark_version}\\\"\\n\",\n    \"      ],\\n\",\n    \"      \\\"metadata\\\": {\\n\",\n    \"        \\\"id\\\": \\\"g9XK28gcnHiG\\\"\\n\",\n    \"      },\\n\",\n    \"      \\\"id\\\": \\\"g9XK28gcnHiG\\\",\\n\",\n    \"      \\\"execution_count\\\": null,\\n\",\n    \"      \\\"outputs\\\": []\\n\",\n    \"    },\\n\",\n    \"    {\\n\",\n    \"      \\\"cell_type\\\": \\\"code\\\",\\n\",\n    \"      \\\"source\\\": [\\n\",\n    \"        \\\"from importlib.resources import files\\\\n\\\",\\n\",\n    \"        \\\"from pyspark.sql import SparkSession\\\\n\\\",\\n\",\n    \"        \\\"import glob\\\\n\\\",\\n\",\n    \"        \\\"import os\\\\n\\\",\\n\",\n    \"        \\\"import re\\\\n\\\",\\n\",\n    \"        \\\"import time\\\\n\\\",\\n\",\n    \"        \\\"import statistics\\\"\\n\",\n    \"      ],\\n\",\n    \"      \\\"metadata\\\": {\\n\",\n    \"        \\\"id\\\": \\\"gr2msGD1nLh-\\\"\\n\",\n    \"      },\\n\",\n    \"      \\\"id\\\": \\\"gr2msGD1nLh-\\\",\\n\",\n    \"      \\\"execution_count\\\": null,\\n\",\n    \"      \\\"outputs\\\": []\\n\",\n    \"    },\\n\",\n    \"    {\\n\",\n    \"      \\\"cell_type\\\": \\\"code\\\",\\n\",\n    \"      \\\"source\\\": [\\n\",\n    \"        \\\"pyspark_files = files('pyspark')\\\\n\\\",\\n\",\n    \"        \\\"spark_sql_jar_path, *_ = glob.glob(f\\\\\\\"{pyspark_files}/*/spark-sql_*jar\\\\\\\")\\\\n\\\",\\n\",\n    \"        \\\"spark_sql_jar = os.path.basename(spark_sql_jar_path)\\\\n\\\",\\n\",\n    \"        \\\"scala_version = re.search(r'^spark-sql_(\\\\\\\\d+.\\\\\\\\d+)-.*\\\\\\\\.jar$', spark_sql_jar).group(1)\\\"\\n\",\n    \"      ],\\n\",\n    \"      \\\"metadata\\\": {\\n\",\n    \"        \\\"id\\\": \\\"0uXK6z8KoFUt\\\"\\n\",\n    \"      },\\n\",\n    \"      \\\"id\\\": \\\"0uXK6z8KoFUt\\\",\\n\",\n    \"      \\\"execution_count\\\": null,\\n\",\n    \"      \\\"outputs\\\": []\\n\",\n    \"    },\\n\",\n    \"    {\\n\",\n    \"      \\\"cell_type\\\": \\\"code\\\",\\n\",\n    \"      \\\"source\\\": [\\n\",\n    \"        \\\"spark = (\\\\n\\\",\\n\",\n    \"        \\\"    SparkSession.builder\\\\n\\\",\\n\",\n    \"        \\\"      .appName('Parquet Spark GPU Acceleration')\\\\n\\\",\\n\",\n    \"        \\\"      .master('local[*]')\\\\n\\\",\\n\",\n    \"        \\\"      .config('spark.driver.memory', '5g')\\\\n\\\",\\n\",\n    \"        \\\"      .config('spark.plugins', 'com.nvidia.spark.SQLPlugin')\\\\n\\\",\\n\",\n    \"        \\\"      .config('spark.jars.packages', f\\\\\\\"com.nvidia:rapids-4-spark_{scala_version}:{rapids_version}\\\\\\\")\\\\n\\\",\\n\",\n    \"        \\\"      .getOrCreate()\\\\n\\\",\\n\",\n    \"        \\\")\\\\n\\\",\\n\",\n    \"        \\\"spark\\\"\\n\",\n    \"      ],\\n\",\n    \"      \\\"metadata\\\": {\\n\",\n    \"        \\\"id\\\": \\\"ayT5VJQvnQv4\\\"\\n\",\n    \"      },\\n\",\n    \"      \\\"id\\\": \\\"ayT5VJQvnQv4\\\",\\n\",\n    \"      \\\"execution_count\\\": null,\\n\",\n    \"      \\\"outputs\\\": []\\n\",\n    \"    },\\n\",\n    \"    {\\n\",\n    \"      \\\"cell_type\\\": \\\"code\\\",\\n\",\n    \"      \\\"source\\\": [\\n\",\n    \"        \\\"location = \\\\\\\"./TMP_DATA\\\\\\\"\\\\n\\\",\\n\",\n    \"        \\\"iters = 5\\\"\\n\",\n    \"      ],\\n\",\n    \"      \\\"metadata\\\": {\\n\",\n    \"        \\\"id\\\": \\\"3VsYyTATpNG1\\\"\\n\",\n    \"      },\\n\",\n    \"      \\\"id\\\": \\\"3VsYyTATpNG1\\\",\\n\",\n    \"      \\\"execution_count\\\": null,\\n\",\n    \"      \\\"outputs\\\": []\\n\",\n    \"    },\\n\",\n    \"    {\\n\",\n    \"      \\\"cell_type\\\": \\\"code\\\",\\n\",\n    \"      \\\"source\\\": [\\n\",\n    \"        \\\"from pyspark.sql.types import IntegerType, StringType, StructType, StructField\\\\n\\\",\\n\",\n    \"        \\\"from pyspark.sql import functions as F\\\\n\\\",\\n\",\n    \"        \\\"import random\\\\n\\\",\\n\",\n    \"        \\\"import string\\\\n\\\",\\n\",\n    \"        \\\"\\\\n\\\",\\n\",\n    \"        \\\"# Define schema\\\\n\\\",\\n\",\n    \"        \\\"schema = StructType([\\\\n\\\",\\n\",\n    \"        \\\"    StructField(\\\\\\\"id\\\\\\\", IntegerType(), False),\\\\n\\\",\\n\",\n    \"        \\\"    StructField(\\\\\\\"name\\\\\\\", StringType(), False),\\\\n\\\",\\n\",\n    \"        \\\"    StructField(\\\\\\\"age\\\\\\\", IntegerType(), False),\\\\n\\\",\\n\",\n    \"        \\\"    StructField(\\\\\\\"salary\\\\\\\", IntegerType(), False)\\\\n\\\",\\n\",\n    \"        \\\"])\\\\n\\\",\\n\",\n    \"        \\\"\\\\n\\\",\\n\",\n    \"        \\\"# Function to generate random strings\\\\n\\\",\\n\",\n    \"        \\\"def random_string(length=10):\\\\n\\\",\\n\",\n    \"        \\\"    return ''.join(random.choices(string.ascii_letters, k=length))\\\\n\\\",\\n\",\n    \"        \\\"\\\\n\\\",\\n\",\n    \"        \\\"# Generate DataFrame with 20M rows\\\\n\\\",\\n\",\n    \"        \\\"df = spark.range(0, 20_000_000).toDF(\\\\\\\"id\\\\\\\") \\\\\\\\\\\\n\\\",\\n\",\n    \"        \\\"    .withColumn(\\\\\\\"name\\\\\\\", F.udf(lambda: random_string(), StringType())()) \\\\\\\\\\\\n\\\",\\n\",\n    \"        \\\"    .withColumn(\\\\\\\"age\\\\\\\", (F.rand() * 50 + 20).cast(IntegerType())) \\\\\\\\\\\\n\\\",\\n\",\n    \"        \\\"    .withColumn(\\\\\\\"salary\\\\\\\", (F.rand() * 100000 + 30000).cast(IntegerType()))\\\\n\\\",\\n\",\n    \"        \\\"\\\\n\\\",\\n\",\n    \"        \\\"df.write.mode(\\\\\\\"overwrite\\\\\\\").parquet(location)\\\"\\n\",\n    \"      ],\\n\",\n    \"      \\\"metadata\\\": {\\n\",\n    \"        \\\"id\\\": \\\"diUi3mxWh91X\\\"\\n\",\n    \"      },\\n\",\n    \"      \\\"id\\\": \\\"diUi3mxWh91X\\\",\\n\",\n    \"      \\\"execution_count\\\": null,\\n\",\n    \"      \\\"outputs\\\": []\\n\",\n    \"    },\\n\",\n    \"    {\\n\",\n    \"      \\\"cell_type\\\": \\\"code\\\",\\n\",\n    \"      \\\"source\\\": [\\n\",\n    \"        \\\"# Run the Parquet scan test on the GPU\\\\n\\\",\\n\",\n    \"        \\\"spark.conf.set(\\\\\\\"spark.rapids.sql.enabled\\\\\\\",True)\\\\n\\\",\\n\",\n    \"        \\\"gpu_times = []\\\\n\\\",\\n\",\n    \"        \\\"for i in range(iters):\\\\n\\\",\\n\",\n    \"        \\\"    start = time.time()\\\\n\\\",\\n\",\n    \"        \\\"    df = spark.read.parquet(location).selectExpr(\\\\\\\"count(name) as rows\\\\\\\", \\\\\\\"avg(salary) as average_salary\\\\\\\", \\\\\\\"median(salary) as median_salary\\\\\\\", \\\\\\\"sum(salary) as total_salary\\\\\\\", \\\\\\\"avg(age) as average_age\\\\\\\", \\\\\\\"median(age) as median_age\\\\\\\")\\\\n\\\",\\n\",\n    \"        \\\"    if i == 0:\\\\n\\\",\\n\",\n    \"        \\\"      df.show()\\\\n\\\",\\n\",\n    \"        \\\"    else:\\\\n\\\",\\n\",\n    \"        \\\"      df.collect()\\\\n\\\",\\n\",\n    \"        \\\"    end = time.time()\\\\n\\\",\\n\",\n    \"        \\\"    gpu_times.append(end - start)\\\\n\\\",\\n\",\n    \"        \\\"\\\\n\\\",\\n\",\n    \"        \\\"gpu_median = statistics.median(gpu_times)\\\\n\\\",\\n\",\n    \"        \\\"\\\\n\\\",\\n\",\n    \"        \\\"print(f\\\\\\\"Median execution time of {iters} runs for GPU Parquet scan: {gpu_median:.3f}\\\\\\\")\\\"\\n\",\n    \"      ],\\n\",\n    \"      \\\"metadata\\\": {\\n\",\n    \"        \\\"id\\\": \\\"iXaXVgBNt4pK\\\"\\n\",\n    \"      },\\n\",\n    \"      \\\"id\\\": \\\"iXaXVgBNt4pK\\\",\\n\",\n    \"      \\\"execution_count\\\": null,\\n\",\n    \"      \\\"outputs\\\": []\\n\",\n    \"    },\\n\",\n    \"    {\\n\",\n    \"      \\\"cell_type\\\": \\\"code\\\",\\n\",\n    \"      \\\"source\\\": [\\n\",\n    \"        \\\"# Run the Parquet scan test on the CPU\\\\n\\\",\\n\",\n    \"        \\\"spark.conf.set(\\\\\\\"spark.rapids.sql.enabled\\\\\\\",False)\\\\n\\\",\\n\",\n    \"        \\\"cpu_times = []\\\\n\\\",\\n\",\n    \"        \\\"for i in range(iters):\\\\n\\\",\\n\",\n    \"        \\\"    start = time.time()\\\\n\\\",\\n\",\n    \"        \\\"    df = spark.read.parquet(location).selectExpr(\\\\\\\"count(name) as rows\\\\\\\", \\\\\\\"avg(salary) as average_salary\\\\\\\", \\\\\\\"median(salary) as median_salary\\\\\\\", \\\\\\\"sum(salary) as total_salary\\\\\\\", \\\\\\\"avg(age) as average_age\\\\\\\", \\\\\\\"median(age) as median_age\\\\\\\")\\\\n\\\",\\n\",\n    \"        \\\"    if i == 0:\\\\n\\\",\\n\",\n    \"        \\\"      df.show()\\\\n\\\",\\n\",\n    \"        \\\"    else:\\\\n\\\",\\n\",\n    \"        \\\"      df.collect()\\\\n\\\",\\n\",\n    \"        \\\"    end = time.time()\\\\n\\\",\\n\",\n    \"        \\\"    cpu_times.append(end - start)\\\\n\\\",\\n\",\n    \"        \\\"\\\\n\\\",\\n\",\n    \"        \\\"cpu_median = statistics.median(cpu_times)\\\\n\\\",\\n\",\n    \"        \\\"print(f\\\\\\\"Median execution time of {iters} runs for CPU Parquet scan: {cpu_median:.3f}\\\\\\\")\\\"\\n\",\n    \"      ],\\n\",\n    \"      \\\"metadata\\\": {\\n\",\n    \"        \\\"id\\\": \\\"lUmVe12Wic5X\\\"\\n\",\n    \"      },\\n\",\n    \"      \\\"id\\\": \\\"lUmVe12Wic5X\\\",\\n\",\n    \"      \\\"execution_count\\\": null,\\n\",\n    \"      \\\"outputs\\\": []\\n\",\n    \"    },\\n\",\n    \"    {\\n\",\n    \"      \\\"cell_type\\\": \\\"code\\\",\\n\",\n    \"      \\\"source\\\": [\\n\",\n    \"        \\\"# GPU speedup should be in the range of 5-10x\\\\n\\\",\\n\",\n    \"        \\\"speedup = cpu_median / gpu_median\\\\n\\\",\\n\",\n    \"        \\\"print(f\\\\\\\"GPU speedup: {speedup:.2f}x\\\\\\\")\\\"\\n\",\n    \"      ],\\n\",\n    \"      \\\"metadata\\\": {\\n\",\n    \"        \\\"id\\\": \\\"CxROFk_AoQQl\\\"\\n\",\n    \"      },\\n\",\n    \"      \\\"id\\\": \\\"CxROFk_AoQQl\\\",\\n\",\n    \"      \\\"execution_count\\\": null,\\n\",\n    \"      \\\"outputs\\\": []\\n\",\n    \"    }\\n\",\n    \"  ],\\n\",\n    \"  \\\"metadata\\\": {\\n\",\n    \"    \\\"accelerator\\\": \\\"GPU\\\",\\n\",\n    \"    \\\"colab\\\": {\\n\",\n    \"      \\\"provenance\\\": []\\n\",\n    \"    },\\n\",\n    \"    \\\"gpuClass\\\": \\\"standard\\\",\\n\",\n    \"    \\\"kernelspec\\\": {\\n\",\n    \"      \\\"display_name\\\": \\\"Python 3.9.12 ('base')\\\",\\n\",\n    \"      \\\"language\\\": \\\"python\\\",\\n\",\n    \"      \\\"name\\\": \\\"python3\\\"\\n\",\n    \"    },\\n\",\n    \"    \\\"language_info\\\": {\\n\",\n    \"      \\\"codemirror_mode\\\": {\\n\",\n    \"        \\\"name\\\": \\\"ipython\\\",\\n\",\n    \"        \\\"version\\\": 3\\n\",\n    \"      },\\n\",\n    \"      \\\"file_extension\\\": \\\".py\\\",\\n\",\n    \"      \\\"mimetype\\\": \\\"text/x-python\\\",\\n\",\n    \"      \\\"name\\\": \\\"python\\\",\\n\",\n    \"      \\\"nbconvert_exporter\\\": \\\"python\\\",\\n\",\n    \"      \\\"pygments_lexer\\\": \\\"ipython3\\\",\\n\",\n    \"      \\\"version\\\": \\\"3.9.12\\\"\\n\",\n    \"    },\\n\",\n    \"    \\\"vscode\\\": {\\n\",\n    \"      \\\"interpreter\\\": {\\n\",\n    \"        \\\"hash\\\": \\\"5327a248d9883bedf47bfd9e608af95bf318797e621edcc550c6b5b3fdc820cc\\\"\\n\",\n    \"      }\\n\",\n    \"    }\\n\",\n    \"  },\\n\",\n    \"  \\\"nbformat\\\": 4,\\n\",\n    \"  \\\"nbformat_minor\\\": 5\\n\",\n    \"}\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"AhUsdz6jLdMi\",\n   \"metadata\": {\n    \"id\": \"AhUsdz6jLdMi\"\n   },\n   \"source\": [\n    \"\\n\",\n    \"Before getting started - be sure to change your runtime to use a GPU Hardware accelerator! Use the Runtime -> \\\"Change runtime type\\\" menu option to add a GPU.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"PzW61-K04A1E\",\n   \"metadata\": {\n    \"id\": \"PzW61-K04A1E\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"!nvidia-smi\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"source\": [\n    \"spark_version='3.5.0'\\n\",\n    \"rapids_version='24.12.0'\"\n   ],\n   \"metadata\": {\n    \"id\": \"NEGt46X7nEqf\"\n   },\n   \"id\": \"NEGt46X7nEqf\",\n   \"execution_count\": null,\n   \"outputs\": []\n  },\n  {\n   \"cell_type\": \"code\",\n   \"source\": [\n    \"from importlib.resources import files\\n\",\n    \"from pyspark.sql import SparkSession\\n\",\n    \"import glob\\n\",\n    \"import os\\n\",\n    \"import re\\n\",\n    \"import time\\n\",\n    \"import statistics\"\n   ],\n   \"metadata\": {\n    \"id\": \"gr2msGD1nLh-\"\n   },\n   \"id\": \"gr2msGD1nLh-\",\n   \"execution_count\": null,\n   \"outputs\": []\n  },\n  {\n   \"cell_type\": \"code\",\n   \"source\": [\n    \"spark = (\\n\",\n    \"    SparkSession.builder\\n\",\n    \"      .appName('Parquet Spark GPU Acceleration')\\n\",\n    \"      .master('local[*]')\\n\",\n    \"      .config('spark.driver.memory', '5g')\\n\",\n    \"      .config('spark.plugins', 'com.nvidia.spark.SQLPlugin')\\n\",\n    \"      .config('spark.jars.packages', f\\\"com.nvidia:rapids-4-spark_{scala_version}:{rapids_version}\\\")\\n\",\n    \"      .getOrCreate()\\n\",\n    \")\\n\",\n    \"spark\"\n   ],\n   \"metadata\": {\n    \"id\": \"ayT5VJQvnQv4\"\n   },\n   \"id\": \"ayT5VJQvnQv4\",\n   \"execution_count\": null,\n   \"outputs\": []\n  },\n  {\n   \"cell_type\": \"code\",\n   \"source\": [\n    \"from pyspark.sql.types import IntegerType, StringType, StructType, StructField\\n\",\n    \"from pyspark.sql import functions as F\\n\",\n    \"import random\\n\",\n    \"import string\\n\",\n    \"\\n\",\n    \"# Define schema\\n\",\n    \"schema = StructType([\\n\",\n    \"    StructField(\\\"id\\\", IntegerType(), False),\\n\",\n    \"    StructField(\\\"name\\\", StringType(), False),\\n\",\n    \"    StructField(\\\"age\\\", IntegerType(), False),\\n\",\n    \"    StructField(\\\"salary\\\", IntegerType(), False)\\n\",\n    \"])\\n\",\n    \"\\n\",\n    \"# Function to generate random strings\\n\",\n    \"def random_string(length=10):\\n\",\n    \"    return ''.join(random.choices(string.ascii_letters, k=length))\\n\",\n    \"\\n\",\n    \"# Generate DataFrame with 20M rows\\n\",\n    \"df = spark.range(0, 20_000_000).toDF(\\\"id\\\") \\\\\\n\",\n    \"    .withColumn(\\\"name\\\", F.udf(lambda: random_string(), StringType())()) \\\\\\n\",\n    \"    .withColumn(\\\"age\\\", (F.rand() * 50 + 20).cast(IntegerType())) \\\\\\n\",\n    \"    .withColumn(\\\"salary\\\", (F.rand() * 100000 + 30000).cast(IntegerType()))\\n\",\n    \"\\n\",\n    \"df.write.mode(\\\"overwrite\\\").parquet(location)\"\n   ],\n   \"metadata\": {\n    \"id\": \"diUi3mxWh91X\"\n   },\n   \"id\": \"diUi3mxWh91X\",\n   \"execution_count\": null,\n   \"outputs\": []\n  },\n  {\n   \"cell_type\": \"code\",\n   \"source\": [\n    \"# Run the Parquet scan test on the CPU\\n\",\n    \"spark.conf.set(\\\"spark.rapids.sql.enabled\\\",False)\\n\",\n    \"cpu_times = []\\n\",\n    \"for i in range(iters):\\n\",\n    \"    start = time.time()\\n\",\n    \"    df = spark.read.parquet(location).selectExpr(\\\"count(name) as rows\\\", \\\"avg(salary) as average_salary\\\", \\\"median(salary) as median_salary\\\", \\\"sum(salary) as total_salary\\\", \\\"avg(age) as average_age\\\", \\\"median(age) as median_age\\\")\\n\",\n    \"    if i == 0:\\n\",\n    \"      df.show()\\n\",\n    \"    else:\\n\",\n    \"      df.collect()\\n\",\n    \"    end = time.time()\\n\",\n    \"    cpu_times.append(end - start)\\n\",\n    \"\\n\",\n    \"cpu_median = statistics.median(cpu_times)\\n\",\n    \"print(f\\\"Median execution time of {iters} runs for CPU Parquet scan: {cpu_median:.3f}\\\")\"\n   ],\n   \"metadata\": {\n    \"id\": \"lUmVe12Wic5X\"\n   },\n   \"id\": \"lUmVe12Wic5X\",\n   \"execution_count\": null,\n   \"outputs\": []\n  }\n ],\n \"metadata\": {\n  \"accelerator\": \"GPU\",\n  \"colab\": {\n   \"provenance\": []\n  },\n  \"gpuClass\": \"standard\",\n  \"kernelspec\": {\n   \"display_name\": \"Python 3.9.12 ('base')\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.9.12\"\n  },\n  \"vscode\": {\n   \"interpreter\": {\n    \"hash\": \"5327a248d9883bedf47bfd9e608af95bf318797e621edcc550c6b5b3fdc820cc\"\n   }\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "examples/SQL+DF-Examples/micro-benchmarks/README.md",
    "content": "# Microbenchmark\n\nStandard industry benchmarks are a great way to measure performance over \na period of time but another barometer to measure performance is to measure\nperformance of common operators that are used in the data preprocessing stage or in data analytics.\nThe microbenchmark notebook in this repo uses five such queries in the chart shown below:\n\n- **Count Distinct**: a function used to estimate the number of unique page views or \n  unique customers visiting an e-commerce site.\n- **Window**: a critical operator necessary for preprocessing components in analyzing\n  timestamped event data in marketing or financial industry.\n- **Intersect**: an operator used to remove duplicates in a dataframe.\n- **Cross-join**: A common use for a cross join is to obtain all combinations of items.\n- **Hash-join**: Joining two tables together by matching rows based on a common column.\n\nThese queries were run on a standard eight-nodes CPU cluster with 2 CPU (128 cores),\n512GB memory and 1xA100 GPUs per node. The dataset used was of size 3TB with multiple different data types.\nThe queries are based on several tables in NDS parquet format with Decimal. \nThese four queries show not only performance and cost benefits but also the range of\nspeed-up (27x to 1.5x) varies depending on compute intensity. \nThese queries vary in compute and network utilization similar to a practical use case in\ndata preprocessing.To test these queries, you can generate the parquet format dataset using\nthis NDS dataset generator tool. All the queries are running on the SF3000(Scale Factor 3000) dataset.\nYou can generate it with the following command:\n```\n# Assuming your platform is Linux\n# Install sbt\necho \"deb https://repo.scala-sbt.org/scalasbt/debian all main\" | sudo tee /etc/apt/sources.list.d/sbt.list\necho \"deb https://repo.scala-sbt.org/scalasbt/debian /\" | sudo tee /etc/apt/sources.list.d/sbt_old.list\ncurl -sL \"https://keyserver.ubuntu.com/pks/lookup?op=get&search=0x2EE0EA64E40A89B84B2DF73499E82A75642AC823\" | sudo apt-key add\nsudo apt-get update\nsudo apt-get install sbt\n\n# Install jdk \nsudo apt-get install openjdk-8-jdk\n\n# clone related repos\ngit clone https://github.com/databricks/spark-sql-perf.git\ngit clone https://github.com/databricks/tpcds-kit.git\n\n# build \ncd tpcds-kit/tools\nmake OS=LINUX\n\nsbt \"test:runMain com.databricks.spark.sql.perf.tpcds.GenTPCDSData -d /databricks-tpcds-kit-path -s 3000G -l /your-dataset-path -f parquet\"\n```\n\n![microbenchmark-speedup](/docs/img/guides/microbm.png)\n"
  },
  {
    "path": "examples/SQL+DF-Examples/micro-benchmarks/notebooks/micro-benchmarks-cpu.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"d89df9bf\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Microbenchmarks on CPU\\n\",\n    \"This is a notebook for microbenchmarks running on CPU.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"id\": \"d08c8bae\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from pyspark.sql import SparkSession\\n\",\n    \"from pyspark.conf import SparkConf\\n\",\n    \"from time import time\\n\",\n    \"import os\\n\",\n    \"\\n\",\n    \"# Change to your cluster ip:port\\n\",\n    \"SPARK_MASTER_URL = os.getenv(\\\"SPARK_MASTER_URL\\\", \\\"spark://your-ip:port\\\")\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"6842522a\",\n   \"metadata\": {},\n   \"source\": [\n    \"Run the microbenchmark with retry times\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"id\": \"45f50252\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def runMicroBenchmark(spark, appName, query, retryTimes):\\n\",\n    \"    count = 0\\n\",\n    \"    total_time = 0\\n\",\n    \"    # You can print the physical plan of each query\\n\",\n    \"    # spark.sql(query).explain()\\n\",\n    \"    while count < retryTimes:\\n\",\n    \"        start = time()\\n\",\n    \"        spark.sql(query).show(5)\\n\",\n    \"        end = time()\\n\",\n    \"        total_time += round(end - start, 2)\\n\",\n    \"        count = count + 1\\n\",\n    \"        print(\\\"Retry times : {}, \\\".format(count) + appName + \\\" Microbenchmark takes {} seconds\\\".format(round(end - start, 2)))\\n\",\n    \"    print(appName + \\\" Microbenchmark takes average {} seconds after {} retries\\\".format(round(total_time/retryTimes),retryTimes))\\n\",\n    \"    \"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"id\": \"682c67b1\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# You need to update data path with your real path and hardware resource!\\n\",\n    \"dataRoot = os.getenv(\\\"DATA_ROOT\\\", \\\"/data\\\")\\n\",\n    \"driverMem = os.getenv(\\\"DRIVER_MEM\\\", \\\"50g\\\")\\n\",\n    \"executorMem = os.getenv(\\\"EXECUTOR_MEM\\\", \\\"12g\\\")\\n\",\n    \"maxPartionBytes = os.getenv(\\\"MAX_PARTITION_BYTES\\\", \\\"1g\\\")\\n\",\n    \"executorCores = int(os.getenv(\\\"EXECUTOR_CORES\\\", \\\"4\\\"))\\n\",\n    \"# common spark settings\\n\",\n    \"conf = SparkConf()\\n\",\n    \"conf.setMaster(SPARK_MASTER_URL)\\n\",\n    \"conf.setAppName(\\\"Microbenchmark on CPU\\\")\\n\",\n    \"conf.set(\\\"spark.driver.memory\\\", driverMem)\\n\",\n    \"conf.set(\\\"spark.executor.memory\\\", executorMem)\\n\",\n    \"conf.set(\\\"spark.executor.cores\\\", executorCores)\\n\",\n    \" \\n\",\n    \"conf.set(\\\"spark.locality.wait\\\", \\\"0\\\")\\n\",\n    \"conf.set(\\\"spark.sql.files.maxPartitionBytes\\\", maxPartionBytes) \\n\",\n    \"conf.set(\\\"spark.dynamicAllocation.enabled\\\", \\\"false\\\") \\n\",\n    \"conf.set(\\\"spark.sql.adaptive.enabled\\\", \\\"true\\\")  \\n\",\n    \"\\n\",\n    \"# create spark session\\n\",\n    \"spark = SparkSession.builder.config(conf=conf).getOrCreate()\\n\",\n    \"# Load dataframe and create tempView\\n\",\n    \"spark.read.parquet(dataRoot + \\\"/tpcds/store_sales\\\").createOrReplaceTempView(\\\"store_sales\\\")\\n\",\n    \"spark.read.parquet(dataRoot + \\\"/tpcds/catalog_sales\\\").createOrReplaceTempView(\\\"catalog_sales\\\")\\n\",\n    \"spark.read.parquet(dataRoot + \\\"/tpcds/web_sales\\\").createOrReplaceTempView(\\\"web_sales\\\")\\n\",\n    \"spark.read.parquet(dataRoot + \\\"/tpcds/item\\\").createOrReplaceTempView(\\\"item\\\")\\n\",\n    \"spark.read.parquet(dataRoot + \\\"/tpcds/date_dim\\\").createOrReplaceTempView(\\\"date_dim\\\")\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"89512b77\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Expand&HashAggregate\\n\",\n    \"This is a microbenchmark about Expand&HashAggregate expressions running on the CPU. The query calculates the distinct value of some dimension columns and average birth year by different c_salutation of customers after grouping by c_current_hdemo_sk.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"id\": \"3272ef56\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"--------------------------------------------------\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"# As a part of this query the size of the data in each task grows a lot. \\n\",\n    \"# By default, Spark will try to distribute the data among all the tasks in the cluster, \\n\",\n    \"# but on large clusters with large parquet files the splittable portions of the parquet files end up not being distributed evenly \\n\",\n    \"# and it is faster to re-partition the data to redistribute it than to deal with skew.\\n\",\n    \"spark.read.parquet(dataRoot + \\\"/tpcds/customer\\\").repartition(512).createOrReplaceTempView(\\\"customer\\\")\\n\",\n    \"\\n\",\n    \"print(\\\"-\\\"*50)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"id\": \"dd12d749\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"--------------------------------------------------\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"query = '''\\n\",\n    \"select c_current_hdemo_sk,\\n\",\n    \"count(DISTINCT if(c_salutation==\\\"Ms.\\\",c_salutation,null)) as c1,\\n\",\n    \"count(DISTINCT if(c_salutation==\\\"Mr.\\\",c_salutation,null)) as c12,\\n\",\n    \"count(DISTINCT if(c_salutation==\\\"Dr.\\\",c_salutation,null)) as c13,\\n\",\n    \"\\n\",\n    \"count(DISTINCT if(c_salutation==\\\"Ms.\\\",c_first_name,null)) as c2,\\n\",\n    \"count(DISTINCT if(c_salutation==\\\"Mr.\\\",c_first_name,null)) as c22,\\n\",\n    \"count(DISTINCT if(c_salutation==\\\"Dr.\\\",c_first_name,null)) as c23,\\n\",\n    \"\\n\",\n    \"count(DISTINCT if(c_salutation==\\\"Ms.\\\",c_last_name,null)) as c3,\\n\",\n    \"count(DISTINCT if(c_salutation==\\\"Mr.\\\",c_last_name,null)) as c32,\\n\",\n    \"count(DISTINCT if(c_salutation==\\\"Dr.\\\",c_last_name,null)) as c33,\\n\",\n    \"\\n\",\n    \"count(DISTINCT if(c_salutation==\\\"Ms.\\\",c_birth_country,null)) as c4,\\n\",\n    \"count(DISTINCT if(c_salutation==\\\"Mr.\\\",c_birth_country,null)) as c42,\\n\",\n    \"count(DISTINCT if(c_salutation==\\\"Dr.\\\",c_birth_country,null)) as c43,\\n\",\n    \"\\n\",\n    \"count(DISTINCT if(c_salutation==\\\"Ms.\\\",c_email_address,null)) as c5,\\n\",\n    \"count(DISTINCT if(c_salutation==\\\"Mr.\\\",c_email_address,null)) as c52,\\n\",\n    \"count(DISTINCT if(c_salutation==\\\"Dr.\\\",c_email_address,null)) as c53,\\n\",\n    \"\\n\",\n    \"count(DISTINCT if(c_salutation==\\\"Ms.\\\",c_login,null)) as c6,\\n\",\n    \"count(DISTINCT if(c_salutation==\\\"Mr.\\\",c_login,null)) as c62,\\n\",\n    \"count(DISTINCT if(c_salutation==\\\"Dr.\\\",c_login,null)) as c63,\\n\",\n    \"\\n\",\n    \"count(DISTINCT if(c_salutation==\\\"Ms.\\\",c_preferred_cust_flag,null)) as c7,\\n\",\n    \"count(DISTINCT if(c_salutation==\\\"Mr.\\\",c_preferred_cust_flag,null)) as c72,\\n\",\n    \"count(DISTINCT if(c_salutation==\\\"Dr.\\\",c_preferred_cust_flag,null)) as c73,\\n\",\n    \"\\n\",\n    \"count(DISTINCT if(c_salutation==\\\"Ms.\\\",c_birth_month,null)) as c8,\\n\",\n    \"count(DISTINCT if(c_salutation==\\\"Mr.\\\",c_birth_month,null)) as c82,\\n\",\n    \"count(DISTINCT if(c_salutation==\\\"Dr.\\\",c_birth_month,null)) as c83,\\n\",\n    \"\\n\",\n    \"avg(if(c_salutation==\\\"Ms.\\\",c_birth_year,null)) as avg1,\\n\",\n    \"avg(if(c_salutation==\\\"Mr.\\\",c_birth_year,null)) as avg2,\\n\",\n    \"avg(if(c_salutation==\\\"Dr.\\\",c_birth_year,null)) as avg3,\\n\",\n    \"avg(if(c_salutation==\\\"Miss.\\\",c_birth_year,null)) as avg4,\\n\",\n    \"avg(if(c_salutation==\\\"Mrs.\\\",c_birth_year,null)) as avg5,\\n\",\n    \"avg(if(c_salutation==\\\"Sir.\\\",c_birth_year,null)) as avg6,\\n\",\n    \"avg(if(c_salutation==\\\"Professor.\\\",c_birth_year,null)) as avg7,\\n\",\n    \"avg(if(c_salutation==\\\"Teacher.\\\",c_birth_year,null)) as avg8,\\n\",\n    \"avg(if(c_salutation==\\\"Agent.\\\",c_birth_year,null)) as avg9,\\n\",\n    \"avg(if(c_salutation==\\\"Director.\\\",c_birth_year,null)) as avg10\\n\",\n    \"from customer group by c_current_hdemo_sk\\n\",\n    \"'''\\n\",\n    \"print(\\\"-\\\"*50)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"id\": \"2e105bf8\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+------------------+---+---+---+---+---+---+---+---+---+---+---+---+---+---+----+---+---+---+---+---+---+---+---+---+------------------+------------------+------------------+----+------------------+----+----+----+----+-----+\\n\",\n      \"|c_current_hdemo_sk| c1|c12|c13| c2|c22|c23| c3|c32|c33| c4|c42|c43| c5|c52| c53| c6|c62|c63| c7|c72|c73| c8|c82|c83|              avg1|              avg2|              avg3|avg4|              avg5|avg6|avg7|avg8|avg9|avg10|\\n\",\n      \"+------------------+---+---+---+---+---+---+---+---+---+---+---+---+---+---+----+---+---+---+---+---+---+---+---+---+------------------+------------------+------------------+----+------------------+----+----+----+----+-----+\\n\",\n      \"|              5803|  1|  1|  1|285|272|592|358|496|791|185|202|210|458|674|1177|  0|  0|  0|  2|  2|  2| 12| 12| 12|1959.5225806451613|1959.6557863501484|1958.5581196581197|null| 1958.873873873874|null|null|null|null| null|\\n\",\n      \"|              1591|  1|  1|  1|283|237|544|374|489|739|193|206|211|476|664|1144|  0|  0|  0|  2|  2|  2| 12| 12| 12|1957.3514644351465|1958.2278860569716|1958.6174672489083|null|1958.4357894736843|null|null|null|null| null|\\n\",\n      \"|              3918|  1|  1|  1|300|266|539|392|499|755|190|203|210|507|675|1140|  0|  0|  0|  2|  2|  2| 12| 12| 12|1957.6745562130177|1958.2998522895125|1958.8992994746059|null|1959.4233009708737|null|null|null|null| null|\\n\",\n      \"|              1580|  1|  1|  1|296|256|562|392|499|808|190|203|211|499|692|1222|  0|  0|  0|  2|  2|  2| 12| 12| 12|1958.5771543086173|  1957.53591954023|1957.3303278688525|null|1958.3611691022963|null|null|null|null| null|\\n\",\n      \"|               148|  1|  1|  1|309|260|562|392|501|772|187|207|211|488|668|1154|  0|  0|  0|  2|  2|  2| 12| 12| 12| 1956.219008264463|1958.9161676646706|1957.8076256499132|null|1958.3412017167382|null|null|null|null| null|\\n\",\n      \"+------------------+---+---+---+---+---+---+---+---+---+---+---+---+---+---+----+---+---+---+---+---+---+---+---+---+------------------+------------------+------------------+----+------------------+----+----+----+----+-----+\\n\",\n      \"only showing top 5 rows\\n\",\n      \"\\n\",\n      \"Retry times : 1, Expand&HashAggregate Microbenchmark takes 65.21 seconds\\n\",\n      \"Expand&HashAggregate Microbenchmark takes average 65 seconds after 1 retries\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"# Run microbenchmark with n retry time\\n\",\n    \"runMicroBenchmark(spark,\\\"Expand&HashAggregate\\\",query ,1)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"57da403a\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Windowing (without data skew)\\n\",\n    \"This is a microbenchmark about windowing expressions running on CPU mode. The sub-query calculates the average ss_sales_price of a fixed window function partition by ss_customer_sk, and the parent query calculates the average price of the sub-query grouping by each customer.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 10,\n   \"id\": \"68169e7f\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"--------------------------------------------------\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"query = '''\\n\",\n    \"select ss_customer_sk,avg(avg_price) as avg_price\\n\",\n    \"from\\n\",\n    \"(\\n\",\n    \"SELECT ss_customer_sk ,avg(ss_sales_price) OVER (PARTITION BY ss_customer_sk order by ss_sold_date_sk ROWS BETWEEN 50 PRECEDING AND 50 FOLLOWING ) as avg_price\\n\",\n    \"FROM store_sales\\n\",\n    \"where ss_customer_sk is not null\\n\",\n    \") group by ss_customer_sk order by 2 desc \\n\",\n    \"'''\\n\",\n    \"print(\\\"-\\\"*50)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 11,\n   \"id\": \"f4d1d9ea\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+--------------+------------------+\\n\",\n      \"|ss_customer_sk|         avg_price|\\n\",\n      \"+--------------+------------------+\\n\",\n      \"|      15924921|52.453036568858586|\\n\",\n      \"|      24796404|52.406491887877976|\\n\",\n      \"|      10174233|52.217149302596276|\\n\",\n      \"|      27571451| 52.14256448618126|\\n\",\n      \"|      14299506| 52.09827897444722|\\n\",\n      \"+--------------+------------------+\\n\",\n      \"only showing top 5 rows\\n\",\n      \"\\n\",\n      \"Retry times : 1, Windowing without skew Microbenchmark takes 176.61 seconds\\n\",\n      \"Windowing without skew Microbenchmark takes average 177 seconds after 1 retries\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"# Run microbenchmark with n retry time\\n\",\n    \"runMicroBenchmark(spark,\\\"Windowing without skew\\\",query , 1)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"7df0e850\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Windowing(with data skew)\\n\",\n    \"Data skew is caused by many null values in the ss_customer_sk column.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 15,\n   \"id\": \"12ec99fb\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"--------------------------------------------------\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"query = '''\\n\",\n    \"select ss_customer_sk,avg(avg_price) as avg_price\\n\",\n    \"from\\n\",\n    \"(\\n\",\n    \"SELECT ss_customer_sk ,avg(ss_sales_price) OVER (PARTITION BY ss_customer_sk order by ss_sold_date_sk ROWS BETWEEN 50 PRECEDING AND 50 FOLLOWING ) as avg_price\\n\",\n    \"FROM store_sales\\n\",\n    \") group by ss_customer_sk order by 2 desc \\n\",\n    \"'''\\n\",\n    \"print(\\\"-\\\"*50)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 16,\n   \"id\": \"86e12b88\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+--------------+------------------+\\n\",\n      \"|ss_customer_sk|         avg_price|\\n\",\n      \"+--------------+------------------+\\n\",\n      \"|      15924921| 52.44865972015809|\\n\",\n      \"|      24796404|52.406491887877976|\\n\",\n      \"|      10174233|52.215293069577626|\\n\",\n      \"|      27571451| 52.14256448618126|\\n\",\n      \"|      14299506| 52.09827897444722|\\n\",\n      \"+--------------+------------------+\\n\",\n      \"only showing top 5 rows\\n\",\n      \"\\n\",\n      \"Retry times : 1, Windowing with skew Microbenchmark takes 1666.07 seconds\\n\",\n      \"Windowing with skew Microbenchmark takes average 1666 seconds after 1 retries\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"# Run microbenchmark with n retry time\\n\",\n    \"runMicroBenchmark(spark,\\\"Windowing with skew\\\",query ,1)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"2ef292cc\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Intersection\\n\",\n    \"This is a microbenchmark about intersection operation running on CPU mode. The query calculates items in the same brand, class, and category that are sold in all three sales channels in two consecutive years.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 19,\n   \"id\": \"30c8eb8e\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"query = '''\\n\",\n    \"select i_item_sk ss_item_sk\\n\",\n    \" from item,\\n\",\n    \"    (select iss.i_brand_id brand_id, iss.i_class_id class_id, iss.i_category_id category_id\\n\",\n    \"     from store_sales, item iss, date_dim d1\\n\",\n    \"     where ss_item_sk = iss.i_item_sk\\n\",\n    \"                    and ss_sold_date_sk = d1.d_date_sk\\n\",\n    \"       and d1.d_year between 1999 AND 1999 + 2\\n\",\n    \"   intersect\\n\",\n    \"     select ics.i_brand_id, ics.i_class_id, ics.i_category_id\\n\",\n    \"     from catalog_sales, item ics, date_dim d2\\n\",\n    \"     where cs_item_sk = ics.i_item_sk\\n\",\n    \"       and cs_sold_date_sk = d2.d_date_sk\\n\",\n    \"       and d2.d_year between 1999 AND 1999 + 2\\n\",\n    \"   intersect\\n\",\n    \"     select iws.i_brand_id, iws.i_class_id, iws.i_category_id\\n\",\n    \"     from web_sales, item iws, date_dim d3\\n\",\n    \"     where ws_item_sk = iws.i_item_sk\\n\",\n    \"       and ws_sold_date_sk = d3.d_date_sk\\n\",\n    \"       and d3.d_year between 1999 AND 1999 + 2) x\\n\",\n    \" where i_brand_id = brand_id\\n\",\n    \"   and i_class_id = class_id\\n\",\n    \"   and i_category_id = category_id\\n\",\n    \"'''\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 20,\n   \"id\": \"d4f9f669\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+----------+\\n\",\n      \"|ss_item_sk|\\n\",\n      \"+----------+\\n\",\n      \"|    326835|\\n\",\n      \"|    248465|\\n\",\n      \"|    174935|\\n\",\n      \"|    130715|\\n\",\n      \"|     78159|\\n\",\n      \"+----------+\\n\",\n      \"only showing top 5 rows\\n\",\n      \"\\n\",\n      \"Retry times : 1, NDS Q14a subquery Microbenchmark takes 62.42 seconds\\n\",\n      \"NDS Q14a subquery Microbenchmark takes average 62 seconds after 1 retries\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"# Run microbenchmark with n retry time\\n\",\n    \"runMicroBenchmark(spark,\\\"NDS Q14a subquery\\\",query ,1)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"5b051d6b\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Crossjoin\\n\",\n    \"This is a microbenchmark for a 1-million rows crossjoin with itself.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 21,\n   \"id\": \"56af3f00\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"--------------------------------------------------\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"# You have to stop the sparksession and create a new one \\n\",\n    \"# because in this query we need to create more executors with less cores to get the best performance\\n\",\n    \"spark.stop()\\n\",\n    \"conf = SparkConf()\\n\",\n    \"# Common spark settings\\n\",\n    \"conf.setMaster(SPARK_MASTER_URL)\\n\",\n    \"conf.setAppName(\\\"Crossjoin Microbenchmark on CPU\\\")\\n\",\n    \" \\n\",\n    \"conf.set(\\\"spark.driver.memory\\\", driverMem)\\n\",\n    \"conf.set(\\\"spark.executor.memory\\\", executorMem)\\n\",\n    \"conf.set(\\\"spark.executor.cores\\\", executorCores)\\n\",\n    \" \\n\",\n    \"conf.set(\\\"spark.locality.wait\\\", \\\"0\\\")\\n\",\n    \"conf.set(\\\"spark.sql.files.maxPartitionBytes\\\", maxPartionBytes) \\n\",\n    \"conf.set(\\\"spark.dynamicAllocation.enabled\\\", \\\"false\\\") \\n\",\n    \"conf.set(\\\"spark.sql.adaptive.enabled\\\", \\\"true\\\")\\n\",\n    \"# We can get a better performance by broadcast one table to change CartesianJoin to BroadCastNestLoopJoin\\n\",\n    \"conf.set(\\\"spark.sql.autoBroadcastJoinThreshold\\\",1000000000)\\n\",\n    \"# Get or create spark session\\n\",\n    \"spark = SparkSession.builder.config(conf=conf).getOrCreate()\\n\",\n    \"\\n\",\n    \"print(\\\"-\\\"*50)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 22,\n   \"id\": \"ae9cdc08\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"scanning and writing parquet cost : 18.18 seconds\\n\",\n      \"--------------------------------------------------\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"# Load dataframe and create tempView\\n\",\n    \"start = time() \\n\",\n    \"spark.read.parquet(dataRoot + \\\"/tpcds/customer\\\").limit(1000000).write.format(\\\"parquet\\\").mode(\\\"overwrite\\\").save(\\\"/data/tmp/customer1m\\\")\\n\",\n    \"end = time()\\n\",\n    \"print(\\\"scanning and writing parquet cost : {} seconds\\\".format(round(end - start, 2)))\\n\",\n    \"# We need to tune the partition number to get the best performance.\\n\",\n    \"spark.read.parquet(\\\"/data/tmp/customer1m\\\").repartition(16000).createOrReplaceTempView(\\\"costomer_df_1_million\\\")\\n\",\n    \"query = '''\\n\",\n    \"select count(*) from costomer_df_1_million c1 inner join costomer_df_1_million c2 on c1.c_customer_sk>c2.c_customer_sk\\n\",\n    \"'''\\n\",\n    \"print(\\\"-\\\"*50)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 23,\n   \"id\": \"0571d861\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+------------+\\n\",\n      \"|    count(1)|\\n\",\n      \"+------------+\\n\",\n      \"|499999500000|\\n\",\n      \"+------------+\\n\",\n      \"\\n\",\n      \"Retry times : 1, Crossjoin Microbenchmark takes 78.8 seconds\\n\",\n      \"Crossjoin Microbenchmark takes average 79 seconds after 1 retries\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"# Run microbenchmark with n retry time\\n\",\n    \"runMicroBenchmark(spark,\\\"Crossjoin\\\",query ,1)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"56f915c2-9b9a-4982-8c4e-5b570c17bfeb\",\n   \"metadata\": {},\n   \"source\": [\n    \"### HashJoin\\n\",\n    \"This is a microbenchmark for a HashJoin. The query on GPU will be more than 10x times faster than CPU based on the cluster in the readme.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"040603c9-a96f-4017-bcdb-5f93e12996a4\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"spark.read.parquet(dataRoot + \\\"/tpcds/store_sales\\\").createOrReplaceTempView(\\\"store_sales\\\")\\n\",\n    \"spark.read.parquet(dataRoot + \\\"/tpcds/store_returns\\\").createOrReplaceTempView(\\\"store_returns\\\")\\n\",\n    \"\\n\",\n    \"print(\\\"-\\\"*50)\\n\",\n    \"query = '''\\n\",\n    \"select  sum(store_sales.ss_ext_wholesale_cost)\\n\",\n    \"from store_sales\\n\",\n    \"join store_returns on (ss_item_sk = sr_item_sk) and (ss_addr_sk=sr_addr_sk)\\n\",\n    \"'''\\n\",\n    \"runMicroBenchmark(spark,\\\"HashJoin\\\",query,1)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 24,\n   \"id\": \"7c118cc9\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"spark.stop()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"c9e43255\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3 (ipykernel)\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.12.3\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "examples/SQL+DF-Examples/micro-benchmarks/notebooks/micro-benchmarks-gpu.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"62787244\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Microbenchmarks on GPU\\n\",\n    \"This is a notebook for microbenchmarks running on GPU. \"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"id\": \"1c3a15d7\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from pyspark.sql import SparkSession\\n\",\n    \"from pyspark.conf import SparkConf\\n\",\n    \"from time import time\\n\",\n    \"import os\\n\",\n    \"# Change to your cluster ip:port and directories\\n\",\n    \"SPARK_MASTER_URL = os.getenv(\\\"SPARK_MASTER_URL\\\", \\\"spark:your-ip:port\\\")\\n\",\n    \"RAPIDS_JAR = os.getenv(\\\"RAPIDS_JAR\\\", \\\"/your-path/rapids-4-spark_2.12-26.02.0.jar\\\")\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"b10a2ad1\",\n   \"metadata\": {},\n   \"source\": [\n    \"Run the microbenchmark with retryTimes\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"id\": \"0c3536ad\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def runMicroBenchmark(spark, appName, query, retryTimes):\\n\",\n    \"    count = 0\\n\",\n    \"    total_time = 0\\n\",\n    \"    # You can print the physical plan of each query\\n\",\n    \"    # spark.sql(query).explain()\\n\",\n    \"    while count < retryTimes:\\n\",\n    \"        start = time()\\n\",\n    \"        spark.sql(query).show(5)\\n\",\n    \"        end = time()\\n\",\n    \"        total_time += round(end - start, 2)\\n\",\n    \"        count = count + 1\\n\",\n    \"        print(\\\"Retry times : {}, \\\".format(count) + appName + \\\" microbenchmark takes {} seconds\\\".format(round(end - start, 2)))\\n\",\n    \"    print(appName + \\\" microbenchmark takes average {} seconds after {} retries\\\".format(round(total_time/retryTimes),retryTimes))\\n\",\n    \"    with open('result.txt', 'a') as file:\\n\",\n    \"        file.write(\\\"{},{},{}\\\\n\\\".format(appName, round(total_time/retryTimes), retryTimes))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"975717da\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"--------------------------------------------------\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"# You need to update with your real hardware resource \\n\",\n    \"driverMem = os.getenv(\\\"DRIVER_MEM\\\", \\\"50g\\\")\\n\",\n    \"executorMem = os.getenv(\\\"EXECUTOR_MEM\\\", \\\"16g\\\")\\n\",\n    \"maxPartionBytes = os.getenv(\\\"MAX_PARTITION_BYTES\\\", \\\"1g\\\")\\n\",\n    \"pinnedPoolSize = os.getenv(\\\"PINNED_POOL_SIZE\\\", \\\"8g\\\")\\n\",\n    \"concurrentGpuTasks = os.getenv(\\\"CONCURRENT_GPU_TASKS\\\", \\\"4\\\")\\n\",\n    \"executorCores = int(os.getenv(\\\"EXECUTOR_CORES\\\", \\\"16\\\"))\\n\",\n    \"eventlogDir = \\\"file:\\\"+os.getenv(\\\"EVENTLOG_DIR\\\")\\n\",\n    \"gpuPerExecutor = 1/executorCores\\n\",\n    \"# Common spark settings\\n\",\n    \"conf = SparkConf()\\n\",\n    \"conf.setMaster(SPARK_MASTER_URL)\\n\",\n    \"conf.setAppName(\\\"Microbenchmark on GPU\\\")\\n\",\n    \"conf.set(\\\"spark.driver.memory\\\", driverMem)\\n\",\n    \"## The tasks will run on GPU memory, so there is no need to set a high host memory\\n\",\n    \"conf.set(\\\"spark.executor.memory\\\", executorMem)\\n\",\n    \"## The tasks will run on GPU cores, so there is no need to use many cpu cores\\n\",\n    \"conf.set(\\\"spark.executor.cores\\\", executorCores)\\n\",\n    \"conf.set(\\\"spark.locality.wait\\\", \\\"0\\\")\\n\",\n    \"conf.set(\\\"spark.sql.files.maxPartitionBytes\\\", maxPartionBytes) \\n\",\n    \"conf.set(\\\"spark.dynamicAllocation.enabled\\\", \\\"false\\\") \\n\",\n    \"conf.set(\\\"spark.sql.adaptive.enabled\\\", \\\"true\\\") \\n\",\n    \"\\n\",\n    \"# Plugin settings\\n\",\n    \"conf.set(\\\"spark.executor.resource.gpu.amount\\\", \\\"1\\\")\\n\",\n    \"# 4 tasks will run concurrently per GPU\\n\",\n    \"conf.set(\\\"spark.rapids.sql.concurrentGpuTasks\\\", concurrentGpuTasks)\\n\",\n    \"# Pinned 8g host memory to transfer data between GPU and host memory\\n\",\n    \"conf.set(\\\"spark.rapids.memory.pinnedPool.size\\\", pinnedPoolSize)\\n\",\n    \"# 16 tasks will run concurrently per executor, as we set spark.executor.cores=16\\n\",\n    \"conf.set(\\\"spark.task.resource.gpu.amount\\\", gpuPerExecutor) \\n\",\n    \"conf.set(\\\"spark.rapids.sql.enabled\\\", \\\"true\\\") \\n\",\n    \"conf.set(\\\"spark.plugins\\\", \\\"com.nvidia.spark.SQLPlugin\\\")\\n\",\n    \"conf.set(\\\"spark.rapids.sql.variableFloatAgg.enabled\\\", \\\"true\\\")\\n\",\n    \"conf.set(\\\"spark.driver.extraClassPath\\\", RAPIDS_JAR)\\n\",\n    \"conf.set(\\\"spark.executor.extraClassPath\\\", RAPIDS_JAR)\\n\",\n    \"conf.set(\\\"spark.jars\\\", RAPIDS_JAR)\\n\",\n    \"conf.set(\\\"spark.eventLog.enabled\\\", \\\"true\\\")\\n\",\n    \"conf.set(\\\"spark.eventLog.dir\\\", eventlogDir)\\n\",\n    \"# Create spark session\\n\",\n    \"spark = SparkSession.builder.config(conf=conf).getOrCreate()\\n\",\n    \"# Load dataframe and create tempView\\n\",\n    \"# You need to update data path to your real path!\\n\",\n    \"dataRoot = os.getenv(\\\"DATA_ROOT\\\", \\\"/data\\\")\\n\",\n    \"spark.read.parquet(dataRoot + \\\"/customer.dat\\\").createOrReplaceTempView(\\\"customer\\\")\\n\",\n    \"spark.read.parquet(dataRoot + \\\"/store_sales.dat\\\").createOrReplaceTempView(\\\"store_sales\\\")\\n\",\n    \"spark.read.parquet(dataRoot + \\\"/catalog_sales.dat\\\").createOrReplaceTempView(\\\"catalog_sales\\\")\\n\",\n    \"spark.read.parquet(dataRoot + \\\"/web_sales.dat\\\").createOrReplaceTempView(\\\"web_sales\\\")\\n\",\n    \"spark.read.parquet(dataRoot + \\\"/item.dat\\\").createOrReplaceTempView(\\\"item\\\")\\n\",\n    \"spark.read.parquet(dataRoot + \\\"/date_dim.dat\\\").createOrReplaceTempView(\\\"date_dim\\\")\\n\",\n    \"print(\\\"-\\\"*50)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"7136eb63\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Expand&HashAggregate\\n\",\n    \"This is a microbenchmark about Expand&HashAggregate expressions running on the GPU. The query calculates the distinct value of some dimension columns and average birth year by different c_salutation of customers after grouping by c_current_hdemo_sk. You will see about 10x speedups in this query. Because an additional shuffle involved by the repartition operator in CPU mode. And GPUExpand and GPUHashAggregate is much faster than Expand and HashAggregate because GPU algorithms allow us to parallelize the computation and we can utilize most of the GPU cores. The tasks' duration in the third stage is less than one second but will cost 20x-40x while running on CPU. There will be a more significant performance improvement along with the increasing number of count distinct columns and aggregate functions.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"id\": \"dd12d749\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"query = '''\\n\",\n    \"select c_current_hdemo_sk,\\n\",\n    \"count(DISTINCT if(c_salutation==\\\"Ms.\\\",c_salutation,null)) as c1,\\n\",\n    \"count(DISTINCT if(c_salutation==\\\"Mr.\\\",c_salutation,null)) as c12,\\n\",\n    \"count(DISTINCT if(c_salutation==\\\"Dr.\\\",c_salutation,null)) as c13,\\n\",\n    \"\\n\",\n    \"count(DISTINCT if(c_salutation==\\\"Ms.\\\",c_first_name,null)) as c2,\\n\",\n    \"count(DISTINCT if(c_salutation==\\\"Mr.\\\",c_first_name,null)) as c22,\\n\",\n    \"count(DISTINCT if(c_salutation==\\\"Dr.\\\",c_first_name,null)) as c23,\\n\",\n    \"\\n\",\n    \"count(DISTINCT if(c_salutation==\\\"Ms.\\\",c_last_name,null)) as c3,\\n\",\n    \"count(DISTINCT if(c_salutation==\\\"Mr.\\\",c_last_name,null)) as c32,\\n\",\n    \"count(DISTINCT if(c_salutation==\\\"Dr.\\\",c_last_name,null)) as c33,\\n\",\n    \"\\n\",\n    \"count(DISTINCT if(c_salutation==\\\"Ms.\\\",c_birth_country,null)) as c4,\\n\",\n    \"count(DISTINCT if(c_salutation==\\\"Mr.\\\",c_birth_country,null)) as c42,\\n\",\n    \"count(DISTINCT if(c_salutation==\\\"Dr.\\\",c_birth_country,null)) as c43,\\n\",\n    \"\\n\",\n    \"count(DISTINCT if(c_salutation==\\\"Ms.\\\",c_email_address,null)) as c5,\\n\",\n    \"count(DISTINCT if(c_salutation==\\\"Mr.\\\",c_email_address,null)) as c52,\\n\",\n    \"count(DISTINCT if(c_salutation==\\\"Dr.\\\",c_email_address,null)) as c53,\\n\",\n    \"\\n\",\n    \"count(DISTINCT if(c_salutation==\\\"Ms.\\\",c_login,null)) as c6,\\n\",\n    \"count(DISTINCT if(c_salutation==\\\"Mr.\\\",c_login,null)) as c62,\\n\",\n    \"count(DISTINCT if(c_salutation==\\\"Dr.\\\",c_login,null)) as c63,\\n\",\n    \"\\n\",\n    \"count(DISTINCT if(c_salutation==\\\"Ms.\\\",c_preferred_cust_flag,null)) as c7,\\n\",\n    \"count(DISTINCT if(c_salutation==\\\"Mr.\\\",c_preferred_cust_flag,null)) as c72,\\n\",\n    \"count(DISTINCT if(c_salutation==\\\"Dr.\\\",c_preferred_cust_flag,null)) as c73,\\n\",\n    \"\\n\",\n    \"count(DISTINCT if(c_salutation==\\\"Ms.\\\",c_birth_month,null)) as c8,\\n\",\n    \"count(DISTINCT if(c_salutation==\\\"Mr.\\\",c_birth_month,null)) as c82,\\n\",\n    \"count(DISTINCT if(c_salutation==\\\"Dr.\\\",c_birth_month,null)) as c83,\\n\",\n    \"\\n\",\n    \"avg(if(c_salutation==\\\"Ms.\\\",c_birth_year,null)) as avg1,\\n\",\n    \"avg(if(c_salutation==\\\"Mr.\\\",c_birth_year,null)) as avg2,\\n\",\n    \"avg(if(c_salutation==\\\"Dr.\\\",c_birth_year,null)) as avg3,\\n\",\n    \"avg(if(c_salutation==\\\"Miss.\\\",c_birth_year,null)) as avg4,\\n\",\n    \"avg(if(c_salutation==\\\"Mrs.\\\",c_birth_year,null)) as avg5,\\n\",\n    \"avg(if(c_salutation==\\\"Sir.\\\",c_birth_year,null)) as avg6,\\n\",\n    \"avg(if(c_salutation==\\\"Professor.\\\",c_birth_year,null)) as avg7,\\n\",\n    \"avg(if(c_salutation==\\\"Teacher.\\\",c_birth_year,null)) as avg8,\\n\",\n    \"avg(if(c_salutation==\\\"Agent.\\\",c_birth_year,null)) as avg9,\\n\",\n    \"avg(if(c_salutation==\\\"Director.\\\",c_birth_year,null)) as avg10\\n\",\n    \"from customer group by c_current_hdemo_sk\\n\",\n    \"'''\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"id\": \"2e105bf8\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+------------------+---+---+---+---+---+---+---+---+---+---+---+---+---+---+----+---+---+---+---+---+---+---+---+---+------------------+------------------+------------------+----+------------------+----+----+----+----+-----+\\n\",\n      \"|c_current_hdemo_sk| c1|c12|c13| c2|c22|c23| c3|c32|c33| c4|c42|c43| c5|c52| c53| c6|c62|c63| c7|c72|c73| c8|c82|c83|              avg1|              avg2|              avg3|avg4|              avg5|avg6|avg7|avg8|avg9|avg10|\\n\",\n      \"+------------------+---+---+---+---+---+---+---+---+---+---+---+---+---+---+----+---+---+---+---+---+---+---+---+---+------------------+------------------+------------------+----+------------------+----+----+----+----+-----+\\n\",\n      \"|              1238|  1|  1|  1|284|255|562|358|467|772|194|203|211|452|664|1157|  0|  0|  0|  2|  2|  2| 12| 12| 12|1957.2444933920706|1958.8547655068078|1957.2870771899393|null| 1958.042643923241|null|null|null|null| null|\\n\",\n      \"|              6658|  1|  1|  1|318|253|541|384|492|752|190|203|210|516|647|1115|  0|  0|  0|  2|  2|  2| 12| 12| 12|1959.0155945419103|1958.9720930232559|1958.0089525514773|null|1959.2618025751074|null|null|null|null| null|\\n\",\n      \"|              1088|  1|  1|  1|302|263|547|374|476|736|191|206|210|487|648|1074|  0|  0|  0|  2|  2|  2| 12| 12| 12|1957.7084188911704|1959.1323076923077|1957.2780898876404|null|1958.5641025641025|null|null|null|null| null|\\n\",\n      \"|              4818|  1|  1|  1|276|248|542|368|514|747|183|204|211|460|691|1093|  0|  0|  0|  2|  2|  2| 12| 12| 12|1957.8954248366013|1958.1313131313132|1957.5018315018315|null|1958.0252293577983|null|null|null|null| null|\\n\",\n      \"|               148|  1|  1|  1|309|260|562|392|501|772|187|207|211|488|668|1154|  0|  0|  0|  2|  2|  2| 12| 12| 12| 1956.219008264463|1958.9161676646706|1957.8076256499132|null|1958.3412017167382|null|null|null|null| null|\\n\",\n      \"+------------------+---+---+---+---+---+---+---+---+---+---+---+---+---+---+----+---+---+---+---+---+---+---+---+---+------------------+------------------+------------------+----+------------------+----+----+----+----+-----+\\n\",\n      \"only showing top 5 rows\\n\",\n      \"\\n\",\n      \"Retry times : 1, Expand&HashAggregate microbenchmark takes 11.13 seconds\\n\",\n      \"+------------------+---+---+---+---+---+---+---+---+---+---+---+---+---+---+----+---+---+---+---+---+---+---+---+---+------------------+------------------+------------------+----+------------------+----+----+----+----+-----+\\n\",\n      \"|c_current_hdemo_sk| c1|c12|c13| c2|c22|c23| c3|c32|c33| c4|c42|c43| c5|c52| c53| c6|c62|c63| c7|c72|c73| c8|c82|c83|              avg1|              avg2|              avg3|avg4|              avg5|avg6|avg7|avg8|avg9|avg10|\\n\",\n      \"+------------------+---+---+---+---+---+---+---+---+---+---+---+---+---+---+----+---+---+---+---+---+---+---+---+---+------------------+------------------+------------------+----+------------------+----+----+----+----+-----+\\n\",\n      \"|              1238|  1|  1|  1|284|255|562|358|467|772|194|203|211|452|664|1157|  0|  0|  0|  2|  2|  2| 12| 12| 12|1957.2444933920706|1958.8547655068078|1957.2870771899393|null| 1958.042643923241|null|null|null|null| null|\\n\",\n      \"|              6658|  1|  1|  1|318|253|541|384|492|752|190|203|210|516|647|1115|  0|  0|  0|  2|  2|  2| 12| 12| 12|1959.0155945419103|1958.9720930232559|1958.0089525514773|null|1959.2618025751074|null|null|null|null| null|\\n\",\n      \"|              4818|  1|  1|  1|276|248|542|368|514|747|183|204|211|460|691|1093|  0|  0|  0|  2|  2|  2| 12| 12| 12|1957.8954248366013|1958.1313131313132|1957.5018315018315|null|1958.0252293577983|null|null|null|null| null|\\n\",\n      \"|              1088|  1|  1|  1|302|263|547|374|476|736|191|206|210|487|648|1074|  0|  0|  0|  2|  2|  2| 12| 12| 12|1957.7084188911704|1959.1323076923077|1957.2780898876404|null|1958.5641025641025|null|null|null|null| null|\\n\",\n      \"|               148|  1|  1|  1|309|260|562|392|501|772|187|207|211|488|668|1154|  0|  0|  0|  2|  2|  2| 12| 12| 12| 1956.219008264463|1958.9161676646706|1957.8076256499132|null|1958.3412017167382|null|null|null|null| null|\\n\",\n      \"+------------------+---+---+---+---+---+---+---+---+---+---+---+---+---+---+----+---+---+---+---+---+---+---+---+---+------------------+------------------+------------------+----+------------------+----+----+----+----+-----+\\n\",\n      \"only showing top 5 rows\\n\",\n      \"\\n\",\n      \"Retry times : 2, Expand&HashAggregate microbenchmark takes 7.74 seconds\\n\",\n      \"Expand&HashAggregate microbenchmark takes average 9 seconds after 2 retries\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"# Run microbenchmark with n retry time\\n\",\n    \"runMicroBenchmark(spark,\\\"Expand&HashAggregate\\\",query,2)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"f50ec183\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Windowing(without data skew)\\n\",\n    \"This is a microbenchmark about windowing expressions running on GPU mode. The sub-query calculates the average ss_sales_price of a fixed window function partition by ss_customer_sk, and the parent query calculates the average price of the sub-query grouping by each customer. You will see about 25x speedups in this query. The speedup mainly comes from GPUSort/GPUWindow/GPUHashAggregate. The avg aggregation function evaluates all rows which are generated by the sub-query's window function. There will be a more significant performance improvement along with the increasing number of sub-query aggregate functions.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"id\": \"31bd0635\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"--------------------------------------------------\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"query = '''\\n\",\n    \"select ss_customer_sk,avg(avg_price) as avg_price\\n\",\n    \"from\\n\",\n    \"(\\n\",\n    \"SELECT ss_customer_sk ,avg(ss_sales_price) OVER (PARTITION BY ss_customer_sk order by ss_sold_date_sk ROWS BETWEEN 50 PRECEDING AND 50 FOLLOWING ) as avg_price\\n\",\n    \"FROM store_sales\\n\",\n    \"where ss_customer_sk is not null\\n\",\n    \") group by ss_customer_sk order by 2 desc \\n\",\n    \"'''\\n\",\n    \"print(\\\"-\\\"*50)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"id\": \"f9e93983\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+--------------+------------------+\\n\",\n      \"|ss_customer_sk|         avg_price|\\n\",\n      \"+--------------+------------------+\\n\",\n      \"|      15924921|52.375180502283705|\\n\",\n      \"|      24796404| 52.21073975966333|\\n\",\n      \"|      14299506| 52.16263537127018|\\n\",\n      \"|      27571451|52.156112032252395|\\n\",\n      \"|      10174233| 52.06401030721082|\\n\",\n      \"+--------------+------------------+\\n\",\n      \"only showing top 5 rows\\n\",\n      \"\\n\",\n      \"Retry times : 1, Windowing without skew microbenchmark takes 11.39 seconds\\n\",\n      \"+--------------+-----------------+\\n\",\n      \"|ss_customer_sk|        avg_price|\\n\",\n      \"+--------------+-----------------+\\n\",\n      \"|      15924921|52.53781291335107|\\n\",\n      \"|      24796404|52.39683466140243|\\n\",\n      \"|      27571451|52.18830023174899|\\n\",\n      \"|      14299506|52.10829141087412|\\n\",\n      \"|      10174233|51.92766214818386|\\n\",\n      \"+--------------+-----------------+\\n\",\n      \"only showing top 5 rows\\n\",\n      \"\\n\",\n      \"Retry times : 2, Windowing without skew microbenchmark takes 9.53 seconds\\n\",\n      \"Windowing without skew microbenchmark takes average 10 seconds after 2 retries\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"# Run microbenchmark with n retry time\\n\",\n    \"runMicroBenchmark(spark,\\\"Windowing without skew\\\",query,2)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"dcf08e47\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Windowing(with data skew)\\n\",\n    \"Data skew is caused by many null values in the ss_customer_sk column. You will see about 80x speedups in this query. The heavier skew task a query has, the more improved performance we will get because GPU parallelizes the computation, CPU is limited to just a single core because of how the algorithms are written.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"id\": \"2b9d223c\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"--------------------------------------------------\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"query = '''\\n\",\n    \"select ss_customer_sk,avg(avg_price) as avg_price\\n\",\n    \"from\\n\",\n    \"(\\n\",\n    \"SELECT ss_customer_sk ,avg(ss_sales_price) OVER (PARTITION BY ss_customer_sk order by ss_sold_date_sk ROWS BETWEEN 50 PRECEDING AND 50 FOLLOWING ) as avg_price\\n\",\n    \"FROM store_sales\\n\",\n    \") group by ss_customer_sk order by 2 desc \\n\",\n    \"'''\\n\",\n    \"print(\\\"-\\\"*50)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 9,\n   \"id\": \"0d7c65ee\",\n   \"metadata\": {\n    \"scrolled\": true\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+--------------+------------------+\\n\",\n      \"|ss_customer_sk|         avg_price|\\n\",\n      \"+--------------+------------------+\\n\",\n      \"|      24796404| 52.40675225109215|\\n\",\n      \"|      27571451|52.396675141359374|\\n\",\n      \"|      15924921| 52.30557497833058|\\n\",\n      \"|      10174233|52.088916933379096|\\n\",\n      \"|      14299506|51.995045713009794|\\n\",\n      \"+--------------+------------------+\\n\",\n      \"only showing top 5 rows\\n\",\n      \"\\n\",\n      \"Retry times : 1, Windowing with skew microbenchmark takes 17.46 seconds\\n\",\n      \"+--------------+------------------+\\n\",\n      \"|ss_customer_sk|         avg_price|\\n\",\n      \"+--------------+------------------+\\n\",\n      \"|      24796404|52.403564615099896|\\n\",\n      \"|      15924921|52.262694645994465|\\n\",\n      \"|      27571451| 52.14256448618127|\\n\",\n      \"|      10174233| 52.11346591610992|\\n\",\n      \"|      14299506| 51.99180221022445|\\n\",\n      \"+--------------+------------------+\\n\",\n      \"only showing top 5 rows\\n\",\n      \"\\n\",\n      \"Retry times : 2, Windowing with skew microbenchmark takes 16.63 seconds\\n\",\n      \"Windowing with skew microbenchmark takes average 17 seconds after 2 retries\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"# Run microbenchmark with n retry time\\n\",\n    \"runMicroBenchmark(spark,\\\"Windowing with skew\\\",query,2)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"53c0ed28\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Intersection\\n\",\n    \"This is a microbenchmark about intersection operation running on GPU mode. The query calculates items in the same brand, class, and category that are sold in all three sales channels in two consecutive years. You will see about 10x speedups in this query. This is a competition between high cardinality SortMergeJoin vs GpuShuffleHashJoin. The mainly improved performance comes from two SortMergeJoin(s) in this query running on CPU get converted to GpuShuffleHashJoin running on GPU.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 10,\n   \"id\": \"643c2e8a\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"query = '''\\n\",\n    \"select i_item_sk ss_item_sk\\n\",\n    \" from item,\\n\",\n    \"    (select iss.i_brand_id brand_id, iss.i_class_id class_id, iss.i_category_id category_id\\n\",\n    \"     from store_sales, item iss, date_dim d1\\n\",\n    \"     where ss_item_sk = iss.i_item_sk\\n\",\n    \"                    and ss_sold_date_sk = d1.d_date_sk\\n\",\n    \"       and d1.d_year between 1999 AND 1999 + 2\\n\",\n    \"   intersect\\n\",\n    \"     select ics.i_brand_id, ics.i_class_id, ics.i_category_id\\n\",\n    \"     from catalog_sales, item ics, date_dim d2\\n\",\n    \"     where cs_item_sk = ics.i_item_sk\\n\",\n    \"       and cs_sold_date_sk = d2.d_date_sk\\n\",\n    \"       and d2.d_year between 1999 AND 1999 + 2\\n\",\n    \"   intersect\\n\",\n    \"     select iws.i_brand_id, iws.i_class_id, iws.i_category_id\\n\",\n    \"     from web_sales, item iws, date_dim d3\\n\",\n    \"     where ws_item_sk = iws.i_item_sk\\n\",\n    \"       and ws_sold_date_sk = d3.d_date_sk\\n\",\n    \"       and d3.d_year between 1999 AND 1999 + 2) x\\n\",\n    \" where i_brand_id = brand_id\\n\",\n    \"   and i_class_id = class_id\\n\",\n    \"   and i_category_id = category_id\\n\",\n    \"'''\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 11,\n   \"id\": \"61bc2260\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+----------+\\n\",\n      \"|ss_item_sk|\\n\",\n      \"+----------+\\n\",\n      \"|      4323|\\n\",\n      \"|      4324|\\n\",\n      \"|      4325|\\n\",\n      \"|      4327|\\n\",\n      \"|      4328|\\n\",\n      \"+----------+\\n\",\n      \"only showing top 5 rows\\n\",\n      \"\\n\",\n      \"Retry times : 1, NDS Q14a subquery microbenchmark takes 6.71 seconds\\n\",\n      \"+----------+\\n\",\n      \"|ss_item_sk|\\n\",\n      \"+----------+\\n\",\n      \"|     14103|\\n\",\n      \"|     14104|\\n\",\n      \"|     14105|\\n\",\n      \"|     14107|\\n\",\n      \"|     14108|\\n\",\n      \"+----------+\\n\",\n      \"only showing top 5 rows\\n\",\n      \"\\n\",\n      \"Retry times : 2, NDS Q14a subquery microbenchmark takes 6.11 seconds\\n\",\n      \"NDS Q14a subquery microbenchmark takes average 6 seconds after 2 retries\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"# Run microbenchmark with n retry time\\n\",\n    \"runMicroBenchmark(spark,\\\"NDS Q14a subquery\\\",query,2)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"1346d126\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Crossjoin\\n\",\n    \"This is a microbenchmark for a 1-million rows crossjoin with itself. You will see about 10x speedups in this query. The mainly improved performance comes from converting BroadcastNestedLoogJoin running on CPU to GpuBroadcastNestedLoogJoin running on GPU.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 12,\n   \"id\": \"286ea45d\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"scanning and writing parquet cost : 5.31 seconds\\n\",\n      \"--------------------------------------------------\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"start = time() \\n\",\n    \"spark.read.parquet(dataRoot + \\\"/customer.dat\\\").limit(1000000).write.format(\\\"parquet\\\").mode(\\\"overwrite\\\").save(\\\"/data/tmp/customer1m\\\")\\n\",\n    \"end = time()\\n\",\n    \"# Parquet file scanning and writing will be about 3 times faster running on GPU\\n\",\n    \"print(\\\"scanning and writing parquet cost : {} seconds\\\".format(round(end - start, 2)))\\n\",\n    \"spark.read.parquet(\\\"/data/tmp/customer1m\\\").repartition(200).createOrReplaceTempView(\\\"costomer_df_1_million\\\")\\n\",\n    \"query = '''\\n\",\n    \"select count(*) from costomer_df_1_million c1 inner join costomer_df_1_million c2 on c1.c_customer_sk>c2.c_customer_sk\\n\",\n    \"'''\\n\",\n    \"print(\\\"-\\\"*50)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 13,\n   \"id\": \"f41b8d54\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+------------+\\n\",\n      \"|    count(1)|\\n\",\n      \"+------------+\\n\",\n      \"|499999500000|\\n\",\n      \"+------------+\\n\",\n      \"\\n\",\n      \"Retry times : 1, Crossjoin microbenchmark takes 6.7 seconds\\n\",\n      \"+------------+\\n\",\n      \"|    count(1)|\\n\",\n      \"+------------+\\n\",\n      \"|499999500000|\\n\",\n      \"+------------+\\n\",\n      \"\\n\",\n      \"Retry times : 2, Crossjoin microbenchmark takes 6.37 seconds\\n\",\n      \"Crossjoin microbenchmark takes average 7 seconds after 2 retries\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"# Run microbenchmark with n retry time\\n\",\n    \"runMicroBenchmark(spark,\\\"Crossjoin\\\",query,2)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"06b351e6-b7bd-4063-a20b-fe4fd71221f9\",\n   \"metadata\": {},\n   \"source\": [\n    \"### HashJoin\\n\",\n    \"This is a microbenchmark for a HashJoin. The query on GPU will be more than 10x times faster than CPU based on the cluster in the readme.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"191d0c9a-2d3a-40f4-89aa-f61dab5caa90\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"spark.read.parquet(dataRoot + \\\"/store_sales.dat\\\").createOrReplaceTempView(\\\"store_sales\\\")\\n\",\n    \"spark.read.parquet(dataRoot + \\\"/store_returns.dat\\\").createOrReplaceTempView(\\\"store_returns\\\")\\n\",\n    \"\\n\",\n    \"print(\\\"-\\\"*50)\\n\",\n    \"query = '''\\n\",\n    \"select  sum(store_sales.ss_ext_wholesale_cost)\\n\",\n    \"from store_sales\\n\",\n    \"join store_returns on (ss_item_sk = sr_item_sk) and (ss_addr_sk=sr_addr_sk)\\n\",\n    \"'''\\n\",\n    \"runMicroBenchmark(spark,\\\"HashJoin\\\",query,1)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"fc2092e8\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"spark.stop()\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3 (ipykernel)\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.12.3\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "examples/SQL+DF-Examples/retail-analytics/README.md",
    "content": "\n# Overview Retail Analytics \nThis repository contains two Jupyter notebooks:\n\nData Generation: This notebook generates sample data that can be used for analysis. It demonstrates how to use various Python libraries to create synthetic data sets that can be used for testing and experimentation. This notebook can be run in GCP n1-standard-32 instance type\n\nData Cleaning and Analysis: This notebook takes the generated data and performs a series of cleaning and analysis tasks. It demonstrates how to use Spark RAPIDS library to manipulate and analyze data sets.\n"
  },
  {
    "path": "examples/SQL+DF-Examples/retail-analytics/notebooks/python/retail-analytic.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import random\\n\",\n    \"from pyspark.sql import SparkSession\\n\",\n    \"from pyspark import broadcast, SparkConf\\n\",\n    \"import time\\n\",\n    \"import os\\n\",\n    \"\\n\",\n    \"RAPIDS_JAR = os.getenv(\\\"RAPIDS_JAR\\\", \\\"/path/to/your/jars/rapids.jar\\\")\\n\",\n    \"SPARK_MASTER = os.getenv(\\\"SPARK_MASTER_URL\\\", \\\"spark://ip:port\\\")\\n\",\n    \"print(\\\"RAPIDS_JAR: {}\\\".format(RAPIDS_JAR))\\n\",\n    \"if \\\"sc\\\" in globals():\\n\",\n    \"    sc.stop()\\n\",\n    \"\\n\",\n    \"### Configure the parameters based on your dataproc cluster ###\\n\",\n    \"conf = SparkConf().setAppName(\\\"Retail Analytics\\\")\\n\",\n    \"conf.setMaster(SPARK_MASTER)\\n\",\n    \"conf.set(\\\"spark.driver.extraClassPath\\\", RAPIDS_JAR)\\n\",\n    \"conf.set(\\\"spark.executor.extraClassPath\\\", RAPIDS_JAR)\\n\",\n    \"conf.set(\\\"spark.jars\\\", RAPIDS_JAR)\\n\",\n    \"conf.set(\\\"spark.executor.instances\\\", \\\"1\\\")\\n\",\n    \"conf.set(\\\"spark.executor.cores\\\", \\\"4\\\")\\n\",\n    \"conf.set(\\\"spark.task.resource.gpu.amount\\\", \\\"0.25\\\")\\n\",\n    \"conf.set(\\\"spark.rapids.sql.concurrentGpuTasks\\\", \\\"2\\\")\\n\",\n    \"conf.set(\\\"spark.executor.memory\\\", \\\"4g\\\")\\n\",\n    \"conf.set(\\\"spark.sql.files.maxPartitionBytes\\\", \\\"128m\\\")\\n\",\n    \"conf.set(\\\"spark.executor.resource.gpu.amount\\\", \\\"1\\\")\\n\",\n    \"conf.set(\\\"spark.rapids.memory.pinnedPool.size\\\", \\\"2048m\\\")\\n\",\n    \"conf.set(\\\"spark.executor.memoryOverhead\\\", \\\"4096m\\\")\\n\",\n    \"conf.set(\\\"spark.dynamicAllocation.enabled\\\", \\\"false\\\")\\n\",\n    \"conf.set(\\\"spark.rapids.sql.format.json.read.enabled\\\",True)\\n\",\n    \"conf.set(\\\"spark.rapids.sql.castStringToTimestamp.enabled\\\",True)\\n\",\n    \"conf.set(\\\"spark.rapids.sql.expression.PercentRank\\\",False)\\n\",\n    \"conf.set(\\\"spark.rapids.sql.castDecimalToString.enabled\\\",True)\\n\",\n    \"conf.set(\\\"spark.rapids.sql.hasExtendedYearValues\\\",False)\\n\",\n    \"conf.set(\\\"spark.rapids.sql.enabled\\\",True)\\n\",\n    \"conf.set(\\\"spark.plugins\\\", \\\"com.nvidia.spark.SQLPlugin\\\")\\n\",\n    \"conf.set(\\\"spark.rapids.sql.allowMultipleJars\\\", \\\"ALWAYS\\\")\\n\",\n    \"\\n\",\n    \"spark = SparkSession.builder \\\\\\n\",\n    \"                    .config(conf=conf) \\\\\\n\",\n    \"                    .getOrCreate()\\n\",\n    \"# create a SparkSession\\n\",\n    \"spark = SparkSession.builder.appName(\\\"RetailInvMgmt\\\").getOrCreate()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import os\\n\",\n    \"# You need to update these to your real paths!\\n\",\n    \"dataRoot = os.getenv(\\\"DATA_ROOT\\\", 'path/to/your/datasets')\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from pyspark.sql.functions import *\\n\",\n    \"from pyspark.sql.types import *\\n\",\n    \"from pyspark.sql.window import Window\\n\",\n    \"\\n\",\n    \"start = time.time()\\n\",\n    \"\\n\",\n    \"def clean_data(df):\\n\",\n    \"    # remove missing values\\n\",\n    \"    df = df.dropna()\\n\",\n    \"    # remove duplicate data\\n\",\n    \"    df = df.dropDuplicates()\\n\",\n    \"    return df\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def read_data(spark, format, file_path):\\n\",\n    \"    if format==\\\"csv\\\":\\n\",\n    \"        return spark.read.format(format).load(file_path,header=True)\\n\",\n    \"    else:\\n\",\n    \"        return spark.read.format(format).load(file_path)\\n\",\n    \"\\n\",\n    \"# read sales data\\n\",\n    \"sales_df = read_data(spark, \\\"csv\\\", dataRoot+\\\"/sales/\\\")\\n\",\n    \"\\n\",\n    \"# read stock data\\n\",\n    \"stock_df = read_data(spark, \\\"json\\\", dataRoot+\\\"/stock/\\\")\\n\",\n    \"\\n\",\n    \"# read supplier data\\n\",\n    \"supplier_df = read_data(spark, \\\"json\\\", dataRoot+\\\"/supplier/\\\")\\n\",\n    \"\\n\",\n    \"# read customer data\\n\",\n    \"customer_df = read_data(spark, \\\"csv\\\", dataRoot+\\\"/customer/\\\")\\n\",\n    \"\\n\",\n    \"# read market data\\n\",\n    \"market_df = read_data(spark, \\\"csv\\\", dataRoot+\\\"/market/\\\")\\n\",\n    \"\\n\",\n    \"# read logistic data\\n\",\n    \"logistic_df = read_data(spark, \\\"csv\\\", dataRoot+\\\"/logistic/\\\")\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"# data cleaning\\n\",\n    \"sales_df = clean_data(sales_df)\\n\",\n    \"stock_df = clean_data(stock_df)\\n\",\n    \"supplier_df = clean_data(supplier_df)\\n\",\n    \"customer_df = clean_data(customer_df)\\n\",\n    \"market_df = clean_data(market_df)\\n\",\n    \"logistic_df = clean_data(logistic_df)\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"# convert date columns to date type\\n\",\n    \"sales_df = sales_df.withColumn(\\\"date_of_sale\\\", to_date(col(\\\"date_of_sale\\\")))\\n\",\n    \"stock_df = stock_df.withColumn(\\\"date_received\\\", to_date(col(\\\"date_received\\\")))\\n\",\n    \"supplier_df = supplier_df.withColumn(\\\"date_ordered\\\", to_date(col(\\\"date_ordered\\\")))\\n\",\n    \"\\n\",\n    \"# standardize case of string columns\\n\",\n    \"sales_df = sales_df.withColumn(\\\"product_name\\\", upper(col(\\\"product_name\\\")))\\n\",\n    \"stock_df = stock_df.withColumn(\\\"product_name\\\", upper(col(\\\"product_name\\\")))\\n\",\n    \"stock_df = stock_df.withColumn(\\\"location\\\", upper(col(\\\"location\\\")))\\n\",\n    \"supplier_df = supplier_df.withColumn(\\\"product_name\\\", upper(col(\\\"product_name\\\")))\\n\",\n    \"customer_df = customer_df.withColumn(\\\"customer_name\\\", upper(col(\\\"customer_name\\\")))\\n\",\n    \"market_df = market_df.withColumn(\\\"product_name\\\", upper(col(\\\"product_name\\\")))\\n\",\n    \"logistic_df = logistic_df.withColumn(\\\"product_name\\\", upper(col(\\\"product_name\\\")))\\n\",\n    \"\\n\",\n    \"# remove leading and trailing whitespaces\\n\",\n    \"sales_df = sales_df.withColumn(\\\"product_name\\\", trim(col(\\\"product_name\\\")))\\n\",\n    \"stock_df = stock_df.withColumn(\\\"location\\\", trim(col(\\\"location\\\")))\\n\",\n    \"\\n\",\n    \"supplier_df = supplier_df.withColumn(\\\"product_name\\\", trim(col(\\\"product_name\\\")))\\n\",\n    \"customer_df = customer_df.withColumn(\\\"customer_name\\\", trim(col(\\\"customer_name\\\")))\\n\",\n    \"market_df = market_df.withColumn(\\\"product_name\\\", trim(col(\\\"product_name\\\")))\\n\",\n    \"logistic_df = logistic_df.withColumn(\\\"product_name\\\", trim(col(\\\"product_name\\\")))\\n\",\n    \"\\n\",\n    \"# check for invalid values\\n\",\n    \"sales_df = sales_df.filter(col(\\\"product_name\\\").isNotNull())\\n\",\n    \"stock_df = stock_df.filter(col(\\\"location\\\").isNotNull())\\n\",\n    \"customer_df = customer_df.filter(col(\\\"gender\\\").isin(\\\"male\\\",\\\"female\\\"))\\n\",\n    \"market_df = market_df.filter(col(\\\"product_name\\\").isNotNull())\\n\",\n    \"logistic_df = logistic_df.filter(col(\\\"product_name\\\").isNotNull())\\n\",\n    \"\\n\",\n    \"#drop extra columns\\n\",\n    \"market_df = market_df.drop(\\\"price\\\")\\n\",\n    \"supplier_df = supplier_df.drop(\\\"price\\\")\\n\",\n    \"\\n\",\n    \"# join all data\\n\",\n    \"data_int = sales_df.join(stock_df, \\\"product_name\\\",\\\"leftouter\\\").join(supplier_df, \\\"product_name\\\",\\\"leftouter\\\").join(market_df, \\\"product_name\\\",\\\"leftouter\\\").join(logistic_df, \\\"product_name\\\",\\\"leftouter\\\").join(customer_df, \\\"customer_id\\\",\\\"leftouter\\\")  \\n\",\n    \"\\n\",\n    \"# write the cleaned data\\n\",\n    \"os.makedirs(dataRoot+\\\"cleaned/\\\", exist_ok=True)\\n\",\n    \"data_int.write.mode(\\\"overwrite\\\").format(\\\"parquet\\\").save(dataRoot+\\\"/cleaned/\\\")\\n\",\n    \"\\n\",\n    \"end = time.time()\\n\",\n    \"\\n\",\n    \"print(\\\"Time taken on GPU for Data Cleaning: \\\", end - start)\\n\",\n    \"\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from pyspark.sql.functions import *\\n\",\n    \"from pyspark.sql.types import *\\n\",\n    \"from pyspark.sql.window import Window\\n\",\n    \"\\n\",\n    \"#DO VARIOUS RETAIL DATA ANALYTICS \\n\",\n    \"\\n\",\n    \"start = time.time()\\n\",\n    \"\\n\",\n    \"# read cleaned data\\n\",\n    \"\\n\",\n    \"data = spark.read.format(\\\"parquet\\\").load(dataRoot+\\\"/cleaned/\\\")\\n\",\n    \"\\n\",\n    \"#Case when statement to create a new column to indicate whether the product is perishable or not:\\n\",\n    \"\\n\",\n    \"data = data.withColumn(\\\"perishable\\\", when(col(\\\"shelf_life\\\") <= 30, \\\"yes\\\").otherwise(\\\"no\\\"))\\n\",\n    \"\\n\",\n    \"# You can use the when() and otherwise() functions to create new columns based on certain conditions:\\n\",\n    \"\\n\",\n    \"data = data.withColumn(\\\"sales_status\\\", when(col(\\\"quantity_sold\\\") > 50, \\\"good\\\").otherwise(\\\"bad\\\"))\\n\",\n    \"\\n\",\n    \"# create a window to perform time series analysis\\n\",\n    \"window = Window.partitionBy(\\\"product_name\\\").orderBy(\\\"date_of_sale\\\")\\n\",\n    \"\\n\",\n    \"# calculate the rolling average of sales for each product\\n\",\n    \"time_series_df = data.withColumn(\\\"rolling_avg_sales\\\", avg(\\\"quantity_sold\\\").over(window))\\n\",\n    \"\\n\",\n    \"# use window function for forecasting\\n\",\n    \"\\n\",\n    \"forecast_df = time_series_df.withColumn(\\\"prev_sales\\\", lag(\\\"rolling_avg_sales\\\").over(window))\\\\\\n\",\n    \"    .withColumn(\\\"next_sales\\\", lead(\\\"rolling_avg_sales\\\").over(window))\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"# Calculate the average price of a product, grouped by supplier\\n\",\n    \"forecast_df.groupBy(\\\"sup_id\\\").agg({\\\"price\\\": \\\"avg\\\"}).show()\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"# Calculate the total quantity in stock and total sales by supplier\\n\",\n    \"forecast_df.groupBy(\\\"sup_id\\\").agg({\\\"quantity_in_stock\\\": \\\"sum\\\", \\\"price\\\": \\\"sum\\\"}).show()\\n\",\n    \"\\n\",\n    \"#Calculate the number of perishable v/s non-perishable product per location\\n\",\n    \"forecast_df.groupBy(\\\"perishable\\\").agg({\\\"perishable\\\": \\\"count\\\"}).show()\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"#Calculate number of good v/s bad sales status per location\\n\",\n    \"forecast_df.groupBy(\\\"sales_status\\\").agg({\\\"sales_status\\\": \\\"count\\\"}).show()\\n\",\n    \"\\n\",\n    \"# Count the number of sales that contain a 10% off promotion\\n\",\n    \"countt = forecast_df.filter(forecast_df[\\\"contains_promotion\\\"].contains(\\\"10% off\\\")).count()\\n\",\n    \"print(countt)\\n\",\n    \"# Perform some complex analysis on the DataFrame\\n\",\n    \"\\n\",\n    \"# Calculate the total sales, quantity sold by product and location\\n\",\n    \"total_sales_by_product_location = forecast_df.groupBy(\\\"product_name\\\", \\\"location\\\").agg(sum(\\\"price\\\").alias(\\\"total_price\\\"),sum(\\\"quantity_ordered\\\").alias(\\\"total_quantity_sold\\\"),avg(\\\"quantity_sold\\\").alias(\\\"avg_quantity_sold\\\")).sort(desc(\\\"total_price\\\"))\\n\",\n    \"\\n\",\n    \"# Group the data by product_name\\n\",\n    \"grouped_df = forecast_df.groupBy(\\\"product_name\\\")\\n\",\n    \"\\n\",\n    \"#Sum the quantity_in_stock, quantity_ordered, quantity_sold, and (price * quantity_sold) for each group\\n\",\n    \"aggregated_df = grouped_df.agg(sum(\\\"quantity_in_stock\\\").alias(\\\"total_quantity_in_stock\\\"),avg(\\\"price\\\").alias(\\\"average_price\\\"),sum(\\\"quantity_ordered\\\").alias(\\\"total_quantity_ordered\\\"),sum(\\\"quantity_sold\\\").alias(\\\"total_quantity_sold\\\"),sum(col(\\\"price\\\") * col(\\\"quantity_sold\\\")).alias(\\\"total_sales\\\"),sum(\\\"prev_sales\\\").alias(\\\"total_prev_sales\\\"),sum(\\\"next_sales\\\").alias(\\\"total_next_sales\\\"),).sort(desc(\\\"total_sales\\\"))\\n\",\n    \"\\n\",\n    \"#WRITE THE AGGREGATES TO DISK\\n\",\n    \"aggregated_df.write.mode(\\\"overwrite\\\").format(\\\"parquet\\\").save(dataRoot+\\\"/app/data.parquet\\\")\\n\",\n    \"total_sales_by_product_location.write.mode(\\\"overwrite\\\").format(\\\"parquet\\\").save(dataRoot+\\\"/app1/data.parquet\\\")\\n\",\n    \"\\n\",\n    \"end = time.time()\\n\",\n    \"\\n\",\n    \"print(\\\"Time taken on GPU for Data Analysis: \\\", end - start)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"spark.stop()\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.9.19\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "examples/SQL+DF-Examples/retail-analytics/notebooks/python/retail-datagen.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Generating and Writing Data to GCS\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import os\\n\",\n    \"import pandas as pd\\n\",\n    \"import multiprocessing as mp\\n\",\n    \"import random\\n\",\n    \"\\n\",\n    \"# You need to update these to your real paths!\\n\",\n    \"dataRoot = os.getenv(\\\"DATA_ROOT\\\", '/path/to/your/datasets')\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"#We define the generate_data function which takes an integer i as input and generates sales data using random numbers. The generated data includes sales ID, product name, price, quantity sold, date of sale, and customer ID. The function returns a tuple of the generated data.\\n\",\n    \"def generate_data(i):\\n\",\n    \"    sales_id = \\\"s_{}\\\".format(i)\\n\",\n    \"    product_name = \\\"Product_{}\\\".format(i)\\n\",\n    \"    price = random.uniform(1,10)\\n\",\n    \"    quantity_sold = random.randint(1,10)\\n\",\n    \"    date_of_sale = \\\"2022-{}-{}\\\".format(random.randint(1,12), random.randint(1,28))\\n\",\n    \"    customer_id = \\\"c_{}\\\".format(random.randint(1,10000))\\n\",\n    \"    return (sales_id, product_name, price, quantity_sold, date_of_sale, customer_id)\\n\",\n    \"\\n\",\n    \"with mp.Pool(mp.cpu_count()) as p:\\n\",\n    \"    sales_data = p.map(generate_data, range(1000000))\\n\",\n    \"    sales_data = list(sales_data)\\n\",\n    \"    \\n\",\n    \"print(\\\"write to gcs started\\\")\\n\",\n    \"sales_df = pd.DataFrame(sales_data, columns=[\\\"sales_id\\\", \\\"product_name\\\", \\\"price\\\", \\\"quantity_sold\\\", \\\"date_of_sale\\\", \\\"customer_id\\\"])\\n\",\n    \"os.makedirs(dataRoot+\\\"/sales/\\\", exist_ok=True)\\n\",\n    \"sales_df.to_csv(dataRoot+\\\"/sales/data.csv\\\", index=False, header=True)\\n\",\n    \"print(\\\"Write to gcs completed\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def generate_data(i):\\n\",\n    \"    product_name = \\\"Product_{}\\\".format(i)\\n\",\n    \"    shelf_life = random.randint(1,365)\\n\",\n    \"    contains_promotion = \\\"{} % off\\\".format(random.randint(0,10))\\n\",\n    \"    quantity_in_stock = random.randint(1,100)\\n\",\n    \"    location = \\\"Location_{}\\\".format(random.randint(1,100))\\n\",\n    \"    date_received = \\\"2022-{}-{}\\\".format(random.randint(1,12), random.randint(1,28))\\n\",\n    \"    return (product_name,shelf_life,contains_promotion,quantity_in_stock, location, date_received)\\n\",\n    \"\\n\",\n    \"with mp.Pool(mp.cpu_count()) as p:\\n\",\n    \"    stock_data = p.map(generate_data, range(50000))\\n\",\n    \"    stock_data = list(stock_data)\\n\",\n    \"    \\n\",\n    \"stock_df = pd.DataFrame(stock_data,  columns=[\\\"product_name\\\",\\\"shelf_life\\\",\\\"contains_promotion\\\",\\\"quantity_in_stock\\\", \\\"location\\\", \\\"date_received\\\"])\\n\",\n    \"os.makedirs(dataRoot+\\\"/stock/\\\", exist_ok=True)\\n\",\n    \"stock_df.to_json(dataRoot+\\\"/stock/stock.json\\\", orient='records')\\n\",\n    \"print(\\\"Write to gcs completed\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def generate_data(i):\\n\",\n    \"    sup_id = \\\"s_{}\\\".format(i)\\n\",\n    \"    product_name = \\\"Product_{}\\\".format(i)\\n\",\n    \"    quantity_ordered = random.randint(1,100)\\n\",\n    \"    price = random.uniform(1,10)\\n\",\n    \"    date_ordered = \\\"2022-{}-{}\\\".format(random.randint(1,12), random.randint(1,28))\\n\",\n    \"    return (sup_id,product_name, quantity_ordered, price, date_ordered)\\n\",\n    \"\\n\",\n    \"with mp.Pool(mp.cpu_count()) as p:\\n\",\n    \"    supplier_data = p.map(generate_data, range(50000))\\n\",\n    \"    supplier_data = list(supplier_data)\\n\",\n    \"    \\n\",\n    \"supplier_df = pd.DataFrame(supplier_data,  columns=[\\\"sup_id\\\",\\\"product_name\\\", \\\"quantity_ordered\\\", \\\"price\\\", \\\"date_ordered\\\"])\\n\",\n    \"os.makedirs(dataRoot+\\\"/supplier/\\\", exist_ok=True)\\n\",\n    \"supplier_df.to_json(dataRoot+\\\"/supplier/supplier.json\\\", orient='records')\\n\",\n    \"print(\\\"Write to gcs completed\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def generate_data(i):\\n\",\n    \"    customer_id = \\\"c_{}\\\".format(i)\\n\",\n    \"    customer_name = \\\"Customer_{}\\\".format(i)\\n\",\n    \"    age = random.randint(20,70)\\n\",\n    \"    gender = random.choice([\\\"male\\\", \\\"female\\\"])\\n\",\n    \"    purchase_history = random.randint(1,100)\\n\",\n    \"    contact_info = \\\"email_{}@gmail.com\\\".format(i)\\n\",\n    \"    return (customer_id,customer_name, age, gender, purchase_history, contact_info)\\n\",\n    \"\\n\",\n    \"with mp.Pool(mp.cpu_count()) as p:\\n\",\n    \"    customer_data = p.map(generate_data, range(1000))\\n\",\n    \"    customer_data = list(customer_data)\\n\",\n    \"    \\n\",\n    \"customer_df = pd.DataFrame(customer_data,  columns=[\\\"customer_id\\\",\\\"customer_name\\\", \\\"age\\\", \\\"gender\\\", \\\"purchase_history\\\", \\\"contact_info\\\"])\\n\",\n    \"os.makedirs(dataRoot+\\\"/customer/\\\", exist_ok=True)\\n\",\n    \"customer_df.to_csv(dataRoot+\\\"/customer/customer.csv\\\", index=False,header=True)\\n\",\n    \"print(\\\"Write to gcs completed\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def generate_data(i):\\n\",\n    \"    product_name = \\\"Product_{}\\\".format(i)\\n\",\n    \"    competitor_price = random.uniform(1,100)\\n\",\n    \"    sales_trend = random.randint(1,100)\\n\",\n    \"    demand_forecast = random.randint(1,100)\\n\",\n    \"    return (product_name, competitor_price, sales_trend, demand_forecast)\\n\",\n    \"\\n\",\n    \"with mp.Pool(mp.cpu_count()) as p:\\n\",\n    \"    market_data = p.map(generate_data, range(500000))\\n\",\n    \"    market_data = list(market_data)\\n\",\n    \"    \\n\",\n    \"market_df = pd.DataFrame(market_data,  columns=[\\\"product_name\\\", \\\"competitor_price\\\", \\\"sales_trend\\\", \\\"demand_forecast\\\"])\\n\",\n    \"os.makedirs(dataRoot+\\\"/market/\\\", exist_ok=True)\\n\",\n    \"market_df.to_csv(dataRoot+\\\"/market/market.csv\\\", index=False,header=True)\\n\",\n    \"print(\\\"Write to gcs completed\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def generate_data(i):\\n\",\n    \"    product_name = \\\"Product_{}\\\".format(i)\\n\",\n    \"    shipping_cost = random.uniform(1,10)\\n\",\n    \"    transportation_cost = random.uniform(1,10)\\n\",\n    \"    warehouse_cost = random.uniform(1,10)\\n\",\n    \"    return (product_name, shipping_cost, transportation_cost, warehouse_cost)\\n\",\n    \"\\n\",\n    \"with mp.Pool(mp.cpu_count()) as p:\\n\",\n    \"    logistic_data = p.map(generate_data, range(500000))\\n\",\n    \"    logistic_data = list(logistic_data)\\n\",\n    \"    \\n\",\n    \"logistic_df = pd.DataFrame(logistic_data,  columns=[\\\"product_name\\\", \\\"shipping_cost\\\", \\\"transportation_cost\\\", \\\"warehouse_cost\\\"])\\n\",\n    \"os.makedirs(dataRoot+\\\"/logistic/\\\", exist_ok=True)\\n\",\n    \"logistic_df.to_csv(dataRoot+\\\"/logistic/logistic.csv\\\", index=False,header=True)\\n\",\n    \"print(\\\"Write to gcs completed\\\")\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.9.19\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "examples/SQL+DF-Examples/tpcds/README.md",
    "content": "# TPC-DS Scale Factor 10 (GiB) - CPU Spark vs GPU Spark\n\n[TPC-DS](https://www.tpc.org/tpcds/) is a decision support benchmark often used to evaluate\nperformance of OLAP Databases and Big Data systems.\n\nThe notebook in this folder runs a user-specified subset of the TPC-DS queries on the\nScale Factor 10 (GiB) dataset. It uses [TPCDS PySpark](https://github.com/cerndb/SparkTraining/blob/master/notebooks/TPCDS_PySpark_CERN_SWAN_getstarted.ipynb)\nto execute TPC-DS queries with SparkSQL on GPU and CPU capturing the metrics\nas a Pandas dataframe. It then plots a comparison bar chart visualizing\nthe GPU acceleration achieved for the queries run with RAPIDS Spark in this\nvery notebook.\n\nThis notebook can be opened and executed using standard\n\n- Jupyter(Lab)\n- in VSCode with Jupyter [extension](https://marketplace.visualstudio.com/items?itemName=ms-toolsai.jupyter)\n\nIt can also be opened and evaluated on hosted Notebook environments. Use the link below to launch on\nGoogle Colab and connect it to a [GPU instance](https://research.google.com/colaboratory/faq.html).\n\n <a target=\"_blank\" href=\"https://colab.research.google.com/github/NVIDIA/spark-rapids-examples/blob/main/examples/SQL%2BDF-Examples/tpcds/notebooks/TPCDS-SF10.ipynb\">\n  <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n</a>\n\nHere is the bar chart from a recent execution on Google Colab's T4 High RAM instance using\nRAPIDS Spark 26.02.0 with Apache Spark 3.5.0\n\n![tpcds-speedup](/docs/img/guides/tpcds.png)\n"
  },
  {
    "path": "examples/SQL+DF-Examples/tpcds/notebooks/TPCDS-SF10.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"editable\": true,\n    \"id\": \"HtgYO0bXEBrN\",\n    \"slideshow\": {\n     \"slide_type\": \"\"\n    },\n    \"tags\": []\n   },\n   \"source\": [\n    \"# TPC-DS 10GiB - Apache Spark acceleration on GPU with RAPIDS Spark\\n\",\n    \"\\n\",\n    \"based on https://colab.research.google.com/github/LucaCanali/Miscellaneous/blob/master/Performance_Testing/TPCDS_PySpark/Labs_and_Notes/TPCDS_PySpark_getstarted.ipynb#scrollTo=6bab7772\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Install packages\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"spark_version='3.5.5'\\n\",\n    \"rapids_version='26.02.0'\\n\",\n    \"sparkmeasure_version='0.27'\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"executionInfo\": {\n     \"elapsed\": 1630,\n     \"status\": \"ok\",\n     \"timestamp\": 1729291037060,\n     \"user\": {\n      \"displayName\": \"Gera Shegalov\",\n      \"userId\": \"07399839501144323282\"\n     },\n     \"user_tz\": 420\n    },\n    \"id\": \"Yq230e1Nho_M\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"%pip install --quiet \\\\\\n\",\n    \"  tpcds_pyspark \\\\\\n\",\n    \"  pyspark=={spark_version} \\\\\\n\",\n    \"  pandas \\\\\\n\",\n    \"  sparkmeasure=={sparkmeasure_version}.0 \\\\\\n\",\n    \"  matplotlib\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Import modules\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"executionInfo\": {\n     \"elapsed\": 1052,\n     \"status\": \"ok\",\n     \"timestamp\": 1729291488008,\n     \"user\": {\n      \"displayName\": \"Gera Shegalov\",\n      \"userId\": \"07399839501144323282\"\n     },\n     \"user_tz\": 420\n    },\n    \"id\": \"uq_LmKsB36R_\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"from importlib.resources import files\\n\",\n    \"from pyspark.sql import SparkSession\\n\",\n    \"from tpcds_pyspark import TPCDS\\n\",\n    \"import glob\\n\",\n    \"import os\\n\",\n    \"import pandas as pd\\n\",\n    \"import re\\n\",\n    \"import time\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"edMCFrhvgDS8\"\n   },\n   \"source\": [\n    \"# Download TPC-DS 10GiB Scale Parquet Dataset\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"executionInfo\": {\n     \"elapsed\": 41530,\n     \"status\": \"ok\",\n     \"timestamp\": 1729292943990,\n     \"user\": {\n      \"displayName\": \"Gera Shegalov\",\n      \"userId\": \"07399839501144323282\"\n     },\n     \"user_tz\": 420\n    },\n    \"id\": \"DY8TkhPQTjbB\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"if not os.path.isdir('tpcds_10'):\\n\",\n    \"  if not os.path.isfile('tpcds_10.zip'):\\n\",\n    \"    !wget https://sparkdltrigger.web.cern.ch/sparkdltrigger/TPCDS/tpcds_10.zip\\n\",\n    \"  !unzip -q tpcds_10.zip\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"6tgF9LWcgUEs\"\n   },\n   \"source\": [\n    \"# Init a SparkSession with RAPIDS Spark\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Detect Scala Version used in PySpark package\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"pyspark_files = files('pyspark')\\n\",\n    \"spark_sql_jar_path, *_ = glob.glob(f\\\"{pyspark_files}/*/spark-sql_*jar\\\")\\n\",\n    \"spark_sql_jar = os.path.basename(spark_sql_jar_path)\\n\",\n    \"scala_version = re.search(r'^spark-sql_(\\\\d+.\\\\d+)-.*\\\\.jar$', spark_sql_jar).group(1)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Create Spark Session\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"executionInfo\": {\n     \"elapsed\": 39420,\n     \"status\": \"ok\",\n     \"timestamp\": 1729289098419,\n     \"user\": {\n      \"displayName\": \"Gera Shegalov\",\n      \"userId\": \"07399839501144323282\"\n     },\n     \"user_tz\": 420\n    },\n    \"id\": \"-L-wMZTpfYxs\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"extra_packages = [\\n\",\n    \"  f\\\"com.nvidia:rapids-4-spark_{scala_version}:{rapids_version}\\\",\\n\",\n    \"  f\\\"ch.cern.sparkmeasure:spark-measure_{scala_version}:{sparkmeasure_version}\\\"\\n\",\n    \"]\\n\",\n    \"spark = (\\n\",\n    \"    SparkSession.builder\\n\",\n    \"      .appName('TPCDS PySpark RAPIDS=ON/OFF')\\n\",\n    \"      .config('spark.driver.memory', '5g')\\n\",\n    \"      .config('spark.plugins', 'com.nvidia.spark.SQLPlugin')\\n\",\n    \"      .config('spark.jars.packages', ','.join(extra_packages))\\n\",\n    \"      .getOrCreate()\\n\",\n    \")\\n\",\n    \"spark\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"_4sYje2NiNA7\"\n   },\n   \"source\": [\n    \"# Verify SQL Acceleration on GPU can be enabled by checking the query plan\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\",\n     \"height\": 0\n    },\n    \"executionInfo\": {\n     \"elapsed\": 5921,\n     \"status\": \"ok\",\n     \"timestamp\": 1729289104337,\n     \"user\": {\n      \"displayName\": \"Gera Shegalov\",\n      \"userId\": \"07399839501144323282\"\n     },\n     \"user_tz\": 420\n    },\n    \"id\": \"nUyQBKtkga9y\",\n    \"outputId\": \"5d493a51-58de-4aed-bbaf-d73c82769836\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"spark.conf.set('spark.rapids.sql.enabled', True)\\n\",\n    \"sum_df = spark.range(1000).selectExpr('SUM(*)')\\n\",\n    \"sum_df.collect()\\n\",\n    \"sum_df.explain()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# TPCDS App\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\",\n     \"height\": 0\n    },\n    \"executionInfo\": {\n     \"elapsed\": 4,\n     \"status\": \"ok\",\n     \"timestamp\": 1729289104337,\n     \"user\": {\n      \"displayName\": \"Gera Shegalov\",\n      \"userId\": \"07399839501144323282\"\n     },\n     \"user_tz\": 420\n    },\n    \"id\": \"BYPgafupcxaY\",\n    \"outputId\": \"fdfb427f-6cc0-4dff-9295-dc44e6ead132\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# https://github.com/LucaCanali/Miscellaneous/tree/master/Performance_Testing/TPCDS_PySpark/tpcds_pyspark/Queries\\n\",\n    \"\\n\",\n    \"# queries = None to run all (takes much longer)\\n\",\n    \"queries = None\\n\",\n    \"queries = [\\n\",\n    \"    'q14a',\\n\",\n    \"    'q14b',\\n\",\n    \"    'q23a',\\n\",\n    \"    'q23b',\\n\",\n    \"    # 'q24a',\\n\",\n    \"    # 'q24b',\\n\",\n    \"    # 'q88',\\n\",\n    \"]\\n\",\n    \"\\n\",\n    \"demo_start = time.time()\\n\",\n    \"tpcds = TPCDS(data_path='./tpcds_10', num_runs=1, queries_repeat_times=1, queries=queries)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"0Yaaw2GfliC5\"\n   },\n   \"source\": [\n    \"## Register TPC-DS tables before running queries\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\",\n     \"height\": 0\n    },\n    \"executionInfo\": {\n     \"elapsed\": 2992,\n     \"status\": \"ok\",\n     \"timestamp\": 1729289107327,\n     \"user\": {\n      \"displayName\": \"Gera Shegalov\",\n      \"userId\": \"07399839501144323282\"\n     },\n     \"user_tz\": 420\n    },\n    \"id\": \"kfsHodFqdDl7\",\n    \"outputId\": \"5a810f9d-e353-456c-b7bb-48ae3290178a\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"tpcds.map_tables()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"bs6X_54UhuqJ\"\n   },\n   \"source\": [\n    \"## Measure Apache Spark GPU\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\",\n     \"height\": 0\n    },\n    \"executionInfo\": {\n     \"elapsed\": 45658,\n     \"status\": \"ok\",\n     \"timestamp\": 1729290819190,\n     \"user\": {\n      \"displayName\": \"Gera Shegalov\",\n      \"userId\": \"07399839501144323282\"\n     },\n     \"user_tz\": 420\n    },\n    \"id\": \"8vXDasUom70g\",\n    \"outputId\": \"adccdd7f-99f0-4c82-d600-056b59f53933\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"tpcds.spark.conf.set('spark.rapids.sql.enabled', True)\\n\",\n    \"%time tpcds.run_TPCDS()\\n\",\n    \"gpu_grouped_results = tpcds.grouped_results_pdf.copy()\\n\",\n    \"gpu_grouped_results\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"ulyFidEPhg_l\"\n   },\n   \"source\": [\n    \"## Measure Apache Spark CPU\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\",\n     \"height\": 0\n    },\n    \"executionInfo\": {\n     \"elapsed\": 135425,\n     \"status\": \"ok\",\n     \"timestamp\": 1729289242749,\n     \"user\": {\n      \"displayName\": \"Gera Shegalov\",\n      \"userId\": \"07399839501144323282\"\n     },\n     \"user_tz\": 420\n    },\n    \"id\": \"Dg0itS7cdIf4\",\n    \"outputId\": \"4ce1f8a2-5ac7-4805-e6f6-37a8acb7e039\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"tpcds.spark.conf.set('spark.rapids.sql.enabled', False)\\n\",\n    \"%time tpcds.run_TPCDS()\\n\",\n    \"cpu_grouped_results = tpcds.grouped_results_pdf.copy()\\n\",\n    \"cpu_grouped_results\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"PcZ12b13h3cq\"\n   },\n   \"source\": [\n    \"## Show Speedup Factors achieved by GPU\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"executionInfo\": {\n     \"elapsed\": 5,\n     \"status\": \"ok\",\n     \"timestamp\": 1729289293047,\n     \"user\": {\n      \"displayName\": \"Gera Shegalov\",\n      \"userId\": \"07399839501144323282\"\n     },\n     \"user_tz\": 420\n    },\n    \"id\": \"cJxS9Nqi3AQj\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"res = pd.merge(cpu_grouped_results, gpu_grouped_results, on='query', how='inner', suffixes=['_cpu', '_gpu'])\\n\",\n    \"res['speedup'] = res['elapsedTime_cpu'] / res['elapsedTime_gpu']\\n\",\n    \"res = res.sort_values(by='elapsedTime_cpu', ascending=False)\\n\",\n    \"res\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"demo_dur = time.time() - demo_start\\n\",\n    \"print(f\\\"CPU and GPU run took: {demo_dur=} seconds\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\",\n     \"height\": 510\n    },\n    \"executionInfo\": {\n     \"elapsed\": 1041,\n     \"status\": \"ok\",\n     \"timestamp\": 1729289294084,\n     \"user\": {\n      \"displayName\": \"Gera Shegalov\",\n      \"userId\": \"07399839501144323282\"\n     },\n     \"user_tz\": 420\n    },\n    \"id\": \"wn7u33fZlUJL\",\n    \"outputId\": \"8d1ef757-e5c2-4761-fc58-f65f833bdffc\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"res.plot(title='TPC-DS query elapsedTime on CPU vs GPU (lower is better)', \\n\",\n    \"         kind='bar', x='query', y=['elapsedTime_cpu', 'elapsedTime_gpu'],\\n\",\n    \"         color=['blue', '#76B900'])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\",\n     \"height\": 510\n    },\n    \"executionInfo\": {\n     \"elapsed\": 381,\n     \"status\": \"ok\",\n     \"timestamp\": 1729289294462,\n     \"user\": {\n      \"displayName\": \"Gera Shegalov\",\n      \"userId\": \"07399839501144323282\"\n     },\n     \"user_tz\": 420\n    },\n    \"id\": \"hW-LZponwGQE\",\n    \"outputId\": \"ed456120-ca7f-4c91-a2bf-de87c2401f0c\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"res.plot(title='Speedup factors of TPC-DS queries on GPU', kind='bar', \\n\",\n    \"         x='query', y='speedup', color='#76B900')\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"Pk2TR4yimNqP\"\n   },\n   \"source\": [\n    \"# Run Queries interactively\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"tpcds_pyspark_files = files('tpcds_pyspark')\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"executionInfo\": {\n     \"elapsed\": 4,\n     \"status\": \"ok\",\n     \"timestamp\": 1729289294462,\n     \"user\": {\n      \"displayName\": \"Gera Shegalov\",\n      \"userId\": \"07399839501144323282\"\n     },\n     \"user_tz\": 420\n    },\n    \"id\": \"RpIl6NyNzqYU\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"query = 'q88'\\n\",\n    \"with open(f\\\"{tpcds_pyspark_files}/Queries/{query}.sql\\\") as f:\\n\",\n    \"  q = f.read()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\",\n     \"height\": 0\n    },\n    \"executionInfo\": {\n     \"elapsed\": 3,\n     \"status\": \"ok\",\n     \"timestamp\": 1729289294462,\n     \"user\": {\n      \"displayName\": \"Gera Shegalov\",\n      \"userId\": \"07399839501144323282\"\n     },\n     \"user_tz\": 420\n    },\n    \"id\": \"cuCVL1Ed1lQd\",\n    \"outputId\": \"d256f4b7-e0e2-450c-ba88-aff0d7571510\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"print(q)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\",\n     \"height\": 0\n    },\n    \"editable\": true,\n    \"executionInfo\": {\n     \"elapsed\": 1470,\n     \"status\": \"ok\",\n     \"timestamp\": 1729289295930,\n     \"user\": {\n      \"displayName\": \"Gera Shegalov\",\n      \"userId\": \"07399839501144323282\"\n     },\n     \"user_tz\": 420\n    },\n    \"id\": \"n4QUdq17040i\",\n    \"outputId\": \"7d7c7562-fae6-4426-97a7-ec23b8fe2f0d\",\n    \"slideshow\": {\n     \"slide_type\": \"\"\n    },\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"spark.conf.set('spark.rapids.sql.enabled', True)\\n\",\n    \"df  = spark.sql(q)\\n\",\n    \"%time df.collect()\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"accelerator\": \"GPU\",\n  \"colab\": {\n   \"gpuType\": \"T4\",\n   \"machine_shape\": \"hm\",\n   \"provenance\": []\n  },\n  \"kernelspec\": {\n   \"display_name\": \".venv\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.12.3\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 4\n}\n"
  },
  {
    "path": "examples/UDF-Examples/RAPIDS-accelerated-UDFs/Dockerfile",
    "content": "#\n# Copyright (c) 2021-2026, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n\n# A container that can be used to build UDF native code against libcudf\nARG CUDA_VERSION=12.9.1\nARG LINUX_VERSION=rockylinux8\n\nFROM nvidia/cuda:${CUDA_VERSION}-devel-${LINUX_VERSION}\n\nARG TOOLSET_VERSION=13\nENV TOOLSET_VERSION=13\nARG PARALLEL_LEVEL=10\nENV PARALLEL_LEVEL=10\n\n### Install basic requirements\nRUN dnf --enablerepo=powertools install -y \\\n  gcc-toolset-${TOOLSET_VERSION} \\\n  git \\\n  java-1.8.0-openjdk \\\n  maven \\\n  ninja-build \\\n  patch \\\n  python39 \\\n  scl-utils \\\n  tar \\\n  wget \\\n  zlib-devel \\\n  && alternatives --set python /usr/bin/python3\n\n# 3.22.3: CUDA architecture 'native' support + flexible CMAKE_<LANG>_*_LAUNCHER for ccache\nARG CMAKE_VERSION=3.30.4\n# default x86_64 from x86 build, aarch64 cmake for arm build\nARG CMAKE_ARCH=x86_64\nRUN cd /usr/local && wget --quiet https://github.com/Kitware/CMake/releases/download/v${CMAKE_VERSION}/cmake-${CMAKE_VERSION}-linux-${CMAKE_ARCH}.tar.gz && \\\n   tar zxf cmake-${CMAKE_VERSION}-linux-${CMAKE_ARCH}.tar.gz && \\\n   rm cmake-${CMAKE_VERSION}-linux-${CMAKE_ARCH}.tar.gz\nENV PATH /usr/local/cmake-${CMAKE_VERSION}-linux-${CMAKE_ARCH}/bin:$PATH\n\n# ccache for interactive builds\nARG CCACHE_VERSION=4.11.2\nRUN cd /tmp && wget --quiet https://github.com/ccache/ccache/releases/download/v${CCACHE_VERSION}/ccache-${CCACHE_VERSION}.tar.gz && \\\n   tar zxf ccache-${CCACHE_VERSION}.tar.gz && \\\n   rm ccache-${CCACHE_VERSION}.tar.gz && \\\n   cd ccache-${CCACHE_VERSION} && \\\n   mkdir build && \\\n   cd build && \\\n   scl enable gcc-toolset-${TOOLSET_VERSION} \\\n      \"cmake .. \\\n         -DCMAKE_BUILD_TYPE=Release \\\n         -DZSTD_FROM_INTERNET=ON \\\n         -DREDIS_STORAGE_BACKEND=OFF && \\\n      cmake --build . --parallel ${PARALLEL_LEVEL} --target install\" && \\\n   cd ../.. && \\\n   rm -rf ccache-${CCACHE_VERSION}\n\nENTRYPOINT /usr/bin/scl enable gcc-toolset-${TOOLSET_VERSION} -- bash\n"
  },
  {
    "path": "examples/UDF-Examples/RAPIDS-accelerated-UDFs/README.md",
    "content": "# RAPIDS Accelerated UDF Examples\n\nThis project contains sample implementations of RAPIDS accelerated user-defined functions.\n\nThe ideal solution would be to replace the UDF with a series of DataFrame or SQL operations. If that\nis not possible, we also provide\na [UDF compiler extension](https://nvidia.github.io/spark-rapids/docs/additional-functionality/udf-to-catalyst-expressions.html)\nto translate UDFs to Catalyst expressions. The extension is limited to only support compiling simple\noperations. For complicated cases, you can choose to implement a RAPIDS accelerated UDF.\n\n## Spark Scala UDF Examples\n\n[URLDecode](src/main/scala/com/nvidia/spark/rapids/udf/scala/URLDecode.scala)\nis the simplest demo for getting started. From the code you can see there is an original CPU\nimplementation provided by the `apply` method. We only need to implement the RapidsUDF interface\nwhich provides a single method we need to override called\n`evaluateColumnar`. The CPU URLDecode function processes the input row by row, but the GPU\nevaluateColumnar returns a cudf ColumnVector, because the GPU get its speed by performing operations\non many rows at a time. In the `evaluateColumnar` function, there is a cudf implementation of URL\ndecode that we're leveraging, so we don't need to write any native C++ code. This is all done\nthrough the [Java APIs of RAPIDS cudf](https://docs.rapids.ai/api/cudf-java/legacy). The benefit to\nimplement via the Java API is ease of development, but the memory model is not friendly for doing\nGPU operations because the JVM makes the assumption that everything we're trying to do is in heap\nmemory. We need to free the GPU resources in a timely manner with try-finally blocks. Note that we\nneed to implement both CPU and GPU functions so the UDF will still work if a higher-level operation\ninvolving the RAPIDS accelerated UDF falls back to the CPU.\n\n- [URLDecode](src/main/scala/com/nvidia/spark/rapids/udf/scala/URLDecode.scala)\n  decodes URL-encoded strings using the\n  [Java APIs of RAPIDS cudf](https://docs.rapids.ai/api/cudf-java/legacy)\n- [URLEncode](src/main/scala/com/nvidia/spark/rapids/udf/scala/URLEncode.scala)\n  URL-encodes strings using the\n  [Java APIs of RAPIDS cudf](https://docs.rapids.ai/api/cudf-java/legacy)\n\n## Spark Java UDF Examples\n\nBelow are some examples for implementing RAPIDS accelerated Scala UDF via JNI and native code. If\nthere is no existing simple Java API we could leverage, we can write native custom code.\nTake [CosineSimilarity](src/main/java/com/nvidia/spark/rapids/udf/java/CosineSimilarity.java) as the\nexample, the Java class for the UDF is similar as the previous URLDecode/URLEncode demo. We need to\nimplement a cosineSimilarity function in C++ code and goes into the native code as quickly as\npossible, because it is easier to write the code safely. In the native code, it `reinterpret_cast`\nthe input to a column view, do some sanity checking and convert to list column views, then compute\nthe cosine similarity, finally return the unique pointer to a column, release the underlying\nresources. On Java side we are going to wrap it in a column vector and own that resource.\nIn `cosine_similarity.cu` we implement the computation as the actual CUDA kernel. In the CUDA kernel\nwe can leverage the [Thrust template library](https://docs.nvidia.com/cuda/thrust/index.html) to\nwrite the standard algorithms for GPU parallelizing code. The benefit of implementing the UDF in\nnative code is for maximum control over GPU memory utilization and performance. However the\ntrade-off is a more complicated build environment, as we need to build against libcudf with\nsignificantly longer build times. Implementing a RAPIDS accelerated UDF in native code is a\nsignificant effort.\n\n- [URLDecode](src/main/java/com/nvidia/spark/rapids/udf/java/URLDecode.java)\n  decodes URL-encoded strings using the\n  [Java APIs of RAPIDS cudf](https://docs.rapids.ai/api/cudf-java/legacy)\n- [URLEncode](src/main/java/com/nvidia/spark/rapids/udf/java/URLEncode.java)\n  URL-encodes strings using the\n  [Java APIs of RAPIDS cudf](https://docs.rapids.ai/api/cudf-java/legacy)\n- [CosineSimilarity](src/main/java/com/nvidia/spark/rapids/udf/java/CosineSimilarity.java)\n  computes the [cosine similarity](https://en.wikipedia.org/wiki/Cosine_similarity)\n  between two float vectors using [native code](src/main/cpp/src)\n\n## Hive UDF Examples\n\nBelow are some examples for implementing RAPIDS accelerated Hive UDF via JNI and native code.\n\n- [URLDecode](src/main/java/com/nvidia/spark/rapids/udf/hive/URLDecode.java)\n  implements a Hive simple UDF using the\n  [Java APIs of RAPIDS cudf](https://docs.rapids.ai/api/cudf-java/legacy)\n  to decode URL-encoded strings\n- [URLEncode](src/main/java/com/nvidia/spark/rapids/udf/hive/URLEncode.java)\n  implements a Hive generic UDF using the\n  [Java APIs of RAPIDS cudf](https://docs.rapids.ai/api/cudf-java/legacy)\n  to URL-encode strings\n- [StringWordCount](src/main/java/com/nvidia/spark/rapids/udf/hive/StringWordCount.java)\n  implements a Hive simple UDF using\n  [native code](src/main/cpp/src) to count words in strings\n\n## Building and run the tests without Native Code Examples\n\nSome UDF examples use native code in their implementation. Building the native code requires a\nlibcudf build environment, so these examples do not build by default.\n\n### Prerequisites\n\nDownload [Apache Spark](https://spark.apache.org/downloads.html) and set `SPARK_HOME` environment variable.\nInstall Python 3.8+, then install `pytest`, `sre_yield` by using pip or conda. For\nexample:\n\n```\nexport SPARK_HOME=path-to-spark\npip install pytest              # If running in the docker container, please use pip3\npip install sre_yield           # If running in the docker container, please use pip3\n```\n\nRun the following command to build and run tests\n\n```bash\ncd spark-rapids-examples/examples/UDF-Examples/RAPIDS-accelerated-UDFs\nmvn clean package\n./run_pyspark_from_build.sh -m \"not rapids_udf_example_native\"\n```\n\n## Building with Native Code Examples and run test cases\n\nThe `udf-native-examples` Maven profile can be used to include the native UDF examples in the build,\ni.e.: specify\n`-Pudf-native-examples` on the `mvn` command-line.\n\n### Creating a libcudf Build Environment\n\nBuilding the native code requires a libcudf build environment.  \nThe `Dockerfile` in this directory can be used to setup a Docker image that provides a libcudf build\nenvironment. This repository will either need to be cloned or mounted into a container using that\nDocker image. The `Dockerfile` contains build arguments to control the Linux version, CUDA version,\nand other settings. See the top of the `Dockerfile` for details.\n\nFirst install docker and [nvidia-docker](https://github.com/NVIDIA/nvidia-docker)\n\nRun the following commands to build and start a docker\n\n```bash\ncd spark-rapids-examples/examples/UDF-Examples/RAPIDS-accelerated-UDFs\ndocker build -t my-local:my-udf-example .\nnvidia-docker run -it my-local:my-udf-example\n```\n\n### Build the udf-examples jar\n\n#### Option 1: Fast Build Using Prebuilt libcudf (Recommended)\n\nInstead of building cuDF from source (which takes a long time), you can use the prebuilt `libcudf.so` \nfrom the `rapids-4-spark` jar. This is much faster!\n\n**Prerequisites:**\n- rapids-4-spark jar must be available in your local Maven repository\n\n**Steps:**\n\n1. Extract libcudf.so and cuDF headers (automatic with Maven):\n```bash\ncd spark-rapids-examples/examples/UDF-Examples/RAPIDS-accelerated-UDFs\nmvn clean package -Pudf-native-examples\n```\n\nThe build will automatically:\n- Extract `libcudf.so` from the rapids-4-spark jar\n- Clone cuDF repository for headers (shallow clone)\n- Build only your UDF native code against the prebuilt library\n\n**Or manually extract first:**\n```bash\n./extract-cudf-libs.sh\nmvn clean package -Pudf-native-examples\n```\n\nThis approach typically reduces the native cuDF build time by almost **2 hours**!\n\n#### Option 2: Build cuDF from Source (Slow but Complete)\n\nIf you need to build cuDF from source, you can disable the prebuilt library option.\n\n**How it works:**\n- The Maven property `USE_PREBUILT_CUDF` (default: `ON` in pom.xml) is passed to CMake\n- Use `-DUSE_PREBUILT_CUDF=OFF` as a Maven system property to override the default\n- Maven replaces `${USE_PREBUILT_CUDF}` in pom.xml and passes it to CMake as `-DUSE_PREBUILT_CUDF=OFF`\n\n**Build with source:**\n\n```bash\ncd spark-rapids-examples/examples/UDF-Examples/RAPIDS-accelerated-UDFs\nexport LOCAL_CCACHE_DIR=\"$HOME/.ccache\"\nmkdir -p $LOCAL_CCACHE_DIR\nexport CCACHE_DIR=\"$LOCAL_CCACHE_DIR\"\nexport CMAKE_C_COMPILER_LAUNCHER=\"ccache\"\nexport CMAKE_CXX_COMPILER_LAUNCHER=\"ccache\"\nexport CMAKE_CUDA_COMPILER_LAUNCHER=\"ccache\"\nexport CMAKE_CXX_LINKER_LAUNCHER=\"ccache\"\nmvn clean package -Pudf-native-examples -DUSE_PREBUILT_CUDF=OFF\n```\n\n**Alternative: Edit CMakeLists.txt directly**\n\nYou can also edit `src/main/cpp/CMakeLists.txt` and set:\n```cmake\noption(USE_PREBUILT_CUDF \"Use prebuilt libcudf.so from rapids-4-spark jar\" OFF)\n```\n\n#### Configurable Maven Properties\n\nYou can customize the build by passing Maven system properties via `-D<property>=<value>`. These properties are defined in `pom.xml` and passed to CMake:\n\n| Maven Property | Default Value | Description |\n|----------------|---------------|-------------|\n| `USE_PREBUILT_CUDF` | `ON` | Use prebuilt libcudf.so from rapids-4-spark jar (faster build) |\n| `GPU_ARCHS` | `RAPIDS` | GPU architectures to compile for (e.g., `60;70;75;80`) |\n| `CPP_PARALLEL_LEVEL` | `10` | Number of parallel compilation jobs |\n| `BUILD_UDF_BENCHMARKS` | `OFF` | Build benchmark executables |\n| `PER_THREAD_DEFAULT_STREAM` | `ON` | Enable per-thread default CUDA streams |\n| `CUDF_ENABLE_ARROW_S3` | `OFF` | Enable Arrow S3 support in cuDF |\n| `cudf.git.branch` | `main` | cuDF git branch to clone for headers |\n| `skipCudfExtraction` | `false` | Skip extracting cuDF dependencies from jar |\n\n**Example usage:**\n```bash\n# Build for specific GPU architectures with more parallel jobs\nmvn clean package -Pudf-native-examples -DGPU_ARCHS=\"75;80;86\" -DCPP_PARALLEL_LEVEL=16\n\n# Skip cuDF extraction and use existing dependencies\nmvn clean package -Pudf-native-examples -DskipCudfExtraction=true\n```\n\n#### Using ccache to Accelerate Builds\n\nThe Docker container has installed ccache 4.6 to accelerate the incremental building.\nYou can change the LOCAL_CCACHE_DIR to a mounted folder so that the cache can persist.\nIf you don't want to use ccache, you can remove or unset the ccache environment variables.\n\n```bash\nunset CCACHE_DIR\nunset CMAKE_C_COMPILER_LAUNCHER\nunset CMAKE_CXX_COMPILER_LAUNCHER\nunset CMAKE_CUDA_COMPILER_LAUNCHER\nunset CMAKE_CXX_LINKER_LAUNCHER\n```\n\nThe first build could take a long time (e.g.: 1.5 hours). Then the rapids-4-spark-udf-examples*.jar is\ngenerated under RAPIDS-accelerated-UDFs/target directory.\nThe following build can benefit from ccache if you enable it.\n\nIf you want to enable building with ccache on your own system,\nplease refer to the commands which build ccache from the source code in the Dockerfile.\n\n### Run all the examples including native examples in the docker\n\nSee the above [Prerequisites section](#prerequisites)\n\n```\nexport SPARK_HOME=path-to-spark\npip install pytest\npip install sre_yield\n```\n\nRun the following command to run tests\n\n```\n./run_pyspark_from_build.sh\n```\n\n## How to run the Native UDFs on Spark local mode\n\nFirst finish the steps in \n[Building with Native Code Examples and run test cases](#building-with-native-code-examples-and-run-test-cases) section, \nthen do the following inside the Docker container.\n\n### Get jars from Maven Central\n\n[rapids-4-spark_2.12-26.02.0.jar](https://repo1.maven.org/maven2/com/nvidia/rapids-4-spark_2.12/26.02.0/rapids-4-spark_2.12-26.02.0.jar)\n\n\n### Launch a local mode Spark\n\n```bash\nexport SPARK_RAPIDS_PLUGIN_JAR=path-to-rapids-4-spark-jar\nexport SPARK_RAPIDS_UDF_EXAMPLES_JAR=path-to-udf-examples-jar\n\n$SPARK_HOME/bin/pyspark --master local[*] \\\n--conf spark.executor.cores=6 \\\n--driver-memory 5G  \\\n--executor-memory 5G  \\\n--jars ${SPARK_RAPIDS_PLUGIN_JAR},${SPARK_RAPIDS_UDF_EXAMPLES_JAR} \\\n--conf spark.plugins=com.nvidia.spark.SQLPlugin \\\n--conf spark.rapids.sql.enabled=true\n```\n\n### Test native based UDF\n\nInput the following commands to test wordcount JNI UDF\n\n```python\nfrom pyspark.sql.types import *\nschema = StructType([\n    StructField(\"c1\", StringType()),\n    StructField(\"c2\", IntegerType()),\n])\ndata = [\n    (\"a b c d\",1),\n    (\"\",2),\n    (None,3),\n    (\"the quick brown fox jumped over the lazy dog\",3),\n]\ndf = spark.createDataFrame(\n        SparkContext.getOrCreate().parallelize(data, numSlices=2),\n        schema)\ndf.createOrReplaceTempView(\"tab\")\n\nspark.sql(\"CREATE TEMPORARY FUNCTION {} AS '{}'\".format(\"wordcount\", \"com.nvidia.spark.rapids.udf.hive.StringWordCount\"))\nspark.sql(\"select c1, wordcount(c1) from tab\").show()\nspark.sql(\"select c1, wordcount(c1) from tab\").explain()\n```\n"
  },
  {
    "path": "examples/UDF-Examples/RAPIDS-accelerated-UDFs/clone-cudf-repo.sh",
    "content": "#!/bin/bash\n#\n# Copyright (c) 2026, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n\n###############################################################################\n# Clone or update cuDF repository for header files\n#\n# This script is called by Maven during the build process to obtain cuDF\n# headers needed for compiling native UDF code.\n#\n# Usage:\n#   clone-cudf-repo.sh <target_directory> <branch_name>\n#\n# Arguments:\n#   target_directory - Directory where cuDF repo will be cloned\n#   branch_name      - Git branch to clone/checkout\n#\n# Exit codes:\n#   0 - Success\n#   1 - Failed to clone, fetch, or checkout\n###############################################################################\n\nset -e\nset -o pipefail\n\n# Parse arguments\nif [ $# -ne 2 ]; then\n    echo \"ERROR: Usage: $0 <target_directory> <branch_name>\" >&2\n    exit 1\nfi\n\nCUDF_DIR=\"$1\"\nBRANCH=\"$2\"\n\necho \"================================================\"\necho \"cuDF Repository Management\"\necho \"  Target directory: $CUDF_DIR\"\necho \"  Branch: $BRANCH\"\necho \"================================================\"\n\n# Check if repository already exists\nif [ ! -d \"$CUDF_DIR/.git\" ]; then\n    # Repository doesn't exist - clone it\n    echo \"Cloning cuDF repository ($BRANCH branch)...\"\n    \n    git clone --depth 1 --branch \"$BRANCH\" \\\n        https://github.com/rapidsai/cudf.git \"$CUDF_DIR\" || {\n        echo \"ERROR: Failed to clone cuDF from branch $BRANCH\" >&2\n        echo \"Please check:\" >&2\n        echo \"  1. Network connectivity to GitHub\" >&2\n        echo \"  2. Branch '$BRANCH' exists in cuDF repository\" >&2\n        exit 1\n    }\n    \n    echo \"✓ Successfully cloned cuDF repository\"\nelse\n    # Repository exists - verify and update if needed\n    echo \"cuDF repository exists, verifying branch...\"\n    cd \"$CUDF_DIR\" || {\n        echo \"ERROR: Cannot access directory $CUDF_DIR\" >&2\n        exit 1\n    }\n    \n    # Get current branch\n    CURRENT_BRANCH=$(git rev-parse --abbrev-ref HEAD 2>/dev/null || echo \"unknown\")\n    \n    if [ \"$CURRENT_BRANCH\" != \"$BRANCH\" ]; then\n        # Branch mismatch - fetch and switch to correct branch\n        echo \"Branch mismatch detected:\"\n        echo \"  Current branch: $CURRENT_BRANCH\"\n        echo \"  Expected branch: $BRANCH\"\n        echo \"Fetching and switching to $BRANCH...\"\n        \n        git fetch --depth 1 origin \"$BRANCH\" || {\n            echo \"ERROR: Failed to fetch branch $BRANCH from origin\" >&2\n            echo \"Please check:\" >&2\n            echo \"  1. Network connectivity to GitHub\" >&2\n            echo \"  2. Branch '$BRANCH' exists in cuDF repository\" >&2\n            exit 1\n        }\n        \n        git checkout \"$BRANCH\" || {\n            echo \"ERROR: Failed to checkout branch $BRANCH\" >&2\n            exit 1\n        }\n        \n        git reset --hard \"origin/$BRANCH\" || {\n            echo \"ERROR: Failed to reset to origin/$BRANCH\" >&2\n            exit 1\n        }\n        \n        echo \"✓ Switched to branch $BRANCH\"\n    else\n        echo \"✓ Already on correct branch ($BRANCH)\"\n    fi\nfi\n\necho \"================================================\"\necho \"✓ cuDF repository ready at: $CUDF_DIR\"\necho \"  Branch: $BRANCH\"\necho \"================================================\"\n\nexit 0\n\n"
  },
  {
    "path": "examples/UDF-Examples/RAPIDS-accelerated-UDFs/conftest.py",
    "content": "# Copyright (c) 2020-2022, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\ndef pytest_addoption(parser):\n    \"\"\"Pytest hook to define command line options for pytest\"\"\"\n    parser.addoption(\n        \"--mortgage_format\", action=\"store\", default=\"parquet\", help=\"format of Mortgage data\"\n    )\n    parser.addoption(\n        \"--mortgage_path\", action=\"store\", default=None, help=\"path to Mortgage data\"\n    )\n    parser.addoption(\n        \"--std_input_path\", action=\"store\", default=None, help=\"path to standard input files\"\n    )\n    parser.addoption(\n        \"--tmp_path\", action=\"store\", default=None, help=\"path to store tmp files\"\n    )\n    parser.addoption(\n        \"--debug_tmp_path\", action='store_true', default=False, help=\"if true don't delete tmp_path contents for debugging\"\n    )\n    parser.addoption(\n        \"--runtime_env\", action='store', default=\"Apache\", help=\"the runtime environment for the tests - apache or databricks\"\n    )\n    parser.addoption(\n        \"--cudf_udf\", action='store_true', default=False, help=\"if true enable cudf_udf test\"\n    )\n    parser.addoption(\n        \"--rapids_udf_example_native\", action='store_true', default=False,\n        help=\"if true enable tests for RAPIDS UDF examples with native code\"\n    )\n    parser.addoption(\n        \"--test_type\", action='store', default=\"developer\",\n        help=\"the type of tests that are being run to help check all the correct tests are run - developer, pre-commit, or nightly\"\n    )\n"
  },
  {
    "path": "examples/UDF-Examples/RAPIDS-accelerated-UDFs/extract-cudf-libs.sh",
    "content": "#!/bin/bash\n#\n# Copyright (c) 2026, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n\n###############################################################################\n# Extract libcudf.so from rapids-4-spark jar\n#\n# This script extracts prebuilt cuDF libraries from the rapids-4-spark jar\n# to enable faster builds by avoiding building cuDF from source.\n#\n# Configuration values are read from pom.xml by default, but can be overridden\n# using environment variables:\n#\n# Usage:\n#   ./extract-cudf-libs.sh\n#\n# Environment Variables (optional, will use pom.xml values if not set):\n#   RAPIDS4SPARK_VERSION - rapids-4-spark version (e.g., 26.02.0 or 26.06.0-SNAPSHOT)\n#   SCALA_VERSION        - Scala binary version (e.g., 2.12, 2.13)\n#   CUDA_VERSION         - CUDA version (e.g., cuda11, cuda12)\n#   CUDF_BRANCH          - cuDF git branch for headers (e.g., main, branch-26.02)\n#\n# Example with overrides:\n#   RAPIDS4SPARK_VERSION=26.02.0 CUDA_VERSION=cuda11 ./extract-cudf-libs.sh\n###############################################################################\n\nset -e\n\nSCRIPT_DIR=\"$(cd \"$(dirname \"${BASH_SOURCE[0]}\")\" && pwd)\"\nTARGET_DIR=\"$SCRIPT_DIR/target\"\nNATIVE_DEPS_DIR=\"$TARGET_DIR/native-deps\"\nCUDF_REPO_DIR=\"$TARGET_DIR/cudf-repo\"\nPOM_FILE=\"$SCRIPT_DIR/pom.xml\"\n\n# Function to extract property value from pom.xml\n# Usage: extract_pom_property \"property_name\"\nextract_pom_property() {\n    local property_name=\"$1\"\n    local value\n    \n    # Use xmllint if available (more reliable)\n    if command -v xmllint >/dev/null 2>&1; then\n        value=$(xmllint --xpath \"string(//*[local-name()='project']/*[local-name()='properties']/*[local-name()='${property_name}'])\" \"$POM_FILE\" 2>/dev/null)\n    else\n        # Fallback to grep/sed (less robust but widely available)\n        value=$(grep -A 1 \"<${property_name}>\" \"$POM_FILE\" | grep -v \"^--$\" | sed -n \"s/.*<${property_name}>\\(.*\\)<\\/${property_name}>.*/\\1/p\" | head -1 | xargs)\n    fi\n    \n    echo \"$value\"\n}\n\necho \"==================================================\"\necho \"Extract cuDF Dependencies for UDF Examples\"\necho \"==================================================\"\necho \"Reading configuration from pom.xml...\"\n\n# Read defaults from pom.xml\nPOM_RAPIDS4SPARK_VERSION=$(extract_pom_property \"rapids4spark.version\")\nPOM_SCALA_VERSION=$(extract_pom_property \"scala.binary.version\")\nPOM_CUDA_VERSION=$(extract_pom_property \"cuda.version\")\nPOM_CUDF_BRANCH=$(extract_pom_property \"cudf.git.branch\")\n\n# Use environment variables if set, otherwise use pom.xml values\nRAPIDS4SPARK_VERSION=\"${RAPIDS4SPARK_VERSION:-${POM_RAPIDS4SPARK_VERSION}}\"\nSCALA_VERSION=\"${SCALA_VERSION:-${POM_SCALA_VERSION}}\"\nCUDA_VERSION=\"${CUDA_VERSION:-${POM_CUDA_VERSION}}\"\nCUDF_BRANCH=\"${CUDF_BRANCH:-${POM_CUDF_BRANCH}}\"\n\n# Validate that we have all required values\nif [ -z \"$RAPIDS4SPARK_VERSION\" ] || [ -z \"$SCALA_VERSION\" ] || [ -z \"$CUDA_VERSION\" ] || [ -z \"$CUDF_BRANCH\" ]; then\n    echo \"ERROR: Failed to read required properties from pom.xml\" >&2\n    echo \"Please ensure pom.xml exists and contains all required properties:\" >&2\n    echo \"  - rapids4spark.version\" >&2\n    echo \"  - scala.binary.version\" >&2\n    echo \"  - cuda.version\" >&2\n    echo \"  - cudf.git.branch\" >&2\n    exit 1\nfi\n\necho \"Configuration:\"\necho \"  RAPIDS4SPARK_VERSION: $RAPIDS4SPARK_VERSION\"\necho \"  SCALA_VERSION: $SCALA_VERSION\"\necho \"  CUDA_VERSION: $CUDA_VERSION\"\necho \"  CUDF_BRANCH: $CUDF_BRANCH\"\necho \"==================================================\"\n\n# Create directories\nmkdir -p \"$NATIVE_DEPS_DIR\"\nmkdir -p \"$CUDF_REPO_DIR\"\n\n# Find rapids-4-spark jar in local Maven repository\nMAVEN_REPO=\"${HOME}/.m2/repository\"\n\n# Try multiple naming patterns\nJAR_PATH_WITH_CLASSIFIER=\"$MAVEN_REPO/com/nvidia/rapids-4-spark_${SCALA_VERSION}/${RAPIDS4SPARK_VERSION}/rapids-4-spark_${SCALA_VERSION}-${RAPIDS4SPARK_VERSION}-${CUDA_VERSION}.jar\"\nJAR_PATH_NO_CLASSIFIER=\"$MAVEN_REPO/com/nvidia/rapids-4-spark_${SCALA_VERSION}/${RAPIDS4SPARK_VERSION}/rapids-4-spark_${SCALA_VERSION}-${RAPIDS4SPARK_VERSION}.jar\"\n\necho \"Looking for rapids-4-spark jar...\"\necho \"  Pattern 1 (with classifier): $JAR_PATH_WITH_CLASSIFIER\"\necho \"  Pattern 2 (no classifier):   $JAR_PATH_NO_CLASSIFIER\"\n\nif [ -f \"$JAR_PATH_WITH_CLASSIFIER\" ]; then\n    JAR_PATH=\"$JAR_PATH_WITH_CLASSIFIER\"\n    echo \"✓ Found jar (with classifier): $JAR_PATH\"\nelif [ -f \"$JAR_PATH_NO_CLASSIFIER\" ]; then\n    JAR_PATH=\"$JAR_PATH_NO_CLASSIFIER\"\n    echo \"✓ Found jar (no classifier): $JAR_PATH\"\nelse\n    echo \"\"\n    echo \"ERROR: rapids-4-spark jar not found!\"\n    echo \"Tried:\"\n    echo \"  $JAR_PATH_WITH_CLASSIFIER\"\n    echo \"  $JAR_PATH_NO_CLASSIFIER\"\n    echo \"\"\n    echo \"For SNAPSHOT versions:\"\n    echo \"  cd /path/to/spark-rapids\"\n    echo \"  mvn clean install -DskipTests\"\n    echo \"\"\n    echo \"For release versions:\"\n    echo \"  mvn dependency:get -Dartifact=com.nvidia:rapids-4-spark_${SCALA_VERSION}:${RAPIDS4SPARK_VERSION}:jar:${CUDA_VERSION}\"\n    exit 1\nfi\n\n# Extract libcudf.so and dependencies\necho \"Extracting native libraries from jar...\"\necho \"  Jar: $JAR_PATH\"\necho \"  Looking for: */libcudf.so*, */libnvcomp.so*\"\n\n# Use unzip without -q to capture output, but redirect to log for debugging\nUNZIP_OUTPUT=$(unzip -o \"$JAR_PATH\" \"*/libcudf.so*\" \"*/libnvcomp.so*\" -d \"$TARGET_DIR/temp\" 2>&1)\nUNZIP_EXIT_CODE=$?\n\n# Check unzip exit code\nif [ $UNZIP_EXIT_CODE -ne 0 ]; then\n    echo \"ERROR: Failed to extract libraries from jar\" >&2\n    echo \"unzip exit code: $UNZIP_EXIT_CODE\" >&2\n    \n    # Provide helpful diagnostics\n    case $UNZIP_EXIT_CODE in\n        11)\n            echo \"Reason: No matching files found in jar\" >&2\n            echo \"\" >&2\n            echo \"The jar may not contain native libraries for your platform.\" >&2\n            echo \"Expected patterns: */libcudf.so*, */libnvcomp.so*\" >&2\n            echo \"\" >&2\n            echo \"Listing jar contents:\" >&2\n            unzip -l \"$JAR_PATH\" | grep -E '\\.(so|dylib|dll)' || echo \"  No native libraries found\" >&2\n            ;;\n        *)\n            echo \"Reason: unzip command failed\" >&2\n            echo \"Output: $UNZIP_OUTPUT\" >&2\n            ;;\n    esac\n    \n    echo \"\" >&2\n    echo \"Falling back to source build...\" >&2\n    exit 1\nfi\n\n# Verify that we actually extracted some files\nEXTRACTED_COUNT=$(find \"$TARGET_DIR/temp\" -name \"*.so*\" 2>/dev/null | wc -l)\necho \"Extracted $EXTRACTED_COUNT library file(s)\"\n\nif [ \"$EXTRACTED_COUNT\" -eq 0 ]; then\n    echo \"ERROR: No library files were extracted from jar\" >&2\n    echo \"This usually means the jar doesn't contain native libraries.\" >&2\n    echo \"\" >&2\n    echo \"Listing jar contents:\" >&2\n    unzip -l \"$JAR_PATH\" | head -20 >&2\n    exit 1\nfi\n\n# Move libraries to native-deps directory, detecting conflicts\necho \"Moving extracted libraries...\"\nCONFLICT_COUNT=0\n\n# Use process substitution to avoid subshell issues\nwhile IFS= read -r source_file; do\n    filename=$(basename \"$source_file\")\n    dest_file=\"$NATIVE_DEPS_DIR/$filename\"\n    \n    if [ -f \"$dest_file\" ]; then\n        # File already exists - check if it's the same\n        if ! cmp -s \"$source_file\" \"$dest_file\"; then\n            echo \"WARNING: Conflicting library detected: $filename\" >&2\n            echo \"  Existing: $dest_file\" >&2\n            echo \"  New:      $source_file\" >&2\n            echo \"  Keeping existing file, skipping new one\" >&2\n            CONFLICT_COUNT=$((CONFLICT_COUNT + 1))\n        fi\n        # Remove the duplicate source file\n        rm -f \"$source_file\"\n    else\n        # No conflict, move the file\n        mv \"$source_file\" \"$dest_file\"\n    fi\ndone < <(find \"$TARGET_DIR/temp\" -name \"*.so*\")\n\nif [ \"$CONFLICT_COUNT\" -gt 0 ]; then\n    echo \"WARNING: $CONFLICT_COUNT library file(s) had conflicts. Review the warnings above.\" >&2\nfi\n\nrm -rf \"$TARGET_DIR/temp\"\n\n# Verify that libcudf.so was successfully moved to final location\nif [ ! -f \"$NATIVE_DEPS_DIR/libcudf.so\" ]; then\n    echo \"ERROR: libcudf.so not found in $NATIVE_DEPS_DIR\" >&2\n    echo \"\" >&2\n    echo \"This could mean:\" >&2\n    echo \"  1. The jar didn't contain libcudf.so\" >&2\n    echo \"  2. Extraction succeeded but moving files failed\" >&2\n    echo \"  3. Wrong architecture (jar might be for a different platform)\" >&2\n    echo \"\" >&2\n    echo \"Contents of $NATIVE_DEPS_DIR:\" >&2\n    ls -lh \"$NATIVE_DEPS_DIR\" >&2 || echo \"  Directory is empty or doesn't exist\" >&2\n    exit 1\nfi\n\necho \"✓ Successfully extracted libraries to: $NATIVE_DEPS_DIR\"\nls -lh \"$NATIVE_DEPS_DIR\"\n\n# Clone cuDF repo for headers (shallow clone)\nif [ ! -d \"$CUDF_REPO_DIR/.git\" ]; then\n    echo \"Cloning cuDF repository for headers...\"\n    git clone --depth 1 --branch \"$CUDF_BRANCH\" https://github.com/rapidsai/cudf.git \"$CUDF_REPO_DIR\"\n    echo \"✓ Cloned cuDF repo to: $CUDF_REPO_DIR\"\nelse\n    echo \"✓ cuDF repo already exists at: $CUDF_REPO_DIR\"\n    echo \"  (Delete it to re-clone: rm -rf \\\"$CUDF_REPO_DIR\\\")\"\nfi\n\necho \"\"\necho \"==================================================\"\necho \"Setup complete! You can now build with:\"\necho \"  mvn clean package -P udf-native-examples\"\necho \"\"\necho \"This will use prebuilt libcudf.so and avoid\"\necho \"building cuDF from source (much faster!).\"\necho \"==================================================\"\n\n"
  },
  {
    "path": "examples/UDF-Examples/RAPIDS-accelerated-UDFs/pom.xml",
    "content": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<!--\n  Copyright (c) 2020-2026, NVIDIA CORPORATION.\n\n  Licensed under the Apache License, Version 2.0 (the \"License\");\n  you may not use this file except in compliance with the License.\n  You may obtain a copy of the License at\n\n     http://www.apache.org/licenses/LICENSE-2.0\n\n  Unless required by applicable law or agreed to in writing, software\n  distributed under the License is distributed on an \"AS IS\" BASIS,\n  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n  See the License for the specific language governing permissions and\n  limitations under the License.\n-->\n<project xmlns=\"http://maven.apache.org/POM/4.0.0\"\n         xmlns:xsi=\"http://www.w3.org/2001/XMLSchema-instance\"\n         xsi:schemaLocation=\"http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd\">\n    <modelVersion>4.0.0</modelVersion>\n    <groupId>com.nvidia</groupId>\n    <artifactId>rapids-4-spark-udf-examples_2.12</artifactId>\n    <name>RAPIDS Accelerator for Apache Spark UDF Examples</name>\n    <description>Sample implementations of RAPIDS accelerated\n        user defined functions for use with the RAPIDS Accelerator\n        for Apache Spark\n    </description>\n    <version>26.06.0-SNAPSHOT</version>\n\n    <properties>\n        <maven.compiler.source>1.8</maven.compiler.source>\n        <maven.compiler.target>1.8</maven.compiler.target>\n        <java.major.version>8</java.major.version>\n        <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>\n        <project.reporting.sourceEncoding>UTF-8</project.reporting.sourceEncoding>\n        <project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>\n        <cuda.version>cuda12</cuda.version>\n        <scala.binary.version>2.12</scala.binary.version>\n        <!-- Depends on release version, Snapshot version is not published to the Maven Central -->\n        <rapids4spark.version>26.02.0</rapids4spark.version>\n        <spark.version>3.1.1</spark.version>\n        <scala.version>2.12.15</scala.version>\n        <udf.native.build.path>${project.build.directory}/cpp-build</udf.native.build.path>\n        <cudf.git.branch>main</cudf.git.branch>\n        <skipCudfExtraction>false</skipCudfExtraction>\n        <BUILD_UDF_BENCHMARKS>OFF</BUILD_UDF_BENCHMARKS>\n        <CMAKE_CXX_FLAGS/>\n        <GPU_ARCHS>RAPIDS</GPU_ARCHS>\n        <PER_THREAD_DEFAULT_STREAM>ON</PER_THREAD_DEFAULT_STREAM>\n        <CPP_PARALLEL_LEVEL>10</CPP_PARALLEL_LEVEL>\n        <CUDF_ENABLE_ARROW_S3>OFF</CUDF_ENABLE_ARROW_S3>\n        <USE_PREBUILT_CUDF>ON</USE_PREBUILT_CUDF>\n        <target.classifier/>\n    </properties>\n\n    <dependencies>\n        <dependency>\n            <groupId>org.apache.spark</groupId>\n            <artifactId>spark-hive_${scala.binary.version}</artifactId>\n            <version>${spark.version}</version>\n        </dependency>\n        <dependency>\n            <groupId>org.scala-lang</groupId>\n            <artifactId>scala-library</artifactId>\n            <version>${scala.version}</version>\n        </dependency>\n        <dependency>\n            <groupId>com.nvidia</groupId>\n            <artifactId>rapids-4-spark_${scala.binary.version}</artifactId>\n            <version>${rapids4spark.version}</version>\n            <scope>provided</scope>\n        </dependency>\n    </dependencies>\n\n    <build>\n        <resources>\n            <resource>\n                <!-- Include the properties file to provide the build information. -->\n                <directory>${project.build.directory}/extra-resources</directory>\n                <filtering>true</filtering>\n            </resource>\n        </resources>\n        <plugins>\n            <plugin>\n                <groupId>org.apache.maven.plugins</groupId>\n                <artifactId>maven-jar-plugin</artifactId>\n                <version>3.2.0</version>\n                <executions>\n                    <execution>\n                        <!-- disable test jar -->\n                        <id>default-test-jar</id>\n                        <phase>none</phase>\n                    </execution>\n                </executions>\n                <configuration>\n                    <!-- Include native libraries in the jar -->\n                    <includes>\n                        <include>**/*</include>\n                    </includes>\n                </configuration>\n            </plugin>\n            <plugin>\n                <groupId>net.alchim31.maven</groupId>\n                <artifactId>scala-maven-plugin</artifactId>\n                <version>4.3.0</version>\n            </plugin>\n            <plugin>\n                <groupId>org.apache.rat</groupId>\n                <artifactId>apache-rat-plugin</artifactId>\n                <version>0.13</version>\n            </plugin>\n            <plugin>\n                <groupId>org.apache.maven.plugins</groupId>\n                <artifactId>maven-antrun-plugin</artifactId>\n                <version>3.0.0</version>\n                <executions>\n                    <execution>\n                        <id>generate-build-info</id>\n                        <phase>none</phase>\n                    </execution>\n                </executions>\n            </plugin>\n            <plugin>\n                <groupId>org.codehaus.mojo</groupId>\n                <artifactId>exec-maven-plugin</artifactId>\n                <version>3.0.0</version>\n                <executions>\n                    <execution>\n                        <id>run pyspark tests</id>\n                        <phase>verify</phase><!--run after packaging and collecting dependencies-->\n                        <goals>\n                            <goal>exec</goal>\n                        </goals>\n                        <configuration>\n                            <executable>./run_pyspark_from_build.sh</executable>\n                            <workingDirectory>./</workingDirectory>\n                            <environmentVariables>\n                                <SKIP_TESTS>${skipTests}</SKIP_TESTS>\n                            </environmentVariables>\n                        </configuration>\n                    </execution>\n                </executions>\n            </plugin>\n            <plugin>\n                <groupId>org.apache.maven.plugins</groupId>\n                <artifactId>maven-dependency-plugin</artifactId>\n                <!-- copy rapids-4-spark.jar to dependency directory-->\n                <executions>\n                    <execution>\n                        <id>copy-dist-jar</id>\n                        <phase>package</phase>\n                        <goals>\n                            <goal>copy</goal>\n                        </goals>\n                        <configuration>\n                            <useBaseVersion>true</useBaseVersion>\n                            <artifactItems>\n                                <artifactItem>\n                                    <groupId>com.nvidia</groupId>\n                                    <artifactId>rapids-4-spark_${scala.binary.version}</artifactId>\n                                    <version>${rapids4spark.version}</version>\n                                </artifactItem>\n                            </artifactItems>\n                        </configuration>\n                    </execution>\n                </executions>\n            </plugin>\n        </plugins>\n    </build>\n\n    <profiles>\n        <profile>\n            <id>udf-native-examples</id>\n            <build>\n                <plugins>\n                    <plugin>\n                        <groupId>org.apache.maven.plugins</groupId>\n                        <artifactId>maven-dependency-plugin</artifactId>\n                        <executions>\n                            <!-- Try downloading with classifier first (for release versions) -->\n                            <execution>\n                                <id>download-rapids-jar-with-classifier</id>\n                                <phase>generate-sources</phase>\n                                <goals>\n                                    <goal>copy</goal>\n                                </goals>\n                                <configuration>\n                                    <artifactItems>\n                                        <artifactItem>\n                                            <groupId>com.nvidia</groupId>\n                                            <artifactId>rapids-4-spark_${scala.binary.version}</artifactId>\n                                            <version>${rapids4spark.version}</version>\n                                            <classifier>${cuda.version}</classifier>\n                                            <type>jar</type>\n                                            <overWrite>false</overWrite>\n                                            <outputDirectory>${project.build.directory}/rapids-jar</outputDirectory>\n                                        </artifactItem>\n                                    </artifactItems>\n                                    <!-- Don't fail if not found - SNAPSHOT versions may not have classifier -->\n                                    <skip>false</skip>\n                                    <ignoreMissingArtifact>true</ignoreMissingArtifact>\n                                </configuration>\n                            </execution>\n                            <!-- Try downloading without classifier (for SNAPSHOT versions) -->\n                            <execution>\n                                <id>download-rapids-jar-no-classifier</id>\n                                <phase>generate-sources</phase>\n                                <goals>\n                                    <goal>copy</goal>\n                                </goals>\n                                <configuration>\n                                    <artifactItems>\n                                        <artifactItem>\n                                            <groupId>com.nvidia</groupId>\n                                            <artifactId>rapids-4-spark_${scala.binary.version}</artifactId>\n                                            <version>${rapids4spark.version}</version>\n                                            <!-- No classifier for SNAPSHOT versions -->\n                                            <type>jar</type>\n                                            <overWrite>false</overWrite>\n                                            <outputDirectory>${project.build.directory}/rapids-jar</outputDirectory>\n                                        </artifactItem>\n                                    </artifactItems>\n                                    <!-- Don't fail if not found - one of the two attempts should succeed -->\n                                    <skip>false</skip>\n                                    <ignoreMissingArtifact>true</ignoreMissingArtifact>\n                                </configuration>\n                            </execution>\n                        </executions>\n                    </plugin>\n                    <plugin>\n                        <artifactId>maven-antrun-plugin</artifactId>\n                        <executions>\n                            <execution>\n                                <id>extract-cudf-dependencies</id>\n                                <phase>generate-sources</phase>\n                                <configuration>\n                                    <skip>${skipCudfExtraction}</skip>\n                                    <target>\n                                        <echo message=\"================================================\"/>\n                                        <echo message=\"Extracting libcudf.so from rapids-4-spark jar...\"/>\n                                        <echo message=\"USE_PREBUILT_CUDF=${USE_PREBUILT_CUDF}\"/>\n                                        <echo message=\"================================================\"/>\n\n                                        <!-- Skip if user explicitly wants source build -->\n                                        <condition property=\"use.source.build\">\n                                            <equals arg1=\"${USE_PREBUILT_CUDF}\" arg2=\"OFF\"/>\n                                        </condition>\n\n                                        <sequential if:set=\"use.source.build\" xmlns:if=\"ant:if\">\n                                            <echo message=\"\"/>\n                                            <echo message=\"========================================================\"/>\n                                            <echo message=\"USE_PREBUILT_CUDF=OFF detected\"/>\n                                            <echo message=\"Skipping extraction, will build cuDF from source\"/>\n                                            <echo message=\"========================================================\"/>\n                                            <echo message=\"\"/>\n                                        </sequential>\n\n                                        <!-- Try multiple jar naming patterns -->\n                                        <!-- Normalize paths for cross-platform compatibility (Windows/Linux/macOS) -->\n                                        <property name=\"maven.repo.base\" value=\"${settings.localRepository}\"/>\n                                        <pathconvert property=\"maven.repo.normalized\" dirsep=\"/\">\n                                            <path location=\"${maven.repo.base}\"/>\n                                        </pathconvert>\n\n                                        <property name=\"project.target.base\" value=\"${project.build.directory}\"/>\n                                        <pathconvert property=\"project.target.normalized\" dirsep=\"/\">\n                                            <path location=\"${project.target.base}\"/>\n                                        </pathconvert>\n\n                                        <!-- Pattern 1: Maven repo with classifier (release versions) -->\n                                        <property name=\"rapids.jar.maven.with.classifier\" \n                                                  value=\"${maven.repo.normalized}/com/nvidia/rapids-4-spark_${scala.binary.version}/${rapids4spark.version}/rapids-4-spark_${scala.binary.version}-${rapids4spark.version}-${cuda.version}.jar\"/>\n\n                                        <!-- Pattern 2: Maven repo without classifier (SNAPSHOT versions) -->\n                                        <property name=\"rapids.jar.maven.no.classifier\" \n                                                  value=\"${maven.repo.normalized}/com/nvidia/rapids-4-spark_${scala.binary.version}/${rapids4spark.version}/rapids-4-spark_${scala.binary.version}-${rapids4spark.version}.jar\"/>\n\n                                        <!-- Pattern 3: Downloaded with classifier (release versions) -->\n                                        <property name=\"rapids.jar.downloaded.with.classifier\" \n                                                  value=\"${project.target.normalized}/rapids-jar/rapids-4-spark_${scala.binary.version}-${rapids4spark.version}-${cuda.version}.jar\"/>\n\n                                        <!-- Pattern 4: Downloaded without classifier (SNAPSHOT versions) -->\n                                        <property name=\"rapids.jar.downloaded.no.classifier\" \n                                                  value=\"${project.target.normalized}/rapids-jar/rapids-4-spark_${scala.binary.version}-${rapids4spark.version}.jar\"/>\n\n                                        <echo message=\"Looking for rapids-4-spark jar...\"/>\n                                        <echo message=\"  Pattern 1 (Maven with classifier):     ${rapids.jar.maven.with.classifier}\"/>\n                                        <echo message=\"  Pattern 2 (Maven no classifier):       ${rapids.jar.maven.no.classifier}\"/>\n                                        <echo message=\"  Pattern 3 (Downloaded with classifier): ${rapids.jar.downloaded.with.classifier}\"/>\n                                        <echo message=\"  Pattern 4 (Downloaded no classifier):  ${rapids.jar.downloaded.no.classifier}\"/>\n\n                                        <!-- Check all locations -->\n                                        <available file=\"${rapids.jar.maven.with.classifier}\" \n                                                   property=\"rapids.jar.maven.with.classifier.exists\"/>\n                                        <available file=\"${rapids.jar.maven.no.classifier}\" \n                                                   property=\"rapids.jar.maven.no.classifier.exists\"/>\n                                        <available file=\"${rapids.jar.downloaded.with.classifier}\" \n                                                   property=\"rapids.jar.downloaded.with.classifier.exists\"/>\n                                        <available file=\"${rapids.jar.downloaded.no.classifier}\" \n                                                   property=\"rapids.jar.downloaded.no.classifier.exists\"/>\n\n                                        <!-- Set the actual path to use (try in order of preference) -->\n                                        <!-- 1. Try Maven repo with classifier (release) -->\n                                        <condition property=\"rapids.jar.path\" value=\"${rapids.jar.maven.with.classifier}\">\n                                            <isset property=\"rapids.jar.maven.with.classifier.exists\"/>\n                                        </condition>\n                                        <!-- 2. Try Maven repo without classifier (SNAPSHOT) -->\n                                        <condition property=\"rapids.jar.path\" value=\"${rapids.jar.maven.no.classifier}\">\n                                            <and>\n                                                <not><isset property=\"rapids.jar.path\"/></not>\n                                                <isset property=\"rapids.jar.maven.no.classifier.exists\"/>\n                                            </and>\n                                        </condition>\n                                        <!-- 3. Try downloaded with classifier (release) -->\n                                        <condition property=\"rapids.jar.path\" value=\"${rapids.jar.downloaded.with.classifier}\">\n                                            <and>\n                                                <not><isset property=\"rapids.jar.path\"/></not>\n                                                <isset property=\"rapids.jar.downloaded.with.classifier.exists\"/>\n                                            </and>\n                                        </condition>\n                                        <!-- 4. Try downloaded without classifier (SNAPSHOT) -->\n                                        <condition property=\"rapids.jar.path\" value=\"${rapids.jar.downloaded.no.classifier}\">\n                                            <and>\n                                                <not><isset property=\"rapids.jar.path\"/></not>\n                                                <isset property=\"rapids.jar.downloaded.no.classifier.exists\"/>\n                                            </and>\n                                        </condition>\n\n                                        <!-- Verify we have the jar from at least one location -->\n                                        <condition property=\"rapids.jar.exists\">\n                                            <or>\n                                                <isset property=\"rapids.jar.maven.with.classifier.exists\"/>\n                                                <isset property=\"rapids.jar.maven.no.classifier.exists\"/>\n                                                <isset property=\"rapids.jar.downloaded.with.classifier.exists\"/>\n                                                <isset property=\"rapids.jar.downloaded.no.classifier.exists\"/>\n                                            </or>\n                                        </condition>\n\n                                        <!-- Skip all extraction if user wants source build -->\n                                        <condition property=\"should.skip.extraction\">\n                                            <isset property=\"use.source.build\"/>\n                                        </condition>\n\n                                        <!-- Only proceed if not skipping -->\n                                        <sequential unless:set=\"should.skip.extraction\" xmlns:unless=\"ant:unless\">\n\n                                        <!-- Create necessary directories for extraction -->\n                                        <mkdir dir=\"${project.build.directory}/native-deps\"/>\n                                        <mkdir dir=\"${project.build.directory}/cudf-repo\"/>\n                                        <mkdir dir=\"${project.build.directory}/temp-extract\"/>\n\n                                        <!-- Create inverse property for jar NOT found -->\n                                        <condition property=\"rapids.jar.not.found\">\n                                            <not><isset property=\"rapids.jar.exists\"/></not>\n                                        </condition>\n\n                                        <!-- Show warning if jar not found -->\n                                        <sequential if:set=\"rapids.jar.not.found\" xmlns:if=\"ant:if\">\n                                            <echo message=\"\" level=\"warning\"/>\n                                            <echo message=\"========================================================\" level=\"warning\"/>\n                                            <echo message=\"WARNING: rapids-4-spark jar not found!\" level=\"warning\"/>\n                                            <echo message=\"========================================================\" level=\"warning\"/>\n                                            <echo message=\"Tried locations:\" level=\"warning\"/>\n                                            <echo message=\"  Maven (with classifier):    ${rapids.jar.maven.with.classifier}\" level=\"warning\"/>\n                                            <echo message=\"  Maven (no classifier):      ${rapids.jar.maven.no.classifier}\" level=\"warning\"/>\n                                            <echo message=\"  Downloaded (with classifier): ${rapids.jar.downloaded.with.classifier}\" level=\"warning\"/>\n                                            <echo message=\"  Downloaded (no classifier):  ${rapids.jar.downloaded.no.classifier}\" level=\"warning\"/>\n                                            <echo message=\"\" level=\"warning\"/>\n                                            <echo message=\"Note: Release versions use classifier (e.g., cuda12)\" level=\"warning\"/>\n                                            <echo message=\"      SNAPSHOT versions typically do not use classifier\" level=\"warning\"/>\n                                            <echo message=\"\" level=\"warning\"/>\n                                            <echo message=\"For SNAPSHOT versions:\" level=\"warning\"/>\n                                            <echo message=\"  1. Build spark-rapids locally:\" level=\"warning\"/>\n                                            <echo message=\"     cd /path/to/spark-rapids\" level=\"warning\"/>\n                                            <echo message=\"     mvn clean install -DskipTests\" level=\"warning\"/>\n                                            <echo message=\"\" level=\"warning\"/>\n                                            <echo message=\"FALLING BACK to building cuDF from source...\" level=\"warning\"/>\n                                            <echo message=\"========================================================\" level=\"warning\"/>\n                                            <echo message=\"\" level=\"warning\"/>\n\n                                            <!-- Create a marker file to tell CMake to use source build -->\n                                            <touch file=\"${project.build.directory}/USE_SOURCE_BUILD\"/>\n                                        </sequential>\n\n                                        <!-- Extract if jar was found -->\n                                        <sequential if:set=\"rapids.jar.exists\" xmlns:if=\"ant:if\">\n                                            <echo message=\"✓ Found jar at: ${rapids.jar.path}\"/>\n                                            <echo message=\"✓ Extracting libraries...\"/>\n\n                                            <!-- Extract libcudf.so from rapids-4-spark jar -->\n                                            <unzip src=\"${rapids.jar.path}\" \n                                                   dest=\"${project.build.directory}/temp-extract\">\n                                                <patternset>\n                                                    <include name=\"**/libcudf.so*\"/>\n                                                    <include name=\"**/libnvcomp.so*\"/>\n                                                </patternset>\n                                            </unzip>\n\n                                            <!-- Move extracted libraries to native-deps (flatten structure) -->\n                                            <copy todir=\"${project.build.directory}/native-deps\" flatten=\"true\">\n                                                <fileset dir=\"${project.build.directory}/temp-extract\">\n                                                    <include name=\"**/*.so*\"/>\n                                                </fileset>\n                                            </copy>\n\n                                            <!-- Clean up temp directory -->\n                                            <delete dir=\"${project.build.directory}/temp-extract\"/>\n\n                                            <!-- Verify extraction -->\n                                            <available file=\"${project.build.directory}/native-deps/libcudf.so\" \n                                                       property=\"libcudf.extracted\"/>\n                                            <fail unless=\"libcudf.extracted\" \n                                                  message=\"Failed to extract libcudf.so from jar\"/>\n\n                                            <echo message=\"✓ libcudf.so extracted to: ${project.build.directory}/native-deps\"/>\n                                            \n                                            <!-- Clone cuDF repo for headers only (shallow clone) -->\n                                            <!-- Ensure script is executable -->\n                                            <chmod file=\"${basedir}/clone-cudf-repo.sh\" perm=\"755\"/>\n                                            \n                                            <!-- Execute script to clone/update cuDF repository -->\n                                            <exec executable=\"${basedir}/clone-cudf-repo.sh\" failonerror=\"true\">\n                                                <arg value=\"${project.build.directory}/cudf-repo\"/>\n                                                <arg value=\"${cudf.git.branch}\"/>\n                                            </exec>\n\n                                            <!-- Verify cuDF headers exist after clone/update -->\n                                            <available file=\"${project.build.directory}/cudf-repo/cpp/include\" \n                                                       type=\"dir\" \n                                                       property=\"cudf.headers.exist\"/>\n                                            <!-- Note: <fail> only supports 'message' as attribute, not nested element -->\n                                            <!-- Multi-line message using ${line.separator} for readability in output -->\n                                            <fail unless=\"cudf.headers.exist\"\n                                                  message=\"Failed to obtain cuDF headers from branch ${cudf.git.branch}.${line.separator}The directory ${project.build.directory}/cudf-repo/cpp/include does not exist.${line.separator}${line.separator}Please check:${line.separator}  1) Network connectivity to GitHub${line.separator}  2) Branch '${cudf.git.branch}' exists in cuDF repository${line.separator}  3) Try with -Dcudf.git.branch=main${line.separator}${line.separator}For more information, see README.md\"/>\n\n                                            <echo message=\"✓ cuDF headers available at: ${project.build.directory}/cudf-repo\"/>\n                                            <echo message=\"================================================\"/>\n                                            <echo message=\"Dependencies ready! Will use FAST BUILD mode\"/>\n                                            <echo message=\"================================================\"/>\n                                        </sequential>\n\n                                        </sequential><!-- End of: unless should.skip.extraction -->\n                                    </target>\n                                </configuration>\n                                <goals>\n                                    <goal>run</goal>\n                                </goals>\n                            </execution>\n                            <execution>\n                                <id>cmake</id>\n                                <phase>compile</phase>\n                                <configuration>\n                                    <target>\n                                        <mkdir dir=\"${udf.native.build.path}\"/>\n                                        <exec dir=\"${udf.native.build.path}\"\n                                              failonerror=\"true\"\n                                              executable=\"cmake\">\n                                            <arg value=\"${basedir}/src/main/cpp\"/>\n                                            <arg value=\"-DBUILD_UDF_BENCHMARKS=${BUILD_UDF_BENCHMARKS}\"/>\n                                            <arg value=\"-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}\"/>\n                                            <arg value=\"-DGPU_ARCHS=${GPU_ARCHS}\"/>\n                                            <arg value=\"-DPER_THREAD_DEFAULT_STREAM=${PER_THREAD_DEFAULT_STREAM}\"/>\n                                            <arg value=\"-DCUDF_ENABLE_ARROW_S3=${CUDF_ENABLE_ARROW_S3}\"/>\n                                            <arg value=\"-DUSE_PREBUILT_CUDF=${USE_PREBUILT_CUDF}\"/>\n                                        </exec>\n                                        <exec failonerror=\"true\"\n                                              executable=\"cmake\">\n                                            <arg value=\"--build\"/>\n                                            <arg value=\"${udf.native.build.path}\"/>\n                                            <arg value=\"-j${CPP_PARALLEL_LEVEL}\"/>\n                                            <arg value=\"-v\"/>\n                                        </exec>\n                                    </target>\n                                </configuration>\n                                <goals>\n                                    <goal>run</goal>\n                                </goals>\n                            </execution>\n                        </executions>\n                    </plugin>\n                    <plugin>\n                        <artifactId>maven-resources-plugin</artifactId>\n                        <version>3.2.0</version>\n                        <executions>\n                            <execution>\n                                <id>copy-native-libs-to-deps</id>\n                                <phase>process-classes</phase>\n                                <goals>\n                                    <goal>copy-resources</goal>\n                                </goals>\n                                <configuration>\n                                    <overwrite>true</overwrite>\n                                    <outputDirectory>${project.build.directory}/native-deps/${os.arch}/${os.name}\n                                    </outputDirectory>\n                                    <resources>\n                                        <resource>\n                                            <directory>${udf.native.build.path}</directory>\n                                            <includes>\n                                                <include>libudfexamplesjni.so</include>\n                                            </includes>\n                                        </resource>\n                                    </resources>\n                                </configuration>\n                            </execution>\n                            <execution>\n                                <id>copy-native-libs-to-classes</id>\n                                <phase>process-classes</phase>\n                                <goals>\n                                    <goal>copy-resources</goal>\n                                </goals>\n                                <configuration>\n                                    <overwrite>true</overwrite>\n                                    <!-- Copy to classes directory so it gets included in jar -->\n                                    <outputDirectory>${project.build.outputDirectory}/${os.arch}/${os.name}\n                                    </outputDirectory>\n                                    <resources>\n                                        <resource>\n                                            <directory>${udf.native.build.path}</directory>\n                                            <includes>\n                                                <include>libudfexamplesjni.so</include>\n                                            </includes>\n                                        </resource>\n                                    </resources>\n                                </configuration>\n                            </execution>\n                        </executions>\n                    </plugin>\n                </plugins>\n            </build>\n        </profile>\n    </profiles>\n</project>\n"
  },
  {
    "path": "examples/UDF-Examples/RAPIDS-accelerated-UDFs/pytest.ini",
    "content": "; Copyright (c) 2020-2022, NVIDIA CORPORATION.\n;\n; Licensed under the Apache License, Version 2.0 (the \"License\");\n; you may not use this file except in compliance with the License.\n; You may obtain a copy of the License at\n;\n;     http://www.apache.org/licenses/LICENSE-2.0\n;\n; Unless required by applicable law or agreed to in writing, software\n; distributed under the License is distributed on an \"AS IS\" BASIS,\n; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n; See the License for the specific language governing permissions and\n; limitations under the License.\n\n[pytest]\nmarkers =\n    rapids_udf_example_native: test UDFs that require custom cuda compilation\n"
  },
  {
    "path": "examples/UDF-Examples/RAPIDS-accelerated-UDFs/run_pyspark_from_build.sh",
    "content": "#!/bin/bash\n# Copyright (c) 2022-2025, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nset -ex\n\nSCRIPTPATH=\"$( cd \"$(dirname \"$0\")\" >/dev/null 2>&1 ; pwd -P )\"\ncd \"$SCRIPTPATH\"\n\nif [[ $( echo ${SKIP_TESTS} | tr '[:upper:]' '[:lower:]' ) == \"true\" ]];\nthen\n    echo \"PYTHON INTEGRATION TESTS SKIPPED...\"\n    exit 0\nelif [[ -z \"$SPARK_HOME\" ]];\nthen\n    >&2 echo \"SPARK_HOME IS NOT SET CANNOT RUN PYTHON INTEGRATION TESTS...\"\n    exit 1\nelse\n    echo \"WILL RUN TESTS WITH SPARK_HOME: ${SPARK_HOME}\"\n    # Spark 3.1.1 includes https://github.com/apache/spark/pull/31540\n    # which helps with spurious task failures as observed in our tests. If you are running\n    # Spark versions before 3.1.1, this sets the spark.max.taskFailures to 4 to allow for\n    # more lineant configuration, else it will set them to 1 as spurious task failures are not expected\n    # for Spark 3.1.1+\n    VERSION_STRING=`$SPARK_HOME/bin/pyspark --version 2>&1|grep -v Scala|awk '/version\\ [0-9.]+/{print $NF}'`\n    VERSION_STRING=\"${VERSION_STRING/-SNAPSHOT/}\"\n    [[ -z $VERSION_STRING ]] && { echo \"Unable to detect the Spark version at $SPARK_HOME\"; exit 1; }\n    [[ -z $SPARK_SHIM_VER ]] && { SPARK_SHIM_VER=\"spark${VERSION_STRING//./}\"; }\n\n    echo \"Detected Spark version $VERSION_STRING (shim version: $SPARK_SHIM_VER)\"\n\n    PLUGIN_JARS=$(echo \"$SCRIPTPATH\"/target/dependency/rapids-4-spark*.jar)\n    UDF_EXAMPLE_JARS=$(echo \"$SCRIPTPATH\"/target/rapids-4-spark-udf-examples*.jar)\n    ALL_JARS=\"$PLUGIN_JARS $UDF_EXAMPLE_JARS\"\n    echo \"AND PLUGIN JARS: $ALL_JARS\"\n\n    RUN_TESTS_COMMAND=(\"$SCRIPTPATH\"/runtests.py\n      --rootdir\n      \"$SCRIPTPATH\"\n      \"$SCRIPTPATH\"/src/main/python)\n\n    # --ignore=target is used to exclude the target directory which contains unrelated python files.\n    TEST_COMMON_OPTS=(-v\n          -rfExXs\n          \"$TEST_ARGS\"\n          --color=yes\n          --ignore=target\n          \"$@\")\n\n    \"$SPARK_HOME\"/bin/spark-submit --jars \"${ALL_JARS// /,}\" \\\n        --master local[1] \\\n        \"${RUN_TESTS_COMMAND[@]}\" \"${TEST_COMMON_OPTS[@]}\"\n\nfi\n"
  },
  {
    "path": "examples/UDF-Examples/RAPIDS-accelerated-UDFs/runtests.py",
    "content": "# Copyright (c) 2022, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport sys\n\nfrom pytest import main\n\n#import cProfile\n\nif __name__ == '__main__':\n    #cProfile.run('main(sys.argv[1:])', 'test_profile')\n    # arguments are the same as for pytest https://docs.pytest.org/en/latest/usage.html\n    # or run pytest -h\n    sys.exit(main(sys.argv[1:]))\n"
  },
  {
    "path": "examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/cpp/CMakeLists.txt",
    "content": "#=============================================================================\n# Copyright (c) 2021-2026, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#=============================================================================\n\n# Keep the same with https://github.com/rapidsai/rapids-cmake/blob/main/RAPIDS.cmake\ncmake_minimum_required(VERSION 3.30.4 FATAL_ERROR)\n\n# set to the rapids-cmake-branch\nset(rapids-cmake-branch \"main\")\n\nfile(DOWNLOAD https://raw.githubusercontent.com/rapidsai/rapids-cmake/${rapids-cmake-branch}/RAPIDS.cmake\n     ${CMAKE_BINARY_DIR}/RAPIDS.cmake)\ninclude(${CMAKE_BINARY_DIR}/RAPIDS.cmake)\n\ninclude(rapids-cmake)\ninclude(rapids-cpm)\ninclude(rapids-cuda)\ninclude(rapids-export)\ninclude(rapids-find)\n\n# Get the rapids-cmake directory for later use\n# After include(rapids-cmake), CPM will download rapids-cmake to _deps\n# We can get it from the CPM cache\nget_property(rapids-cmake-dir GLOBAL PROPERTY rapids-cmake-dir)\nif(NOT rapids-cmake-dir)\n    # Fallback: rapids-cmake is downloaded by CPM to _deps\n    set(rapids-cmake-dir \"${CMAKE_BINARY_DIR}/_deps/rapids-cmake-src\")\n    message(STATUS \"rapids-cmake property not set, using fallback path\")\nendif()\n\n# Verify rapids-cmake directory exists\nif(NOT EXISTS \"${rapids-cmake-dir}\")\n    message(FATAL_ERROR \n        \"rapids-cmake directory not found: ${rapids-cmake-dir}\\n\"\n        \"This usually means rapids-cmake wasn't properly fetched by CPM.\\n\"\n        \"Try deleting the build directory and reconfiguring:\\n\"\n        \"  rm -rf ${CMAKE_BINARY_DIR}\\n\"\n        \"  cmake ..\")\nendif()\n\nmessage(STATUS \"rapids-cmake directory: ${rapids-cmake-dir}\")\n\n# Use GPU_ARCHS if it is defined\nif(DEFINED GPU_ARCHS)\n  set(CMAKE_CUDA_ARCHITECTURES \"${GPU_ARCHS}\")\nendif()\nrapids_cuda_init_architectures(UDFEXAMPLESJNI)\n\nproject(UDFEXAMPLESJNI VERSION 26.06.0 LANGUAGES C CXX CUDA)\n\noption(PER_THREAD_DEFAULT_STREAM \"Build with per-thread default stream\" OFF)\noption(BUILD_UDF_BENCHMARKS \"Build the benchmarks\" OFF)\n\n###################################################################################################\n# - build type ------------------------------------------------------------------------------------\n\n# Set a default build type if none was specified\nset(DEFAULT_BUILD_TYPE \"Release\")\n\n###################################################################################################\n# - compiler options ------------------------------------------------------------------------------\n\nset(CMAKE_POSITION_INDEPENDENT_CODE ON)\nset(CMAKE_CXX_STANDARD 20)\nset(CMAKE_CXX_COMPILER $ENV{CXX})\nset(CMAKE_CXX_STANDARD_REQUIRED ON)\n\nset(CMAKE_CUDA_STANDARD 20)\nset(CMAKE_CUDA_STANDARD_REQUIRED ON)\n\nif(CMAKE_COMPILER_IS_GNUCXX)\n    set(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS} -Wall\")\n    set(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS} -Wno-unknown-pragmas\")\n    set(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS} -Wno-error=deprecated-declarations\")\nendif(CMAKE_COMPILER_IS_GNUCXX)\n\nif(CMAKE_CUDA_COMPILER_VERSION)\n  # Compute the version. from  CMAKE_CUDA_COMPILER_VERSION\n  string(REGEX REPLACE \"([0-9]+)\\\\.([0-9]+).*\" \"\\\\1\" CUDA_VERSION_MAJOR ${CMAKE_CUDA_COMPILER_VERSION})\n  string(REGEX REPLACE \"([0-9]+)\\\\.([0-9]+).*\" \"\\\\2\" CUDA_VERSION_MINOR ${CMAKE_CUDA_COMPILER_VERSION})\n  set(CUDA_VERSION \"${CUDA_VERSION_MAJOR}.${CUDA_VERSION_MINOR}\" CACHE STRING \"Version of CUDA as computed from nvcc.\")\n  mark_as_advanced(CUDA_VERSION)\nendif()\n\nmessage(STATUS \"CUDA_VERSION_MAJOR: ${CUDA_VERSION_MAJOR}\")\nmessage(STATUS \"CUDA_VERSION_MINOR: ${CUDA_VERSION_MINOR}\")\nmessage(STATUS \"CUDA_VERSION: ${CUDA_VERSION}\")\n\n# Always set this convenience variable\nset(CUDA_VERSION_STRING \"${CUDA_VERSION}\")\n\nset(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} -w --expt-extended-lambda --expt-relaxed-constexpr\")\n\n####################################################################################################\n# - cudf -------------------------------------------------------------------------------------------\n\n# Check if USE_PREBUILT_CUDF was explicitly set by user (e.g., via -DUSE_PREBUILT_CUDF=...)\n# This must be done BEFORE the option() command\nif(DEFINED USE_PREBUILT_CUDF)\n    set(USER_SET_USE_PREBUILT_CUDF TRUE)\n    message(STATUS \"USE_PREBUILT_CUDF explicitly set by user to: ${USE_PREBUILT_CUDF}\")\nelse()\n    set(USER_SET_USE_PREBUILT_CUDF FALSE)\nendif()\n\noption(USE_PREBUILT_CUDF \"Use prebuilt libcudf.so from rapids-4-spark jar\" ON)\n\nmessage(STATUS \"USE_PREBUILT_CUDF is set to: ${USE_PREBUILT_CUDF}\")\n\n# Check if Maven created a marker to force source build\n# This happens when rapids-4-spark jar is not found\nif(EXISTS \"${CMAKE_CURRENT_SOURCE_DIR}/../../../target/USE_SOURCE_BUILD\")\n    message(STATUS \"Found USE_SOURCE_BUILD marker file from Maven (rapids-4-spark jar not found)\")\n\n    if(USE_PREBUILT_CUDF)\n        if(USER_SET_USE_PREBUILT_CUDF)\n            # User explicitly requested prebuilt mode, but jar is missing - fail fast with clear error\n            message(FATAL_ERROR \n                \"\\n\"\n                \"================================================================\\n\"\n                \"ERROR: rapids-4-spark jar not found, but USE_PREBUILT_CUDF=ON\\n\"\n                \"was explicitly set by the user.\\n\"\n                \"\\n\"\n                \"Cannot proceed with prebuilt mode because required libraries\\n\"\n                \"are not available.\\n\"\n                \"\\n\"\n                \"Solutions:\\n\"\n                \"  1. Remove -DUSE_PREBUILT_CUDF=ON to allow automatic fallback\\n\"\n                \"     to building from source\\n\"\n                \"\\n\"\n                \"  2. Build and install rapids-4-spark:\\n\"\n                \"     cd /path/to/spark-rapids\\n\"\n                \"     mvn clean install -DskipTests\\n\"\n                \"\\n\"\n                \"  3. Explicitly use source build:\\n\"\n                \"     -DUSE_PREBUILT_CUDF=OFF\\n\"\n                \"================================================================\\n\")\n        else()\n            # Not explicitly set by user - safe to auto-fallback\n            message(STATUS \"Auto-fallback: Switching to source build due to missing jar\")\n            set(USE_PREBUILT_CUDF OFF CACHE BOOL \"Auto-fallback to source build (jar not found)\" FORCE)\n        endif()\n    endif()\nendif()\n\n# Check prebuilt availability before making final decision\n# This avoids modifying cache variables within conditional blocks\nset(SHOULD_USE_PREBUILT ${USE_PREBUILT_CUDF})\n\nif(USE_PREBUILT_CUDF AND NOT USER_SET_USE_PREBUILT_CUDF)\n    # User didn't explicitly set the option - check if prebuilt components are available\n    # Set paths for prebuilt library and headers\n    set(CUDF_LIB_DIR \"${CMAKE_CURRENT_SOURCE_DIR}/../../../target/native-deps\")\n    set(CUDF_INCLUDE_DIR \"${CMAKE_CURRENT_SOURCE_DIR}/../../../target/cudf-repo/cpp/include\")\n    \n    message(STATUS \"Checking for prebuilt libcudf.so from rapids-4-spark jar\")\n    message(STATUS \"Looking in: ${CUDF_LIB_DIR}\")\n    \n    # Check if prebuilt components are available\n    set(PREBUILT_AVAILABLE TRUE)\n    if(NOT EXISTS \"${CUDF_LIB_DIR}\")\n        message(STATUS \"Directory ${CUDF_LIB_DIR} does not exist\")\n        set(PREBUILT_AVAILABLE FALSE)\n    else()\n        # Try to find the library\n        # Note: find_library sets variable to <var>-NOTFOUND on failure, not undefined\n        find_library(CUDF_LIBRARY_CHECK\n            NAMES cudf\n            PATHS ${CUDF_LIB_DIR}\n            NO_DEFAULT_PATH\n        )\n        \n        # Proper check: find_library failure results in <VAR>-NOTFOUND string\n        if(CUDF_LIBRARY_CHECK MATCHES \"-NOTFOUND$\")\n            message(STATUS \"libcudf.so not found in ${CUDF_LIB_DIR}\")\n            set(PREBUILT_AVAILABLE FALSE)\n        else()\n            message(STATUS \"Found libcudf at: ${CUDF_LIBRARY_CHECK}\")\n        endif()\n    endif()\n    \n    # Auto-fallback to source build if components not available\n    if(NOT PREBUILT_AVAILABLE)\n        message(WARNING \n            \"\\n\"\n            \"================================================================\\n\"\n            \"Prebuilt libcudf.so not available.\\n\"\n            \"Automatically falling back to building cuDF from source.\\n\"\n            \"This will take 30+ minutes.\\n\"\n            \"\\n\"\n            \"To use fast build mode in future:\\n\"\n            \"  1. For SNAPSHOT versions: Build and install rapids-4-spark\\n\"\n            \"     cd /path/to/spark-rapids\\n\"\n            \"     mvn clean install -DskipTests\\n\"\n            \"  2. Run: mvn clean package -Pudf-native-examples\\n\"\n            \"\\n\"\n            \"NOTE: If you need to reset this decision, delete:\\n\"\n            \"      ${CMAKE_BINARY_DIR}/CMakeCache.txt\\n\"\n            \"================================================================\\n\")\n        set(SHOULD_USE_PREBUILT FALSE)\n        # Update cache for subsequent runs\n        set(USE_PREBUILT_CUDF OFF CACHE BOOL \"Auto-fallback to source build\" FORCE)\n    endif()\nendif()\n\n# Now use the final decision consistently\nif(SHOULD_USE_PREBUILT)\n    # Set paths as cache variables for user customization\n    set(CUDF_LIB_DIR \"${CMAKE_CURRENT_SOURCE_DIR}/../../../target/native-deps\" CACHE PATH \"Path to directory containing libcudf.so\")\n    set(CUDF_INCLUDE_DIR \"${CMAKE_CURRENT_SOURCE_DIR}/../../../target/cudf-repo/cpp/include\" CACHE PATH \"Path to cudf headers\")\n    \n    message(STATUS \"✓ Using FAST BUILD mode with prebuilt libcudf.so\")\n    \n    # Find the prebuilt libcudf.so (should succeed based on earlier check)\n    find_library(CUDF_LIBRARY\n        NAMES cudf\n        PATHS ${CUDF_LIB_DIR}\n        NO_DEFAULT_PATH\n        REQUIRED\n    )\n    \n    message(STATUS \"✓ Found libcudf: ${CUDF_LIBRARY}\")\n    message(STATUS \"✓ cuDF include directory: ${CUDF_INCLUDE_DIR}\")\n    \n    # Verify cuDF source directory exists (cloned by Maven)\n    set(CUDF_SOURCE_DIR \"${CMAKE_CURRENT_SOURCE_DIR}/../../../target/cudf-repo/cpp\")\n    if(NOT EXISTS \"${CUDF_SOURCE_DIR}/CMakeLists.txt\")\n        message(FATAL_ERROR \n            \"cuDF source directory not found: ${CUDF_SOURCE_DIR}\\n\"\n            \"The cuDF repository should have been cloned by Maven.\\n\"\n            \"Check if target/cudf-repo/ exists.\")\n    endif()\n    \n    message(STATUS \"✓ Found cuDF source at: ${CUDF_SOURCE_DIR}\")\n    \n    # We'll use cuDF's dependency fetching mechanism but create our own target\n    # First, let rapids-cpm fetch the dependencies that cuDF needs\n    message(STATUS \"Fetching cuDF dependencies (this may take a few minutes)...\")\n    \n    rapids_cpm_init()\n    \n    # Set options to avoid building unnecessary components\n    set(BUILD_TESTS OFF CACHE BOOL \"\" FORCE)\n    set(BUILD_BENCHMARKS OFF CACHE BOOL \"\" FORCE)\n    \n    # Use rapids-cmake's helper scripts to get CCCL and RMM\n    # These scripts use versions defined in rapids-cmake (avoiding duplicate version definitions)\n    message(STATUS \"Using rapids-cmake helper scripts for CCCL and RMM\")\n\n    # Get CCCL (Thrust, libcudacxx, CUB) - version defined in rapids-cmake\n    set(CCCL_CMAKE_FILE \"${rapids-cmake-dir}/cpm/cccl.cmake\")\n    if(NOT EXISTS \"${CCCL_CMAKE_FILE}\")\n        message(FATAL_ERROR \n            \"rapids-cmake CCCL helper script not found: ${CCCL_CMAKE_FILE}\\n\"\n            \"Expected location: ${rapids-cmake-dir}/cpm/cccl.cmake\\n\"\n            \"This indicates rapids-cmake directory structure is incomplete or incorrect.\")\n    endif()\n    include(${CCCL_CMAKE_FILE})\n    rapids_cpm_cccl()\n\n    # Use rapids-cpm to get RMM - this is what cuDF uses  \n    set(RMM_CMAKE_FILE \"${rapids-cmake-dir}/cpm/rmm.cmake\")\n    if(NOT EXISTS \"${RMM_CMAKE_FILE}\")\n        message(FATAL_ERROR \n            \"rapids-cmake RMM helper script not found: ${RMM_CMAKE_FILE}\\n\"\n            \"Expected location: ${rapids-cmake-dir}/cpm/rmm.cmake\\n\"\n            \"This indicates rapids-cmake directory structure is incomplete or incorrect.\")\n    endif()\n    include(${RMM_CMAKE_FILE})\n    rapids_cpm_rmm()\n\n    # After rapids_cpm_rmm(), the rmm::rmm target should be available\n    # Verify it exists\n    if(NOT TARGET rmm::rmm)\n        message(FATAL_ERROR \"rmm::rmm target not created by rapids_cpm_rmm()\")\n    endif()\n    \n    # Get RMM include directory from the target\n    get_target_property(RMM_INCLUDE_DIR rmm::rmm INTERFACE_INCLUDE_DIRECTORIES)\n    message(STATUS \"RMM include directories: ${RMM_INCLUDE_DIR}\")\n\n    # Now create our own imported target for cudf using the prebuilt library\n    add_library(cudf_imported SHARED IMPORTED GLOBAL)\n    set_target_properties(cudf_imported PROPERTIES\n        IMPORTED_LOCATION ${CUDF_LIBRARY}\n    )\n    \n    # Add include directories to the imported target\n    # Include cuDF headers and RMM headers\n    target_include_directories(cudf_imported INTERFACE\n        ${CUDF_INCLUDE_DIR}\n        ${RMM_INCLUDE_DIR}\n    )\n    \n    # Link against RMM to get other dependencies\n    target_link_libraries(cudf_imported INTERFACE rmm::rmm)\n    \n    # Create an alias to match expected name\n    add_library(cudf::cudf ALIAS cudf_imported)\n    \n    message(STATUS \"✓ Prebuilt cuDF configured with all dependencies\")\n    message(STATUS \"  Prebuilt library: ${CUDF_LIBRARY}\")\n    message(STATUS \"  cuDF headers: ${CUDF_INCLUDE_DIR}\")\n    message(STATUS \"  Dependencies: CCCL, RMM (via rapids-cpm)\")\n    \nelse()\n    message(STATUS \"Building cuDF from source (this will take a long time)\")\n    \n    # Ensure CUDA runtime is dynamic despite statically linking Arrow in libcudf\n    set(CUDA_USE_STATIC_CUDA_RUNTIME ON)\n    \n    rapids_cpm_init()\n    rapids_cpm_find(cudf 26.06.00\n            CPM_ARGS\n            GIT_REPOSITORY  https://github.com/rapidsai/cudf.git\n            GIT_TAG         ${rapids-cmake-branch}\n            GIT_SHALLOW     TRUE\n            SOURCE_SUBDIR   cpp\n            OPTIONS         \"BUILD_TESTS OFF\"\n                            \"BUILD_BENCHMARKS OFF\"\n                            \"CUDF_USE_ARROW_STATIC ON\"\n                            \"JITIFY_USE_CACHE ON\"\n                            \"CUDA_STATIC_RUNTIME ${CUDA_USE_STATIC_CUDA_RUNTIME}\"\n                            \"DISABLE_DEPRECATION_WARNING ON\"\n                            \"AUTO_DETECT_CUDA_ARCHITECTURES OFF\"\n                            \"CUDF_KVIKIO_REMOTE_IO OFF\"\n        )\nendif()\n\n###################################################################################################\n# - benchmarks ------------------------------------------------------------------------------------\n\nif(BUILD_UDF_BENCHMARKS)\n    # Find or install GoogleBench\n    CPMFindPackage(NAME benchmark\n        VERSION         1.5.2\n        GIT_REPOSITORY  https://github.com/google/benchmark.git\n        GIT_TAG         v1.5.2\n        GIT_SHALLOW     TRUE\n        OPTIONS         \"BENCHMARK_ENABLE_TESTING OFF\"\n                        \"BENCHMARK_ENABLE_INSTALL OFF\")\n    add_subdirectory(benchmarks)\nendif()\n\n###################################################################################################\n# - find JNI -------------------------------------------------------------------------------------\n\nfind_package(JNI REQUIRED)\nif(JNI_FOUND)\n    message(STATUS \"JDK with JNI in ${JNI_INCLUDE_DIRS}\")\nelse()\n    message(FATAL_ERROR \"JDK with JNI not found, please check your settings.\")\nendif(JNI_FOUND)\n\n###################################################################################################\n# - library paths ---------------------------------------------------------------------------------\n\n# CMAKE_CUDA_IMPLICIT_LINK_DIRECTORIES is an undocumented/unsupported variable containing the link directories for nvcc\nlink_directories(\"${CMAKE_CUDA_IMPLICIT_LINK_DIRECTORIES}\"\n                 \"${CMAKE_BINARY_DIR}/lib\")\n\n\n###################################################################################################\n# - library targets -------------------------------------------------------------------------------\n\nset(SOURCE_FILES\n    \"src/CosineSimilarityJni.cpp\"\n    \"src/StringWordCountJni.cpp\"\n    \"src/cosine_similarity.cu\"\n    \"src/string_word_count.cu\")\n\nadd_library(udfexamplesjni SHARED ${SOURCE_FILES})\n\n#Override RPATH for udfexamplesjni\nSET_TARGET_PROPERTIES(udfexamplesjni PROPERTIES BUILD_RPATH \"\\$ORIGIN\")\n\n###################################################################################################\n# - build options ---------------------------------------------------------------------------------\n\noption(PER_THREAD_DEFAULT_STREAM \"Build with per-thread default stream\" OFF)\nif(PER_THREAD_DEFAULT_STREAM)\n    message(STATUS \"Using per-thread default stream\")\n    target_compile_definitions(udfexamplesjni PRIVATE CUDA_API_PER_THREAD_DEFAULT_STREAM)\nendif(PER_THREAD_DEFAULT_STREAM)\n\ntarget_include_directories(udfexamplesjni PRIVATE ${JNI_INCLUDE_DIRS})\n\n###################################################################################################\n# - rmm logging level -----------------------------------------------------------------------------\n\nset(RMM_LOGGING_LEVEL \"OFF\" CACHE STRING \"Choose the logging level.\")\n# Set the possible values of build type for cmake-gui\nset_property(CACHE RMM_LOGGING_LEVEL PROPERTY STRINGS\n        \"TRACE\" \"DEBUG\" \"INFO\" \"WARN\" \"ERROR\" \"CRITICAL\" \"OFF\")\nmessage(STATUS \"RMM_LOGGING_LEVEL = '${RMM_LOGGING_LEVEL}'.\")\n\ntarget_compile_definitions(udfexamplesjni\n    PUBLIC SPDLOG_ACTIVE_LEVEL=SPDLOG_LEVEL_${RMM_LOGGING_LEVEL})\n\n###################################################################################################\n# - link libraries --------------------------------------------------------------------------------\n\ntarget_link_libraries(udfexamplesjni cudf::cudf)\n"
  },
  {
    "path": "examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/cpp/benchmarks/CMakeLists.txt",
    "content": "#=============================================================================\n# Copyright (c) 2021-2022, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#=============================================================================\n\n# Use an OBJECT library so we only compile these helper source files only once\nadd_library(udf_benchmark_common OBJECT\n    synchronization/synchronization.cpp)\n\ntarget_link_libraries(udf_benchmark_common PUBLIC benchmark::benchmark cudf)\n\ntarget_include_directories(udf_benchmark_common\n    PUBLIC \"$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}>\"\n           \"$<BUILD_INTERFACE:${UDFEXAMPLESJNI_SOURCE_DIR}>\"\n           \"$<BUILD_INTERFACE:${UDFEXAMPLESJNI_SOURCE_DIR}>/src\")\n\nfunction(ConfigureBench CMAKE_BENCH_NAME)\n    add_executable(${CMAKE_BENCH_NAME} ${ARGN})\n    set_target_properties(${CMAKE_BENCH_NAME}\n        PROPERTIES RUNTIME_OUTPUT_DIRECTORY \"$<BUILD_INTERFACE:${UDFEXAMPLESJNI_BINARY_DIR}/gbenchmarks>\")\n    target_link_libraries(${CMAKE_BENCH_NAME}\n        PRIVATE udf_benchmark_common udfexamplesjni benchmark::benchmark_main)\nendfunction()\n\nConfigureBench(COSINE_SIMILARITY_BENCH cosine_similarity/cosine_similarity_benchmark.cpp)\n"
  },
  {
    "path": "examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/cpp/benchmarks/cosine_similarity/cosine_similarity_benchmark.cpp",
    "content": "/*\n * Copyright (c) 2021-2022, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \"benchmarks/fixture/benchmark_fixture.hpp\"\n#include \"benchmarks/synchronization/synchronization.hpp\"\n#include \"cosine_similarity.hpp\"\n\n#include <cudf/column/column_factories.hpp>\n#include <cudf/filling.hpp>\n#include <cudf/null_mask.hpp>\n#include <cudf/scalar/scalar_factories.hpp>\n\nstatic void cosine_similarity_bench_args(benchmark::internal::Benchmark* b)\n{\n  int const min_rows   = 1 << 12;\n  int const max_rows   = 1 << 24;\n  int const row_mult   = 8;\n  int const min_rowlen = 1 << 0;\n  int const max_rowlen = 1 << 12;\n  int const len_mult   = 8;\n  for (int row_count = min_rows; row_count <= max_rows; row_count *= row_mult) {\n    for (int rowlen = min_rowlen; rowlen <= max_rowlen; rowlen *= len_mult) {\n      // avoid generating combinations that exceed the cudf column limit\n      size_t total_chars = static_cast<size_t>(row_count) * rowlen;\n      if (total_chars < std::numeric_limits<cudf::size_type>::max()) {\n        b->Args({row_count, rowlen});\n      }\n    }\n  }\n}\n\nstatic void BM_cosine_similarity(benchmark::State& state)\n{\n  cudf::size_type const n_rows{static_cast<cudf::size_type>(state.range(0))};\n  cudf::size_type const list_len{static_cast<cudf::size_type>(state.range(1))};\n\n  auto val_start = cudf::make_fixed_width_scalar(1.0f);\n  auto val_step = cudf::make_fixed_width_scalar(-1.0f);\n  auto child_rows = n_rows * list_len;\n  auto col1_child = cudf::sequence(child_rows, *val_start);\n  auto col2_child = cudf::sequence(child_rows, *val_start, *val_step);\n  auto offset_start = cudf::make_fixed_width_scalar(static_cast<int32_t>(0));\n  auto offset_step = cudf::make_fixed_width_scalar(list_len);\n  auto offsets = cudf::sequence(n_rows + 1, *offset_start, *offset_step);\n\n  auto col1 = cudf::make_lists_column(\n      n_rows,\n      std::make_unique<cudf::column>(*offsets),\n      std::move(col1_child),\n      0,\n      cudf::create_null_mask(n_rows, cudf::mask_state::ALL_VALID));\n  auto lcol1 = cudf::lists_column_view(*col1);\n  auto col2 = cudf::make_lists_column(\n      n_rows,\n      std::move(offsets),\n      std::move(col2_child),\n      0,\n      cudf::create_null_mask(n_rows, cudf::mask_state::ALL_VALID));\n  auto lcol2 = cudf::lists_column_view(*col2);\n\n  for (auto _ : state) {\n    cuda_event_timer raii(state, true, rmm::cuda_stream_default);\n    auto output = cosine_similarity(lcol1, lcol2);\n  }\n\n  state.SetBytesProcessed(state.iterations() * child_rows * sizeof(float));\n}\n\nclass CosineSimilarity : public native_udf::benchmark {\n};\n\nBENCHMARK_DEFINE_F(CosineSimilarity, cosine_similarity)\n(::benchmark::State& state) { BM_cosine_similarity(state); }\n\nBENCHMARK_REGISTER_F(CosineSimilarity, cosine_similarity)\n  ->Apply(cosine_similarity_bench_args)\n  ->Unit(benchmark::kMillisecond)\n  ->UseManualTime();\n"
  },
  {
    "path": "examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/cpp/benchmarks/fixture/benchmark_fixture.hpp",
    "content": "/*\n * Copyright (c) 2021-2022, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include <benchmark/benchmark.h>\n#include <rmm/mr/device/cuda_memory_resource.hpp>\n#include <rmm/mr/device/owning_wrapper.hpp>\n#include <rmm/mr/device/per_device_resource.hpp>\n#include <rmm/mr/device/pool_memory_resource.hpp>\n\nnamespace native_udf {\n\nnamespace {\n// memory resource factory helpers\ninline auto make_cuda() { return std::make_shared<rmm::mr::cuda_memory_resource>(); }\n\ninline auto make_pool()\n{\n  return rmm::mr::make_owning_wrapper<rmm::mr::pool_memory_resource>(make_cuda());\n}\n}  // namespace\n\n/**\n * @brief Google Benchmark fixture for native UDF benchmarks\n *\n * Native UDF benchmarks should use a fixture derived from this fixture class to\n * ensure that the RAPIDS Memory Manager pool mode is used in benchmarks, which\n * eliminates memory allocation / deallocation performance overhead from the\n * benchmark.\n *\n * The SetUp and TearDown methods of this fixture initialize RMM into pool mode\n * and finalize it, respectively. These methods are called automatically by\n * Google Benchmark\n *\n * Example:\n *\n * template <class T>\n * class my_benchmark : public native_udf::benchmark {\n * public:\n *   using TypeParam = T;\n * };\n *\n * Then:\n *\n * BENCHMARK_TEMPLATE_DEFINE_F(my_benchmark, my_test_name, int)\n *   (::benchmark::State& state) {\n *     for (auto _ : state) {\n *       // benchmark stuff\n *     }\n * }\n *\n * BENCHMARK_REGISTER_F(my_benchmark, my_test_name)->Range(128, 512);\n */\nclass benchmark : public ::benchmark::Fixture {\n public:\n  virtual void SetUp(const ::benchmark::State& state)\n  {\n    mr = make_pool();\n    rmm::mr::set_current_device_resource(mr.get());  // set default resource to pool\n  }\n\n  virtual void TearDown(const ::benchmark::State& state)\n  {\n    // reset default resource to the initial resource\n    rmm::mr::set_current_device_resource(nullptr);\n    mr.reset();\n  }\n\n  // eliminate partial override warnings (see benchmark/benchmark.h)\n  virtual void SetUp(::benchmark::State& st) { SetUp(const_cast<const ::benchmark::State&>(st)); }\n  virtual void TearDown(::benchmark::State& st)\n  {\n    TearDown(const_cast<const ::benchmark::State&>(st));\n  }\n\n  std::shared_ptr<rmm::mr::device_memory_resource> mr;\n};\n\n}  // namespace native_udf\n"
  },
  {
    "path": "examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/cpp/benchmarks/synchronization/synchronization.cpp",
    "content": "/*\n * Copyright (c) 2021-2022, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \"synchronization.hpp\"\n\n#include <cudf/utilities/error.hpp>\n\n#include <rmm/cuda_stream_view.hpp>\n#include <rmm/device_buffer.hpp>\n\ncuda_event_timer::cuda_event_timer(benchmark::State& state,\n                                   bool flush_l2_cache,\n                                   rmm::cuda_stream_view stream)\n  : stream(stream), p_state(&state)\n{\n  // flush all of L2$\n  if (flush_l2_cache) {\n    int current_device = 0;\n    CUDA_TRY(cudaGetDevice(&current_device));\n\n    int l2_cache_bytes = 0;\n    CUDA_TRY(cudaDeviceGetAttribute(&l2_cache_bytes, cudaDevAttrL2CacheSize, current_device));\n\n    if (l2_cache_bytes > 0) {\n      const int memset_value = 0;\n      rmm::device_buffer l2_cache_buffer(l2_cache_bytes, stream);\n      CUDA_TRY(\n        cudaMemsetAsync(l2_cache_buffer.data(), memset_value, l2_cache_bytes, stream.value()));\n    }\n  }\n\n  CUDA_TRY(cudaEventCreate(&start));\n  CUDA_TRY(cudaEventCreate(&stop));\n  CUDA_TRY(cudaEventRecord(start, stream.value()));\n}\n\ncuda_event_timer::~cuda_event_timer()\n{\n  CUDA_TRY(cudaEventRecord(stop, stream.value()));\n  CUDA_TRY(cudaEventSynchronize(stop));\n\n  float milliseconds = 0.0f;\n  CUDA_TRY(cudaEventElapsedTime(&milliseconds, start, stop));\n  p_state->SetIterationTime(milliseconds / (1000.0f));\n  CUDA_TRY(cudaEventDestroy(start));\n  CUDA_TRY(cudaEventDestroy(stop));\n}\n"
  },
  {
    "path": "examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/cpp/benchmarks/synchronization/synchronization.hpp",
    "content": "/*\n * Copyright (c) 2021-2022, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n/**\n * @file synchronization.hpp\n * @brief This is the header file for `cuda_event_timer`.\n */\n\n/**\n * @brief  This class serves as a wrapper for using `cudaEvent_t` as the user\n * defined timer within the framework of google benchmark\n * (https://github.com/google/benchmark).\n *\n * It is built on top of the idea of Resource acquisition is initialization\n * (RAII). In the following we show a minimal example of how to use this class.\n\n    #include <benchmark/benchmark.h>\n\n    static void sample_cuda_benchmark(benchmark::State& state) {\n\n      for (auto _ : state){\n\n        rmm::cuda_stream_view stream{}; // default stream, could be another stream\n\n        // Create (Construct) an object of this class. You HAVE to pass in the\n        // benchmark::State object you are using. It measures the time from its\n        // creation to its destruction that is spent on the specified CUDA stream.\n        // It also clears the L2 cache by cudaMemset'ing a device buffer that is of\n        // the size of the L2 cache (if flush_l2_cache is set to true and there is\n        // an L2 cache on the current device).\n        cuda_event_timer raii(state, true, stream); // flush_l2_cache = true\n\n        // Now perform the operations that is to be benchmarked\n        sample_kernel<<<1, 256, 0, stream.value()>>>(); // Possibly launching a CUDA kernel\n\n      }\n    }\n\n    // Register the function as a benchmark. You will need to set the `UseManualTime()`\n    // flag in order to use the timer embedded in this class.\n    BENCHMARK(sample_cuda_benchmark)->UseManualTime();\n\n\n */\n\n#ifndef UDF_BENCH_SYNCHRONIZATION_H\n#define UDF_BENCH_SYNCHRONIZATION_H\n\n// Google Benchmark library\n#include <benchmark/benchmark.h>\n\n#include <rmm/cuda_stream_view.hpp>\n\n#include <driver_types.h>\n\nclass cuda_event_timer {\n public:\n  /**\n   * @brief This c'tor clears the L2$ by cudaMemset'ing a buffer of L2$ size\n   * and starts the timer.\n   *\n   * @param[in,out] state  This is the benchmark::State whose timer we are going\n   * to update.\n   * @param[in] flush_l2_cache_ whether or not to flush the L2 cache before\n   *                            every iteration.\n   * @param[in] stream_ The CUDA stream we are measuring time on.\n   */\n  cuda_event_timer(benchmark::State& state,\n                   bool flush_l2_cache,\n                   rmm::cuda_stream_view stream = rmm::cuda_stream_default);\n\n  // The user must provide a benchmark::State object to set\n  // the timer so we disable the default c'tor.\n  cuda_event_timer() = delete;\n\n  // The d'tor stops the timer and performs a synchronization.\n  // Time of the benchmark::State object provided to the c'tor\n  // will be set to the value given by `cudaEventElapsedTime`.\n  ~cuda_event_timer();\n\n private:\n  cudaEvent_t start;\n  cudaEvent_t stop;\n  rmm::cuda_stream_view stream;\n  benchmark::State* p_state;\n};\n\n#endif\n"
  },
  {
    "path": "examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/cpp/src/CosineSimilarityJni.cpp",
    "content": "/*\n * Copyright (c) 2021-2022, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include <cudf/column/column.hpp>\n#include <cudf/column/column_view.hpp>\n#include <cudf/lists/lists_column_view.hpp>\n\n#include <memory>\n#include <jni.h>\n\n#include \"cosine_similarity.hpp\"\n\nnamespace {\n\nconstexpr char const* RUNTIME_ERROR_CLASS = \"java/lang/RuntimeException\";\nconstexpr char const* ILLEGAL_ARG_CLASS   = \"java/lang/IllegalArgumentException\";\n\n/**\n * @brief Throw a Java exception\n *\n * @param env The Java environment\n * @param class_name The fully qualified Java class name of the exception\n * @param msg The message string to associate with the exception\n */\nvoid throw_java_exception(JNIEnv* env, char const* class_name, char const* msg) {\n  jclass ex_class = env->FindClass(class_name);\n  if (ex_class != NULL) {\n    env->ThrowNew(ex_class, msg);\n  }\n}\n\n}  // anonymous namespace\n\nextern \"C\" {\n\n/**\n * @brief The native implementation of CosineSimilarity.cosineSimilarity which\n * computes the cosine similarity between two LIST(FLOAT32) columns as a FLOAT32\n * columnar result.\n *\n * @param env The Java environment\n * @param j_view1 The address of the cudf column view of the first LIST column\n * @param j_view2 The address of the cudf column view of the second LIST column\n * @return The address of the cudf column containing the FLOAT32 results\n */\nJNIEXPORT jlong JNICALL\nJava_com_nvidia_spark_rapids_udf_java_CosineSimilarity_cosineSimilarity(JNIEnv* env, jclass,\n                                                                        jlong j_view1,\n                                                                        jlong j_view2) {\n  // Use a try block to translate C++ exceptions into Java exceptions to avoid\n  // crashing the JVM if a C++ exception occurs.\n  try {\n    // turn the addresses into column_view pointers\n    auto v1 = reinterpret_cast<cudf::column_view const*>(j_view1);\n    auto v2 = reinterpret_cast<cudf::column_view const*>(j_view2);\n    if (v1->type().id() != v2->type().id() || v1->type().id() != cudf::type_id::LIST) {\n      throw_java_exception(env, ILLEGAL_ARG_CLASS, \"inputs not list columns\");\n      return 0;\n    }\n\n    // run the GPU kernel to compute the cosine similarity\n    auto lv1 = cudf::lists_column_view(*v1);\n    auto lv2 = cudf::lists_column_view(*v2);\n    std::unique_ptr<cudf::column> result = cosine_similarity(lv1, lv2);\n\n    // take ownership of the column and return the column address to Java and release the underlying resources.\n    return reinterpret_cast<jlong>(result.release());\n  } catch (std::bad_alloc const& e) {\n    auto msg = std::string(\"Unable to allocate native memory: \") +\n        (e.what() == nullptr ? \"\" : e.what());\n    throw_java_exception(env, RUNTIME_ERROR_CLASS, msg.c_str());\n  } catch (std::invalid_argument const& e) {\n    throw_java_exception(env, ILLEGAL_ARG_CLASS, e.what() == nullptr ? \"\" : e.what());\n  } catch (std::exception const& e) {\n    auto msg = e.what() == nullptr ? \"\" : e.what();\n    throw_java_exception(env, RUNTIME_ERROR_CLASS, msg);\n  }\n  return 0;\n}\n\n}\n"
  },
  {
    "path": "examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/cpp/src/StringWordCountJni.cpp",
    "content": "/*\n * Copyright (c) 2021-2022, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include <cudf/column/column.hpp>\n#include <cudf/column/column_view.hpp>\n\n#include <memory>\n#include <jni.h>\n\n#include \"string_word_count.hpp\"\n\nnamespace {\n\nconstexpr char const* RUNTIME_ERROR_CLASS = \"java/lang/RuntimeException\";\n\n/**\n * @brief Throw a Java exception\n *\n * @param env The Java environment\n * @param class_name The fully qualified Java class name of the exception\n * @param msg The message string to associate with the exception\n */\nvoid throw_java_exception(JNIEnv* env, char const* class_name, char const* msg) {\n  jclass ex_class = env->FindClass(class_name);\n  if (ex_class != NULL) {\n    env->ThrowNew(ex_class, msg);\n  }\n}\n\n}  // anonymous namespace\n\nextern \"C\" {\n\n/**\n * @brief The native implementation of StringWordCount.countWords which counts the\n * number of words per string in a string column.\n *\n * @param env The Java environment\n * @param j_strings The address of the cudf column view of the strings column\n * @return The address of the cudf column containing the word counts\n */\nJNIEXPORT jlong JNICALL\nJava_com_nvidia_spark_rapids_udf_hive_StringWordCount_countWords(JNIEnv* env, jclass,\n                                                                 jlong j_strings) {\n  // Use a try block to translate C++ exceptions into Java exceptions to avoid\n  // crashing the JVM if a C++ exception occurs.\n  try {\n    // turn the addresses into column_view pointers\n    auto strs = reinterpret_cast<cudf::column_view const*>(j_strings);\n\n    // run the GPU kernel to compute the word counts\n    std::unique_ptr<cudf::column> result = string_word_count(*strs);\n\n    // take ownership of the column and return the column address to Java\n    return reinterpret_cast<jlong>(result.release());\n  } catch (std::bad_alloc const& e) {\n    auto msg = std::string(\"Unable to allocate native memory: \") +\n        (e.what() == nullptr ? \"\" : e.what());\n    throw_java_exception(env, RUNTIME_ERROR_CLASS, msg.c_str());\n  } catch (std::exception const& e) {\n    auto msg = e.what() == nullptr ? \"\" : e.what();\n    throw_java_exception(env, RUNTIME_ERROR_CLASS, msg);\n  }\n  return 0;\n}\n\n}\n"
  },
  {
    "path": "examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/cpp/src/cosine_similarity.cu",
    "content": "/*\n * Copyright (c) 2021-2022, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \"cosine_similarity.hpp\"\n\n#include <cudf/column/column_factories.hpp>\n#include <cudf/column/column_device_view.cuh>\n#include <cudf/lists/list_device_view.cuh>\n#include <cudf/lists/lists_column_device_view.cuh>\n#include <cudf/lists/lists_column_view.hpp>\n#include <cudf/null_mask.hpp>\n#include <cudf/table/table_view.hpp>\n#include <cudf/utilities/bit.hpp>\n\n#include <rmm/cuda_stream_view.hpp>\n#include <rmm/device_buffer.hpp>\n#include <rmm/exec_policy.hpp>\n\n#include <thrust/iterator/counting_iterator.h>\n#include <thrust/logical.h>\n#include <thrust/transform.h>\n\n#include <cmath>\n\nnamespace {\n\n/**\n * @brief Functor for computing the cosine similarity between two list of float columns\n */\nstruct cosine_similarity_functor {\n  float const* const v1;\n  float const* const v2;\n  int32_t const* const v1_offsets;\n  int32_t const* const v2_offsets;\n\n  // This kernel executes thread-per-row which should be fine for relatively short lists\n  // but may need to be revisited for performance if operating on long lists.\n  __device__ float operator()(cudf::size_type row_idx) {\n    auto const v1_start_idx = v1_offsets[row_idx];\n    auto const v1_num_elems = v1_offsets[row_idx + 1] - v1_start_idx;\n    auto const v2_start_idx = v2_offsets[row_idx];\n    auto const v2_num_elems = v2_offsets[row_idx + 1] - v2_start_idx;\n    auto const num_elems = std::min(v1_num_elems, v2_num_elems);\n    double mag1 = 0;\n    double mag2 = 0;\n    double dot_product = 0;\n    for (auto i = 0; i < num_elems; i++) {\n      float const f1 = v1[v1_start_idx + i];\n      mag1 += f1 * f1;\n      float const f2 = v2[v2_start_idx + i];\n      mag2 += f2 * f2;\n      dot_product += f1 * f2;\n    }\n    mag1 = std::sqrt(mag1);\n    mag2 = std::sqrt(mag2);\n    return static_cast<float>(dot_product / (mag1 * mag2));\n  }\n};\n\n} // anonymous namespace\n\n/**\n * @brief Compute the cosine similarity between two LIST of FLOAT32 columns\n *\n * The input vectors must have matching shapes, i.e.: same row count and same number of\n * list elements per row. A null list row is supported, but null float entries within a\n * list are not supported.\n *\n * @param lv1 The first LIST of FLOAT32 column view\n * @param lv2 The second LIST of FLOAT32 column view\n * @return A FLOAT32 column containing the cosine similarity corresponding to each input row\n */\nstd::unique_ptr<cudf::column> cosine_similarity(cudf::lists_column_view const& lv1,\n                                                cudf::lists_column_view const& lv2) {\n  // sanity-check the input types\n  if (lv1.child().type().id() != lv2.child().type().id() ||\n      lv1.child().type().id() != cudf::type_id::FLOAT32) {\n    throw std::invalid_argument(\"inputs are not lists of floats\");\n  }\n\n  // sanity check the input shape\n  auto const row_count = lv1.size();\n  if (row_count != lv2.size()) {\n    throw std::invalid_argument(\"input row counts do not match\");\n  }\n  if (row_count == 0) {\n    return cudf::make_empty_column(cudf::data_type{cudf::type_id::FLOAT32});\n  }\n  if (lv1.child().null_count() != 0 || lv2.child().null_count() != 0) {\n    throw std::invalid_argument(\"null floats are not supported\");\n  }\n\n  auto const stream = rmm::cuda_stream_default;\n\n  // Check if list sizes match by comparing offsets differences\n  // Need to handle null lists: if either list is null, consider it valid (will be null in output)\n  auto const lv1_offsets_ptr = lv1.offsets().data<int32_t>();\n  auto const lv2_offsets_ptr = lv2.offsets().data<int32_t>();\n  auto const lv1_null_mask = lv1.parent().null_mask();\n  auto const lv2_null_mask = lv2.parent().null_mask();\n  bool const are_offsets_equal =\n    thrust::all_of(rmm::exec_policy(stream),\n                   thrust::make_counting_iterator<cudf::size_type>(0),\n                   thrust::make_counting_iterator<cudf::size_type>(row_count),\n                   [lv1_offsets_ptr, lv2_offsets_ptr, lv1_null_mask, lv2_null_mask]\n                   __device__(cudf::size_type idx) -> bool {\n                     // Check if either list is null - if so, consider valid\n                     // Use cudf::bit_is_set() for proper bitmask handling\n                     bool lv1_is_null = lv1_null_mask != nullptr && !cudf::bit_is_set(lv1_null_mask, idx);\n                     bool lv2_is_null = lv2_null_mask != nullptr && !cudf::bit_is_set(lv2_null_mask, idx);\n                     if (lv1_is_null || lv2_is_null) return true;\n                     \n                     // Both are valid, check sizes\n                     auto lv1_size = lv1_offsets_ptr[idx + 1] - lv1_offsets_ptr[idx];\n                     auto lv2_size = lv2_offsets_ptr[idx + 1] - lv2_offsets_ptr[idx];\n                     return lv1_size == lv2_size;\n                   });\n  if (not are_offsets_equal) {\n    throw std::invalid_argument(\"input list lengths do not match for every row\");\n  }\n\n  // allocate the vector of float results\n  rmm::device_uvector<float> float_results(row_count, stream);\n\n  // compute the cosine similarity\n  auto const lv1_data = lv1.child().data<float>();\n  auto const lv2_data = lv2.child().data<float>();\n  auto const lv1_offsets = lv1.offsets().data<int32_t>();\n  auto const lv2_offsets = lv2.offsets().data<int32_t>();\n  thrust::transform(rmm::exec_policy(stream),\n                    thrust::make_counting_iterator<cudf::size_type>(0),\n                    thrust::make_counting_iterator<cudf::size_type>(row_count),\n                    float_results.data(),\n                    cosine_similarity_functor({lv1_data, lv2_data, lv1_offsets, lv2_offsets}));\n\n  // the validity of the output is the bitwise-and of the two input validity masks\n  auto [null_mask, null_count] = cudf::bitmask_and(cudf::table_view({lv1.parent(), lv2.parent()}));\n\n  return std::make_unique<cudf::column>(cudf::data_type{cudf::type_id::FLOAT32},\n                                        row_count,\n                                        float_results.release(),\n                                        std::move(null_mask),\n                                        null_count);\n}\n"
  },
  {
    "path": "examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/cpp/src/cosine_similarity.hpp",
    "content": "/*\n * Copyright (c) 2021-2022, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#pragma once\n\n#include <cudf/column/column.hpp>\n#include <cudf/column/column_view.hpp>\n#include <cudf/lists/lists_column_view.hpp>\n\n/**\n * @brief Compute the cosine similarity between two LIST of FLOAT32 columns\n *\n * The input vectors must have matching shapes, i.e.: same row count and same number of\n * list elements per row. A null list row is supported, but null float entries within a\n * list are not supported.\n *\n * @param lv1 The first LIST of FLOAT32 column view\n * @param lv2 The second LIST of FLOAT32 column view\n * @return A FLOAT32 column containing the cosine similarity corresponding to each input row\n */\nstd::unique_ptr<cudf::column> cosine_similarity(cudf::lists_column_view const& lv1,\n                                                cudf::lists_column_view const& lv2);\n"
  },
  {
    "path": "examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/cpp/src/string_word_count.cu",
    "content": "/*\n * Copyright (c) 2021-2022, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \"string_word_count.hpp\"\n\n#include <cudf/column/column_factories.hpp>\n#include <cudf/column/column_device_view.cuh>\n#include <cudf/strings/string_view.cuh>\n\n#include <rmm/cuda_stream_view.hpp>\n#include <rmm/device_buffer.hpp>\n#include <rmm/exec_policy.hpp>\n\n#include <thrust/iterator/counting_iterator.h>\n#include <thrust/transform.h>\n\nnamespace {\n\n// count the words separated by whitespace characters\n__device__ cudf::size_type count_words(cudf::column_device_view const& d_strings,\n                                       cudf::size_type idx) {\n  if (d_strings.is_null(idx)) return 0;\n  cudf::string_view const d_str = d_strings.element<cudf::string_view>(idx);\n  cudf::size_type word_count    = 0;\n  // run of whitespace is considered a single delimiter\n  bool spaces = true;\n  auto itr    = d_str.begin();\n  while (itr != d_str.end()) {\n    cudf::char_utf8 ch = *itr;\n    if (spaces == (ch <= ' ')) {\n      itr++;\n    } else {\n      word_count += static_cast<cudf::size_type>(spaces);\n      spaces = !spaces;\n    }\n  }\n\n  return word_count;\n}\n\n\n} // anonymous namespace\n\n/**\n * @brief Count the words in a string using whitespace as word boundaries\n *\n * @param strs The column containing the strings\n * @param stream The CUDA stream to use\n * @return The INT32 column containing the word count results per string\n */\nstd::unique_ptr<cudf::column> string_word_count(cudf::column_view const& strs) {\n  auto strings_count = strs.size();\n  if (strings_count == 0) {\n    return cudf::make_empty_column(cudf::data_type{cudf::type_id::INT32});\n  }\n\n  // the validity of the output matches the validity of the input\n  rmm::device_buffer null_mask = cudf::copy_bitmask(strs);\n\n  // allocate the column that will contain the word count results\n  std::unique_ptr<cudf::column> result =\n    cudf::make_numeric_column(\n      cudf::data_type{cudf::type_id::INT32},\n      strs.size(),\n      std::move(null_mask),\n      strs.null_count());\n\n  // compute the word counts, writing into the result column data buffer\n  auto stream = rmm::cuda_stream_default;\n  auto strs_device_view = cudf::column_device_view::create(strs, stream);\n  auto d_strs_view = *strs_device_view;\n  thrust::transform(\n    rmm::exec_policy(stream),\n    thrust::make_counting_iterator<cudf::size_type>(0),\n    thrust::make_counting_iterator<cudf::size_type>(strings_count),\n    result->mutable_view().data<cudf::size_type>(),\n    [d_strs_view] __device__(cudf::size_type idx) -> cudf::size_type {\n      return count_words(d_strs_view, idx);\n    });\n\n  return result;\n}\n"
  },
  {
    "path": "examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/cpp/src/string_word_count.hpp",
    "content": "/*\n * Copyright (c) 2021-2022, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#pragma once\n\n#include <cudf/column/column.hpp>\n#include <cudf/column/column_view.hpp>\n\n/**\n * @brief Count the words in a string separated by whitespace\n *\n * @param strs The column containing the strings to be examined\n * @return The INT32 column containing the word count results for each string\n */\nstd::unique_ptr<cudf::column> string_word_count(cudf::column_view const& strs);\n"
  },
  {
    "path": "examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/java/com/nvidia/spark/rapids/udf/hive/DecimalFraction.java",
    "content": "/*\n * Copyright (c) 2021-2022, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.nvidia.spark.rapids.udf.hive;\n\nimport ai.rapids.cudf.ColumnVector;\nimport ai.rapids.cudf.Scalar;\nimport com.nvidia.spark.RapidsUDF;\nimport org.apache.hadoop.hive.common.type.HiveDecimal;\nimport org.apache.hadoop.hive.ql.exec.UDFArgumentException;\nimport org.apache.hadoop.hive.ql.metadata.HiveException;\nimport org.apache.hadoop.hive.ql.udf.generic.GenericUDF;\nimport org.apache.hadoop.hive.serde2.io.HiveDecimalWritable;\nimport org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;\nimport org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;\nimport org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;\nimport org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo;\n\nimport java.math.BigDecimal;\n\n\n/**\n * A simple HiveGenericUDF demo for DecimalType, which extracts and returns\n * the fraction part of the input Decimal data. So, the output data has the\n * same precision and scale as the input one.\n */\npublic class DecimalFraction extends GenericUDF implements RapidsUDF {\n  private transient PrimitiveObjectInspector inputOI;\n\n  @Override\n  public String getDisplayString(String[] strings) {\n    return getStandardDisplayString(\"DecimalFraction\", strings);\n  }\n\n  @Override\n  public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException {\n    if (arguments.length != 1) {\n      throw new UDFArgumentException(\"One argument is supported, found: \" + arguments.length);\n    }\n    if (!(arguments[0] instanceof PrimitiveObjectInspector)) {\n      throw new UDFArgumentException(\"Unsupported argument type: \" + arguments[0].getTypeName());\n    }\n\n    inputOI = (PrimitiveObjectInspector) arguments[0];\n    if (inputOI.getPrimitiveCategory() != PrimitiveObjectInspector.PrimitiveCategory.DECIMAL) {\n      throw new UDFArgumentException(\"Unsupported primitive type: \" + inputOI.getPrimitiveCategory());\n    }\n\n    DecimalTypeInfo inputTypeInfo = (DecimalTypeInfo) inputOI.getTypeInfo();\n\n    return PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(inputTypeInfo);\n  }\n\n  @Override\n  public Object evaluate(GenericUDF.DeferredObject[] arguments) throws HiveException {\n    if (arguments[0] == null || arguments[0].get() == null) {\n      return null;\n    }\n\n    Object input = arguments[0].get();\n    HiveDecimalWritable decimalWritable = (HiveDecimalWritable) inputOI.getPrimitiveWritableObject(input);\n    BigDecimal decimalInput = decimalWritable.getHiveDecimal().bigDecimalValue();\n    BigDecimal decimalResult = decimalInput.subtract(new BigDecimal(decimalInput.toBigInteger()));\n    HiveDecimalWritable result = new HiveDecimalWritable(decimalWritable);\n    result.set(HiveDecimal.create(decimalResult));\n\n    return result;\n  }\n\n  @Override\n  public ColumnVector evaluateColumnar(int numRows, ColumnVector... args) {\n    if (args.length != 1) {\n      throw new IllegalArgumentException(\"Unexpected argument count: \" + args.length);\n    }\n    ColumnVector input = args[0];\n    if (numRows != input.getRowCount()) {\n      throw new IllegalArgumentException(\"Expected \" + numRows + \" rows, received \" + input.getRowCount());\n    }\n    if (!input.getType().isDecimalType()) {\n      throw new IllegalArgumentException(\"Argument type is not a decimal column: \" +\n          input.getType());\n    }\n\n    try (Scalar nullScalar = Scalar.fromNull(input.getType());\n         ColumnVector nullPredicate = input.isNull();\n         ColumnVector integral = input.floor();\n         ColumnVector fraction = input.sub(integral, input.getType())) {\n      return nullPredicate.ifElse(nullScalar, fraction);\n    }\n  }\n}\n"
  },
  {
    "path": "examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/java/com/nvidia/spark/rapids/udf/hive/StringWordCount.java",
    "content": "/*\n * Copyright (c) 2021-2022, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.nvidia.spark.rapids.udf.hive;\n\nimport ai.rapids.cudf.ColumnVector;\nimport ai.rapids.cudf.DType;\nimport ai.rapids.cudf.NativeDepsLoader;\nimport com.nvidia.spark.RapidsUDF;\nimport com.nvidia.spark.rapids.udf.java.NativeUDFExamplesLoader;\nimport org.apache.hadoop.hive.ql.exec.UDF;\n\nimport java.io.IOException;\n\n/**\n * A user-defined function (UDF) that counts the words in a string.\n * This avoids the manifestation of intermediate results required when\n * splitting the string on whitespace and counting the split results.\n * <p>\n * This class demonstrates how to implement a Hive UDF with a RAPIDS\n * implementation that uses custom native code.\n */\npublic class StringWordCount extends UDF implements RapidsUDF {\n  private volatile boolean isNativeCodeLoaded = false;\n\n  /** Row-by-row implementation that executes on the CPU */\n  public Integer evaluate(String str) {\n    if (str == null) {\n      return null;\n    }\n\n    int numWords = 0;\n    // run of whitespace is considered a single delimiter\n    boolean spaces = true;\n    for (int idx = 0; idx < str.length(); idx++) {\n      char ch = str.charAt(idx);\n      if (spaces != (ch <= ' ')) {\n        if (spaces) {\n          numWords++;\n        }\n        spaces = !spaces;\n      }\n    }\n    return numWords;\n  }\n\n  /** Columnar implementation that runs on the GPU */\n  @Override\n  public ColumnVector evaluateColumnar(int numRows, ColumnVector... args) {\n    // The CPU implementation takes a single string argument, so similarly\n    // there should only be one column argument of type STRING.\n    if (args.length != 1) {\n      throw new IllegalArgumentException(\"Unexpected argument count: \" + args.length);\n    }\n    ColumnVector strs = args[0];\n    if (numRows != strs.getRowCount()) {\n      throw new IllegalArgumentException(\"Expected \" + numRows + \" rows, received \" + strs.getRowCount());\n    }\n    if (!strs.getType().equals(DType.STRING)) {\n      throw new IllegalArgumentException(\"type mismatch, expected strings but found \" +\n          strs.getType());\n    }\n\n    // Load the native code if it has not been already loaded. This is done here\n    // rather than in a static code block since the driver may not have the\n    // required CUDA environment.\n    NativeUDFExamplesLoader.ensureLoaded();\n\n    return new ColumnVector(countWords(strs.getNativeView()));\n  }\n\n  private static native long countWords(long stringsView);\n}\n"
  },
  {
    "path": "examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/java/com/nvidia/spark/rapids/udf/hive/URLDecode.java",
    "content": "/*\n * Copyright (c) 2020-2022, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.nvidia.spark.rapids.udf.hive;\n\nimport ai.rapids.cudf.ColumnVector;\nimport ai.rapids.cudf.DType;\nimport ai.rapids.cudf.Scalar;\nimport com.nvidia.spark.RapidsUDF;\nimport org.apache.hadoop.hive.ql.exec.UDF;\n\nimport java.io.UnsupportedEncodingException;\nimport java.net.URLDecoder;\n\n/**\n * A Hive user-defined function (UDF) that decodes URL-encoded strings.\n * This class demonstrates how to implement a simple Hive UDF that also\n * provides a RAPIDS implementation that can run on the GPU when the query\n * is executed with the RAPIDS Accelerator for Apache Spark.\n */\npublic class URLDecode extends UDF implements RapidsUDF {\n\n  /** Row-by-row implementation that executes on the CPU */\n  public String evaluate(String s) {\n    String result = null;\n    if (s != null) {\n      try {\n        result = URLDecoder.decode(s, \"utf-8\");\n      } catch (IllegalArgumentException ignored) {\n        result = s;\n      } catch (UnsupportedEncodingException e) {\n        // utf-8 is a builtin, standard encoding, so this should never happen\n        throw new RuntimeException(e);\n      }\n    }\n    return result;\n  }\n\n  /** Columnar implementation that runs on the GPU */\n  @Override\n  public ColumnVector evaluateColumnar(int numRows, ColumnVector... args) {\n    // The CPU implementation takes a single string argument, so similarly\n    // there should only be one column argument of type STRING.\n    if (args.length != 1) {\n      throw new IllegalArgumentException(\"Unexpected argument count: \" + args.length);\n    }\n    ColumnVector input = args[0];\n    if (numRows != input.getRowCount()) {\n      throw new IllegalArgumentException(\"Expected \" + numRows + \" rows, received \" + input.getRowCount());\n    }\n    if (!input.getType().equals(DType.STRING)) {\n      throw new IllegalArgumentException(\"Argument type is not a string column: \" +\n          input.getType());\n    }\n\n    // The cudf urlDecode does not convert '+' to a space, so do that as a pre-pass first.\n    // All intermediate results are closed to avoid leaking GPU resources.\n    try (Scalar plusScalar = Scalar.fromString(\"+\");\n         Scalar spaceScalar = Scalar.fromString(\" \");\n         ColumnVector replaced = input.stringReplace(plusScalar, spaceScalar)) {\n      return replaced.urlDecode();\n    }\n  }\n}\n"
  },
  {
    "path": "examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/java/com/nvidia/spark/rapids/udf/hive/URLEncode.java",
    "content": "/*\n * Copyright (c) 2020-2022, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.nvidia.spark.rapids.udf.hive;\n\nimport ai.rapids.cudf.ColumnVector;\nimport ai.rapids.cudf.DType;\nimport com.nvidia.spark.RapidsUDF;\nimport org.apache.hadoop.hive.ql.exec.UDFArgumentException;\nimport org.apache.hadoop.hive.ql.metadata.HiveException;\nimport org.apache.hadoop.hive.ql.udf.generic.GenericUDF;\nimport org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;\nimport org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;\nimport org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorConverter;\nimport org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;\nimport org.apache.hadoop.io.Text;\n\nimport java.io.UnsupportedEncodingException;\nimport java.net.URLEncoder;\n\n/**\n * A Hive user-defined function (UDF) that URL-encodes strings.\n * This class demonstrates how to implement a Hive GenericUDF that also\n * provides a RAPIDS implementation that can run on the GPU when the query\n * is executed with the RAPIDS Accelerator for Apache Spark.\n */\npublic class URLEncode extends GenericUDF implements RapidsUDF {\n  private transient PrimitiveObjectInspectorConverter.TextConverter converter;\n  private final Text textResult = new Text();\n\n  /** Standard getDisplayString method for implementing GenericUDF */\n  @Override\n  public String getDisplayString(String[] children) {\n    return getStandardDisplayString(\"urlencode\", children);\n  }\n\n  /** Standard initialize method for implementing GenericUDF for a single string parameter */\n  @Override\n  public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException {\n    if (arguments.length != 1) {\n      throw new UDFArgumentException(\"One argument is supported, found: \" + arguments.length);\n    }\n    if (!(arguments[0] instanceof PrimitiveObjectInspector)) {\n      throw new UDFArgumentException(\"Unsupported argument type: \" + arguments[0].getTypeName());\n    }\n    PrimitiveObjectInspector poi = (PrimitiveObjectInspector) arguments[0];\n    switch (poi.getPrimitiveCategory()) {\n      case STRING:\n      case CHAR:\n      case VARCHAR:\n        break;\n      default:\n        throw new UDFArgumentException(\"Unsupported primitive type: \" + poi.getPrimitiveCategory());\n    }\n\n    converter = new PrimitiveObjectInspectorConverter.TextConverter(poi);\n    return PrimitiveObjectInspectorFactory.writableStringObjectInspector;\n  }\n\n  /** Row-by-row implementation that executes on the CPU */\n  @Override\n  public Object evaluate(GenericUDF.DeferredObject[] arguments) throws HiveException {\n    Text text = converter.convert(arguments[0].get());\n    if (text == null) {\n      return null;\n    }\n    String encoded;\n    try {\n      encoded = URLEncoder.encode(text.toString(), \"utf-8\")\n          .replace(\"+\", \"%20\")\n          .replace(\"*\", \"%2A\")\n          .replace(\"%7E\", \"~\");\n    } catch (UnsupportedEncodingException e) {\n      // utf-8 is a builtin, standard encoding, so this should never happen\n      throw new RuntimeException(e);\n    }\n    textResult.set(encoded);\n    return textResult;\n  }\n\n  /** Columnar implementation that runs on the GPU */\n  @Override\n  public ColumnVector evaluateColumnar(int numRows, ColumnVector... args) {\n    // The CPU implementation takes a single string argument, so similarly\n    // there should only be one column argument of type STRING.\n    if (args.length != 1) {\n      throw new IllegalArgumentException(\"Unexpected argument count: \" + args.length);\n    }\n    ColumnVector input = args[0];\n    if (numRows != input.getRowCount()) {\n      throw new IllegalArgumentException(\"Expected \" + numRows + \" rows, received \" + input.getRowCount());\n    }\n    if (!input.getType().equals(DType.STRING)) {\n      throw new IllegalArgumentException(\"Argument type is not a string column: \" +\n          input.getType());\n    }\n\n    return input.urlEncode();\n  }\n}\n"
  },
  {
    "path": "examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/java/com/nvidia/spark/rapids/udf/java/CosineSimilarity.java",
    "content": "/*\n * Copyright (c) 2021-2022, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.nvidia.spark.rapids.udf.java;\n\nimport ai.rapids.cudf.ColumnVector;\nimport com.nvidia.spark.RapidsUDF;\nimport org.apache.spark.sql.api.java.UDF2;\nimport scala.collection.mutable.WrappedArray;\n\n/**\n * A Spark Java UDF that computes the cosine similarity between two float vectors.\n * The input vectors must have matching shapes, i.e.: same number of elements.\n * A null vector is supported, but null entries within the vector are not supported.\n */\npublic class CosineSimilarity\n    implements UDF2<WrappedArray<Float>, WrappedArray<Float>, Float>, RapidsUDF {\n\n  /** Row-by-row implementation that executes on the CPU */\n  @Override\n  public Float call(WrappedArray<Float> v1, WrappedArray<Float> v2) {\n    if (v1 == null || v2 == null) {\n      return null;\n    }\n    if (v1.length() != v2.length()) {\n      throw new IllegalArgumentException(\"Array lengths must match: \" +\n          v1.length() + \" != \" + v2.length());\n    }\n\n    double dotProduct = 0;\n    for (int i = 0; i < v1.length(); i++) {\n      float f1 = v1.apply(i);\n      float f2 = v2.apply(i);\n      dotProduct += f1 * f2;\n    }\n    double magProduct = magnitude(v1) * magnitude(v2);\n    return (float) (dotProduct / magProduct);\n  }\n\n  private double magnitude(WrappedArray<Float> v) {\n    double sum = 0;\n    for (int i = 0; i < v.length(); i++) {\n      float x = v.apply(i);\n      sum += x * x;\n    }\n    return Math.sqrt(sum);\n  }\n\n  /** Columnar implementation that processes data on the GPU */\n  @Override\n  public ColumnVector evaluateColumnar(int numRows, ColumnVector... args) {\n    if (args.length != 2) {\n      throw new IllegalArgumentException(\"Unexpected argument count: \" + args.length);\n    }\n\n    // Load the native code if it has not been already loaded. This is done here\n    // rather than in a static code block since the driver may not have the\n    // required CUDA environment. \n    NativeUDFExamplesLoader.ensureLoaded();\n    \n    // We need to go into the native code as quickly as possible\n    // because it is easier to write the code safely.\n    // Then wrap returns in a column vector and own that resource.\n    return new ColumnVector(cosineSimilarity(args[0].getNativeView(), args[1].getNativeView()));\n  }\n\n  /** Native implementation that computes on the GPU */\n  private static native long cosineSimilarity(long vectorView1, long vectorView2);\n}\n"
  },
  {
    "path": "examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/java/com/nvidia/spark/rapids/udf/java/DecimalFraction.java",
    "content": "/*\n * Copyright (c) 2021-2022, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.nvidia.spark.rapids.udf.java;\n\nimport ai.rapids.cudf.ColumnVector;\nimport ai.rapids.cudf.Scalar;\nimport com.nvidia.spark.RapidsUDF;\nimport org.apache.spark.sql.api.java.UDF1;\n\nimport java.math.BigDecimal;\n\n/**\n * A simple Java UDF demo for DecimalType, which extracts and returns the\n * fraction part of the input Decimal data. So, the output data has the\n * same precision and scale as the input one.\n */\npublic class DecimalFraction implements UDF1<BigDecimal, BigDecimal>, RapidsUDF {\n\n  @Override\n  public BigDecimal call(BigDecimal dec) throws Exception {\n    if (dec == null) {\n      return null;\n    }\n    BigDecimal integral = new BigDecimal(dec.toBigInteger());\n    return dec.subtract(integral);\n  }\n\n  @Override\n  public ColumnVector evaluateColumnar(int numRows, ColumnVector... args) {\n    if (args.length != 1) {\n      throw new IllegalArgumentException(\"Unexpected argument count: \" + args.length);\n    }\n    ColumnVector input = args[0];\n    if (!input.getType().isDecimalType()) {\n      throw new IllegalArgumentException(\"Argument type is not a decimal column: \" +\n          input.getType());\n    }\n\n    try (Scalar nullScalar = Scalar.fromNull(input.getType());\n         ColumnVector nullPredicate = input.isNull();\n         ColumnVector integral = input.floor();\n         ColumnVector fraction = input.sub(integral, input.getType())) {\n      return nullPredicate.ifElse(nullScalar, fraction);\n    }\n  }\n}\n"
  },
  {
    "path": "examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/java/com/nvidia/spark/rapids/udf/java/NativeUDFExamplesLoader.java",
    "content": "/*\n * Copyright (c) 2021-2022, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.nvidia.spark.rapids.udf.java;\n\nimport ai.rapids.cudf.NativeDepsLoader;\n\nimport java.io.IOException;\n\n/** Loads the native dependencies for UDF examples with a native implementation */\npublic class NativeUDFExamplesLoader {\n  private static boolean isLoaded;\n\n  /** Loads native UDF code if necessary */\n  public static synchronized void ensureLoaded() {\n    if (!isLoaded) {\n      try {\n        NativeDepsLoader.loadNativeDeps(new String[]{\"udfexamplesjni\"});\n        isLoaded = true;\n      } catch (IOException e) {\n        throw new RuntimeException(e);\n      }\n    }\n  }\n}\n"
  },
  {
    "path": "examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/java/com/nvidia/spark/rapids/udf/java/URLDecode.java",
    "content": "/*\n * Copyright (c) 2021-2022, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.nvidia.spark.rapids.udf.java;\n\nimport ai.rapids.cudf.ColumnVector;\nimport ai.rapids.cudf.DType;\nimport ai.rapids.cudf.Scalar;\nimport com.nvidia.spark.RapidsUDF;\nimport org.apache.spark.sql.api.java.UDF1;\n\nimport java.io.UnsupportedEncodingException;\nimport java.net.URLDecoder;\n\n/**\n * A Java user-defined function (UDF) that decodes URL-encoded strings.\n * This class demonstrates how to implement a Java UDF that also\n * provides a RAPIDS implementation that can run on the GPU when the query\n * is executed with the RAPIDS Accelerator for Apache Spark.\n */\npublic class URLDecode implements UDF1<String, String>, RapidsUDF {\n  /** Row-by-row implementation that executes on the CPU */\n  @Override\n  public String call(String s) {\n    String result = null;\n    if (s != null) {\n      try {\n        result = URLDecoder.decode(s, \"utf-8\");\n      } catch (IllegalArgumentException ignored) {\n        result = s;\n      } catch (UnsupportedEncodingException e) {\n        // utf-8 is a builtin, standard encoding, so this should never happen\n        throw new RuntimeException(e);\n      }\n    }\n    return result;\n  }\n\n  /** Columnar implementation that runs on the GPU */\n  @Override\n  public ColumnVector evaluateColumnar(int numRows, ColumnVector... args) {\n    // The CPU implementation takes a single string argument, so similarly\n    // there should only be one column argument of type STRING.\n    if (args.length != 1) {\n      throw new IllegalArgumentException(\"Unexpected argument count: \" + args.length);\n    }\n    ColumnVector input = args[0];\n    if (numRows != input.getRowCount()) {\n      throw new IllegalArgumentException(\"Expected \" + numRows + \" rows, received \" + input.getRowCount());\n    }\n    if (!input.getType().equals(DType.STRING)) {\n      throw new IllegalArgumentException(\"Argument type is not a string column: \" +\n          input.getType());\n    }\n\n    // The cudf urlDecode does not convert '+' to a space, so do that as a pre-pass first.\n    // All intermediate results are closed to avoid leaking GPU resources.\n    try (Scalar plusScalar = Scalar.fromString(\"+\");\n         Scalar spaceScalar = Scalar.fromString(\" \");\n         ColumnVector replaced = input.stringReplace(plusScalar, spaceScalar)) {\n      return replaced.urlDecode();\n    }\n  }\n}\n"
  },
  {
    "path": "examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/java/com/nvidia/spark/rapids/udf/java/URLEncode.java",
    "content": "/*\n * Copyright (c) 2021-2022, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.nvidia.spark.rapids.udf.java;\n\nimport ai.rapids.cudf.ColumnVector;\nimport ai.rapids.cudf.DType;\nimport com.nvidia.spark.RapidsUDF;\nimport org.apache.spark.sql.api.java.UDF1;\n\nimport java.io.UnsupportedEncodingException;\nimport java.net.URLEncoder;\n\n/**\n * A Java user-defined function (UDF) that URL-encodes strings.\n * This class demonstrates how to implement a Java UDF that also\n * provides a RAPIDS implementation that can run on the GPU when the query\n * is executed with the RAPIDS Accelerator for Apache Spark.\n */\npublic class URLEncode implements UDF1<String, String>, RapidsUDF {\n  /** Row-by-row implementation that executes on the CPU */\n  @Override\n  public String call(String s) {\n    if (s == null) {\n      return null;\n    }\n    try {\n      return URLEncoder.encode(s, \"utf-8\")\n          .replace(\"+\", \"%20\")\n          .replace(\"*\", \"%2A\")\n          .replace(\"%7E\", \"~\");\n    } catch (UnsupportedEncodingException e) {\n      // utf-8 is a builtin, standard encoding, so this should never happen\n      throw new RuntimeException(e);\n    }\n  }\n\n  /** Columnar implementation that runs on the GPU */\n  @Override\n  public ColumnVector evaluateColumnar(int numRows, ColumnVector... args) {\n    // The CPU implementation takes a single string argument, so similarly\n    // there should only be one column argument of type STRING.\n    if (args.length != 1) {\n      throw new IllegalArgumentException(\"Unexpected argument count: \" + args.length);\n    }\n    ColumnVector input = args[0];\n    if (numRows != input.getRowCount()) {\n      throw new IllegalArgumentException(\"Expected \" + numRows + \" rows, received \" + input.getRowCount());\n    }\n    if (!input.getType().equals(DType.STRING)) {\n      throw new IllegalArgumentException(\"Argument type is not a string column: \" +\n          input.getType());\n    }\n\n    return input.urlEncode();\n  }\n}\n"
  },
  {
    "path": "examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/python/asserts.py",
    "content": "# Copyright (c) 2020-2022, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom conftest import is_incompat, should_sort_on_spark, should_sort_locally, get_float_check, get_limit, spark_jvm\nfrom datetime import date, datetime\nfrom decimal import Decimal\nimport math\nfrom pyspark.sql import Row\nfrom py4j.protocol import Py4JJavaError\n\nimport pytest\nfrom spark_session import with_cpu_session, with_gpu_session\nimport time\nimport types as pytypes\nimport data_gen\n\ndef _assert_equal(cpu, gpu, float_check, path):\n    t = type(cpu)\n    if (t is Row):\n        assert len(cpu) == len(gpu), \"CPU and GPU row have different lengths at {} CPU: {} GPU: {}\".format(path, len(cpu), len(gpu))\n        if hasattr(cpu, \"__fields__\") and hasattr(gpu, \"__fields__\"):\n            assert cpu.__fields__ == gpu.__fields__, \"CPU and GPU row have different fields at {} CPU: {} GPU: {}\".format(path, cpu.__fields__, gpu.__fields__)\n            for field in cpu.__fields__:\n                _assert_equal(cpu[field], gpu[field], float_check, path + [field])\n        else:\n            for index in range(len(cpu)):\n                _assert_equal(cpu[index], gpu[index], float_check, path + [index])\n    elif (t is list):\n        assert len(cpu) == len(gpu), \"CPU and GPU list have different lengths at {} CPU: {} GPU: {}\".format(path, len(cpu), len(gpu))\n        for index in range(len(cpu)):\n            _assert_equal(cpu[index], gpu[index], float_check, path + [index])\n    elif (t is tuple):\n        assert len(cpu) == len(gpu), \"CPU and GPU list have different lengths at {} CPU: {} GPU: {}\".format(path, len(cpu), len(gpu))\n        for index in range(len(cpu)):\n            _assert_equal(cpu[index], gpu[index], float_check, path + [index])\n    elif (t is pytypes.GeneratorType):\n        index = 0\n        # generator has no zip :( so we have to do this the hard way\n        done = False\n        while not done:\n            sub_cpu = None\n            sub_gpu = None\n            try:\n                sub_cpu = next(cpu)\n            except StopIteration:\n                done = True\n\n            try:\n                sub_gpu = next(gpu)\n            except StopIteration:\n                done = True\n\n            if done:\n                assert sub_cpu == sub_gpu and sub_cpu == None, \"CPU and GPU generators have different lengths at {}\".format(path)\n            else:\n                _assert_equal(sub_cpu, sub_gpu, float_check, path + [index])\n\n            index = index + 1\n    elif (t is dict):\n        # The order of key/values is not guaranteed in python dicts, nor are they guaranteed by Spark\n        # so sort the items to do our best with ignoring the order of dicts\n        cpu_items = list(cpu.items()).sort(key=_RowCmp)\n        gpu_items = list(gpu.items()).sort(key=_RowCmp)\n        _assert_equal(cpu_items, gpu_items, float_check, path + [\"map\"])\n    elif (t is int):\n        assert cpu == gpu, \"GPU and CPU int values are different at {}\".format(path)\n    elif (t is float):\n        if (math.isnan(cpu)):\n            assert math.isnan(gpu), \"GPU and CPU float values are different at {}\".format(path)\n        else:\n            assert float_check(cpu, gpu), \"GPU and CPU float values are different {}\".format(path)\n    elif isinstance(cpu, str):\n        assert cpu == gpu, \"GPU and CPU string values are different at {}\".format(path)\n    elif isinstance(cpu, datetime):\n        assert cpu == gpu, \"GPU and CPU timestamp values are different at {}\".format(path)\n    elif isinstance(cpu, date):\n        assert cpu == gpu, \"GPU and CPU date values are different at {}\".format(path)\n    elif isinstance(cpu, bool):\n        assert cpu == gpu, \"GPU and CPU boolean values are different at {}\".format(path)\n    elif isinstance(cpu, Decimal):\n        assert cpu == gpu, \"GPU and CPU decimal values are different at {}\".format(path)\n    elif isinstance(cpu, bytearray):\n        assert cpu == gpu, \"GPU and CPU bytearray values are different at {}\".format(path)\n    elif (cpu == None):\n        assert cpu == gpu, \"GPU and CPU are not both null at {}\".format(path)\n    else:\n        assert False, \"Found unexpected type {} at {}\".format(t, path)\n\ndef assert_equal(cpu, gpu):\n    \"\"\"Verify that the result from the CPU and the GPU are equal\"\"\"\n    try:\n      _assert_equal(cpu, gpu, float_check=get_float_check(), path=[])\n    except:\n      print(\"CPU OUTPUT: %s\" % cpu)\n      print(\"GPU OUTPUT: %s\" % gpu)\n      raise\n\ndef _has_incompat_conf(conf):\n    return ('spark.rapids.sql.incompatibleOps.enabled' in conf and\n            conf['spark.rapids.sql.incompatibleOps.enabled'].lower() == 'true')\n\nclass _RowCmp(object):\n    \"\"\"Allows for sorting Rows in a consistent way\"\"\"\n    def __init__(self, wrapped):\n        if isinstance(wrapped, Row) or isinstance(wrapped, list) or isinstance(wrapped, tuple):\n            self.wrapped = [_RowCmp(c) for c in wrapped]\n        elif isinstance(wrapped, dict):\n            def sort_dict(e):\n                return _RowCmp(e)\n            tmp = [(k, v) for k, v in wrapped.items()]\n            tmp.sort(key=sort_dict)\n            self.wrapped = [_RowCmp(c) for c in tmp]\n        else:\n            self.wrapped = wrapped\n\n        if isinstance(wrapped, float):\n            self.is_nan = math.isnan(wrapped)\n        else:\n            self.is_nan = False\n\n    def cmp(self, other):\n        try:\n            #None comes before anything else\n            #NaN comes next\n            if (self.wrapped is None and other.wrapped is None):\n                return 0\n            elif (self.wrapped is None):\n                return -1\n            elif (other.wrapped is None):\n                return 1\n            elif self.is_nan and other.is_nan:\n                return 0\n            elif self.is_nan:\n                return -1\n            elif other.is_nan:\n                return 1\n            elif self.wrapped == other.wrapped:\n                return 0\n            elif self.wrapped < other.wrapped:\n                return -1\n            else:\n                return 1\n        except TypeError as te:\n            print(\"ERROR TRYING TO COMPARE {} to {} {}\".format(self.wrapped, other.wrapped, te))\n            raise te\n\n\n    def __lt__(self, other):\n        return self.cmp(other) < 0\n\n    def __gt__(self, other):\n        return self.cmp(other) > 0\n\n    def __eq__(self, other):\n        return self.cmp(other) == 0\n\n    def __le__(self, other):\n        return self.cmp(other) <= 0\n\n    def __ge__(self, other):\n        return self.cmp(other) >= 0\n\n    def __ne__(self, other):\n        return self.cmp(other) != 0\n\ndef _prep_func_for_compare(func, mode):\n    sort_locally = should_sort_locally()\n    if should_sort_on_spark():\n        def with_sorted(spark):\n            df = func(spark)\n            return df.sort(df.columns)\n\n        sorted_func = with_sorted\n    else:\n        sorted_func = func\n\n    limit_val = get_limit()\n    if limit_val > 0:\n        def with_limit(spark):\n            df = sorted_func(spark)\n            return df.limit(limit_val)\n        limit_func = with_limit\n    else:\n        limit_func = sorted_func\n\n    if mode == 'COLLECT':\n        bring_back = lambda spark: limit_func(spark).collect()\n        collect_type = 'COLLECT'\n    elif mode == 'COUNT':\n        bring_back = lambda spark: limit_func(spark).count()\n        collect_type = 'COUNT'\n    elif mode == 'COLLECT_WITH_DATAFRAME':\n        def bring_back(spark):\n            df = limit_func(spark)\n            return (df.collect(), df)\n        collect_type = 'COLLECT'\n        return (bring_back, collect_type)\n    else:\n        bring_back = lambda spark: limit_func(spark).toLocalIterator()\n        collect_type = 'ITERATOR'\n        if sort_locally:\n            raise RuntimeError('Local Sort is only supported on a collect')\n    return (bring_back, collect_type)\n\ndef _prep_incompat_conf(conf):\n    if is_incompat():\n        conf = dict(conf) # Make a copy before we change anything\n        conf['spark.rapids.sql.incompatibleOps.enabled'] = 'true'\n    elif _has_incompat_conf(conf):\n        raise AssertionError(\"incompat must be enabled by the incompat fixture\")\n    return conf\n\ndef _assert_gpu_and_cpu_writes_are_equal(\n        write_func,\n        read_func,\n        base_path,\n        mode,\n        conf={}):\n    conf = _prep_incompat_conf(conf)\n\n    print('### CPU RUN ###')\n    cpu_start = time.time()\n    cpu_path = base_path + '/CPU'\n    with_cpu_session(lambda spark : write_func(spark, cpu_path), conf=conf)\n    cpu_end = time.time()\n    print('### GPU RUN ###')\n    gpu_start = time.time()\n    gpu_path = base_path + '/GPU'\n    with_gpu_session(lambda spark : write_func(spark, gpu_path), conf=conf)\n    gpu_end = time.time()\n    print('### WRITE: GPU TOOK {} CPU TOOK {} ###'.format(\n        gpu_end - gpu_start, cpu_end - cpu_start))\n\n    (cpu_bring_back, cpu_collect_type) = _prep_func_for_compare(\n            lambda spark: read_func(spark, cpu_path), mode)\n    (gpu_bring_back, gpu_collect_type) = _prep_func_for_compare(\n            lambda spark: read_func(spark, gpu_path), mode)\n\n    from_cpu = with_cpu_session(cpu_bring_back, conf=conf)\n    from_gpu = with_cpu_session(gpu_bring_back, conf=conf)\n    if should_sort_locally():\n        from_cpu.sort(key=_RowCmp)\n        from_gpu.sort(key=_RowCmp)\n\n    assert_equal(from_cpu, from_gpu)\n\ndef assert_gpu_and_cpu_writes_are_equal_collect(write_func, read_func, base_path, conf={}):\n    \"\"\"\n    Assert when running write_func on both the CPU and the GPU and reading using read_func\n    ont he CPU that the results are equal.\n    In this case the data is collected back to the driver and compared here, so be\n    careful about the amount of data returned.\n    \"\"\"\n    _assert_gpu_and_cpu_writes_are_equal(write_func, read_func, base_path, 'COLLECT', conf=conf)\n\ndef assert_gpu_and_cpu_writes_are_equal_iterator(write_func, read_func, base_path, conf={}):\n    \"\"\"\n    Assert when running write_func on both the CPU and the GPU and reading using read_func\n    ont he CPU that the results are equal.\n    In this case the data is pulled back to the driver in chunks and compared here\n    so any amount of data can work, just be careful about how long it might take.\n    \"\"\"\n    _assert_gpu_and_cpu_writes_are_equal(write_func, read_func, base_path, 'ITERATOR', conf=conf)\n\ndef assert_gpu_fallback_write(write_func,\n        read_func,\n        base_path,\n        cpu_fallback_class_name,\n        conf={}):\n    conf = _prep_incompat_conf(conf)\n\n    print('### CPU RUN ###')\n    cpu_start = time.time()\n    cpu_path = base_path + '/CPU'\n    with_cpu_session(lambda spark : write_func(spark, cpu_path), conf=conf)\n    cpu_end = time.time()\n    print('### GPU RUN ###')\n    jvm = spark_jvm()\n    jvm.org.apache.spark.sql.rapids.ExecutionPlanCaptureCallback.startCapture()\n    gpu_start = time.time()\n    gpu_path = base_path + '/GPU'\n    with_gpu_session(lambda spark : write_func(spark, gpu_path), conf=conf)\n    gpu_end = time.time()\n    jvm.org.apache.spark.sql.rapids.ExecutionPlanCaptureCallback.assertCapturedAndGpuFellBack(cpu_fallback_class_name, 10000)\n    print('### WRITE: GPU TOOK {} CPU TOOK {} ###'.format(\n        gpu_end - gpu_start, cpu_end - cpu_start))\n\n    (cpu_bring_back, cpu_collect_type) = _prep_func_for_compare(\n            lambda spark: read_func(spark, cpu_path), 'COLLECT')\n    (gpu_bring_back, gpu_collect_type) = _prep_func_for_compare(\n            lambda spark: read_func(spark, gpu_path), 'COLLECT')\n\n    from_cpu = with_cpu_session(cpu_bring_back, conf=conf)\n    from_gpu = with_cpu_session(gpu_bring_back, conf=conf)\n    if should_sort_locally():\n        from_cpu.sort(key=_RowCmp)\n        from_gpu.sort(key=_RowCmp)\n\n    assert_equal(from_cpu, from_gpu)\n\ndef assert_cpu_and_gpu_are_equal_collect_with_capture(func,\n        exist_classes='',\n        non_exist_classes='',\n        conf={}):\n    (bring_back, collect_type) = _prep_func_for_compare(func, 'COLLECT_WITH_DATAFRAME')\n\n    conf = _prep_incompat_conf(conf)\n\n    print('### CPU RUN ###')\n    cpu_start = time.time()\n    from_cpu, cpu_df = with_cpu_session(bring_back, conf=conf)\n    cpu_end = time.time()\n    print('### GPU RUN ###')\n    gpu_start = time.time()\n    from_gpu, gpu_df = with_gpu_session(bring_back, conf=conf)\n    gpu_end = time.time()\n    jvm = spark_jvm()\n    if exist_classes:\n        for clz in exist_classes.split(','):\n            jvm.org.apache.spark.sql.rapids.ExecutionPlanCaptureCallback.assertContains(gpu_df._jdf, clz)\n    if non_exist_classes:\n        for clz in non_exist_classes.split(','):\n            jvm.org.apache.spark.sql.rapids.ExecutionPlanCaptureCallback.assertNotContain(gpu_df._jdf, clz)\n    print('### {}: GPU TOOK {} CPU TOOK {} ###'.format(collect_type,\n        gpu_end - gpu_start, cpu_end - cpu_start))\n    if should_sort_locally():\n        from_cpu.sort(key=_RowCmp)\n        from_gpu.sort(key=_RowCmp)\n\n    assert_equal(from_cpu, from_gpu)\n\ndef assert_cpu_and_gpu_are_equal_sql_with_capture(df_fun,\n        sql,\n        table_name,\n        exist_classes='',\n        non_exist_classes='',\n        conf=None,\n        debug=False):\n    if conf is None:\n        conf = {}\n    def do_it_all(spark):\n        df = df_fun(spark)\n        df.createOrReplaceTempView(table_name)\n        if debug:\n            return data_gen.debug_df(spark.sql(sql))\n        else:\n            return spark.sql(sql)\n    assert_cpu_and_gpu_are_equal_collect_with_capture(do_it_all, exist_classes, non_exist_classes, conf)\n\ndef assert_gpu_fallback_collect(func,\n        cpu_fallback_class_name,\n        conf={}):\n    (bring_back, collect_type) = _prep_func_for_compare(func, 'COLLECT_WITH_DATAFRAME')\n\n    conf = _prep_incompat_conf(conf)\n\n    print('### CPU RUN ###')\n    cpu_start = time.time()\n    from_cpu, cpu_df = with_cpu_session(bring_back, conf=conf)\n    cpu_end = time.time()\n    print('### GPU RUN ###')\n    gpu_start = time.time()\n    from_gpu, gpu_df = with_gpu_session(bring_back, conf=conf)\n    gpu_end = time.time()\n    jvm = spark_jvm()\n    jvm.org.apache.spark.sql.rapids.ExecutionPlanCaptureCallback.assertDidFallBack(gpu_df._jdf, cpu_fallback_class_name)\n    print('### {}: GPU TOOK {} CPU TOOK {} ###'.format(collect_type,\n        gpu_end - gpu_start, cpu_end - cpu_start))\n    if should_sort_locally():\n        from_cpu.sort(key=_RowCmp)\n        from_gpu.sort(key=_RowCmp)\n\n    assert_equal(from_cpu, from_gpu)\n\ndef assert_gpu_sql_fallback_collect(df_fun, cpu_fallback_class_name, table_name, sql, conf=None, debug=False):\n    if conf is None:\n        conf = {}\n    def do_it_all(spark):\n        df = df_fun(spark)\n        df.createOrReplaceTempView(table_name)\n        if debug:\n            return data_gen.debug_df(spark.sql(sql))\n        else:\n            return spark.sql(sql)\n    assert_gpu_fallback_collect(do_it_all, cpu_fallback_class_name, conf)\n\ndef _assert_gpu_and_cpu_are_equal(func,\n    mode,\n    conf={},\n    is_cpu_first=True):\n    (bring_back, collect_type) = _prep_func_for_compare(func, mode)\n    conf = _prep_incompat_conf(conf)\n\n    def run_on_cpu():\n        print('### CPU RUN ###')\n        global cpu_start\n        cpu_start = time.time()\n        global from_cpu\n        from_cpu = with_cpu_session(bring_back, conf=conf)\n        global cpu_end\n        cpu_end = time.time()\n\n    def run_on_gpu():\n        print('### GPU RUN ###')\n        global gpu_start\n        gpu_start = time.time()\n        global from_gpu\n        from_gpu = with_gpu_session(bring_back, conf=conf)\n        global gpu_end\n        gpu_end = time.time()\n\n    if is_cpu_first:\n        run_on_cpu()\n        run_on_gpu()\n    else:\n        run_on_gpu()\n        run_on_cpu()\n\n    print('### {}: GPU TOOK {} CPU TOOK {} ###'.format(collect_type,\n        gpu_end - gpu_start, cpu_end - cpu_start))\n    if should_sort_locally():\n        from_cpu.sort(key=_RowCmp)\n        from_gpu.sort(key=_RowCmp)\n\n    assert_equal(from_cpu, from_gpu)\n\ndef run_with_cpu(func,\n    mode,\n    conf={}):\n    (bring_back, collect_type) = _prep_func_for_compare(func, mode)\n    conf = _prep_incompat_conf(conf)\n\n    print(\"run_with_cpu\")\n\n    def run_on_cpu():\n        print('### CPU RUN ###')\n        global cpu_start\n        cpu_start = time.time()\n        global from_cpu\n        from_cpu = with_cpu_session(bring_back, conf=conf)\n        global cpu_end\n        cpu_end = time.time()\n\n    run_on_cpu()\n\n    print('### {}: CPU TOOK {} ###'.format(collect_type,\n        cpu_end - cpu_start))\n    if should_sort_locally():\n        from_cpu.sort(key=_RowCmp)\n\n    return from_cpu\n\ndef run_with_cpu_and_gpu(func,\n    mode,\n    conf={}):\n    (bring_back, collect_type) = _prep_func_for_compare(func, mode)\n    conf = _prep_incompat_conf(conf)\n\n    def run_on_cpu():\n        print('### CPU RUN ###')\n        global cpu_start\n        cpu_start = time.time()\n        global from_cpu\n        from_cpu = with_cpu_session(bring_back, conf=conf)\n        global cpu_end\n        cpu_end = time.time()\n\n    def run_on_gpu():\n        print('### GPU RUN ###')\n        global gpu_start\n        gpu_start = time.time()\n        global from_gpu\n        from_gpu = with_gpu_session(bring_back, conf=conf)\n        global gpu_end\n        gpu_end = time.time()\n\n    run_on_cpu()\n    run_on_gpu()\n\n    print('### {}: GPU TOOK {} CPU TOOK {} ###'.format(collect_type,\n        gpu_end - gpu_start, cpu_end - cpu_start))\n    if should_sort_locally():\n        from_cpu.sort(key=_RowCmp)\n        from_gpu.sort(key=_RowCmp)\n\n    return (from_cpu, from_gpu)\n\ndef assert_gpu_and_cpu_are_equal_collect(func, conf={}, is_cpu_first=True):\n    \"\"\"\n    Assert when running func on both the CPU and the GPU that the results are equal.\n    In this case the data is collected back to the driver and compared here, so be\n    careful about the amount of data returned.\n    \"\"\"\n    _assert_gpu_and_cpu_are_equal(func, 'COLLECT', conf=conf, is_cpu_first=is_cpu_first)\n\ndef assert_gpu_and_cpu_are_equal_iterator(func, conf={}, is_cpu_first=True):\n    \"\"\"\n    Assert when running func on both the CPU and the GPU that the results are equal.\n    In this case the data is pulled back to the driver in chunks and compared here\n    so any amount of data can work, just be careful about how long it might take.\n    \"\"\"\n    _assert_gpu_and_cpu_are_equal(func, 'ITERATOR', conf=conf, is_cpu_first=is_cpu_first)\n\ndef assert_gpu_and_cpu_row_counts_equal(func, conf={}, is_cpu_first=True):\n    \"\"\"\n    Assert that the row counts from running the func are the same on both the CPU and GPU.\n    This function runs count() to only get the number of rows and compares that count\n    between the CPU and GPU. It does NOT compare any underlying data.\n    \"\"\"\n    _assert_gpu_and_cpu_are_equal(func, 'COUNT', conf=conf, is_cpu_first=is_cpu_first)\n\ndef assert_gpu_and_cpu_are_equal_sql(df_fun, table_name, sql, conf=None, debug=False, is_cpu_first=True, validate_execs_in_gpu_plan=[]):\n    \"\"\"\n    Assert that the specified SQL query produces equal results on CPU and GPU.\n    :param df_fun: a function that will create the dataframe\n    :param table_name: Name of table to be created with the dataframe\n    :param sql: SQL query to be run on the specified table\n    :param conf: Any user-specified confs. Empty by default.\n    :param debug: Boolean to indicate if the SQL output should be printed\n    :param is_cpu_first: Boolean to indicate if the CPU should be run first or not\n    :param validate_execs_in_gpu_plan: String list of expressions to be validated in the GPU plan.\n    :return: Assertion failure, if results from CPU and GPU do not match.\n    \"\"\"\n    if conf is None:\n        conf = {}\n    def do_it_all(spark):\n        df = df_fun(spark)\n        df.createOrReplaceTempView(table_name)\n        # we hold off on setting the validate execs until after creating the temp view\n\n        spark.conf.set('spark.rapids.sql.test.validateExecsInGpuPlan', ','.join(validate_execs_in_gpu_plan))\n        if debug:\n            return data_gen.debug_df(spark.sql(sql))\n        else:\n            return spark.sql(sql)\n    assert_gpu_and_cpu_are_equal_collect(do_it_all, conf, is_cpu_first=is_cpu_first)\n\ndef assert_py4j_exception(func, error_message):\n    \"\"\"\n    Assert that a specific Java exception is thrown\n    :param func: a function to be verified\n    :param error_message: a string such as the one produce by java.lang.Exception.toString\n    :return: Assertion failure if no exception matching error_message has occurred.\n    \"\"\"\n    with pytest.raises(Py4JJavaError) as py4jError:\n        func()\n    assert error_message in str(py4jError.value.java_exception)\n\ndef assert_gpu_and_cpu_error(df_fun, conf, error_message):\n    \"\"\"\n    Assert that GPU and CPU execution results in a specific Java exception thrown\n    :param df_fun: a function to be verified\n    :param conf: Spark config\n    :param error_message: a string such as the one produce by java.lang.Exception.toString\n    :return: Assertion failure if either GPU or CPU versions has not generated error messages\n             expected\n    \"\"\"\n    assert_py4j_exception(lambda: with_cpu_session(df_fun, conf), error_message)\n    assert_py4j_exception(lambda: with_gpu_session(df_fun, conf), error_message)\n\ndef with_cpu_sql(df_fun, table_name, sql, conf=None, debug=False):\n    if conf is None:\n        conf = {}\n    def do_it_all(spark):\n        df = df_fun(spark)\n        df.createOrReplaceTempView(table_name)\n        if debug:\n            return data_gen.debug_df(spark.sql(sql))\n        else:\n            return spark.sql(sql)\n    assert_gpu_and_cpu_are_equal_collect(do_it_all, conf, is_cpu_first=is_cpu_first)\n"
  },
  {
    "path": "examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/python/conftest.py",
    "content": "# Copyright (c) 2020-2022, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport pytest\nimport random\nfrom spark_init_internal import get_spark_i_know_what_i_am_doing\nfrom pyspark.sql.dataframe import DataFrame\n\n_approximate_float_args = None\n\ndef get_float_check():\n    if not _approximate_float_args is None:\n        return lambda lhs,rhs: lhs == pytest.approx(rhs, **_approximate_float_args)\n    else:\n        return lambda lhs,rhs: lhs == rhs\n\n_incompat = False\n\ndef is_incompat():\n    return _incompat\n\n_sort_on_spark = False\n_sort_locally = False\n\ndef should_sort_on_spark():\n    return _sort_on_spark\n\ndef should_sort_locally():\n    return _sort_locally\n\n_allow_any_non_gpu = False\n_non_gpu_allowed = []\n\ndef is_allowing_any_non_gpu():\n    return _allow_any_non_gpu\n\ndef get_non_gpu_allowed():\n    return _non_gpu_allowed\n\ndef get_validate_execs_in_gpu_plan():\n    return _validate_execs_in_gpu_plan\n\n_runtime_env = \"apache\"\n\ndef runtime_env():\n    return _runtime_env.lower()\n\ndef is_apache_runtime():\n    return runtime_env() == \"apache\"\n\ndef is_databricks_runtime():\n    return runtime_env() == \"databricks\"\n\ndef is_emr_runtime():\n    return runtime_env() == \"emr\"\n\ndef is_dataproc_runtime():\n    return runtime_env() == \"dataproc\"\n\n_is_nightly_run = False\n_is_precommit_run = False\n\ndef is_nightly_run():\n    return _is_nightly_run\n\ndef is_at_least_precommit_run():\n    return _is_nightly_run or _is_precommit_run\n\ndef skip_unless_nightly_tests(description):\n    if (_is_nightly_run):\n        raise AssertionError(description + ' during nightly test run')\n    else:\n        pytest.skip(description)\n\ndef skip_unless_precommit_tests(description):\n    if (_is_nightly_run):\n        raise AssertionError(description + ' during nightly test run')\n    elif (_is_precommit_run):\n        raise AssertionError(description + ' during pre-commit test run')\n    else:\n        pytest.skip(description)\n\n_limit = -1\n\ndef get_limit():\n    return _limit\n\ndef _get_limit_from_mark(mark):\n    if mark.args:\n        return mark.args[0]\n    else:\n        return mark.kwargs.get('num_rows', 100000)\n\ndef pytest_runtest_setup(item):\n    global _sort_on_spark\n    global _sort_locally\n    order = item.get_closest_marker('ignore_order')\n    if order:\n        if order.kwargs.get('local', False):\n            _sort_on_spark = False\n            _sort_locally = True\n        else:\n            _sort_on_spark = True\n            _sort_locally = False\n    else:\n        _sort_on_spark = False\n        _sort_locally = False\n\n    global _incompat\n    if item.get_closest_marker('incompat'):\n        _incompat = True\n    else:\n        _incompat = False\n\n    global _approximate_float_args\n    app_f = item.get_closest_marker('approximate_float')\n    if app_f:\n        _approximate_float_args = app_f.kwargs\n    else:\n        _approximate_float_args = None\n\n    global _allow_any_non_gpu\n    global _non_gpu_allowed\n    _non_gpu_allowed_databricks = []\n    _allow_any_non_gpu_databricks = False\n    non_gpu_databricks = item.get_closest_marker('allow_non_gpu_databricks')\n    non_gpu = item.get_closest_marker('allow_non_gpu')\n    if non_gpu_databricks:\n        if is_databricks_runtime():\n            if non_gpu_databricks.kwargs and non_gpu_databricks.kwargs['any']:\n                _allow_any_non_gpu_databricks = True\n            elif non_gpu_databricks.args:\n                _non_gpu_allowed_databricks = non_gpu_databricks.args\n            else:\n                pytest.warn('allow_non_gpu_databricks marker without anything allowed')\n    if non_gpu:\n        if non_gpu.kwargs and non_gpu.kwargs['any']:\n            _allow_any_non_gpu = True\n            _non_gpu_allowed = []\n        elif non_gpu.args:\n            _allow_any_non_gpu = False\n            _non_gpu_allowed = non_gpu.args\n        else:\n            pytest.warn('allow_non_gpu marker without anything allowed')\n            _allow_any_non_gpu = False\n            _non_gpu_allowed = []\n    else:\n        _allow_any_non_gpu = False\n        _non_gpu_allowed = []\n\n    _allow_any_non_gpu = _allow_any_non_gpu | _allow_any_non_gpu_databricks\n    if _non_gpu_allowed and _non_gpu_allowed_databricks:\n        _non_gpu_allowed = _non_gpu_allowed + _non_gpu_allowed_databricks\n    elif _non_gpu_allowed_databricks:\n        _non_gpu_allowed = _non_gpu_allowed_databricks\n\n    global _validate_execs_in_gpu_plan\n    validate_execs = item.get_closest_marker('validate_execs_in_gpu_plan')\n    if validate_execs and validate_execs.args:\n        _validate_execs_in_gpu_plan = validate_execs.args\n    else:\n        _validate_execs_in_gpu_plan = []\n\n    global _limit\n    limit_mrk = item.get_closest_marker('limit')\n    if limit_mrk:\n        _limit = _get_limit_from_mark(limit_mrk)\n    else:\n        _limit = -1\n\ndef pytest_configure(config):\n    global _runtime_env\n    _runtime_env = config.getoption('runtime_env')\n    global _is_nightly_run\n    global _is_precommit_run\n    test_type = config.getoption('test_type').lower()\n    if \"nightly\" == test_type:\n        _is_nightly_run = True\n    elif \"pre-commit\" == test_type:\n        _is_precommit_run = True\n    elif \"developer\" != test_type:\n        raise Exception(\"not supported test type {}\".format(test_type))\n\ndef pytest_collection_modifyitems(config, items):\n    for item in items:\n        extras = []\n        order = item.get_closest_marker('ignore_order')\n        if order:\n            if order.kwargs:\n                extras.append('IGNORE_ORDER(' + str(order.kwargs) + ')')\n            else:\n                extras.append('IGNORE_ORDER')\n        if item.get_closest_marker('incompat'):\n            extras.append('INCOMPAT')\n        app_f = item.get_closest_marker('approximate_float')\n        if app_f:\n            if app_f.kwargs:\n                extras.append('APPROXIMATE_FLOAT(' + str(app_f.kwargs) + ')')\n            else:\n                extras.append('APPROXIMATE_FLOAT')\n        non_gpu = item.get_closest_marker('allow_non_gpu')\n        if non_gpu:\n            if non_gpu.kwargs and non_gpu.kwargs['any']:\n                extras.append('ALLOW_NON_GPU(ANY)')\n            elif non_gpu.args:\n                extras.append('ALLOW_NON_GPU(' + ','.join(non_gpu.args) + ')')\n\n        limit_mrk = item.get_closest_marker('limit')\n        if limit_mrk:\n            extras.append('LIMIT({})'.format(_get_limit_from_mark(limit_mrk)))\n\n        if extras:\n            # This is not ideal because we are reaching into an internal value\n            item._nodeid = item.nodeid + '[' + ', '.join(extras) + ']'\n\n@pytest.fixture(scope=\"session\")\ndef std_input_path(request):\n    path = request.config.getoption(\"std_input_path\")\n    if path is None:\n        skip_unless_precommit_tests(\"std_input_path is not configured\")\n    else:\n        yield path\n\n@pytest.fixture\ndef spark_tmp_path(request):\n    debug = request.config.getoption('debug_tmp_path')\n    ret = request.config.getoption('tmp_path')\n    if ret is None:\n        ret = '/tmp/pyspark_tests/'\n    ret = ret + '/' + str(random.randint(0, 1000000)) + '/'\n    # Make sure it is there and accessible\n    sc = get_spark_i_know_what_i_am_doing().sparkContext\n    config = sc._jsc.hadoopConfiguration()\n    path = sc._jvm.org.apache.hadoop.fs.Path(ret)\n    fs = sc._jvm.org.apache.hadoop.fs.FileSystem.get(config)\n    fs.mkdirs(path)\n    yield ret\n    if not debug:\n        fs.delete(path)\n\nclass TmpTableFactory:\n  def __init__(self, base_id):\n      self.base_id = base_id\n      self.running_id = 0\n\n  def get(self):\n      ret = '{}_{}'.format(self.base_id, self.running_id)\n      self.running_id = self.running_id + 1\n      return ret\n\n@pytest.fixture\ndef spark_tmp_table_factory(request):\n    base_id = 'tmp_table_{}'.format(random.randint(0, 1000000))\n    yield TmpTableFactory(base_id)\n    sp = get_spark_i_know_what_i_am_doing()\n    tables = sp.sql(\"SHOW TABLES\".format(base_id)).collect()\n    for row in tables:\n        t_name = row['tableName']\n        if (t_name.startswith(base_id)):\n            sp.sql(\"DROP TABLE IF EXISTS {}\".format(t_name))\n\ndef _get_jvm_session(spark):\n    return spark._jsparkSession\n\ndef _get_jvm(spark):\n    return spark.sparkContext._jvm\n\ndef spark_jvm():\n    return _get_jvm(get_spark_i_know_what_i_am_doing())\n\nclass MortgageRunner:\n  def __init__(self, mortgage_format, mortgage_acq_path, mortgage_perf_path):\n    self.mortgage_format = mortgage_format\n    self.mortgage_acq_path = mortgage_acq_path\n    self.mortgage_perf_path = mortgage_perf_path\n\n  def do_test_query(self, spark):\n    jvm_session = _get_jvm_session(spark)\n    jvm = _get_jvm(spark)\n    acq = self.mortgage_acq_path\n    perf = self.mortgage_perf_path\n    run = jvm.com.nvidia.spark.rapids.tests.mortgage.Run\n    if self.mortgage_format == 'csv':\n        df = run.csv(jvm_session, perf, acq)\n    elif self.mortgage_format == 'parquet':\n        df = run.parquet(jvm_session, perf, acq)\n    elif self.mortgage_format == 'orc':\n        df = run.orc(jvm_session, perf, acq)\n    else:\n        raise AssertionError('Not Supported Format {}'.format(self.mortgage_format))\n\n    return DataFrame(df, spark.getActiveSession())\n   \n@pytest.fixture(scope=\"session\")\ndef mortgage(request):\n    mortgage_format = request.config.getoption(\"mortgage_format\")\n    mortgage_path = request.config.getoption(\"mortgage_path\")\n    if mortgage_path is None:\n        std_path = request.config.getoption(\"std_input_path\")\n        if std_path is None:\n            skip_unless_precommit_tests(\"Mortgage tests are not configured to run\")\n        else:\n            yield MortgageRunner('parquet', std_path + '/parquet_acq', std_path + '/parquet_perf')\n    else:\n        yield MortgageRunner(mortgage_format, mortgage_path + '/acq', mortgage_path + '/perf')\n\n@pytest.fixture(scope=\"session\")\ndef enable_cudf_udf(request):\n    enable_udf_cudf = request.config.getoption(\"cudf_udf\")\n    if not enable_udf_cudf:\n        # cudf_udf tests are not required for any test runs\n        pytest.skip(\"cudf_udf not configured to run\")\n\n@pytest.fixture(scope=\"session\")\ndef enable_rapids_udf_example_native(request):\n    native_enabled = request.config.getoption(\"rapids_udf_example_native\")\n    if not native_enabled:\n        # udf_example_native tests are not required for any test runs\n        pytest.skip(\"rapids_udf_example_native is not configured to run\")\n"
  },
  {
    "path": "examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/python/data_gen.py",
    "content": "# Copyright (c) 2020-2022, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport copy\nfrom datetime import date, datetime, timedelta, timezone\nfrom decimal import *\nimport math\nfrom pyspark.context import SparkContext\nfrom pyspark.sql import Row\nfrom pyspark.sql.types import *\nimport pyspark.sql.functions as f\nimport pytest\nimport random\nfrom spark_session import is_tz_utc\nimport sre_yield\nimport struct\nfrom conftest import skip_unless_precommit_tests\n\nclass DataGen:\n    \"\"\"Base class for data generation\"\"\"\n\n    def __repr__(self):\n        if not self.nullable:\n            return self.__class__.__name__[:-3] + '(not_null)'\n        return self.__class__.__name__[:-3]\n\n    def __hash__(self):\n        return hash(str(self))\n\n    def __eq__(self, other):\n        return isinstance(other, self.__class__) and self.__dict__ == other.__dict__\n\n    def __ne__(self, other):\n        return not self.__eq__(other)\n\n    def __init__(self, data_type, nullable=True, special_cases =[]):\n        self.data_type = data_type\n        self.list_of_special_cases = special_cases\n        self._special_cases = []\n        if isinstance(nullable, tuple):\n            self.nullable = nullable[0]\n            weight = nullable[1]\n        else:\n            self.nullable = nullable\n            weight = 5.0\n        if self.nullable:\n            self.with_special_case(None, weight)\n\n        # Special cases can be a value or a tuple of (value, weight). If the\n        # special_case itself is a tuple as in the case of StructGen, it MUST be added with a\n        # weight like : ((special_case_tuple_v1, special_case_tuple_v2), weight).\n        for element in special_cases:\n            if isinstance(element, tuple):\n                self.with_special_case(element[0], element[1])\n            else:\n                self.with_special_case(element)\n\n    def copy_special_case(self, special_case, weight=1.0):\n        # it would be good to do a deepcopy, but sre_yield is not happy with that.\n        c = copy.copy(self)\n        c._special_cases = copy.deepcopy(self._special_cases)\n\n        return c.with_special_case(special_case, weight=weight)\n\n    def with_special_case(self, special_case, weight=1.0):\n        \"\"\"\n        Add in a special case with a given weight. A special case can either be\n        a function that takes an instance of Random and returns the generated data\n        or it can be a constant.  By default the weight is 1.0, and the default\n        number generation's weight is 100.0.  The number of lines that are generate in\n        the data set should be proportional to the its weight/sum weights\n        \"\"\"\n        if callable(special_case):\n            sc = special_case\n        else:\n            sc = lambda rand: special_case\n        self._special_cases.append((weight, sc))\n        return self\n\n    def get_types(self):\n        return 'DataType: {}, nullable: {}, special_cases: {}'.format(self.data_type,\n          self.nullable, self.list_of_special_cases)\n\n    def start(self, rand):\n        \"\"\"Start data generation using the given rand\"\"\"\n        raise TypeError('Children should implement this method and call _start')\n\n    def _start(self, rand, gen_func):\n        \"\"\"Start internally, but use the given gen_func as the base\"\"\"\n        if not self._special_cases:\n            self._gen_func = gen_func\n        else:\n            weighted_choices = [(100.0, lambda rand: gen_func())]\n            weighted_choices.extend(self._special_cases)\n            total = float(sum(weight for weight,gen in weighted_choices))\n            normalized_choices = [(weight/total, gen) for weight,gen in weighted_choices]\n\n            def choose_one():\n                pick = rand.random()\n                total = 0\n                for (weight, gen) in normalized_choices:\n                    total += weight\n                    if total >= pick:\n                        return gen(rand)\n                raise RuntimeError('Random did not pick something we expected')\n            self._gen_func = choose_one\n\n    def gen(self, force_no_nulls=False):\n        \"\"\"generate the next line\"\"\"\n        if not self._gen_func:\n            raise RuntimeError('start must be called before generating any data')\n        v = self._gen_func()\n        if force_no_nulls:\n            while v is None:\n                v = self._gen_func()\n        return v\n\n    def contains_ts(self):\n        \"\"\"Checks if this contains a TimestampGen\"\"\"\n        return False\n\nclass ConvertGen(DataGen):\n    \"\"\"Provides a way to modify the data before it is returned\"\"\"\n    def __init__(self, child_gen, func, data_type=None, nullable=True):\n        if data_type is None:\n            data_type = child_gen.data_type\n        super().__init__(data_type, nullable=nullable)\n        self._child_gen = child_gen\n        self._func = func\n\n    def __repr__(self):\n        return super().__repr__() + '(' + str(self._child_gen) + ')'\n\n    def start(self, rand):\n        self._child_gen.start(rand)\n        def modify():\n            return self._func(self._child_gen.gen())\n\n        self._start(rand, modify)\n\n_MAX_CHOICES = 1 << 64\nclass StringGen(DataGen):\n    \"\"\"Generate strings that match a pattern\"\"\"\n    def __init__(self, pattern=\"(.|\\n){1,30}\", flags=0, charset=sre_yield.CHARSET, nullable=True):\n        super().__init__(StringType(), nullable=nullable)\n        self.base_strs = sre_yield.AllStrings(pattern, flags=flags, charset=charset, max_count=_MAX_CHOICES)\n\n    def with_special_pattern(self, pattern, flags=0, charset=sre_yield.CHARSET, weight=1.0):\n        \"\"\"\n        Like with_special_case but you can provide a regexp pattern\n        instead of a hard coded string value.\n        \"\"\"\n        strs = sre_yield.AllStrings(pattern, flags=flags, charset=charset, max_count=_MAX_CHOICES)\n        try:\n            length = int(len(strs))\n        except OverflowError:\n            length = _MAX_CHOICES\n        return self.with_special_case(lambda rand : strs[rand.randrange(0, length)], weight=weight)\n\n    def start(self, rand):\n        strs = self.base_strs\n        try:\n            length = int(len(strs))\n        except OverflowError:\n            length = _MAX_CHOICES\n        self._start(rand, lambda : strs[rand.randrange(0, length)])\n\nBYTE_MIN = -(1 << 7)\nBYTE_MAX = (1 << 7) - 1\nclass ByteGen(DataGen):\n    \"\"\"Generate Bytes\"\"\"\n    def __init__(self, nullable=True, min_val = BYTE_MIN, max_val = BYTE_MAX, special_cases=[]):\n        super().__init__(ByteType(), nullable=nullable, special_cases=special_cases)\n        self._min_val = min_val\n        self._max_val = max_val\n\n    def start(self, rand):\n        self._start(rand, lambda : rand.randint(self._min_val, self._max_val))\n\nSHORT_MIN = -(1 << 15)\nSHORT_MAX = (1 << 15) - 1\nclass ShortGen(DataGen):\n    \"\"\"Generate Shorts, which some built in corner cases.\"\"\"\n    def __init__(self, nullable=True, min_val = SHORT_MIN, max_val = SHORT_MAX,\n                 special_cases = [SHORT_MIN, SHORT_MAX, 0, 1, -1]):\n        super().__init__(ShortType(), nullable=nullable, special_cases=special_cases)\n        self._min_val = min_val\n        self._max_val = max_val\n\n    def start(self, rand):\n        self._start(rand, lambda : rand.randint(self._min_val, self._max_val))\n\nINT_MIN = -(1 << 31)\nINT_MAX = (1 << 31) - 1\nclass IntegerGen(DataGen):\n    \"\"\"Generate Ints, which some built in corner cases.\"\"\"\n    def __init__(self, nullable=True, min_val = INT_MIN, max_val = INT_MAX,\n                 special_cases = [INT_MIN, INT_MAX, 0, 1, -1]):\n        super().__init__(IntegerType(), nullable=nullable, special_cases=special_cases)\n        self._min_val = min_val\n        self._max_val = max_val\n\n    def start(self, rand):\n        self._start(rand, lambda : rand.randint(self._min_val, self._max_val))\n\nclass DecimalGen(DataGen):\n    \"\"\"Generate Decimals, with some built in corner cases.\"\"\"\n    def __init__(self, precision=None, scale=None, nullable=True, special_cases=[]):\n        if precision is None:\n            #Maximum number of decimal digits a Long can represent is 18\n            precision = 18\n            scale = 0\n        DECIMAL_MIN = Decimal('-' + ('9' * precision) + 'e' + str(-scale))\n        DECIMAL_MAX = Decimal(('9'* precision) + 'e' + str(-scale))\n        super().__init__(DecimalType(precision, scale), nullable=nullable, special_cases=special_cases)\n        self.scale = scale\n        self.precision = precision\n        pattern = \"[0-9]{1,\"+ str(precision) + \"}e\" + str(-scale)\n        self.base_strs = sre_yield.AllStrings(pattern, flags=0, charset=sre_yield.CHARSET, max_count=_MAX_CHOICES)\n\n    def __repr__(self):\n        return super().__repr__() + '(' + str(self.precision) + ',' + str(self.scale) + ')'\n\n    def start(self, rand):\n        strs = self.base_strs\n        try:\n            length = int(strs.length)\n        except OverflowError:\n            length = _MAX_CHOICES\n        self._start(rand, lambda : Decimal(strs[rand.randrange(0, length)]))\n\nLONG_MIN = -(1 << 63)\nLONG_MAX = (1 << 63) - 1\nclass LongGen(DataGen):\n    \"\"\"Generate Longs, which some built in corner cases.\"\"\"\n    def __init__(self, nullable=True, min_val = LONG_MIN, max_val = LONG_MAX, special_cases = []):\n        _special_cases = [min_val, max_val, 0, 1, -1] if not special_cases else special_cases\n        super().__init__(LongType(), nullable=nullable, special_cases=_special_cases)\n        self._min_val = min_val\n        self._max_val = max_val\n\n    def start(self, rand):\n        self._start(rand, lambda : rand.randint(self._min_val, self._max_val))\n\nclass LongRangeGen(DataGen):\n    \"\"\"Generate Longs in incrementing order.\"\"\"\n    def __init__(self, nullable=False, start_val=0, direction=\"inc\"):\n        super().__init__(LongType(), nullable=nullable)\n        self._start_val = start_val\n        self._current_val = start_val\n        if (direction == \"dec\"):\n            def dec_it():\n                tmp = self._current_val\n                self._current_val -= 1\n                return tmp\n            self._do_it = dec_it\n        else:\n            def inc_it():\n                tmp = self._current_val\n                self._current_val += 1\n                return tmp\n            self._do_it = inc_it\n\n    def start(self, rand):\n        self._current_val = self._start_val\n        self._start(rand, self._do_it)\n\nclass RepeatSeqGen(DataGen):\n    \"\"\"Generate Repeated seq of `length` random items\"\"\"\n    def __init__(self, child, length):\n        super().__init__(child.data_type, nullable=False)\n        self.nullable = child.nullable\n        self._child = child\n        self._vals = []\n        self._length = length\n        self._index = 0\n\n    def __repr__(self):\n        return super().__repr__() + '(' + str(self._child) + ')'\n\n    def _loop_values(self):\n        ret = self._vals[self._index]\n        self._index = (self._index + 1) % self._length\n        return ret\n\n    def start(self, rand):\n        self._index = 0\n        self._child.start(rand)\n        self._start(rand, self._loop_values)\n        self._vals = [self._child.gen() for _ in range(0, self._length)]\n\nclass SetValuesGen(DataGen):\n    \"\"\"A set of values that are randomly selected\"\"\"\n    def __init__(self, data_type, data):\n        super().__init__(data_type, nullable=False)\n        self.nullable = any(x is None for x in data)\n        self._vals = data\n\n    def __repr__(self):\n        return super().__repr__() + '(' + str(self._child) + ')'\n\n    def start(self, rand):\n        data = self._vals\n        length = len(data)\n        self._start(rand, lambda : data[rand.randrange(0, length)])\n\nFLOAT_MIN = -3.4028235E38\nFLOAT_MAX = 3.4028235E38\nNEG_FLOAT_NAN_MIN_VALUE = struct.unpack('f', struct.pack('I', 0xffffffff))[0]\nNEG_FLOAT_NAN_MAX_VALUE = struct.unpack('f', struct.pack('I', 0xff800001))[0]\nPOS_FLOAT_NAN_MIN_VALUE = struct.unpack('f', struct.pack('I', 0x7f800001))[0]\nPOS_FLOAT_NAN_MAX_VALUE = struct.unpack('f', struct.pack('I', 0x7fffffff))[0]\nclass FloatGen(DataGen):\n    \"\"\"Generate floats, which some built in corner cases.\"\"\"\n    def __init__(self, nullable=True,\n            no_nans=False, special_cases=None):\n        self._no_nans = no_nans\n        if special_cases is None:\n            special_cases = [FLOAT_MIN, FLOAT_MAX, 0.0, -0.0, 1.0, -1.0]\n            if not no_nans:\n                special_cases.append(float('inf'))\n                special_cases.append(float('-inf'))\n                special_cases.append(float('nan'))\n                special_cases.append(NEG_FLOAT_NAN_MAX_VALUE)\n        super().__init__(FloatType(), nullable=nullable, special_cases=special_cases)\n\n    def _fixup_nans(self, v):\n        if self._no_nans and (math.isnan(v) or v == math.inf or v == -math.inf):\n            v = None if self.nullable else 0.0\n        return v\n\n    def start(self, rand):\n        def gen_float():\n            i = rand.randint(INT_MIN, INT_MAX)\n            p = struct.pack('i', i)\n            return self._fixup_nans(struct.unpack('f', p)[0])\n        self._start(rand, gen_float)\n\nDOUBLE_MIN_EXP = -1022\nDOUBLE_MAX_EXP = 1023\nDOUBLE_MAX_FRACTION = int('1'*52, 2)\nDOUBLE_MIN = -1.7976931348623157E308\nDOUBLE_MAX = 1.7976931348623157E308\nNEG_DOUBLE_NAN_MIN_VALUE = struct.unpack('d', struct.pack('L', 0xffffffffffffffff))[0]\nNEG_DOUBLE_NAN_MAX_VALUE = struct.unpack('d', struct.pack('L', 0xfff0000000000001))[0]\nPOS_DOUBLE_NAN_MIN_VALUE = struct.unpack('d', struct.pack('L', 0x7ff0000000000001))[0]\nPOS_DOUBLE_NAN_MAX_VALUE = struct.unpack('d', struct.pack('L', 0x7fffffffffffffff))[0]\nclass DoubleGen(DataGen):\n    \"\"\"Generate doubles, which some built in corner cases.\"\"\"\n    def __init__(self, min_exp=DOUBLE_MIN_EXP, max_exp=DOUBLE_MAX_EXP, no_nans=False,\n            nullable=True, special_cases = None):\n        self._min_exp = min_exp\n        self._max_exp = max_exp\n        self._no_nans = no_nans\n        self._use_full_range = (self._min_exp == DOUBLE_MIN_EXP) and (self._max_exp == DOUBLE_MAX_EXP)\n        if special_cases is None:\n            special_cases = [\n                self.make_from(1, self._max_exp, DOUBLE_MAX_FRACTION),\n                self.make_from(0, self._max_exp, DOUBLE_MAX_FRACTION),\n                self.make_from(1, self._min_exp, DOUBLE_MAX_FRACTION),\n                self.make_from(0, self._min_exp, DOUBLE_MAX_FRACTION)\n            ]\n            if self._min_exp <= 0 and self._max_exp >= 0:\n                special_cases.append(0.0)\n                special_cases.append(-0.0)\n            if self._min_exp <= 3 and self._max_exp >= 3:\n                special_cases.append(1.0)\n                special_cases.append(-1.0)\n            if not no_nans:\n                special_cases.append(float('inf'))\n                special_cases.append(float('-inf'))\n                special_cases.append(float('nan'))\n                special_cases.append(NEG_DOUBLE_NAN_MAX_VALUE)\n        super().__init__(DoubleType(), nullable=nullable, special_cases=special_cases)\n\n    @staticmethod\n    def make_from(sign, exp, fraction):\n        sign = sign & 1 # 1 bit\n        exp = (exp + 1023) & 0x7FF # add bias and 11 bits\n        fraction = fraction & DOUBLE_MAX_FRACTION\n        i = (sign << 63) | (exp << 52) | fraction\n        p = struct.pack('L', i)\n        ret = struct.unpack('d', p)[0]\n        return ret\n\n    def _fixup_nans(self, v):\n        if self._no_nans and (math.isnan(v) or v == math.inf or v == -math.inf):\n            v = None if self.nullable else 0.0\n        return v\n\n    def start(self, rand):\n        if self._use_full_range:\n            def gen_double():\n                i = rand.randint(LONG_MIN, LONG_MAX)\n                p = struct.pack('l', i)\n                return self._fixup_nans(struct.unpack('d', p)[0])\n            self._start(rand, gen_double)\n        else:\n            def gen_part_double():\n                sign = rand.getrandbits(1)\n                exp = rand.randint(self._min_exp, self._max_exp)\n                fraction = rand.getrandbits(52)\n                return self._fixup_nans(self.make_from(sign, exp, fraction))\n            self._start(rand, gen_part_double)\n\nclass BooleanGen(DataGen):\n    \"\"\"Generate Bools (True/False)\"\"\"\n    def __init__(self, nullable=True):\n        super().__init__(BooleanType(), nullable=nullable)\n\n    def start(self, rand):\n        self._start(rand, lambda : bool(rand.getrandbits(1)))\n\nclass StructGen(DataGen):\n    \"\"\"Generate a Struct\"\"\"\n    def __init__(self, children, nullable=True, special_cases=[]):\n        \"\"\"\n        Initialize the struct with children.  The children should be of the form:\n        [('name', Gen),('name_2', Gen2)]\n        Where name is the name of the strict field and Gens are Generators of\n        the type for that entry.\n        \"\"\"\n        tmp = [StructField(name, child.data_type, nullable=child.nullable) for name, child in children]\n        super().__init__(StructType(tmp), nullable=nullable, special_cases=special_cases)\n        self.children = children\n\n    def __repr__(self):\n        return super().__repr__() + '(' + ','.join([str(i) for i in self.children]) + ')'\n\n    def start(self, rand):\n        for name, child in self.children:\n            child.start(rand)\n        def make_tuple():\n            data = [child.gen() for name, child in self.children]\n            return tuple(data)\n        self._start(rand, make_tuple)\n\n    def contains_ts(self):\n        return any(child[1].contains_ts() for child in self.children)\n\nclass DateGen(DataGen):\n    \"\"\"Generate Dates in a given range\"\"\"\n    def __init__(self, start=None, end=None, nullable=True):\n        super().__init__(DateType(), nullable=nullable)\n        if start is None:\n            # Spark supports times starting at\n            # \"0001-01-01 00:00:00.000000\"\n            start = date(1, 1, 1)\n        elif not isinstance(start, date):\n            raise RuntimeError('Unsupported type passed in for start {}'.format(start))\n\n        if end is None:\n            # Spark supports time through\n            # \"9999-12-31 23:59:59.999999\"\n            end = date(9999, 12, 31)\n        elif isinstance(end, timedelta):\n            end = start + end\n        elif not isinstance(start, date):\n            raise RuntimeError('Unsupported type passed in for end {}'.format(end))\n\n        self._start_day = self._to_days_since_epoch(start)\n        self._end_day = self._to_days_since_epoch(end)\n\n        self.with_special_case(start)\n        self.with_special_case(end)\n\n        # we want a few around the leap year if possible\n        step = int((end.year - start.year) / 5.0)\n        if (step != 0):\n            years = {self._guess_leap_year(y) for y in range(start.year, end.year, step)}\n            for y in years:\n                leap_day = date(y, 2, 29)\n                if (leap_day > start and leap_day < end):\n                    self.with_special_case(leap_day)\n                next_day = date(y, 3, 1)\n                if (next_day > start and next_day < end):\n                    self.with_special_case(next_day)\n\n    @staticmethod\n    def _guess_leap_year(t):\n        y = int(math.ceil(t/4.0)) * 4\n        if ((y % 100) == 0) and ((y % 400) != 0):\n            y = y + 4\n        if (y == 10000):\n            y = y - 4\n        return y\n\n    _epoch = date(1970, 1, 1)\n    _days = timedelta(days=1)\n    def _to_days_since_epoch(self, val):\n        return int((val - self._epoch)/self._days)\n\n    def _from_days_since_epoch(self, days):\n        return self._epoch + timedelta(days=days)\n\n    def start(self, rand):\n        start = self._start_day\n        end = self._end_day\n        self._start(rand, lambda : self._from_days_since_epoch(rand.randint(start, end)))\n\nclass TimestampGen(DataGen):\n    \"\"\"Generate Timestamps in a given range. All timezones are UTC by default.\"\"\"\n    def __init__(self, start=None, end=None, nullable=True):\n        super().__init__(TimestampType(), nullable=nullable)\n        if start is None:\n            # Spark supports times starting at\n            # \"0001-01-01 00:00:00.000000\"\n            # but it has issues if you get really close to that because it tries to do things\n            # in a different format which causes roundoff, so we have to add a few days,\n            # just to be sure\n            start = datetime(1, 1, 3, tzinfo=timezone.utc)\n        elif not isinstance(start, datetime):\n            raise RuntimeError('Unsupported type passed in for start {}'.format(start))\n\n        if end is None:\n            # Spark supports time through\n            # \"9999-12-31 23:59:59.999999\"\n            end = datetime(9999, 12, 31, 23, 59, 59, 999999, tzinfo=timezone.utc)\n        elif isinstance(end, timedelta):\n            end = start + end\n        elif not isinstance(start, date):\n            raise RuntimeError('Unsupported type passed in for end {}'.format(end))\n\n        self._start_time = self._to_ms_since_epoch(start)\n        self._end_time = self._to_ms_since_epoch(end)\n        if (self._epoch >= start and self._epoch <= end):\n            self.with_special_case(self._epoch)\n\n    _epoch = datetime(1970, 1, 1, tzinfo=timezone.utc)\n    _ms = timedelta(milliseconds=1)\n    def _to_ms_since_epoch(self, val):\n        return int((val - self._epoch)/self._ms)\n\n    def _from_ms_since_epoch(self, ms):\n        return self._epoch + timedelta(milliseconds=ms)\n\n    def start(self, rand):\n        start = self._start_time\n        end = self._end_time\n        self._start(rand, lambda : self._from_ms_since_epoch(rand.randint(start, end)))\n\n    def contains_ts(self):\n        return True\n\nclass ArrayGen(DataGen):\n    \"\"\"Generate Arrays of data.\"\"\"\n    def __init__(self, child_gen, min_length=0, max_length=20, nullable=True, all_null=False):\n        super().__init__(ArrayType(child_gen.data_type, containsNull=child_gen.nullable), nullable=nullable)\n        self._min_length = min_length\n        self._max_length = max_length\n        self._child_gen = child_gen\n        self.all_null = all_null\n\n    def __repr__(self):\n        return super().__repr__() + '(' + str(self._child_gen) + ')'\n\n    def start(self, rand):\n        self._child_gen.start(rand)\n        def gen_array():\n            if self.all_null:\n                return None\n            length = rand.randint(self._min_length, self._max_length)\n            return [self._child_gen.gen() for _ in range(0, length)]\n        self._start(rand, gen_array)\n\n    def contains_ts(self):\n        return self._child_gen.contains_ts()\n\nclass MapGen(DataGen):\n    \"\"\"Generate a Map\"\"\"\n    def __init__(self, key_gen, value_gen, min_length=0, max_length=20, nullable=True, special_cases=[]):\n        # keys cannot be nullable\n        assert not key_gen.nullable\n        self._min_length = min_length\n        self._max_length = max_length\n        self._key_gen = key_gen\n        self._value_gen = value_gen\n        super().__init__(MapType(key_gen.data_type, value_gen.data_type, valueContainsNull=value_gen.nullable), nullable=nullable, special_cases=special_cases)\n\n    def __repr__(self):\n        return super().__repr__() + '(' + str(self._key_gen) + ',' + str(self._value_gen) + ')'\n\n    def start(self, rand):\n        self._key_gen.start(rand)\n        self._value_gen.start(rand)\n        def make_dict():\n            length = rand.randint(self._min_length, self._max_length)\n            return {self._key_gen.gen(): self._value_gen.gen() for idx in range(0, length)}\n        self._start(rand, make_dict)\n\n    def contains_ts(self):\n        return self._key_gen.contains_ts() or self._value_gen.contains_ts()\n\n\nclass NullGen(DataGen):\n    \"\"\"Generate NullType values\"\"\"\n    def __init__(self):\n        super().__init__(NullType(), nullable=True)\n\n    def start(self, rand):\n        def make_null():\n            return None\n        self._start(rand, make_null)\n\ndef skip_if_not_utc():\n    if (not is_tz_utc()):\n        skip_unless_precommit_tests('The java system time zone is not set to UTC')\n\ndef gen_df(spark, data_gen, length=2048, seed=0, num_slices=None):\n    \"\"\"Generate a spark dataframe from the given data generators.\"\"\"\n    if isinstance(data_gen, list):\n        src = StructGen(data_gen, nullable=False)\n    else:\n        src = data_gen\n        # we cannot create a data frame from a nullable struct\n        assert not data_gen.nullable\n\n    # Before we get too far we need to verify that we can run with timestamps\n    if src.contains_ts():\n        skip_if_not_utc()\n\n    rand = random.Random(seed)\n    src.start(rand)\n    data = [src.gen() for index in range(0, length)]\n    # We use `numSlices` to create an RDD with the specific number of partitions,\n    # which is then turned into a dataframe. If not specified, it is `None` (default spark value)\n    return spark.createDataFrame(\n        SparkContext.getOrCreate().parallelize(data, numSlices=num_slices),\n        src.data_type)\n\ndef _mark_as_lit(data, data_type):\n    # To support nested types, 'data_type' is required.\n    assert data_type is not None\n\n    if data is None:\n        return f.lit(data).cast(data_type)\n\n    if isinstance(data_type, ArrayType):\n        assert isinstance(data, list)\n        # Sadly you cannot create a literal from just an array in pyspark\n        return f.array([_mark_as_lit(x, data_type.elementType) for x in data])\n    elif isinstance(data_type, StructType):\n        assert isinstance(data, tuple) and len(data) == len(data_type.fields)\n        # Sadly you cannot create a literal from just a dict/tuple in pyspark\n        children = zip(data, data_type.fields)\n        return f.struct([_mark_as_lit(x, fd.dataType).alias(fd.name) for x, fd in children])\n    elif isinstance(data_type, DateType):\n        # Due to https://bugs.python.org/issue13305 we need to zero pad for years prior to 1000,\n        # but this works for all of them\n        dateString = data.strftime(\"%Y-%m-%d\").zfill(10)\n        return f.lit(dateString).cast(data_type)\n    elif isinstance(data_type, MapType):\n        assert isinstance(data, dict)\n        # Sadly you cannot create a literal from just a dict/tuple in pyspark\n        col_array = []\n        for k in data:\n            col_array.append(_mark_as_lit(k, data_type.keyType))\n            col_array.append(_mark_as_lit(data[k], data_type.valueType))\n        return f.create_map(*col_array)\n    else:\n        # lit does not take a data type so we might have to cast it\n        return f.lit(data).cast(data_type)\n\ndef _gen_scalars_common(data_gen, count, seed=0):\n    if isinstance(data_gen, list):\n        src = StructGen(data_gen, nullable=False)\n    else:\n        src = data_gen\n\n    # Before we get too far we need to verify that we can run with timestamps\n    if src.contains_ts():\n        skip_if_not_utc()\n\n    rand = random.Random(seed)\n    src.start(rand)\n    return src\n\ndef gen_scalars(data_gen, count, seed=0, force_no_nulls=False):\n    \"\"\"Generate scalar values.\"\"\"\n    if force_no_nulls:\n        assert(not isinstance(data_gen, NullGen))\n    src = _gen_scalars_common(data_gen, count, seed=seed)\n    data_type = src.data_type\n    return (_mark_as_lit(src.gen(force_no_nulls=force_no_nulls), data_type) for i in range(0, count))\n\ndef gen_scalar(data_gen, seed=0, force_no_nulls=False):\n    \"\"\"Generate a single scalar value.\"\"\"\n    v = list(gen_scalars(data_gen, 1, seed=seed, force_no_nulls=force_no_nulls))\n    return v[0]\n\ndef gen_scalar_values(data_gen, count, seed=0, force_no_nulls=False):\n    \"\"\"Generate scalar values.\"\"\"\n    src = _gen_scalars_common(data_gen, count, seed=seed)\n    return (src.gen(force_no_nulls=force_no_nulls) for i in range(0, count))\n\ndef gen_scalar_value(data_gen, seed=0, force_no_nulls=False):\n    \"\"\"Generate a single scalar value.\"\"\"\n    v = list(gen_scalar_values(data_gen, 1, seed=seed, force_no_nulls=force_no_nulls))\n    return v[0]\n\ndef debug_df(df, path = None, file_format = 'json', num_parts = 1):\n    \"\"\"Print out or save the contents and the schema of a dataframe for debugging.\"\"\"\n\n    if path is not None:\n        # Save the dataframe and its schema\n        # The schema can be re-created by using DataType.fromJson and used\n        # for loading the dataframe\n        file_name = f\"{path}.{file_format}\"\n        schema_file_name = f\"{path}.schema.json\"\n\n        df.coalesce(num_parts).write.format(file_format).save(file_name)\n        print(f\"SAVED df output for debugging at {file_name}\")\n\n        schema_json = df.schema.json()\n        schema_file = open(schema_file_name , 'w')\n        schema_file.write(schema_json)\n        schema_file.close()\n        print(f\"SAVED df schema for debugging along in the output dir\")\n    else:\n        print('COLLECTED\\n{}'.format(df.collect()))\n\n    df.explain()\n    df.printSchema()\n    return df\n\ndef print_params(data_gen):\n    print('Test Datagen Params=' + str([(a, b.get_types()) for a, b in data_gen]))\n\ndef idfn(val):\n    \"\"\"Provide an API to provide display names for data type generators.\"\"\"\n    return str(val)\n\ndef meta_idfn(meta):\n    def tmp(something):\n        return meta + idfn(something)\n    return tmp\n\ndef three_col_df(spark, a_gen, b_gen, c_gen, length=2048, seed=0, num_slices=None):\n    gen = StructGen([('a', a_gen),('b', b_gen),('c', c_gen)], nullable=False)\n    return gen_df(spark, gen, length=length, seed=seed, num_slices=num_slices)\n\ndef two_col_df(spark, a_gen, b_gen, length=2048, seed=0, num_slices=None):\n    gen = StructGen([('a', a_gen),('b', b_gen)], nullable=False)\n    return gen_df(spark, gen, length=length, seed=seed, num_slices=num_slices)\n\ndef binary_op_df(spark, gen, length=2048, seed=0, num_slices=None):\n    return two_col_df(spark, gen, gen, length=length, seed=seed, num_slices=num_slices)\n\ndef unary_op_df(spark, gen, length=2048, seed=0, num_slices=None):\n    return gen_df(spark, StructGen([('a', gen)], nullable=False),\n        length=length, seed=seed, num_slices=num_slices)\n\ndef to_cast_string(spark_type):\n    if isinstance(spark_type, ByteType):\n        return 'BYTE'\n    elif isinstance(spark_type, ShortType):\n        return 'SHORT'\n    elif isinstance(spark_type, IntegerType):\n        return 'INT'\n    elif isinstance(spark_type, LongType):\n        return 'LONG'\n    elif isinstance(spark_type, FloatType):\n        return 'FLOAT'\n    elif isinstance(spark_type, DoubleType):\n        return 'DOUBLE'\n    elif isinstance(spark_type, BooleanType):\n        return 'BOOLEAN'\n    elif isinstance(spark_type, DateType):\n        return 'DATE'\n    elif isinstance(spark_type, TimestampType):\n        return 'TIMESTAMP'\n    elif isinstance(spark_type, StringType):\n        return 'STRING'\n    elif isinstance(spark_type, DecimalType):\n        return 'DECIMAL({}, {})'.format(spark_type.precision, spark_type.scale)\n    elif isinstance(spark_type, ArrayType):\n        return 'ARRAY<{}>'.format(to_cast_string(spark_type.elementType))\n    elif isinstance(spark_type, StructType):\n        children = [fd.name + ':' + to_cast_string(fd.dataType) for fd in spark_type.fields]\n        return 'STRUCT<{}>'.format(','.join(children))\n    else:\n        raise RuntimeError('CAST TO TYPE {} NOT SUPPORTED YET'.format(spark_type))\n\ndef get_null_lit_string(spark_type):\n    if isinstance(spark_type, NullType):\n        return 'null'\n    else:\n        string_type = to_cast_string(spark_type)\n        return 'CAST(null as {})'.format(string_type)\n\ndef _convert_to_sql(spark_type, data):\n    if isinstance(data, str):\n        d = \"'\" + data.replace(\"'\", \"\\\\'\") + \"'\"\n    elif isinstance(data, datetime):\n        d = \"'\" + data.strftime('%Y-%m-%d T%H:%M:%S.%f').zfill(26) + \"'\"\n    elif isinstance(data, date):\n        d = \"'\" + data.strftime('%Y-%m-%d').zfill(10) + \"'\"\n    elif isinstance(data, list):\n        assert isinstance(spark_type, ArrayType)\n        d = \"array({})\".format(\",\".join([_convert_to_sql(spark_type.elementType, x) for x in data]))\n    elif isinstance(data, tuple):\n        assert isinstance(spark_type, StructType) and len(data) == len(spark_type.fields)\n        # Format of each child: 'name',data\n        children = [\"'{}'\".format(fd.name) + ',' + _convert_to_sql(fd.dataType, x)\n                for fd, x in zip(spark_type.fields, data)]\n        d = \"named_struct({})\".format(','.join(children))\n    elif not data:\n        # data is None\n        d = \"null\"\n    else:\n        d = \"'{}'\".format(str(data))\n\n    if isinstance(spark_type, NullType):\n        return d\n    else:\n        return 'CAST({} as {})'.format(d, to_cast_string(spark_type))\n\ndef gen_scalars_for_sql(data_gen, count, seed=0, force_no_nulls=False):\n    \"\"\"Generate scalar values, but strings that can be used in selectExpr or SQL\"\"\"\n    src = _gen_scalars_common(data_gen, count, seed=seed)\n    if isinstance(data_gen, NullGen):\n        assert not force_no_nulls\n        return ('null' for i in range(0, count))\n    spark_type = data_gen.data_type\n    return (_convert_to_sql(spark_type, src.gen(force_no_nulls=force_no_nulls)) for i in range(0, count))\n\nbyte_gen = ByteGen()\nshort_gen = ShortGen()\nint_gen = IntegerGen()\nlong_gen = LongGen()\nfloat_gen = FloatGen()\ndouble_gen = DoubleGen()\nstring_gen = StringGen()\nboolean_gen = BooleanGen()\ndate_gen = DateGen()\ntimestamp_gen = TimestampGen()\ndecimal_gen_default = DecimalGen()\ndecimal_gen_neg_scale = DecimalGen(precision=7, scale=-3)\ndecimal_gen_scale_precision = DecimalGen(precision=7, scale=3)\ndecimal_gen_same_scale_precision = DecimalGen(precision=7, scale=7)\ndecimal_gen_64bit = DecimalGen(precision=12, scale=2)\ndecimal_gen_12_2 = DecimalGen(precision=12, scale=2)\ndecimal_gen_18_3 = DecimalGen(precision=18, scale=3)\ndecimal_gen_128bit = DecimalGen(precision=20, scale=2)\ndecimal_gen_20_2 = DecimalGen(precision=20, scale=2)\ndecimal_gen_30_2 = DecimalGen(precision=30, scale=2)\ndecimal_gen_36_5 = DecimalGen(precision=36, scale=5)\ndecimal_gen_36_neg5 = DecimalGen(precision=36, scale=-5)\ndecimal_gen_38_0 = DecimalGen(precision=38, scale=0)\ndecimal_gen_38_10 = DecimalGen(precision=38, scale=10)\ndecimal_gen_38_neg10 = DecimalGen(precision=38, scale=-10)\n\nnull_gen = NullGen()\n\nnumeric_gens = [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen]\n\nintegral_gens = [byte_gen, short_gen, int_gen, long_gen]\n# A lot of mathematical expressions only support a double as input\n# by parametrizing even for a single param for the test it makes the tests consistent\ndouble_gens = [double_gen]\ndouble_n_long_gens = [double_gen, long_gen]\nint_n_long_gens = [int_gen, long_gen]\ndecimal_gens_no_neg = [decimal_gen_default, decimal_gen_scale_precision,\n        decimal_gen_same_scale_precision, decimal_gen_64bit]\n\ndecimal_gens = [decimal_gen_neg_scale] + decimal_gens_no_neg\n\ndecimal_128_gens_no_neg = [decimal_gen_20_2, decimal_gen_30_2, decimal_gen_36_5,\n        decimal_gen_38_0, decimal_gen_38_10]\n\ndecimal_128_gens = decimal_128_gens_no_neg + [decimal_gen_36_neg5, decimal_gen_38_neg10]\n\n# all of the basic gens\nall_basic_gens_no_null = [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,\n                          string_gen, boolean_gen, date_gen, timestamp_gen]\nall_basic_gens = all_basic_gens_no_null + [null_gen]\n\nall_basic_gens_no_nan = [byte_gen, short_gen, int_gen, long_gen, FloatGen(no_nans=True), DoubleGen(no_nans=True),\n        string_gen, boolean_gen, date_gen, timestamp_gen, null_gen]\n\n# TODO add in some array generators to this once that is supported for sorting\n# a selection of generators that should be orderable (sortable and compareable)\norderable_gens = [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,\n        string_gen, boolean_gen, date_gen, timestamp_gen, null_gen] + decimal_gens\n\n# TODO add in some array generators to this once that is supported for these operations\n# a selection of generators that can be compared for equality\neq_gens = [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,\n        string_gen, boolean_gen, date_gen, timestamp_gen, null_gen]\n\n# Include decimal type while testing equalTo and notEqualTo\neq_gens_with_decimal_gen =  eq_gens + decimal_gens\n\ndate_gens = [date_gen]\ndate_n_time_gens = [date_gen, timestamp_gen]\n\nboolean_gens = [boolean_gen]\n\nsingle_level_array_gens = [ArrayGen(sub_gen) for sub_gen in all_basic_gens + decimal_gens]\nsingle_array_gens_sample_with_decimal128 = [ArrayGen(sub_gen) for sub_gen in decimal_128_gens]\n\nsingle_level_array_gens_no_null = [ArrayGen(sub_gen) for sub_gen in all_basic_gens_no_null + decimal_gens_no_neg]\n\nsingle_level_array_gens_no_nan = [ArrayGen(sub_gen) for sub_gen in all_basic_gens_no_nan + decimal_gens]\n\nsingle_level_array_gens_no_decimal = [ArrayGen(sub_gen) for sub_gen in all_basic_gens]\n\nmap_string_string_gen = [MapGen(StringGen(pattern='key_[0-9]', nullable=False), StringGen())]\n\n# Be careful to not make these too large of data generation takes for ever\n# This is only a few nested array gens, because nesting can be very deep\nnested_array_gens_sample = [ArrayGen(ArrayGen(short_gen, max_length=10), max_length=10),\n        ArrayGen(ArrayGen(string_gen, max_length=10), max_length=10),\n        ArrayGen(StructGen([['child0', byte_gen], ['child1', string_gen], ['child2', float_gen]]))]\n\n# Some array gens, but not all because of nesting\narray_gens_sample = single_level_array_gens + nested_array_gens_sample\narray_gens_sample_with_decimal128 = single_level_array_gens + nested_array_gens_sample + single_array_gens_sample_with_decimal128\n\n# all of the basic types in a single struct\nall_basic_struct_gen = StructGen([['child'+str(ind), sub_gen] for ind, sub_gen in enumerate(all_basic_gens)])\n\n# Some struct gens, but not all because of nesting\nnonempty_struct_gens_sample = [all_basic_struct_gen,\n        StructGen([['child0', byte_gen], ['child1', all_basic_struct_gen]]),\n        StructGen([['child0', ArrayGen(short_gen)], ['child1', double_gen]])]\n\nstruct_gens_sample = nonempty_struct_gens_sample + [StructGen([])]\nstruct_gen_decimal128 = StructGen(\n    [['child' + str(ind), sub_gen] for ind, sub_gen in enumerate(decimal_128_gens)])\nstruct_gens_sample_with_decimal128 = struct_gens_sample + [\n    struct_gen_decimal128]\n\nsimple_string_to_string_map_gen = MapGen(StringGen(pattern='key_[0-9]', nullable=False),\n        StringGen(), max_length=10)\n\nall_basic_map_gens = [MapGen(f(nullable=False), f()) for f in [BooleanGen, ByteGen, ShortGen, IntegerGen, LongGen, FloatGen, DoubleGen, DateGen, TimestampGen]] + [simple_string_to_string_map_gen]\ndecimal_64_map_gens = [MapGen(key_gen=gen, value_gen=gen, nullable=False) for gen in [DecimalGen(7, 3, nullable=False), DecimalGen(12, 2, nullable=False), DecimalGen(18, -3, nullable=False)]]\ndecimal_128_map_gens = [MapGen(key_gen=gen, value_gen=gen, nullable=False) for gen in [DecimalGen(20, 2, nullable=False), DecimalGen(36, 5, nullable=False), DecimalGen(38, 38, nullable=False),\n                                                                                       DecimalGen(36, -5, nullable=False)]]\ndecimal_128_no_neg_map_gens = [MapGen(key_gen=gen, value_gen=gen, nullable=False) for gen in [DecimalGen(20, 2, nullable=False), DecimalGen(36, 5, nullable=False), DecimalGen(38, 38, nullable=False)]]\n\n# Some map gens, but not all because of nesting\nmap_gens_sample = all_basic_map_gens + [MapGen(StringGen(pattern='key_[0-9]', nullable=False), ArrayGen(string_gen), max_length=10),\n        MapGen(RepeatSeqGen(IntegerGen(nullable=False), 10), long_gen, max_length=10),\n        MapGen(StringGen(pattern='key_[0-9]', nullable=False), simple_string_to_string_map_gen)]\n\nallow_negative_scale_of_decimal_conf = {'spark.sql.legacy.allowNegativeScaleOfDecimal': 'true'}\n\ndef copy_and_update(conf, *more_confs):\n    local_conf = conf.copy()\n    for more in more_confs:\n        local_conf.update(more)\n    return local_conf\n\nall_gen = [StringGen(), ByteGen(), ShortGen(), IntegerGen(), LongGen(),\n           FloatGen(), DoubleGen(), BooleanGen(), DateGen(), TimestampGen(),\n           decimal_gen_default, decimal_gen_scale_precision, decimal_gen_same_scale_precision,\n           decimal_gen_64bit, decimal_gen_128bit, decimal_gen_36_5, decimal_gen_38_10]\n\n# Pyarrow will complain the error as below if the timestamp is out of range for both CPU and GPU,\n# so narrow down the time range to avoid exceptions causing test failures.\n#\n#     \"pyarrow.lib.ArrowInvalid: Casting from timestamp[us, tz=UTC] to timestamp[ns]\n#      would result in out of bounds timestamp: 51496791452587000\"\n#\n# This issue has been fixed in pyarrow by the PR https://github.com/apache/arrow/pull/7169\n# However it still requires PySpark to specify the new argument \"timestamp_as_object\".\narrow_common_gen = [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,\n        string_gen, boolean_gen, date_gen,\n        TimestampGen(start=datetime(1970, 1, 1, tzinfo=timezone.utc),\n                     end=datetime(2262, 1, 1, tzinfo=timezone.utc))]\n\narrow_array_gens = [ArrayGen(subGen) for subGen in arrow_common_gen] + nested_array_gens_sample\n\narrow_one_level_struct_gen = StructGen([\n        ['child'+str(i), sub_gen] for i, sub_gen in enumerate(arrow_common_gen)])\n\narrow_struct_gens = [arrow_one_level_struct_gen,\n        StructGen([['child0', ArrayGen(short_gen)], ['child1', arrow_one_level_struct_gen]])]\n\n# This function adds a new column named uniq_int where each row\n# has a new unique integer value. It just starts at 0 and\n# increments by 1 for each row.\n# This can be used to add a column to a dataframe if you need to\n# sort on a column with unique values.\n# This collects the data to driver though so can be expensive.\ndef append_unique_int_col_to_df(spark, dataframe):\n    def append_unique_to_rows(rows):\n        new = []\n        for item in range(len(rows)):\n            row_dict = rows[item].asDict()\n            row_dict['uniq_int'] = item\n            new_row = Row(**row_dict)\n            new.append(new_row)\n        return new\n\n    collected = dataframe.collect()\n    if (len(collected) > INT_MAX):\n        raise RuntimeError('To many rows to add unique integer values starting from 0 to')\n    existing_schema = dataframe.schema\n    new_rows = append_unique_to_rows(collected)\n    new_schema = StructType(existing_schema.fields + [StructField(\"uniq_int\", IntegerType(), False)])\n    return spark.createDataFrame(new_rows, new_schema)\n"
  },
  {
    "path": "examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/python/rapids_udf_test.py",
    "content": "# Copyright (c) 2020-2022, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport pytest\n\nfrom asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_are_equal_sql\nfrom data_gen import *\nfrom spark_session import with_spark_session\nfrom pyspark.sql.utils import AnalysisException\n\nencoded_url_gen = StringGen('([^%]{0,1}(%[0-9A-F][0-9A-F]){0,1}){0,30}')\n\ndef drop_udf(spark, udfname):\n    spark.sql(\"DROP TEMPORARY FUNCTION IF EXISTS {}\".format(udfname))\n\ndef skip_if_no_hive(spark):\n    if spark.conf.get(\"spark.sql.catalogImplementation\") != \"hive\":\n        raise RuntimeError('The Spark session does not have Hive support')\n\ndef load_hive_udf_or_skip_test(spark, udfname, udfclass):\n    drop_udf(spark, udfname)\n    spark.sql(\"CREATE TEMPORARY FUNCTION {} AS '{}'\".format(udfname, udfclass))\n\ndef test_hive_simple_udf():\n    with_spark_session(skip_if_no_hive)\n    data_gens = [[\"i\", int_gen], [\"s\", encoded_url_gen]]\n    def evalfn(spark):\n        load_hive_udf_or_skip_test(spark, \"urldecode\", \"com.nvidia.spark.rapids.udf.hive.URLDecode\")\n        return gen_df(spark, data_gens)\n    assert_gpu_and_cpu_are_equal_sql(\n        evalfn,\n        \"hive_simple_udf_test_table\",\n        \"SELECT i, urldecode(s) FROM hive_simple_udf_test_table\")\n\ndef test_hive_generic_udf():\n    with_spark_session(skip_if_no_hive)\n    def evalfn(spark):\n        load_hive_udf_or_skip_test(spark, \"urlencode\", \"com.nvidia.spark.rapids.udf.hive.URLEncode\")\n        return gen_df(spark, [[\"s\", StringGen('.{0,30}')]])\n    assert_gpu_and_cpu_are_equal_sql(\n        evalfn,\n        \"hive_generic_udf_test_table\",\n        \"SELECT urlencode(s) FROM hive_generic_udf_test_table\")\n\n    def evalfn_decimal(spark):\n        load_hive_udf_or_skip_test(spark, \"fraction\", \"com.nvidia.spark.rapids.udf.hive.DecimalFraction\")\n        return gen_df(spark, [[\"dec\", DecimalGen(38, 18)]])\n    assert_gpu_and_cpu_are_equal_sql(\n        evalfn_decimal,\n        \"hive_generic_udf_test_table\",\n        \"SELECT fraction(dec) FROM hive_generic_udf_test_table\")\n\n@pytest.mark.rapids_udf_example_native\ndef test_hive_simple_udf_native():\n    with_spark_session(skip_if_no_hive)\n    data_gens = [[\"s\", StringGen('.{0,30}')]]\n    def evalfn(spark):\n        load_hive_udf_or_skip_test(spark, \"wordcount\", \"com.nvidia.spark.rapids.udf.hive.StringWordCount\")\n        return gen_df(spark, data_gens)\n    assert_gpu_and_cpu_are_equal_sql(\n        evalfn,\n        \"hive_native_udf_test_table\",\n        \"SELECT wordcount(s) FROM hive_native_udf_test_table\")\n\ndef load_java_udf_or_skip_test(spark, udfname, udfclass, udf_return_type=None):\n    drop_udf(spark, udfname)\n    spark.udf.registerJavaFunction(udfname, udfclass, udf_return_type)\n\ndef test_java_url_decode():\n    def evalfn(spark):\n        load_java_udf_or_skip_test(spark, 'urldecode', 'com.nvidia.spark.rapids.udf.java.URLDecode')\n        return unary_op_df(spark, encoded_url_gen).selectExpr(\"urldecode(a)\")\n    assert_gpu_and_cpu_are_equal_collect(evalfn)\n\ndef test_java_url_encode():\n    def evalfn(spark):\n        load_java_udf_or_skip_test(spark, 'urlencode', 'com.nvidia.spark.rapids.udf.java.URLEncode')\n        return unary_op_df(spark, StringGen('.{0,30}')).selectExpr(\"urlencode(a)\")\n    assert_gpu_and_cpu_are_equal_collect(evalfn)\n\ndef test_java_decimal_fraction():\n    def evalfn(spark):\n        from pyspark.sql.types import DecimalType\n        load_java_udf_or_skip_test(spark, 'fraction',\n                                   'com.nvidia.spark.rapids.udf.java.DecimalFraction')\n        load_java_udf_or_skip_test(spark, 'fraction_dec64_s10',\n                                   'com.nvidia.spark.rapids.udf.java.DecimalFraction',\n                                   DecimalType(18, 10))\n        load_java_udf_or_skip_test(spark, 'fraction_dec32_s3',\n                                   'com.nvidia.spark.rapids.udf.java.DecimalFraction',\n                                   DecimalType(8, 3))\n        return three_col_df(spark, DecimalGen(38, 18), DecimalGen(18, 10), DecimalGen(8, 3)\n                            ).selectExpr(\"fraction(a)\", \"fraction_dec64_s10(b)\", \"fraction_dec32_s3(c)\")\n    assert_gpu_and_cpu_are_equal_collect(evalfn)\n\n@pytest.mark.rapids_udf_example_native\ndef test_java_cosine_similarity_reasonable_range():\n    def evalfn(spark):\n        class RangeFloatGen(FloatGen):\n            def start(self, rand):\n                self._start(rand, lambda: rand.uniform(-1000.0, 1000.0))\n        load_java_udf_or_skip_test(spark, \"cosine_similarity\", \"com.nvidia.spark.rapids.udf.java.CosineSimilarity\")\n        arraygen = ArrayGen(RangeFloatGen(nullable=False, no_nans=True, special_cases=[]), min_length=8, max_length=8)\n        df = binary_op_df(spark, arraygen)\n        return df.selectExpr(\"cosine_similarity(a, b)\")\n    assert_gpu_and_cpu_are_equal_collect(evalfn)\n\n@pytest.mark.rapids_udf_example_native\ndef test_java_cosine_similarity_with_nans():\n    def evalfn(spark):\n        load_java_udf_or_skip_test(spark, \"cosine_similarity\", \"com.nvidia.spark.rapids.udf.java.CosineSimilarity\")\n        arraygen = ArrayGen(FloatGen(nullable=False), min_length=8, max_length=8)\n        return binary_op_df(spark, arraygen).selectExpr(\"cosine_similarity(a, b)\")\n    assert_gpu_and_cpu_are_equal_collect(evalfn)\n"
  },
  {
    "path": "examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/python/spark_init_internal.py",
    "content": "# Copyright (c) 2020-2021, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\n\ntry:\n    import pyspark\nexcept ImportError as error:\n    import findspark\n    findspark.init()\n    import pyspark\n\n_DRIVER_ENV = 'PYSP_TEST_spark_driver_extraJavaOptions'\n\ndef _spark__init():\n    #Force the RapidsPlugin to be enabled, so it blows up if the classpath is not set properly\n    # DO NOT SET ANY OTHER CONFIGS HERE!!!\n    # due to bugs in pyspark/pytest it looks like any configs set here\n    # can be reset in the middle of a test if specific operations are done (some types of cast etc)\n    _sb = pyspark.sql.SparkSession.builder\n    _sb.config('spark.plugins', 'com.nvidia.spark.SQLPlugin') \\\n            .config(\"spark.sql.adaptive.enabled\", \"false\") \\\n            .config('spark.sql.queryExecutionListeners', 'org.apache.spark.sql.rapids.ExecutionPlanCaptureCallback')\n\n    for key, value in os.environ.items():\n        if key.startswith('PYSP_TEST_') and key != _DRIVER_ENV:\n            _sb.config(key[10:].replace('_', '.'), value)\n\n    driver_opts = os.environ.get(_DRIVER_ENV, \"\")\n\n    if ('PYTEST_XDIST_WORKER' in os.environ):\n        wid = os.environ['PYTEST_XDIST_WORKER']\n        _handle_derby_dir(_sb, driver_opts, wid)\n        _handle_event_log_dir(_sb, wid)\n    else:\n        _sb.config('spark.driver.extraJavaOptions', driver_opts)\n        _handle_event_log_dir(_sb, 'gw0')\n\n    # enableHiveSupport() is needed for parquet bucket tests\n    _s = _sb.enableHiveSupport() \\\n            .appName('rapids spark plugin integration tests (python)').getOrCreate()\n    #TODO catch the ClassNotFound error that happens if the classpath is not set up properly and\n    # make it a better error message\n    _s.sparkContext.setLogLevel(\"WARN\")\n    return _s\n\n\ndef _handle_derby_dir(sb, driver_opts, wid):\n    d = \"./derby_{}\".format(wid)\n    if not os.path.exists(d):\n        os.makedirs(d)\n    sb.config('spark.driver.extraJavaOptions', driver_opts + ' -Dderby.system.home={}'.format(d))\n\n\ndef _handle_event_log_dir(sb, wid):\n    if os.environ.get('SPARK_EVENTLOG_ENABLED', str(True)).lower() in [\n        str(False).lower(), 'off', '0'\n    ]:\n        print('Automatic configuration for spark event log disabled')\n        return\n\n    spark_conf = pyspark.SparkConf()\n    master_url = os.environ.get('PYSP_TEST_spark_master',\n                                spark_conf.get(\"spark.master\", 'local'))\n    event_log_config = os.environ.get('PYSP_TEST_spark_eventLog_enabled',\n                                      spark_conf.get('spark.eventLog.enabled', str(False).lower()))\n    event_log_codec = os.environ.get('PYSP_TEST_spark_eventLog_compression_codec', 'zstd')\n\n    if not master_url.startswith('local') or event_log_config != str(False).lower():\n        print(\"SPARK_EVENTLOG_ENABLED is ignored for non-local Spark master and when \"\n              \"it's pre-configured by the user\")\n        return\n    d = \"./eventlog_{}\".format(wid)\n    if not os.path.exists(d):\n        os.makedirs(d)\n\n    print('Spark event logs will appear under {}. Set the environmnet variable '\n          'SPARK_EVENTLOG_ENABLED=false if you want to disable it'.format(d))\n\n    sb\\\n        .config('spark.eventLog.dir', \"file://{}\".format(os.path.abspath(d))) \\\n        .config('spark.eventLog.compress', True) \\\n        .config('spark.eventLog.enabled', True) \\\n        .config('spark.eventLog.compression.codec', event_log_codec)\n\n\n_spark = _spark__init()\n\ndef get_spark_i_know_what_i_am_doing():\n    \"\"\"\n    Get the current SparkSession.\n    This should almost never be called directly instead you should call\n    with_spark_session, with_cpu_session, or with_gpu_session for spark_session.\n    This is to guarantee that the session and it's config is setup in a repeatable way.\n    \"\"\"\n    return _spark\n\ndef spark_version():\n    return _spark.version\n"
  },
  {
    "path": "examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/python/spark_session.py",
    "content": "# Copyright (c) 2020-2022, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nfrom conftest import is_allowing_any_non_gpu, get_non_gpu_allowed, get_validate_execs_in_gpu_plan, is_databricks_runtime\nfrom pyspark.sql import SparkSession, DataFrame\nfrom spark_init_internal import get_spark_i_know_what_i_am_doing, spark_version\n\ndef _from_scala_map(scala_map):\n    ret = {}\n    # The value we get is a scala map, not a java map, so we need to jump through some hoops\n    keys = scala_map.keys().iterator()\n    while keys.hasNext():\n        key = keys.next()\n        ret[key] = scala_map.get(key).get()\n    return ret\n\n_spark = get_spark_i_know_what_i_am_doing()\n# Have to reach into a private member to get access to the API we need\n_orig_conf = _from_scala_map(_spark.conf._jconf.getAll())\n_orig_conf_keys = _orig_conf.keys()\n\ndef is_tz_utc(spark=_spark):\n    \"\"\"\n    true if the tz is UTC else false\n    \"\"\"\n    # Now we have to do some kind of ugly internal java stuff\n    jvm = spark.sparkContext._jvm\n    utc = jvm.java.time.ZoneId.of('UTC').normalized()\n    sys_tz = jvm.java.time.ZoneId.systemDefault().normalized()\n    return utc == sys_tz\n\ndef _set_all_confs(conf):\n    for key, value in conf.items():\n        if _spark.conf.get(key, None) != value:\n            _spark.conf.set(key, value)\n\ndef reset_spark_session_conf():\n    \"\"\"Reset all of the configs for a given spark session.\"\"\"\n    _set_all_confs(_orig_conf)\n    #We should clear the cache\n    _spark.catalog.clearCache()\n    # Have to reach into a private member to get access to the API we need\n    current_keys = _from_scala_map(_spark.conf._jconf.getAll()).keys()\n    for key in current_keys:\n        if key not in _orig_conf_keys:\n            _spark.conf.unset(key)\n\ndef _check_for_proper_return_values(something):\n    \"\"\"We don't want to return an DataFrame or Dataset from a with_spark_session. You will not get what you expect\"\"\"\n    if (isinstance(something, DataFrame)):\n        raise RuntimeError(\"You should never return a DataFrame from a with_*_session, you will not get the results that you expect\")\n\ndef with_spark_session(func, conf={}):\n    \"\"\"Run func that takes a spark session as input with the given configs set.\"\"\"\n    reset_spark_session_conf()\n    _add_job_description(conf)\n    _set_all_confs(conf)\n    ret = func(_spark)\n    _check_for_proper_return_values(ret)\n    return ret\n\n\ndef _add_job_description(conf):\n    is_gpu_job = conf.get('spark.rapids.sql.enabled', False)\n    job_type = 'GPU' if str(is_gpu_job).lower() == str(True).lower() else 'CPU'\n    job_desc = '{}[{}]'.format(os.environ.get('PYTEST_CURRENT_TEST'), job_type)\n    _spark.sparkContext.setJobDescription(job_desc)\n\n\ndef with_cpu_session(func, conf={}):\n    \"\"\"Run func that takes a spark session as input with the given configs set on the CPU.\"\"\"\n    copy = dict(conf)\n    copy['spark.rapids.sql.enabled'] = 'false'\n    return with_spark_session(func, conf=copy)\n\ndef with_gpu_session(func, conf={}):\n    \"\"\"\n    Run func that takes a spark session as input with the given configs set on the GPU.\n    Note that this forces you into test mode unless.  It is not a requirement, but is\n    simplest for right now.\n    \"\"\"\n    copy = dict(conf)\n    copy['spark.rapids.sql.enabled'] = 'true'\n    if is_allowing_any_non_gpu():\n        copy['spark.rapids.sql.test.enabled'] = 'false'\n    else:\n        copy['spark.rapids.sql.test.enabled'] = 'true'\n        copy['spark.rapids.sql.test.allowedNonGpu'] = ','.join(get_non_gpu_allowed())\n\n    copy['spark.rapids.sql.test.validateExecsInGpuPlan'] = ','.join(get_validate_execs_in_gpu_plan())\n    return with_spark_session(func, conf=copy)\n\ndef is_before_spark_311():\n    return spark_version() < \"3.1.0\"\n\ndef is_before_spark_320():\n    return spark_version() < \"3.2.0\"\n\ndef is_before_spark_330():\n    return spark_version() < \"3.3.0\"\n\ndef is_databricks91_or_later():\n    spark = get_spark_i_know_what_i_am_doing()\n    return spark.conf.get(\"spark.databricks.clusterUsageTags.sparkVersion\", \"\") >= \"9.1\"\n"
  },
  {
    "path": "examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/scala/com/nvidia/spark/rapids/udf/scala/URLDecode.scala",
    "content": "/*\n * Copyright (c) 2021-2022, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.nvidia.spark.rapids.udf.scala\n\nimport java.net.URLDecoder\n\nimport ai.rapids.cudf.{ColumnVector, DType, Scalar}\nimport com.nvidia.spark.RapidsUDF\n\n/**\n * A Scala user-defined function (UDF) that decodes URL-encoded strings.\n * This class demonstrates how to implement a Scala UDF that also\n * provides a RAPIDS implementation that can run on the GPU when the query\n * is executed with the RAPIDS Accelerator for Apache Spark.\n */\nclass URLDecode extends Function[String, String] with RapidsUDF with Serializable {\n  /** Row-by-row implementation that executes on the CPU */\n  override def apply(s: String): String = {\n    Option(s).map { s =>\n      try {\n        URLDecoder.decode(s, \"utf-8\")\n      } catch {\n        case _: IllegalArgumentException => s\n      }\n    }.orNull\n  }\n\n  /** Columnar implementation that runs on the GPU */\n  override def evaluateColumnar(numRows: Int, args: ColumnVector*): ColumnVector = {\n    // The CPU implementation takes a single string argument, so similarly\n    // there should only be one column argument of type STRING.\n    require(args.length == 1, s\"Unexpected argument count: ${args.length}\")\n    val input = args.head\n    require(numRows == input.getRowCount, s\"Expected $numRows rows, received ${input.getRowCount}\")\n    require(input.getType == DType.STRING, s\"Argument type is not a string: ${input.getType}\")\n\n    // The cudf urlDecode does not convert '+' to a space, so do that as a pre-pass first.\n    // All intermediate results are closed to avoid leaking GPU resources.\n    val plusScalar = Scalar.fromString(\"+\")\n    try {\n      val spaceScalar = Scalar.fromString(\" \")\n      try {\n        val replaced = input.stringReplace(plusScalar, spaceScalar)\n        try {\n          replaced.urlDecode()\n        } finally {\n          replaced.close()\n        }\n      } finally {\n        spaceScalar.close()\n      }\n    } finally {\n      plusScalar.close()\n    }\n  }\n}\n"
  },
  {
    "path": "examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/scala/com/nvidia/spark/rapids/udf/scala/URLEncode.scala",
    "content": "/*\n * Copyright (c) 2021-2022, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.nvidia.spark.rapids.udf.scala\n\nimport java.net.URLEncoder\n\nimport ai.rapids.cudf.{ColumnVector, DType}\nimport com.nvidia.spark.RapidsUDF\n\n/**\n * A Scala user-defined function (UDF) that URL-encodes strings.\n * This class demonstrates how to implement a Scala UDF that also\n * provides a RAPIDS implementation that can run on the GPU when the query\n * is executed with the RAPIDS Accelerator for Apache Spark.\n */\nclass URLEncode extends Function[String, String] with RapidsUDF with Serializable {\n  /** Row-by-row implementation that executes on the CPU */\n  override def apply(s: String): String = {\n    Option(s).map { s =>\n      URLEncoder.encode(s, \"utf-8\")\n          .replace(\"+\", \"%20\")\n          .replace(\"*\", \"%2A\")\n          .replace(\"%7E\", \"~\")\n    }.orNull\n  }\n\n  /** Columnar implementation that runs on the GPU */\n  override def evaluateColumnar(numRows: Int, args: ColumnVector*): ColumnVector = {\n    // The CPU implementation takes a single string argument, so similarly\n    // there should only be one column argument of type STRING.\n    require(args.length == 1, s\"Unexpected argument count: ${args.length}\")\n    val input = args.head\n    require(numRows == input.getRowCount, s\"Expected $numRows rows, received ${input.getRowCount}\")\n    require(input.getType == DType.STRING, s\"Argument type is not a string: ${input.getType}\")\n    input.urlEncode()\n  }\n}\n"
  },
  {
    "path": "examples/XGBoost-Examples/.gitignore",
    "content": "samples.zip\n"
  },
  {
    "path": "examples/XGBoost-Examples/README.md",
    "content": "# Spark XGBoost Examples\n\nSpark XGBoost examples here showcase the need for ETL+Training pipeline GPU acceleration.\nThe Scala based XGBoost examples here use [DMLC’s version](https://repo1.maven.org/maven2/ml/dmlc/xgboost4j-spark_2.12/).\nThe pyspark based XGBoost examples requires [installing RAPIDS via pip](https://rapids.ai/pip.html#install).\nMost data scientists spend a lot of time not only on\nTraining models but also processing the large amounts of data needed to train these models.\nAs you can see below, Pyspark+XGBoost training on GPUs can be up to 13X and data processing using\nRAPIDS Accelerator can also be accelerated with an end-to-end speed-up of 11X on GPU compared to CPU.\nIn the public cloud, better performance can lead to significantly lower costs as demonstrated in this [blog](https://developer.nvidia.com/blog/gpu-accelerated-spark-xgboost/).\n\n![mortgage-speedup](/docs/img/guides/mortgage-perf.png)\n\nNote that the Training test result is based on 4 years [Fannie Mea Single-Family Loan Performance Data](https://capitalmarkets.fanniemae.com/credit-risk-transfer/single-family-credit-risk-transfer/fannie-mae-single-family-loan-performance-data) \nwith a 8 A100 GPU and 1024 CPU vcores cluster, the performance is affected by many aspects, \nincluding data size and type of GPU. \n\nIn this folder, there are three blue prints for users to learn about using \nSpark XGBoost and RAPIDS Accelerator on GPUs :\n\n1. Mortgage Prediction\n2. Agaricus Classification\n3. Taxi Fare Prediction\n\nFor each of these examples we have prepared a [sample dataset](/datasets) \nin this folder for testing. These datasets are only provided for convenience. In order to test for performance,\nplease download the larger dataset from their respectives sources.\n\nThere are three sections in this readme section. \nIn the first section, we will list the notebooks that can be run on Jupyter with Python or Scala\n([Spylon Kernel](https://pypi.org/project/spylon-kernel/) or [Apache Toree Kernel](https://toree.apache.org/)). \n\nIn the second section, we have sample jar files and source code if users would like to build \nand run this as a Scala or a PySpark Spark-XGBoost application. \n\nIn the last section, we provide basic “Getting Started Guides” for setting up GPU\nSpark-XGBoost on different environments based on the Apache Spark scheduler such as YARN,\nStandalone or Kubernetes.\n\n## SECTION 1: SPARK-XGBOOST EXAMPLE NOTEBOOKS\n\n1. Mortgage Notebooks\n   - Python\n     - [Mortgage ETL](mortgage/notebooks/python/MortgageETL.ipynb)\n     - [Mortgage Training Prediction](mortgage/notebooks/python/mortgage-gpu.ipynb)\n     - [Mortgage ETL + XGBoost Training](mortgage/notebooks/python/MortgageETL+XGBoost.ipynb)\n   - Scala\n     - [Mortgage ETL](mortgage/notebooks/scala/mortgage-ETL.ipynb)\n     - [Mortgage Training Prediction](mortgage/notebooks/scala/mortgage-gpu.ipynb)\n2. Agaricus Notebooks    \n   - Python\n     - [Agaricus Training Classification](agaricus/notebooks/python/agaricus-gpu.ipynb)\n   - Scala\n     - [Agaricus Training Classification](agaricus/notebooks/scala/agaricus-gpu.ipynb)\n3. Taxi Notebook    \n   - Python\n     - [Taxi Training Classification](taxi/notebooks/python/taxi-gpu.ipynb)\n   - Scala    \n     - [Taxi Training Classification](taxi/notebooks/scala/taxi-gpu.ipynb)\n    \n## SECTION 2: BUILDING A PYSPARK OR A SCALA XGBOOST APPLICATION\nThe first step to build a Spark application is preparing packages and datasets\nneeded to build the jars. Please use the instructions below for building the\n\n- [Scala](/docs/get-started/xgboost-examples/prepare-package-data/preparation-scala.md)\n- [Python](/docs/get-started/xgboost-examples/prepare-package-data/preparation-python.md)\n\nIn addition, we have the source code for building reference applications. \nBelow are source codes for the example Spark jobs:\n- Mortgage: [Scala](mortgage/scala/src/com/nvidia/spark/examples/mortgage), [Python](mortgage/python/com/nvidia/spark/examples/mortgage)\n- Taxi: [Scala](taxi/scala/src/com/nvidia/spark/examples/taxi), [Python](taxi/python/com/nvidia/spark/examples/taxi)\n- Agaricus: [Scala](agaricus/scala/src/com/nvidia/spark/examples/agaricus), [Python](agaricus/python/com/nvidia/spark/examples/agaricus)\n\n\n## SECTION 3: SETTING UP THE ENVIRONMENT\nPlease follow below steps to run the example Spark jobs in different Spark environments:\n- Getting started on on-premises clusters\n    - [Standalone cluster for Scala](/docs/get-started/xgboost-examples/on-prem-cluster/standalone-scala.md)\n    - [Standalone cluster for Python](/docs/get-started/xgboost-examples/on-prem-cluster/standalone-python.md)\n    - [YARN for Scala](/docs/get-started/xgboost-examples/on-prem-cluster/yarn-scala.md)\n    - [YARN for Python](/docs/get-started/xgboost-examples/on-prem-cluster/yarn-python.md)\n    - [Kubernetes](/docs/get-started/xgboost-examples/on-prem-cluster/kubernetes-scala.md)\n- Getting started on cloud service providers    \n  - Amazon AWS\n    - [EC2](/docs/get-started/xgboost-examples/csp/aws/ec2.md)\n  - [Databricks](/docs/get-started/xgboost-examples/csp/databricks/databricks.md)\n  - [GCP](/docs/get-started/xgboost-examples/csp/dataproc/gcp.md)\n    \nPlease follow below steps to run the example notebooks in different notebook environments:\n\n- Getting started for Jupyter Notebook applications\n    - [Apache Toree Notebook for Scala](/docs/get-started/xgboost-examples/notebook/toree.md)\n    - [Jupyter Notebook with spylon kernel](/docs/get-started/xgboost-examples/notebook/spylon.md)\n    - [Jupyter Notebook for Python](/docs/get-started/xgboost-examples/notebook/python-notebook.md)\n    \nNote: \nUpdate the default value of `spark.sql.execution.arrow.maxRecordsPerBatch` to a larger number(such as 200000) will  \nsignificantly improve performance by accelerating data transfer between JVM and Python process.\n\nFor the CrossValidator job, we need to set `spark.task.resource.gpu.amount=1` to allow only 1 training task running on 1 GPU(executor),\notherwise the customized CrossValidator may schedule more than 1 xgboost training tasks into one executor simultaneously and trigger \n[issue-131](https://github.com/NVIDIA/spark-rapids-examples/issues/131).\nFor XGBoost job, if the number of shuffle stage tasks before training is less than the num_worker, \nthe training tasks will be scheduled to run on part of nodes instead of all nodes due to Spark Data Locality feature.\nThe workaround is to increase the partitions of the shuffle stage by setting `spark.sql.files.maxPartitionBytes=RightNum`.\nIf you are running XGBoost scala notebooks on Dataproc, please make sure to update below configs to avoid job failure:\n```\nspark.dynamicAllocation.enabled=false\nspark.task.resource.gpu.amount=1\n```"
  },
  {
    "path": "examples/XGBoost-Examples/agaricus/.gitignore",
    "content": ".idea\ntarget\n*.iml\n"
  },
  {
    "path": "examples/XGBoost-Examples/agaricus/notebooks/python/agaricus-gpu.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Introduction to XGBoost Spark with GPU\\n\",\n    \"\\n\",\n    \"Agaricus is an example of xgboost classifier for multiple classification. This notebook will show you how to load data, train the xgboost model.\\n\",\n    \"\\n\",\n    \"A few libraries required for this notebook:\\n\",\n    \"  1. cudf-cu11\\n\",\n    \"  2. xgboost\\n\",\n    \"  3. scikit-learn\\n\",\n    \"  4. numpy\\n\",\n    \"  \\n\",\n    \"This notebook also illustrates the ease of porting a sample CPU based Spark xgboost4j code into GPU. There is no change required for running Spark XGBoost on GPU because both CPU and GPU call the same API. For CPU run, we need to vectorize the trained dataset before fitting data to classifier.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Import All Libraries\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from xgboost.spark import SparkXGBClassifier, SparkXGBClassifierModel\\n\",\n    \"from pyspark.ml.evaluation import MulticlassClassificationEvaluator\\n\",\n    \"from pyspark.sql import SparkSession\\n\",\n    \"from pyspark.sql.types import FloatType, StructField, StructType\\n\",\n    \"from time import time\\n\",\n    \"from pyspark.conf import SparkConf\\n\",\n    \"import os\\n\",\n    \"# if you pass/unpack the archive file and enable the environment\\n\",\n    \"# os.environ['PYSPARK_PYTHON'] = \\\"./environment/bin/python\\\"\\n\",\n    \"# os.environ['PYSPARK_DRIVER_PYTHON'] = \\\"./environment/bin/python\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Besides CPU version requires two extra libraries.\\n\",\n    \"```Python\\n\",\n    \"from pyspark.ml.feature import VectorAssembler\\n\",\n    \"from pyspark.sql.functions import col\\n\",\n    \"```\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Create Spark Session and Data Reader\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2022-11-30 06:57:40,306 WARN util.NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\\n\",\n      \"Setting default log level to \\\"WARN\\\".\\n\",\n      \"To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\\n\",\n      \"2022-11-30 06:57:40,550 WARN resource.ResourceUtils: The configuration of cores (exec = 2 task = 1, runnable tasks = 2) will result in wasted resources due to resource gpu limiting the number of runnable tasks per executor to: 1. Please adjust your configuration.\\n\",\n      \"2022-11-30 06:57:54,195 WARN rapids.RapidsPluginUtils: RAPIDS Accelerator 25.02.1 using cudf 25.02.1.\\n\",\n      \"2022-11-30 06:57:54,210 WARN rapids.RapidsPluginUtils: spark.rapids.sql.multiThreadedRead.numThreads is set to 20.\\n\",\n      \"2022-11-30 06:57:54,214 WARN rapids.RapidsPluginUtils: RAPIDS Accelerator is enabled, to disable GPU support set `spark.rapids.sql.enabled` to false.\\n\",\n      \"2022-11-30 06:57:54,214 WARN rapids.RapidsPluginUtils: spark.rapids.sql.explain is set to `NOT_ON_GPU`. Set it to 'NONE' to suppress the diagnostics logging about the query placement on the GPU.\\n\",\n      \"2022-11-30 06:57:54,685 WARN yarn.Client: Neither spark.yarn.jars nor spark.yarn.archive is set, falling back to uploading libraries under SPARK_HOME.\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"SPARK_MASTER_URL = os.getenv(\\\"SPARK_MASTER_URL\\\", \\\"/your-url\\\")\\n\",\n    \"\\n\",\n    \"RAPIDS_JAR = os.getenv(\\\"RAPIDS_JAR\\\", \\\"/your-jar-path\\\")\\n\",\n    \"\\n\",\n    \"# You need to update with your real hardware resource \\n\",\n    \"driverMem = os.getenv(\\\"DRIVER_MEM\\\", \\\"2g\\\")\\n\",\n    \"executorMem = os.getenv(\\\"EXECUTOR_MEM\\\", \\\"2g\\\")\\n\",\n    \"pinnedPoolSize = os.getenv(\\\"PINNED_POOL_SIZE\\\", \\\"2g\\\")\\n\",\n    \"concurrentGpuTasks = os.getenv(\\\"CONCURRENT_GPU_TASKS\\\", \\\"2\\\")\\n\",\n    \"executorCores = int(os.getenv(\\\"EXECUTOR_CORES\\\", \\\"2\\\"))\\n\",\n    \"# Common spark settings\\n\",\n    \"conf = SparkConf()\\n\",\n    \"conf.setMaster(SPARK_MASTER_URL)\\n\",\n    \"conf.setAppName(\\\"Microbenchmark on GPU\\\")\\n\",\n    \"conf.set(\\\"spark.executor.instances\\\",\\\"1\\\")\\n\",\n    \"conf.set(\\\"spark.driver.memory\\\", driverMem)\\n\",\n    \"## The tasks will run on GPU memory, so there is no need to set a high host memory\\n\",\n    \"conf.set(\\\"spark.executor.memory\\\", executorMem)\\n\",\n    \"## The tasks will run on GPU cores, so there is no need to use many cpu cores\\n\",\n    \"conf.set(\\\"spark.executor.cores\\\", executorCores)\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"# Plugin settings\\n\",\n    \"conf.set(\\\"spark.executor.resource.gpu.amount\\\", \\\"1\\\")\\n\",\n    \"conf.set(\\\"spark.rapids.sql.concurrentGpuTasks\\\", concurrentGpuTasks)\\n\",\n    \"conf.set(\\\"spark.rapids.memory.pinnedPool.size\\\", pinnedPoolSize)\\n\",\n    \"# since pyspark and xgboost share the same GPU, we disable RMM to avoid GPU OOM while training \\n\",\n    \"conf.set(\\\"spark.rapids.memory.gpu.pool\\\", \\\"NONE\\\")\\n\",\n    \"conf.set(\\\"spark.locality.wait\\\",\\\"0\\\")\\n\",\n    \"##############note: only support value=1 https://github.com/dmlc/xgboost/blame/master/python-package/xgboost/spark/core.py#L370-L374\\n\",\n    \"conf.set(\\\"spark.task.resource.gpu.amount\\\", 1) \\n\",\n    \"conf.set(\\\"spark.rapids.sql.enabled\\\", \\\"true\\\") \\n\",\n    \"conf.set(\\\"spark.plugins\\\", \\\"com.nvidia.spark.SQLPlugin\\\")\\n\",\n    \"conf.set(\\\"spark.sql.cache.serializer\\\",\\\"com.nvidia.spark.ParquetCachedBatchSerializer\\\")\\n\",\n    \"conf.set(\\\"spark.sql.execution.arrow.maxRecordsPerBatch\\\", 200000) \\n\",\n    \"conf.set(\\\"spark.driver.extraClassPath\\\", RAPIDS_JAR)\\n\",\n    \"conf.set(\\\"spark.executor.extraClassPath\\\", RAPIDS_JAR)\\n\",\n    \"\\n\",\n    \"# if you pass/unpack the archive file and enable the environment\\n\",\n    \"# conf.set(\\\"spark.yarn.dist.archives\\\", \\\"your_pyspark_venv.tar.gz#environment\\\")\\n\",\n    \"# Create spark session\\n\",\n    \"spark = SparkSession.builder.config(conf=conf).getOrCreate()\\n\",\n    \"\\n\",\n    \"reader = spark.read\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Specify the Data Schema and Load the Data\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"label = 'label'\\n\",\n    \"features = [ 'feature_' + str(i) for i in range(0, 126) ]\\n\",\n    \"schema = StructType([ StructField(x, FloatType()) for x in [label] + features ])\\n\",\n    \"\\n\",\n    \"# You need to update them to your real paths!\\n\",\n    \"dataRoot = os.getenv(\\\"DATA_ROOT\\\", \\\"/data\\\")\\n\",\n    \"train_path = dataRoot + \\\"/agaricus/csv/train\\\"\\n\",\n    \"eval_path = dataRoot + \\\"/agaricus/csv/eval\\\"\\n\",\n    \"\\n\",\n    \"data_format = 'csv'\\n\",\n    \"has_header = 'true'\\n\",\n    \"if data_format == 'csv':\\n\",\n    \"    train_data = reader.schema(schema).option('header',has_header).csv(train_path)\\n\",\n    \"    trans_data = reader.schema(schema).option('header',has_header).csv(eval_path)\\n\",\n    \"else :\\n\",\n    \"    train_data = reader.load(train_path)\\n\",\n    \"    trans_data = reader.load(eval_path)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Note on CPU version, vectorization is required before fitting data to classifier, which means you need to assemble all feature columns into one column.\\n\",\n    \"\\n\",\n    \"```Python\\n\",\n    \"def vectorize(data_frame):\\n\",\n    \"    to_floats = [ col(x.name).cast(FloatType()) for x in data_frame.schema ]\\n\",\n    \"    return (VectorAssembler()\\n\",\n    \"        .setInputCols(features)\\n\",\n    \"        .setOutputCol('features')\\n\",\n    \"        .transform(data_frame.select(to_floats))\\n\",\n    \"        .select(col('features'), col(label)))\\n\",\n    \"\\n\",\n    \"train_data = vectorize(train_data)\\n\",\n    \"trans_data = vectorize(trans_data)\\n\",\n    \"```\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Create a XGBoostClassifier\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"params = { \\n\",\n    \"    \\\"tree_method\\\": \\\"hist\\\",\\n\",\n    \"    \\\"grow_policy\\\": \\\"depthwise\\\",\\n\",\n    \"    \\\"num_workers\\\": 1,\\n\",\n    \"    \\\"device\\\": \\\"cuda\\\",\\n\",\n    \"}\\n\",\n    \"params['features_col'] = features\\n\",\n    \"params['label_col'] = label\\n\",\n    \"    \\n\",\n    \"classifier = SparkXGBClassifier(**params)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"The parameter `num_workers` should be set to the number of GPUs in Spark cluster for GPU version, while for CPU version it is usually equal to the number of the CPU cores.\\n\",\n    \"\\n\",\n    \"Concerning the device, GPU version supports `cuda` currently, while `cpu` is designed and used here for CPU training.\\n\",\n    \"\\n\",\n    \"An example of CPU classifier:\\n\",\n    \"```\\n\",\n    \"classifier = SparkXGBClassifier(\\n\",\n    \"  feature_col=features,\\n\",\n    \"  label_col=label,  \\n\",\n    \"  num_workers=1024,\\n\",\n    \")\\n\",\n    \"```\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Train the Data with Benchmark\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"If features_cols param set, then features_col param is ignored.\\n\",\n      \"2022-11-30 07:00:45,526 WARN util.package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.\\n\",\n      \"[Stage 5:>                                                          (0 + 1) / 1]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Training takes 13.92 seconds\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\r\",\n      \"                                                                                \\r\",\n      \"/data/home/yuanli/work/reviews/pr252/pyspark_venv_20221125/lib/python3.8/site-packages/xgboost/sklearn.py:808: UserWarning: Loading a native XGBoost model with Scikit-Learn interface.\\n\",\n      \"  warnings.warn(\\\"Loading a native XGBoost model with Scikit-Learn interface.\\\")\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"def with_benchmark(phrase, action):\\n\",\n    \"    start = time()\\n\",\n    \"    result = action()\\n\",\n    \"    end = time()\\n\",\n    \"    print('{} takes {} seconds'.format(phrase, round(end - start, 2)))\\n\",\n    \"    return result\\n\",\n    \"model = with_benchmark('Training', lambda: classifier.fit(train_data))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Save and Reload the Model\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"If features_cols param set, then features_col param is ignored.\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"model.write().overwrite().save(dataRoot + '/model/agaricus')\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"loaded_model = SparkXGBClassifierModel().load(dataRoot + '/model/agaricus')\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Transformation and Show Result Sample\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2022-11-30 07:01:07,030 WARN rapids.GpuOverrides: \\n\",\n      \"!Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced; unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [rawPrediction#798, probability#1062]\\n\",\n      \"  @Expression <AttributeReference> label#254 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_0#255 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_1#256 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_2#257 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_3#258 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_4#259 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_5#260 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_6#261 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_7#262 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_8#263 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_9#264 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_10#265 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_11#266 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_12#267 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_13#268 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_14#269 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_15#270 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_16#271 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_17#272 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_18#273 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_19#274 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_20#275 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_21#276 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_22#277 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_23#278 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_24#279 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_25#280 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_26#281 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_27#282 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_28#283 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_29#284 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_30#285 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_31#286 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_32#287 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_33#288 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_34#289 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_35#290 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_36#291 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_37#292 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_38#293 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_39#294 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_40#295 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_41#296 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_42#297 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_43#298 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_44#299 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_45#300 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_46#301 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_47#302 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_48#303 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_49#304 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_50#305 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_51#306 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_52#307 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_53#308 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_54#309 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_55#310 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_56#311 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_57#312 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_58#313 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_59#314 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_60#315 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_61#316 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_62#317 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_63#318 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_64#319 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_65#320 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_66#321 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_67#322 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_68#323 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_69#324 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_70#325 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_71#326 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_72#327 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_73#328 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_74#329 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_75#330 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_76#331 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_77#332 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_78#333 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_79#334 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_80#335 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_81#336 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_82#337 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_83#338 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_84#339 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_85#340 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_86#341 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_87#342 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_88#343 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_89#344 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_90#345 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_91#346 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_92#347 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_93#348 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_94#349 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_95#350 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_96#351 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_97#352 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_98#353 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_99#354 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_100#355 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_101#356 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_102#357 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_103#358 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_104#359 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_105#360 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_106#361 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_107#362 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_108#363 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_109#364 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_110#365 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_111#366 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_112#367 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_113#368 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_114#369 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_115#370 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_116#371 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_117#372 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_118#373 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_119#374 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_120#375 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_121#376 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_122#377 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_123#378 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_124#379 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_125#380 could run on GPU\\n\",\n      \"  !Expression <Alias> UDF(pythonUDF0#1327.rawPrediction) AS rawPrediction#798 cannot run on GPU because expression Alias UDF(pythonUDF0#1327.rawPrediction) AS rawPrediction#798 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7; input expression ScalaUDF UDF(pythonUDF0#1327.rawPrediction) (org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 is not supported)\\n\",\n      \"    !Expression <ScalaUDF> UDF(pythonUDF0#1327.rawPrediction) cannot run on GPU because neither UDF implemented by class org.apache.spark.ml.functions$$$Lambda$3659/488666387 provides a GPU implementation, nor the conf `spark.rapids.sql.rowBasedUDF.enabled` is enabled; expression ScalaUDF UDF(pythonUDF0#1327.rawPrediction) produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\\n\",\n      \"      @Expression <GetStructField> pythonUDF0#1327.rawPrediction could run on GPU\\n\",\n      \"        @Expression <AttributeReference> pythonUDF0#1327 could run on GPU\\n\",\n      \"  @Expression <Alias> pythonUDF0#1327.prediction AS prediction#931 could run on GPU\\n\",\n      \"    @Expression <GetStructField> pythonUDF0#1327.prediction could run on GPU\\n\",\n      \"      @Expression <AttributeReference> pythonUDF0#1327 could run on GPU\\n\",\n      \"  !Expression <Alias> UDF(pythonUDF0#1327.probability) AS probability#1062 cannot run on GPU because expression Alias UDF(pythonUDF0#1327.probability) AS probability#1062 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7; input expression ScalaUDF UDF(pythonUDF0#1327.probability) (org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 is not supported)\\n\",\n      \"    !Expression <ScalaUDF> UDF(pythonUDF0#1327.probability) cannot run on GPU because neither UDF implemented by class org.apache.spark.ml.functions$$$Lambda$3659/488666387 provides a GPU implementation, nor the conf `spark.rapids.sql.rowBasedUDF.enabled` is enabled; expression ScalaUDF UDF(pythonUDF0#1327.probability) produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\\n\",\n      \"      @Expression <GetStructField> pythonUDF0#1327.probability could run on GPU\\n\",\n      \"        @Expression <AttributeReference> pythonUDF0#1327 could run on GPU\\n\",\n      \"\\n\",\n      \"2022-11-30 07:01:07,071 WARN rapids.GpuOverrides: \\n\",\n      \"!Exec <InMemoryTableScanExec> cannot run on GPU because unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [rawPrediction, probability]; not all expressions can be replaced; unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [rawPrediction#798, probability#1062]\\n\",\n      \"  @Expression <AttributeReference> label#254 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_0#255 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_1#256 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_2#257 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_3#258 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_4#259 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_5#260 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_6#261 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_7#262 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_8#263 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_9#264 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_10#265 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_11#266 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_12#267 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_13#268 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_14#269 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_15#270 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_16#271 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_17#272 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_18#273 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_19#274 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_20#275 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_21#276 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_22#277 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_23#278 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_24#279 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_25#280 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_26#281 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_27#282 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_28#283 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_29#284 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_30#285 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_31#286 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_32#287 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_33#288 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_34#289 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_35#290 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_36#291 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_37#292 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_38#293 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_39#294 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_40#295 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_41#296 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_42#297 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_43#298 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_44#299 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_45#300 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_46#301 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_47#302 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_48#303 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_49#304 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_50#305 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_51#306 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_52#307 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_53#308 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_54#309 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_55#310 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_56#311 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_57#312 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_58#313 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_59#314 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_60#315 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_61#316 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_62#317 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_63#318 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_64#319 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_65#320 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_66#321 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_67#322 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_68#323 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_69#324 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_70#325 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_71#326 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_72#327 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_73#328 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_74#329 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_75#330 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_76#331 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_77#332 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_78#333 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_79#334 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_80#335 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_81#336 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_82#337 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_83#338 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_84#339 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_85#340 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_86#341 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_87#342 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_88#343 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_89#344 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_90#345 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_91#346 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_92#347 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_93#348 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_94#349 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_95#350 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_96#351 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_97#352 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_98#353 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_99#354 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_100#355 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_101#356 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_102#357 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_103#358 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_104#359 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_105#360 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_106#361 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_107#362 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_108#363 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_109#364 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_110#365 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_111#366 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_112#367 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_113#368 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_114#369 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_115#370 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_116#371 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_117#372 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_118#373 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_119#374 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_120#375 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_121#376 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_122#377 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_123#378 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_124#379 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> feature_125#380 could run on GPU\\n\",\n      \"  !Expression <AttributeReference> rawPrediction#798 cannot run on GPU because expression AttributeReference rawPrediction#798 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\\n\",\n      \"  @Expression <AttributeReference> prediction#931 could run on GPU\\n\",\n      \"  !Expression <AttributeReference> probability#1062 cannot run on GPU because expression AttributeReference probability#1062 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\\n\",\n      \"\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2022-11-30 07:01:09,857 WARN rapids.GpuOverrides:                               \\n\",\n      \"!Exec <CollectLimitExec> cannot run on GPU because the Exec CollectLimitExec has been disabled, and is disabled by default because Collect Limit replacement can be slower on the GPU, if huge number of rows in a batch it could help by limiting the number of rows transferred from GPU to CPU. Set spark.rapids.sql.exec.CollectLimitExec to true if you wish to enable it\\n\",\n      \"  @Partitioning <SinglePartition$> could run on GPU\\n\",\n      \"  !Exec <ProjectExec> cannot run on GPU because unsupported data types in input: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [probability#1062, rawPrediction#798]; not all expressions can be replaced\\n\",\n      \"    @Expression <Alias> cast(label#254 as string) AS label#3936 could run on GPU\\n\",\n      \"      @Expression <Cast> cast(label#254 as string) could run on GPU\\n\",\n      \"        @Expression <AttributeReference> label#254 could run on GPU\\n\",\n      \"    @Expression <Alias> cast(rawPrediction#798 as string) AS rawPrediction#3937 could run on GPU\\n\",\n      \"      !Expression <Cast> cast(rawPrediction#798 as string) cannot run on GPU because Cast from org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 to StringType is not supported\\n\",\n      \"        !Expression <AttributeReference> rawPrediction#798 cannot run on GPU because expression AttributeReference rawPrediction#798 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\\n\",\n      \"    @Expression <Alias> cast(probability#1062 as string) AS probability#3938 could run on GPU\\n\",\n      \"      !Expression <Cast> cast(probability#1062 as string) cannot run on GPU because Cast from org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 to StringType is not supported\\n\",\n      \"        !Expression <AttributeReference> probability#1062 cannot run on GPU because expression AttributeReference probability#1062 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\\n\",\n      \"    @Expression <Alias> cast(prediction#931 as string) AS prediction#3939 could run on GPU\\n\",\n      \"      @Expression <Cast> cast(prediction#931 as string) could run on GPU\\n\",\n      \"        @Expression <AttributeReference> prediction#931 could run on GPU\\n\",\n      \"    !Exec <InMemoryTableScanExec> cannot run on GPU because unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [rawPrediction, probability]; unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [probability#1062, rawPrediction#798]; not all expressions can be replaced\\n\",\n      \"      @Expression <AttributeReference> label#254 could run on GPU\\n\",\n      \"      @Expression <AttributeReference> prediction#931 could run on GPU\\n\",\n      \"      !Expression <AttributeReference> probability#1062 cannot run on GPU because expression AttributeReference probability#1062 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\\n\",\n      \"      !Expression <AttributeReference> rawPrediction#798 cannot run on GPU because expression AttributeReference rawPrediction#798 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\\n\",\n      \"\\n\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Transformation takes 3.26 seconds\\n\",\n      \"+-----+--------------------+--------------------+----------+\\n\",\n      \"|label|       rawPrediction|         probability|prediction|\\n\",\n      \"+-----+--------------------+--------------------+----------+\\n\",\n      \"|  1.0|[-9.6646747589111...|[6.35385513305664...|       1.0|\\n\",\n      \"|  0.0|[-8.3923015594482...|[2.26557254791259...|       1.0|\\n\",\n      \"|  0.0|[-8.0568389892578...|[3.16858291625976...|       1.0|\\n\",\n      \"|  0.0|[1.91234850883483...|[0.87128275632858...|       0.0|\\n\",\n      \"|  0.0|[-8.5582475662231...|[1.91867351531982...|       1.0|\\n\",\n      \"+-----+--------------------+--------------------+----------+\\n\",\n      \"only showing top 5 rows\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"def transform():\\n\",\n    \"    result = loaded_model.transform(trans_data).cache()\\n\",\n    \"    result.foreachPartition(lambda _: None)\\n\",\n    \"    return result\\n\",\n    \"result = with_benchmark('Transformation', transform)\\n\",\n    \"result.select(label, 'rawPrediction', 'probability', 'prediction').show(5)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Evaluation\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 9,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2022-11-30 07:01:10,292 WARN rapids.GpuOverrides: \\n\",\n      \"! <DeserializeToObjectExec> cannot run on GPU because not all expressions can be replaced; GPU does not currently support the operator class org.apache.spark.sql.execution.DeserializeToObjectExec\\n\",\n      \"  ! <CreateExternalRow> createexternalrow(prediction#931, label#5899, 1.0#5900, newInstance(class org.apache.spark.ml.linalg.VectorUDT).deserialize, StructField(prediction,DoubleType,true), StructField(label,DoubleType,true), StructField(1.0,DoubleType,false), StructField(probability,org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7,true)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.CreateExternalRow\\n\",\n      \"    @Expression <AttributeReference> prediction#931 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> label#5899 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> 1.0#5900 could run on GPU\\n\",\n      \"    ! <Invoke> newInstance(class org.apache.spark.ml.linalg.VectorUDT).deserialize cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.Invoke\\n\",\n      \"      ! <NewInstance> newInstance(class org.apache.spark.ml.linalg.VectorUDT) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.NewInstance\\n\",\n      \"      !Expression <AttributeReference> probability#1062 cannot run on GPU because expression AttributeReference probability#1062 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\\n\",\n      \"  !Expression <AttributeReference> obj#5905 cannot run on GPU because expression AttributeReference obj#5905 produces an unsupported type ObjectType(interface org.apache.spark.sql.Row)\\n\",\n      \"  !Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced; unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [probability#1062]; unsupported data types in input: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [probability#1062]\\n\",\n      \"    @Expression <AttributeReference> prediction#931 could run on GPU\\n\",\n      \"    @Expression <Alias> cast(label#254 as double) AS label#5899 could run on GPU\\n\",\n      \"      @Expression <Cast> cast(label#254 as double) could run on GPU\\n\",\n      \"        @Expression <AttributeReference> label#254 could run on GPU\\n\",\n      \"    @Expression <Alias> 1.0 AS 1.0#5900 could run on GPU\\n\",\n      \"      @Expression <Literal> 1.0 could run on GPU\\n\",\n      \"    !Expression <AttributeReference> probability#1062 cannot run on GPU because expression AttributeReference probability#1062 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\\n\",\n      \"    !Exec <InMemoryTableScanExec> cannot run on GPU because unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [rawPrediction, probability]; not all expressions can be replaced; unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [probability#1062]\\n\",\n      \"      @Expression <AttributeReference> label#254 could run on GPU\\n\",\n      \"      @Expression <AttributeReference> prediction#931 could run on GPU\\n\",\n      \"      !Expression <AttributeReference> probability#1062 cannot run on GPU because expression AttributeReference probability#1062 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\\n\",\n      \"\\n\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Evaluation takes 1.0 seconds\\n\",\n      \"Accuracy is 0.9069677632722861\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\r\",\n      \"[Stage 12:>                                                         (0 + 1) / 1]\\r\",\n      \"\\r\",\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"accuracy = with_benchmark(\\n\",\n    \"    'Evaluation',\\n\",\n    \"    lambda: MulticlassClassificationEvaluator().setLabelCol(label).evaluate(result))\\n\",\n    \"print('Accuracy is ' + str(accuracy))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Stop\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 10,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"spark.stop()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.2\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "examples/XGBoost-Examples/agaricus/notebooks/scala/agaricus-gpu.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Introduction to XGBoost Spark3.0 with GPU\\n\",\n    \"\\n\",\n    \"Agaricus is an example of XGBoost classifier for multiple classification. This notebook will show you how to load data, train the xgboost model. Comparing to original XGBoost Spark code, there're only one API difference.\\n\",\n    \"\\n\",\n    \"## Load libraries\\n\",\n    \"First load some common libraries will be used by both GPU version and CPU version XGBoost.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import ml.dmlc.xgboost4j.scala.spark.{XGBoostClassifier, XGBoostClassificationModel}\\n\",\n    \"import org.apache.spark.sql.SparkSession\\n\",\n    \"import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator\\n\",\n    \"import org.apache.spark.sql.types.{DoubleType, IntegerType, StructField, StructType}\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Besides CPU version requires some extra libraries, such as:\\n\",\n    \"\\n\",\n    \"```scala\\n\",\n    \"import org.apache.spark.ml.feature.VectorAssembler\\n\",\n    \"import org.apache.spark.sql.functions._\\n\",\n    \"```\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Set the dataset paths\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"lastException = null\\n\",\n       \"dataRoot = /data\\n\",\n       \"trainPath = /data/agaricus/csv/train/\\n\",\n       \"evalPath = /data/agaricus/csv/test/\\n\",\n       \"transPath = /data/agaricus/csv/test/\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"/data/agaricus/csv/test/\"\n      ]\n     },\n     \"execution_count\": 2,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"// You need to update them to your real paths!\\n\",\n    \"val dataRoot = sys.env.getOrElse(\\\"DATA_ROOT\\\", \\\"/data\\\")\\n\",\n    \"val trainPath = dataRoot + \\\"/agaricus/csv/train/\\\"\\n\",\n    \"val evalPath  = dataRoot + \\\"/agaricus/csv/test/\\\"\\n\",\n    \"val transPath = dataRoot + \\\"/agaricus/csv/test/\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Build the schema of the dataset\\n\",\n    \"\\n\",\n    \"For agaricus example, the data has 126 dimensions, being named as \\\"feature_0\\\", \\\"feature_1\\\" ... \\\"feature_125\\\". The schema will be used to load data in the future.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"labelName = label\\n\",\n       \"dataSchema = StructType(StructField(label,DoubleType,true), StructField(feature_0,DoubleType,true), StructField(feature_1,DoubleType,true), StructField(feature_2,DoubleType,true), StructField(feature_3,DoubleType,true), StructField(feature_4,DoubleType,true), StructField(feature_5,DoubleType,true), StructField(feature_6,DoubleType,true), StructField(feature_7,DoubleType,true), StructField(feature_8,DoubleType,true), StructField(feature_9,DoubleType,true), StructField(feature_10,DoubleType,true), StructField(feature_11,DoubleType,true), StructField(feature_12,DoubleType,true), StructField(feature_13,DoubleType,true), StructFiel...\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"columnNames: (length: Int)List[String]\\n\",\n       \"schema: (length: Int)org.apache.spark.sql.types.StructType\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"StructType(StructField(label,DoubleType,true), StructField(feature_0,DoubleType,true), StructField(feature_1,DoubleType,true), StructField(feature_2,DoubleType,true), StructField(feature_3,DoubleType,true), StructField(feature_4,DoubleType,true), StructField(feature_5,DoubleType,true), StructField(feature_6,DoubleType,true), StructField(feature_7,DoubleType,true), StructField(feature_8,DoubleType,true), StructField(feature_9,DoubleType,true), StructField(feature_10,DoubleType,true), StructField(feature_11,DoubleType,true), StructField(feature_12,DoubleType,true), StructField(feature_13,DoubleType,true), StructFiel...\"\n      ]\n     },\n     \"execution_count\": 3,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"val labelName = \\\"label\\\"\\n\",\n    \"def columnNames(length: Int): List[String] =\\n\",\n    \"  0.until(length).map(i => s\\\"feature_$i\\\").toList.+:(labelName)\\n\",\n    \"\\n\",\n    \"def schema(length: Int): StructType =\\n\",\n    \"  StructType(columnNames(length).map(n => StructField(n, DoubleType)))\\n\",\n    \"\\n\",\n    \"val dataSchema = schema(126)\\n\",\n    \"\\n\",\n    \"// Build the column name list for features.\\n\",\n    \"val featureCols = dataSchema.filter(_.name != labelName).map(_.name).toArray\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Create a new spark session and load data\\n\",\n    \"\\n\",\n    \"A new spark session should be created to continue all the following spark operations.\\n\",\n    \"\\n\",\n    \"NOTE: in this notebook, the dependency jars have been loaded when installing toree kernel. Alternatively the jars can be loaded into notebook by [%AddJar magic](https://toree.incubator.apache.org/docs/current/user/faq/). However, there's one restriction for `%AddJar`: the jar uploaded can only be available when `AddJar` is called just after a new spark session is created. Do it as below:\\n\",\n    \"\\n\",\n    \"```scala\\n\",\n    \"import org.apache.spark.sql.SparkSession\\n\",\n    \"val spark = SparkSession.builder().appName(\\\"agaricus-GPU\\\").getOrCreate\\n\",\n    \"%AddJar file:/data/libs/rapids-4-spark-XXX.jar\\n\",\n    \"%AddJar file:/data/libs/xgboost4j-spark-gpu_2.12-XXX.jar\\n\",\n    \"%AddJar file:/data/libs/xgboost4j-gpu_2.12-XXX.jar\\n\",\n    \"// ...\\n\",\n    \"```\\n\",\n    \"\\n\",\n    \"##### Please note the new jar \\\"rapids-4-spark-XXX.jar\\\" is only needed for GPU version, you can not add it to dependence list for CPU version.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"sparkSession = org.apache.spark.sql.SparkSession@3886ba44\\n\",\n       \"dataReader = org.apache.spark.sql.DataFrameReader@5c8be07f\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"org.apache.spark.sql.DataFrameReader@5c8be07f\"\n      ]\n     },\n     \"execution_count\": 4,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"// Build the spark session and data reader as usual\\n\",\n    \"val sparkSession = SparkSession.builder.appName(\\\"agaricus-gpu\\\").getOrCreate\\n\",\n    \"val dataReader = sparkSession.read.option(\\\"header\\\", true).schema(dataSchema)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"trainSet = [label: double, feature_0: double ... 125 more fields]\\n\",\n       \"evalSet = [label: double, feature_0: double ... 125 more fields]\\n\",\n       \"transSet = [label: double, feature_0: double ... 125 more fields]\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"[label: double, feature_0: double ... 125 more fields]\"\n      ]\n     },\n     \"execution_count\": 5,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"// load all the dataset\\n\",\n    \"val trainSet = dataReader.csv(trainPath)\\n\",\n    \"val evalSet  = dataReader.csv(evalPath)\\n\",\n    \"val transSet = dataReader.csv(transPath)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Set XGBoost parameters and build a XGBoostClassifier\\n\",\n    \"\\n\",\n    \"For CPU version, `num_workers` is recommended being equal to the number of CPU cores, while for GPU version, it should be set to the number of GPUs in Spark cluster.\\n\",\n    \"\\n\",\n    \"Besides the `device` for CPU version is also different from that for GPU version. Now only \\\"cuda\\\" is supported for training on GPU.\\n\",\n    \"\\n\",\n    \"```scala\\n\",\n    \"// difference in parameters\\n\",\n    \"  \\\"num_workers\\\" -> 12,\\n\",\n    \"  \\\"device\\\" -> \\\"cpu\\\",\\n\",\n    \"```\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"paramMap = Map(num_workers -> 1, tree_method -> hist, device -> cuda, num_round -> 100)\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"Map(num_workers -> 1, tree_method -> hist, device -> cuda, num_round -> 100)\"\n      ]\n     },\n     \"execution_count\": 6,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"// build XGBoost classifier\\n\",\n    \"val paramMap = Map(\\n\",\n    \"  \\\"num_workers\\\" -> 1,\\n\",\n    \"  \\\"tree_method\\\" -> \\\"hist\\\",\\n\",\n    \"  \\\"device\\\" -> \\\"cuda\\\",\\n\",\n    \"  \\\"num_round\\\" -> 100\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"xgbClassifier = xgbc_57e2d7fc657a\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"xgbc_57e2d7fc657a\"\n      ]\n     },\n     \"execution_count\": 7,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"val xgbClassifier  = new XGBoostClassifier(paramMap)\\n\",\n    \"  .setLabelCol(labelName)\\n\",\n    \"  .setFeaturesCol(featureCols)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Benchmark and train\\n\",\n    \"The object `benchmark` is used to compute the elapsed time of some operations.\\n\",\n    \"\\n\",\n    \"Training with evaluation dataset is also supported, the same as CPU version's behavior:\\n\",\n    \"\\n\",\n    \"* Call API `setEvalDataset` after initializing an XGBoostClassifier\\n\",\n    \"\\n\",\n    \"```scala\\n\",\n    \"xgbClassifier.setEvalDataset(evalSet)\\n\",\n    \"```\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"xgbc_57e2d7fc657a\"\n      ]\n     },\n     \"execution_count\": 8,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"xgbClassifier.setEvalDataset(evalSet)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 9,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"defined object Benchmark\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"object Benchmark {\\n\",\n    \"  def time[R](phase: String)(block: => R): (R, Float) = {\\n\",\n    \"    val t0 = System.currentTimeMillis\\n\",\n    \"    val result = block // call-by-name\\n\",\n    \"    val t1 = System.currentTimeMillis\\n\",\n    \"    println(\\\"Elapsed time [\\\" + phase + \\\"]: \\\" + ((t1 - t0).toFloat / 1000) + \\\"s\\\")\\n\",\n    \"    (result, (t1 - t0).toFloat / 1000)\\n\",\n    \"  }\\n\",\n    \"}\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 10,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\n\",\n      \"------ Training ------\\n\",\n      \"Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=34739, DMLC_NUM_WORKER=1}\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"xgbClassificationModel = xgbc_57e2d7fc657a\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Elapsed time [train]: 11.177s\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"xgbc_57e2d7fc657a\"\n      ]\n     },\n     \"execution_count\": 10,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"// start training\\n\",\n    \"println(\\\"\\\\n------ Training ------\\\")\\n\",\n    \"val (xgbClassificationModel, _) = Benchmark.time(\\\"train\\\") {\\n\",\n    \"  xgbClassifier.fit(trainSet)\\n\",\n    \"}\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Transformation and evaluation\\n\",\n    \"Here uses `transSet` to evaluate our model and prints some useful columns to show our prediction result. After that `MulticlassClassificationEvaluator` is used to calculate an overall accuracy of our predictions.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 11,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\n\",\n      \"------ Transforming ------\\n\",\n      \"Elapsed time [transform]: 2.51s\\n\",\n      \"+-----+--------------------+--------------------+----------+\\n\",\n      \"|label|       rawPrediction|         probability|prediction|\\n\",\n      \"+-----+--------------------+--------------------+----------+\\n\",\n      \"|  1.0|[-0.9999903440475...|[9.65595245361328...|       1.0|\\n\",\n      \"|  0.0|[-0.9999903440475...|[9.65595245361328...|       1.0|\\n\",\n      \"|  0.0|[-0.9999903440475...|[9.65595245361328...|       1.0|\\n\",\n      \"|  0.0|[-4.4405460357666...|[0.99995559453964...|       0.0|\\n\",\n      \"|  0.0|[-0.9999903440475...|[9.65595245361328...|       1.0|\\n\",\n      \"|  1.0|[-0.9999903440475...|[9.65595245361328...|       1.0|\\n\",\n      \"|  0.0|[-0.9999903440475...|[9.65595245361328...|       1.0|\\n\",\n      \"|  1.0|[-0.9999903440475...|[9.65595245361328...|       1.0|\\n\",\n      \"|  0.0|[-0.9999903440475...|[9.65595245361328...|       1.0|\\n\",\n      \"|  1.0|[-0.9999903440475...|[9.65595245361328...|       1.0|\\n\",\n      \"+-----+--------------------+--------------------+----------+\\n\",\n      \"only showing top 10 rows\\n\",\n      \"\\n\",\n      \"\\n\",\n      \"------Accuracy of Evaluation------\\n\",\n      \"accuracy == 0.9069677632722861\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"results = [label: double, feature_0: double ... 128 more fields]\\n\",\n       \"evaluator = MulticlassClassificationEvaluator: uid=mcEval_8f89b3a17d4b, metricName=f1, metricLabel=0.0, beta=1.0, eps=1.0E-15\\n\",\n       \"accuracy = 0.9069677632722861\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"0.9069677632722861\"\n      ]\n     },\n     \"execution_count\": 11,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"// start transform\\n\",\n    \"println(\\\"\\\\n------ Transforming ------\\\")\\n\",\n    \"val (results, _) = Benchmark.time(\\\"transform\\\") {\\n\",\n    \"  val ret = xgbClassificationModel.transform(transSet).cache()\\n\",\n    \"  ret.foreachPartition((_: Iterator[_]) => ())\\n\",\n    \"  ret\\n\",\n    \"}\\n\",\n    \"results.select(labelName, \\\"rawPrediction\\\", \\\"probability\\\", \\\"prediction\\\").show(10)\\n\",\n    \"\\n\",\n    \"println(\\\"\\\\n------Accuracy of Evaluation------\\\")\\n\",\n    \"val evaluator = new MulticlassClassificationEvaluator()\\n\",\n    \"evaluator.setLabelCol(labelName)\\n\",\n    \"val accuracy = evaluator.evaluate(results)\\n\",\n    \"\\n\",\n    \"println(s\\\"accuracy == $accuracy\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Save the model to disk and load model\\n\",\n    \"Save the model to disk and then load it to memory. After that use the loaded model to do a new prediction.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 12,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Elapsed time [transform2]: 0.069s\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"modelFromDisk = xgbc_57e2d7fc657a\\n\",\n       \"results2 = [label: double, feature_0: double ... 128 more fields]\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+-----+--------------------+--------------------+----------+\\n\",\n      \"|label|       rawPrediction|         probability|prediction|\\n\",\n      \"+-----+--------------------+--------------------+----------+\\n\",\n      \"|  1.0|[-0.9999903440475...|[9.65595245361328...|       1.0|\\n\",\n      \"|  0.0|[-0.9999903440475...|[9.65595245361328...|       1.0|\\n\",\n      \"|  0.0|[-0.9999903440475...|[9.65595245361328...|       1.0|\\n\",\n      \"|  0.0|[-4.4405460357666...|[0.99995559453964...|       0.0|\\n\",\n      \"|  0.0|[-0.9999903440475...|[9.65595245361328...|       1.0|\\n\",\n      \"|  1.0|[-0.9999903440475...|[9.65595245361328...|       1.0|\\n\",\n      \"|  0.0|[-0.9999903440475...|[9.65595245361328...|       1.0|\\n\",\n      \"|  1.0|[-0.9999903440475...|[9.65595245361328...|       1.0|\\n\",\n      \"|  0.0|[-0.9999903440475...|[9.65595245361328...|       1.0|\\n\",\n      \"|  1.0|[-0.9999903440475...|[9.65595245361328...|       1.0|\\n\",\n      \"+-----+--------------------+--------------------+----------+\\n\",\n      \"only showing top 10 rows\\n\",\n      \"\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"[label: double, feature_0: double ... 128 more fields]\"\n      ]\n     },\n     \"execution_count\": 12,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"xgbClassificationModel.write.overwrite.save(dataRoot + \\\"/model/agaricus\\\")\\n\",\n    \"\\n\",\n    \"val modelFromDisk = XGBoostClassificationModel.load(dataRoot + \\\"/model/agaricus\\\")\\n\",\n    \"val (results2, _) = Benchmark.time(\\\"transform2\\\") {\\n\",\n    \"  modelFromDisk.transform(transSet)\\n\",\n    \"}\\n\",\n    \"results2.select(labelName, \\\"rawPrediction\\\", \\\"probability\\\", \\\"prediction\\\").show(10)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 13,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"sparkSession.close()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"XGBoost4j-Spark - Scala\",\n   \"language\": \"scala\",\n   \"name\": \"XGBoost4j-Spark_scala\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": \"text/x-scala\",\n   \"file_extension\": \".scala\",\n   \"mimetype\": \"text/x-scala\",\n   \"name\": \"scala\",\n   \"pygments_lexer\": \"scala\",\n   \"version\": \"2.12.15\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "examples/XGBoost-Examples/agaricus/pom.xml",
    "content": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<!--\n  ~ Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.\n  ~\n  ~ Licensed under the Apache License, Version 2.0 (the \"License\");\n  ~ you may not use this file except in compliance with the License.\n  ~ You may obtain a copy of the License at\n  ~\n  ~ http://www.apache.org/licenses/LICENSE-2.0\n  ~\n  ~ Unless required by applicable law or agreed to in writing, software\n  ~ distributed under the License is distributed on an \"AS IS\" BASIS,\n  ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n  ~ See the License for the specific language governing permissions and\n  ~ limitations under the License.\n  -->\n\n<project xmlns=\"http://maven.apache.org/POM/4.0.0\"\n         xmlns:xsi=\"http://www.w3.org/2001/XMLSchema-instance\"\n         xsi:schemaLocation=\"http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd\">\n    <parent>\n        <artifactId>sample_xgboost_examples</artifactId>\n        <groupId>com.nvidia</groupId>\n        <version>0.2.3-SNAPSHOT</version>\n    </parent>\n    <modelVersion>4.0.0</modelVersion>\n\n    <artifactId>spark_examples_agaricus_${scala.binary.version}</artifactId>\n\n    <properties>\n        <maven.compiler.source>8</maven.compiler.source>\n        <maven.compiler.target>8</maven.compiler.target>\n    </properties>\n\n    <dependencies>\n        <dependency>\n            <groupId>com.nvidia</groupId>\n            <artifactId>spark_examples_utility_${scala.binary.version}</artifactId>\n            <version>${project.version}</version>\n            <scope>compile</scope>\n        </dependency>\n    </dependencies>\n\n    <build>\n        <sourceDirectory>scala/src</sourceDirectory>\n    </build>\n\n</project>"
  },
  {
    "path": "examples/XGBoost-Examples/agaricus/python/com/__init__.py",
    "content": "# Copyright (c) 2024, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n"
  },
  {
    "path": "examples/XGBoost-Examples/agaricus/python/com/nvidia/__init__.py",
    "content": "# Copyright (c) 2024, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n"
  },
  {
    "path": "examples/XGBoost-Examples/agaricus/python/com/nvidia/spark/__init__.py",
    "content": "# Copyright (c) 2024, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n"
  },
  {
    "path": "examples/XGBoost-Examples/agaricus/python/com/nvidia/spark/examples/__init__.py",
    "content": "# Copyright (c) 2024, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n"
  },
  {
    "path": "examples/XGBoost-Examples/agaricus/python/com/nvidia/spark/examples/agaricus/__init__.py",
    "content": "# Copyright (c) 2024, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n"
  },
  {
    "path": "examples/XGBoost-Examples/agaricus/python/com/nvidia/spark/examples/agaricus/main.py",
    "content": "#\n# Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n# http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\nfrom pyspark.sql.types import *\n\nfrom com.nvidia.spark.examples.utility.utils import *\nfrom pyspark.sql import SparkSession\n\nfrom xgboost.spark import SparkXGBClassifier, SparkXGBClassifierModel\n\nlabel = 'label'\nfeature_names = ['feature_' + str(i) for i in range(0, 126)]\nschema = StructType([StructField(x, FloatType()) for x in [label] + feature_names])\n\n\ndef main(args, xgboost_args):\n    spark = (SparkSession\n             .builder\n             .appName(args.mainClass)\n             .getOrCreate())\n\n    train_data, eval_data, trans_data = valid_input_data(spark, args, '', schema)\n\n    if args.mode in ['all', 'train']:\n        if train_data is None:\n            print('-' * 80)\n            print('Usage: train data path required when mode is all or train')\n            print('-' * 80)\n            exit(1)\n\n        train_data, features = transform_data(train_data, label, args.use_gpu)\n        xgboost_args['features_col'] = features\n        xgboost_args['label_col'] = label\n        classifier = SparkXGBClassifier(**xgboost_args)\n\n        if eval_data:\n            # TODO\n            pass\n\n        model = with_benchmark('Training', lambda: classifier.fit(train_data))\n\n        if args.modelPath:\n            writer = model.write().overwrite() if args.overwrite else model\n            writer.save(args.modelPath)\n    else:\n        model = SparkXGBClassifierModel.load(args.modelPath)\n\n    if args.mode in ['all', 'transform']:\n        if trans_data is None:\n            print('-' * 80)\n            print('Usage: trans data path required when mode is all or transform')\n            print('-' * 80)\n            exit(1)\n\n        trans_data, _ = transform_data(trans_data, label, args.use_gpu)\n\n        def transform():\n            result = model.transform(trans_data).cache()\n            result.foreachPartition(lambda _: None)\n            return result\n\n        result = with_benchmark('Transformation', transform)\n        show_sample(args, result, label)\n        with_benchmark('Evaluation', lambda: check_classification_accuracy(result, label))\n\n    spark.stop()\n"
  },
  {
    "path": "examples/XGBoost-Examples/agaricus/scala/src/com/nvidia/spark/examples/agaricus/Main.scala",
    "content": "/*\n * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.nvidia.spark.examples.agaricus\n\nimport com.nvidia.spark.examples.utility.{Benchmark, SparkSetup, XGBoostArgs}\nimport ml.dmlc.xgboost4j.scala.spark.{XGBoostClassificationModel, XGBoostClassifier}\nimport org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator\nimport org.apache.spark.sql.types.{FloatType, StructField, StructType}\n\nobject Main {\n  def main(args: Array[String]): Unit = {\n\n    val labelName = \"label\"\n\n    def featureNames(length: Int): List[String] =\n      0.until(length).map(i => s\"feature_$i\").toList.+:(labelName)\n\n    def schema(length: Int): StructType =\n      StructType(featureNames(length).map(n => StructField(n, FloatType)))\n\n    val dataSchema = schema(126)\n    val xgboostArgs = XGBoostArgs.parse(args)\n    val processor = this.getClass.getSimpleName.stripSuffix(\"$\").substring(0, 3)\n    val appInfo = Seq(\"Agaricus\", processor, xgboostArgs.format)\n\n    // build spark session\n    val spark = SparkSetup(args, appInfo.mkString(\"-\"))\n    val benchmark = Benchmark(appInfo(0), appInfo(1), appInfo(2))\n\n    // build data reader\n    val dataReader = spark.read\n\n    // load data\n    val pathsArray = xgboostArgs.getDataPaths\n    // train, eval, transform\n    var datasets = pathsArray.map { paths =>\n      if (paths.nonEmpty) {\n        xgboostArgs.format match {\n          case \"csv\" => Some(dataReader.option(\"header\", xgboostArgs.hasHeader).schema(dataSchema).csv(paths: _*))\n          case \"orc\" => Some(dataReader.orc(paths: _*))\n          case \"parquet\" => Some(dataReader.parquet(paths: _*))\n          case _ => throw new IllegalArgumentException(\"Unsupported data file format!\")\n        }\n      } else None\n    }\n\n    val featureCols = dataSchema.filter(_.name != labelName).map(_.name).toArray\n\n    val xgbClassificationModel = if (xgboostArgs.isToTrain) {\n      // build XGBoost classifier\n      val paramMap = xgboostArgs.xgboostParams(Map(\n        \"objective\" -> \"binary:logistic\",\n      ))\n      val xgbClassifier = new XGBoostClassifier(paramMap)\n        .setLabelCol(labelName)\n        // === diff ===\n        .setFeaturesCol(featureCols)\n\n      datasets(1).foreach(_ => xgbClassifier.setEvalDataset(_))\n\n      println(\"\\n------ Training ------\")\n      val (model, _) = benchmark.time(\"train\") {\n        xgbClassifier.fit(datasets(0).get)\n      }\n      // Save model if modelPath exists\n      xgboostArgs.modelPath.foreach(path =>\n        if (xgboostArgs.isOverwrite) model.write.overwrite().save(path) else model.save(path))\n      model\n    } else {\n      XGBoostClassificationModel.load(xgboostArgs.modelPath.get)\n    }\n\n    if (xgboostArgs.isToTransform) {\n      // start transform\n      println(\"\\n------ Transforming ------\")\n      var (results, _) = benchmark.time(\"transform\") {\n        val ret = xgbClassificationModel.transform(datasets(2).get).cache()\n        ret.foreachPartition((_: Iterator[_]) => ())\n        ret\n      }\n      results = if (xgboostArgs.isShowFeatures) {\n        results\n      } else {\n        results.select(labelName, \"rawPrediction\", \"probability\", \"prediction\")\n      }\n      results.show(xgboostArgs.numRows)\n\n      println(\"\\n------Accuracy of Evaluation------\")\n      val evaluator = new MulticlassClassificationEvaluator().setLabelCol(labelName)\n      evaluator.evaluate(results) match {\n        case accuracy if !accuracy.isNaN =>\n          benchmark.value(accuracy, \"Accuracy\", \"Accuracy for\")\n        // Throw an exception when NaN ?\n      }\n    }\n\n    spark.close()\n  }\n}\n"
  },
  {
    "path": "examples/XGBoost-Examples/aggregator/.gitignore",
    "content": ".idea\ntarget\n*.iml\n*.xml\n"
  },
  {
    "path": "examples/XGBoost-Examples/app-parameters/supported_xgboost_parameters_python.md",
    "content": "Supported Parameters\n============================\n\nThis is a description of all the parameters available when you are running examples in this repo:\n\n1. All [xgboost parameters](https://xgboost.readthedocs.io/en/latest/parameter.html) are supported.\n   * Please use the `camelCase`, e.g., `--treeMethod=hist`.\n   * `lambda` is replaced with `lambda_`, because `lambda` is a keyword in Python.\n2. `--mainClass=[app class]`: The entry class of the application to be started. Available value is one of the below classes.\n   * com.nvidia.spark.examples.agaricus.main\n   * com.nvidia.spark.examples.agaricus.main\n   * com.nvidia.spark.examples.mortgage.main\n   * com.nvidia.spark.examples.mortgage.main\n   * com.nvidia.spark.examples.taxi.main\n   * com.nvidia.spark.examples.taxi.main\n   * com.nvidia.spark.examples.mortgage.etl_main\n   * com.nvidia.spark.examples.taxi.etl_main\n3. `--format=[csv|parquet|orc]`: The format of the data for training/transforming, now only supports 'csv', 'parquet' and 'orc'. *Required*.\n4. `--mode=[all|train|transform]`. The behavior of the XGBoost application (meaning CPUMain and GPUMain), default is 'all' if not specified.\n   * all: Do both training and transforming, will save model to 'modelPath' if specified\n   * train: Do training only, will save model to 'modelPath' if specified.\n   * transform: Do transforming only, 'modelPath' is required to locate the model data to be loaded.\n5. `--dataPath=[prefix]::[path]`: Path to input data file(s), or path to output data files. Use it repeatly to specify multiple data paths.\n   * `--dataPath=train::[path]`: Path to the training data file(s), required when mode is NOT 'transform'.\n   * `--dataPath=trans::[path]`: Path to the transforming data file(s), required when mode is NOT 'train'.\n   * `--dataPath=eval::[path]`: Path to the evaluation data file(s) for training. Optional.\n   * `--dataPath=rawTrain::[path]`: Path to the raw data files for training, only used by taxi/CPUMain, taxi/GPUMain now to support E2E train.\n   * `--dataPath=rawTrans::[path]`: Path to the raw data files for transforming, only used by taxi/CPUMain, taxi/GPUMain now to support E2E tranformation.\n   * `--dataPath=rawEval::[path]`: Path to the raw data files being used as evaluation data for training. Optional.\n   * `--dataPath=raw::[path]`: Path to the raw data files to be transformed by taxi/ETLMain.\n   * `--dataPath=perf::[path]`,`-dataPath=acq::[path]`: Paths to the raw data files to be transformed by mortgage/ETLMain.\n   * `--dataPath=out::`: Path where to place the output data files for both mortgage/ETLMain and taxi/ETLMain.\n   * `--dataPath=tmp::`: Path where to place the output data files for converting raw csv format to parquet.\n6. `--modelPath=[path]`: Path to save model after training, or where to load model for transforming only. Required only when mode is 'transform'.\n7. `--overwrite=[true|false]`: Whether to overwrite the current model data under 'modelPath'. Default is false. You may need to set to true to avoid IOException when saving the model to a path already exists.\n8. `--hasHeader=[true|false]`: Indicate whether the csv file has header.\n9. `--numRows=[int value]`: The number of the rows to be shown after transforming done. Default is 5.\n10. `--showFeatures=[true|false]`: Whether to show the features columns after transforming done. Default is true.\n11. `--dataRatios=[trainRatio:transformRatio]`: The ratios of data for train and transform, then the ratio for evaluation is (100-train-test). Default is 80:20, no evaluation. This is only used by taxi/ETLMain now to generate the output data.\n"
  },
  {
    "path": "examples/XGBoost-Examples/app-parameters/supported_xgboost_parameters_scala.md",
    "content": "Supported Parameters\n============================\n\nThis is a description of all the parameters available when you are running examples in this repo:\n\n1. All [xgboost parameters](https://xgboost.readthedocs.io/en/latest/parameter.html) are supported.\n2. `-format=[csv|parquet|orc]`: The format of the data for training/transforming, now only supports 'csv', 'parquet' and 'orc'. *Required*.\n3. `-mode=[all|train|transform]`. The behavior of the XGBoost application (meaning CPUMain and GPUMain), default is 'all' if not specified.\n   * all: Do both training and transforming, will save model to 'modelPath' if specified\n   * train: Do training only, will save model to 'modelPath' if specified.\n   * transform: Do transforming only, 'modelPath' is required to locate the model data to be loaded.\n4. `-dataPath=[prefix]::[path]`: Path to input data file(s), or path to output data files. Use it repeatly to specify multiple data paths.\n   * `-dataPath=train::[path]`: Path to the training data file(s), required when mode is NOT 'transform'.\n   * `-dataPath=trans::[path]`: Path to the transforming data file(s), required when mode is NOT 'train'.\n   * `-dataPath=eval::[path]`: Path to the evaluation data file(s) for training. Optional.\n   * `-dataPath=rawTrain::[path]`: Path to the raw data files for training, only used by taxi/CPUMain, taxi/GPUMain now to support E2E train.\n   * `-dataPath=rawTrans::[path]`: Path to the raw data files for transforming, only used by taxi/CPUMain, taxi/GPUMain now to support E2E tranformation.\n   * `-dataPath=rawEval::[path]`: Path to the raw data files being used as evaluation data for training. Optional.\n   * `-dataPath=raw::[path]`: Path to the raw data files to be transformed by taxi/ETLMain.\n   * `-dataPath=perf::[path]`,`-dataPath=acq::[path]`: Paths to the raw data files to be transformed by mortgage/ETLMain.\n   * `-dataPath=out::`: Path where to place the output data files for both mortgage/ETLMain and taxi/ETLMain.\n   * `-dataPath=tmp::`: Path where to place the output data files for converting raw csv format to parquet.\n5. `-modelPath=[path]`: Path to save model after training, or where to load model for transforming only. Required only when mode is 'transform'.\n6. `-overwrite=[true|false]`: Whether to overwrite the current model data under 'modelPath'. Default is false. You may need to set to true to avoid IOException when saving the model to a path already exists.\n7. `-hasHeader=[true|false]`: Indicate whether the csv file has header.\n8. `-numRows=[int value]`: The number of the rows to be shown after transforming done. Default is 5.\n9. `-showFeatures=[true|false]`: Whether to show the features columns after transforming done. Default is true.\n10. `-dataRatios=[trainRatio:transformRatio]`: The ratios of data for train and transform, then the ratio for evaluation is (100-train-test). Default is 80:20, no evaluation. This is only used by taxi/ETLMain now to generate the output data.\n"
  },
  {
    "path": "examples/XGBoost-Examples/assembly/assembly-no-scala.xml",
    "content": "<!--\n  ~ Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.\n  ~\n  ~ Licensed under the Apache License, Version 2.0 (the \"License\");\n  ~ you may not use this file except in compliance with the License.\n  ~ You may obtain a copy of the License at\n  ~\n  ~ http://www.apache.org/licenses/LICENSE-2.0\n  ~\n  ~ Unless required by applicable law or agreed to in writing, software\n  ~ distributed under the License is distributed on an \"AS IS\" BASIS,\n  ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n  ~ See the License for the specific language governing permissions and\n  ~ limitations under the License.\n  -->\n<assembly xmlns=\"http://maven.apache.org/ASSEMBLY/2.0.0\"\n  xmlns:xsi=\"http://www.w3.org/2001/XMLSchema-instance\"\n  xsi:schemaLocation=\"http://maven.apache.org/ASSEMBLY/2.0.0 http://maven.apache.org/xsd/assembly-2.0.0.xsd\">\n  <id>jar-with-dependencies_${scala.binary.version}</id>\n  <formats>\n    <format>jar</format>\n  </formats>\n  <includeBaseDirectory>false</includeBaseDirectory>\n  <dependencySets>\n    <dependencySet>\n      <excludes>\n        <exclude>org.scala-lang*:scala-*</exclude>\n      </excludes>\n      <outputDirectory>/</outputDirectory>\n      <useProjectArtifact>true</useProjectArtifact>\n      <unpack>true</unpack>\n      <scope>runtime</scope>\n    </dependencySet>\n  </dependencySets>\n</assembly>\n\n\n"
  },
  {
    "path": "examples/XGBoost-Examples/main.py",
    "content": "#\n# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n# http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\nfrom com.nvidia.spark.examples.main import main\n\nmain()\n"
  },
  {
    "path": "examples/XGBoost-Examples/mortgage/.gitignore",
    "content": ".idea\ntarget\n*.iml\n"
  },
  {
    "path": "examples/XGBoost-Examples/mortgage/notebooks/python/MortgageETL+XGBoost.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Dataset\\n\",\n    \"\\n\",\n    \"Dataset is derived from Fannie Mae’s [Single-Family Loan Performance Data](http://www.fanniemae.com/portal/funding-the-market/data/loan-performance-data.html) with all rights reserved by Fannie Mae. Refer to these [instructions](https://github.com/NVIDIA/spark-rapids-examples/blob/branch-23.10/docs/get-started/xgboost-examples/dataset/mortgage.md) to download the dataset.\\n\",\n    \"\\n\",\n    \"# ETL + XGBoost train & transform\\n\",\n    \"\\n\",\n    \"This notebook is an end-to-end example of ETL + XGBoost Train & Transform by using [Spark-Rapids](https://github.com/NVIDIA/spark-rapids) and [XGBoost](https://github.com/dmlc/xgboost) with GPU accelerated.\\n\",\n    \"<br>The main steps:\\n\",\n    \"1. Run ETL to generate 2 datasets for train and test<br>\\n\",\n    \"   You can choose to save the datasets or not by setting \\\"is_save_dataset\\\" to True or False.<br>\\n\",\n    \"   It means you don't need to save the dataset to disk after ETL and directly feed the dataframe to XGBoost train or transform.\\n\",\n    \"2. Run XGBoost train with the train dataset\\n\",\n    \"3. Run XGBoost transform with the test dataset\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import time\\n\",\n    \"import os\\n\",\n    \"from pyspark import broadcast\\n\",\n    \"from pyspark.conf import SparkConf\\n\",\n    \"from pyspark.sql import SparkSession\\n\",\n    \"from pyspark.sql.functions import *\\n\",\n    \"from pyspark.sql.types import *\\n\",\n    \"from pyspark.sql.window import Window\\n\",\n    \"# if you pass/unpack the archive file and enable the environment\\n\",\n    \"# os.environ['PYSPARK_PYTHON'] = \\\"./environment/bin/python\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Define Part\\n\",\n    \"### 1. Define the paths\\n\",\n    \"You need to update them to your real paths.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# The input path of dataset\\n\",\n    \"dataRoot = os.getenv(\\\"DATA_ROOT\\\", \\\"/data\\\")\\n\",\n    \"orig_raw_path = dataRoot + \\\"/mortgage/input/\\\"\\n\",\n    \"orig_raw_path_csv2parquet = dataRoot + \\\"/mortgage/output/csv2parquet/\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"SPARK_MASTER_URL = os.getenv(\\\"SPARK_MASTER_URL\\\", \\\"/your-url\\\")\\n\",\n    \"RAPIDS_JAR = os.getenv(\\\"RAPIDS_JAR\\\", \\\"/your-jar-path\\\")\\n\",\n    \"\\n\",\n    \"# You need to update with your real hardware resource \\n\",\n    \"driverMem = os.getenv(\\\"DRIVER_MEM\\\", \\\"10g\\\")\\n\",\n    \"executorMem = os.getenv(\\\"EXECUTOR_MEM\\\", \\\"10g\\\")\\n\",\n    \"pinnedPoolSize = os.getenv(\\\"PINNED_POOL_SIZE\\\", \\\"2g\\\")\\n\",\n    \"concurrentGpuTasks = os.getenv(\\\"CONCURRENT_GPU_TASKS\\\", \\\"2\\\")\\n\",\n    \"executorCores = int(os.getenv(\\\"EXECUTOR_CORES\\\", \\\"4\\\"))\\n\",\n    \"\\n\",\n    \"# Common spark settings\\n\",\n    \"conf = SparkConf()\\n\",\n    \"conf.setMaster(SPARK_MASTER_URL)\\n\",\n    \"conf.setAppName(\\\"Microbenchmark on GPU\\\")\\n\",\n    \"conf.set(\\\"spark.driver.memory\\\", driverMem)\\n\",\n    \"## The tasks will run on GPU memory, so there is no need to set a high host memory\\n\",\n    \"conf.set(\\\"spark.executor.memory\\\", executorMem)\\n\",\n    \"## The tasks will run on GPU cores, so there is no need to use many cpu cores\\n\",\n    \"conf.set(\\\"spark.executor.cores\\\", executorCores)\\n\",\n    \"\\n\",\n    \"# Plugin settings\\n\",\n    \"conf.set(\\\"spark.executor.resource.gpu.amount\\\", \\\"1\\\")\\n\",\n    \"conf.set(\\\"spark.rapids.sql.concurrentGpuTasks\\\", concurrentGpuTasks)\\n\",\n    \"conf.set(\\\"spark.rapids.memory.pinnedPool.size\\\", pinnedPoolSize)\\n\",\n    \"##############note: only support value=1 see https://github.com/dmlc/xgboost/blame/master/python-package/xgboost/spark/core.py#L370-L374\\n\",\n    \"conf.set(\\\"spark.task.resource.gpu.amount\\\", 1) \\n\",\n    \"# since pyspark and xgboost share the same GPU, we disable RMM to avoid GPU OOM while training \\n\",\n    \"conf.set(\\\"spark.rapids.memory.gpu.pool\\\", \\\"NONE\\\")\\n\",\n    \"conf.set(\\\"spark.rapids.sql.enabled\\\", \\\"true\\\") \\n\",\n    \"conf.set(\\\"spark.plugins\\\", \\\"com.nvidia.spark.SQLPlugin\\\")\\n\",\n    \"conf.set(\\\"spark.sql.cache.serializer\\\",\\\"com.nvidia.spark.ParquetCachedBatchSerializer\\\")\\n\",\n    \"conf.set(\\\"spark.sql.execution.arrow.maxRecordsPerBatch\\\", 200000) \\n\",\n    \"conf.set(\\\"spark.driver.extraClassPath\\\", RAPIDS_JAR)\\n\",\n    \"conf.set(\\\"spark.executor.extraClassPath\\\", RAPIDS_JAR)\\n\",\n    \"conf.set(\\\"spark.jars\\\", RAPIDS_JAR)\\n\",\n    \"\\n\",\n    \"# if you pass/unpack the archive file and enable the environment\\n\",\n    \"# conf.set(\\\"spark.yarn.dist.archives\\\", \\\"your_pyspark_venv.tar.gz#environment\\\")\\n\",\n    \"\\n\",\n    \"# Create spark session\\n\",\n    \"spark = SparkSession.builder.config(conf=conf).getOrCreate()\\n\",\n    \"reader = spark.read\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Set True to save processed dataset after ETL\\n\",\n    \"# Set False, the dataset after ETL will be directly used in XGBoost train and transform\\n\",\n    \"\\n\",\n    \"is_save_dataset=True\\n\",\n    \"output_path_data=dataRoot + \\\"/mortgage/output/data/\\\"\\n\",\n    \"# the path to save the xgboost model\\n\",\n    \"output_path_model=dataRoot + \\\"/mortgage/output/model/\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### 2. Define the constants\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# File schema\\n\",\n    \"\\n\",\n    \"_csv_raw_schema = StructType([\\n\",\n    \"      StructField(\\\"reference_pool_id\\\", StringType()),\\n\",\n    \"      StructField(\\\"loan_id\\\", LongType()),\\n\",\n    \"      StructField(\\\"monthly_reporting_period\\\", StringType()),\\n\",\n    \"      StructField(\\\"orig_channel\\\", StringType()),\\n\",\n    \"      StructField(\\\"seller_name\\\", StringType()),\\n\",\n    \"      StructField(\\\"servicer\\\", StringType()),\\n\",\n    \"      StructField(\\\"master_servicer\\\", StringType()),\\n\",\n    \"      StructField(\\\"orig_interest_rate\\\", DoubleType()),\\n\",\n    \"      StructField(\\\"interest_rate\\\", DoubleType()),\\n\",\n    \"      StructField(\\\"orig_upb\\\", DoubleType()),\\n\",\n    \"      StructField(\\\"upb_at_issuance\\\", StringType()),\\n\",\n    \"      StructField(\\\"current_actual_upb\\\", DoubleType()),\\n\",\n    \"      StructField(\\\"orig_loan_term\\\", IntegerType()),\\n\",\n    \"      StructField(\\\"orig_date\\\", StringType()),\\n\",\n    \"      StructField(\\\"first_pay_date\\\", StringType()),    \\n\",\n    \"      StructField(\\\"loan_age\\\", DoubleType()),\\n\",\n    \"      StructField(\\\"remaining_months_to_legal_maturity\\\", DoubleType()),\\n\",\n    \"      StructField(\\\"adj_remaining_months_to_maturity\\\", DoubleType()),\\n\",\n    \"      StructField(\\\"maturity_date\\\", StringType()),\\n\",\n    \"      StructField(\\\"orig_ltv\\\", DoubleType()),\\n\",\n    \"      StructField(\\\"orig_cltv\\\", DoubleType()),\\n\",\n    \"      StructField(\\\"num_borrowers\\\", DoubleType()),\\n\",\n    \"      StructField(\\\"dti\\\", DoubleType()),\\n\",\n    \"      StructField(\\\"borrower_credit_score\\\", DoubleType()),\\n\",\n    \"      StructField(\\\"coborrow_credit_score\\\", DoubleType()),\\n\",\n    \"      StructField(\\\"first_home_buyer\\\", StringType()),\\n\",\n    \"      StructField(\\\"loan_purpose\\\", StringType()),\\n\",\n    \"      StructField(\\\"property_type\\\", StringType()),\\n\",\n    \"      StructField(\\\"num_units\\\", IntegerType()),\\n\",\n    \"      StructField(\\\"occupancy_status\\\", StringType()),\\n\",\n    \"      StructField(\\\"property_state\\\", StringType()),\\n\",\n    \"      StructField(\\\"msa\\\", DoubleType()),\\n\",\n    \"      StructField(\\\"zip\\\", IntegerType()),\\n\",\n    \"      StructField(\\\"mortgage_insurance_percent\\\", DoubleType()),\\n\",\n    \"      StructField(\\\"product_type\\\", StringType()),\\n\",\n    \"      StructField(\\\"prepayment_penalty_indicator\\\", StringType()),\\n\",\n    \"      StructField(\\\"interest_only_loan_indicator\\\", StringType()),\\n\",\n    \"      StructField(\\\"interest_only_first_principal_and_interest_payment_date\\\", StringType()),\\n\",\n    \"      StructField(\\\"months_to_amortization\\\", StringType()),\\n\",\n    \"      StructField(\\\"current_loan_delinquency_status\\\", IntegerType()),\\n\",\n    \"      StructField(\\\"loan_payment_history\\\", StringType()),\\n\",\n    \"      StructField(\\\"mod_flag\\\", StringType()),\\n\",\n    \"      StructField(\\\"mortgage_insurance_cancellation_indicator\\\", StringType()),\\n\",\n    \"      StructField(\\\"zero_balance_code\\\", StringType()),\\n\",\n    \"      StructField(\\\"zero_balance_effective_date\\\", StringType()),\\n\",\n    \"      StructField(\\\"upb_at_the_time_of_removal\\\", StringType()),\\n\",\n    \"      StructField(\\\"repurchase_date\\\", StringType()),\\n\",\n    \"      StructField(\\\"scheduled_principal_current\\\", StringType()),\\n\",\n    \"      StructField(\\\"total_principal_current\\\", StringType()),\\n\",\n    \"      StructField(\\\"unscheduled_principal_current\\\", StringType()),\\n\",\n    \"      StructField(\\\"last_paid_installment_date\\\", StringType()),\\n\",\n    \"      StructField(\\\"foreclosed_after\\\", StringType()),\\n\",\n    \"      StructField(\\\"disposition_date\\\", StringType()),\\n\",\n    \"      StructField(\\\"foreclosure_costs\\\", DoubleType()),\\n\",\n    \"      StructField(\\\"prop_preservation_and_repair_costs\\\", DoubleType()),\\n\",\n    \"      StructField(\\\"asset_recovery_costs\\\", DoubleType()),\\n\",\n    \"      StructField(\\\"misc_holding_expenses\\\", DoubleType()),\\n\",\n    \"      StructField(\\\"holding_taxes\\\", DoubleType()),\\n\",\n    \"      StructField(\\\"net_sale_proceeds\\\", DoubleType()),\\n\",\n    \"      StructField(\\\"credit_enhancement_proceeds\\\", DoubleType()),\\n\",\n    \"      StructField(\\\"repurchase_make_whole_proceeds\\\", StringType()),\\n\",\n    \"      StructField(\\\"other_foreclosure_proceeds\\\", DoubleType()),\\n\",\n    \"      StructField(\\\"non_interest_bearing_upb\\\", DoubleType()),\\n\",\n    \"      StructField(\\\"principal_forgiveness_upb\\\", StringType()),\\n\",\n    \"      StructField(\\\"original_list_start_date\\\", StringType()),\\n\",\n    \"      StructField(\\\"original_list_price\\\", StringType()),\\n\",\n    \"      StructField(\\\"current_list_start_date\\\", StringType()),\\n\",\n    \"      StructField(\\\"current_list_price\\\", StringType()),\\n\",\n    \"      StructField(\\\"borrower_credit_score_at_issuance\\\", StringType()),\\n\",\n    \"      StructField(\\\"co-borrower_credit_score_at_issuance\\\", StringType()),\\n\",\n    \"      StructField(\\\"borrower_credit_score_current\\\", StringType()),\\n\",\n    \"      StructField(\\\"co-Borrower_credit_score_current\\\", StringType()),\\n\",\n    \"      StructField(\\\"mortgage_insurance_type\\\", DoubleType()),\\n\",\n    \"      StructField(\\\"servicing_activity_indicator\\\", StringType()),\\n\",\n    \"      StructField(\\\"current_period_modification_loss_amount\\\", StringType()),\\n\",\n    \"      StructField(\\\"cumulative_modification_loss_amount\\\", StringType()),\\n\",\n    \"      StructField(\\\"current_period_credit_event_net_gain_or_loss\\\", StringType()),\\n\",\n    \"      StructField(\\\"cumulative_credit_event_net_gain_or_loss\\\", StringType()),\\n\",\n    \"      StructField(\\\"homeready_program_indicator\\\", StringType()),\\n\",\n    \"      StructField(\\\"foreclosure_principal_write_off_amount\\\", StringType()),\\n\",\n    \"      StructField(\\\"relocation_mortgage_indicator\\\", StringType()),\\n\",\n    \"      StructField(\\\"zero_balance_code_change_date\\\", StringType()),\\n\",\n    \"      StructField(\\\"loan_holdback_indicator\\\", StringType()),\\n\",\n    \"      StructField(\\\"loan_holdback_effective_date\\\", StringType()),\\n\",\n    \"      StructField(\\\"delinquent_accrued_interest\\\", StringType()),\\n\",\n    \"      StructField(\\\"property_valuation_method\\\", StringType()),\\n\",\n    \"      StructField(\\\"high_balance_loan_indicator\\\", StringType()),\\n\",\n    \"      StructField(\\\"arm_initial_fixed-rate_period_lt_5_yr_indicator\\\", StringType()),\\n\",\n    \"      StructField(\\\"arm_product_type\\\", StringType()),\\n\",\n    \"      StructField(\\\"initial_fixed-rate_period\\\", StringType()),\\n\",\n    \"      StructField(\\\"interest_rate_adjustment_frequency\\\", StringType()),\\n\",\n    \"      StructField(\\\"next_interest_rate_adjustment_date\\\", StringType()),\\n\",\n    \"      StructField(\\\"next_payment_change_date\\\", StringType()),\\n\",\n    \"      StructField(\\\"index\\\", StringType()),\\n\",\n    \"      StructField(\\\"arm_cap_structure\\\", StringType()),\\n\",\n    \"      StructField(\\\"initial_interest_rate_cap_up_percent\\\", StringType()),\\n\",\n    \"      StructField(\\\"periodic_interest_rate_cap_up_percent\\\", StringType()),\\n\",\n    \"      StructField(\\\"lifetime_interest_rate_cap_up_percent\\\", StringType()),\\n\",\n    \"      StructField(\\\"mortgage_margin\\\", StringType()),\\n\",\n    \"      StructField(\\\"arm_balloon_indicator\\\", StringType()),\\n\",\n    \"      StructField(\\\"arm_plan_number\\\", StringType()),\\n\",\n    \"      StructField(\\\"borrower_assistance_plan\\\", StringType()),\\n\",\n    \"      StructField(\\\"hltv_refinance_option_indicator\\\", StringType()),\\n\",\n    \"      StructField(\\\"deal_name\\\", StringType()),\\n\",\n    \"      StructField(\\\"repurchase_make_whole_proceeds_flag\\\", StringType()),\\n\",\n    \"      StructField(\\\"alternative_delinquency_resolution\\\", StringType()),\\n\",\n    \"      StructField(\\\"alternative_delinquency_resolution_count\\\", StringType()),\\n\",\n    \"      StructField(\\\"total_deferral_amount\\\", StringType())\\n\",\n    \"      ])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# name mappings\\n\",\n    \"_name_mapping = [\\n\",\n    \"        (\\\"WITMER FUNDING, LLC\\\", \\\"Witmer\\\"),\\n\",\n    \"        (\\\"WELLS FARGO CREDIT RISK TRANSFER SECURITIES TRUST 2015\\\", \\\"Wells Fargo\\\"),\\n\",\n    \"        (\\\"WELLS FARGO BANK,  NA\\\" , \\\"Wells Fargo\\\"),\\n\",\n    \"        (\\\"WELLS FARGO BANK, N.A.\\\" , \\\"Wells Fargo\\\"),\\n\",\n    \"        (\\\"WELLS FARGO BANK, NA\\\" , \\\"Wells Fargo\\\"),\\n\",\n    \"        (\\\"USAA FEDERAL SAVINGS BANK\\\" , \\\"USAA\\\"),\\n\",\n    \"        (\\\"UNITED SHORE FINANCIAL SERVICES, LLC D\\\\\\\\/B\\\\\\\\/A UNITED WHOLESALE MORTGAGE\\\" , \\\"United Seq(e\\\"),\\n\",\n    \"        (\\\"U.S. BANK N.A.\\\" , \\\"US Bank\\\"),\\n\",\n    \"        (\\\"SUNTRUST MORTGAGE INC.\\\" , \\\"Suntrust\\\"),\\n\",\n    \"        (\\\"STONEGATE MORTGAGE CORPORATION\\\" , \\\"Stonegate Mortgage\\\"),\\n\",\n    \"        (\\\"STEARNS LENDING, LLC\\\" , \\\"Stearns Lending\\\"),\\n\",\n    \"        (\\\"STEARNS LENDING, INC.\\\" , \\\"Stearns Lending\\\"),\\n\",\n    \"        (\\\"SIERRA PACIFIC MORTGAGE COMPANY, INC.\\\" , \\\"Sierra Pacific Mortgage\\\"),\\n\",\n    \"        (\\\"REGIONS BANK\\\" , \\\"Regions\\\"),\\n\",\n    \"        (\\\"RBC MORTGAGE COMPANY\\\" , \\\"RBC\\\"),\\n\",\n    \"        (\\\"QUICKEN LOANS INC.\\\" , \\\"Quicken Loans\\\"),\\n\",\n    \"        (\\\"PULTE MORTGAGE, L.L.C.\\\" , \\\"Pulte Mortgage\\\"),\\n\",\n    \"        (\\\"PROVIDENT FUNDING ASSOCIATES, L.P.\\\" , \\\"Provident Funding\\\"),\\n\",\n    \"        (\\\"PROSPECT MORTGAGE, LLC\\\" , \\\"Prospect Mortgage\\\"),\\n\",\n    \"        (\\\"PRINCIPAL RESIDENTIAL MORTGAGE CAPITAL RESOURCES, LLC\\\" , \\\"Principal Residential\\\"),\\n\",\n    \"        (\\\"PNC BANK, N.A.\\\" , \\\"PNC\\\"),\\n\",\n    \"        (\\\"PMT CREDIT RISK TRANSFER TRUST 2015-2\\\" , \\\"PennyMac\\\"),\\n\",\n    \"        (\\\"PHH MORTGAGE CORPORATION\\\" , \\\"PHH Mortgage\\\"),\\n\",\n    \"        (\\\"PENNYMAC CORP.\\\" , \\\"PennyMac\\\"),\\n\",\n    \"        (\\\"PACIFIC UNION FINANCIAL, LLC\\\" , \\\"Other\\\"),\\n\",\n    \"        (\\\"OTHER\\\" , \\\"Other\\\"),\\n\",\n    \"        (\\\"NYCB MORTGAGE COMPANY, LLC\\\" , \\\"NYCB\\\"),\\n\",\n    \"        (\\\"NEW YORK COMMUNITY BANK\\\" , \\\"NYCB\\\"),\\n\",\n    \"        (\\\"NETBANK FUNDING SERVICES\\\" , \\\"Netbank\\\"),\\n\",\n    \"        (\\\"NATIONSTAR MORTGAGE, LLC\\\" , \\\"Nationstar Mortgage\\\"),\\n\",\n    \"        (\\\"METLIFE BANK, NA\\\" , \\\"Metlife\\\"),\\n\",\n    \"        (\\\"LOANDEPOT.COM, LLC\\\" , \\\"LoanDepot.com\\\"),\\n\",\n    \"        (\\\"J.P. MORGAN MADISON AVENUE SECURITIES TRUST, SERIES 2015-1\\\" , \\\"JP Morgan Chase\\\"),\\n\",\n    \"        (\\\"J.P. MORGAN MADISON AVENUE SECURITIES TRUST, SERIES 2014-1\\\" , \\\"JP Morgan Chase\\\"),\\n\",\n    \"        (\\\"JPMORGAN CHASE BANK, NATIONAL ASSOCIATION\\\" , \\\"JP Morgan Chase\\\"),\\n\",\n    \"        (\\\"JPMORGAN CHASE BANK, NA\\\" , \\\"JP Morgan Chase\\\"),\\n\",\n    \"        (\\\"JP MORGAN CHASE BANK, NA\\\" , \\\"JP Morgan Chase\\\"),\\n\",\n    \"        (\\\"IRWIN MORTGAGE, CORPORATION\\\" , \\\"Irwin Mortgage\\\"),\\n\",\n    \"        (\\\"IMPAC MORTGAGE CORP.\\\" , \\\"Impac Mortgage\\\"),\\n\",\n    \"        (\\\"HSBC BANK USA, NATIONAL ASSOCIATION\\\" , \\\"HSBC\\\"),\\n\",\n    \"        (\\\"HOMEWARD RESIDENTIAL, INC.\\\" , \\\"Homeward Mortgage\\\"),\\n\",\n    \"        (\\\"HOMESTREET BANK\\\" , \\\"Other\\\"),\\n\",\n    \"        (\\\"HOMEBRIDGE FINANCIAL SERVICES, INC.\\\" , \\\"HomeBridge\\\"),\\n\",\n    \"        (\\\"HARWOOD STREET FUNDING I, LLC\\\" , \\\"Harwood Mortgage\\\"),\\n\",\n    \"        (\\\"GUILD MORTGAGE COMPANY\\\" , \\\"Guild Mortgage\\\"),\\n\",\n    \"        (\\\"GMAC MORTGAGE, LLC (USAA FEDERAL SAVINGS BANK)\\\" , \\\"GMAC\\\"),\\n\",\n    \"        (\\\"GMAC MORTGAGE, LLC\\\" , \\\"GMAC\\\"),\\n\",\n    \"        (\\\"GMAC (USAA)\\\" , \\\"GMAC\\\"),\\n\",\n    \"        (\\\"FREMONT BANK\\\" , \\\"Fremont Bank\\\"),\\n\",\n    \"        (\\\"FREEDOM MORTGAGE CORP.\\\" , \\\"Freedom Mortgage\\\"),\\n\",\n    \"        (\\\"FRANKLIN AMERICAN MORTGAGE COMPANY\\\" , \\\"Franklin America\\\"),\\n\",\n    \"        (\\\"FLEET NATIONAL BANK\\\" , \\\"Fleet National\\\"),\\n\",\n    \"        (\\\"FLAGSTAR CAPITAL MARKETS CORPORATION\\\" , \\\"Flagstar Bank\\\"),\\n\",\n    \"        (\\\"FLAGSTAR BANK, FSB\\\" , \\\"Flagstar Bank\\\"),\\n\",\n    \"        (\\\"FIRST TENNESSEE BANK NATIONAL ASSOCIATION\\\" , \\\"Other\\\"),\\n\",\n    \"        (\\\"FIFTH THIRD BANK\\\" , \\\"Fifth Third Bank\\\"),\\n\",\n    \"        (\\\"FEDERAL HOME LOAN BANK OF CHICAGO\\\" , \\\"Fedral Home of Chicago\\\"),\\n\",\n    \"        (\\\"FDIC, RECEIVER, INDYMAC FEDERAL BANK FSB\\\" , \\\"FDIC\\\"),\\n\",\n    \"        (\\\"DOWNEY SAVINGS AND LOAN ASSOCIATION, F.A.\\\" , \\\"Downey Mortgage\\\"),\\n\",\n    \"        (\\\"DITECH FINANCIAL LLC\\\" , \\\"Ditech\\\"),\\n\",\n    \"        (\\\"CITIMORTGAGE, INC.\\\" , \\\"Citi\\\"),\\n\",\n    \"        (\\\"CHICAGO MORTGAGE SOLUTIONS DBA INTERFIRST MORTGAGE COMPANY\\\" , \\\"Chicago Mortgage\\\"),\\n\",\n    \"        (\\\"CHICAGO MORTGAGE SOLUTIONS DBA INTERBANK MORTGAGE COMPANY\\\" , \\\"Chicago Mortgage\\\"),\\n\",\n    \"        (\\\"CHASE HOME FINANCE, LLC\\\" , \\\"JP Morgan Chase\\\"),\\n\",\n    \"        (\\\"CHASE HOME FINANCE FRANKLIN AMERICAN MORTGAGE COMPANY\\\" , \\\"JP Morgan Chase\\\"),\\n\",\n    \"        (\\\"CHASE HOME FINANCE (CIE 1)\\\" , \\\"JP Morgan Chase\\\"),\\n\",\n    \"        (\\\"CHASE HOME FINANCE\\\" , \\\"JP Morgan Chase\\\"),\\n\",\n    \"        (\\\"CASHCALL, INC.\\\" , \\\"CashCall\\\"),\\n\",\n    \"        (\\\"CAPITAL ONE, NATIONAL ASSOCIATION\\\" , \\\"Capital One\\\"),\\n\",\n    \"        (\\\"CALIBER HOME LOANS, INC.\\\" , \\\"Caliber Funding\\\"),\\n\",\n    \"        (\\\"BISHOPS GATE RESIDENTIAL MORTGAGE TRUST\\\" , \\\"Bishops Gate Mortgage\\\"),\\n\",\n    \"        (\\\"BANK OF AMERICA, N.A.\\\" , \\\"Bank of America\\\"),\\n\",\n    \"        (\\\"AMTRUST BANK\\\" , \\\"AmTrust\\\"),\\n\",\n    \"        (\\\"AMERISAVE MORTGAGE CORPORATION\\\" , \\\"Amerisave\\\"),\\n\",\n    \"        (\\\"AMERIHOME MORTGAGE COMPANY, LLC\\\" , \\\"AmeriHome Mortgage\\\"),\\n\",\n    \"        (\\\"ALLY BANK\\\" , \\\"Ally Bank\\\"),\\n\",\n    \"        (\\\"ACADEMY MORTGAGE CORPORATION\\\" , \\\"Academy Mortgage\\\"),\\n\",\n    \"        (\\\"NO CASH-OUT REFINANCE\\\" , \\\"OTHER REFINANCE\\\"),\\n\",\n    \"        (\\\"REFINANCE - NOT SPECIFIED\\\" , \\\"OTHER REFINANCE\\\"),\\n\",\n    \"        (\\\"Other REFINANCE\\\" , \\\"OTHER REFINANCE\\\")]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# String columns\\n\",\n    \"cate_col_names = [\\n\",\n    \"        \\\"orig_channel\\\",\\n\",\n    \"        \\\"first_home_buyer\\\",\\n\",\n    \"        \\\"loan_purpose\\\",\\n\",\n    \"        \\\"property_type\\\",\\n\",\n    \"        \\\"occupancy_status\\\",\\n\",\n    \"        \\\"property_state\\\",\\n\",\n    \"        \\\"product_type\\\",\\n\",\n    \"        \\\"relocation_mortgage_indicator\\\",\\n\",\n    \"        \\\"seller_name\\\",\\n\",\n    \"        \\\"mod_flag\\\"\\n\",\n    \"]\\n\",\n    \"# Numeric columns\\n\",\n    \"label_col_name = \\\"delinquency_12\\\"\\n\",\n    \"numeric_col_names = [\\n\",\n    \"        \\\"orig_interest_rate\\\",\\n\",\n    \"        \\\"orig_upb\\\",\\n\",\n    \"        \\\"orig_loan_term\\\",\\n\",\n    \"        \\\"orig_ltv\\\",\\n\",\n    \"        \\\"orig_cltv\\\",\\n\",\n    \"        \\\"num_borrowers\\\",\\n\",\n    \"        \\\"dti\\\",\\n\",\n    \"        \\\"borrower_credit_score\\\",\\n\",\n    \"        \\\"num_units\\\",\\n\",\n    \"        \\\"zip\\\",\\n\",\n    \"        \\\"mortgage_insurance_percent\\\",\\n\",\n    \"        \\\"current_loan_delinquency_status\\\",\\n\",\n    \"        \\\"current_actual_upb\\\",\\n\",\n    \"        \\\"interest_rate\\\",\\n\",\n    \"        \\\"loan_age\\\",\\n\",\n    \"        \\\"msa\\\",\\n\",\n    \"        \\\"non_interest_bearing_upb\\\",\\n\",\n    \"        label_col_name\\n\",\n    \"]\\n\",\n    \"all_col_names = cate_col_names + numeric_col_names\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### 3. Define ETL Process\\n\",\n    \"\\n\",\n    \"Define the function to do the ETL process\\n\",\n    \"\\n\",\n    \"#### 3.1 Define Functions to Read Raw CSV File\\n\",\n    \"\\n\",\n    \"* Define function to get quarter from input CSV file name\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def _get_quarter_from_csv_file_name():\\n\",\n    \"    return substring_index(substring_index(input_file_name(), \\\".\\\", 1), \\\"/\\\", -1)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"* Define function to read raw CSV data file\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 9,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def read_raw_csv(spark, path):\\n\",\n    \"    return spark.read.format('csv') \\\\\\n\",\n    \"            .option('nullValue', '') \\\\\\n\",\n    \"            .option('header', False) \\\\\\n\",\n    \"            .option('delimiter', '|') \\\\\\n\",\n    \"            .schema(_csv_raw_schema) \\\\\\n\",\n    \"            .load(path) \\\\\\n\",\n    \"            .withColumn('quarter', _get_quarter_from_csv_file_name())\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"* Functions to extract perf and acq columns from raw schema\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 12,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def extract_perf_columns(rawDf):\\n\",\n    \"    perfDf = rawDf.select(\\n\",\n    \"      col(\\\"loan_id\\\"),\\n\",\n    \"      date_format(to_date(col(\\\"monthly_reporting_period\\\"),\\\"MMyyyy\\\"), \\\"MM/dd/yyyy\\\").alias(\\\"monthly_reporting_period\\\"),\\n\",\n    \"      upper(col(\\\"servicer\\\")).alias(\\\"servicer\\\"),\\n\",\n    \"      col(\\\"interest_rate\\\"),\\n\",\n    \"      col(\\\"current_actual_upb\\\"),\\n\",\n    \"      col(\\\"loan_age\\\"),\\n\",\n    \"      col(\\\"remaining_months_to_legal_maturity\\\"),\\n\",\n    \"      col(\\\"adj_remaining_months_to_maturity\\\"),\\n\",\n    \"      date_format(to_date(col(\\\"maturity_date\\\"),\\\"MMyyyy\\\"), \\\"MM/yyyy\\\").alias(\\\"maturity_date\\\"),\\n\",\n    \"      col(\\\"msa\\\"),\\n\",\n    \"      col(\\\"current_loan_delinquency_status\\\"),\\n\",\n    \"      col(\\\"mod_flag\\\"),\\n\",\n    \"      col(\\\"zero_balance_code\\\"),\\n\",\n    \"      date_format(to_date(col(\\\"zero_balance_effective_date\\\"),\\\"MMyyyy\\\"), \\\"MM/yyyy\\\").alias(\\\"zero_balance_effective_date\\\"),\\n\",\n    \"      date_format(to_date(col(\\\"last_paid_installment_date\\\"),\\\"MMyyyy\\\"), \\\"MM/dd/yyyy\\\").alias(\\\"last_paid_installment_date\\\"),\\n\",\n    \"      date_format(to_date(col(\\\"foreclosed_after\\\"),\\\"MMyyyy\\\"), \\\"MM/dd/yyyy\\\").alias(\\\"foreclosed_after\\\"),\\n\",\n    \"      date_format(to_date(col(\\\"disposition_date\\\"),\\\"MMyyyy\\\"), \\\"MM/dd/yyyy\\\").alias(\\\"disposition_date\\\"),\\n\",\n    \"      col(\\\"foreclosure_costs\\\"),\\n\",\n    \"      col(\\\"prop_preservation_and_repair_costs\\\"),\\n\",\n    \"      col(\\\"asset_recovery_costs\\\"),\\n\",\n    \"      col(\\\"misc_holding_expenses\\\"),\\n\",\n    \"      col(\\\"holding_taxes\\\"),\\n\",\n    \"      col(\\\"net_sale_proceeds\\\"),\\n\",\n    \"      col(\\\"credit_enhancement_proceeds\\\"),\\n\",\n    \"      col(\\\"repurchase_make_whole_proceeds\\\"),\\n\",\n    \"      col(\\\"other_foreclosure_proceeds\\\"),\\n\",\n    \"      col(\\\"non_interest_bearing_upb\\\"),\\n\",\n    \"      col(\\\"principal_forgiveness_upb\\\"),\\n\",\n    \"      col(\\\"repurchase_make_whole_proceeds_flag\\\"),\\n\",\n    \"      col(\\\"foreclosure_principal_write_off_amount\\\"),\\n\",\n    \"      col(\\\"servicing_activity_indicator\\\"),\\n\",\n    \"      col('quarter')\\n\",\n    \"    )\\n\",\n    \"    return perfDf.select(\\\"*\\\").filter(\\\"current_actual_upb != 0.0\\\")\\n\",\n    \"\\n\",\n    \"def extract_acq_columns(rawDf):\\n\",\n    \"    acqDf = rawDf.select(\\n\",\n    \"      col(\\\"loan_id\\\"),\\n\",\n    \"      col(\\\"orig_channel\\\"),\\n\",\n    \"      upper(col(\\\"seller_name\\\")).alias(\\\"seller_name\\\"),\\n\",\n    \"      col(\\\"orig_interest_rate\\\"),\\n\",\n    \"      col(\\\"orig_upb\\\"),\\n\",\n    \"      col(\\\"orig_loan_term\\\"),\\n\",\n    \"      date_format(to_date(col(\\\"orig_date\\\"),\\\"MMyyyy\\\"), \\\"MM/yyyy\\\").alias(\\\"orig_date\\\"),\\n\",\n    \"      date_format(to_date(col(\\\"first_pay_date\\\"),\\\"MMyyyy\\\"), \\\"MM/yyyy\\\").alias(\\\"first_pay_date\\\"),\\n\",\n    \"      col(\\\"orig_ltv\\\"),\\n\",\n    \"      col(\\\"orig_cltv\\\"),\\n\",\n    \"      col(\\\"num_borrowers\\\"),\\n\",\n    \"      col(\\\"dti\\\"),\\n\",\n    \"      col(\\\"borrower_credit_score\\\"),\\n\",\n    \"      col(\\\"first_home_buyer\\\"),\\n\",\n    \"      col(\\\"loan_purpose\\\"),\\n\",\n    \"      col(\\\"property_type\\\"),\\n\",\n    \"      col(\\\"num_units\\\"),\\n\",\n    \"      col(\\\"occupancy_status\\\"),\\n\",\n    \"      col(\\\"property_state\\\"),\\n\",\n    \"      col(\\\"zip\\\"),\\n\",\n    \"      col(\\\"mortgage_insurance_percent\\\"),\\n\",\n    \"      col(\\\"product_type\\\"),\\n\",\n    \"      col(\\\"coborrow_credit_score\\\"),\\n\",\n    \"      col(\\\"mortgage_insurance_type\\\"),\\n\",\n    \"      col(\\\"relocation_mortgage_indicator\\\"),\\n\",\n    \"      dense_rank().over(Window.partitionBy(\\\"loan_id\\\").orderBy(to_date(col(\\\"monthly_reporting_period\\\"),\\\"MMyyyy\\\"))).alias(\\\"rank\\\"),\\n\",\n    \"      col('quarter')\\n\",\n    \"      )\\n\",\n    \"\\n\",\n    \"    return acqDf.select(\\\"*\\\").filter(col(\\\"rank\\\")==1)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### 3.2 Define ETL Process\\n\",\n    \"\\n\",\n    \"* Define function to parse dates in Performance data\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 13,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def _parse_dates(perf):\\n\",\n    \"    return perf \\\\\\n\",\n    \"            .withColumn(\\\"monthly_reporting_period\\\", to_date(col(\\\"monthly_reporting_period\\\"), \\\"MM/dd/yyyy\\\")) \\\\\\n\",\n    \"            .withColumn(\\\"monthly_reporting_period_month\\\", month(col(\\\"monthly_reporting_period\\\"))) \\\\\\n\",\n    \"            .withColumn(\\\"monthly_reporting_period_year\\\", year(col(\\\"monthly_reporting_period\\\"))) \\\\\\n\",\n    \"            .withColumn(\\\"monthly_reporting_period_day\\\", dayofmonth(col(\\\"monthly_reporting_period\\\"))) \\\\\\n\",\n    \"            .withColumn(\\\"last_paid_installment_date\\\", to_date(col(\\\"last_paid_installment_date\\\"), \\\"MM/dd/yyyy\\\")) \\\\\\n\",\n    \"            .withColumn(\\\"foreclosed_after\\\", to_date(col(\\\"foreclosed_after\\\"), \\\"MM/dd/yyyy\\\")) \\\\\\n\",\n    \"            .withColumn(\\\"disposition_date\\\", to_date(col(\\\"disposition_date\\\"), \\\"MM/dd/yyyy\\\")) \\\\\\n\",\n    \"            .withColumn(\\\"maturity_date\\\", to_date(col(\\\"maturity_date\\\"), \\\"MM/yyyy\\\")) \\\\\\n\",\n    \"            .withColumn(\\\"zero_balance_effective_date\\\", to_date(col(\\\"zero_balance_effective_date\\\"), \\\"MM/yyyy\\\"))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"* Define function to create deliquency data frame from Performance data\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 14,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def _create_perf_deliquency(spark, perf):\\n\",\n    \"    aggDF = perf.select(\\n\",\n    \"            col(\\\"quarter\\\"),\\n\",\n    \"            col(\\\"loan_id\\\"),\\n\",\n    \"            col(\\\"current_loan_delinquency_status\\\"),\\n\",\n    \"            when(col(\\\"current_loan_delinquency_status\\\") >= 1, col(\\\"monthly_reporting_period\\\")).alias(\\\"delinquency_30\\\"),\\n\",\n    \"            when(col(\\\"current_loan_delinquency_status\\\") >= 3, col(\\\"monthly_reporting_period\\\")).alias(\\\"delinquency_90\\\"),\\n\",\n    \"            when(col(\\\"current_loan_delinquency_status\\\") >= 6, col(\\\"monthly_reporting_period\\\")).alias(\\\"delinquency_180\\\")) \\\\\\n\",\n    \"            .groupBy(\\\"quarter\\\", \\\"loan_id\\\") \\\\\\n\",\n    \"            .agg(\\n\",\n    \"                max(\\\"current_loan_delinquency_status\\\").alias(\\\"delinquency_12\\\"),\\n\",\n    \"                min(\\\"delinquency_30\\\").alias(\\\"delinquency_30\\\"),\\n\",\n    \"                min(\\\"delinquency_90\\\").alias(\\\"delinquency_90\\\"),\\n\",\n    \"                min(\\\"delinquency_180\\\").alias(\\\"delinquency_180\\\")) \\\\\\n\",\n    \"            .select(\\n\",\n    \"                col(\\\"quarter\\\"),\\n\",\n    \"                col(\\\"loan_id\\\"),\\n\",\n    \"                (col(\\\"delinquency_12\\\") >= 1).alias(\\\"ever_30\\\"),\\n\",\n    \"                (col(\\\"delinquency_12\\\") >= 3).alias(\\\"ever_90\\\"),\\n\",\n    \"                (col(\\\"delinquency_12\\\") >= 6).alias(\\\"ever_180\\\"),\\n\",\n    \"                col(\\\"delinquency_30\\\"),\\n\",\n    \"                col(\\\"delinquency_90\\\"),\\n\",\n    \"                col(\\\"delinquency_180\\\"))\\n\",\n    \"    joinedDf = perf \\\\\\n\",\n    \"            .withColumnRenamed(\\\"monthly_reporting_period\\\", \\\"timestamp\\\") \\\\\\n\",\n    \"            .withColumnRenamed(\\\"monthly_reporting_period_month\\\", \\\"timestamp_month\\\") \\\\\\n\",\n    \"            .withColumnRenamed(\\\"monthly_reporting_period_year\\\", \\\"timestamp_year\\\") \\\\\\n\",\n    \"            .withColumnRenamed(\\\"current_loan_delinquency_status\\\", \\\"delinquency_12\\\") \\\\\\n\",\n    \"            .withColumnRenamed(\\\"current_actual_upb\\\", \\\"upb_12\\\") \\\\\\n\",\n    \"            .select(\\\"quarter\\\", \\\"loan_id\\\", \\\"timestamp\\\", \\\"delinquency_12\\\", \\\"upb_12\\\", \\\"timestamp_month\\\", \\\"timestamp_year\\\") \\\\\\n\",\n    \"            .join(aggDF, [\\\"loan_id\\\", \\\"quarter\\\"], \\\"left_outer\\\")\\n\",\n    \"\\n\",\n    \"    # calculate the 12 month delinquency and upb values\\n\",\n    \"    months = 12\\n\",\n    \"    monthArray = [lit(x) for x in range(0, 12)]\\n\",\n    \"    # explode on a small amount of data is actually slightly more efficient than a cross join\\n\",\n    \"    testDf = joinedDf \\\\\\n\",\n    \"            .withColumn(\\\"month_y\\\", explode(array(monthArray))) \\\\\\n\",\n    \"            .select(\\n\",\n    \"                    col(\\\"quarter\\\"),\\n\",\n    \"                    floor(((col(\\\"timestamp_year\\\") * 12 + col(\\\"timestamp_month\\\")) - 24000) / months).alias(\\\"josh_mody\\\"),\\n\",\n    \"                    floor(((col(\\\"timestamp_year\\\") * 12 + col(\\\"timestamp_month\\\")) - 24000 - col(\\\"month_y\\\")) / months).alias(\\\"josh_mody_n\\\"),\\n\",\n    \"                    col(\\\"ever_30\\\"),\\n\",\n    \"                    col(\\\"ever_90\\\"),\\n\",\n    \"                    col(\\\"ever_180\\\"),\\n\",\n    \"                    col(\\\"delinquency_30\\\"),\\n\",\n    \"                    col(\\\"delinquency_90\\\"),\\n\",\n    \"                    col(\\\"delinquency_180\\\"),\\n\",\n    \"                    col(\\\"loan_id\\\"),\\n\",\n    \"                    col(\\\"month_y\\\"),\\n\",\n    \"                    col(\\\"delinquency_12\\\"),\\n\",\n    \"                    col(\\\"upb_12\\\")) \\\\\\n\",\n    \"            .groupBy(\\\"quarter\\\", \\\"loan_id\\\", \\\"josh_mody_n\\\", \\\"ever_30\\\", \\\"ever_90\\\", \\\"ever_180\\\", \\\"delinquency_30\\\", \\\"delinquency_90\\\", \\\"delinquency_180\\\", \\\"month_y\\\") \\\\\\n\",\n    \"            .agg(max(\\\"delinquency_12\\\").alias(\\\"delinquency_12\\\"), min(\\\"upb_12\\\").alias(\\\"upb_12\\\")) \\\\\\n\",\n    \"            .withColumn(\\\"timestamp_year\\\", floor((lit(24000) + (col(\\\"josh_mody_n\\\") * lit(months)) + (col(\\\"month_y\\\") - 1)) / lit(12))) \\\\\\n\",\n    \"            .selectExpr(\\\"*\\\", \\\"pmod(24000 + (josh_mody_n * {}) + month_y, 12) as timestamp_month_tmp\\\".format(months)) \\\\\\n\",\n    \"            .withColumn(\\\"timestamp_month\\\", when(col(\\\"timestamp_month_tmp\\\") == lit(0), lit(12)).otherwise(col(\\\"timestamp_month_tmp\\\"))) \\\\\\n\",\n    \"            .withColumn(\\\"delinquency_12\\\", ((col(\\\"delinquency_12\\\") > 3).cast(\\\"int\\\") + (col(\\\"upb_12\\\") == 0).cast(\\\"int\\\")).alias(\\\"delinquency_12\\\")) \\\\\\n\",\n    \"            .drop(\\\"timestamp_month_tmp\\\", \\\"josh_mody_n\\\", \\\"month_y\\\")\\n\",\n    \"\\n\",\n    \"    return perf.withColumnRenamed(\\\"monthly_reporting_period_month\\\", \\\"timestamp_month\\\") \\\\\\n\",\n    \"            .withColumnRenamed(\\\"monthly_reporting_period_year\\\", \\\"timestamp_year\\\") \\\\\\n\",\n    \"            .join(testDf, [\\\"quarter\\\", \\\"loan_id\\\", \\\"timestamp_year\\\", \\\"timestamp_month\\\"], \\\"left\\\") \\\\\\n\",\n    \"            .drop(\\\"timestamp_year\\\", \\\"timestamp_month\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"* Define function to create acquisition data frame from Acquisition data\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 15,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def _create_acquisition(spark, acq):\\n\",\n    \"    nameMapping = spark.createDataFrame(_name_mapping, [\\\"from_seller_name\\\", \\\"to_seller_name\\\"])\\n\",\n    \"    return acq.join(nameMapping, col(\\\"seller_name\\\") == col(\\\"from_seller_name\\\"), \\\"left\\\") \\\\\\n\",\n    \"      .drop(\\\"from_seller_name\\\") \\\\\\n\",\n    \"      .withColumn(\\\"old_name\\\", col(\\\"seller_name\\\")) \\\\\\n\",\n    \"      .withColumn(\\\"seller_name\\\", coalesce(col(\\\"to_seller_name\\\"), col(\\\"seller_name\\\"))) \\\\\\n\",\n    \"      .drop(\\\"to_seller_name\\\") \\\\\\n\",\n    \"      .withColumn(\\\"orig_date\\\", to_date(col(\\\"orig_date\\\"), \\\"MM/yyyy\\\")) \\\\\\n\",\n    \"      .withColumn(\\\"first_pay_date\\\", to_date(col(\\\"first_pay_date\\\"), \\\"MM/yyyy\\\")) \"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### 3.3 Define Casting Process\\n\",\n    \"This part is casting String column to Numeric one. \\n\",\n    \"Example:\\n\",\n    \"```\\n\",\n    \"col_1\\n\",\n    \" \\\"a\\\"\\n\",\n    \" \\\"b\\\"\\n\",\n    \" \\\"c\\\"\\n\",\n    \" \\\"a\\\"\\n\",\n    \"# After String ====> Numeric\\n\",\n    \"col_1\\n\",\n    \" 0\\n\",\n    \" 1\\n\",\n    \" 2\\n\",\n    \" 0\\n\",\n    \"```  \\n\",\n    \"<br>\\n\",\n    \"\\n\",\n    \"* Define function to get column dictionary\\n\",\n    \"\\n\",\n    \"    Example\\n\",\n    \"    ```\\n\",\n    \"    col1 = [row(data=\\\"a\\\",id=0), row(data=\\\"b\\\",id=1)]\\n\",\n    \"    ```\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 16,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def _gen_dictionary(etl_df, col_names):\\n\",\n    \"    cnt_table = etl_df.select(posexplode(array([col(i) for i in col_names])))\\\\\\n\",\n    \"                    .withColumnRenamed(\\\"pos\\\", \\\"column_id\\\")\\\\\\n\",\n    \"                    .withColumnRenamed(\\\"col\\\", \\\"data\\\")\\\\\\n\",\n    \"                    .filter(\\\"data is not null\\\")\\\\\\n\",\n    \"                    .groupBy(\\\"column_id\\\", \\\"data\\\")\\\\\\n\",\n    \"                    .count()\\n\",\n    \"    windowed = Window.partitionBy(\\\"column_id\\\").orderBy(desc(\\\"count\\\"))\\n\",\n    \"    return cnt_table.withColumn(\\\"id\\\", row_number().over(windowed)).drop(\\\"count\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"* Define function to convert string columns to numeric\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 17,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def _cast_string_columns_to_numeric(spark, input_df):\\n\",\n    \"    cached_dict_df = _gen_dictionary(input_df, cate_col_names).cache()\\n\",\n    \"    output_df = input_df\\n\",\n    \"    #  Generate the final table with all columns being numeric.\\n\",\n    \"    for col_pos, col_name in enumerate(cate_col_names):\\n\",\n    \"        col_dict_df = cached_dict_df.filter(col(\\\"column_id\\\") == col_pos)\\\\\\n\",\n    \"                                    .drop(\\\"column_id\\\")\\\\\\n\",\n    \"                                    .withColumnRenamed(\\\"data\\\", col_name)\\n\",\n    \"        \\n\",\n    \"        output_df = output_df.join(broadcast(col_dict_df), col_name, \\\"left\\\")\\\\\\n\",\n    \"                        .drop(col_name)\\\\\\n\",\n    \"                        .withColumnRenamed(\\\"id\\\", col_name)\\n\",\n    \"    return output_df        \"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### 3.4 Define Main Function\\n\",\n    \"In this function:\\n\",\n    \"1. Parse date in Performance data by calling _parse_dates (parsed_perf)\\n\",\n    \"2. Create deliqency dataframe(perf_deliqency) form Performance data by calling _create_perf_deliquency\\n\",\n    \"3. Create cleaned acquisition dataframe(cleaned_acq) from Acquisition data by calling _create_acquisition\\n\",\n    \"4. Join deliqency dataframe(perf_deliqency) and cleaned acquisition dataframe(cleaned_acq), get clean_df\\n\",\n    \"5. Cast String column to Numeric in clean_df by calling _cast_string_columns_to_numeric, get casted_clean_df\\n\",\n    \"6. Return casted_clean_df as final result\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 18,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def run_mortgage(spark, perf, acq):\\n\",\n    \"    parsed_perf = _parse_dates(perf)\\n\",\n    \"    perf_deliqency = _create_perf_deliquency(spark, parsed_perf)\\n\",\n    \"    cleaned_acq = _create_acquisition(spark, acq)\\n\",\n    \"    clean_df = perf_deliqency.join(cleaned_acq, [\\\"loan_id\\\", \\\"quarter\\\"], \\\"inner\\\").drop(\\\"quarter\\\")\\n\",\n    \"    casted_clean_df = _cast_string_columns_to_numeric(spark, clean_df)\\\\\\n\",\n    \"                    .select(all_col_names)\\\\\\n\",\n    \"                    .withColumn(label_col_name, when(col(label_col_name) > 0, 1).otherwise(0))\\\\\\n\",\n    \"                    .fillna(float(0))\\n\",\n    \"    return casted_clean_df\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Run Part\\n\",\n    \"### Run ETL\\n\",\n    \"#### 1. Add additional Spark settings\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 19,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# GPU run, set to true\\n\",\n    \"spark.conf.set(\\\"spark.rapids.sql.enabled\\\", \\\"true\\\")\\n\",\n    \"# CPU run, set to false, it can only make ETL run on CPU when is_save_dataset=True.\\n\",\n    \"# spark.conf.set(\\\"spark.rapids.sql.enabled\\\", \\\"false\\\")\\n\",\n    \"spark.conf.set(\\\"spark.sql.files.maxPartitionBytes\\\", \\\"1G\\\")\\n\",\n    \"spark.conf.set(\\\"spark.rapids.sql.explain\\\", \\\"ALL\\\")\\n\",\n    \"spark.conf.set(\\\"spark.rapids.sql.batchSizeBytes\\\", \\\"512M\\\")\\n\",\n    \"spark.conf.set(\\\"spark.rapids.sql.reader.batchSizeBytes\\\", \\\"768M\\\")\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### 2. Read Raw Data and Run ETL Process, Save the Result\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 20,\n   \"metadata\": {\n    \"scrolled\": false\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"ETL takes 135.9117729663849\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"\\n\",\n    \"# read raw dataset\\n\",\n    \"rawDf = read_raw_csv(spark, orig_raw_path)\\n\",\n    \"rawDf.write.parquet(orig_raw_path_csv2parquet, mode='overwrite')\\n\",\n    \"rawDf = spark.read.parquet(orig_raw_path_csv2parquet)\\n\",\n    \"\\n\",\n    \"acq = extract_acq_columns(rawDf)\\n\",\n    \"perf = extract_perf_columns(rawDf)\\n\",\n    \"\\n\",\n    \"# run main function to process data\\n\",\n    \"out = run_mortgage(spark, perf, acq)\\n\",\n    \"\\n\",\n    \"# save processed data\\n\",\n    \"if is_save_dataset:\\n\",\n    \"    start = time.time()\\n\",\n    \"    out.write.parquet(output_path_data, mode=\\\"overwrite\\\")\\n\",\n    \"    end = time.time()\\n\",\n    \"    print(\\\"ETL takes {}\\\".format(end - start))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## XGBoost Spark with GPU\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"###### Import ML Libraries\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 21,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from xgboost.spark import SparkXGBClassifier, SparkXGBClassifierModel\\n\",\n    \"from pyspark.ml.evaluation import MulticlassClassificationEvaluator\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"###### Create Data Reader\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 22,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Make sure it runs on GPU\\n\",\n    \"spark.conf.set(\\\"spark.rapids.sql.enabled\\\", \\\"true\\\")\\n\",\n    \"\\n\",\n    \"reader = spark.read\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"###### Specify the Data Schema and Load the Data\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 23,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"label = \\\"delinquency_12\\\"\\n\",\n    \"schema = StructType([\\n\",\n    \"    StructField(\\\"orig_channel\\\", FloatType()),\\n\",\n    \"    StructField(\\\"first_home_buyer\\\", FloatType()),\\n\",\n    \"    StructField(\\\"loan_purpose\\\", FloatType()),\\n\",\n    \"    StructField(\\\"property_type\\\", FloatType()),\\n\",\n    \"    StructField(\\\"occupancy_status\\\", FloatType()),\\n\",\n    \"    StructField(\\\"property_state\\\", FloatType()),\\n\",\n    \"    StructField(\\\"product_type\\\", FloatType()),\\n\",\n    \"    StructField(\\\"relocation_mortgage_indicator\\\", FloatType()),\\n\",\n    \"    StructField(\\\"seller_name\\\", FloatType()),\\n\",\n    \"    StructField(\\\"mod_flag\\\", FloatType()),\\n\",\n    \"    StructField(\\\"orig_interest_rate\\\", FloatType()),\\n\",\n    \"    StructField(\\\"orig_upb\\\", DoubleType()),\\n\",\n    \"    StructField(\\\"orig_loan_term\\\", IntegerType()),\\n\",\n    \"    StructField(\\\"orig_ltv\\\", FloatType()),\\n\",\n    \"    StructField(\\\"orig_cltv\\\", FloatType()),\\n\",\n    \"    StructField(\\\"num_borrowers\\\", FloatType()),\\n\",\n    \"    StructField(\\\"dti\\\", FloatType()),\\n\",\n    \"    StructField(\\\"borrower_credit_score\\\", FloatType()),\\n\",\n    \"    StructField(\\\"num_units\\\", IntegerType()),\\n\",\n    \"    StructField(\\\"zip\\\", IntegerType()),\\n\",\n    \"    StructField(\\\"mortgage_insurance_percent\\\", FloatType()),\\n\",\n    \"    StructField(\\\"current_loan_delinquency_status\\\", IntegerType()),\\n\",\n    \"    StructField(\\\"current_actual_upb\\\", FloatType()),\\n\",\n    \"    StructField(\\\"interest_rate\\\", FloatType()),\\n\",\n    \"    StructField(\\\"loan_age\\\", FloatType()),\\n\",\n    \"    StructField(\\\"msa\\\", FloatType()),\\n\",\n    \"    StructField(\\\"non_interest_bearing_upb\\\", FloatType()),\\n\",\n    \"    StructField(label, IntegerType()),\\n\",\n    \"])\\n\",\n    \"features = [ x.name for x in schema if x.name != label ]\\n\",\n    \"\\n\",\n    \"if is_save_dataset:\\n\",\n    \"    # load dataset from file\\n\",\n    \"    etlDf = reader.parquet(output_path_data)\\n\",\n    \"    splits = etlDf.randomSplit([0.8, 0.2])\\n\",\n    \"    train_data = splits[0]\\n\",\n    \"    test_data = splits[1]\\n\",\n    \"else:\\n\",\n    \"    # use Dataframe from ETL directly\\n\",\n    \"    splits = out.randomSplit([0.8, 0.2])\\n\",\n    \"    train_data = splits[0]\\n\",\n    \"    test_data = splits[1]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 24,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# This sample uses 1 worker(GPU) to run XGBoost training, you can change according to your GPU resources\\n\",\n    \"params = { \\n\",\n    \"    \\\"tree_method\\\": \\\"hist\\\",\\n\",\n    \"    \\\"grow_policy\\\": \\\"depthwise\\\",\\n\",\n    \"    \\\"num_workers\\\": 1,\\n\",\n    \"    \\\"device\\\": \\\"cuda\\\",\\n\",\n    \"}\\n\",\n    \"params['features_col'] = features\\n\",\n    \"params['label_col'] = label\\n\",\n    \"    \\n\",\n    \"classifier = SparkXGBClassifier(**params)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 25,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Training takes 18.92583155632019 seconds\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"def with_benchmark(phrase, action):\\n\",\n    \"    start = time.time()\\n\",\n    \"    result = action()\\n\",\n    \"    end = time.time()\\n\",\n    \"    print(\\\"{} takes {} seconds\\\".format(phrase, end - start))\\n\",\n    \"    return result\\n\",\n    \"model = with_benchmark(\\\"Training\\\", lambda: classifier.fit(train_data))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 26,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"model.write().overwrite().save(output_path_model)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"loaded_model = SparkXGBClassifierModel().load(output_path_model)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 27,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Transformation takes 8.959877967834473 seconds\\n\",\n      \"+--------------+--------------------+--------------------+----------+\\n\",\n      \"|delinquency_12|       rawPrediction|         probability|prediction|\\n\",\n      \"+--------------+--------------------+--------------------+----------+\\n\",\n      \"|             0|[7.92072248458862...|[0.99963699193904...|       0.0|\\n\",\n      \"|             0|[7.92072248458862...|[0.99963699193904...|       0.0|\\n\",\n      \"|             0|[8.43130302429199...|[0.99978211015695...|       0.0|\\n\",\n      \"|             0|[8.20779895782470...|[0.99972755435737...|       0.0|\\n\",\n      \"|             0|[8.885986328125,-...|[0.99986170543706...|       0.0|\\n\",\n      \"+--------------+--------------------+--------------------+----------+\\n\",\n      \"only showing top 5 rows\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"def transform():\\n\",\n    \"    result = loaded_model.transform(test_data).cache()\\n\",\n    \"    result.foreachPartition(lambda _: None)\\n\",\n    \"    return result\\n\",\n    \"result = with_benchmark(\\\"Transformation\\\", transform)\\n\",\n    \"result.select(label, \\\"rawPrediction\\\", \\\"probability\\\", \\\"prediction\\\").show(5)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 28,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Evaluation takes 0.6158628463745117 seconds\\n\",\n      \"Accuracy is 0.9861453808970397\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"accuracy = with_benchmark(\\n\",\n    \"    \\\"Evaluation\\\",\\n\",\n    \"    lambda: MulticlassClassificationEvaluator().setLabelCol(label).evaluate(result))\\n\",\n    \"print(\\\"Accuracy is \\\" + str(accuracy))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 30,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"spark.stop()\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.2\"\n  },\n  \"name\": \"gpu-mortgage\",\n  \"notebookId\": 4440374682851873\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 1\n}\n"
  },
  {
    "path": "examples/XGBoost-Examples/mortgage/notebooks/python/MortgageETL.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Prerequirement\\n\",\n    \"### 1. Download data\\n\",\n    \"Dataset is derived from Fannie Mae’s [Single-Family Loan Performance Data](http://www.fanniemae.com/portal/funding-the-market/data/loan-performance-data.html) with all rights reserved by Fannie Mae. Refer to these [instructions](https://github.com/NVIDIA/spark-rapids-examples/blob/branch-24.12/docs/get-started/xgboost-examples/dataset/mortgage.md) to download the dataset.\\n\",\n    \"\\n\",\n    \"### 2. Download needed jars\\n\",\n    \"* [rapids-4-spark_2.12-26.02.0.jar](https://repo1.maven.org/maven2/com/nvidia/rapids-4-spark_2.12/26.02.0/rapids-4-spark_2.12-26.02.0.jar)\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"### 3. Start Spark Standalone\\n\",\n    \"Before running the script, please setup Spark standalone mode\\n\",\n    \"\\n\",\n    \"### 4. Add ENV\\n\",\n    \"```\\n\",\n    \"$ export SPARK_JARS=rapids-4-spark_2.12-26.02.0.jar\\n\",\n    \"$ export PYSPARK_DRIVER_PYTHON=jupyter \\n\",\n    \"$ export PYSPARK_DRIVER_PYTHON_OPTS=notebook\\n\",\n    \"```\\n\",\n    \"\\n\",\n    \"### 5. Start Jupyter Notebook with plugin config\\n\",\n    \"\\n\",\n    \"```\\n\",\n    \"$ pyspark --master ${SPARK_MASTER}            \\\\\\n\",\n    \"--jars ${SPARK_JARS}                \\\\\\n\",\n    \"--conf spark.plugins=com.nvidia.spark.SQLPlugin \\\\\\n\",\n    \"--conf spark.rapids.sql.incompatibleDateFormats.enabled=true \\\\\\n\",\n    \"--conf spark.rapids.sql.csv.read.double.enabled=true \\\\\\n\",\n    \"--py-files ${SPARK_PY_FILES}\\n\",\n    \"```\\n\",\n    \"\\n\",\n    \"## Import Libs\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 18,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import time\\n\",\n    \"import os\\n\",\n    \"from pyspark import broadcast\\n\",\n    \"from pyspark.sql import SparkSession\\n\",\n    \"from pyspark.sql.functions import *\\n\",\n    \"from pyspark.sql.types import *\\n\",\n    \"from pyspark.sql.window import Window\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Create Spark Session\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 19,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"spark = (SparkSession\\n\",\n    \"    .builder\\n\",\n    \"    .appName(\\\"MortgageETL\\\")\\n\",\n    \"    .getOrCreate())\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Function Define\\n\",\n    \"### 1. Define the constants\\n\",\n    \"\\n\",\n    \"* Define input file schema (Performance and Acquisition)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 20,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# File schema\\n\",\n    \"_csv_raw_schema = StructType([\\n\",\n    \"      StructField(\\\"reference_pool_id\\\", StringType()),\\n\",\n    \"      StructField(\\\"loan_id\\\", LongType()),\\n\",\n    \"      StructField(\\\"monthly_reporting_period\\\", StringType()),\\n\",\n    \"      StructField(\\\"orig_channel\\\", StringType()),\\n\",\n    \"      StructField(\\\"seller_name\\\", StringType()),\\n\",\n    \"      StructField(\\\"servicer\\\", StringType()),\\n\",\n    \"      StructField(\\\"master_servicer\\\", StringType()),\\n\",\n    \"      StructField(\\\"orig_interest_rate\\\", DoubleType()),\\n\",\n    \"      StructField(\\\"interest_rate\\\", DoubleType()),\\n\",\n    \"      StructField(\\\"orig_upb\\\", DoubleType()),\\n\",\n    \"      StructField(\\\"upb_at_issuance\\\", StringType()),\\n\",\n    \"      StructField(\\\"current_actual_upb\\\", DoubleType()),\\n\",\n    \"      StructField(\\\"orig_loan_term\\\", IntegerType()),\\n\",\n    \"      StructField(\\\"orig_date\\\", StringType()),\\n\",\n    \"      StructField(\\\"first_pay_date\\\", StringType()),    \\n\",\n    \"      StructField(\\\"loan_age\\\", DoubleType()),\\n\",\n    \"      StructField(\\\"remaining_months_to_legal_maturity\\\", DoubleType()),\\n\",\n    \"      StructField(\\\"adj_remaining_months_to_maturity\\\", DoubleType()),\\n\",\n    \"      StructField(\\\"maturity_date\\\", StringType()),\\n\",\n    \"      StructField(\\\"orig_ltv\\\", DoubleType()),\\n\",\n    \"      StructField(\\\"orig_cltv\\\", DoubleType()),\\n\",\n    \"      StructField(\\\"num_borrowers\\\", DoubleType()),\\n\",\n    \"      StructField(\\\"dti\\\", DoubleType()),\\n\",\n    \"      StructField(\\\"borrower_credit_score\\\", DoubleType()),\\n\",\n    \"      StructField(\\\"coborrow_credit_score\\\", DoubleType()),\\n\",\n    \"      StructField(\\\"first_home_buyer\\\", StringType()),\\n\",\n    \"      StructField(\\\"loan_purpose\\\", StringType()),\\n\",\n    \"      StructField(\\\"property_type\\\", StringType()),\\n\",\n    \"      StructField(\\\"num_units\\\", IntegerType()),\\n\",\n    \"      StructField(\\\"occupancy_status\\\", StringType()),\\n\",\n    \"      StructField(\\\"property_state\\\", StringType()),\\n\",\n    \"      StructField(\\\"msa\\\", DoubleType()),\\n\",\n    \"      StructField(\\\"zip\\\", IntegerType()),\\n\",\n    \"      StructField(\\\"mortgage_insurance_percent\\\", DoubleType()),\\n\",\n    \"      StructField(\\\"product_type\\\", StringType()),\\n\",\n    \"      StructField(\\\"prepayment_penalty_indicator\\\", StringType()),\\n\",\n    \"      StructField(\\\"interest_only_loan_indicator\\\", StringType()),\\n\",\n    \"      StructField(\\\"interest_only_first_principal_and_interest_payment_date\\\", StringType()),\\n\",\n    \"      StructField(\\\"months_to_amortization\\\", StringType()),\\n\",\n    \"      StructField(\\\"current_loan_delinquency_status\\\", IntegerType()),\\n\",\n    \"      StructField(\\\"loan_payment_history\\\", StringType()),\\n\",\n    \"      StructField(\\\"mod_flag\\\", StringType()),\\n\",\n    \"      StructField(\\\"mortgage_insurance_cancellation_indicator\\\", StringType()),\\n\",\n    \"      StructField(\\\"zero_balance_code\\\", StringType()),\\n\",\n    \"      StructField(\\\"zero_balance_effective_date\\\", StringType()),\\n\",\n    \"      StructField(\\\"upb_at_the_time_of_removal\\\", StringType()),\\n\",\n    \"      StructField(\\\"repurchase_date\\\", StringType()),\\n\",\n    \"      StructField(\\\"scheduled_principal_current\\\", StringType()),\\n\",\n    \"      StructField(\\\"total_principal_current\\\", StringType()),\\n\",\n    \"      StructField(\\\"unscheduled_principal_current\\\", StringType()),\\n\",\n    \"      StructField(\\\"last_paid_installment_date\\\", StringType()),\\n\",\n    \"      StructField(\\\"foreclosed_after\\\", StringType()),\\n\",\n    \"      StructField(\\\"disposition_date\\\", StringType()),\\n\",\n    \"      StructField(\\\"foreclosure_costs\\\", DoubleType()),\\n\",\n    \"      StructField(\\\"prop_preservation_and_repair_costs\\\", DoubleType()),\\n\",\n    \"      StructField(\\\"asset_recovery_costs\\\", DoubleType()),\\n\",\n    \"      StructField(\\\"misc_holding_expenses\\\", DoubleType()),\\n\",\n    \"      StructField(\\\"holding_taxes\\\", DoubleType()),\\n\",\n    \"      StructField(\\\"net_sale_proceeds\\\", DoubleType()),\\n\",\n    \"      StructField(\\\"credit_enhancement_proceeds\\\", DoubleType()),\\n\",\n    \"      StructField(\\\"repurchase_make_whole_proceeds\\\", StringType()),\\n\",\n    \"      StructField(\\\"other_foreclosure_proceeds\\\", DoubleType()),\\n\",\n    \"      StructField(\\\"non_interest_bearing_upb\\\", DoubleType()),\\n\",\n    \"      StructField(\\\"principal_forgiveness_upb\\\", StringType()),\\n\",\n    \"      StructField(\\\"original_list_start_date\\\", StringType()),\\n\",\n    \"      StructField(\\\"original_list_price\\\", StringType()),\\n\",\n    \"      StructField(\\\"current_list_start_date\\\", StringType()),\\n\",\n    \"      StructField(\\\"current_list_price\\\", StringType()),\\n\",\n    \"      StructField(\\\"borrower_credit_score_at_issuance\\\", StringType()),\\n\",\n    \"      StructField(\\\"co-borrower_credit_score_at_issuance\\\", StringType()),\\n\",\n    \"      StructField(\\\"borrower_credit_score_current\\\", StringType()),\\n\",\n    \"      StructField(\\\"co-Borrower_credit_score_current\\\", StringType()),\\n\",\n    \"      StructField(\\\"mortgage_insurance_type\\\", DoubleType()),\\n\",\n    \"      StructField(\\\"servicing_activity_indicator\\\", StringType()),\\n\",\n    \"      StructField(\\\"current_period_modification_loss_amount\\\", StringType()),\\n\",\n    \"      StructField(\\\"cumulative_modification_loss_amount\\\", StringType()),\\n\",\n    \"      StructField(\\\"current_period_credit_event_net_gain_or_loss\\\", StringType()),\\n\",\n    \"      StructField(\\\"cumulative_credit_event_net_gain_or_loss\\\", StringType()),\\n\",\n    \"      StructField(\\\"homeready_program_indicator\\\", StringType()),\\n\",\n    \"      StructField(\\\"foreclosure_principal_write_off_amount\\\", StringType()),\\n\",\n    \"      StructField(\\\"relocation_mortgage_indicator\\\", StringType()),\\n\",\n    \"      StructField(\\\"zero_balance_code_change_date\\\", StringType()),\\n\",\n    \"      StructField(\\\"loan_holdback_indicator\\\", StringType()),\\n\",\n    \"      StructField(\\\"loan_holdback_effective_date\\\", StringType()),\\n\",\n    \"      StructField(\\\"delinquent_accrued_interest\\\", StringType()),\\n\",\n    \"      StructField(\\\"property_valuation_method\\\", StringType()),\\n\",\n    \"      StructField(\\\"high_balance_loan_indicator\\\", StringType()),\\n\",\n    \"      StructField(\\\"arm_initial_fixed-rate_period_lt_5_yr_indicator\\\", StringType()),\\n\",\n    \"      StructField(\\\"arm_product_type\\\", StringType()),\\n\",\n    \"      StructField(\\\"initial_fixed-rate_period\\\", StringType()),\\n\",\n    \"      StructField(\\\"interest_rate_adjustment_frequency\\\", StringType()),\\n\",\n    \"      StructField(\\\"next_interest_rate_adjustment_date\\\", StringType()),\\n\",\n    \"      StructField(\\\"next_payment_change_date\\\", StringType()),\\n\",\n    \"      StructField(\\\"index\\\", StringType()),\\n\",\n    \"      StructField(\\\"arm_cap_structure\\\", StringType()),\\n\",\n    \"      StructField(\\\"initial_interest_rate_cap_up_percent\\\", StringType()),\\n\",\n    \"      StructField(\\\"periodic_interest_rate_cap_up_percent\\\", StringType()),\\n\",\n    \"      StructField(\\\"lifetime_interest_rate_cap_up_percent\\\", StringType()),\\n\",\n    \"      StructField(\\\"mortgage_margin\\\", StringType()),\\n\",\n    \"      StructField(\\\"arm_balloon_indicator\\\", StringType()),\\n\",\n    \"      StructField(\\\"arm_plan_number\\\", StringType()),\\n\",\n    \"      StructField(\\\"borrower_assistance_plan\\\", StringType()),\\n\",\n    \"      StructField(\\\"hltv_refinance_option_indicator\\\", StringType()),\\n\",\n    \"      StructField(\\\"deal_name\\\", StringType()),\\n\",\n    \"      StructField(\\\"repurchase_make_whole_proceeds_flag\\\", StringType()),\\n\",\n    \"      StructField(\\\"alternative_delinquency_resolution\\\", StringType()),\\n\",\n    \"      StructField(\\\"alternative_delinquency_resolution_count\\\", StringType()),\\n\",\n    \"      StructField(\\\"total_deferral_amount\\\", StringType())\\n\",\n    \"      ])\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"* Define seller name mapping\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 21,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# name mappings\\n\",\n    \"_name_mapping = [\\n\",\n    \"        (\\\"WITMER FUNDING, LLC\\\", \\\"Witmer\\\"),\\n\",\n    \"        (\\\"WELLS FARGO CREDIT RISK TRANSFER SECURITIES TRUST 2015\\\", \\\"Wells Fargo\\\"),\\n\",\n    \"        (\\\"WELLS FARGO BANK,  NA\\\" , \\\"Wells Fargo\\\"),\\n\",\n    \"        (\\\"WELLS FARGO BANK, N.A.\\\" , \\\"Wells Fargo\\\"),\\n\",\n    \"        (\\\"WELLS FARGO BANK, NA\\\" , \\\"Wells Fargo\\\"),\\n\",\n    \"        (\\\"USAA FEDERAL SAVINGS BANK\\\" , \\\"USAA\\\"),\\n\",\n    \"        (\\\"UNITED SHORE FINANCIAL SERVICES, LLC D\\\\\\\\/B\\\\\\\\/A UNITED WHOLESALE MORTGAGE\\\" , \\\"United Seq(e\\\"),\\n\",\n    \"        (\\\"U.S. BANK N.A.\\\" , \\\"US Bank\\\"),\\n\",\n    \"        (\\\"SUNTRUST MORTGAGE INC.\\\" , \\\"Suntrust\\\"),\\n\",\n    \"        (\\\"STONEGATE MORTGAGE CORPORATION\\\" , \\\"Stonegate Mortgage\\\"),\\n\",\n    \"        (\\\"STEARNS LENDING, LLC\\\" , \\\"Stearns Lending\\\"),\\n\",\n    \"        (\\\"STEARNS LENDING, INC.\\\" , \\\"Stearns Lending\\\"),\\n\",\n    \"        (\\\"SIERRA PACIFIC MORTGAGE COMPANY, INC.\\\" , \\\"Sierra Pacific Mortgage\\\"),\\n\",\n    \"        (\\\"REGIONS BANK\\\" , \\\"Regions\\\"),\\n\",\n    \"        (\\\"RBC MORTGAGE COMPANY\\\" , \\\"RBC\\\"),\\n\",\n    \"        (\\\"QUICKEN LOANS INC.\\\" , \\\"Quicken Loans\\\"),\\n\",\n    \"        (\\\"PULTE MORTGAGE, L.L.C.\\\" , \\\"Pulte Mortgage\\\"),\\n\",\n    \"        (\\\"PROVIDENT FUNDING ASSOCIATES, L.P.\\\" , \\\"Provident Funding\\\"),\\n\",\n    \"        (\\\"PROSPECT MORTGAGE, LLC\\\" , \\\"Prospect Mortgage\\\"),\\n\",\n    \"        (\\\"PRINCIPAL RESIDENTIAL MORTGAGE CAPITAL RESOURCES, LLC\\\" , \\\"Principal Residential\\\"),\\n\",\n    \"        (\\\"PNC BANK, N.A.\\\" , \\\"PNC\\\"),\\n\",\n    \"        (\\\"PMT CREDIT RISK TRANSFER TRUST 2015-2\\\" , \\\"PennyMac\\\"),\\n\",\n    \"        (\\\"PHH MORTGAGE CORPORATION\\\" , \\\"PHH Mortgage\\\"),\\n\",\n    \"        (\\\"PENNYMAC CORP.\\\" , \\\"PennyMac\\\"),\\n\",\n    \"        (\\\"PACIFIC UNION FINANCIAL, LLC\\\" , \\\"Other\\\"),\\n\",\n    \"        (\\\"OTHER\\\" , \\\"Other\\\"),\\n\",\n    \"        (\\\"NYCB MORTGAGE COMPANY, LLC\\\" , \\\"NYCB\\\"),\\n\",\n    \"        (\\\"NEW YORK COMMUNITY BANK\\\" , \\\"NYCB\\\"),\\n\",\n    \"        (\\\"NETBANK FUNDING SERVICES\\\" , \\\"Netbank\\\"),\\n\",\n    \"        (\\\"NATIONSTAR MORTGAGE, LLC\\\" , \\\"Nationstar Mortgage\\\"),\\n\",\n    \"        (\\\"METLIFE BANK, NA\\\" , \\\"Metlife\\\"),\\n\",\n    \"        (\\\"LOANDEPOT.COM, LLC\\\" , \\\"LoanDepot.com\\\"),\\n\",\n    \"        (\\\"J.P. MORGAN MADISON AVENUE SECURITIES TRUST, SERIES 2015-1\\\" , \\\"JP Morgan Chase\\\"),\\n\",\n    \"        (\\\"J.P. MORGAN MADISON AVENUE SECURITIES TRUST, SERIES 2014-1\\\" , \\\"JP Morgan Chase\\\"),\\n\",\n    \"        (\\\"JPMORGAN CHASE BANK, NATIONAL ASSOCIATION\\\" , \\\"JP Morgan Chase\\\"),\\n\",\n    \"        (\\\"JPMORGAN CHASE BANK, NA\\\" , \\\"JP Morgan Chase\\\"),\\n\",\n    \"        (\\\"JP MORGAN CHASE BANK, NA\\\" , \\\"JP Morgan Chase\\\"),\\n\",\n    \"        (\\\"IRWIN MORTGAGE, CORPORATION\\\" , \\\"Irwin Mortgage\\\"),\\n\",\n    \"        (\\\"IMPAC MORTGAGE CORP.\\\" , \\\"Impac Mortgage\\\"),\\n\",\n    \"        (\\\"HSBC BANK USA, NATIONAL ASSOCIATION\\\" , \\\"HSBC\\\"),\\n\",\n    \"        (\\\"HOMEWARD RESIDENTIAL, INC.\\\" , \\\"Homeward Mortgage\\\"),\\n\",\n    \"        (\\\"HOMESTREET BANK\\\" , \\\"Other\\\"),\\n\",\n    \"        (\\\"HOMEBRIDGE FINANCIAL SERVICES, INC.\\\" , \\\"HomeBridge\\\"),\\n\",\n    \"        (\\\"HARWOOD STREET FUNDING I, LLC\\\" , \\\"Harwood Mortgage\\\"),\\n\",\n    \"        (\\\"GUILD MORTGAGE COMPANY\\\" , \\\"Guild Mortgage\\\"),\\n\",\n    \"        (\\\"GMAC MORTGAGE, LLC (USAA FEDERAL SAVINGS BANK)\\\" , \\\"GMAC\\\"),\\n\",\n    \"        (\\\"GMAC MORTGAGE, LLC\\\" , \\\"GMAC\\\"),\\n\",\n    \"        (\\\"GMAC (USAA)\\\" , \\\"GMAC\\\"),\\n\",\n    \"        (\\\"FREMONT BANK\\\" , \\\"Fremont Bank\\\"),\\n\",\n    \"        (\\\"FREEDOM MORTGAGE CORP.\\\" , \\\"Freedom Mortgage\\\"),\\n\",\n    \"        (\\\"FRANKLIN AMERICAN MORTGAGE COMPANY\\\" , \\\"Franklin America\\\"),\\n\",\n    \"        (\\\"FLEET NATIONAL BANK\\\" , \\\"Fleet National\\\"),\\n\",\n    \"        (\\\"FLAGSTAR CAPITAL MARKETS CORPORATION\\\" , \\\"Flagstar Bank\\\"),\\n\",\n    \"        (\\\"FLAGSTAR BANK, FSB\\\" , \\\"Flagstar Bank\\\"),\\n\",\n    \"        (\\\"FIRST TENNESSEE BANK NATIONAL ASSOCIATION\\\" , \\\"Other\\\"),\\n\",\n    \"        (\\\"FIFTH THIRD BANK\\\" , \\\"Fifth Third Bank\\\"),\\n\",\n    \"        (\\\"FEDERAL HOME LOAN BANK OF CHICAGO\\\" , \\\"Fedral Home of Chicago\\\"),\\n\",\n    \"        (\\\"FDIC, RECEIVER, INDYMAC FEDERAL BANK FSB\\\" , \\\"FDIC\\\"),\\n\",\n    \"        (\\\"DOWNEY SAVINGS AND LOAN ASSOCIATION, F.A.\\\" , \\\"Downey Mortgage\\\"),\\n\",\n    \"        (\\\"DITECH FINANCIAL LLC\\\" , \\\"Ditech\\\"),\\n\",\n    \"        (\\\"CITIMORTGAGE, INC.\\\" , \\\"Citi\\\"),\\n\",\n    \"        (\\\"CHICAGO MORTGAGE SOLUTIONS DBA INTERFIRST MORTGAGE COMPANY\\\" , \\\"Chicago Mortgage\\\"),\\n\",\n    \"        (\\\"CHICAGO MORTGAGE SOLUTIONS DBA INTERBANK MORTGAGE COMPANY\\\" , \\\"Chicago Mortgage\\\"),\\n\",\n    \"        (\\\"CHASE HOME FINANCE, LLC\\\" , \\\"JP Morgan Chase\\\"),\\n\",\n    \"        (\\\"CHASE HOME FINANCE FRANKLIN AMERICAN MORTGAGE COMPANY\\\" , \\\"JP Morgan Chase\\\"),\\n\",\n    \"        (\\\"CHASE HOME FINANCE (CIE 1)\\\" , \\\"JP Morgan Chase\\\"),\\n\",\n    \"        (\\\"CHASE HOME FINANCE\\\" , \\\"JP Morgan Chase\\\"),\\n\",\n    \"        (\\\"CASHCALL, INC.\\\" , \\\"CashCall\\\"),\\n\",\n    \"        (\\\"CAPITAL ONE, NATIONAL ASSOCIATION\\\" , \\\"Capital One\\\"),\\n\",\n    \"        (\\\"CALIBER HOME LOANS, INC.\\\" , \\\"Caliber Funding\\\"),\\n\",\n    \"        (\\\"BISHOPS GATE RESIDENTIAL MORTGAGE TRUST\\\" , \\\"Bishops Gate Mortgage\\\"),\\n\",\n    \"        (\\\"BANK OF AMERICA, N.A.\\\" , \\\"Bank of America\\\"),\\n\",\n    \"        (\\\"AMTRUST BANK\\\" , \\\"AmTrust\\\"),\\n\",\n    \"        (\\\"AMERISAVE MORTGAGE CORPORATION\\\" , \\\"Amerisave\\\"),\\n\",\n    \"        (\\\"AMERIHOME MORTGAGE COMPANY, LLC\\\" , \\\"AmeriHome Mortgage\\\"),\\n\",\n    \"        (\\\"ALLY BANK\\\" , \\\"Ally Bank\\\"),\\n\",\n    \"        (\\\"ACADEMY MORTGAGE CORPORATION\\\" , \\\"Academy Mortgage\\\"),\\n\",\n    \"        (\\\"NO CASH-OUT REFINANCE\\\" , \\\"OTHER REFINANCE\\\"),\\n\",\n    \"        (\\\"REFINANCE - NOT SPECIFIED\\\" , \\\"OTHER REFINANCE\\\"),\\n\",\n    \"        (\\\"Other REFINANCE\\\" , \\\"OTHER REFINANCE\\\")]\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"* Define category (string) column and numeric column\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 22,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# String columns\\n\",\n    \"cate_col_names = [\\n\",\n    \"        \\\"orig_channel\\\",\\n\",\n    \"        \\\"first_home_buyer\\\",\\n\",\n    \"        \\\"loan_purpose\\\",\\n\",\n    \"        \\\"property_type\\\",\\n\",\n    \"        \\\"occupancy_status\\\",\\n\",\n    \"        \\\"property_state\\\",\\n\",\n    \"        \\\"product_type\\\",\\n\",\n    \"        \\\"relocation_mortgage_indicator\\\",\\n\",\n    \"        \\\"seller_name\\\",\\n\",\n    \"        \\\"mod_flag\\\"\\n\",\n    \"]\\n\",\n    \"# Numberic columns\\n\",\n    \"label_col_name = \\\"delinquency_12\\\"\\n\",\n    \"numeric_col_names = [\\n\",\n    \"        \\\"orig_interest_rate\\\",\\n\",\n    \"        \\\"orig_upb\\\",\\n\",\n    \"        \\\"orig_loan_term\\\",\\n\",\n    \"        \\\"orig_ltv\\\",\\n\",\n    \"        \\\"orig_cltv\\\",\\n\",\n    \"        \\\"num_borrowers\\\",\\n\",\n    \"        \\\"dti\\\",\\n\",\n    \"        \\\"borrower_credit_score\\\",\\n\",\n    \"        \\\"num_units\\\",\\n\",\n    \"        \\\"zip\\\",\\n\",\n    \"        \\\"mortgage_insurance_percent\\\",\\n\",\n    \"        \\\"current_loan_delinquency_status\\\",\\n\",\n    \"        \\\"current_actual_upb\\\",\\n\",\n    \"        \\\"interest_rate\\\",\\n\",\n    \"        \\\"loan_age\\\",\\n\",\n    \"        \\\"msa\\\",\\n\",\n    \"        \\\"non_interest_bearing_upb\\\",\\n\",\n    \"        label_col_name\\n\",\n    \"]\\n\",\n    \"all_col_names = cate_col_names + numeric_col_names\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"* Functions to extract perf and acq columns from raw schema\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 43,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def extract_perf_columns(rawDf):\\n\",\n    \"    perfDf = rawDf.select(\\n\",\n    \"      col(\\\"loan_id\\\"),\\n\",\n    \"      date_format(to_date(col(\\\"monthly_reporting_period\\\"),\\\"MMyyyy\\\"), \\\"MM/dd/yyyy\\\").alias(\\\"monthly_reporting_period\\\"),\\n\",\n    \"      upper(col(\\\"servicer\\\")).alias(\\\"servicer\\\"),\\n\",\n    \"      col(\\\"interest_rate\\\"),\\n\",\n    \"      col(\\\"current_actual_upb\\\"),\\n\",\n    \"      col(\\\"loan_age\\\"),\\n\",\n    \"      col(\\\"remaining_months_to_legal_maturity\\\"),\\n\",\n    \"      col(\\\"adj_remaining_months_to_maturity\\\"),\\n\",\n    \"      date_format(to_date(col(\\\"maturity_date\\\"),\\\"MMyyyy\\\"), \\\"MM/yyyy\\\").alias(\\\"maturity_date\\\"),\\n\",\n    \"      col(\\\"msa\\\"),\\n\",\n    \"      col(\\\"current_loan_delinquency_status\\\"),\\n\",\n    \"      col(\\\"mod_flag\\\"),\\n\",\n    \"      col(\\\"zero_balance_code\\\"),\\n\",\n    \"      date_format(to_date(col(\\\"zero_balance_effective_date\\\"),\\\"MMyyyy\\\"), \\\"MM/yyyy\\\").alias(\\\"zero_balance_effective_date\\\"),\\n\",\n    \"      date_format(to_date(col(\\\"last_paid_installment_date\\\"),\\\"MMyyyy\\\"), \\\"MM/dd/yyyy\\\").alias(\\\"last_paid_installment_date\\\"),\\n\",\n    \"      date_format(to_date(col(\\\"foreclosed_after\\\"),\\\"MMyyyy\\\"), \\\"MM/dd/yyyy\\\").alias(\\\"foreclosed_after\\\"),\\n\",\n    \"      date_format(to_date(col(\\\"disposition_date\\\"),\\\"MMyyyy\\\"), \\\"MM/dd/yyyy\\\").alias(\\\"disposition_date\\\"),\\n\",\n    \"      col(\\\"foreclosure_costs\\\"),\\n\",\n    \"      col(\\\"prop_preservation_and_repair_costs\\\"),\\n\",\n    \"      col(\\\"asset_recovery_costs\\\"),\\n\",\n    \"      col(\\\"misc_holding_expenses\\\"),\\n\",\n    \"      col(\\\"holding_taxes\\\"),\\n\",\n    \"      col(\\\"net_sale_proceeds\\\"),\\n\",\n    \"      col(\\\"credit_enhancement_proceeds\\\"),\\n\",\n    \"      col(\\\"repurchase_make_whole_proceeds\\\"),\\n\",\n    \"      col(\\\"other_foreclosure_proceeds\\\"),\\n\",\n    \"      col(\\\"non_interest_bearing_upb\\\"),\\n\",\n    \"      col(\\\"principal_forgiveness_upb\\\"),\\n\",\n    \"      col(\\\"repurchase_make_whole_proceeds_flag\\\"),\\n\",\n    \"      col(\\\"foreclosure_principal_write_off_amount\\\"),\\n\",\n    \"      col(\\\"servicing_activity_indicator\\\"),\\n\",\n    \"      col('quarter')\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    return perfDf.select(\\\"*\\\").filter(\\\"current_actual_upb != 0.0\\\")\\n\",\n    \"\\n\",\n    \"def extract_acq_columns(rawDf):\\n\",\n    \"    acqDf = rawDf.select(\\n\",\n    \"      col(\\\"loan_id\\\"),\\n\",\n    \"      col(\\\"orig_channel\\\"),\\n\",\n    \"      upper(col(\\\"seller_name\\\")).alias(\\\"seller_name\\\"),\\n\",\n    \"      col(\\\"orig_interest_rate\\\"),\\n\",\n    \"      col(\\\"orig_upb\\\"),\\n\",\n    \"      col(\\\"orig_loan_term\\\"),\\n\",\n    \"      date_format(to_date(col(\\\"orig_date\\\"),\\\"MMyyyy\\\"), \\\"MM/yyyy\\\").alias(\\\"orig_date\\\"),\\n\",\n    \"      date_format(to_date(col(\\\"first_pay_date\\\"),\\\"MMyyyy\\\"), \\\"MM/yyyy\\\").alias(\\\"first_pay_date\\\"),\\n\",\n    \"      col(\\\"orig_ltv\\\"),\\n\",\n    \"      col(\\\"orig_cltv\\\"),\\n\",\n    \"      col(\\\"num_borrowers\\\"),\\n\",\n    \"      col(\\\"dti\\\"),\\n\",\n    \"      col(\\\"borrower_credit_score\\\"),\\n\",\n    \"      col(\\\"first_home_buyer\\\"),\\n\",\n    \"      col(\\\"loan_purpose\\\"),\\n\",\n    \"      col(\\\"property_type\\\"),\\n\",\n    \"      col(\\\"num_units\\\"),\\n\",\n    \"      col(\\\"occupancy_status\\\"),\\n\",\n    \"      col(\\\"property_state\\\"),\\n\",\n    \"      col(\\\"zip\\\"),\\n\",\n    \"      col(\\\"mortgage_insurance_percent\\\"),\\n\",\n    \"      col(\\\"product_type\\\"),\\n\",\n    \"      col(\\\"coborrow_credit_score\\\"),\\n\",\n    \"      col(\\\"mortgage_insurance_type\\\"),\\n\",\n    \"      col(\\\"relocation_mortgage_indicator\\\"),\\n\",\n    \"      dense_rank().over(Window.partitionBy(\\\"loan_id\\\").orderBy(to_date(col(\\\"monthly_reporting_period\\\"),\\\"MMyyyy\\\"))).alias(\\\"rank\\\"),\\n\",\n    \"      col('quarter')\\n\",\n    \"      )\\n\",\n    \"\\n\",\n    \"    return acqDf.select(\\\"*\\\").filter(col(\\\"rank\\\")==1)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### 2. Define ETL Process\\n\",\n    \"\\n\",\n    \"Define the function to do the ETL process\\n\",\n    \"\\n\",\n    \"#### 2.1 Define Functions to Read Raw CSV File\\n\",\n    \"\\n\",\n    \"* Define function to get quarter from input CSV file name\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 44,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def _get_quarter_from_csv_file_name():\\n\",\n    \"    return substring_index(substring_index(input_file_name(), '.', 1), '/', -1)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"* Define function to read raw CSV data file\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 45,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def read_raw_csv(spark, path):\\n\",\n    \"    return spark.read.format('csv') \\\\\\n\",\n    \"            .option('nullValue', '') \\\\\\n\",\n    \"            .option('header', False) \\\\\\n\",\n    \"            .option('delimiter', '|') \\\\\\n\",\n    \"            .schema(_csv_raw_schema) \\\\\\n\",\n    \"            .load(path) \\\\\\n\",\n    \"            .withColumn('quarter', _get_quarter_from_csv_file_name())\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### 2.2 Define ETL Process\\n\",\n    \"\\n\",\n    \"* Define function to parse dates in Performance data\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 48,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def _parse_dates(perf):\\n\",\n    \"    return perf \\\\\\n\",\n    \"            .withColumn('monthly_reporting_period', to_date(col('monthly_reporting_period'), 'MM/dd/yyyy')) \\\\\\n\",\n    \"            .withColumn('monthly_reporting_period_month', month(col('monthly_reporting_period'))) \\\\\\n\",\n    \"            .withColumn('monthly_reporting_period_year', year(col('monthly_reporting_period'))) \\\\\\n\",\n    \"            .withColumn('monthly_reporting_period_day', dayofmonth(col('monthly_reporting_period'))) \\\\\\n\",\n    \"            .withColumn('last_paid_installment_date', to_date(col('last_paid_installment_date'), 'MM/dd/yyyy')) \\\\\\n\",\n    \"            .withColumn('foreclosed_after', to_date(col('foreclosed_after'), 'MM/dd/yyyy')) \\\\\\n\",\n    \"            .withColumn('disposition_date', to_date(col('disposition_date'), 'MM/dd/yyyy')) \\\\\\n\",\n    \"            .withColumn('maturity_date', to_date(col('maturity_date'), 'MM/yyyy')) \\\\\\n\",\n    \"            .withColumn('zero_balance_effective_date', to_date(col('zero_balance_effective_date'), 'MM/yyyy'))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"* Define function to create deliquency data frame from Performance data\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 49,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def _create_perf_deliquency(spark, perf):\\n\",\n    \"    aggDF = perf.select(\\n\",\n    \"            col(\\\"quarter\\\"),\\n\",\n    \"            col(\\\"loan_id\\\"),\\n\",\n    \"            col(\\\"current_loan_delinquency_status\\\"),\\n\",\n    \"            when(col(\\\"current_loan_delinquency_status\\\") >= 1, col(\\\"monthly_reporting_period\\\")).alias(\\\"delinquency_30\\\"),\\n\",\n    \"            when(col(\\\"current_loan_delinquency_status\\\") >= 3, col(\\\"monthly_reporting_period\\\")).alias(\\\"delinquency_90\\\"),\\n\",\n    \"            when(col(\\\"current_loan_delinquency_status\\\") >= 6, col(\\\"monthly_reporting_period\\\")).alias(\\\"delinquency_180\\\")) \\\\\\n\",\n    \"            .groupBy(\\\"quarter\\\", \\\"loan_id\\\") \\\\\\n\",\n    \"            .agg(\\n\",\n    \"                max(\\\"current_loan_delinquency_status\\\").alias(\\\"delinquency_12\\\"),\\n\",\n    \"                min(\\\"delinquency_30\\\").alias(\\\"delinquency_30\\\"),\\n\",\n    \"                min(\\\"delinquency_90\\\").alias(\\\"delinquency_90\\\"),\\n\",\n    \"                min(\\\"delinquency_180\\\").alias(\\\"delinquency_180\\\")) \\\\\\n\",\n    \"            .select(\\n\",\n    \"                col(\\\"quarter\\\"),\\n\",\n    \"                col(\\\"loan_id\\\"),\\n\",\n    \"                (col(\\\"delinquency_12\\\") >= 1).alias(\\\"ever_30\\\"),\\n\",\n    \"                (col(\\\"delinquency_12\\\") >= 3).alias(\\\"ever_90\\\"),\\n\",\n    \"                (col(\\\"delinquency_12\\\") >= 6).alias(\\\"ever_180\\\"),\\n\",\n    \"                col(\\\"delinquency_30\\\"),\\n\",\n    \"                col(\\\"delinquency_90\\\"),\\n\",\n    \"                col(\\\"delinquency_180\\\"))\\n\",\n    \"    joinedDf = perf \\\\\\n\",\n    \"            .withColumnRenamed(\\\"monthly_reporting_period\\\", \\\"timestamp\\\") \\\\\\n\",\n    \"            .withColumnRenamed(\\\"monthly_reporting_period_month\\\", \\\"timestamp_month\\\") \\\\\\n\",\n    \"            .withColumnRenamed(\\\"monthly_reporting_period_year\\\", \\\"timestamp_year\\\") \\\\\\n\",\n    \"            .withColumnRenamed(\\\"current_loan_delinquency_status\\\", \\\"delinquency_12\\\") \\\\\\n\",\n    \"            .withColumnRenamed(\\\"current_actual_upb\\\", \\\"upb_12\\\") \\\\\\n\",\n    \"            .select(\\\"quarter\\\", \\\"loan_id\\\", \\\"timestamp\\\", \\\"delinquency_12\\\", \\\"upb_12\\\", \\\"timestamp_month\\\", \\\"timestamp_year\\\") \\\\\\n\",\n    \"            .join(aggDF, [\\\"loan_id\\\", \\\"quarter\\\"], \\\"left_outer\\\")\\n\",\n    \"\\n\",\n    \"    # calculate the 12 month delinquency and upb values\\n\",\n    \"    months = 12\\n\",\n    \"    monthArray = [lit(x) for x in range(0, 12)]\\n\",\n    \"    # explode on a small amount of data is actually slightly more efficient than a cross join\\n\",\n    \"    testDf = joinedDf \\\\\\n\",\n    \"            .withColumn(\\\"month_y\\\", explode(array(monthArray))) \\\\\\n\",\n    \"            .select(\\n\",\n    \"                    col(\\\"quarter\\\"),\\n\",\n    \"                    floor(((col(\\\"timestamp_year\\\") * 12 + col(\\\"timestamp_month\\\")) - 24000) / months).alias(\\\"josh_mody\\\"),\\n\",\n    \"                    floor(((col(\\\"timestamp_year\\\") * 12 + col(\\\"timestamp_month\\\")) - 24000 - col(\\\"month_y\\\")) / months).alias(\\\"josh_mody_n\\\"),\\n\",\n    \"                    col(\\\"ever_30\\\"),\\n\",\n    \"                    col(\\\"ever_90\\\"),\\n\",\n    \"                    col(\\\"ever_180\\\"),\\n\",\n    \"                    col(\\\"delinquency_30\\\"),\\n\",\n    \"                    col(\\\"delinquency_90\\\"),\\n\",\n    \"                    col(\\\"delinquency_180\\\"),\\n\",\n    \"                    col(\\\"loan_id\\\"),\\n\",\n    \"                    col(\\\"month_y\\\"),\\n\",\n    \"                    col(\\\"delinquency_12\\\"),\\n\",\n    \"                    col(\\\"upb_12\\\")) \\\\\\n\",\n    \"            .groupBy(\\\"quarter\\\", \\\"loan_id\\\", \\\"josh_mody_n\\\", \\\"ever_30\\\", \\\"ever_90\\\", \\\"ever_180\\\", \\\"delinquency_30\\\", \\\"delinquency_90\\\", \\\"delinquency_180\\\", \\\"month_y\\\") \\\\\\n\",\n    \"            .agg(max(\\\"delinquency_12\\\").alias(\\\"delinquency_12\\\"), min(\\\"upb_12\\\").alias(\\\"upb_12\\\")) \\\\\\n\",\n    \"            .withColumn(\\\"timestamp_year\\\", floor((lit(24000) + (col(\\\"josh_mody_n\\\") * lit(months)) + (col(\\\"month_y\\\") - 1)) / lit(12))) \\\\\\n\",\n    \"            .selectExpr('*', 'pmod(24000 + (josh_mody_n * {}) + month_y, 12) as timestamp_month_tmp'.format(months)) \\\\\\n\",\n    \"            .withColumn(\\\"timestamp_month\\\", when(col(\\\"timestamp_month_tmp\\\") == lit(0), lit(12)).otherwise(col(\\\"timestamp_month_tmp\\\"))) \\\\\\n\",\n    \"            .withColumn(\\\"delinquency_12\\\", ((col(\\\"delinquency_12\\\") > 3).cast(\\\"int\\\") + (col(\\\"upb_12\\\") == 0).cast(\\\"int\\\")).alias(\\\"delinquency_12\\\")) \\\\\\n\",\n    \"            .drop(\\\"timestamp_month_tmp\\\", \\\"josh_mody_n\\\", \\\"month_y\\\")\\n\",\n    \"\\n\",\n    \"    return perf.withColumnRenamed(\\\"monthly_reporting_period_month\\\", \\\"timestamp_month\\\") \\\\\\n\",\n    \"            .withColumnRenamed(\\\"monthly_reporting_period_year\\\", \\\"timestamp_year\\\") \\\\\\n\",\n    \"            .join(testDf, [\\\"quarter\\\", \\\"loan_id\\\", \\\"timestamp_year\\\", \\\"timestamp_month\\\"], \\\"left\\\") \\\\\\n\",\n    \"            .drop(\\\"timestamp_year\\\", \\\"timestamp_month\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"* Define function to create acquisition data frame from Acquisition data\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 50,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def _create_acquisition(spark, acq):\\n\",\n    \"    nameMapping = spark.createDataFrame(_name_mapping, [\\\"from_seller_name\\\", \\\"to_seller_name\\\"])\\n\",\n    \"    return acq.join(nameMapping, col(\\\"seller_name\\\") == col(\\\"from_seller_name\\\"), \\\"left\\\") \\\\\\n\",\n    \"      .drop(\\\"from_seller_name\\\") \\\\\\n\",\n    \"      .withColumn(\\\"old_name\\\", col(\\\"seller_name\\\")) \\\\\\n\",\n    \"      .withColumn(\\\"seller_name\\\", coalesce(col(\\\"to_seller_name\\\"), col(\\\"seller_name\\\"))) \\\\\\n\",\n    \"      .drop(\\\"to_seller_name\\\") \\\\\\n\",\n    \"      .withColumn(\\\"orig_date\\\", to_date(col(\\\"orig_date\\\"), \\\"MM/yyyy\\\")) \\\\\\n\",\n    \"      .withColumn(\\\"first_pay_date\\\", to_date(col(\\\"first_pay_date\\\"), \\\"MM/yyyy\\\")) \"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### 2.3 Define Casting Process\\n\",\n    \"This part is casting String column to Numbric. \\n\",\n    \"Example:\\n\",\n    \"```\\n\",\n    \"col_1\\n\",\n    \" \\\"a\\\"\\n\",\n    \" \\\"b\\\"\\n\",\n    \" \\\"c\\\"\\n\",\n    \" \\\"a\\\"\\n\",\n    \"# After String ====> Numberic\\n\",\n    \"col_1\\n\",\n    \" 0\\n\",\n    \" 1\\n\",\n    \" 2\\n\",\n    \" 0\\n\",\n    \"```  \\n\",\n    \"<br>\\n\",\n    \"\\n\",\n    \"* Define function to get column dictionary\\n\",\n    \"\\n\",\n    \"    Example\\n\",\n    \"    ```\\n\",\n    \"    col1 = [row(data=\\\"a\\\",id=0), row(data=\\\"b\\\",id=1)]\\n\",\n    \"    ```\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 51,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def _gen_dictionary(etl_df, col_names):\\n\",\n    \"    cnt_table = etl_df.select(posexplode(array([col(i) for i in col_names])))\\\\\\n\",\n    \"                    .withColumnRenamed(\\\"pos\\\", \\\"column_id\\\")\\\\\\n\",\n    \"                    .withColumnRenamed(\\\"col\\\", \\\"data\\\")\\\\\\n\",\n    \"                    .filter(\\\"data is not null\\\")\\\\\\n\",\n    \"                    .groupBy(\\\"column_id\\\", \\\"data\\\")\\\\\\n\",\n    \"                    .count()\\n\",\n    \"    windowed = Window.partitionBy(\\\"column_id\\\").orderBy(desc(\\\"count\\\"))\\n\",\n    \"    return cnt_table.withColumn(\\\"id\\\", row_number().over(windowed)).drop(\\\"count\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"* Define function to convert string columns to numeric\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 52,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def _cast_string_columns_to_numeric(spark, input_df):\\n\",\n    \"    cached_dict_df = _gen_dictionary(input_df, cate_col_names).cache()\\n\",\n    \"    output_df = input_df\\n\",\n    \"    #  Generate the final table with all columns being numeric.\\n\",\n    \"    for col_pos, col_name in enumerate(cate_col_names):\\n\",\n    \"        col_dict_df = cached_dict_df.filter(col(\\\"column_id\\\") == col_pos)\\\\\\n\",\n    \"                                    .drop(\\\"column_id\\\")\\\\\\n\",\n    \"                                    .withColumnRenamed(\\\"data\\\", col_name)\\n\",\n    \"        \\n\",\n    \"        output_df = output_df.join(broadcast(col_dict_df), col_name, \\\"left\\\")\\\\\\n\",\n    \"                        .drop(col_name)\\\\\\n\",\n    \"                        .withColumnRenamed(\\\"id\\\", col_name)\\n\",\n    \"    return output_df        \"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### 2.4 Define Main Function\\n\",\n    \"In this function:\\n\",\n    \"1. Parse date in Performance data by calling _parse_dates (parsed_perf)\\n\",\n    \"2. Create deliqency dataframe(perf_deliqency) form Performance data by calling _create_perf_deliquency\\n\",\n    \"3. Create cleaned acquisition dataframe(cleaned_acq) from Acquisition data by calling _create_acquisition\\n\",\n    \"4. Join deliqency dataframe(perf_deliqency) and cleaned acquisition dataframe(cleaned_acq), get clean_df\\n\",\n    \"5. Cast String column to Numbric in clean_df by calling _cast_string_columns_to_numeric, get casted_clean_df\\n\",\n    \"6. Return casted_clean_df as final result\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 53,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def run_mortgage(spark, perf, acq):\\n\",\n    \"    parsed_perf = _parse_dates(perf)\\n\",\n    \"    perf_deliqency = _create_perf_deliquency(spark, parsed_perf)\\n\",\n    \"    cleaned_acq = _create_acquisition(spark, acq)\\n\",\n    \"    clean_df = perf_deliqency.join(cleaned_acq, [\\\"loan_id\\\", \\\"quarter\\\"], \\\"inner\\\").drop(\\\"quarter\\\")\\n\",\n    \"    casted_clean_df = _cast_string_columns_to_numeric(spark, clean_df)\\\\\\n\",\n    \"                    .select(all_col_names)\\\\\\n\",\n    \"                    .withColumn(label_col_name, when(col(label_col_name) > 0, 1).otherwise(0))\\\\\\n\",\n    \"                    .fillna(float(0))\\n\",\n    \"    return casted_clean_df\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Script Settings\\n\",\n    \"\\n\",\n    \"### 1. File Path Settings\\n\",\n    \"* Define input file path\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 54,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# You need to update them to your real paths!\\n\",\n    \"dataRoot = os.getenv(\\\"DATA_ROOT\\\", \\\"/data\\\")\\n\",\n    \"orig_raw_path = dataRoot + '/mortgage/input/'\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"* Define output folder path\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 56,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"output_path = dataRoot + '/mortgage/output/data/'\\n\",\n    \"output_csv2parquet = dataRoot + '/mortgage/output/csv2parquet/'\\n\",\n    \"output_path_train = dataRoot + '/mortgage/output/train/'\\n\",\n    \"output_path_eval = dataRoot + '/mortgage/output/eval/'\\n\",\n    \"save_train_eval_dataset = True\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### 2. Common Spark Settings\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 57,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"spark.conf.set('spark.rapids.sql.explain', 'ALL')\\n\",\n    \"spark.conf.set('spark.rapids.sql.batchSizeBytes', '512M')\\n\",\n    \"spark.conf.set('spark.rapids.sql.reader.batchSizeBytes', '768M')\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Run Part\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Read Raw File\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"rawDf = read_raw_csv(spark, orig_raw_path)\\n\",\n    \"rawDf.write.parquet(output_csv2parquet, mode='overwrite')\\n\",\n    \"rawDf = spark.read.parquet(output_csv2parquet)\\n\",\n    \"\\n\",\n    \"acq = extract_acq_columns(rawDf)\\n\",\n    \"perf = extract_perf_columns(rawDf)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Run ETL\\n\",\n    \"#### 1. Add additional Spark settings\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 60,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# GPU run, set to true\\n\",\n    \"spark.conf.set('spark.rapids.sql.enabled', 'true')\\n\",\n    \"# CPU run, set to false\\n\",\n    \"# spark.conf.set('spark.rapids.sql.enabled', 'false')\\n\",\n    \"spark.conf.set('spark.sql.files.maxPartitionBytes', '1G')\\n\",\n    \"# use GPU to read CSV\\n\",\n    \"spark.conf.set(\\\"spark.rapids.sql.csv.read.double.enabled\\\", \\\"true\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### 2.Read Parquet File and Run ETL Process, Save the Result\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 61,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"== Physical Plan ==\\n\",\n      \"GpuColumnarToRow false\\n\",\n      \"+- GpuProject [gpucoalesce(orig_channel#3146, 0) AS orig_channel#5143, gpucoalesce(first_home_buyer#3351, 0) AS first_home_buyer#5144, gpucoalesce(loan_purpose#3556, 0) AS loan_purpose#5145, gpucoalesce(property_type#3761, 0) AS property_type#5146, gpucoalesce(occupancy_status#3966, 0) AS occupancy_status#5147, gpucoalesce(property_state#4171, 0) AS property_state#5148, gpucoalesce(product_type#4376, 0) AS product_type#5149, gpucoalesce(relocation_mortgage_indicator#4581, 0) AS relocation_mortgage_indicator#5150, gpucoalesce(seller_name#4786, 0) AS seller_name#5151, gpucoalesce(id#2956, 0) AS mod_flag#5152, gpucoalesce(gpunanvl(orig_interest_rate#1606, null), 0.0) AS orig_interest_rate#5153, gpucoalesce(orig_upb#1607, 0) AS orig_upb#5154, gpucoalesce(orig_loan_term#1608, 0) AS orig_loan_term#5155, gpucoalesce(gpunanvl(orig_ltv#1611, null), 0.0) AS orig_ltv#5156, gpucoalesce(gpunanvl(orig_cltv#1612, null), 0.0) AS orig_cltv#5157, gpucoalesce(gpunanvl(num_borrowers#1613, null), 0.0) AS num_borrowers#5158, gpucoalesce(gpunanvl(dti#1614, null), 0.0) AS dti#5159, gpucoalesce(gpunanvl(borrower_credit_score#1615, null), 0.0) AS borrower_credit_score#5160, gpucoalesce(num_units#1619, 0) AS num_units#5161, gpucoalesce(zip#1622, 0) AS zip#5162, gpucoalesce(gpunanvl(mortgage_insurance_percent#1623, null), 0.0) AS mortgage_insurance_percent#5163, gpucoalesce(current_loan_delinquency_status#1549, 0) AS current_loan_delinquency_status#5164, gpucoalesce(gpunanvl(current_actual_upb#1543, null), 0.0) AS current_actual_upb#5165, gpucoalesce(gpunanvl(interest_rate#1542, null), 0.0) AS interest_rate#5166, ... 4 more fields]\\n\",\n      \"   +- GpuBroadcastHashJoin [mod_flag#1550], [mod_flag#4855], LeftOuter, GpuBuildRight\\n\",\n      \"      :- GpuProject [interest_rate#1542, current_actual_upb#1543, loan_age#1544, msa#1548, current_loan_delinquency_status#1549, mod_flag#1550, non_interest_bearing_upb#1565, delinquency_12#2353, orig_interest_rate#1606, orig_upb#1607, orig_loan_term#1608, orig_ltv#1611, orig_cltv#1612, num_borrowers#1613, dti#1614, borrower_credit_score#1615, num_units#1619, zip#1622, mortgage_insurance_percent#1623, orig_channel#3146, first_home_buyer#3351, loan_purpose#3556, property_type#3761, occupancy_status#3966, ... 4 more fields]\\n\",\n      \"      :  +- GpuBroadcastHashJoin [seller_name#2689], [seller_name#4650], LeftOuter, GpuBuildRight\\n\",\n      \"      :     :- GpuProject [interest_rate#1542, current_actual_upb#1543, loan_age#1544, msa#1548, current_loan_delinquency_status#1549, mod_flag#1550, non_interest_bearing_upb#1565, delinquency_12#2353, seller_name#2689, orig_interest_rate#1606, orig_upb#1607, orig_loan_term#1608, orig_ltv#1611, orig_cltv#1612, num_borrowers#1613, dti#1614, borrower_credit_score#1615, num_units#1619, zip#1622, mortgage_insurance_percent#1623, orig_channel#3146, first_home_buyer#3351, loan_purpose#3556, property_type#3761, ... 4 more fields]\\n\",\n      \"      :     :  +- GpuBroadcastHashJoin [relocation_mortgage_indicator#1627], [relocation_mortgage_indicator#4445], LeftOuter, GpuBuildRight\\n\",\n      \"      :     :     :- GpuProject [interest_rate#1542, current_actual_upb#1543, loan_age#1544, msa#1548, current_loan_delinquency_status#1549, mod_flag#1550, non_interest_bearing_upb#1565, delinquency_12#2353, seller_name#2689, orig_interest_rate#1606, orig_upb#1607, orig_loan_term#1608, orig_ltv#1611, orig_cltv#1612, num_borrowers#1613, dti#1614, borrower_credit_score#1615, num_units#1619, zip#1622, mortgage_insurance_percent#1623, relocation_mortgage_indicator#1627, orig_channel#3146, first_home_buyer#3351, loan_purpose#3556, ... 4 more fields]\\n\",\n      \"      :     :     :  +- GpuBroadcastHashJoin [product_type#1624], [product_type#4240], LeftOuter, GpuBuildRight\\n\",\n      \"      :     :     :     :- GpuProject [interest_rate#1542, current_actual_upb#1543, loan_age#1544, msa#1548, current_loan_delinquency_status#1549, mod_flag#1550, non_interest_bearing_upb#1565, delinquency_12#2353, seller_name#2689, orig_interest_rate#1606, orig_upb#1607, orig_loan_term#1608, orig_ltv#1611, orig_cltv#1612, num_borrowers#1613, dti#1614, borrower_credit_score#1615, num_units#1619, zip#1622, mortgage_insurance_percent#1623, product_type#1624, relocation_mortgage_indicator#1627, orig_channel#3146, first_home_buyer#3351, ... 4 more fields]\\n\",\n      \"      :     :     :     :  +- GpuBroadcastHashJoin [property_state#1621], [property_state#4035], LeftOuter, GpuBuildRight\\n\",\n      \"      :     :     :     :     :- GpuProject [interest_rate#1542, current_actual_upb#1543, loan_age#1544, msa#1548, current_loan_delinquency_status#1549, mod_flag#1550, non_interest_bearing_upb#1565, delinquency_12#2353, seller_name#2689, orig_interest_rate#1606, orig_upb#1607, orig_loan_term#1608, orig_ltv#1611, orig_cltv#1612, num_borrowers#1613, dti#1614, borrower_credit_score#1615, num_units#1619, property_state#1621, zip#1622, mortgage_insurance_percent#1623, product_type#1624, relocation_mortgage_indicator#1627, orig_channel#3146, ... 4 more fields]\\n\",\n      \"      :     :     :     :     :  +- GpuBroadcastHashJoin [occupancy_status#1620], [occupancy_status#3830], LeftOuter, GpuBuildRight\\n\",\n      \"      :     :     :     :     :     :- GpuProject [interest_rate#1542, current_actual_upb#1543, loan_age#1544, msa#1548, current_loan_delinquency_status#1549, mod_flag#1550, non_interest_bearing_upb#1565, delinquency_12#2353, seller_name#2689, orig_interest_rate#1606, orig_upb#1607, orig_loan_term#1608, orig_ltv#1611, orig_cltv#1612, num_borrowers#1613, dti#1614, borrower_credit_score#1615, num_units#1619, occupancy_status#1620, property_state#1621, zip#1622, mortgage_insurance_percent#1623, product_type#1624, relocation_mortgage_indicator#1627, ... 4 more fields]\\n\",\n      \"      :     :     :     :     :     :  +- GpuBroadcastHashJoin [property_type#1618], [property_type#3625], LeftOuter, GpuBuildRight\\n\",\n      \"      :     :     :     :     :     :     :- GpuProject [interest_rate#1542, current_actual_upb#1543, loan_age#1544, msa#1548, current_loan_delinquency_status#1549, mod_flag#1550, non_interest_bearing_upb#1565, delinquency_12#2353, seller_name#2689, orig_interest_rate#1606, orig_upb#1607, orig_loan_term#1608, orig_ltv#1611, orig_cltv#1612, num_borrowers#1613, dti#1614, borrower_credit_score#1615, property_type#1618, num_units#1619, occupancy_status#1620, property_state#1621, zip#1622, mortgage_insurance_percent#1623, product_type#1624, ... 4 more fields]\\n\",\n      \"      :     :     :     :     :     :     :  +- GpuBroadcastHashJoin [loan_purpose#1617], [loan_purpose#3420], LeftOuter, GpuBuildRight\\n\",\n      \"      :     :     :     :     :     :     :     :- GpuProject [interest_rate#1542, current_actual_upb#1543, loan_age#1544, msa#1548, current_loan_delinquency_status#1549, mod_flag#1550, non_interest_bearing_upb#1565, delinquency_12#2353, seller_name#2689, orig_interest_rate#1606, orig_upb#1607, orig_loan_term#1608, orig_ltv#1611, orig_cltv#1612, num_borrowers#1613, dti#1614, borrower_credit_score#1615, loan_purpose#1617, property_type#1618, num_units#1619, occupancy_status#1620, property_state#1621, zip#1622, mortgage_insurance_percent#1623, ... 4 more fields]\\n\",\n      \"      :     :     :     :     :     :     :     :  +- GpuBroadcastHashJoin [first_home_buyer#1616], [first_home_buyer#3215], LeftOuter, GpuBuildRight\\n\",\n      \"      :     :     :     :     :     :     :     :     :- GpuProject [interest_rate#1542, current_actual_upb#1543, loan_age#1544, msa#1548, current_loan_delinquency_status#1549, mod_flag#1550, non_interest_bearing_upb#1565, delinquency_12#2353, seller_name#2689, orig_interest_rate#1606, orig_upb#1607, orig_loan_term#1608, orig_ltv#1611, orig_cltv#1612, num_borrowers#1613, dti#1614, borrower_credit_score#1615, first_home_buyer#1616, loan_purpose#1617, property_type#1618, num_units#1619, occupancy_status#1620, property_state#1621, zip#1622, ... 4 more fields]\\n\",\n      \"      :     :     :     :     :     :     :     :     :  +- GpuBroadcastHashJoin [orig_channel#1604], [orig_channel#3010], LeftOuter, GpuBuildRight\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :- GpuProject [interest_rate#1542, current_actual_upb#1543, loan_age#1544, msa#1548, current_loan_delinquency_status#1549, mod_flag#1550, non_interest_bearing_upb#1565, delinquency_12#2353, orig_channel#1604, seller_name#2689, orig_interest_rate#1606, orig_upb#1607, orig_loan_term#1608, orig_ltv#1611, orig_cltv#1612, num_borrowers#1613, dti#1614, borrower_credit_score#1615, first_home_buyer#1616, loan_purpose#1617, property_type#1618, num_units#1619, occupancy_status#1620, property_state#1621, ... 4 more fields]\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :  +- GpuShuffledHashJoin [loan_id#1539L, quarter#1570], [loan_id#1603L, quarter#1629], Inner, GpuBuildRight, false\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :     :- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :     :  +- GpuColumnarExchange gpuhashpartitioning(loan_id#1539L, quarter#1570, 192), ENSURE_REQUIREMENTS, [id=#3885]\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :     :     +- GpuProject [quarter#1570, loan_id#1539L, interest_rate#1542, current_actual_upb#1543, loan_age#1544, msa#1548, current_loan_delinquency_status#1549, mod_flag#1550, non_interest_bearing_upb#1565, delinquency_12#2353]\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :     :        +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :     :           +- SortMergeJoin [quarter#1570, loan_id#1539L, cast(timestamp_year#2417 as bigint), cast(timestamp_month#2381 as bigint)], [quarter#2484, loan_id#2453L, timestamp_year#2307L, timestamp_month#2336L], LeftOuter\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :     :              :- *(2) Sort [quarter#1570 ASC NULLS FIRST, loan_id#1539L ASC NULLS FIRST, cast(timestamp_year#2417 as bigint) ASC NULLS FIRST, cast(timestamp_month#2381 as bigint) ASC NULLS FIRST], false, 0\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :     :              :  +- Exchange hashpartitioning(quarter#1570, loan_id#1539L, cast(timestamp_year#2417 as bigint), cast(timestamp_month#2381 as bigint), 192), ENSURE_REQUIREMENTS, [id=#3847]\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :     :              :     +- *(1) Project [loan_id#1539L, interest_rate#1542, current_actual_upb#1543, loan_age#1544, msa#1548, current_loan_delinquency_status#1549, mod_flag#1550, non_interest_bearing_upb#1565, quarter#1570, month(cast(gettimestamp(monthly_reporting_period#1540, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_month#2381, year(cast(gettimestamp(monthly_reporting_period#1540, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_year#2417]\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :     :              :        +- GpuColumnarToRow false\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :     :              :           +- GpuFilter (gpuisnotnull(loan_id#1539L) AND gpuisnotnull(quarter#1570)), true\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :     :              :              +- GpuFileGpuScan parquet [loan_id#1539L,monthly_reporting_period#1540,interest_rate#1542,current_actual_upb#1543,loan_age#1544,msa#1548,current_loan_delinquency_status#1549,mod_flag#1550,non_interest_bearing_upb#1565,quarter#1570] Batched: true, DataFilters: [isnotnull(loan_id#1539L), isnotnull(quarter#1570)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(loan_id), IsNotNull(quarter)], ReadSchema: struct<loan_id:bigint,monthly_reporting_period:string,interest_rate:double,current_actual_upb:dou...\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :     :              +- GpuColumnarToRow false\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :     :                 +- GpuSort [quarter#2484 ASC NULLS FIRST, loan_id#2453L ASC NULLS FIRST, timestamp_year#2307L ASC NULLS FIRST, timestamp_month#2336L ASC NULLS FIRST], false, com.nvidia.spark.rapids.OutOfCoreSort$@163d9f7d\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :     :                    +- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :     :                       +- GpuColumnarExchange gpuhashpartitioning(quarter#2484, loan_id#2453L, timestamp_year#2307L, timestamp_month#2336L, 192), ENSURE_REQUIREMENTS, [id=#3878]\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :     :                          +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :     :                             +- *(6) HashAggregate(keys=[quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248], functions=[max(delinquency_12#2123), min(upb_12#2159)])\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :     :                                +- Exchange hashpartitioning(quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248, 192), ENSURE_REQUIREMENTS, [id=#3873]\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :     :                                   +- *(5) HashAggregate(keys=[quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248], functions=[partial_max(delinquency_12#2123), partial_min(upb_12#2159)])\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :     :                                      +- *(5) Project [quarter#2484, FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) AS josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, loan_id#2453L, month_y#2248, delinquency_12#2123, upb_12#2159]\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :     :                                         +- *(5) Filter (isnotnull(FLOOR((cast(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast((month_y#2248 - 1) as bigint)) as double) / 12.0))) AND isnotnull(CASE WHEN (pmod(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast(month_y#2248 as bigint)), 12) = 0) THEN 12 ELSE pmod(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast(month_y#2248 as bigint)), 12) END))\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :     :                                            +- GpuColumnarToRow false\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :     :                                               +- GpuGenerate gpuexplode([0,1,2,3,4,5,6,7,8,9,10,11]), [loan_id#2453L, quarter#2484, delinquency_12#2123, upb_12#2159, timestamp_month#2051, timestamp_year#2087, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997], false, [month_y#2248]\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :     :                                                  +- GpuProject [loan_id#2453L, quarter#2484, delinquency_12#2123, upb_12#2159, timestamp_month#2051, timestamp_year#2087, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997]\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :     :                                                     +- GpuBroadcastHashJoin [loan_id#2453L, quarter#2484], [loan_id#2202L, quarter#2233], LeftOuter, GpuBuildRight\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :     :                                                        :- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :     :                                                        :  +- *(3) Project [quarter#2484, loan_id#2453L, current_loan_delinquency_status#2463 AS delinquency_12#2123, current_actual_upb#2457 AS upb_12#2159, month(cast(gettimestamp(monthly_reporting_period#2454, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_month#2051, year(cast(gettimestamp(monthly_reporting_period#2454, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_year#2087]\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :     :                                                        :     +- GpuColumnarToRow false\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :     :                                                        :        +- GpuFilter (gpuisnotnull(quarter#2484) AND gpuisnotnull(loan_id#2453L)), true\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :     :                                                        :           +- GpuFileGpuScan parquet [loan_id#2453L,monthly_reporting_period#2454,current_actual_upb#2457,current_loan_delinquency_status#2463,quarter#2484] Batched: true, DataFilters: [isnotnull(quarter#2484), isnotnull(loan_id#2453L)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(quarter), IsNotNull(loan_id)], ReadSchema: struct<loan_id:bigint,monthly_reporting_period:string,current_actual_upb:double,current_loan_deli...\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :     :                                                        +- GpuBroadcastExchange HashedRelationBroadcastMode(List(input[1, bigint, true], input[0, string, true]),false), [id=#3863]\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :     :                                                           +- GpuHashAggregate(keys=[quarter#2233, loan_id#2202L], functions=[gpumax(current_loan_delinquency_status#2212), gpumin(delinquency_30#1975), gpumin(delinquency_90#1976), gpumin(delinquency_180#1977)]), filters=ArrayBuffer(None, None, None, None))\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :     :                                                              +- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :     :                                                                 +- GpuColumnarExchange gpuhashpartitioning(quarter#2233, loan_id#2202L, 192), ENSURE_REQUIREMENTS, [id=#3860]\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :     :                                                                    +- GpuHashAggregate(keys=[quarter#2233, loan_id#2202L], functions=[partial_gpumax(current_loan_delinquency_status#2212), partial_gpumin(delinquency_30#1975), partial_gpumin(delinquency_90#1976), partial_gpumin(delinquency_180#1977)]), filters=ArrayBuffer(None, None, None, None))\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :     :                                                                       +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :     :                                                                          +- *(4) Project [quarter#2233, loan_id#2202L, current_loan_delinquency_status#2212, CASE WHEN (current_loan_delinquency_status#2212 >= 1) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_30#1975, CASE WHEN (current_loan_delinquency_status#2212 >= 3) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_90#1976, CASE WHEN (current_loan_delinquency_status#2212 >= 6) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_180#1977]\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :     :                                                                             +- GpuColumnarToRow false\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :     :                                                                                +- GpuFilter (gpuisnotnull(loan_id#2202L) AND gpuisnotnull(quarter#2233)), true\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :     :                                                                                   +- GpuFileGpuScan parquet [loan_id#2202L,monthly_reporting_period#2203,current_loan_delinquency_status#2212,quarter#2233] Batched: true, DataFilters: [isnotnull(loan_id#2202L), isnotnull(quarter#2233)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(loan_id), IsNotNull(quarter)], ReadSchema: struct<loan_id:bigint,monthly_reporting_period:string,current_loan_delinquency_status:int,quarter...\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :     +- GpuColumnarExchange gpuhashpartitioning(loan_id#1603L, quarter#1629, 192), ENSURE_REQUIREMENTS, [id=#3894]\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :        +- GpuProject [loan_id#1603L, orig_channel#1604, gpucoalesce(to_seller_name#2570, seller_name#1605) AS seller_name#2689, orig_interest_rate#1606, orig_upb#1607, orig_loan_term#1608, orig_ltv#1611, orig_cltv#1612, num_borrowers#1613, dti#1614, borrower_credit_score#1615, first_home_buyer#1616, loan_purpose#1617, property_type#1618, num_units#1619, occupancy_status#1620, property_state#1621, zip#1622, mortgage_insurance_percent#1623, product_type#1624, relocation_mortgage_indicator#1627, quarter#1629]\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :           +- GpuShuffledHashJoin [seller_name#1605], [from_seller_name#2569], LeftOuter, GpuBuildRight, false\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :              :- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :              :  +- GpuColumnarExchange gpuhashpartitioning(seller_name#1605, 192), ENSURE_REQUIREMENTS, [id=#3523]\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :              :     +- GpuCoalesceBatches targetsize(536870912)\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :              :        +- GpuFilter (gpuisnotnull(loan_id#1603L) AND gpuisnotnull(quarter#1629)), true\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :              :           +- GpuFileGpuScan parquet [loan_id#1603L,orig_channel#1604,seller_name#1605,orig_interest_rate#1606,orig_upb#1607,orig_loan_term#1608,orig_ltv#1611,orig_cltv#1612,num_borrowers#1613,dti#1614,borrower_credit_score#1615,first_home_buyer#1616,loan_purpose#1617,property_type#1618,num_units#1619,occupancy_status#1620,property_state#1621,zip#1622,mortgage_insurance_percent#1623,product_type#1624,relocation_mortgage_indicator#1627,quarter#1629] Batched: true, DataFilters: [isnotnull(loan_id#1603L), isnotnull(quarter#1629)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/acq], PartitionFilters: [], PushedFilters: [IsNotNull(loan_id), IsNotNull(quarter)], ReadSchema: struct<loan_id:bigint,orig_channel:string,seller_name:string,orig_interest_rate:double,orig_upb:i...\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :              +- GpuColumnarExchange gpuhashpartitioning(from_seller_name#2569, 192), ENSURE_REQUIREMENTS, [id=#3891]\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :                 +- GpuCoalesceBatches targetsize(536870912)\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :                    +- GpuFilter gpuisnotnull(from_seller_name#2569), true\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :                       +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :     :     :     :     :     :     :     :     :                          +- *(7) Scan ExistingRDD[from_seller_name#2569,to_seller_name#2570]\\n\",\n      \"      :     :     :     :     :     :     :     :     :     +- GpuBroadcastExchange HashedRelationBroadcastMode(List(input[0, string, true]),false), [id=#3694]\\n\",\n      \"      :     :     :     :     :     :     :     :     :        +- GpuProject [data#2945 AS orig_channel#3010, id#2956]\\n\",\n      \"      :     :     :     :     :     :     :     :     :           +- GpuCoalesceBatches targetsize(536870912)\\n\",\n      \"      :     :     :     :     :     :     :     :     :              +- GpuFilter ((column_id#2942 = 0) AND gpuisnotnull(data#2945)), true\\n\",\n      \"      :     :     :     :     :     :     :     :     :                 +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :     :     :     :     :     :     :     :                    +- InMemoryTableScan [column_id#2942, data#2945, id#2956], [(column_id#2942 = 0), isnotnull(data#2945)]\\n\",\n      \"      :     :     :     :     :     :     :     :     :                          +- InMemoryRelation [column_id#2942, data#2945, id#2956], StorageLevel(disk, memory, deserialized, 1 replicas)\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                +- GpuColumnarToRow false\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                   +- GpuProject [column_id#2942, data#2945, id#2956]\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                      +- GpuRunningWindow [column_id#2942, data#2945, count#2951L, gpurownumber$() gpuwindowspecdefinition(column_id#2942, count#2951L DESC NULLS LAST, gpuspecifiedwindowframe(RowFrame, gpuspecialframeboundary(unboundedpreceding$()), gpuspecialframeboundary(currentrow$()))) AS id#2956], [column_id#2942], [count#2951L DESC NULLS LAST]\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                         +- GpuSort [column_id#2942 ASC NULLS FIRST, count#2951L DESC NULLS LAST], false, com.nvidia.spark.rapids.OutOfCoreSort$@163d9f7d\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                            +- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                               +- GpuColumnarExchange gpuhashpartitioning(column_id#2942, 192), ENSURE_REQUIREMENTS, [id=#1141]\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                  +- GpuHashAggregate(keys=[column_id#2942, data#2945], functions=[gpucount(1)]), filters=ArrayBuffer(None))\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                     +- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                        +- GpuColumnarExchange gpuhashpartitioning(column_id#2942, data#2945, 192), ENSURE_REQUIREMENTS, [id=#1138]\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                           +- GpuHashAggregate(keys=[column_id#2942, data#2945], functions=[partial_gpucount(1)]), filters=ArrayBuffer(None))\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                              +- GpuProject [pos#2938 AS column_id#2942, col#2939 AS data#2945]\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                 +- GpuCoalesceBatches targetsize(536870912)\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                    +- GpuFilter gpuisnotnull(col#2939), true\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                       +- GpuGenerate gpuposexplode(array(orig_channel#1604, first_home_buyer#1616, loan_purpose#1617, property_type#1618, occupancy_status#1620, property_state#1621, product_type#1624, relocation_mortgage_indicator#1627, seller_name#2689, mod_flag#1550)), false, [pos#2938, col#2939]\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                          +- GpuProject [mod_flag#1550, orig_channel#1604, seller_name#2689, first_home_buyer#1616, loan_purpose#1617, property_type#1618, occupancy_status#1620, property_state#1621, product_type#1624, relocation_mortgage_indicator#1627]\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                             +- GpuShuffledHashJoin [loan_id#1539L, quarter#1570], [loan_id#1603L, quarter#1629], Inner, GpuBuildRight, false\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                :- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                :  +- GpuColumnarExchange gpuhashpartitioning(loan_id#1539L, quarter#1570, 192), ENSURE_REQUIREMENTS, [id=#1121]\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                :     +- GpuProject [quarter#1570, loan_id#1539L, mod_flag#1550]\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                :        +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                :           +- SortMergeJoin [quarter#1570, loan_id#1539L, cast(timestamp_year#2417 as bigint), cast(timestamp_month#2381 as bigint)], [quarter#2484, loan_id#2453L, timestamp_year#2307L, timestamp_month#2336L], LeftOuter\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                :              :- *(2) Sort [quarter#1570 ASC NULLS FIRST, loan_id#1539L ASC NULLS FIRST, cast(timestamp_year#2417 as bigint) ASC NULLS FIRST, cast(timestamp_month#2381 as bigint) ASC NULLS FIRST], false, 0\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                :              :  +- Exchange hashpartitioning(quarter#1570, loan_id#1539L, cast(timestamp_year#2417 as bigint), cast(timestamp_month#2381 as bigint), 192), ENSURE_REQUIREMENTS, [id=#1080]\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                :              :     +- *(1) Project [loan_id#1539L, mod_flag#1550, quarter#1570, month(cast(gettimestamp(monthly_reporting_period#1540, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_month#2381, year(cast(gettimestamp(monthly_reporting_period#1540, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_year#2417]\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                :              :        +- GpuColumnarToRow false\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                :              :           +- GpuFilter (gpuisnotnull(loan_id#1539L) AND gpuisnotnull(quarter#1570)), true\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                :              :              +- GpuFileGpuScan parquet [loan_id#1539L,monthly_reporting_period#1540,mod_flag#1550,quarter#1570] Batched: true, DataFilters: [isnotnull(loan_id#1539L), isnotnull(quarter#1570)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(loan_id), IsNotNull(quarter)], ReadSchema: struct<loan_id:bigint,monthly_reporting_period:string,mod_flag:string,quarter:string>\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                :              +- GpuColumnarToRow false\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                :                 +- GpuSort [quarter#2484 ASC NULLS FIRST, loan_id#2453L ASC NULLS FIRST, timestamp_year#2307L ASC NULLS FIRST, timestamp_month#2336L ASC NULLS FIRST], false, com.nvidia.spark.rapids.OutOfCoreSort$@163d9f7d\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                :                    +- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                :                       +- GpuColumnarExchange gpuhashpartitioning(quarter#2484, loan_id#2453L, timestamp_year#2307L, timestamp_month#2336L, 192), ENSURE_REQUIREMENTS, [id=#1114]\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                :                          +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                :                             +- *(6) HashAggregate(keys=[quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248], functions=[])\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                :                                +- GpuColumnarToRow false\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                :                                   +- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                :                                      +- GpuColumnarExchange gpuhashpartitioning(quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248, 192), ENSURE_REQUIREMENTS, [id=#1107]\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                :                                         +- GpuHashAggregate(keys=[quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248], functions=[]), filters=ArrayBuffer())\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                :                                            +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                :                                               +- *(5) Project [quarter#2484, FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) AS josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, loan_id#2453L, month_y#2248]\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                :                                                  +- *(5) Filter (isnotnull(FLOOR((cast(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast((month_y#2248 - 1) as bigint)) as double) / 12.0))) AND isnotnull(CASE WHEN (pmod(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast(month_y#2248 as bigint)), 12) = 0) THEN 12 ELSE pmod(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast(month_y#2248 as bigint)), 12) END))\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                :                                                     +- GpuColumnarToRow false\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                :                                                        +- GpuGenerate gpuexplode([0,1,2,3,4,5,6,7,8,9,10,11]), [loan_id#2453L, quarter#2484, timestamp_month#2051, timestamp_year#2087, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997], false, [month_y#2248]\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                :                                                           +- GpuProject [loan_id#2453L, quarter#2484, timestamp_month#2051, timestamp_year#2087, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997]\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                :                                                              +- GpuBroadcastHashJoin [loan_id#2453L, quarter#2484], [loan_id#2202L, quarter#2233], LeftOuter, GpuBuildRight\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                :                                                                 :- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                :                                                                 :  +- *(3) Project [quarter#2484, loan_id#2453L, month(cast(gettimestamp(monthly_reporting_period#2454, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_month#2051, year(cast(gettimestamp(monthly_reporting_period#2454, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_year#2087]\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                :                                                                 :     +- GpuColumnarToRow false\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                :                                                                 :        +- GpuFilter (gpuisnotnull(quarter#2484) AND gpuisnotnull(loan_id#2453L)), true\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                :                                                                 :           +- GpuFileGpuScan parquet [loan_id#2453L,monthly_reporting_period#2454,quarter#2484] Batched: true, DataFilters: [isnotnull(quarter#2484), isnotnull(loan_id#2453L)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(quarter), IsNotNull(loan_id)], ReadSchema: struct<loan_id:bigint,monthly_reporting_period:string,quarter:string>\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                :                                                                 +- GpuBroadcastExchange HashedRelationBroadcastMode(List(input[1, bigint, true], input[0, string, true]),false), [id=#1096]\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                :                                                                    +- GpuHashAggregate(keys=[quarter#2233, loan_id#2202L], functions=[gpumax(current_loan_delinquency_status#2212), gpumin(delinquency_30#1975), gpumin(delinquency_90#1976), gpumin(delinquency_180#1977)]), filters=ArrayBuffer(None, None, None, None))\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                :                                                                       +- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                :                                                                          +- GpuColumnarExchange gpuhashpartitioning(quarter#2233, loan_id#2202L, 192), ENSURE_REQUIREMENTS, [id=#1093]\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                :                                                                             +- GpuHashAggregate(keys=[quarter#2233, loan_id#2202L], functions=[partial_gpumax(current_loan_delinquency_status#2212), partial_gpumin(delinquency_30#1975), partial_gpumin(delinquency_90#1976), partial_gpumin(delinquency_180#1977)]), filters=ArrayBuffer(None, None, None, None))\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                :                                                                                +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                :                                                                                   +- *(4) Project [quarter#2233, loan_id#2202L, current_loan_delinquency_status#2212, CASE WHEN (current_loan_delinquency_status#2212 >= 1) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_30#1975, CASE WHEN (current_loan_delinquency_status#2212 >= 3) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_90#1976, CASE WHEN (current_loan_delinquency_status#2212 >= 6) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_180#1977]\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                :                                                                                      +- GpuColumnarToRow false\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                :                                                                                         +- GpuFilter (gpuisnotnull(loan_id#2202L) AND gpuisnotnull(quarter#2233)), true\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                :                                                                                            +- GpuFileGpuScan parquet [loan_id#2202L,monthly_reporting_period#2203,current_loan_delinquency_status#2212,quarter#2233] Batched: true, DataFilters: [isnotnull(loan_id#2202L), isnotnull(quarter#2233)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(loan_id), IsNotNull(quarter)], ReadSchema: struct<loan_id:bigint,monthly_reporting_period:string,current_loan_delinquency_status:int,quarter...\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                +- GpuColumnarExchange gpuhashpartitioning(loan_id#1603L, quarter#1629, 192), ENSURE_REQUIREMENTS, [id=#1130]\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                   +- GpuProject [loan_id#1603L, orig_channel#1604, gpucoalesce(to_seller_name#2570, seller_name#1605) AS seller_name#2689, first_home_buyer#1616, loan_purpose#1617, property_type#1618, occupancy_status#1620, property_state#1621, product_type#1624, relocation_mortgage_indicator#1627, quarter#1629]\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                      +- GpuShuffledHashJoin [seller_name#1605], [from_seller_name#2569], LeftOuter, GpuBuildRight, false\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                         :- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                         :  +- GpuColumnarExchange gpuhashpartitioning(seller_name#1605, 192), ENSURE_REQUIREMENTS, [id=#862]\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                         :     +- GpuCoalesceBatches targetsize(536870912)\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                         :        +- GpuFilter (gpuisnotnull(loan_id#1603L) AND gpuisnotnull(quarter#1629)), true\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                         :           +- GpuFileGpuScan parquet [loan_id#1603L,orig_channel#1604,seller_name#1605,first_home_buyer#1616,loan_purpose#1617,property_type#1618,occupancy_status#1620,property_state#1621,product_type#1624,relocation_mortgage_indicator#1627,quarter#1629] Batched: true, DataFilters: [isnotnull(loan_id#1603L), isnotnull(quarter#1629)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/acq], PartitionFilters: [], PushedFilters: [IsNotNull(loan_id), IsNotNull(quarter)], ReadSchema: struct<loan_id:bigint,orig_channel:string,seller_name:string,first_home_buyer:string,loan_purpose...\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                         +- GpuColumnarExchange gpuhashpartitioning(from_seller_name#2569, 192), ENSURE_REQUIREMENTS, [id=#1127]\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                            +- GpuCoalesceBatches targetsize(536870912)\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                               +- GpuFilter gpuisnotnull(from_seller_name#2569), true\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                                  +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :     :     :     :     :     :     :     :                                                                                                     +- *(7) Scan ExistingRDD[from_seller_name#2569,to_seller_name#2570]\\n\",\n      \"      :     :     :     :     :     :     :     :     +- GpuBroadcastExchange HashedRelationBroadcastMode(List(input[0, string, true]),false), [id=#3701]\\n\",\n      \"      :     :     :     :     :     :     :     :        +- GpuProject [data#2945 AS first_home_buyer#3215, id#2956]\\n\",\n      \"      :     :     :     :     :     :     :     :           +- GpuCoalesceBatches targetsize(536870912)\\n\",\n      \"      :     :     :     :     :     :     :     :              +- GpuFilter ((column_id#2942 = 1) AND gpuisnotnull(data#2945)), true\\n\",\n      \"      :     :     :     :     :     :     :     :                 +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :     :     :     :     :     :     :                    +- InMemoryTableScan [column_id#2942, data#2945, id#2956], [(column_id#2942 = 1), isnotnull(data#2945)]\\n\",\n      \"      :     :     :     :     :     :     :     :                          +- InMemoryRelation [column_id#2942, data#2945, id#2956], StorageLevel(disk, memory, deserialized, 1 replicas)\\n\",\n      \"      :     :     :     :     :     :     :     :                                +- GpuColumnarToRow false\\n\",\n      \"      :     :     :     :     :     :     :     :                                   +- GpuProject [column_id#2942, data#2945, id#2956]\\n\",\n      \"      :     :     :     :     :     :     :     :                                      +- GpuRunningWindow [column_id#2942, data#2945, count#2951L, gpurownumber$() gpuwindowspecdefinition(column_id#2942, count#2951L DESC NULLS LAST, gpuspecifiedwindowframe(RowFrame, gpuspecialframeboundary(unboundedpreceding$()), gpuspecialframeboundary(currentrow$()))) AS id#2956], [column_id#2942], [count#2951L DESC NULLS LAST]\\n\",\n      \"      :     :     :     :     :     :     :     :                                         +- GpuSort [column_id#2942 ASC NULLS FIRST, count#2951L DESC NULLS LAST], false, com.nvidia.spark.rapids.OutOfCoreSort$@163d9f7d\\n\",\n      \"      :     :     :     :     :     :     :     :                                            +- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :     :     :     :     :     :     :                                               +- GpuColumnarExchange gpuhashpartitioning(column_id#2942, 192), ENSURE_REQUIREMENTS, [id=#1141]\\n\",\n      \"      :     :     :     :     :     :     :     :                                                  +- GpuHashAggregate(keys=[column_id#2942, data#2945], functions=[gpucount(1)]), filters=ArrayBuffer(None))\\n\",\n      \"      :     :     :     :     :     :     :     :                                                     +- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :     :     :     :     :     :     :                                                        +- GpuColumnarExchange gpuhashpartitioning(column_id#2942, data#2945, 192), ENSURE_REQUIREMENTS, [id=#1138]\\n\",\n      \"      :     :     :     :     :     :     :     :                                                           +- GpuHashAggregate(keys=[column_id#2942, data#2945], functions=[partial_gpucount(1)]), filters=ArrayBuffer(None))\\n\",\n      \"      :     :     :     :     :     :     :     :                                                              +- GpuProject [pos#2938 AS column_id#2942, col#2939 AS data#2945]\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                 +- GpuCoalesceBatches targetsize(536870912)\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                    +- GpuFilter gpuisnotnull(col#2939), true\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                       +- GpuGenerate gpuposexplode(array(orig_channel#1604, first_home_buyer#1616, loan_purpose#1617, property_type#1618, occupancy_status#1620, property_state#1621, product_type#1624, relocation_mortgage_indicator#1627, seller_name#2689, mod_flag#1550)), false, [pos#2938, col#2939]\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                          +- GpuProject [mod_flag#1550, orig_channel#1604, seller_name#2689, first_home_buyer#1616, loan_purpose#1617, property_type#1618, occupancy_status#1620, property_state#1621, product_type#1624, relocation_mortgage_indicator#1627]\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                             +- GpuShuffledHashJoin [loan_id#1539L, quarter#1570], [loan_id#1603L, quarter#1629], Inner, GpuBuildRight, false\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                :- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                :  +- GpuColumnarExchange gpuhashpartitioning(loan_id#1539L, quarter#1570, 192), ENSURE_REQUIREMENTS, [id=#1121]\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                :     +- GpuProject [quarter#1570, loan_id#1539L, mod_flag#1550]\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                :        +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                :           +- SortMergeJoin [quarter#1570, loan_id#1539L, cast(timestamp_year#2417 as bigint), cast(timestamp_month#2381 as bigint)], [quarter#2484, loan_id#2453L, timestamp_year#2307L, timestamp_month#2336L], LeftOuter\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                :              :- *(2) Sort [quarter#1570 ASC NULLS FIRST, loan_id#1539L ASC NULLS FIRST, cast(timestamp_year#2417 as bigint) ASC NULLS FIRST, cast(timestamp_month#2381 as bigint) ASC NULLS FIRST], false, 0\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                :              :  +- Exchange hashpartitioning(quarter#1570, loan_id#1539L, cast(timestamp_year#2417 as bigint), cast(timestamp_month#2381 as bigint), 192), ENSURE_REQUIREMENTS, [id=#1080]\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                :              :     +- *(1) Project [loan_id#1539L, mod_flag#1550, quarter#1570, month(cast(gettimestamp(monthly_reporting_period#1540, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_month#2381, year(cast(gettimestamp(monthly_reporting_period#1540, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_year#2417]\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                :              :        +- GpuColumnarToRow false\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                :              :           +- GpuFilter (gpuisnotnull(loan_id#1539L) AND gpuisnotnull(quarter#1570)), true\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                :              :              +- GpuFileGpuScan parquet [loan_id#1539L,monthly_reporting_period#1540,mod_flag#1550,quarter#1570] Batched: true, DataFilters: [isnotnull(loan_id#1539L), isnotnull(quarter#1570)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(loan_id), IsNotNull(quarter)], ReadSchema: struct<loan_id:bigint,monthly_reporting_period:string,mod_flag:string,quarter:string>\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                :              +- GpuColumnarToRow false\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                :                 +- GpuSort [quarter#2484 ASC NULLS FIRST, loan_id#2453L ASC NULLS FIRST, timestamp_year#2307L ASC NULLS FIRST, timestamp_month#2336L ASC NULLS FIRST], false, com.nvidia.spark.rapids.OutOfCoreSort$@163d9f7d\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                :                    +- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                :                       +- GpuColumnarExchange gpuhashpartitioning(quarter#2484, loan_id#2453L, timestamp_year#2307L, timestamp_month#2336L, 192), ENSURE_REQUIREMENTS, [id=#1114]\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                :                          +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                :                             +- *(6) HashAggregate(keys=[quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248], functions=[])\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                :                                +- GpuColumnarToRow false\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                :                                   +- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                :                                      +- GpuColumnarExchange gpuhashpartitioning(quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248, 192), ENSURE_REQUIREMENTS, [id=#1107]\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                :                                         +- GpuHashAggregate(keys=[quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248], functions=[]), filters=ArrayBuffer())\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                :                                            +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                :                                               +- *(5) Project [quarter#2484, FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) AS josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, loan_id#2453L, month_y#2248]\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                :                                                  +- *(5) Filter (isnotnull(FLOOR((cast(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast((month_y#2248 - 1) as bigint)) as double) / 12.0))) AND isnotnull(CASE WHEN (pmod(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast(month_y#2248 as bigint)), 12) = 0) THEN 12 ELSE pmod(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast(month_y#2248 as bigint)), 12) END))\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                :                                                     +- GpuColumnarToRow false\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                :                                                        +- GpuGenerate gpuexplode([0,1,2,3,4,5,6,7,8,9,10,11]), [loan_id#2453L, quarter#2484, timestamp_month#2051, timestamp_year#2087, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997], false, [month_y#2248]\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                :                                                           +- GpuProject [loan_id#2453L, quarter#2484, timestamp_month#2051, timestamp_year#2087, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997]\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                :                                                              +- GpuBroadcastHashJoin [loan_id#2453L, quarter#2484], [loan_id#2202L, quarter#2233], LeftOuter, GpuBuildRight\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                :                                                                 :- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                :                                                                 :  +- *(3) Project [quarter#2484, loan_id#2453L, month(cast(gettimestamp(monthly_reporting_period#2454, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_month#2051, year(cast(gettimestamp(monthly_reporting_period#2454, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_year#2087]\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                :                                                                 :     +- GpuColumnarToRow false\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                :                                                                 :        +- GpuFilter (gpuisnotnull(quarter#2484) AND gpuisnotnull(loan_id#2453L)), true\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                :                                                                 :           +- GpuFileGpuScan parquet [loan_id#2453L,monthly_reporting_period#2454,quarter#2484] Batched: true, DataFilters: [isnotnull(quarter#2484), isnotnull(loan_id#2453L)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(quarter), IsNotNull(loan_id)], ReadSchema: struct<loan_id:bigint,monthly_reporting_period:string,quarter:string>\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                :                                                                 +- GpuBroadcastExchange HashedRelationBroadcastMode(List(input[1, bigint, true], input[0, string, true]),false), [id=#1096]\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                :                                                                    +- GpuHashAggregate(keys=[quarter#2233, loan_id#2202L], functions=[gpumax(current_loan_delinquency_status#2212), gpumin(delinquency_30#1975), gpumin(delinquency_90#1976), gpumin(delinquency_180#1977)]), filters=ArrayBuffer(None, None, None, None))\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                :                                                                       +- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                :                                                                          +- GpuColumnarExchange gpuhashpartitioning(quarter#2233, loan_id#2202L, 192), ENSURE_REQUIREMENTS, [id=#1093]\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                :                                                                             +- GpuHashAggregate(keys=[quarter#2233, loan_id#2202L], functions=[partial_gpumax(current_loan_delinquency_status#2212), partial_gpumin(delinquency_30#1975), partial_gpumin(delinquency_90#1976), partial_gpumin(delinquency_180#1977)]), filters=ArrayBuffer(None, None, None, None))\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                :                                                                                +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                :                                                                                   +- *(4) Project [quarter#2233, loan_id#2202L, current_loan_delinquency_status#2212, CASE WHEN (current_loan_delinquency_status#2212 >= 1) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_30#1975, CASE WHEN (current_loan_delinquency_status#2212 >= 3) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_90#1976, CASE WHEN (current_loan_delinquency_status#2212 >= 6) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_180#1977]\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                :                                                                                      +- GpuColumnarToRow false\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                :                                                                                         +- GpuFilter (gpuisnotnull(loan_id#2202L) AND gpuisnotnull(quarter#2233)), true\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                :                                                                                            +- GpuFileGpuScan parquet [loan_id#2202L,monthly_reporting_period#2203,current_loan_delinquency_status#2212,quarter#2233] Batched: true, DataFilters: [isnotnull(loan_id#2202L), isnotnull(quarter#2233)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(loan_id), IsNotNull(quarter)], ReadSchema: struct<loan_id:bigint,monthly_reporting_period:string,current_loan_delinquency_status:int,quarter...\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                +- GpuColumnarExchange gpuhashpartitioning(loan_id#1603L, quarter#1629, 192), ENSURE_REQUIREMENTS, [id=#1130]\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                   +- GpuProject [loan_id#1603L, orig_channel#1604, gpucoalesce(to_seller_name#2570, seller_name#1605) AS seller_name#2689, first_home_buyer#1616, loan_purpose#1617, property_type#1618, occupancy_status#1620, property_state#1621, product_type#1624, relocation_mortgage_indicator#1627, quarter#1629]\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                      +- GpuShuffledHashJoin [seller_name#1605], [from_seller_name#2569], LeftOuter, GpuBuildRight, false\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                         :- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                         :  +- GpuColumnarExchange gpuhashpartitioning(seller_name#1605, 192), ENSURE_REQUIREMENTS, [id=#862]\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                         :     +- GpuCoalesceBatches targetsize(536870912)\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                         :        +- GpuFilter (gpuisnotnull(loan_id#1603L) AND gpuisnotnull(quarter#1629)), true\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                         :           +- GpuFileGpuScan parquet [loan_id#1603L,orig_channel#1604,seller_name#1605,first_home_buyer#1616,loan_purpose#1617,property_type#1618,occupancy_status#1620,property_state#1621,product_type#1624,relocation_mortgage_indicator#1627,quarter#1629] Batched: true, DataFilters: [isnotnull(loan_id#1603L), isnotnull(quarter#1629)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/acq], PartitionFilters: [], PushedFilters: [IsNotNull(loan_id), IsNotNull(quarter)], ReadSchema: struct<loan_id:bigint,orig_channel:string,seller_name:string,first_home_buyer:string,loan_purpose...\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                         +- GpuColumnarExchange gpuhashpartitioning(from_seller_name#2569, 192), ENSURE_REQUIREMENTS, [id=#1127]\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                            +- GpuCoalesceBatches targetsize(536870912)\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                               +- GpuFilter gpuisnotnull(from_seller_name#2569), true\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                                  +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :     :     :     :     :     :     :                                                                                                     +- *(7) Scan ExistingRDD[from_seller_name#2569,to_seller_name#2570]\\n\",\n      \"      :     :     :     :     :     :     :     +- GpuBroadcastExchange HashedRelationBroadcastMode(List(input[0, string, true]),false), [id=#3708]\\n\",\n      \"      :     :     :     :     :     :     :        +- GpuProject [data#2945 AS loan_purpose#3420, id#2956]\\n\",\n      \"      :     :     :     :     :     :     :           +- GpuCoalesceBatches targetsize(536870912)\\n\",\n      \"      :     :     :     :     :     :     :              +- GpuFilter ((column_id#2942 = 2) AND gpuisnotnull(data#2945)), true\\n\",\n      \"      :     :     :     :     :     :     :                 +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :     :     :     :     :     :                    +- InMemoryTableScan [column_id#2942, data#2945, id#2956], [(column_id#2942 = 2), isnotnull(data#2945)]\\n\",\n      \"      :     :     :     :     :     :     :                          +- InMemoryRelation [column_id#2942, data#2945, id#2956], StorageLevel(disk, memory, deserialized, 1 replicas)\\n\",\n      \"      :     :     :     :     :     :     :                                +- GpuColumnarToRow false\\n\",\n      \"      :     :     :     :     :     :     :                                   +- GpuProject [column_id#2942, data#2945, id#2956]\\n\",\n      \"      :     :     :     :     :     :     :                                      +- GpuRunningWindow [column_id#2942, data#2945, count#2951L, gpurownumber$() gpuwindowspecdefinition(column_id#2942, count#2951L DESC NULLS LAST, gpuspecifiedwindowframe(RowFrame, gpuspecialframeboundary(unboundedpreceding$()), gpuspecialframeboundary(currentrow$()))) AS id#2956], [column_id#2942], [count#2951L DESC NULLS LAST]\\n\",\n      \"      :     :     :     :     :     :     :                                         +- GpuSort [column_id#2942 ASC NULLS FIRST, count#2951L DESC NULLS LAST], false, com.nvidia.spark.rapids.OutOfCoreSort$@163d9f7d\\n\",\n      \"      :     :     :     :     :     :     :                                            +- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :     :     :     :     :     :                                               +- GpuColumnarExchange gpuhashpartitioning(column_id#2942, 192), ENSURE_REQUIREMENTS, [id=#1141]\\n\",\n      \"      :     :     :     :     :     :     :                                                  +- GpuHashAggregate(keys=[column_id#2942, data#2945], functions=[gpucount(1)]), filters=ArrayBuffer(None))\\n\",\n      \"      :     :     :     :     :     :     :                                                     +- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :     :     :     :     :     :                                                        +- GpuColumnarExchange gpuhashpartitioning(column_id#2942, data#2945, 192), ENSURE_REQUIREMENTS, [id=#1138]\\n\",\n      \"      :     :     :     :     :     :     :                                                           +- GpuHashAggregate(keys=[column_id#2942, data#2945], functions=[partial_gpucount(1)]), filters=ArrayBuffer(None))\\n\",\n      \"      :     :     :     :     :     :     :                                                              +- GpuProject [pos#2938 AS column_id#2942, col#2939 AS data#2945]\\n\",\n      \"      :     :     :     :     :     :     :                                                                 +- GpuCoalesceBatches targetsize(536870912)\\n\",\n      \"      :     :     :     :     :     :     :                                                                    +- GpuFilter gpuisnotnull(col#2939), true\\n\",\n      \"      :     :     :     :     :     :     :                                                                       +- GpuGenerate gpuposexplode(array(orig_channel#1604, first_home_buyer#1616, loan_purpose#1617, property_type#1618, occupancy_status#1620, property_state#1621, product_type#1624, relocation_mortgage_indicator#1627, seller_name#2689, mod_flag#1550)), false, [pos#2938, col#2939]\\n\",\n      \"      :     :     :     :     :     :     :                                                                          +- GpuProject [mod_flag#1550, orig_channel#1604, seller_name#2689, first_home_buyer#1616, loan_purpose#1617, property_type#1618, occupancy_status#1620, property_state#1621, product_type#1624, relocation_mortgage_indicator#1627]\\n\",\n      \"      :     :     :     :     :     :     :                                                                             +- GpuShuffledHashJoin [loan_id#1539L, quarter#1570], [loan_id#1603L, quarter#1629], Inner, GpuBuildRight, false\\n\",\n      \"      :     :     :     :     :     :     :                                                                                :- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :     :     :     :     :     :                                                                                :  +- GpuColumnarExchange gpuhashpartitioning(loan_id#1539L, quarter#1570, 192), ENSURE_REQUIREMENTS, [id=#1121]\\n\",\n      \"      :     :     :     :     :     :     :                                                                                :     +- GpuProject [quarter#1570, loan_id#1539L, mod_flag#1550]\\n\",\n      \"      :     :     :     :     :     :     :                                                                                :        +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :     :     :     :     :     :                                                                                :           +- SortMergeJoin [quarter#1570, loan_id#1539L, cast(timestamp_year#2417 as bigint), cast(timestamp_month#2381 as bigint)], [quarter#2484, loan_id#2453L, timestamp_year#2307L, timestamp_month#2336L], LeftOuter\\n\",\n      \"      :     :     :     :     :     :     :                                                                                :              :- *(2) Sort [quarter#1570 ASC NULLS FIRST, loan_id#1539L ASC NULLS FIRST, cast(timestamp_year#2417 as bigint) ASC NULLS FIRST, cast(timestamp_month#2381 as bigint) ASC NULLS FIRST], false, 0\\n\",\n      \"      :     :     :     :     :     :     :                                                                                :              :  +- Exchange hashpartitioning(quarter#1570, loan_id#1539L, cast(timestamp_year#2417 as bigint), cast(timestamp_month#2381 as bigint), 192), ENSURE_REQUIREMENTS, [id=#1080]\\n\",\n      \"      :     :     :     :     :     :     :                                                                                :              :     +- *(1) Project [loan_id#1539L, mod_flag#1550, quarter#1570, month(cast(gettimestamp(monthly_reporting_period#1540, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_month#2381, year(cast(gettimestamp(monthly_reporting_period#1540, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_year#2417]\\n\",\n      \"      :     :     :     :     :     :     :                                                                                :              :        +- GpuColumnarToRow false\\n\",\n      \"      :     :     :     :     :     :     :                                                                                :              :           +- GpuFilter (gpuisnotnull(loan_id#1539L) AND gpuisnotnull(quarter#1570)), true\\n\",\n      \"      :     :     :     :     :     :     :                                                                                :              :              +- GpuFileGpuScan parquet [loan_id#1539L,monthly_reporting_period#1540,mod_flag#1550,quarter#1570] Batched: true, DataFilters: [isnotnull(loan_id#1539L), isnotnull(quarter#1570)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(loan_id), IsNotNull(quarter)], ReadSchema: struct<loan_id:bigint,monthly_reporting_period:string,mod_flag:string,quarter:string>\\n\",\n      \"      :     :     :     :     :     :     :                                                                                :              +- GpuColumnarToRow false\\n\",\n      \"      :     :     :     :     :     :     :                                                                                :                 +- GpuSort [quarter#2484 ASC NULLS FIRST, loan_id#2453L ASC NULLS FIRST, timestamp_year#2307L ASC NULLS FIRST, timestamp_month#2336L ASC NULLS FIRST], false, com.nvidia.spark.rapids.OutOfCoreSort$@163d9f7d\\n\",\n      \"      :     :     :     :     :     :     :                                                                                :                    +- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :     :     :     :     :     :                                                                                :                       +- GpuColumnarExchange gpuhashpartitioning(quarter#2484, loan_id#2453L, timestamp_year#2307L, timestamp_month#2336L, 192), ENSURE_REQUIREMENTS, [id=#1114]\\n\",\n      \"      :     :     :     :     :     :     :                                                                                :                          +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :     :     :     :     :     :                                                                                :                             +- *(6) HashAggregate(keys=[quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248], functions=[])\\n\",\n      \"      :     :     :     :     :     :     :                                                                                :                                +- GpuColumnarToRow false\\n\",\n      \"      :     :     :     :     :     :     :                                                                                :                                   +- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :     :     :     :     :     :                                                                                :                                      +- GpuColumnarExchange gpuhashpartitioning(quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248, 192), ENSURE_REQUIREMENTS, [id=#1107]\\n\",\n      \"      :     :     :     :     :     :     :                                                                                :                                         +- GpuHashAggregate(keys=[quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248], functions=[]), filters=ArrayBuffer())\\n\",\n      \"      :     :     :     :     :     :     :                                                                                :                                            +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :     :     :     :     :     :                                                                                :                                               +- *(5) Project [quarter#2484, FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) AS josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, loan_id#2453L, month_y#2248]\\n\",\n      \"      :     :     :     :     :     :     :                                                                                :                                                  +- *(5) Filter (isnotnull(FLOOR((cast(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast((month_y#2248 - 1) as bigint)) as double) / 12.0))) AND isnotnull(CASE WHEN (pmod(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast(month_y#2248 as bigint)), 12) = 0) THEN 12 ELSE pmod(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast(month_y#2248 as bigint)), 12) END))\\n\",\n      \"      :     :     :     :     :     :     :                                                                                :                                                     +- GpuColumnarToRow false\\n\",\n      \"      :     :     :     :     :     :     :                                                                                :                                                        +- GpuGenerate gpuexplode([0,1,2,3,4,5,6,7,8,9,10,11]), [loan_id#2453L, quarter#2484, timestamp_month#2051, timestamp_year#2087, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997], false, [month_y#2248]\\n\",\n      \"      :     :     :     :     :     :     :                                                                                :                                                           +- GpuProject [loan_id#2453L, quarter#2484, timestamp_month#2051, timestamp_year#2087, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997]\\n\",\n      \"      :     :     :     :     :     :     :                                                                                :                                                              +- GpuBroadcastHashJoin [loan_id#2453L, quarter#2484], [loan_id#2202L, quarter#2233], LeftOuter, GpuBuildRight\\n\",\n      \"      :     :     :     :     :     :     :                                                                                :                                                                 :- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :     :     :     :     :     :                                                                                :                                                                 :  +- *(3) Project [quarter#2484, loan_id#2453L, month(cast(gettimestamp(monthly_reporting_period#2454, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_month#2051, year(cast(gettimestamp(monthly_reporting_period#2454, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_year#2087]\\n\",\n      \"      :     :     :     :     :     :     :                                                                                :                                                                 :     +- GpuColumnarToRow false\\n\",\n      \"      :     :     :     :     :     :     :                                                                                :                                                                 :        +- GpuFilter (gpuisnotnull(quarter#2484) AND gpuisnotnull(loan_id#2453L)), true\\n\",\n      \"      :     :     :     :     :     :     :                                                                                :                                                                 :           +- GpuFileGpuScan parquet [loan_id#2453L,monthly_reporting_period#2454,quarter#2484] Batched: true, DataFilters: [isnotnull(quarter#2484), isnotnull(loan_id#2453L)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(quarter), IsNotNull(loan_id)], ReadSchema: struct<loan_id:bigint,monthly_reporting_period:string,quarter:string>\\n\",\n      \"      :     :     :     :     :     :     :                                                                                :                                                                 +- GpuBroadcastExchange HashedRelationBroadcastMode(List(input[1, bigint, true], input[0, string, true]),false), [id=#1096]\\n\",\n      \"      :     :     :     :     :     :     :                                                                                :                                                                    +- GpuHashAggregate(keys=[quarter#2233, loan_id#2202L], functions=[gpumax(current_loan_delinquency_status#2212), gpumin(delinquency_30#1975), gpumin(delinquency_90#1976), gpumin(delinquency_180#1977)]), filters=ArrayBuffer(None, None, None, None))\\n\",\n      \"      :     :     :     :     :     :     :                                                                                :                                                                       +- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :     :     :     :     :     :                                                                                :                                                                          +- GpuColumnarExchange gpuhashpartitioning(quarter#2233, loan_id#2202L, 192), ENSURE_REQUIREMENTS, [id=#1093]\\n\",\n      \"      :     :     :     :     :     :     :                                                                                :                                                                             +- GpuHashAggregate(keys=[quarter#2233, loan_id#2202L], functions=[partial_gpumax(current_loan_delinquency_status#2212), partial_gpumin(delinquency_30#1975), partial_gpumin(delinquency_90#1976), partial_gpumin(delinquency_180#1977)]), filters=ArrayBuffer(None, None, None, None))\\n\",\n      \"      :     :     :     :     :     :     :                                                                                :                                                                                +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :     :     :     :     :     :                                                                                :                                                                                   +- *(4) Project [quarter#2233, loan_id#2202L, current_loan_delinquency_status#2212, CASE WHEN (current_loan_delinquency_status#2212 >= 1) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_30#1975, CASE WHEN (current_loan_delinquency_status#2212 >= 3) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_90#1976, CASE WHEN (current_loan_delinquency_status#2212 >= 6) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_180#1977]\\n\",\n      \"      :     :     :     :     :     :     :                                                                                :                                                                                      +- GpuColumnarToRow false\\n\",\n      \"      :     :     :     :     :     :     :                                                                                :                                                                                         +- GpuFilter (gpuisnotnull(loan_id#2202L) AND gpuisnotnull(quarter#2233)), true\\n\",\n      \"      :     :     :     :     :     :     :                                                                                :                                                                                            +- GpuFileGpuScan parquet [loan_id#2202L,monthly_reporting_period#2203,current_loan_delinquency_status#2212,quarter#2233] Batched: true, DataFilters: [isnotnull(loan_id#2202L), isnotnull(quarter#2233)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(loan_id), IsNotNull(quarter)], ReadSchema: struct<loan_id:bigint,monthly_reporting_period:string,current_loan_delinquency_status:int,quarter...\\n\",\n      \"      :     :     :     :     :     :     :                                                                                +- GpuColumnarExchange gpuhashpartitioning(loan_id#1603L, quarter#1629, 192), ENSURE_REQUIREMENTS, [id=#1130]\\n\",\n      \"      :     :     :     :     :     :     :                                                                                   +- GpuProject [loan_id#1603L, orig_channel#1604, gpucoalesce(to_seller_name#2570, seller_name#1605) AS seller_name#2689, first_home_buyer#1616, loan_purpose#1617, property_type#1618, occupancy_status#1620, property_state#1621, product_type#1624, relocation_mortgage_indicator#1627, quarter#1629]\\n\",\n      \"      :     :     :     :     :     :     :                                                                                      +- GpuShuffledHashJoin [seller_name#1605], [from_seller_name#2569], LeftOuter, GpuBuildRight, false\\n\",\n      \"      :     :     :     :     :     :     :                                                                                         :- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :     :     :     :     :     :                                                                                         :  +- GpuColumnarExchange gpuhashpartitioning(seller_name#1605, 192), ENSURE_REQUIREMENTS, [id=#862]\\n\",\n      \"      :     :     :     :     :     :     :                                                                                         :     +- GpuCoalesceBatches targetsize(536870912)\\n\",\n      \"      :     :     :     :     :     :     :                                                                                         :        +- GpuFilter (gpuisnotnull(loan_id#1603L) AND gpuisnotnull(quarter#1629)), true\\n\",\n      \"      :     :     :     :     :     :     :                                                                                         :           +- GpuFileGpuScan parquet [loan_id#1603L,orig_channel#1604,seller_name#1605,first_home_buyer#1616,loan_purpose#1617,property_type#1618,occupancy_status#1620,property_state#1621,product_type#1624,relocation_mortgage_indicator#1627,quarter#1629] Batched: true, DataFilters: [isnotnull(loan_id#1603L), isnotnull(quarter#1629)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/acq], PartitionFilters: [], PushedFilters: [IsNotNull(loan_id), IsNotNull(quarter)], ReadSchema: struct<loan_id:bigint,orig_channel:string,seller_name:string,first_home_buyer:string,loan_purpose...\\n\",\n      \"      :     :     :     :     :     :     :                                                                                         +- GpuColumnarExchange gpuhashpartitioning(from_seller_name#2569, 192), ENSURE_REQUIREMENTS, [id=#1127]\\n\",\n      \"      :     :     :     :     :     :     :                                                                                            +- GpuCoalesceBatches targetsize(536870912)\\n\",\n      \"      :     :     :     :     :     :     :                                                                                               +- GpuFilter gpuisnotnull(from_seller_name#2569), true\\n\",\n      \"      :     :     :     :     :     :     :                                                                                                  +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :     :     :     :     :     :                                                                                                     +- *(7) Scan ExistingRDD[from_seller_name#2569,to_seller_name#2570]\\n\",\n      \"      :     :     :     :     :     :     +- GpuBroadcastExchange HashedRelationBroadcastMode(List(input[0, string, true]),false), [id=#3715]\\n\",\n      \"      :     :     :     :     :     :        +- GpuProject [data#2945 AS property_type#3625, id#2956]\\n\",\n      \"      :     :     :     :     :     :           +- GpuCoalesceBatches targetsize(536870912)\\n\",\n      \"      :     :     :     :     :     :              +- GpuFilter ((column_id#2942 = 3) AND gpuisnotnull(data#2945)), true\\n\",\n      \"      :     :     :     :     :     :                 +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :     :     :     :     :                    +- InMemoryTableScan [column_id#2942, data#2945, id#2956], [(column_id#2942 = 3), isnotnull(data#2945)]\\n\",\n      \"      :     :     :     :     :     :                          +- InMemoryRelation [column_id#2942, data#2945, id#2956], StorageLevel(disk, memory, deserialized, 1 replicas)\\n\",\n      \"      :     :     :     :     :     :                                +- GpuColumnarToRow false\\n\",\n      \"      :     :     :     :     :     :                                   +- GpuProject [column_id#2942, data#2945, id#2956]\\n\",\n      \"      :     :     :     :     :     :                                      +- GpuRunningWindow [column_id#2942, data#2945, count#2951L, gpurownumber$() gpuwindowspecdefinition(column_id#2942, count#2951L DESC NULLS LAST, gpuspecifiedwindowframe(RowFrame, gpuspecialframeboundary(unboundedpreceding$()), gpuspecialframeboundary(currentrow$()))) AS id#2956], [column_id#2942], [count#2951L DESC NULLS LAST]\\n\",\n      \"      :     :     :     :     :     :                                         +- GpuSort [column_id#2942 ASC NULLS FIRST, count#2951L DESC NULLS LAST], false, com.nvidia.spark.rapids.OutOfCoreSort$@163d9f7d\\n\",\n      \"      :     :     :     :     :     :                                            +- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :     :     :     :     :                                               +- GpuColumnarExchange gpuhashpartitioning(column_id#2942, 192), ENSURE_REQUIREMENTS, [id=#1141]\\n\",\n      \"      :     :     :     :     :     :                                                  +- GpuHashAggregate(keys=[column_id#2942, data#2945], functions=[gpucount(1)]), filters=ArrayBuffer(None))\\n\",\n      \"      :     :     :     :     :     :                                                     +- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :     :     :     :     :                                                        +- GpuColumnarExchange gpuhashpartitioning(column_id#2942, data#2945, 192), ENSURE_REQUIREMENTS, [id=#1138]\\n\",\n      \"      :     :     :     :     :     :                                                           +- GpuHashAggregate(keys=[column_id#2942, data#2945], functions=[partial_gpucount(1)]), filters=ArrayBuffer(None))\\n\",\n      \"      :     :     :     :     :     :                                                              +- GpuProject [pos#2938 AS column_id#2942, col#2939 AS data#2945]\\n\",\n      \"      :     :     :     :     :     :                                                                 +- GpuCoalesceBatches targetsize(536870912)\\n\",\n      \"      :     :     :     :     :     :                                                                    +- GpuFilter gpuisnotnull(col#2939), true\\n\",\n      \"      :     :     :     :     :     :                                                                       +- GpuGenerate gpuposexplode(array(orig_channel#1604, first_home_buyer#1616, loan_purpose#1617, property_type#1618, occupancy_status#1620, property_state#1621, product_type#1624, relocation_mortgage_indicator#1627, seller_name#2689, mod_flag#1550)), false, [pos#2938, col#2939]\\n\",\n      \"      :     :     :     :     :     :                                                                          +- GpuProject [mod_flag#1550, orig_channel#1604, seller_name#2689, first_home_buyer#1616, loan_purpose#1617, property_type#1618, occupancy_status#1620, property_state#1621, product_type#1624, relocation_mortgage_indicator#1627]\\n\",\n      \"      :     :     :     :     :     :                                                                             +- GpuShuffledHashJoin [loan_id#1539L, quarter#1570], [loan_id#1603L, quarter#1629], Inner, GpuBuildRight, false\\n\",\n      \"      :     :     :     :     :     :                                                                                :- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :     :     :     :     :                                                                                :  +- GpuColumnarExchange gpuhashpartitioning(loan_id#1539L, quarter#1570, 192), ENSURE_REQUIREMENTS, [id=#1121]\\n\",\n      \"      :     :     :     :     :     :                                                                                :     +- GpuProject [quarter#1570, loan_id#1539L, mod_flag#1550]\\n\",\n      \"      :     :     :     :     :     :                                                                                :        +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :     :     :     :     :                                                                                :           +- SortMergeJoin [quarter#1570, loan_id#1539L, cast(timestamp_year#2417 as bigint), cast(timestamp_month#2381 as bigint)], [quarter#2484, loan_id#2453L, timestamp_year#2307L, timestamp_month#2336L], LeftOuter\\n\",\n      \"      :     :     :     :     :     :                                                                                :              :- *(2) Sort [quarter#1570 ASC NULLS FIRST, loan_id#1539L ASC NULLS FIRST, cast(timestamp_year#2417 as bigint) ASC NULLS FIRST, cast(timestamp_month#2381 as bigint) ASC NULLS FIRST], false, 0\\n\",\n      \"      :     :     :     :     :     :                                                                                :              :  +- Exchange hashpartitioning(quarter#1570, loan_id#1539L, cast(timestamp_year#2417 as bigint), cast(timestamp_month#2381 as bigint), 192), ENSURE_REQUIREMENTS, [id=#1080]\\n\",\n      \"      :     :     :     :     :     :                                                                                :              :     +- *(1) Project [loan_id#1539L, mod_flag#1550, quarter#1570, month(cast(gettimestamp(monthly_reporting_period#1540, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_month#2381, year(cast(gettimestamp(monthly_reporting_period#1540, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_year#2417]\\n\",\n      \"      :     :     :     :     :     :                                                                                :              :        +- GpuColumnarToRow false\\n\",\n      \"      :     :     :     :     :     :                                                                                :              :           +- GpuFilter (gpuisnotnull(loan_id#1539L) AND gpuisnotnull(quarter#1570)), true\\n\",\n      \"      :     :     :     :     :     :                                                                                :              :              +- GpuFileGpuScan parquet [loan_id#1539L,monthly_reporting_period#1540,mod_flag#1550,quarter#1570] Batched: true, DataFilters: [isnotnull(loan_id#1539L), isnotnull(quarter#1570)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(loan_id), IsNotNull(quarter)], ReadSchema: struct<loan_id:bigint,monthly_reporting_period:string,mod_flag:string,quarter:string>\\n\",\n      \"      :     :     :     :     :     :                                                                                :              +- GpuColumnarToRow false\\n\",\n      \"      :     :     :     :     :     :                                                                                :                 +- GpuSort [quarter#2484 ASC NULLS FIRST, loan_id#2453L ASC NULLS FIRST, timestamp_year#2307L ASC NULLS FIRST, timestamp_month#2336L ASC NULLS FIRST], false, com.nvidia.spark.rapids.OutOfCoreSort$@163d9f7d\\n\",\n      \"      :     :     :     :     :     :                                                                                :                    +- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :     :     :     :     :                                                                                :                       +- GpuColumnarExchange gpuhashpartitioning(quarter#2484, loan_id#2453L, timestamp_year#2307L, timestamp_month#2336L, 192), ENSURE_REQUIREMENTS, [id=#1114]\\n\",\n      \"      :     :     :     :     :     :                                                                                :                          +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :     :     :     :     :                                                                                :                             +- *(6) HashAggregate(keys=[quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248], functions=[])\\n\",\n      \"      :     :     :     :     :     :                                                                                :                                +- GpuColumnarToRow false\\n\",\n      \"      :     :     :     :     :     :                                                                                :                                   +- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :     :     :     :     :                                                                                :                                      +- GpuColumnarExchange gpuhashpartitioning(quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248, 192), ENSURE_REQUIREMENTS, [id=#1107]\\n\",\n      \"      :     :     :     :     :     :                                                                                :                                         +- GpuHashAggregate(keys=[quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248], functions=[]), filters=ArrayBuffer())\\n\",\n      \"      :     :     :     :     :     :                                                                                :                                            +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :     :     :     :     :                                                                                :                                               +- *(5) Project [quarter#2484, FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) AS josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, loan_id#2453L, month_y#2248]\\n\",\n      \"      :     :     :     :     :     :                                                                                :                                                  +- *(5) Filter (isnotnull(FLOOR((cast(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast((month_y#2248 - 1) as bigint)) as double) / 12.0))) AND isnotnull(CASE WHEN (pmod(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast(month_y#2248 as bigint)), 12) = 0) THEN 12 ELSE pmod(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast(month_y#2248 as bigint)), 12) END))\\n\",\n      \"      :     :     :     :     :     :                                                                                :                                                     +- GpuColumnarToRow false\\n\",\n      \"      :     :     :     :     :     :                                                                                :                                                        +- GpuGenerate gpuexplode([0,1,2,3,4,5,6,7,8,9,10,11]), [loan_id#2453L, quarter#2484, timestamp_month#2051, timestamp_year#2087, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997], false, [month_y#2248]\\n\",\n      \"      :     :     :     :     :     :                                                                                :                                                           +- GpuProject [loan_id#2453L, quarter#2484, timestamp_month#2051, timestamp_year#2087, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997]\\n\",\n      \"      :     :     :     :     :     :                                                                                :                                                              +- GpuBroadcastHashJoin [loan_id#2453L, quarter#2484], [loan_id#2202L, quarter#2233], LeftOuter, GpuBuildRight\\n\",\n      \"      :     :     :     :     :     :                                                                                :                                                                 :- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :     :     :     :     :                                                                                :                                                                 :  +- *(3) Project [quarter#2484, loan_id#2453L, month(cast(gettimestamp(monthly_reporting_period#2454, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_month#2051, year(cast(gettimestamp(monthly_reporting_period#2454, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_year#2087]\\n\",\n      \"      :     :     :     :     :     :                                                                                :                                                                 :     +- GpuColumnarToRow false\\n\",\n      \"      :     :     :     :     :     :                                                                                :                                                                 :        +- GpuFilter (gpuisnotnull(quarter#2484) AND gpuisnotnull(loan_id#2453L)), true\\n\",\n      \"      :     :     :     :     :     :                                                                                :                                                                 :           +- GpuFileGpuScan parquet [loan_id#2453L,monthly_reporting_period#2454,quarter#2484] Batched: true, DataFilters: [isnotnull(quarter#2484), isnotnull(loan_id#2453L)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(quarter), IsNotNull(loan_id)], ReadSchema: struct<loan_id:bigint,monthly_reporting_period:string,quarter:string>\\n\",\n      \"      :     :     :     :     :     :                                                                                :                                                                 +- GpuBroadcastExchange HashedRelationBroadcastMode(List(input[1, bigint, true], input[0, string, true]),false), [id=#1096]\\n\",\n      \"      :     :     :     :     :     :                                                                                :                                                                    +- GpuHashAggregate(keys=[quarter#2233, loan_id#2202L], functions=[gpumax(current_loan_delinquency_status#2212), gpumin(delinquency_30#1975), gpumin(delinquency_90#1976), gpumin(delinquency_180#1977)]), filters=ArrayBuffer(None, None, None, None))\\n\",\n      \"      :     :     :     :     :     :                                                                                :                                                                       +- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :     :     :     :     :                                                                                :                                                                          +- GpuColumnarExchange gpuhashpartitioning(quarter#2233, loan_id#2202L, 192), ENSURE_REQUIREMENTS, [id=#1093]\\n\",\n      \"      :     :     :     :     :     :                                                                                :                                                                             +- GpuHashAggregate(keys=[quarter#2233, loan_id#2202L], functions=[partial_gpumax(current_loan_delinquency_status#2212), partial_gpumin(delinquency_30#1975), partial_gpumin(delinquency_90#1976), partial_gpumin(delinquency_180#1977)]), filters=ArrayBuffer(None, None, None, None))\\n\",\n      \"      :     :     :     :     :     :                                                                                :                                                                                +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :     :     :     :     :                                                                                :                                                                                   +- *(4) Project [quarter#2233, loan_id#2202L, current_loan_delinquency_status#2212, CASE WHEN (current_loan_delinquency_status#2212 >= 1) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_30#1975, CASE WHEN (current_loan_delinquency_status#2212 >= 3) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_90#1976, CASE WHEN (current_loan_delinquency_status#2212 >= 6) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_180#1977]\\n\",\n      \"      :     :     :     :     :     :                                                                                :                                                                                      +- GpuColumnarToRow false\\n\",\n      \"      :     :     :     :     :     :                                                                                :                                                                                         +- GpuFilter (gpuisnotnull(loan_id#2202L) AND gpuisnotnull(quarter#2233)), true\\n\",\n      \"      :     :     :     :     :     :                                                                                :                                                                                            +- GpuFileGpuScan parquet [loan_id#2202L,monthly_reporting_period#2203,current_loan_delinquency_status#2212,quarter#2233] Batched: true, DataFilters: [isnotnull(loan_id#2202L), isnotnull(quarter#2233)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(loan_id), IsNotNull(quarter)], ReadSchema: struct<loan_id:bigint,monthly_reporting_period:string,current_loan_delinquency_status:int,quarter...\\n\",\n      \"      :     :     :     :     :     :                                                                                +- GpuColumnarExchange gpuhashpartitioning(loan_id#1603L, quarter#1629, 192), ENSURE_REQUIREMENTS, [id=#1130]\\n\",\n      \"      :     :     :     :     :     :                                                                                   +- GpuProject [loan_id#1603L, orig_channel#1604, gpucoalesce(to_seller_name#2570, seller_name#1605) AS seller_name#2689, first_home_buyer#1616, loan_purpose#1617, property_type#1618, occupancy_status#1620, property_state#1621, product_type#1624, relocation_mortgage_indicator#1627, quarter#1629]\\n\",\n      \"      :     :     :     :     :     :                                                                                      +- GpuShuffledHashJoin [seller_name#1605], [from_seller_name#2569], LeftOuter, GpuBuildRight, false\\n\",\n      \"      :     :     :     :     :     :                                                                                         :- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :     :     :     :     :                                                                                         :  +- GpuColumnarExchange gpuhashpartitioning(seller_name#1605, 192), ENSURE_REQUIREMENTS, [id=#862]\\n\",\n      \"      :     :     :     :     :     :                                                                                         :     +- GpuCoalesceBatches targetsize(536870912)\\n\",\n      \"      :     :     :     :     :     :                                                                                         :        +- GpuFilter (gpuisnotnull(loan_id#1603L) AND gpuisnotnull(quarter#1629)), true\\n\",\n      \"      :     :     :     :     :     :                                                                                         :           +- GpuFileGpuScan parquet [loan_id#1603L,orig_channel#1604,seller_name#1605,first_home_buyer#1616,loan_purpose#1617,property_type#1618,occupancy_status#1620,property_state#1621,product_type#1624,relocation_mortgage_indicator#1627,quarter#1629] Batched: true, DataFilters: [isnotnull(loan_id#1603L), isnotnull(quarter#1629)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/acq], PartitionFilters: [], PushedFilters: [IsNotNull(loan_id), IsNotNull(quarter)], ReadSchema: struct<loan_id:bigint,orig_channel:string,seller_name:string,first_home_buyer:string,loan_purpose...\\n\",\n      \"      :     :     :     :     :     :                                                                                         +- GpuColumnarExchange gpuhashpartitioning(from_seller_name#2569, 192), ENSURE_REQUIREMENTS, [id=#1127]\\n\",\n      \"      :     :     :     :     :     :                                                                                            +- GpuCoalesceBatches targetsize(536870912)\\n\",\n      \"      :     :     :     :     :     :                                                                                               +- GpuFilter gpuisnotnull(from_seller_name#2569), true\\n\",\n      \"      :     :     :     :     :     :                                                                                                  +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :     :     :     :     :                                                                                                     +- *(7) Scan ExistingRDD[from_seller_name#2569,to_seller_name#2570]\\n\",\n      \"      :     :     :     :     :     +- GpuBroadcastExchange HashedRelationBroadcastMode(List(input[0, string, true]),false), [id=#3722]\\n\",\n      \"      :     :     :     :     :        +- GpuProject [data#2945 AS occupancy_status#3830, id#2956]\\n\",\n      \"      :     :     :     :     :           +- GpuCoalesceBatches targetsize(536870912)\\n\",\n      \"      :     :     :     :     :              +- GpuFilter ((column_id#2942 = 4) AND gpuisnotnull(data#2945)), true\\n\",\n      \"      :     :     :     :     :                 +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :     :     :     :                    +- InMemoryTableScan [column_id#2942, data#2945, id#2956], [(column_id#2942 = 4), isnotnull(data#2945)]\\n\",\n      \"      :     :     :     :     :                          +- InMemoryRelation [column_id#2942, data#2945, id#2956], StorageLevel(disk, memory, deserialized, 1 replicas)\\n\",\n      \"      :     :     :     :     :                                +- GpuColumnarToRow false\\n\",\n      \"      :     :     :     :     :                                   +- GpuProject [column_id#2942, data#2945, id#2956]\\n\",\n      \"      :     :     :     :     :                                      +- GpuRunningWindow [column_id#2942, data#2945, count#2951L, gpurownumber$() gpuwindowspecdefinition(column_id#2942, count#2951L DESC NULLS LAST, gpuspecifiedwindowframe(RowFrame, gpuspecialframeboundary(unboundedpreceding$()), gpuspecialframeboundary(currentrow$()))) AS id#2956], [column_id#2942], [count#2951L DESC NULLS LAST]\\n\",\n      \"      :     :     :     :     :                                         +- GpuSort [column_id#2942 ASC NULLS FIRST, count#2951L DESC NULLS LAST], false, com.nvidia.spark.rapids.OutOfCoreSort$@163d9f7d\\n\",\n      \"      :     :     :     :     :                                            +- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :     :     :     :                                               +- GpuColumnarExchange gpuhashpartitioning(column_id#2942, 192), ENSURE_REQUIREMENTS, [id=#1141]\\n\",\n      \"      :     :     :     :     :                                                  +- GpuHashAggregate(keys=[column_id#2942, data#2945], functions=[gpucount(1)]), filters=ArrayBuffer(None))\\n\",\n      \"      :     :     :     :     :                                                     +- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :     :     :     :                                                        +- GpuColumnarExchange gpuhashpartitioning(column_id#2942, data#2945, 192), ENSURE_REQUIREMENTS, [id=#1138]\\n\",\n      \"      :     :     :     :     :                                                           +- GpuHashAggregate(keys=[column_id#2942, data#2945], functions=[partial_gpucount(1)]), filters=ArrayBuffer(None))\\n\",\n      \"      :     :     :     :     :                                                              +- GpuProject [pos#2938 AS column_id#2942, col#2939 AS data#2945]\\n\",\n      \"      :     :     :     :     :                                                                 +- GpuCoalesceBatches targetsize(536870912)\\n\",\n      \"      :     :     :     :     :                                                                    +- GpuFilter gpuisnotnull(col#2939), true\\n\",\n      \"      :     :     :     :     :                                                                       +- GpuGenerate gpuposexplode(array(orig_channel#1604, first_home_buyer#1616, loan_purpose#1617, property_type#1618, occupancy_status#1620, property_state#1621, product_type#1624, relocation_mortgage_indicator#1627, seller_name#2689, mod_flag#1550)), false, [pos#2938, col#2939]\\n\",\n      \"      :     :     :     :     :                                                                          +- GpuProject [mod_flag#1550, orig_channel#1604, seller_name#2689, first_home_buyer#1616, loan_purpose#1617, property_type#1618, occupancy_status#1620, property_state#1621, product_type#1624, relocation_mortgage_indicator#1627]\\n\",\n      \"      :     :     :     :     :                                                                             +- GpuShuffledHashJoin [loan_id#1539L, quarter#1570], [loan_id#1603L, quarter#1629], Inner, GpuBuildRight, false\\n\",\n      \"      :     :     :     :     :                                                                                :- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :     :     :     :                                                                                :  +- GpuColumnarExchange gpuhashpartitioning(loan_id#1539L, quarter#1570, 192), ENSURE_REQUIREMENTS, [id=#1121]\\n\",\n      \"      :     :     :     :     :                                                                                :     +- GpuProject [quarter#1570, loan_id#1539L, mod_flag#1550]\\n\",\n      \"      :     :     :     :     :                                                                                :        +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :     :     :     :                                                                                :           +- SortMergeJoin [quarter#1570, loan_id#1539L, cast(timestamp_year#2417 as bigint), cast(timestamp_month#2381 as bigint)], [quarter#2484, loan_id#2453L, timestamp_year#2307L, timestamp_month#2336L], LeftOuter\\n\",\n      \"      :     :     :     :     :                                                                                :              :- *(2) Sort [quarter#1570 ASC NULLS FIRST, loan_id#1539L ASC NULLS FIRST, cast(timestamp_year#2417 as bigint) ASC NULLS FIRST, cast(timestamp_month#2381 as bigint) ASC NULLS FIRST], false, 0\\n\",\n      \"      :     :     :     :     :                                                                                :              :  +- Exchange hashpartitioning(quarter#1570, loan_id#1539L, cast(timestamp_year#2417 as bigint), cast(timestamp_month#2381 as bigint), 192), ENSURE_REQUIREMENTS, [id=#1080]\\n\",\n      \"      :     :     :     :     :                                                                                :              :     +- *(1) Project [loan_id#1539L, mod_flag#1550, quarter#1570, month(cast(gettimestamp(monthly_reporting_period#1540, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_month#2381, year(cast(gettimestamp(monthly_reporting_period#1540, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_year#2417]\\n\",\n      \"      :     :     :     :     :                                                                                :              :        +- GpuColumnarToRow false\\n\",\n      \"      :     :     :     :     :                                                                                :              :           +- GpuFilter (gpuisnotnull(loan_id#1539L) AND gpuisnotnull(quarter#1570)), true\\n\",\n      \"      :     :     :     :     :                                                                                :              :              +- GpuFileGpuScan parquet [loan_id#1539L,monthly_reporting_period#1540,mod_flag#1550,quarter#1570] Batched: true, DataFilters: [isnotnull(loan_id#1539L), isnotnull(quarter#1570)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(loan_id), IsNotNull(quarter)], ReadSchema: struct<loan_id:bigint,monthly_reporting_period:string,mod_flag:string,quarter:string>\\n\",\n      \"      :     :     :     :     :                                                                                :              +- GpuColumnarToRow false\\n\",\n      \"      :     :     :     :     :                                                                                :                 +- GpuSort [quarter#2484 ASC NULLS FIRST, loan_id#2453L ASC NULLS FIRST, timestamp_year#2307L ASC NULLS FIRST, timestamp_month#2336L ASC NULLS FIRST], false, com.nvidia.spark.rapids.OutOfCoreSort$@163d9f7d\\n\",\n      \"      :     :     :     :     :                                                                                :                    +- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :     :     :     :                                                                                :                       +- GpuColumnarExchange gpuhashpartitioning(quarter#2484, loan_id#2453L, timestamp_year#2307L, timestamp_month#2336L, 192), ENSURE_REQUIREMENTS, [id=#1114]\\n\",\n      \"      :     :     :     :     :                                                                                :                          +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :     :     :     :                                                                                :                             +- *(6) HashAggregate(keys=[quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248], functions=[])\\n\",\n      \"      :     :     :     :     :                                                                                :                                +- GpuColumnarToRow false\\n\",\n      \"      :     :     :     :     :                                                                                :                                   +- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :     :     :     :                                                                                :                                      +- GpuColumnarExchange gpuhashpartitioning(quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248, 192), ENSURE_REQUIREMENTS, [id=#1107]\\n\",\n      \"      :     :     :     :     :                                                                                :                                         +- GpuHashAggregate(keys=[quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248], functions=[]), filters=ArrayBuffer())\\n\",\n      \"      :     :     :     :     :                                                                                :                                            +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :     :     :     :                                                                                :                                               +- *(5) Project [quarter#2484, FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) AS josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, loan_id#2453L, month_y#2248]\\n\",\n      \"      :     :     :     :     :                                                                                :                                                  +- *(5) Filter (isnotnull(FLOOR((cast(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast((month_y#2248 - 1) as bigint)) as double) / 12.0))) AND isnotnull(CASE WHEN (pmod(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast(month_y#2248 as bigint)), 12) = 0) THEN 12 ELSE pmod(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast(month_y#2248 as bigint)), 12) END))\\n\",\n      \"      :     :     :     :     :                                                                                :                                                     +- GpuColumnarToRow false\\n\",\n      \"      :     :     :     :     :                                                                                :                                                        +- GpuGenerate gpuexplode([0,1,2,3,4,5,6,7,8,9,10,11]), [loan_id#2453L, quarter#2484, timestamp_month#2051, timestamp_year#2087, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997], false, [month_y#2248]\\n\",\n      \"      :     :     :     :     :                                                                                :                                                           +- GpuProject [loan_id#2453L, quarter#2484, timestamp_month#2051, timestamp_year#2087, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997]\\n\",\n      \"      :     :     :     :     :                                                                                :                                                              +- GpuBroadcastHashJoin [loan_id#2453L, quarter#2484], [loan_id#2202L, quarter#2233], LeftOuter, GpuBuildRight\\n\",\n      \"      :     :     :     :     :                                                                                :                                                                 :- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :     :     :     :                                                                                :                                                                 :  +- *(3) Project [quarter#2484, loan_id#2453L, month(cast(gettimestamp(monthly_reporting_period#2454, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_month#2051, year(cast(gettimestamp(monthly_reporting_period#2454, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_year#2087]\\n\",\n      \"      :     :     :     :     :                                                                                :                                                                 :     +- GpuColumnarToRow false\\n\",\n      \"      :     :     :     :     :                                                                                :                                                                 :        +- GpuFilter (gpuisnotnull(quarter#2484) AND gpuisnotnull(loan_id#2453L)), true\\n\",\n      \"      :     :     :     :     :                                                                                :                                                                 :           +- GpuFileGpuScan parquet [loan_id#2453L,monthly_reporting_period#2454,quarter#2484] Batched: true, DataFilters: [isnotnull(quarter#2484), isnotnull(loan_id#2453L)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(quarter), IsNotNull(loan_id)], ReadSchema: struct<loan_id:bigint,monthly_reporting_period:string,quarter:string>\\n\",\n      \"      :     :     :     :     :                                                                                :                                                                 +- GpuBroadcastExchange HashedRelationBroadcastMode(List(input[1, bigint, true], input[0, string, true]),false), [id=#1096]\\n\",\n      \"      :     :     :     :     :                                                                                :                                                                    +- GpuHashAggregate(keys=[quarter#2233, loan_id#2202L], functions=[gpumax(current_loan_delinquency_status#2212), gpumin(delinquency_30#1975), gpumin(delinquency_90#1976), gpumin(delinquency_180#1977)]), filters=ArrayBuffer(None, None, None, None))\\n\",\n      \"      :     :     :     :     :                                                                                :                                                                       +- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :     :     :     :                                                                                :                                                                          +- GpuColumnarExchange gpuhashpartitioning(quarter#2233, loan_id#2202L, 192), ENSURE_REQUIREMENTS, [id=#1093]\\n\",\n      \"      :     :     :     :     :                                                                                :                                                                             +- GpuHashAggregate(keys=[quarter#2233, loan_id#2202L], functions=[partial_gpumax(current_loan_delinquency_status#2212), partial_gpumin(delinquency_30#1975), partial_gpumin(delinquency_90#1976), partial_gpumin(delinquency_180#1977)]), filters=ArrayBuffer(None, None, None, None))\\n\",\n      \"      :     :     :     :     :                                                                                :                                                                                +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :     :     :     :                                                                                :                                                                                   +- *(4) Project [quarter#2233, loan_id#2202L, current_loan_delinquency_status#2212, CASE WHEN (current_loan_delinquency_status#2212 >= 1) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_30#1975, CASE WHEN (current_loan_delinquency_status#2212 >= 3) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_90#1976, CASE WHEN (current_loan_delinquency_status#2212 >= 6) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_180#1977]\\n\",\n      \"      :     :     :     :     :                                                                                :                                                                                      +- GpuColumnarToRow false\\n\",\n      \"      :     :     :     :     :                                                                                :                                                                                         +- GpuFilter (gpuisnotnull(loan_id#2202L) AND gpuisnotnull(quarter#2233)), true\\n\",\n      \"      :     :     :     :     :                                                                                :                                                                                            +- GpuFileGpuScan parquet [loan_id#2202L,monthly_reporting_period#2203,current_loan_delinquency_status#2212,quarter#2233] Batched: true, DataFilters: [isnotnull(loan_id#2202L), isnotnull(quarter#2233)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(loan_id), IsNotNull(quarter)], ReadSchema: struct<loan_id:bigint,monthly_reporting_period:string,current_loan_delinquency_status:int,quarter...\\n\",\n      \"      :     :     :     :     :                                                                                +- GpuColumnarExchange gpuhashpartitioning(loan_id#1603L, quarter#1629, 192), ENSURE_REQUIREMENTS, [id=#1130]\\n\",\n      \"      :     :     :     :     :                                                                                   +- GpuProject [loan_id#1603L, orig_channel#1604, gpucoalesce(to_seller_name#2570, seller_name#1605) AS seller_name#2689, first_home_buyer#1616, loan_purpose#1617, property_type#1618, occupancy_status#1620, property_state#1621, product_type#1624, relocation_mortgage_indicator#1627, quarter#1629]\\n\",\n      \"      :     :     :     :     :                                                                                      +- GpuShuffledHashJoin [seller_name#1605], [from_seller_name#2569], LeftOuter, GpuBuildRight, false\\n\",\n      \"      :     :     :     :     :                                                                                         :- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :     :     :     :                                                                                         :  +- GpuColumnarExchange gpuhashpartitioning(seller_name#1605, 192), ENSURE_REQUIREMENTS, [id=#862]\\n\",\n      \"      :     :     :     :     :                                                                                         :     +- GpuCoalesceBatches targetsize(536870912)\\n\",\n      \"      :     :     :     :     :                                                                                         :        +- GpuFilter (gpuisnotnull(loan_id#1603L) AND gpuisnotnull(quarter#1629)), true\\n\",\n      \"      :     :     :     :     :                                                                                         :           +- GpuFileGpuScan parquet [loan_id#1603L,orig_channel#1604,seller_name#1605,first_home_buyer#1616,loan_purpose#1617,property_type#1618,occupancy_status#1620,property_state#1621,product_type#1624,relocation_mortgage_indicator#1627,quarter#1629] Batched: true, DataFilters: [isnotnull(loan_id#1603L), isnotnull(quarter#1629)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/acq], PartitionFilters: [], PushedFilters: [IsNotNull(loan_id), IsNotNull(quarter)], ReadSchema: struct<loan_id:bigint,orig_channel:string,seller_name:string,first_home_buyer:string,loan_purpose...\\n\",\n      \"      :     :     :     :     :                                                                                         +- GpuColumnarExchange gpuhashpartitioning(from_seller_name#2569, 192), ENSURE_REQUIREMENTS, [id=#1127]\\n\",\n      \"      :     :     :     :     :                                                                                            +- GpuCoalesceBatches targetsize(536870912)\\n\",\n      \"      :     :     :     :     :                                                                                               +- GpuFilter gpuisnotnull(from_seller_name#2569), true\\n\",\n      \"      :     :     :     :     :                                                                                                  +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :     :     :     :                                                                                                     +- *(7) Scan ExistingRDD[from_seller_name#2569,to_seller_name#2570]\\n\",\n      \"      :     :     :     :     +- GpuBroadcastExchange HashedRelationBroadcastMode(List(input[0, string, true]),false), [id=#3729]\\n\",\n      \"      :     :     :     :        +- GpuProject [data#2945 AS property_state#4035, id#2956]\\n\",\n      \"      :     :     :     :           +- GpuCoalesceBatches targetsize(536870912)\\n\",\n      \"      :     :     :     :              +- GpuFilter ((column_id#2942 = 5) AND gpuisnotnull(data#2945)), true\\n\",\n      \"      :     :     :     :                 +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :     :     :                    +- InMemoryTableScan [column_id#2942, data#2945, id#2956], [(column_id#2942 = 5), isnotnull(data#2945)]\\n\",\n      \"      :     :     :     :                          +- InMemoryRelation [column_id#2942, data#2945, id#2956], StorageLevel(disk, memory, deserialized, 1 replicas)\\n\",\n      \"      :     :     :     :                                +- GpuColumnarToRow false\\n\",\n      \"      :     :     :     :                                   +- GpuProject [column_id#2942, data#2945, id#2956]\\n\",\n      \"      :     :     :     :                                      +- GpuRunningWindow [column_id#2942, data#2945, count#2951L, gpurownumber$() gpuwindowspecdefinition(column_id#2942, count#2951L DESC NULLS LAST, gpuspecifiedwindowframe(RowFrame, gpuspecialframeboundary(unboundedpreceding$()), gpuspecialframeboundary(currentrow$()))) AS id#2956], [column_id#2942], [count#2951L DESC NULLS LAST]\\n\",\n      \"      :     :     :     :                                         +- GpuSort [column_id#2942 ASC NULLS FIRST, count#2951L DESC NULLS LAST], false, com.nvidia.spark.rapids.OutOfCoreSort$@163d9f7d\\n\",\n      \"      :     :     :     :                                            +- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :     :     :                                               +- GpuColumnarExchange gpuhashpartitioning(column_id#2942, 192), ENSURE_REQUIREMENTS, [id=#1141]\\n\",\n      \"      :     :     :     :                                                  +- GpuHashAggregate(keys=[column_id#2942, data#2945], functions=[gpucount(1)]), filters=ArrayBuffer(None))\\n\",\n      \"      :     :     :     :                                                     +- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :     :     :                                                        +- GpuColumnarExchange gpuhashpartitioning(column_id#2942, data#2945, 192), ENSURE_REQUIREMENTS, [id=#1138]\\n\",\n      \"      :     :     :     :                                                           +- GpuHashAggregate(keys=[column_id#2942, data#2945], functions=[partial_gpucount(1)]), filters=ArrayBuffer(None))\\n\",\n      \"      :     :     :     :                                                              +- GpuProject [pos#2938 AS column_id#2942, col#2939 AS data#2945]\\n\",\n      \"      :     :     :     :                                                                 +- GpuCoalesceBatches targetsize(536870912)\\n\",\n      \"      :     :     :     :                                                                    +- GpuFilter gpuisnotnull(col#2939), true\\n\",\n      \"      :     :     :     :                                                                       +- GpuGenerate gpuposexplode(array(orig_channel#1604, first_home_buyer#1616, loan_purpose#1617, property_type#1618, occupancy_status#1620, property_state#1621, product_type#1624, relocation_mortgage_indicator#1627, seller_name#2689, mod_flag#1550)), false, [pos#2938, col#2939]\\n\",\n      \"      :     :     :     :                                                                          +- GpuProject [mod_flag#1550, orig_channel#1604, seller_name#2689, first_home_buyer#1616, loan_purpose#1617, property_type#1618, occupancy_status#1620, property_state#1621, product_type#1624, relocation_mortgage_indicator#1627]\\n\",\n      \"      :     :     :     :                                                                             +- GpuShuffledHashJoin [loan_id#1539L, quarter#1570], [loan_id#1603L, quarter#1629], Inner, GpuBuildRight, false\\n\",\n      \"      :     :     :     :                                                                                :- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :     :     :                                                                                :  +- GpuColumnarExchange gpuhashpartitioning(loan_id#1539L, quarter#1570, 192), ENSURE_REQUIREMENTS, [id=#1121]\\n\",\n      \"      :     :     :     :                                                                                :     +- GpuProject [quarter#1570, loan_id#1539L, mod_flag#1550]\\n\",\n      \"      :     :     :     :                                                                                :        +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :     :     :                                                                                :           +- SortMergeJoin [quarter#1570, loan_id#1539L, cast(timestamp_year#2417 as bigint), cast(timestamp_month#2381 as bigint)], [quarter#2484, loan_id#2453L, timestamp_year#2307L, timestamp_month#2336L], LeftOuter\\n\",\n      \"      :     :     :     :                                                                                :              :- *(2) Sort [quarter#1570 ASC NULLS FIRST, loan_id#1539L ASC NULLS FIRST, cast(timestamp_year#2417 as bigint) ASC NULLS FIRST, cast(timestamp_month#2381 as bigint) ASC NULLS FIRST], false, 0\\n\",\n      \"      :     :     :     :                                                                                :              :  +- Exchange hashpartitioning(quarter#1570, loan_id#1539L, cast(timestamp_year#2417 as bigint), cast(timestamp_month#2381 as bigint), 192), ENSURE_REQUIREMENTS, [id=#1080]\\n\",\n      \"      :     :     :     :                                                                                :              :     +- *(1) Project [loan_id#1539L, mod_flag#1550, quarter#1570, month(cast(gettimestamp(monthly_reporting_period#1540, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_month#2381, year(cast(gettimestamp(monthly_reporting_period#1540, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_year#2417]\\n\",\n      \"      :     :     :     :                                                                                :              :        +- GpuColumnarToRow false\\n\",\n      \"      :     :     :     :                                                                                :              :           +- GpuFilter (gpuisnotnull(loan_id#1539L) AND gpuisnotnull(quarter#1570)), true\\n\",\n      \"      :     :     :     :                                                                                :              :              +- GpuFileGpuScan parquet [loan_id#1539L,monthly_reporting_period#1540,mod_flag#1550,quarter#1570] Batched: true, DataFilters: [isnotnull(loan_id#1539L), isnotnull(quarter#1570)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(loan_id), IsNotNull(quarter)], ReadSchema: struct<loan_id:bigint,monthly_reporting_period:string,mod_flag:string,quarter:string>\\n\",\n      \"      :     :     :     :                                                                                :              +- GpuColumnarToRow false\\n\",\n      \"      :     :     :     :                                                                                :                 +- GpuSort [quarter#2484 ASC NULLS FIRST, loan_id#2453L ASC NULLS FIRST, timestamp_year#2307L ASC NULLS FIRST, timestamp_month#2336L ASC NULLS FIRST], false, com.nvidia.spark.rapids.OutOfCoreSort$@163d9f7d\\n\",\n      \"      :     :     :     :                                                                                :                    +- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :     :     :                                                                                :                       +- GpuColumnarExchange gpuhashpartitioning(quarter#2484, loan_id#2453L, timestamp_year#2307L, timestamp_month#2336L, 192), ENSURE_REQUIREMENTS, [id=#1114]\\n\",\n      \"      :     :     :     :                                                                                :                          +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :     :     :                                                                                :                             +- *(6) HashAggregate(keys=[quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248], functions=[])\\n\",\n      \"      :     :     :     :                                                                                :                                +- GpuColumnarToRow false\\n\",\n      \"      :     :     :     :                                                                                :                                   +- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :     :     :                                                                                :                                      +- GpuColumnarExchange gpuhashpartitioning(quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248, 192), ENSURE_REQUIREMENTS, [id=#1107]\\n\",\n      \"      :     :     :     :                                                                                :                                         +- GpuHashAggregate(keys=[quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248], functions=[]), filters=ArrayBuffer())\\n\",\n      \"      :     :     :     :                                                                                :                                            +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :     :     :                                                                                :                                               +- *(5) Project [quarter#2484, FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) AS josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, loan_id#2453L, month_y#2248]\\n\",\n      \"      :     :     :     :                                                                                :                                                  +- *(5) Filter (isnotnull(FLOOR((cast(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast((month_y#2248 - 1) as bigint)) as double) / 12.0))) AND isnotnull(CASE WHEN (pmod(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast(month_y#2248 as bigint)), 12) = 0) THEN 12 ELSE pmod(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast(month_y#2248 as bigint)), 12) END))\\n\",\n      \"      :     :     :     :                                                                                :                                                     +- GpuColumnarToRow false\\n\",\n      \"      :     :     :     :                                                                                :                                                        +- GpuGenerate gpuexplode([0,1,2,3,4,5,6,7,8,9,10,11]), [loan_id#2453L, quarter#2484, timestamp_month#2051, timestamp_year#2087, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997], false, [month_y#2248]\\n\",\n      \"      :     :     :     :                                                                                :                                                           +- GpuProject [loan_id#2453L, quarter#2484, timestamp_month#2051, timestamp_year#2087, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997]\\n\",\n      \"      :     :     :     :                                                                                :                                                              +- GpuBroadcastHashJoin [loan_id#2453L, quarter#2484], [loan_id#2202L, quarter#2233], LeftOuter, GpuBuildRight\\n\",\n      \"      :     :     :     :                                                                                :                                                                 :- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :     :     :                                                                                :                                                                 :  +- *(3) Project [quarter#2484, loan_id#2453L, month(cast(gettimestamp(monthly_reporting_period#2454, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_month#2051, year(cast(gettimestamp(monthly_reporting_period#2454, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_year#2087]\\n\",\n      \"      :     :     :     :                                                                                :                                                                 :     +- GpuColumnarToRow false\\n\",\n      \"      :     :     :     :                                                                                :                                                                 :        +- GpuFilter (gpuisnotnull(quarter#2484) AND gpuisnotnull(loan_id#2453L)), true\\n\",\n      \"      :     :     :     :                                                                                :                                                                 :           +- GpuFileGpuScan parquet [loan_id#2453L,monthly_reporting_period#2454,quarter#2484] Batched: true, DataFilters: [isnotnull(quarter#2484), isnotnull(loan_id#2453L)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(quarter), IsNotNull(loan_id)], ReadSchema: struct<loan_id:bigint,monthly_reporting_period:string,quarter:string>\\n\",\n      \"      :     :     :     :                                                                                :                                                                 +- GpuBroadcastExchange HashedRelationBroadcastMode(List(input[1, bigint, true], input[0, string, true]),false), [id=#1096]\\n\",\n      \"      :     :     :     :                                                                                :                                                                    +- GpuHashAggregate(keys=[quarter#2233, loan_id#2202L], functions=[gpumax(current_loan_delinquency_status#2212), gpumin(delinquency_30#1975), gpumin(delinquency_90#1976), gpumin(delinquency_180#1977)]), filters=ArrayBuffer(None, None, None, None))\\n\",\n      \"      :     :     :     :                                                                                :                                                                       +- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :     :     :                                                                                :                                                                          +- GpuColumnarExchange gpuhashpartitioning(quarter#2233, loan_id#2202L, 192), ENSURE_REQUIREMENTS, [id=#1093]\\n\",\n      \"      :     :     :     :                                                                                :                                                                             +- GpuHashAggregate(keys=[quarter#2233, loan_id#2202L], functions=[partial_gpumax(current_loan_delinquency_status#2212), partial_gpumin(delinquency_30#1975), partial_gpumin(delinquency_90#1976), partial_gpumin(delinquency_180#1977)]), filters=ArrayBuffer(None, None, None, None))\\n\",\n      \"      :     :     :     :                                                                                :                                                                                +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :     :     :                                                                                :                                                                                   +- *(4) Project [quarter#2233, loan_id#2202L, current_loan_delinquency_status#2212, CASE WHEN (current_loan_delinquency_status#2212 >= 1) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_30#1975, CASE WHEN (current_loan_delinquency_status#2212 >= 3) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_90#1976, CASE WHEN (current_loan_delinquency_status#2212 >= 6) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_180#1977]\\n\",\n      \"      :     :     :     :                                                                                :                                                                                      +- GpuColumnarToRow false\\n\",\n      \"      :     :     :     :                                                                                :                                                                                         +- GpuFilter (gpuisnotnull(loan_id#2202L) AND gpuisnotnull(quarter#2233)), true\\n\",\n      \"      :     :     :     :                                                                                :                                                                                            +- GpuFileGpuScan parquet [loan_id#2202L,monthly_reporting_period#2203,current_loan_delinquency_status#2212,quarter#2233] Batched: true, DataFilters: [isnotnull(loan_id#2202L), isnotnull(quarter#2233)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(loan_id), IsNotNull(quarter)], ReadSchema: struct<loan_id:bigint,monthly_reporting_period:string,current_loan_delinquency_status:int,quarter...\\n\",\n      \"      :     :     :     :                                                                                +- GpuColumnarExchange gpuhashpartitioning(loan_id#1603L, quarter#1629, 192), ENSURE_REQUIREMENTS, [id=#1130]\\n\",\n      \"      :     :     :     :                                                                                   +- GpuProject [loan_id#1603L, orig_channel#1604, gpucoalesce(to_seller_name#2570, seller_name#1605) AS seller_name#2689, first_home_buyer#1616, loan_purpose#1617, property_type#1618, occupancy_status#1620, property_state#1621, product_type#1624, relocation_mortgage_indicator#1627, quarter#1629]\\n\",\n      \"      :     :     :     :                                                                                      +- GpuShuffledHashJoin [seller_name#1605], [from_seller_name#2569], LeftOuter, GpuBuildRight, false\\n\",\n      \"      :     :     :     :                                                                                         :- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :     :     :                                                                                         :  +- GpuColumnarExchange gpuhashpartitioning(seller_name#1605, 192), ENSURE_REQUIREMENTS, [id=#862]\\n\",\n      \"      :     :     :     :                                                                                         :     +- GpuCoalesceBatches targetsize(536870912)\\n\",\n      \"      :     :     :     :                                                                                         :        +- GpuFilter (gpuisnotnull(loan_id#1603L) AND gpuisnotnull(quarter#1629)), true\\n\",\n      \"      :     :     :     :                                                                                         :           +- GpuFileGpuScan parquet [loan_id#1603L,orig_channel#1604,seller_name#1605,first_home_buyer#1616,loan_purpose#1617,property_type#1618,occupancy_status#1620,property_state#1621,product_type#1624,relocation_mortgage_indicator#1627,quarter#1629] Batched: true, DataFilters: [isnotnull(loan_id#1603L), isnotnull(quarter#1629)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/acq], PartitionFilters: [], PushedFilters: [IsNotNull(loan_id), IsNotNull(quarter)], ReadSchema: struct<loan_id:bigint,orig_channel:string,seller_name:string,first_home_buyer:string,loan_purpose...\\n\",\n      \"      :     :     :     :                                                                                         +- GpuColumnarExchange gpuhashpartitioning(from_seller_name#2569, 192), ENSURE_REQUIREMENTS, [id=#1127]\\n\",\n      \"      :     :     :     :                                                                                            +- GpuCoalesceBatches targetsize(536870912)\\n\",\n      \"      :     :     :     :                                                                                               +- GpuFilter gpuisnotnull(from_seller_name#2569), true\\n\",\n      \"      :     :     :     :                                                                                                  +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :     :     :                                                                                                     +- *(7) Scan ExistingRDD[from_seller_name#2569,to_seller_name#2570]\\n\",\n      \"      :     :     :     +- GpuBroadcastExchange HashedRelationBroadcastMode(List(input[0, string, true]),false), [id=#3736]\\n\",\n      \"      :     :     :        +- GpuProject [data#2945 AS product_type#4240, id#2956]\\n\",\n      \"      :     :     :           +- GpuCoalesceBatches targetsize(536870912)\\n\",\n      \"      :     :     :              +- GpuFilter ((column_id#2942 = 6) AND gpuisnotnull(data#2945)), true\\n\",\n      \"      :     :     :                 +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :     :                    +- InMemoryTableScan [column_id#2942, data#2945, id#2956], [(column_id#2942 = 6), isnotnull(data#2945)]\\n\",\n      \"      :     :     :                          +- InMemoryRelation [column_id#2942, data#2945, id#2956], StorageLevel(disk, memory, deserialized, 1 replicas)\\n\",\n      \"      :     :     :                                +- GpuColumnarToRow false\\n\",\n      \"      :     :     :                                   +- GpuProject [column_id#2942, data#2945, id#2956]\\n\",\n      \"      :     :     :                                      +- GpuRunningWindow [column_id#2942, data#2945, count#2951L, gpurownumber$() gpuwindowspecdefinition(column_id#2942, count#2951L DESC NULLS LAST, gpuspecifiedwindowframe(RowFrame, gpuspecialframeboundary(unboundedpreceding$()), gpuspecialframeboundary(currentrow$()))) AS id#2956], [column_id#2942], [count#2951L DESC NULLS LAST]\\n\",\n      \"      :     :     :                                         +- GpuSort [column_id#2942 ASC NULLS FIRST, count#2951L DESC NULLS LAST], false, com.nvidia.spark.rapids.OutOfCoreSort$@163d9f7d\\n\",\n      \"      :     :     :                                            +- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :     :                                               +- GpuColumnarExchange gpuhashpartitioning(column_id#2942, 192), ENSURE_REQUIREMENTS, [id=#1141]\\n\",\n      \"      :     :     :                                                  +- GpuHashAggregate(keys=[column_id#2942, data#2945], functions=[gpucount(1)]), filters=ArrayBuffer(None))\\n\",\n      \"      :     :     :                                                     +- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :     :                                                        +- GpuColumnarExchange gpuhashpartitioning(column_id#2942, data#2945, 192), ENSURE_REQUIREMENTS, [id=#1138]\\n\",\n      \"      :     :     :                                                           +- GpuHashAggregate(keys=[column_id#2942, data#2945], functions=[partial_gpucount(1)]), filters=ArrayBuffer(None))\\n\",\n      \"      :     :     :                                                              +- GpuProject [pos#2938 AS column_id#2942, col#2939 AS data#2945]\\n\",\n      \"      :     :     :                                                                 +- GpuCoalesceBatches targetsize(536870912)\\n\",\n      \"      :     :     :                                                                    +- GpuFilter gpuisnotnull(col#2939), true\\n\",\n      \"      :     :     :                                                                       +- GpuGenerate gpuposexplode(array(orig_channel#1604, first_home_buyer#1616, loan_purpose#1617, property_type#1618, occupancy_status#1620, property_state#1621, product_type#1624, relocation_mortgage_indicator#1627, seller_name#2689, mod_flag#1550)), false, [pos#2938, col#2939]\\n\",\n      \"      :     :     :                                                                          +- GpuProject [mod_flag#1550, orig_channel#1604, seller_name#2689, first_home_buyer#1616, loan_purpose#1617, property_type#1618, occupancy_status#1620, property_state#1621, product_type#1624, relocation_mortgage_indicator#1627]\\n\",\n      \"      :     :     :                                                                             +- GpuShuffledHashJoin [loan_id#1539L, quarter#1570], [loan_id#1603L, quarter#1629], Inner, GpuBuildRight, false\\n\",\n      \"      :     :     :                                                                                :- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :     :                                                                                :  +- GpuColumnarExchange gpuhashpartitioning(loan_id#1539L, quarter#1570, 192), ENSURE_REQUIREMENTS, [id=#1121]\\n\",\n      \"      :     :     :                                                                                :     +- GpuProject [quarter#1570, loan_id#1539L, mod_flag#1550]\\n\",\n      \"      :     :     :                                                                                :        +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :     :                                                                                :           +- SortMergeJoin [quarter#1570, loan_id#1539L, cast(timestamp_year#2417 as bigint), cast(timestamp_month#2381 as bigint)], [quarter#2484, loan_id#2453L, timestamp_year#2307L, timestamp_month#2336L], LeftOuter\\n\",\n      \"      :     :     :                                                                                :              :- *(2) Sort [quarter#1570 ASC NULLS FIRST, loan_id#1539L ASC NULLS FIRST, cast(timestamp_year#2417 as bigint) ASC NULLS FIRST, cast(timestamp_month#2381 as bigint) ASC NULLS FIRST], false, 0\\n\",\n      \"      :     :     :                                                                                :              :  +- Exchange hashpartitioning(quarter#1570, loan_id#1539L, cast(timestamp_year#2417 as bigint), cast(timestamp_month#2381 as bigint), 192), ENSURE_REQUIREMENTS, [id=#1080]\\n\",\n      \"      :     :     :                                                                                :              :     +- *(1) Project [loan_id#1539L, mod_flag#1550, quarter#1570, month(cast(gettimestamp(monthly_reporting_period#1540, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_month#2381, year(cast(gettimestamp(monthly_reporting_period#1540, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_year#2417]\\n\",\n      \"      :     :     :                                                                                :              :        +- GpuColumnarToRow false\\n\",\n      \"      :     :     :                                                                                :              :           +- GpuFilter (gpuisnotnull(loan_id#1539L) AND gpuisnotnull(quarter#1570)), true\\n\",\n      \"      :     :     :                                                                                :              :              +- GpuFileGpuScan parquet [loan_id#1539L,monthly_reporting_period#1540,mod_flag#1550,quarter#1570] Batched: true, DataFilters: [isnotnull(loan_id#1539L), isnotnull(quarter#1570)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(loan_id), IsNotNull(quarter)], ReadSchema: struct<loan_id:bigint,monthly_reporting_period:string,mod_flag:string,quarter:string>\\n\",\n      \"      :     :     :                                                                                :              +- GpuColumnarToRow false\\n\",\n      \"      :     :     :                                                                                :                 +- GpuSort [quarter#2484 ASC NULLS FIRST, loan_id#2453L ASC NULLS FIRST, timestamp_year#2307L ASC NULLS FIRST, timestamp_month#2336L ASC NULLS FIRST], false, com.nvidia.spark.rapids.OutOfCoreSort$@163d9f7d\\n\",\n      \"      :     :     :                                                                                :                    +- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :     :                                                                                :                       +- GpuColumnarExchange gpuhashpartitioning(quarter#2484, loan_id#2453L, timestamp_year#2307L, timestamp_month#2336L, 192), ENSURE_REQUIREMENTS, [id=#1114]\\n\",\n      \"      :     :     :                                                                                :                          +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :     :                                                                                :                             +- *(6) HashAggregate(keys=[quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248], functions=[])\\n\",\n      \"      :     :     :                                                                                :                                +- GpuColumnarToRow false\\n\",\n      \"      :     :     :                                                                                :                                   +- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :     :                                                                                :                                      +- GpuColumnarExchange gpuhashpartitioning(quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248, 192), ENSURE_REQUIREMENTS, [id=#1107]\\n\",\n      \"      :     :     :                                                                                :                                         +- GpuHashAggregate(keys=[quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248], functions=[]), filters=ArrayBuffer())\\n\",\n      \"      :     :     :                                                                                :                                            +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :     :                                                                                :                                               +- *(5) Project [quarter#2484, FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) AS josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, loan_id#2453L, month_y#2248]\\n\",\n      \"      :     :     :                                                                                :                                                  +- *(5) Filter (isnotnull(FLOOR((cast(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast((month_y#2248 - 1) as bigint)) as double) / 12.0))) AND isnotnull(CASE WHEN (pmod(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast(month_y#2248 as bigint)), 12) = 0) THEN 12 ELSE pmod(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast(month_y#2248 as bigint)), 12) END))\\n\",\n      \"      :     :     :                                                                                :                                                     +- GpuColumnarToRow false\\n\",\n      \"      :     :     :                                                                                :                                                        +- GpuGenerate gpuexplode([0,1,2,3,4,5,6,7,8,9,10,11]), [loan_id#2453L, quarter#2484, timestamp_month#2051, timestamp_year#2087, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997], false, [month_y#2248]\\n\",\n      \"      :     :     :                                                                                :                                                           +- GpuProject [loan_id#2453L, quarter#2484, timestamp_month#2051, timestamp_year#2087, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997]\\n\",\n      \"      :     :     :                                                                                :                                                              +- GpuBroadcastHashJoin [loan_id#2453L, quarter#2484], [loan_id#2202L, quarter#2233], LeftOuter, GpuBuildRight\\n\",\n      \"      :     :     :                                                                                :                                                                 :- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :     :                                                                                :                                                                 :  +- *(3) Project [quarter#2484, loan_id#2453L, month(cast(gettimestamp(monthly_reporting_period#2454, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_month#2051, year(cast(gettimestamp(monthly_reporting_period#2454, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_year#2087]\\n\",\n      \"      :     :     :                                                                                :                                                                 :     +- GpuColumnarToRow false\\n\",\n      \"      :     :     :                                                                                :                                                                 :        +- GpuFilter (gpuisnotnull(quarter#2484) AND gpuisnotnull(loan_id#2453L)), true\\n\",\n      \"      :     :     :                                                                                :                                                                 :           +- GpuFileGpuScan parquet [loan_id#2453L,monthly_reporting_period#2454,quarter#2484] Batched: true, DataFilters: [isnotnull(quarter#2484), isnotnull(loan_id#2453L)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(quarter), IsNotNull(loan_id)], ReadSchema: struct<loan_id:bigint,monthly_reporting_period:string,quarter:string>\\n\",\n      \"      :     :     :                                                                                :                                                                 +- GpuBroadcastExchange HashedRelationBroadcastMode(List(input[1, bigint, true], input[0, string, true]),false), [id=#1096]\\n\",\n      \"      :     :     :                                                                                :                                                                    +- GpuHashAggregate(keys=[quarter#2233, loan_id#2202L], functions=[gpumax(current_loan_delinquency_status#2212), gpumin(delinquency_30#1975), gpumin(delinquency_90#1976), gpumin(delinquency_180#1977)]), filters=ArrayBuffer(None, None, None, None))\\n\",\n      \"      :     :     :                                                                                :                                                                       +- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :     :                                                                                :                                                                          +- GpuColumnarExchange gpuhashpartitioning(quarter#2233, loan_id#2202L, 192), ENSURE_REQUIREMENTS, [id=#1093]\\n\",\n      \"      :     :     :                                                                                :                                                                             +- GpuHashAggregate(keys=[quarter#2233, loan_id#2202L], functions=[partial_gpumax(current_loan_delinquency_status#2212), partial_gpumin(delinquency_30#1975), partial_gpumin(delinquency_90#1976), partial_gpumin(delinquency_180#1977)]), filters=ArrayBuffer(None, None, None, None))\\n\",\n      \"      :     :     :                                                                                :                                                                                +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :     :                                                                                :                                                                                   +- *(4) Project [quarter#2233, loan_id#2202L, current_loan_delinquency_status#2212, CASE WHEN (current_loan_delinquency_status#2212 >= 1) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_30#1975, CASE WHEN (current_loan_delinquency_status#2212 >= 3) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_90#1976, CASE WHEN (current_loan_delinquency_status#2212 >= 6) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_180#1977]\\n\",\n      \"      :     :     :                                                                                :                                                                                      +- GpuColumnarToRow false\\n\",\n      \"      :     :     :                                                                                :                                                                                         +- GpuFilter (gpuisnotnull(loan_id#2202L) AND gpuisnotnull(quarter#2233)), true\\n\",\n      \"      :     :     :                                                                                :                                                                                            +- GpuFileGpuScan parquet [loan_id#2202L,monthly_reporting_period#2203,current_loan_delinquency_status#2212,quarter#2233] Batched: true, DataFilters: [isnotnull(loan_id#2202L), isnotnull(quarter#2233)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(loan_id), IsNotNull(quarter)], ReadSchema: struct<loan_id:bigint,monthly_reporting_period:string,current_loan_delinquency_status:int,quarter...\\n\",\n      \"      :     :     :                                                                                +- GpuColumnarExchange gpuhashpartitioning(loan_id#1603L, quarter#1629, 192), ENSURE_REQUIREMENTS, [id=#1130]\\n\",\n      \"      :     :     :                                                                                   +- GpuProject [loan_id#1603L, orig_channel#1604, gpucoalesce(to_seller_name#2570, seller_name#1605) AS seller_name#2689, first_home_buyer#1616, loan_purpose#1617, property_type#1618, occupancy_status#1620, property_state#1621, product_type#1624, relocation_mortgage_indicator#1627, quarter#1629]\\n\",\n      \"      :     :     :                                                                                      +- GpuShuffledHashJoin [seller_name#1605], [from_seller_name#2569], LeftOuter, GpuBuildRight, false\\n\",\n      \"      :     :     :                                                                                         :- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :     :                                                                                         :  +- GpuColumnarExchange gpuhashpartitioning(seller_name#1605, 192), ENSURE_REQUIREMENTS, [id=#862]\\n\",\n      \"      :     :     :                                                                                         :     +- GpuCoalesceBatches targetsize(536870912)\\n\",\n      \"      :     :     :                                                                                         :        +- GpuFilter (gpuisnotnull(loan_id#1603L) AND gpuisnotnull(quarter#1629)), true\\n\",\n      \"      :     :     :                                                                                         :           +- GpuFileGpuScan parquet [loan_id#1603L,orig_channel#1604,seller_name#1605,first_home_buyer#1616,loan_purpose#1617,property_type#1618,occupancy_status#1620,property_state#1621,product_type#1624,relocation_mortgage_indicator#1627,quarter#1629] Batched: true, DataFilters: [isnotnull(loan_id#1603L), isnotnull(quarter#1629)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/acq], PartitionFilters: [], PushedFilters: [IsNotNull(loan_id), IsNotNull(quarter)], ReadSchema: struct<loan_id:bigint,orig_channel:string,seller_name:string,first_home_buyer:string,loan_purpose...\\n\",\n      \"      :     :     :                                                                                         +- GpuColumnarExchange gpuhashpartitioning(from_seller_name#2569, 192), ENSURE_REQUIREMENTS, [id=#1127]\\n\",\n      \"      :     :     :                                                                                            +- GpuCoalesceBatches targetsize(536870912)\\n\",\n      \"      :     :     :                                                                                               +- GpuFilter gpuisnotnull(from_seller_name#2569), true\\n\",\n      \"      :     :     :                                                                                                  +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :     :                                                                                                     +- *(7) Scan ExistingRDD[from_seller_name#2569,to_seller_name#2570]\\n\",\n      \"      :     :     +- GpuBroadcastExchange HashedRelationBroadcastMode(List(input[0, string, true]),false), [id=#3743]\\n\",\n      \"      :     :        +- GpuProject [data#2945 AS relocation_mortgage_indicator#4445, id#2956]\\n\",\n      \"      :     :           +- GpuCoalesceBatches targetsize(536870912)\\n\",\n      \"      :     :              +- GpuFilter ((column_id#2942 = 7) AND gpuisnotnull(data#2945)), true\\n\",\n      \"      :     :                 +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :                    +- InMemoryTableScan [column_id#2942, data#2945, id#2956], [(column_id#2942 = 7), isnotnull(data#2945)]\\n\",\n      \"      :     :                          +- InMemoryRelation [column_id#2942, data#2945, id#2956], StorageLevel(disk, memory, deserialized, 1 replicas)\\n\",\n      \"      :     :                                +- GpuColumnarToRow false\\n\",\n      \"      :     :                                   +- GpuProject [column_id#2942, data#2945, id#2956]\\n\",\n      \"      :     :                                      +- GpuRunningWindow [column_id#2942, data#2945, count#2951L, gpurownumber$() gpuwindowspecdefinition(column_id#2942, count#2951L DESC NULLS LAST, gpuspecifiedwindowframe(RowFrame, gpuspecialframeboundary(unboundedpreceding$()), gpuspecialframeboundary(currentrow$()))) AS id#2956], [column_id#2942], [count#2951L DESC NULLS LAST]\\n\",\n      \"      :     :                                         +- GpuSort [column_id#2942 ASC NULLS FIRST, count#2951L DESC NULLS LAST], false, com.nvidia.spark.rapids.OutOfCoreSort$@163d9f7d\\n\",\n      \"      :     :                                            +- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :                                               +- GpuColumnarExchange gpuhashpartitioning(column_id#2942, 192), ENSURE_REQUIREMENTS, [id=#1141]\\n\",\n      \"      :     :                                                  +- GpuHashAggregate(keys=[column_id#2942, data#2945], functions=[gpucount(1)]), filters=ArrayBuffer(None))\\n\",\n      \"      :     :                                                     +- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :                                                        +- GpuColumnarExchange gpuhashpartitioning(column_id#2942, data#2945, 192), ENSURE_REQUIREMENTS, [id=#1138]\\n\",\n      \"      :     :                                                           +- GpuHashAggregate(keys=[column_id#2942, data#2945], functions=[partial_gpucount(1)]), filters=ArrayBuffer(None))\\n\",\n      \"      :     :                                                              +- GpuProject [pos#2938 AS column_id#2942, col#2939 AS data#2945]\\n\",\n      \"      :     :                                                                 +- GpuCoalesceBatches targetsize(536870912)\\n\",\n      \"      :     :                                                                    +- GpuFilter gpuisnotnull(col#2939), true\\n\",\n      \"      :     :                                                                       +- GpuGenerate gpuposexplode(array(orig_channel#1604, first_home_buyer#1616, loan_purpose#1617, property_type#1618, occupancy_status#1620, property_state#1621, product_type#1624, relocation_mortgage_indicator#1627, seller_name#2689, mod_flag#1550)), false, [pos#2938, col#2939]\\n\",\n      \"      :     :                                                                          +- GpuProject [mod_flag#1550, orig_channel#1604, seller_name#2689, first_home_buyer#1616, loan_purpose#1617, property_type#1618, occupancy_status#1620, property_state#1621, product_type#1624, relocation_mortgage_indicator#1627]\\n\",\n      \"      :     :                                                                             +- GpuShuffledHashJoin [loan_id#1539L, quarter#1570], [loan_id#1603L, quarter#1629], Inner, GpuBuildRight, false\\n\",\n      \"      :     :                                                                                :- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :                                                                                :  +- GpuColumnarExchange gpuhashpartitioning(loan_id#1539L, quarter#1570, 192), ENSURE_REQUIREMENTS, [id=#1121]\\n\",\n      \"      :     :                                                                                :     +- GpuProject [quarter#1570, loan_id#1539L, mod_flag#1550]\\n\",\n      \"      :     :                                                                                :        +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :                                                                                :           +- SortMergeJoin [quarter#1570, loan_id#1539L, cast(timestamp_year#2417 as bigint), cast(timestamp_month#2381 as bigint)], [quarter#2484, loan_id#2453L, timestamp_year#2307L, timestamp_month#2336L], LeftOuter\\n\",\n      \"      :     :                                                                                :              :- *(2) Sort [quarter#1570 ASC NULLS FIRST, loan_id#1539L ASC NULLS FIRST, cast(timestamp_year#2417 as bigint) ASC NULLS FIRST, cast(timestamp_month#2381 as bigint) ASC NULLS FIRST], false, 0\\n\",\n      \"      :     :                                                                                :              :  +- Exchange hashpartitioning(quarter#1570, loan_id#1539L, cast(timestamp_year#2417 as bigint), cast(timestamp_month#2381 as bigint), 192), ENSURE_REQUIREMENTS, [id=#1080]\\n\",\n      \"      :     :                                                                                :              :     +- *(1) Project [loan_id#1539L, mod_flag#1550, quarter#1570, month(cast(gettimestamp(monthly_reporting_period#1540, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_month#2381, year(cast(gettimestamp(monthly_reporting_period#1540, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_year#2417]\\n\",\n      \"      :     :                                                                                :              :        +- GpuColumnarToRow false\\n\",\n      \"      :     :                                                                                :              :           +- GpuFilter (gpuisnotnull(loan_id#1539L) AND gpuisnotnull(quarter#1570)), true\\n\",\n      \"      :     :                                                                                :              :              +- GpuFileGpuScan parquet [loan_id#1539L,monthly_reporting_period#1540,mod_flag#1550,quarter#1570] Batched: true, DataFilters: [isnotnull(loan_id#1539L), isnotnull(quarter#1570)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(loan_id), IsNotNull(quarter)], ReadSchema: struct<loan_id:bigint,monthly_reporting_period:string,mod_flag:string,quarter:string>\\n\",\n      \"      :     :                                                                                :              +- GpuColumnarToRow false\\n\",\n      \"      :     :                                                                                :                 +- GpuSort [quarter#2484 ASC NULLS FIRST, loan_id#2453L ASC NULLS FIRST, timestamp_year#2307L ASC NULLS FIRST, timestamp_month#2336L ASC NULLS FIRST], false, com.nvidia.spark.rapids.OutOfCoreSort$@163d9f7d\\n\",\n      \"      :     :                                                                                :                    +- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :                                                                                :                       +- GpuColumnarExchange gpuhashpartitioning(quarter#2484, loan_id#2453L, timestamp_year#2307L, timestamp_month#2336L, 192), ENSURE_REQUIREMENTS, [id=#1114]\\n\",\n      \"      :     :                                                                                :                          +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :                                                                                :                             +- *(6) HashAggregate(keys=[quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248], functions=[])\\n\",\n      \"      :     :                                                                                :                                +- GpuColumnarToRow false\\n\",\n      \"      :     :                                                                                :                                   +- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :                                                                                :                                      +- GpuColumnarExchange gpuhashpartitioning(quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248, 192), ENSURE_REQUIREMENTS, [id=#1107]\\n\",\n      \"      :     :                                                                                :                                         +- GpuHashAggregate(keys=[quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248], functions=[]), filters=ArrayBuffer())\\n\",\n      \"      :     :                                                                                :                                            +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :                                                                                :                                               +- *(5) Project [quarter#2484, FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) AS josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, loan_id#2453L, month_y#2248]\\n\",\n      \"      :     :                                                                                :                                                  +- *(5) Filter (isnotnull(FLOOR((cast(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast((month_y#2248 - 1) as bigint)) as double) / 12.0))) AND isnotnull(CASE WHEN (pmod(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast(month_y#2248 as bigint)), 12) = 0) THEN 12 ELSE pmod(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast(month_y#2248 as bigint)), 12) END))\\n\",\n      \"      :     :                                                                                :                                                     +- GpuColumnarToRow false\\n\",\n      \"      :     :                                                                                :                                                        +- GpuGenerate gpuexplode([0,1,2,3,4,5,6,7,8,9,10,11]), [loan_id#2453L, quarter#2484, timestamp_month#2051, timestamp_year#2087, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997], false, [month_y#2248]\\n\",\n      \"      :     :                                                                                :                                                           +- GpuProject [loan_id#2453L, quarter#2484, timestamp_month#2051, timestamp_year#2087, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997]\\n\",\n      \"      :     :                                                                                :                                                              +- GpuBroadcastHashJoin [loan_id#2453L, quarter#2484], [loan_id#2202L, quarter#2233], LeftOuter, GpuBuildRight\\n\",\n      \"      :     :                                                                                :                                                                 :- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :                                                                                :                                                                 :  +- *(3) Project [quarter#2484, loan_id#2453L, month(cast(gettimestamp(monthly_reporting_period#2454, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_month#2051, year(cast(gettimestamp(monthly_reporting_period#2454, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_year#2087]\\n\",\n      \"      :     :                                                                                :                                                                 :     +- GpuColumnarToRow false\\n\",\n      \"      :     :                                                                                :                                                                 :        +- GpuFilter (gpuisnotnull(quarter#2484) AND gpuisnotnull(loan_id#2453L)), true\\n\",\n      \"      :     :                                                                                :                                                                 :           +- GpuFileGpuScan parquet [loan_id#2453L,monthly_reporting_period#2454,quarter#2484] Batched: true, DataFilters: [isnotnull(quarter#2484), isnotnull(loan_id#2453L)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(quarter), IsNotNull(loan_id)], ReadSchema: struct<loan_id:bigint,monthly_reporting_period:string,quarter:string>\\n\",\n      \"      :     :                                                                                :                                                                 +- GpuBroadcastExchange HashedRelationBroadcastMode(List(input[1, bigint, true], input[0, string, true]),false), [id=#1096]\\n\",\n      \"      :     :                                                                                :                                                                    +- GpuHashAggregate(keys=[quarter#2233, loan_id#2202L], functions=[gpumax(current_loan_delinquency_status#2212), gpumin(delinquency_30#1975), gpumin(delinquency_90#1976), gpumin(delinquency_180#1977)]), filters=ArrayBuffer(None, None, None, None))\\n\",\n      \"      :     :                                                                                :                                                                       +- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :                                                                                :                                                                          +- GpuColumnarExchange gpuhashpartitioning(quarter#2233, loan_id#2202L, 192), ENSURE_REQUIREMENTS, [id=#1093]\\n\",\n      \"      :     :                                                                                :                                                                             +- GpuHashAggregate(keys=[quarter#2233, loan_id#2202L], functions=[partial_gpumax(current_loan_delinquency_status#2212), partial_gpumin(delinquency_30#1975), partial_gpumin(delinquency_90#1976), partial_gpumin(delinquency_180#1977)]), filters=ArrayBuffer(None, None, None, None))\\n\",\n      \"      :     :                                                                                :                                                                                +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :                                                                                :                                                                                   +- *(4) Project [quarter#2233, loan_id#2202L, current_loan_delinquency_status#2212, CASE WHEN (current_loan_delinquency_status#2212 >= 1) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_30#1975, CASE WHEN (current_loan_delinquency_status#2212 >= 3) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_90#1976, CASE WHEN (current_loan_delinquency_status#2212 >= 6) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_180#1977]\\n\",\n      \"      :     :                                                                                :                                                                                      +- GpuColumnarToRow false\\n\",\n      \"      :     :                                                                                :                                                                                         +- GpuFilter (gpuisnotnull(loan_id#2202L) AND gpuisnotnull(quarter#2233)), true\\n\",\n      \"      :     :                                                                                :                                                                                            +- GpuFileGpuScan parquet [loan_id#2202L,monthly_reporting_period#2203,current_loan_delinquency_status#2212,quarter#2233] Batched: true, DataFilters: [isnotnull(loan_id#2202L), isnotnull(quarter#2233)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(loan_id), IsNotNull(quarter)], ReadSchema: struct<loan_id:bigint,monthly_reporting_period:string,current_loan_delinquency_status:int,quarter...\\n\",\n      \"      :     :                                                                                +- GpuColumnarExchange gpuhashpartitioning(loan_id#1603L, quarter#1629, 192), ENSURE_REQUIREMENTS, [id=#1130]\\n\",\n      \"      :     :                                                                                   +- GpuProject [loan_id#1603L, orig_channel#1604, gpucoalesce(to_seller_name#2570, seller_name#1605) AS seller_name#2689, first_home_buyer#1616, loan_purpose#1617, property_type#1618, occupancy_status#1620, property_state#1621, product_type#1624, relocation_mortgage_indicator#1627, quarter#1629]\\n\",\n      \"      :     :                                                                                      +- GpuShuffledHashJoin [seller_name#1605], [from_seller_name#2569], LeftOuter, GpuBuildRight, false\\n\",\n      \"      :     :                                                                                         :- GpuShuffleCoalesce 536870912\\n\",\n      \"      :     :                                                                                         :  +- GpuColumnarExchange gpuhashpartitioning(seller_name#1605, 192), ENSURE_REQUIREMENTS, [id=#862]\\n\",\n      \"      :     :                                                                                         :     +- GpuCoalesceBatches targetsize(536870912)\\n\",\n      \"      :     :                                                                                         :        +- GpuFilter (gpuisnotnull(loan_id#1603L) AND gpuisnotnull(quarter#1629)), true\\n\",\n      \"      :     :                                                                                         :           +- GpuFileGpuScan parquet [loan_id#1603L,orig_channel#1604,seller_name#1605,first_home_buyer#1616,loan_purpose#1617,property_type#1618,occupancy_status#1620,property_state#1621,product_type#1624,relocation_mortgage_indicator#1627,quarter#1629] Batched: true, DataFilters: [isnotnull(loan_id#1603L), isnotnull(quarter#1629)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/acq], PartitionFilters: [], PushedFilters: [IsNotNull(loan_id), IsNotNull(quarter)], ReadSchema: struct<loan_id:bigint,orig_channel:string,seller_name:string,first_home_buyer:string,loan_purpose...\\n\",\n      \"      :     :                                                                                         +- GpuColumnarExchange gpuhashpartitioning(from_seller_name#2569, 192), ENSURE_REQUIREMENTS, [id=#1127]\\n\",\n      \"      :     :                                                                                            +- GpuCoalesceBatches targetsize(536870912)\\n\",\n      \"      :     :                                                                                               +- GpuFilter gpuisnotnull(from_seller_name#2569), true\\n\",\n      \"      :     :                                                                                                  +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :     :                                                                                                     +- *(7) Scan ExistingRDD[from_seller_name#2569,to_seller_name#2570]\\n\",\n      \"      :     +- GpuBroadcastExchange HashedRelationBroadcastMode(List(input[0, string, true]),false), [id=#3750]\\n\",\n      \"      :        +- GpuProject [data#2945 AS seller_name#4650, id#2956]\\n\",\n      \"      :           +- GpuCoalesceBatches targetsize(536870912)\\n\",\n      \"      :              +- GpuFilter ((column_id#2942 = 8) AND gpuisnotnull(data#2945)), true\\n\",\n      \"      :                 +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :                    +- InMemoryTableScan [column_id#2942, data#2945, id#2956], [(column_id#2942 = 8), isnotnull(data#2945)]\\n\",\n      \"      :                          +- InMemoryRelation [column_id#2942, data#2945, id#2956], StorageLevel(disk, memory, deserialized, 1 replicas)\\n\",\n      \"      :                                +- GpuColumnarToRow false\\n\",\n      \"      :                                   +- GpuProject [column_id#2942, data#2945, id#2956]\\n\",\n      \"      :                                      +- GpuRunningWindow [column_id#2942, data#2945, count#2951L, gpurownumber$() gpuwindowspecdefinition(column_id#2942, count#2951L DESC NULLS LAST, gpuspecifiedwindowframe(RowFrame, gpuspecialframeboundary(unboundedpreceding$()), gpuspecialframeboundary(currentrow$()))) AS id#2956], [column_id#2942], [count#2951L DESC NULLS LAST]\\n\",\n      \"      :                                         +- GpuSort [column_id#2942 ASC NULLS FIRST, count#2951L DESC NULLS LAST], false, com.nvidia.spark.rapids.OutOfCoreSort$@163d9f7d\\n\",\n      \"      :                                            +- GpuShuffleCoalesce 536870912\\n\",\n      \"      :                                               +- GpuColumnarExchange gpuhashpartitioning(column_id#2942, 192), ENSURE_REQUIREMENTS, [id=#1141]\\n\",\n      \"      :                                                  +- GpuHashAggregate(keys=[column_id#2942, data#2945], functions=[gpucount(1)]), filters=ArrayBuffer(None))\\n\",\n      \"      :                                                     +- GpuShuffleCoalesce 536870912\\n\",\n      \"      :                                                        +- GpuColumnarExchange gpuhashpartitioning(column_id#2942, data#2945, 192), ENSURE_REQUIREMENTS, [id=#1138]\\n\",\n      \"      :                                                           +- GpuHashAggregate(keys=[column_id#2942, data#2945], functions=[partial_gpucount(1)]), filters=ArrayBuffer(None))\\n\",\n      \"      :                                                              +- GpuProject [pos#2938 AS column_id#2942, col#2939 AS data#2945]\\n\",\n      \"      :                                                                 +- GpuCoalesceBatches targetsize(536870912)\\n\",\n      \"      :                                                                    +- GpuFilter gpuisnotnull(col#2939), true\\n\",\n      \"      :                                                                       +- GpuGenerate gpuposexplode(array(orig_channel#1604, first_home_buyer#1616, loan_purpose#1617, property_type#1618, occupancy_status#1620, property_state#1621, product_type#1624, relocation_mortgage_indicator#1627, seller_name#2689, mod_flag#1550)), false, [pos#2938, col#2939]\\n\",\n      \"      :                                                                          +- GpuProject [mod_flag#1550, orig_channel#1604, seller_name#2689, first_home_buyer#1616, loan_purpose#1617, property_type#1618, occupancy_status#1620, property_state#1621, product_type#1624, relocation_mortgage_indicator#1627]\\n\",\n      \"      :                                                                             +- GpuShuffledHashJoin [loan_id#1539L, quarter#1570], [loan_id#1603L, quarter#1629], Inner, GpuBuildRight, false\\n\",\n      \"      :                                                                                :- GpuShuffleCoalesce 536870912\\n\",\n      \"      :                                                                                :  +- GpuColumnarExchange gpuhashpartitioning(loan_id#1539L, quarter#1570, 192), ENSURE_REQUIREMENTS, [id=#1121]\\n\",\n      \"      :                                                                                :     +- GpuProject [quarter#1570, loan_id#1539L, mod_flag#1550]\\n\",\n      \"      :                                                                                :        +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :                                                                                :           +- SortMergeJoin [quarter#1570, loan_id#1539L, cast(timestamp_year#2417 as bigint), cast(timestamp_month#2381 as bigint)], [quarter#2484, loan_id#2453L, timestamp_year#2307L, timestamp_month#2336L], LeftOuter\\n\",\n      \"      :                                                                                :              :- *(2) Sort [quarter#1570 ASC NULLS FIRST, loan_id#1539L ASC NULLS FIRST, cast(timestamp_year#2417 as bigint) ASC NULLS FIRST, cast(timestamp_month#2381 as bigint) ASC NULLS FIRST], false, 0\\n\",\n      \"      :                                                                                :              :  +- Exchange hashpartitioning(quarter#1570, loan_id#1539L, cast(timestamp_year#2417 as bigint), cast(timestamp_month#2381 as bigint), 192), ENSURE_REQUIREMENTS, [id=#1080]\\n\",\n      \"      :                                                                                :              :     +- *(1) Project [loan_id#1539L, mod_flag#1550, quarter#1570, month(cast(gettimestamp(monthly_reporting_period#1540, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_month#2381, year(cast(gettimestamp(monthly_reporting_period#1540, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_year#2417]\\n\",\n      \"      :                                                                                :              :        +- GpuColumnarToRow false\\n\",\n      \"      :                                                                                :              :           +- GpuFilter (gpuisnotnull(loan_id#1539L) AND gpuisnotnull(quarter#1570)), true\\n\",\n      \"      :                                                                                :              :              +- GpuFileGpuScan parquet [loan_id#1539L,monthly_reporting_period#1540,mod_flag#1550,quarter#1570] Batched: true, DataFilters: [isnotnull(loan_id#1539L), isnotnull(quarter#1570)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(loan_id), IsNotNull(quarter)], ReadSchema: struct<loan_id:bigint,monthly_reporting_period:string,mod_flag:string,quarter:string>\\n\",\n      \"      :                                                                                :              +- GpuColumnarToRow false\\n\",\n      \"      :                                                                                :                 +- GpuSort [quarter#2484 ASC NULLS FIRST, loan_id#2453L ASC NULLS FIRST, timestamp_year#2307L ASC NULLS FIRST, timestamp_month#2336L ASC NULLS FIRST], false, com.nvidia.spark.rapids.OutOfCoreSort$@163d9f7d\\n\",\n      \"      :                                                                                :                    +- GpuShuffleCoalesce 536870912\\n\",\n      \"      :                                                                                :                       +- GpuColumnarExchange gpuhashpartitioning(quarter#2484, loan_id#2453L, timestamp_year#2307L, timestamp_month#2336L, 192), ENSURE_REQUIREMENTS, [id=#1114]\\n\",\n      \"      :                                                                                :                          +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :                                                                                :                             +- *(6) HashAggregate(keys=[quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248], functions=[])\\n\",\n      \"      :                                                                                :                                +- GpuColumnarToRow false\\n\",\n      \"      :                                                                                :                                   +- GpuShuffleCoalesce 536870912\\n\",\n      \"      :                                                                                :                                      +- GpuColumnarExchange gpuhashpartitioning(quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248, 192), ENSURE_REQUIREMENTS, [id=#1107]\\n\",\n      \"      :                                                                                :                                         +- GpuHashAggregate(keys=[quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248], functions=[]), filters=ArrayBuffer())\\n\",\n      \"      :                                                                                :                                            +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :                                                                                :                                               +- *(5) Project [quarter#2484, FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) AS josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, loan_id#2453L, month_y#2248]\\n\",\n      \"      :                                                                                :                                                  +- *(5) Filter (isnotnull(FLOOR((cast(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast((month_y#2248 - 1) as bigint)) as double) / 12.0))) AND isnotnull(CASE WHEN (pmod(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast(month_y#2248 as bigint)), 12) = 0) THEN 12 ELSE pmod(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast(month_y#2248 as bigint)), 12) END))\\n\",\n      \"      :                                                                                :                                                     +- GpuColumnarToRow false\\n\",\n      \"      :                                                                                :                                                        +- GpuGenerate gpuexplode([0,1,2,3,4,5,6,7,8,9,10,11]), [loan_id#2453L, quarter#2484, timestamp_month#2051, timestamp_year#2087, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997], false, [month_y#2248]\\n\",\n      \"      :                                                                                :                                                           +- GpuProject [loan_id#2453L, quarter#2484, timestamp_month#2051, timestamp_year#2087, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997]\\n\",\n      \"      :                                                                                :                                                              +- GpuBroadcastHashJoin [loan_id#2453L, quarter#2484], [loan_id#2202L, quarter#2233], LeftOuter, GpuBuildRight\\n\",\n      \"      :                                                                                :                                                                 :- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :                                                                                :                                                                 :  +- *(3) Project [quarter#2484, loan_id#2453L, month(cast(gettimestamp(monthly_reporting_period#2454, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_month#2051, year(cast(gettimestamp(monthly_reporting_period#2454, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_year#2087]\\n\",\n      \"      :                                                                                :                                                                 :     +- GpuColumnarToRow false\\n\",\n      \"      :                                                                                :                                                                 :        +- GpuFilter (gpuisnotnull(quarter#2484) AND gpuisnotnull(loan_id#2453L)), true\\n\",\n      \"      :                                                                                :                                                                 :           +- GpuFileGpuScan parquet [loan_id#2453L,monthly_reporting_period#2454,quarter#2484] Batched: true, DataFilters: [isnotnull(quarter#2484), isnotnull(loan_id#2453L)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(quarter), IsNotNull(loan_id)], ReadSchema: struct<loan_id:bigint,monthly_reporting_period:string,quarter:string>\\n\",\n      \"      :                                                                                :                                                                 +- GpuBroadcastExchange HashedRelationBroadcastMode(List(input[1, bigint, true], input[0, string, true]),false), [id=#1096]\\n\",\n      \"      :                                                                                :                                                                    +- GpuHashAggregate(keys=[quarter#2233, loan_id#2202L], functions=[gpumax(current_loan_delinquency_status#2212), gpumin(delinquency_30#1975), gpumin(delinquency_90#1976), gpumin(delinquency_180#1977)]), filters=ArrayBuffer(None, None, None, None))\\n\",\n      \"      :                                                                                :                                                                       +- GpuShuffleCoalesce 536870912\\n\",\n      \"      :                                                                                :                                                                          +- GpuColumnarExchange gpuhashpartitioning(quarter#2233, loan_id#2202L, 192), ENSURE_REQUIREMENTS, [id=#1093]\\n\",\n      \"      :                                                                                :                                                                             +- GpuHashAggregate(keys=[quarter#2233, loan_id#2202L], functions=[partial_gpumax(current_loan_delinquency_status#2212), partial_gpumin(delinquency_30#1975), partial_gpumin(delinquency_90#1976), partial_gpumin(delinquency_180#1977)]), filters=ArrayBuffer(None, None, None, None))\\n\",\n      \"      :                                                                                :                                                                                +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :                                                                                :                                                                                   +- *(4) Project [quarter#2233, loan_id#2202L, current_loan_delinquency_status#2212, CASE WHEN (current_loan_delinquency_status#2212 >= 1) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_30#1975, CASE WHEN (current_loan_delinquency_status#2212 >= 3) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_90#1976, CASE WHEN (current_loan_delinquency_status#2212 >= 6) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_180#1977]\\n\",\n      \"      :                                                                                :                                                                                      +- GpuColumnarToRow false\\n\",\n      \"      :                                                                                :                                                                                         +- GpuFilter (gpuisnotnull(loan_id#2202L) AND gpuisnotnull(quarter#2233)), true\\n\",\n      \"      :                                                                                :                                                                                            +- GpuFileGpuScan parquet [loan_id#2202L,monthly_reporting_period#2203,current_loan_delinquency_status#2212,quarter#2233] Batched: true, DataFilters: [isnotnull(loan_id#2202L), isnotnull(quarter#2233)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(loan_id), IsNotNull(quarter)], ReadSchema: struct<loan_id:bigint,monthly_reporting_period:string,current_loan_delinquency_status:int,quarter...\\n\",\n      \"      :                                                                                +- GpuColumnarExchange gpuhashpartitioning(loan_id#1603L, quarter#1629, 192), ENSURE_REQUIREMENTS, [id=#1130]\\n\",\n      \"      :                                                                                   +- GpuProject [loan_id#1603L, orig_channel#1604, gpucoalesce(to_seller_name#2570, seller_name#1605) AS seller_name#2689, first_home_buyer#1616, loan_purpose#1617, property_type#1618, occupancy_status#1620, property_state#1621, product_type#1624, relocation_mortgage_indicator#1627, quarter#1629]\\n\",\n      \"      :                                                                                      +- GpuShuffledHashJoin [seller_name#1605], [from_seller_name#2569], LeftOuter, GpuBuildRight, false\\n\",\n      \"      :                                                                                         :- GpuShuffleCoalesce 536870912\\n\",\n      \"      :                                                                                         :  +- GpuColumnarExchange gpuhashpartitioning(seller_name#1605, 192), ENSURE_REQUIREMENTS, [id=#862]\\n\",\n      \"      :                                                                                         :     +- GpuCoalesceBatches targetsize(536870912)\\n\",\n      \"      :                                                                                         :        +- GpuFilter (gpuisnotnull(loan_id#1603L) AND gpuisnotnull(quarter#1629)), true\\n\",\n      \"      :                                                                                         :           +- GpuFileGpuScan parquet [loan_id#1603L,orig_channel#1604,seller_name#1605,first_home_buyer#1616,loan_purpose#1617,property_type#1618,occupancy_status#1620,property_state#1621,product_type#1624,relocation_mortgage_indicator#1627,quarter#1629] Batched: true, DataFilters: [isnotnull(loan_id#1603L), isnotnull(quarter#1629)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/acq], PartitionFilters: [], PushedFilters: [IsNotNull(loan_id), IsNotNull(quarter)], ReadSchema: struct<loan_id:bigint,orig_channel:string,seller_name:string,first_home_buyer:string,loan_purpose...\\n\",\n      \"      :                                                                                         +- GpuColumnarExchange gpuhashpartitioning(from_seller_name#2569, 192), ENSURE_REQUIREMENTS, [id=#1127]\\n\",\n      \"      :                                                                                            +- GpuCoalesceBatches targetsize(536870912)\\n\",\n      \"      :                                                                                               +- GpuFilter gpuisnotnull(from_seller_name#2569), true\\n\",\n      \"      :                                                                                                  +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"      :                                                                                                     +- *(7) Scan ExistingRDD[from_seller_name#2569,to_seller_name#2570]\\n\",\n      \"      +- GpuBroadcastExchange HashedRelationBroadcastMode(List(input[0, string, true]),false), [id=#3757]\\n\",\n      \"         +- GpuProject [data#2945 AS mod_flag#4855, id#2956]\\n\",\n      \"            +- GpuCoalesceBatches targetsize(536870912)\\n\",\n      \"               +- GpuFilter ((column_id#2942 = 9) AND gpuisnotnull(data#2945)), true\\n\",\n      \"                  +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"                     +- InMemoryTableScan [column_id#2942, data#2945, id#2956], [(column_id#2942 = 9), isnotnull(data#2945)]\\n\",\n      \"                           +- InMemoryRelation [column_id#2942, data#2945, id#2956], StorageLevel(disk, memory, deserialized, 1 replicas)\\n\",\n      \"                                 +- GpuColumnarToRow false\\n\",\n      \"                                    +- GpuProject [column_id#2942, data#2945, id#2956]\\n\",\n      \"                                       +- GpuRunningWindow [column_id#2942, data#2945, count#2951L, gpurownumber$() gpuwindowspecdefinition(column_id#2942, count#2951L DESC NULLS LAST, gpuspecifiedwindowframe(RowFrame, gpuspecialframeboundary(unboundedpreceding$()), gpuspecialframeboundary(currentrow$()))) AS id#2956], [column_id#2942], [count#2951L DESC NULLS LAST]\\n\",\n      \"                                          +- GpuSort [column_id#2942 ASC NULLS FIRST, count#2951L DESC NULLS LAST], false, com.nvidia.spark.rapids.OutOfCoreSort$@163d9f7d\\n\",\n      \"                                             +- GpuShuffleCoalesce 536870912\\n\",\n      \"                                                +- GpuColumnarExchange gpuhashpartitioning(column_id#2942, 192), ENSURE_REQUIREMENTS, [id=#1141]\\n\",\n      \"                                                   +- GpuHashAggregate(keys=[column_id#2942, data#2945], functions=[gpucount(1)]), filters=ArrayBuffer(None))\\n\",\n      \"                                                      +- GpuShuffleCoalesce 536870912\\n\",\n      \"                                                         +- GpuColumnarExchange gpuhashpartitioning(column_id#2942, data#2945, 192), ENSURE_REQUIREMENTS, [id=#1138]\\n\",\n      \"                                                            +- GpuHashAggregate(keys=[column_id#2942, data#2945], functions=[partial_gpucount(1)]), filters=ArrayBuffer(None))\\n\",\n      \"                                                               +- GpuProject [pos#2938 AS column_id#2942, col#2939 AS data#2945]\\n\",\n      \"                                                                  +- GpuCoalesceBatches targetsize(536870912)\\n\",\n      \"                                                                     +- GpuFilter gpuisnotnull(col#2939), true\\n\",\n      \"                                                                        +- GpuGenerate gpuposexplode(array(orig_channel#1604, first_home_buyer#1616, loan_purpose#1617, property_type#1618, occupancy_status#1620, property_state#1621, product_type#1624, relocation_mortgage_indicator#1627, seller_name#2689, mod_flag#1550)), false, [pos#2938, col#2939]\\n\",\n      \"                                                                           +- GpuProject [mod_flag#1550, orig_channel#1604, seller_name#2689, first_home_buyer#1616, loan_purpose#1617, property_type#1618, occupancy_status#1620, property_state#1621, product_type#1624, relocation_mortgage_indicator#1627]\\n\",\n      \"                                                                              +- GpuShuffledHashJoin [loan_id#1539L, quarter#1570], [loan_id#1603L, quarter#1629], Inner, GpuBuildRight, false\\n\",\n      \"                                                                                 :- GpuShuffleCoalesce 536870912\\n\",\n      \"                                                                                 :  +- GpuColumnarExchange gpuhashpartitioning(loan_id#1539L, quarter#1570, 192), ENSURE_REQUIREMENTS, [id=#1121]\\n\",\n      \"                                                                                 :     +- GpuProject [quarter#1570, loan_id#1539L, mod_flag#1550]\\n\",\n      \"                                                                                 :        +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"                                                                                 :           +- SortMergeJoin [quarter#1570, loan_id#1539L, cast(timestamp_year#2417 as bigint), cast(timestamp_month#2381 as bigint)], [quarter#2484, loan_id#2453L, timestamp_year#2307L, timestamp_month#2336L], LeftOuter\\n\",\n      \"                                                                                 :              :- *(2) Sort [quarter#1570 ASC NULLS FIRST, loan_id#1539L ASC NULLS FIRST, cast(timestamp_year#2417 as bigint) ASC NULLS FIRST, cast(timestamp_month#2381 as bigint) ASC NULLS FIRST], false, 0\\n\",\n      \"                                                                                 :              :  +- Exchange hashpartitioning(quarter#1570, loan_id#1539L, cast(timestamp_year#2417 as bigint), cast(timestamp_month#2381 as bigint), 192), ENSURE_REQUIREMENTS, [id=#1080]\\n\",\n      \"                                                                                 :              :     +- *(1) Project [loan_id#1539L, mod_flag#1550, quarter#1570, month(cast(gettimestamp(monthly_reporting_period#1540, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_month#2381, year(cast(gettimestamp(monthly_reporting_period#1540, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_year#2417]\\n\",\n      \"                                                                                 :              :        +- GpuColumnarToRow false\\n\",\n      \"                                                                                 :              :           +- GpuFilter (gpuisnotnull(loan_id#1539L) AND gpuisnotnull(quarter#1570)), true\\n\",\n      \"                                                                                 :              :              +- GpuFileGpuScan parquet [loan_id#1539L,monthly_reporting_period#1540,mod_flag#1550,quarter#1570] Batched: true, DataFilters: [isnotnull(loan_id#1539L), isnotnull(quarter#1570)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(loan_id), IsNotNull(quarter)], ReadSchema: struct<loan_id:bigint,monthly_reporting_period:string,mod_flag:string,quarter:string>\\n\",\n      \"                                                                                 :              +- GpuColumnarToRow false\\n\",\n      \"                                                                                 :                 +- GpuSort [quarter#2484 ASC NULLS FIRST, loan_id#2453L ASC NULLS FIRST, timestamp_year#2307L ASC NULLS FIRST, timestamp_month#2336L ASC NULLS FIRST], false, com.nvidia.spark.rapids.OutOfCoreSort$@163d9f7d\\n\",\n      \"                                                                                 :                    +- GpuShuffleCoalesce 536870912\\n\",\n      \"                                                                                 :                       +- GpuColumnarExchange gpuhashpartitioning(quarter#2484, loan_id#2453L, timestamp_year#2307L, timestamp_month#2336L, 192), ENSURE_REQUIREMENTS, [id=#1114]\\n\",\n      \"                                                                                 :                          +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"                                                                                 :                             +- *(6) HashAggregate(keys=[quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248], functions=[])\\n\",\n      \"                                                                                 :                                +- GpuColumnarToRow false\\n\",\n      \"                                                                                 :                                   +- GpuShuffleCoalesce 536870912\\n\",\n      \"                                                                                 :                                      +- GpuColumnarExchange gpuhashpartitioning(quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248, 192), ENSURE_REQUIREMENTS, [id=#1107]\\n\",\n      \"                                                                                 :                                         +- GpuHashAggregate(keys=[quarter#2484, loan_id#2453L, josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, month_y#2248], functions=[]), filters=ArrayBuffer())\\n\",\n      \"                                                                                 :                                            +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"                                                                                 :                                               +- *(5) Project [quarter#2484, FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) AS josh_mody_n#2264L, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997, loan_id#2453L, month_y#2248]\\n\",\n      \"                                                                                 :                                                  +- *(5) Filter (isnotnull(FLOOR((cast(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast((month_y#2248 - 1) as bigint)) as double) / 12.0))) AND isnotnull(CASE WHEN (pmod(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast(month_y#2248 as bigint)), 12) = 0) THEN 12 ELSE pmod(((24000 + (FLOOR((cast(((((timestamp_year#2087 * 12) + timestamp_month#2051) - 24000) - month_y#2248) as double) / 12.0)) * 12)) + cast(month_y#2248 as bigint)), 12) END))\\n\",\n      \"                                                                                 :                                                     +- GpuColumnarToRow false\\n\",\n      \"                                                                                 :                                                        +- GpuGenerate gpuexplode([0,1,2,3,4,5,6,7,8,9,10,11]), [loan_id#2453L, quarter#2484, timestamp_month#2051, timestamp_year#2087, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997], false, [month_y#2248]\\n\",\n      \"                                                                                 :                                                           +- GpuProject [loan_id#2453L, quarter#2484, timestamp_month#2051, timestamp_year#2087, ever_30#2004, ever_90#2005, ever_180#2006, delinquency_30#1993, delinquency_90#1995, delinquency_180#1997]\\n\",\n      \"                                                                                 :                                                              +- GpuBroadcastHashJoin [loan_id#2453L, quarter#2484], [loan_id#2202L, quarter#2233], LeftOuter, GpuBuildRight\\n\",\n      \"                                                                                 :                                                                 :- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"                                                                                 :                                                                 :  +- *(3) Project [quarter#2484, loan_id#2453L, month(cast(gettimestamp(monthly_reporting_period#2454, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_month#2051, year(cast(gettimestamp(monthly_reporting_period#2454, MM/dd/yyyy, Some(America/Los_Angeles), false) as date)) AS timestamp_year#2087]\\n\",\n      \"                                                                                 :                                                                 :     +- GpuColumnarToRow false\\n\",\n      \"                                                                                 :                                                                 :        +- GpuFilter (gpuisnotnull(quarter#2484) AND gpuisnotnull(loan_id#2453L)), true\\n\",\n      \"                                                                                 :                                                                 :           +- GpuFileGpuScan parquet [loan_id#2453L,monthly_reporting_period#2454,quarter#2484] Batched: true, DataFilters: [isnotnull(quarter#2484), isnotnull(loan_id#2453L)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(quarter), IsNotNull(loan_id)], ReadSchema: struct<loan_id:bigint,monthly_reporting_period:string,quarter:string>\\n\",\n      \"                                                                                 :                                                                 +- GpuBroadcastExchange HashedRelationBroadcastMode(List(input[1, bigint, true], input[0, string, true]),false), [id=#1096]\\n\",\n      \"                                                                                 :                                                                    +- GpuHashAggregate(keys=[quarter#2233, loan_id#2202L], functions=[gpumax(current_loan_delinquency_status#2212), gpumin(delinquency_30#1975), gpumin(delinquency_90#1976), gpumin(delinquency_180#1977)]), filters=ArrayBuffer(None, None, None, None))\\n\",\n      \"                                                                                 :                                                                       +- GpuShuffleCoalesce 536870912\\n\",\n      \"                                                                                 :                                                                          +- GpuColumnarExchange gpuhashpartitioning(quarter#2233, loan_id#2202L, 192), ENSURE_REQUIREMENTS, [id=#1093]\\n\",\n      \"                                                                                 :                                                                             +- GpuHashAggregate(keys=[quarter#2233, loan_id#2202L], functions=[partial_gpumax(current_loan_delinquency_status#2212), partial_gpumin(delinquency_30#1975), partial_gpumin(delinquency_90#1976), partial_gpumin(delinquency_180#1977)]), filters=ArrayBuffer(None, None, None, None))\\n\",\n      \"                                                                                 :                                                                                +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"                                                                                 :                                                                                   +- *(4) Project [quarter#2233, loan_id#2202L, current_loan_delinquency_status#2212, CASE WHEN (current_loan_delinquency_status#2212 >= 1) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_30#1975, CASE WHEN (current_loan_delinquency_status#2212 >= 3) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_90#1976, CASE WHEN (current_loan_delinquency_status#2212 >= 6) THEN cast(gettimestamp(monthly_reporting_period#2203, MM/dd/yyyy, Some(America/Los_Angeles), false) as date) END AS delinquency_180#1977]\\n\",\n      \"                                                                                 :                                                                                      +- GpuColumnarToRow false\\n\",\n      \"                                                                                 :                                                                                         +- GpuFilter (gpuisnotnull(loan_id#2202L) AND gpuisnotnull(quarter#2233)), true\\n\",\n      \"                                                                                 :                                                                                            +- GpuFileGpuScan parquet [loan_id#2202L,monthly_reporting_period#2203,current_loan_delinquency_status#2212,quarter#2233] Batched: true, DataFilters: [isnotnull(loan_id#2202L), isnotnull(quarter#2233)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/perf], PartitionFilters: [], PushedFilters: [IsNotNull(loan_id), IsNotNull(quarter)], ReadSchema: struct<loan_id:bigint,monthly_reporting_period:string,current_loan_delinquency_status:int,quarter...\\n\",\n      \"                                                                                 +- GpuColumnarExchange gpuhashpartitioning(loan_id#1603L, quarter#1629, 192), ENSURE_REQUIREMENTS, [id=#1130]\\n\",\n      \"                                                                                    +- GpuProject [loan_id#1603L, orig_channel#1604, gpucoalesce(to_seller_name#2570, seller_name#1605) AS seller_name#2689, first_home_buyer#1616, loan_purpose#1617, property_type#1618, occupancy_status#1620, property_state#1621, product_type#1624, relocation_mortgage_indicator#1627, quarter#1629]\\n\",\n      \"                                                                                       +- GpuShuffledHashJoin [seller_name#1605], [from_seller_name#2569], LeftOuter, GpuBuildRight, false\\n\",\n      \"                                                                                          :- GpuShuffleCoalesce 536870912\\n\",\n      \"                                                                                          :  +- GpuColumnarExchange gpuhashpartitioning(seller_name#1605, 192), ENSURE_REQUIREMENTS, [id=#862]\\n\",\n      \"                                                                                          :     +- GpuCoalesceBatches targetsize(536870912)\\n\",\n      \"                                                                                          :        +- GpuFilter (gpuisnotnull(loan_id#1603L) AND gpuisnotnull(quarter#1629)), true\\n\",\n      \"                                                                                          :           +- GpuFileGpuScan parquet [loan_id#1603L,orig_channel#1604,seller_name#1605,first_home_buyer#1616,loan_purpose#1617,property_type#1618,occupancy_status#1620,property_state#1621,product_type#1624,relocation_mortgage_indicator#1627,quarter#1629] Batched: true, DataFilters: [isnotnull(loan_id#1603L), isnotnull(quarter#1629)], Format: Parquet, Location: InMemoryFileIndex[file:/local/saralihalli/HOME/mortgage/acq], PartitionFilters: [], PushedFilters: [IsNotNull(loan_id), IsNotNull(quarter)], ReadSchema: struct<loan_id:bigint,orig_channel:string,seller_name:string,first_home_buyer:string,loan_purpose...\\n\",\n      \"                                                                                          +- GpuColumnarExchange gpuhashpartitioning(from_seller_name#2569, 192), ENSURE_REQUIREMENTS, [id=#1127]\\n\",\n      \"                                                                                             +- GpuCoalesceBatches targetsize(536870912)\\n\",\n      \"                                                                                                +- GpuFilter gpuisnotnull(from_seller_name#2569), true\\n\",\n      \"                                                                                                   +- GpuRowToColumnar targetsize(536870912)\\n\",\n      \"                                                                                                      +- *(7) Scan ExistingRDD[from_seller_name#2569,to_seller_name#2570]\\n\",\n      \"\\n\",\n      \"\\n\",\n      \"None\\n\",\n      \"249.0352599620819\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"start = time.time()\\n\",\n    \"\\n\",\n    \"# run main function to process data\\n\",\n    \"out = run_mortgage(spark, perf, acq)\\n\",\n    \"\\n\",\n    \"# save processed data\\n\",\n    \"out.write.parquet(output_path, mode='overwrite')\\n\",\n    \"\\n\",\n    \"# save processed data\\n\",\n    \"if save_train_eval_dataset:\\n\",\n    \"    etlDf = spark.read.parquet(output_path)\\n\",\n    \"\\n\",\n    \"    # split 80% for training, 20% for test\\n\",\n    \"    splits = etlDf.randomSplit([0.8, 0.2])\\n\",\n    \"\\n\",\n    \"    splits[0].write.format('parquet').save(output_path_train, mode=\\\"overwrite\\\")\\n\",\n    \"    splits[1].write.format('parquet').save(output_path_eval, mode=\\\"overwrite\\\")\\n\",\n    \"\\n\",\n    \"# print explain and time\\n\",\n    \"print(out.explain())\\n\",\n    \"end = time.time()\\n\",\n    \"print(end - start)\\n\",\n    \"spark.stop()\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.2\"\n  },\n  \"name\": \"gpu-mortgage\",\n  \"notebookId\": 4440374682851873\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 1\n}\n"
  },
  {
    "path": "examples/XGBoost-Examples/mortgage/notebooks/python/cv-mortgage-gpu.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Introduction to XGBoost-Spark Cross Validation with GPU\\n\",\n    \"\\n\",\n    \"The goal of this notebook is to show you how to levarage GPU to accelerate XGBoost spark cross validatoin for hyperparameter tuning. The best model for the given hyperparameters will be returned.\\n\",\n    \"\\n\",\n    \"Here takes the application 'Mortgage' as an example.\\n\",\n    \"\\n\",\n    \"A few libraries are required for this notebook:\\n\",\n    \"  1. cudf-cu11\\n\",\n    \"  2. xgboost\\n\",\n    \"  3. scikit-learn\\n\",\n    \"  4. numpy\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Import the Required Libraries\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from xgboost.spark import SparkXGBClassifier, SparkXGBClassifierModel\\n\",\n    \"from pyspark.ml.tuning import ParamGridBuilder, CrossValidator\\n\",\n    \"from pyspark.ml.evaluation import MulticlassClassificationEvaluator\\n\",\n    \"from pyspark.sql import SparkSession\\n\",\n    \"from pyspark.sql.types import FloatType, IntegerType, StructField, StructType, DoubleType\\n\",\n    \"from pyspark.conf import SparkConf\\n\",\n    \"from time import time\\n\",\n    \"import os\\n\",\n    \"# if you pass/unpack the archive file and enable the environment\\n\",\n    \"# os.environ['PYSPARK_PYTHON'] = \\\"./environment/bin/python\\\"\\n\",\n    \"# os.environ['PYSPARK_DRIVER_PYTHON'] = \\\"./environment/bin/python\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Create a Spark Session\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2022-11-25 09:34:43,524 WARN util.NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\\n\",\n      \"Setting default log level to \\\"WARN\\\".\\n\",\n      \"To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\\n\",\n      \"2022-11-25 09:34:43,952 WARN resource.ResourceUtils: The configuration of cores (exec = 4 task = 1, runnable tasks = 4) will result in wasted resources due to resource gpu limiting the number of runnable tasks per executor to: 1. Please adjust your configuration.\\n\",\n      \"2022-11-25 09:34:58,155 WARN rapids.RapidsPluginUtils: RAPIDS Accelerator 25.02.1 using cudf 25.02.1.\\n\",\n      \"2022-11-25 09:34:58,171 WARN rapids.RapidsPluginUtils: spark.rapids.sql.multiThreadedRead.numThreads is set to 20.\\n\",\n      \"2022-11-25 09:34:58,175 WARN rapids.RapidsPluginUtils: RAPIDS Accelerator is enabled, to disable GPU support set `spark.rapids.sql.enabled` to false.\\n\",\n      \"2022-11-25 09:34:58,175 WARN rapids.RapidsPluginUtils: spark.rapids.sql.explain is set to `NOT_ON_GPU`. Set it to 'NONE' to suppress the diagnostics logging about the query placement on the GPU.\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"SPARK_MASTER_URL = os.getenv(\\\"SPARK_MASTER_URL\\\", \\\"/your-url\\\")\\n\",\n    \"\\n\",\n    \"RAPIDS_JAR = os.getenv(\\\"RAPIDS_JAR\\\", \\\"/your-jar-path\\\")\\n\",\n    \"\\n\",\n    \"# You need to update with your real hardware resource \\n\",\n    \"driverMem = os.getenv(\\\"DRIVER_MEM\\\", \\\"2g\\\")\\n\",\n    \"executorMem = os.getenv(\\\"EXECUTOR_MEM\\\", \\\"2g\\\")\\n\",\n    \"pinnedPoolSize = os.getenv(\\\"PINNED_POOL_SIZE\\\", \\\"2g\\\")\\n\",\n    \"concurrentGpuTasks = os.getenv(\\\"CONCURRENT_GPU_TASKS\\\", \\\"2\\\")\\n\",\n    \"executorCores = int(os.getenv(\\\"EXECUTOR_CORES\\\", \\\"4\\\"))\\n\",\n    \"# Common spark settings\\n\",\n    \"conf = SparkConf()\\n\",\n    \"conf.setMaster(SPARK_MASTER_URL)\\n\",\n    \"conf.setAppName(\\\"Microbenchmark on GPU\\\")\\n\",\n    \"conf.set(\\\"spark.driver.memory\\\", driverMem)\\n\",\n    \"## The tasks will run on GPU memory, so there is no need to set a high host memory\\n\",\n    \"conf.set(\\\"spark.executor.memory\\\", executorMem)\\n\",\n    \"## The tasks will run on GPU cores, so there is no need to use many cpu cores\\n\",\n    \"conf.set(\\\"spark.executor.cores\\\", executorCores)\\n\",\n    \"\\n\",\n    \"# Plugin settings\\n\",\n    \"conf.set(\\\"spark.executor.resource.gpu.amount\\\", \\\"1\\\")\\n\",\n    \"conf.set(\\\"spark.rapids.sql.concurrentGpuTasks\\\", concurrentGpuTasks)\\n\",\n    \"conf.set(\\\"spark.rapids.memory.pinnedPool.size\\\", pinnedPoolSize)\\n\",\n    \"# since pyspark and xgboost share the same GPU, we disable RMM to avoid GPU OOM while training \\n\",\n    \"conf.set(\\\"spark.rapids.memory.gpu.pool\\\", \\\"NONE\\\")\\n\",\n    \"conf.set(\\\"spark.locality.wait\\\",\\\"0\\\")\\n\",\n    \"##############note: only support value=1 https://github.com/dmlc/xgboost/blame/master/python-package/xgboost/spark/core.py#L370-L374\\n\",\n    \"conf.set(\\\"spark.task.resource.gpu.amount\\\", 1) \\n\",\n    \"conf.set(\\\"spark.rapids.sql.enabled\\\", \\\"true\\\") \\n\",\n    \"conf.set(\\\"spark.plugins\\\", \\\"com.nvidia.spark.SQLPlugin\\\")\\n\",\n    \"conf.set(\\\"spark.sql.cache.serializer\\\",\\\"com.nvidia.spark.ParquetCachedBatchSerializer\\\")\\n\",\n    \"conf.set(\\\"spark.sql.execution.arrow.maxRecordsPerBatch\\\", 200000) \\n\",\n    \"conf.set(\\\"spark.driver.extraClassPath\\\", RAPIDS_JAR)\\n\",\n    \"conf.set(\\\"spark.executor.extraClassPath\\\", RAPIDS_JAR)\\n\",\n    \"# if you pass/unpack the archive file and enable the environment\\n\",\n    \"# conf.set(\\\"spark.yarn.dist.archives\\\", \\\"your_pyspark_venv.tar.gz#environment\\\")\\n\",\n    \"# Create spark session\\n\",\n    \"spark = SparkSession.builder.config(conf=conf).getOrCreate()\\n\",\n    \"\\n\",\n    \"reader = spark.read\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Specify the Data Schema and Load the Data\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"label = 'delinquency_12'\\n\",\n    \"schema = StructType([\\n\",\n    \"    StructField('orig_channel', FloatType()),\\n\",\n    \"    StructField('first_home_buyer', FloatType()),\\n\",\n    \"    StructField('loan_purpose', FloatType()),\\n\",\n    \"    StructField('property_type', FloatType()),\\n\",\n    \"    StructField('occupancy_status', FloatType()),\\n\",\n    \"    StructField('property_state', FloatType()),\\n\",\n    \"    StructField('product_type', FloatType()),\\n\",\n    \"    StructField('relocation_mortgage_indicator', FloatType()),\\n\",\n    \"    StructField('seller_name', FloatType()),\\n\",\n    \"    StructField('mod_flag', FloatType()),\\n\",\n    \"    StructField('orig_interest_rate', FloatType()),\\n\",\n    \"    StructField('orig_upb', DoubleType()),\\n\",\n    \"    StructField('orig_loan_term', IntegerType()),\\n\",\n    \"    StructField('orig_ltv', FloatType()),\\n\",\n    \"    StructField('orig_cltv', FloatType()),\\n\",\n    \"    StructField('num_borrowers', FloatType()),\\n\",\n    \"    StructField('dti', FloatType()),\\n\",\n    \"    StructField('borrower_credit_score', FloatType()),\\n\",\n    \"    StructField('num_units', IntegerType()),\\n\",\n    \"    StructField('zip', IntegerType()),\\n\",\n    \"    StructField('mortgage_insurance_percent', FloatType()),\\n\",\n    \"    StructField('current_loan_delinquency_status', IntegerType()),\\n\",\n    \"    StructField('current_actual_upb', FloatType()),\\n\",\n    \"    StructField('interest_rate', FloatType()),\\n\",\n    \"    StructField('loan_age', FloatType()),\\n\",\n    \"    StructField('msa', FloatType()),\\n\",\n    \"    StructField('non_interest_bearing_upb', FloatType()),\\n\",\n    \"    StructField(label, IntegerType()),\\n\",\n    \"])\\n\",\n    \"features = [ x.name for x in schema if x.name != label ]\\n\",\n    \"\\n\",\n    \"# You need to update them to your real paths!\\n\",\n    \"dataRoot = os.getenv(\\\"DATA_ROOT\\\", \\\"/data\\\")\\n\",\n    \"train_path = dataRoot + \\\"/mortgage/output/train\\\"\\n\",\n    \"eval_path = dataRoot + \\\"/mortgage/output/eval\\\"\\n\",\n    \"\\n\",\n    \"data_format = 'parquet'\\n\",\n    \"has_header = 'true'\\n\",\n    \"if data_format == 'csv':\\n\",\n    \"    train_data = reader.schema(schema).option('header',has_header).csv(train_path)\\n\",\n    \"    trans_data = reader.schema(schema).option('header',has_header).csv(eval_path)\\n\",\n    \"else :\\n\",\n    \"    train_data = reader.load(train_path)\\n\",\n    \"    trans_data = reader.load(eval_path)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Build a XGBoost-Spark CrossValidator\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"params = { \\n\",\n    \"    \\\"tree_method\\\": \\\"hist\\\",\\n\",\n    \"    \\\"grow_policy\\\": \\\"depthwise\\\",\\n\",\n    \"    \\\"num_workers\\\": 1,\\n\",\n    \"    \\\"device\\\": \\\"cuda\\\",\\n\",\n    \"}\\n\",\n    \"\\n\",\n    \"params['features_col'] = features\\n\",\n    \"params['label_col'] = label\\n\",\n    \"    \\n\",\n    \"classifier = SparkXGBClassifier(**params)\\n\",\n    \"\\n\",\n    \"# Then build the evaluator and the hyperparameters\\n\",\n    \"evaluator = (MulticlassClassificationEvaluator()\\n\",\n    \"    .setLabelCol(label))\\n\",\n    \"param_grid = (ParamGridBuilder()\\n\",\n    \"    .addGrid(classifier.max_depth, [3, 6])\\n\",\n    \"    .addGrid(classifier.n_estimators, [100, 200])\\n\",\n    \"    .build())\\n\",\n    \"# Finally the corss validator\\n\",\n    \"cross_validator = (CrossValidator()\\n\",\n    \"    .setEstimator(classifier)\\n\",\n    \"    .setEvaluator(evaluator)\\n\",\n    \"    .setEstimatorParamMaps(param_grid)\\n\",\n    \"    .setNumFolds(2))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Start Cross Validation by Fitting Data to CrossValidator\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2022-11-25 09:35:01,049 WARN util.package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.\\n\",\n      \"If features_cols param set, then features_col param is ignored.\\n\",\n      \"/data/home/yuanli/work/reviews/pr252/pyspark_venv_20221125/lib/python3.8/site-packages/xgboost/sklearn.py:808: UserWarning: Loading a native XGBoost model with Scikit-Learn interface.\\n\",\n      \"  warnings.warn(\\\"Loading a native XGBoost model with Scikit-Learn interface.\\\")\\n\",\n      \"2022-11-25 09:35:26,758 WARN rapids.GpuOverrides: \\n\",\n      \"! <DeserializeToObjectExec> cannot run on GPU because not all expressions can be replaced; GPU does not currently support the operator class org.apache.spark.sql.execution.DeserializeToObjectExec\\n\",\n      \"  ! <CreateExternalRow> createexternalrow(prediction#2153, delinquency_12#2255, 1.0#2256, newInstance(class org.apache.spark.ml.linalg.VectorUDT).deserialize, StructField(prediction,DoubleType,true), StructField(delinquency_12,DoubleType,true), StructField(1.0,DoubleType,false), StructField(probability,org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7,true)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.CreateExternalRow\\n\",\n      \"    @Expression <AttributeReference> prediction#2153 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> delinquency_12#2255 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> 1.0#2256 could run on GPU\\n\",\n      \"    ! <Invoke> newInstance(class org.apache.spark.ml.linalg.VectorUDT).deserialize cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.Invoke\\n\",\n      \"      ! <NewInstance> newInstance(class org.apache.spark.ml.linalg.VectorUDT) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.NewInstance\\n\",\n      \"      !Expression <AttributeReference> probability#2186 cannot run on GPU because expression AttributeReference probability#2186 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\\n\",\n      \"  !Expression <AttributeReference> obj#2261 cannot run on GPU because expression AttributeReference obj#2261 produces an unsupported type ObjectType(interface org.apache.spark.sql.Row)\\n\",\n      \"  !Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced; unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [probability#2186]\\n\",\n      \"    @Expression <Alias> pythonUDF0#2552.prediction AS prediction#2153 could run on GPU\\n\",\n      \"      @Expression <GetStructField> pythonUDF0#2552.prediction could run on GPU\\n\",\n      \"        @Expression <AttributeReference> pythonUDF0#2552 could run on GPU\\n\",\n      \"    @Expression <Alias> cast(delinquency_12#27 as double) AS delinquency_12#2255 could run on GPU\\n\",\n      \"      @Expression <Cast> cast(delinquency_12#27 as double) could run on GPU\\n\",\n      \"        @Expression <AttributeReference> delinquency_12#27 could run on GPU\\n\",\n      \"    @Expression <Alias> 1.0 AS 1.0#2256 could run on GPU\\n\",\n      \"      @Expression <Literal> 1.0 could run on GPU\\n\",\n      \"    !Expression <Alias> UDF(pythonUDF0#2552.probability) AS probability#2186 cannot run on GPU because input expression ScalaUDF UDF(pythonUDF0#2552.probability) (org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 is not supported); expression Alias UDF(pythonUDF0#2552.probability) AS probability#2186 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\\n\",\n      \"      !Expression <ScalaUDF> UDF(pythonUDF0#2552.probability) cannot run on GPU because expression ScalaUDF UDF(pythonUDF0#2552.probability) produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7; neither UDF implemented by class org.apache.spark.ml.functions$$$Lambda$3606/1625633331 provides a GPU implementation, nor the conf `spark.rapids.sql.rowBasedUDF.enabled` is enabled\\n\",\n      \"        @Expression <GetStructField> pythonUDF0#2552.probability could run on GPU\\n\",\n      \"          @Expression <AttributeReference> pythonUDF0#2552 could run on GPU\\n\",\n      \"\\n\",\n      \"If features_cols param set, then features_col param is ignored.                 \\n\",\n      \"2022-11-25 09:35:34,074 WARN rapids.GpuOverrides:                               \\n\",\n      \"! <DeserializeToObjectExec> cannot run on GPU because not all expressions can be replaced; GPU does not currently support the operator class org.apache.spark.sql.execution.DeserializeToObjectExec\\n\",\n      \"  ! <CreateExternalRow> createexternalrow(prediction#4415, delinquency_12#4517, 1.0#4518, newInstance(class org.apache.spark.ml.linalg.VectorUDT).deserialize, StructField(prediction,DoubleType,true), StructField(delinquency_12,DoubleType,true), StructField(1.0,DoubleType,false), StructField(probability,org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7,true)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.CreateExternalRow\\n\",\n      \"    @Expression <AttributeReference> prediction#4415 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> delinquency_12#4517 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> 1.0#4518 could run on GPU\\n\",\n      \"    ! <Invoke> newInstance(class org.apache.spark.ml.linalg.VectorUDT).deserialize cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.Invoke\\n\",\n      \"      ! <NewInstance> newInstance(class org.apache.spark.ml.linalg.VectorUDT) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.NewInstance\\n\",\n      \"      !Expression <AttributeReference> probability#4448 cannot run on GPU because expression AttributeReference probability#4448 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\\n\",\n      \"  !Expression <AttributeReference> obj#4523 cannot run on GPU because expression AttributeReference obj#4523 produces an unsupported type ObjectType(interface org.apache.spark.sql.Row)\\n\",\n      \"  !Exec <ProjectExec> cannot run on GPU because unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [probability#4448]; not all expressions can be replaced\\n\",\n      \"    @Expression <Alias> pythonUDF0#4814.prediction AS prediction#4415 could run on GPU\\n\",\n      \"      @Expression <GetStructField> pythonUDF0#4814.prediction could run on GPU\\n\",\n      \"        @Expression <AttributeReference> pythonUDF0#4814 could run on GPU\\n\",\n      \"    @Expression <Alias> cast(delinquency_12#27 as double) AS delinquency_12#4517 could run on GPU\\n\",\n      \"      @Expression <Cast> cast(delinquency_12#27 as double) could run on GPU\\n\",\n      \"        @Expression <AttributeReference> delinquency_12#27 could run on GPU\\n\",\n      \"    @Expression <Alias> 1.0 AS 1.0#4518 could run on GPU\\n\",\n      \"      @Expression <Literal> 1.0 could run on GPU\\n\",\n      \"    !Expression <Alias> UDF(pythonUDF0#4814.probability) AS probability#4448 cannot run on GPU because expression Alias UDF(pythonUDF0#4814.probability) AS probability#4448 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7; input expression ScalaUDF UDF(pythonUDF0#4814.probability) (org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 is not supported)\\n\",\n      \"      !Expression <ScalaUDF> UDF(pythonUDF0#4814.probability) cannot run on GPU because neither UDF implemented by class org.apache.spark.ml.functions$$$Lambda$3606/1625633331 provides a GPU implementation, nor the conf `spark.rapids.sql.rowBasedUDF.enabled` is enabled; expression ScalaUDF UDF(pythonUDF0#4814.probability) produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\\n\",\n      \"        @Expression <GetStructField> pythonUDF0#4814.probability could run on GPU\\n\",\n      \"          @Expression <AttributeReference> pythonUDF0#4814 could run on GPU\\n\",\n      \"\\n\",\n      \"If features_cols param set, then features_col param is ignored.\\n\",\n      \"2022-11-25 09:35:37,859 WARN rapids.GpuOverrides:                               \\n\",\n      \"! <DeserializeToObjectExec> cannot run on GPU because not all expressions can be replaced; GPU does not currently support the operator class org.apache.spark.sql.execution.DeserializeToObjectExec\\n\",\n      \"  ! <CreateExternalRow> createexternalrow(prediction#6677, delinquency_12#6779, 1.0#6780, newInstance(class org.apache.spark.ml.linalg.VectorUDT).deserialize, StructField(prediction,DoubleType,true), StructField(delinquency_12,DoubleType,true), StructField(1.0,DoubleType,false), StructField(probability,org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7,true)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.CreateExternalRow\\n\",\n      \"    @Expression <AttributeReference> prediction#6677 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> delinquency_12#6779 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> 1.0#6780 could run on GPU\\n\",\n      \"    ! <Invoke> newInstance(class org.apache.spark.ml.linalg.VectorUDT).deserialize cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.Invoke\\n\",\n      \"      ! <NewInstance> newInstance(class org.apache.spark.ml.linalg.VectorUDT) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.NewInstance\\n\",\n      \"      !Expression <AttributeReference> probability#6710 cannot run on GPU because expression AttributeReference probability#6710 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\\n\",\n      \"  !Expression <AttributeReference> obj#6785 cannot run on GPU because expression AttributeReference obj#6785 produces an unsupported type ObjectType(interface org.apache.spark.sql.Row)\\n\",\n      \"  !Exec <ProjectExec> cannot run on GPU because unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [probability#6710]; not all expressions can be replaced\\n\",\n      \"    @Expression <Alias> pythonUDF0#7076.prediction AS prediction#6677 could run on GPU\\n\",\n      \"      @Expression <GetStructField> pythonUDF0#7076.prediction could run on GPU\\n\",\n      \"        @Expression <AttributeReference> pythonUDF0#7076 could run on GPU\\n\",\n      \"    @Expression <Alias> cast(delinquency_12#27 as double) AS delinquency_12#6779 could run on GPU\\n\",\n      \"      @Expression <Cast> cast(delinquency_12#27 as double) could run on GPU\\n\",\n      \"        @Expression <AttributeReference> delinquency_12#27 could run on GPU\\n\",\n      \"    @Expression <Alias> 1.0 AS 1.0#6780 could run on GPU\\n\",\n      \"      @Expression <Literal> 1.0 could run on GPU\\n\",\n      \"    !Expression <Alias> UDF(pythonUDF0#7076.probability) AS probability#6710 cannot run on GPU because input expression ScalaUDF UDF(pythonUDF0#7076.probability) (org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 is not supported); expression Alias UDF(pythonUDF0#7076.probability) AS probability#6710 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\\n\",\n      \"      !Expression <ScalaUDF> UDF(pythonUDF0#7076.probability) cannot run on GPU because expression ScalaUDF UDF(pythonUDF0#7076.probability) produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7; neither UDF implemented by class org.apache.spark.ml.functions$$$Lambda$3606/1625633331 provides a GPU implementation, nor the conf `spark.rapids.sql.rowBasedUDF.enabled` is enabled\\n\",\n      \"        @Expression <GetStructField> pythonUDF0#7076.probability could run on GPU\\n\",\n      \"          @Expression <AttributeReference> pythonUDF0#7076 could run on GPU\\n\",\n      \"\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"If features_cols param set, then features_col param is ignored.\\n\",\n      \"2022-11-25 09:35:41,551 WARN rapids.GpuOverrides:                               \\n\",\n      \"! <DeserializeToObjectExec> cannot run on GPU because not all expressions can be replaced; GPU does not currently support the operator class org.apache.spark.sql.execution.DeserializeToObjectExec\\n\",\n      \"  ! <CreateExternalRow> createexternalrow(prediction#8939, delinquency_12#9041, 1.0#9042, newInstance(class org.apache.spark.ml.linalg.VectorUDT).deserialize, StructField(prediction,DoubleType,true), StructField(delinquency_12,DoubleType,true), StructField(1.0,DoubleType,false), StructField(probability,org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7,true)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.CreateExternalRow\\n\",\n      \"    @Expression <AttributeReference> prediction#8939 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> delinquency_12#9041 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> 1.0#9042 could run on GPU\\n\",\n      \"    ! <Invoke> newInstance(class org.apache.spark.ml.linalg.VectorUDT).deserialize cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.Invoke\\n\",\n      \"      ! <NewInstance> newInstance(class org.apache.spark.ml.linalg.VectorUDT) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.NewInstance\\n\",\n      \"      !Expression <AttributeReference> probability#8972 cannot run on GPU because expression AttributeReference probability#8972 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\\n\",\n      \"  !Expression <AttributeReference> obj#9047 cannot run on GPU because expression AttributeReference obj#9047 produces an unsupported type ObjectType(interface org.apache.spark.sql.Row)\\n\",\n      \"  !Exec <ProjectExec> cannot run on GPU because unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [probability#8972]; not all expressions can be replaced\\n\",\n      \"    @Expression <Alias> pythonUDF0#9338.prediction AS prediction#8939 could run on GPU\\n\",\n      \"      @Expression <GetStructField> pythonUDF0#9338.prediction could run on GPU\\n\",\n      \"        @Expression <AttributeReference> pythonUDF0#9338 could run on GPU\\n\",\n      \"    @Expression <Alias> cast(delinquency_12#27 as double) AS delinquency_12#9041 could run on GPU\\n\",\n      \"      @Expression <Cast> cast(delinquency_12#27 as double) could run on GPU\\n\",\n      \"        @Expression <AttributeReference> delinquency_12#27 could run on GPU\\n\",\n      \"    @Expression <Alias> 1.0 AS 1.0#9042 could run on GPU\\n\",\n      \"      @Expression <Literal> 1.0 could run on GPU\\n\",\n      \"    !Expression <Alias> UDF(pythonUDF0#9338.probability) AS probability#8972 cannot run on GPU because input expression ScalaUDF UDF(pythonUDF0#9338.probability) (org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 is not supported); expression Alias UDF(pythonUDF0#9338.probability) AS probability#8972 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\\n\",\n      \"      !Expression <ScalaUDF> UDF(pythonUDF0#9338.probability) cannot run on GPU because neither UDF implemented by class org.apache.spark.ml.functions$$$Lambda$3606/1625633331 provides a GPU implementation, nor the conf `spark.rapids.sql.rowBasedUDF.enabled` is enabled; expression ScalaUDF UDF(pythonUDF0#9338.probability) produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\\n\",\n      \"        @Expression <GetStructField> pythonUDF0#9338.probability could run on GPU\\n\",\n      \"          @Expression <AttributeReference> pythonUDF0#9338 could run on GPU\\n\",\n      \"\\n\",\n      \"If features_cols param set, then features_col param is ignored.\\n\",\n      \"2022-11-25 09:35:45,231 WARN rapids.GpuOverrides:                               \\n\",\n      \"! <DeserializeToObjectExec> cannot run on GPU because not all expressions can be replaced; GPU does not currently support the operator class org.apache.spark.sql.execution.DeserializeToObjectExec\\n\",\n      \"  ! <CreateExternalRow> createexternalrow(prediction#11491, delinquency_12#11593, 1.0#11594, newInstance(class org.apache.spark.ml.linalg.VectorUDT).deserialize, StructField(prediction,DoubleType,true), StructField(delinquency_12,DoubleType,true), StructField(1.0,DoubleType,false), StructField(probability,org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7,true)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.CreateExternalRow\\n\",\n      \"    @Expression <AttributeReference> prediction#11491 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> delinquency_12#11593 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> 1.0#11594 could run on GPU\\n\",\n      \"    ! <Invoke> newInstance(class org.apache.spark.ml.linalg.VectorUDT).deserialize cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.Invoke\\n\",\n      \"      ! <NewInstance> newInstance(class org.apache.spark.ml.linalg.VectorUDT) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.NewInstance\\n\",\n      \"      !Expression <AttributeReference> probability#11524 cannot run on GPU because expression AttributeReference probability#11524 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\\n\",\n      \"  !Expression <AttributeReference> obj#11599 cannot run on GPU because expression AttributeReference obj#11599 produces an unsupported type ObjectType(interface org.apache.spark.sql.Row)\\n\",\n      \"  !Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced; unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [probability#11524]\\n\",\n      \"    @Expression <Alias> pythonUDF0#11890.prediction AS prediction#11491 could run on GPU\\n\",\n      \"      @Expression <GetStructField> pythonUDF0#11890.prediction could run on GPU\\n\",\n      \"        @Expression <AttributeReference> pythonUDF0#11890 could run on GPU\\n\",\n      \"    @Expression <Alias> cast(delinquency_12#27 as double) AS delinquency_12#11593 could run on GPU\\n\",\n      \"      @Expression <Cast> cast(delinquency_12#27 as double) could run on GPU\\n\",\n      \"        @Expression <AttributeReference> delinquency_12#27 could run on GPU\\n\",\n      \"    @Expression <Alias> 1.0 AS 1.0#11594 could run on GPU\\n\",\n      \"      @Expression <Literal> 1.0 could run on GPU\\n\",\n      \"    !Expression <Alias> UDF(pythonUDF0#11890.probability) AS probability#11524 cannot run on GPU because expression Alias UDF(pythonUDF0#11890.probability) AS probability#11524 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7; input expression ScalaUDF UDF(pythonUDF0#11890.probability) (org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 is not supported)\\n\",\n      \"      !Expression <ScalaUDF> UDF(pythonUDF0#11890.probability) cannot run on GPU because neither UDF implemented by class org.apache.spark.ml.functions$$$Lambda$3606/1625633331 provides a GPU implementation, nor the conf `spark.rapids.sql.rowBasedUDF.enabled` is enabled; expression ScalaUDF UDF(pythonUDF0#11890.probability) produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\\n\",\n      \"        @Expression <GetStructField> pythonUDF0#11890.probability could run on GPU\\n\",\n      \"          @Expression <AttributeReference> pythonUDF0#11890 could run on GPU\\n\",\n      \"\\n\",\n      \"If features_cols param set, then features_col param is ignored.\\n\",\n      \"2022-11-25 09:35:49,003 WARN rapids.GpuOverrides:                               \\n\",\n      \"! <DeserializeToObjectExec> cannot run on GPU because not all expressions can be replaced; GPU does not currently support the operator class org.apache.spark.sql.execution.DeserializeToObjectExec\\n\",\n      \"  ! <CreateExternalRow> createexternalrow(prediction#13753, delinquency_12#13855, 1.0#13856, newInstance(class org.apache.spark.ml.linalg.VectorUDT).deserialize, StructField(prediction,DoubleType,true), StructField(delinquency_12,DoubleType,true), StructField(1.0,DoubleType,false), StructField(probability,org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7,true)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.CreateExternalRow\\n\",\n      \"    @Expression <AttributeReference> prediction#13753 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> delinquency_12#13855 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> 1.0#13856 could run on GPU\\n\",\n      \"    ! <Invoke> newInstance(class org.apache.spark.ml.linalg.VectorUDT).deserialize cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.Invoke\\n\",\n      \"      ! <NewInstance> newInstance(class org.apache.spark.ml.linalg.VectorUDT) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.NewInstance\\n\",\n      \"      !Expression <AttributeReference> probability#13786 cannot run on GPU because expression AttributeReference probability#13786 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\\n\",\n      \"  !Expression <AttributeReference> obj#13861 cannot run on GPU because expression AttributeReference obj#13861 produces an unsupported type ObjectType(interface org.apache.spark.sql.Row)\\n\",\n      \"  !Exec <ProjectExec> cannot run on GPU because unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [probability#13786]; not all expressions can be replaced\\n\",\n      \"    @Expression <Alias> pythonUDF0#14152.prediction AS prediction#13753 could run on GPU\\n\",\n      \"      @Expression <GetStructField> pythonUDF0#14152.prediction could run on GPU\\n\",\n      \"        @Expression <AttributeReference> pythonUDF0#14152 could run on GPU\\n\",\n      \"    @Expression <Alias> cast(delinquency_12#27 as double) AS delinquency_12#13855 could run on GPU\\n\",\n      \"      @Expression <Cast> cast(delinquency_12#27 as double) could run on GPU\\n\",\n      \"        @Expression <AttributeReference> delinquency_12#27 could run on GPU\\n\",\n      \"    @Expression <Alias> 1.0 AS 1.0#13856 could run on GPU\\n\",\n      \"      @Expression <Literal> 1.0 could run on GPU\\n\",\n      \"    !Expression <Alias> UDF(pythonUDF0#14152.probability) AS probability#13786 cannot run on GPU because expression Alias UDF(pythonUDF0#14152.probability) AS probability#13786 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7; input expression ScalaUDF UDF(pythonUDF0#14152.probability) (org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 is not supported)\\n\",\n      \"      !Expression <ScalaUDF> UDF(pythonUDF0#14152.probability) cannot run on GPU because expression ScalaUDF UDF(pythonUDF0#14152.probability) produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7; neither UDF implemented by class org.apache.spark.ml.functions$$$Lambda$3606/1625633331 provides a GPU implementation, nor the conf `spark.rapids.sql.rowBasedUDF.enabled` is enabled\\n\",\n      \"        @Expression <GetStructField> pythonUDF0#14152.probability could run on GPU\\n\",\n      \"          @Expression <AttributeReference> pythonUDF0#14152 could run on GPU\\n\",\n      \"\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"If features_cols param set, then features_col param is ignored.\\n\",\n      \"2022-11-25 09:35:52,578 WARN rapids.GpuOverrides:                               \\n\",\n      \"! <DeserializeToObjectExec> cannot run on GPU because not all expressions can be replaced; GPU does not currently support the operator class org.apache.spark.sql.execution.DeserializeToObjectExec\\n\",\n      \"  ! <CreateExternalRow> createexternalrow(prediction#16015, delinquency_12#16117, 1.0#16118, newInstance(class org.apache.spark.ml.linalg.VectorUDT).deserialize, StructField(prediction,DoubleType,true), StructField(delinquency_12,DoubleType,true), StructField(1.0,DoubleType,false), StructField(probability,org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7,true)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.CreateExternalRow\\n\",\n      \"    @Expression <AttributeReference> prediction#16015 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> delinquency_12#16117 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> 1.0#16118 could run on GPU\\n\",\n      \"    ! <Invoke> newInstance(class org.apache.spark.ml.linalg.VectorUDT).deserialize cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.Invoke\\n\",\n      \"      ! <NewInstance> newInstance(class org.apache.spark.ml.linalg.VectorUDT) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.NewInstance\\n\",\n      \"      !Expression <AttributeReference> probability#16048 cannot run on GPU because expression AttributeReference probability#16048 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\\n\",\n      \"  !Expression <AttributeReference> obj#16123 cannot run on GPU because expression AttributeReference obj#16123 produces an unsupported type ObjectType(interface org.apache.spark.sql.Row)\\n\",\n      \"  !Exec <ProjectExec> cannot run on GPU because unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [probability#16048]; not all expressions can be replaced\\n\",\n      \"    @Expression <Alias> pythonUDF0#16414.prediction AS prediction#16015 could run on GPU\\n\",\n      \"      @Expression <GetStructField> pythonUDF0#16414.prediction could run on GPU\\n\",\n      \"        @Expression <AttributeReference> pythonUDF0#16414 could run on GPU\\n\",\n      \"    @Expression <Alias> cast(delinquency_12#27 as double) AS delinquency_12#16117 could run on GPU\\n\",\n      \"      @Expression <Cast> cast(delinquency_12#27 as double) could run on GPU\\n\",\n      \"        @Expression <AttributeReference> delinquency_12#27 could run on GPU\\n\",\n      \"    @Expression <Alias> 1.0 AS 1.0#16118 could run on GPU\\n\",\n      \"      @Expression <Literal> 1.0 could run on GPU\\n\",\n      \"    !Expression <Alias> UDF(pythonUDF0#16414.probability) AS probability#16048 cannot run on GPU because expression Alias UDF(pythonUDF0#16414.probability) AS probability#16048 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7; input expression ScalaUDF UDF(pythonUDF0#16414.probability) (org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 is not supported)\\n\",\n      \"      !Expression <ScalaUDF> UDF(pythonUDF0#16414.probability) cannot run on GPU because expression ScalaUDF UDF(pythonUDF0#16414.probability) produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7; neither UDF implemented by class org.apache.spark.ml.functions$$$Lambda$3606/1625633331 provides a GPU implementation, nor the conf `spark.rapids.sql.rowBasedUDF.enabled` is enabled\\n\",\n      \"        @Expression <GetStructField> pythonUDF0#16414.probability could run on GPU\\n\",\n      \"          @Expression <AttributeReference> pythonUDF0#16414 could run on GPU\\n\",\n      \"\\n\",\n      \"If features_cols param set, then features_col param is ignored.\\n\",\n      \"2022-11-25 09:35:56,267 WARN rapids.GpuOverrides:                               \\n\",\n      \"! <DeserializeToObjectExec> cannot run on GPU because not all expressions can be replaced; GPU does not currently support the operator class org.apache.spark.sql.execution.DeserializeToObjectExec\\n\",\n      \"  ! <CreateExternalRow> createexternalrow(prediction#18277, delinquency_12#18379, 1.0#18380, newInstance(class org.apache.spark.ml.linalg.VectorUDT).deserialize, StructField(prediction,DoubleType,true), StructField(delinquency_12,DoubleType,true), StructField(1.0,DoubleType,false), StructField(probability,org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7,true)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.CreateExternalRow\\n\",\n      \"    @Expression <AttributeReference> prediction#18277 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> delinquency_12#18379 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> 1.0#18380 could run on GPU\\n\",\n      \"    ! <Invoke> newInstance(class org.apache.spark.ml.linalg.VectorUDT).deserialize cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.Invoke\\n\",\n      \"      ! <NewInstance> newInstance(class org.apache.spark.ml.linalg.VectorUDT) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.NewInstance\\n\",\n      \"      !Expression <AttributeReference> probability#18310 cannot run on GPU because expression AttributeReference probability#18310 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\\n\",\n      \"  !Expression <AttributeReference> obj#18385 cannot run on GPU because expression AttributeReference obj#18385 produces an unsupported type ObjectType(interface org.apache.spark.sql.Row)\\n\",\n      \"  !Exec <ProjectExec> cannot run on GPU because unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [probability#18310]; not all expressions can be replaced\\n\",\n      \"    @Expression <Alias> pythonUDF0#18676.prediction AS prediction#18277 could run on GPU\\n\",\n      \"      @Expression <GetStructField> pythonUDF0#18676.prediction could run on GPU\\n\",\n      \"        @Expression <AttributeReference> pythonUDF0#18676 could run on GPU\\n\",\n      \"    @Expression <Alias> cast(delinquency_12#27 as double) AS delinquency_12#18379 could run on GPU\\n\",\n      \"      @Expression <Cast> cast(delinquency_12#27 as double) could run on GPU\\n\",\n      \"        @Expression <AttributeReference> delinquency_12#27 could run on GPU\\n\",\n      \"    @Expression <Alias> 1.0 AS 1.0#18380 could run on GPU\\n\",\n      \"      @Expression <Literal> 1.0 could run on GPU\\n\",\n      \"    !Expression <Alias> UDF(pythonUDF0#18676.probability) AS probability#18310 cannot run on GPU because expression Alias UDF(pythonUDF0#18676.probability) AS probability#18310 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7; input expression ScalaUDF UDF(pythonUDF0#18676.probability) (org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 is not supported)\\n\",\n      \"      !Expression <ScalaUDF> UDF(pythonUDF0#18676.probability) cannot run on GPU because neither UDF implemented by class org.apache.spark.ml.functions$$$Lambda$3606/1625633331 provides a GPU implementation, nor the conf `spark.rapids.sql.rowBasedUDF.enabled` is enabled; expression ScalaUDF UDF(pythonUDF0#18676.probability) produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\\n\",\n      \"        @Expression <GetStructField> pythonUDF0#18676.probability could run on GPU\\n\",\n      \"          @Expression <AttributeReference> pythonUDF0#18676 could run on GPU\\n\",\n      \"\\n\",\n      \"If features_cols param set, then features_col param is ignored.\\n\",\n      \"[Stage 69:>                                                         (0 + 1) / 1]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Cross-Validation takes 59.46 seconds\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\r\",\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"def with_benchmark(phrase, action):\\n\",\n    \"    start = time()\\n\",\n    \"    result = action()\\n\",\n    \"    end = time()\\n\",\n    \"    print('{} takes {} seconds'.format(phrase, round(end - start, 2)))\\n\",\n    \"    return result\\n\",\n    \"model = with_benchmark('Cross-Validation', lambda: cross_validator.fit(train_data)).bestModel\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Transform On the Best Model\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2022-11-25 09:35:59,886 WARN rapids.GpuOverrides: \\n\",\n      \"!Exec <ProjectExec> cannot run on GPU because unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [rawPrediction#18908, probability#18974]; not all expressions can be replaced\\n\",\n      \"  @Expression <AttributeReference> orig_channel#56 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> first_home_buyer#57 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> loan_purpose#58 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> property_type#59 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> occupancy_status#60 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> property_state#61 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> product_type#62 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> relocation_mortgage_indicator#63 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> seller_name#64 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> mod_flag#65 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> orig_interest_rate#66 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> orig_upb#67 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> orig_loan_term#68 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> orig_ltv#69 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> orig_cltv#70 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> num_borrowers#71 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> dti#72 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> borrower_credit_score#73 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> num_units#74 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> zip#75 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> mortgage_insurance_percent#76 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> current_loan_delinquency_status#77 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> current_actual_upb#78 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> interest_rate#79 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> loan_age#80 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> msa#81 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> non_interest_bearing_upb#82 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> delinquency_12#83 could run on GPU\\n\",\n      \"  !Expression <Alias> UDF(pythonUDF0#19041.rawPrediction) AS rawPrediction#18908 cannot run on GPU because input expression ScalaUDF UDF(pythonUDF0#19041.rawPrediction) (org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 is not supported); expression Alias UDF(pythonUDF0#19041.rawPrediction) AS rawPrediction#18908 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\\n\",\n      \"    !Expression <ScalaUDF> UDF(pythonUDF0#19041.rawPrediction) cannot run on GPU because neither UDF implemented by class org.apache.spark.ml.functions$$$Lambda$3606/1625633331 provides a GPU implementation, nor the conf `spark.rapids.sql.rowBasedUDF.enabled` is enabled; expression ScalaUDF UDF(pythonUDF0#19041.rawPrediction) produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\\n\",\n      \"      @Expression <GetStructField> pythonUDF0#19041.rawPrediction could run on GPU\\n\",\n      \"        @Expression <AttributeReference> pythonUDF0#19041 could run on GPU\\n\",\n      \"  @Expression <Alias> pythonUDF0#19041.prediction AS prediction#18942 could run on GPU\\n\",\n      \"    @Expression <GetStructField> pythonUDF0#19041.prediction could run on GPU\\n\",\n      \"      @Expression <AttributeReference> pythonUDF0#19041 could run on GPU\\n\",\n      \"  !Expression <Alias> UDF(pythonUDF0#19041.probability) AS probability#18974 cannot run on GPU because expression Alias UDF(pythonUDF0#19041.probability) AS probability#18974 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7; input expression ScalaUDF UDF(pythonUDF0#19041.probability) (org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 is not supported)\\n\",\n      \"    !Expression <ScalaUDF> UDF(pythonUDF0#19041.probability) cannot run on GPU because neither UDF implemented by class org.apache.spark.ml.functions$$$Lambda$3606/1625633331 provides a GPU implementation, nor the conf `spark.rapids.sql.rowBasedUDF.enabled` is enabled; expression ScalaUDF UDF(pythonUDF0#19041.probability) produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\\n\",\n      \"      @Expression <GetStructField> pythonUDF0#19041.probability could run on GPU\\n\",\n      \"        @Expression <AttributeReference> pythonUDF0#19041 could run on GPU\\n\",\n      \"\\n\",\n      \"2022-11-25 09:35:59,893 WARN rapids.GpuOverrides: \\n\",\n      \"!Exec <InMemoryTableScanExec> cannot run on GPU because unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [rawPrediction, probability]; unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [rawPrediction#18908, probability#18974]; not all expressions can be replaced\\n\",\n      \"  @Expression <AttributeReference> orig_channel#56 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> first_home_buyer#57 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> loan_purpose#58 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> property_type#59 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> occupancy_status#60 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> property_state#61 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> product_type#62 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> relocation_mortgage_indicator#63 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> seller_name#64 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> mod_flag#65 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> orig_interest_rate#66 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> orig_upb#67 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> orig_loan_term#68 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> orig_ltv#69 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> orig_cltv#70 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> num_borrowers#71 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> dti#72 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> borrower_credit_score#73 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> num_units#74 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> zip#75 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> mortgage_insurance_percent#76 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> current_loan_delinquency_status#77 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> current_actual_upb#78 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> interest_rate#79 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> loan_age#80 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> msa#81 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> non_interest_bearing_upb#82 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> delinquency_12#83 could run on GPU\\n\",\n      \"  !Expression <AttributeReference> rawPrediction#18908 cannot run on GPU because expression AttributeReference rawPrediction#18908 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\\n\",\n      \"  @Expression <AttributeReference> prediction#18942 could run on GPU\\n\",\n      \"  !Expression <AttributeReference> probability#18974 cannot run on GPU because expression AttributeReference probability#18974 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\\n\",\n      \"\\n\",\n      \"2022-11-25 09:36:00,975 WARN rapids.GpuOverrides:                               \\n\",\n      \"!Exec <CollectLimitExec> cannot run on GPU because the Exec CollectLimitExec has been disabled, and is disabled by default because Collect Limit replacement can be slower on the GPU, if huge number of rows in a batch it could help by limiting the number of rows transferred from GPU to CPU. Set spark.rapids.sql.exec.CollectLimitExec to true if you wish to enable it\\n\",\n      \"  @Partitioning <SinglePartition$> could run on GPU\\n\",\n      \"  !Exec <ProjectExec> cannot run on GPU because unsupported data types in input: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [probability#18974, rawPrediction#18908]; not all expressions can be replaced\\n\",\n      \"    @Expression <Alias> cast(delinquency_12#83 as string) AS delinquency_12#19670 could run on GPU\\n\",\n      \"      @Expression <Cast> cast(delinquency_12#83 as string) could run on GPU\\n\",\n      \"        @Expression <AttributeReference> delinquency_12#83 could run on GPU\\n\",\n      \"    @Expression <Alias> cast(rawPrediction#18908 as string) AS rawPrediction#19671 could run on GPU\\n\",\n      \"      !Expression <Cast> cast(rawPrediction#18908 as string) cannot run on GPU because Cast from org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 to StringType is not supported\\n\",\n      \"        !Expression <AttributeReference> rawPrediction#18908 cannot run on GPU because expression AttributeReference rawPrediction#18908 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\\n\",\n      \"    @Expression <Alias> cast(probability#18974 as string) AS probability#19672 could run on GPU\\n\",\n      \"      !Expression <Cast> cast(probability#18974 as string) cannot run on GPU because Cast from org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 to StringType is not supported\\n\",\n      \"        !Expression <AttributeReference> probability#18974 cannot run on GPU because expression AttributeReference probability#18974 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\\n\",\n      \"    @Expression <Alias> cast(prediction#18942 as string) AS prediction#19673 could run on GPU\\n\",\n      \"      @Expression <Cast> cast(prediction#18942 as string) could run on GPU\\n\",\n      \"        @Expression <AttributeReference> prediction#18942 could run on GPU\\n\",\n      \"    !Exec <InMemoryTableScanExec> cannot run on GPU because unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [probability#18974, rawPrediction#18908]; unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [rawPrediction, probability]; not all expressions can be replaced\\n\",\n      \"      @Expression <AttributeReference> delinquency_12#83 could run on GPU\\n\",\n      \"      @Expression <AttributeReference> prediction#18942 could run on GPU\\n\",\n      \"      !Expression <AttributeReference> probability#18974 cannot run on GPU because expression AttributeReference probability#18974 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\\n\",\n      \"      !Expression <AttributeReference> rawPrediction#18908 cannot run on GPU because expression AttributeReference rawPrediction#18908 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\\n\",\n      \"\\n\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Transforming takes 1.15 seconds\\n\",\n      \"+--------------+--------------------+--------------------+----------+\\n\",\n      \"|delinquency_12|       rawPrediction|         probability|prediction|\\n\",\n      \"+--------------+--------------------+--------------------+----------+\\n\",\n      \"|             0|[10.2152490615844...|[0.99996340274810...|       0.0|\\n\",\n      \"|             0|[8.85215473175048...|[0.99985694885253...|       0.0|\\n\",\n      \"|             0|[8.85215473175048...|[0.99985694885253...|       0.0|\\n\",\n      \"|             0|[8.85215473175048...|[0.99985694885253...|       0.0|\\n\",\n      \"|             0|[10.2152490615844...|[0.99996340274810...|       0.0|\\n\",\n      \"+--------------+--------------------+--------------------+----------+\\n\",\n      \"only showing top 5 rows\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"def transform():\\n\",\n    \"    result = model.transform(trans_data).cache()\\n\",\n    \"    result.foreachPartition(lambda _: None)\\n\",\n    \"    return result\\n\",\n    \"result = with_benchmark('Transforming', transform)\\n\",\n    \"result.select(label, 'rawPrediction', 'probability', 'prediction').show(5)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Evaluation\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2022-11-25 09:36:01,155 WARN rapids.GpuOverrides: \\n\",\n      \"! <DeserializeToObjectExec> cannot run on GPU because not all expressions can be replaced; GPU does not currently support the operator class org.apache.spark.sql.execution.DeserializeToObjectExec\\n\",\n      \"  ! <CreateExternalRow> createexternalrow(prediction#18942, delinquency_12#20148, 1.0#20149, newInstance(class org.apache.spark.ml.linalg.VectorUDT).deserialize, StructField(prediction,DoubleType,true), StructField(delinquency_12,DoubleType,true), StructField(1.0,DoubleType,false), StructField(probability,org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7,true)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.CreateExternalRow\\n\",\n      \"    @Expression <AttributeReference> prediction#18942 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> delinquency_12#20148 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> 1.0#20149 could run on GPU\\n\",\n      \"    ! <Invoke> newInstance(class org.apache.spark.ml.linalg.VectorUDT).deserialize cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.Invoke\\n\",\n      \"      ! <NewInstance> newInstance(class org.apache.spark.ml.linalg.VectorUDT) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.NewInstance\\n\",\n      \"      !Expression <AttributeReference> probability#18974 cannot run on GPU because expression AttributeReference probability#18974 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\\n\",\n      \"  !Expression <AttributeReference> obj#20154 cannot run on GPU because expression AttributeReference obj#20154 produces an unsupported type ObjectType(interface org.apache.spark.sql.Row)\\n\",\n      \"  !Exec <ProjectExec> cannot run on GPU because unsupported data types in input: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [probability#18974]; not all expressions can be replaced; unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [probability#18974]\\n\",\n      \"    @Expression <AttributeReference> prediction#18942 could run on GPU\\n\",\n      \"    @Expression <Alias> cast(delinquency_12#83 as double) AS delinquency_12#20148 could run on GPU\\n\",\n      \"      @Expression <Cast> cast(delinquency_12#83 as double) could run on GPU\\n\",\n      \"        @Expression <AttributeReference> delinquency_12#83 could run on GPU\\n\",\n      \"    @Expression <Alias> 1.0 AS 1.0#20149 could run on GPU\\n\",\n      \"      @Expression <Literal> 1.0 could run on GPU\\n\",\n      \"    !Expression <AttributeReference> probability#18974 cannot run on GPU because expression AttributeReference probability#18974 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\\n\",\n      \"    !Exec <InMemoryTableScanExec> cannot run on GPU because unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [rawPrediction, probability]; not all expressions can be replaced; unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [probability#18974]\\n\",\n      \"      @Expression <AttributeReference> delinquency_12#83 could run on GPU\\n\",\n      \"      @Expression <AttributeReference> prediction#18942 could run on GPU\\n\",\n      \"      !Expression <AttributeReference> probability#18974 cannot run on GPU because expression AttributeReference probability#18974 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\\n\",\n      \"\\n\",\n      \"[Stage 72:>                                                         (0 + 1) / 1]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Evaluation takes 1.41 seconds\\n\",\n      \"Accuracy is 1.0\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\r\",\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"accuracy = with_benchmark(\\n\",\n    \"    'Evaluation',\\n\",\n    \"    lambda: MulticlassClassificationEvaluator().setLabelCol(label).evaluate(result))\\n\",\n    \"print('Accuracy is ' + str(accuracy))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"spark.stop()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.2\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "examples/XGBoost-Examples/mortgage/notebooks/python/mortgage-gpu.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Introduction to XGBoost Spark with GPU\\n\",\n    \"\\n\",\n    \"The goal of this notebook is to show how to train a XGBoost Model with Spark RAPIDS XGBoost library on GPUs. The dataset used with this notebook is derived from Fannie Mae’s Single-Family Loan Performance Data with all rights reserved by Fannie Mae. This processed dataset is redistributed with permission and consent from Fannie Mae. This notebook uses XGBoost to train 12-month mortgage loan delinquency prediction model .\\n\",\n    \"\\n\",\n    \"A few libraries required for this notebook:\\n\",\n    \"  1. cudf-cu11\\n\",\n    \"  2. xgboost\\n\",\n    \"  3. scikit-learn\\n\",\n    \"  4. numpy\\n\",\n    \"\\n\",\n    \"This notebook also illustrates the ease of porting a sample CPU based Spark xgboost4j code into GPU. There is no change required for running Spark XGBoost on GPU because both CPU and GPU call the same API. For CPU run, we need to vectorize the trained dataset before fitting data to classifier.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Import All Libraries\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import os\\n\",\n    \"\\n\",\n    \"# if you pass/unpack the archive file and enable the environment\\n\",\n    \"# os.environ['PYSPARK_PYTHON'] = \\\"./environment/bin/python\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from xgboost.spark import SparkXGBClassifier, SparkXGBClassifierModel\\n\",\n    \"from pyspark.ml.evaluation import MulticlassClassificationEvaluator\\n\",\n    \"from pyspark.sql import SparkSession\\n\",\n    \"from pyspark.sql.types import FloatType, IntegerType, StructField, StructType, DoubleType\\n\",\n    \"from pyspark.conf import SparkConf\\n\",\n    \"from time import time\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Besides CPU version requires two extra libraries.\\n\",\n    \"```Python\\n\",\n    \"from pyspark.ml.feature import VectorAssembler\\n\",\n    \"from pyspark.sql.functions import col\\n\",\n    \"```\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Create Spark Session and Data Reader\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Setting default log level to \\\"WARN\\\".\\n\",\n      \"To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\\n\",\n      \"22/11/24 06:14:05 WARN org.apache.spark.resource.ResourceUtils: The configuration of cores (exec = 4 task = 1, runnable tasks = 4) will result in wasted resources due to resource gpu limiting the number of runnable tasks per executor to: 1. Please adjust your configuration.\\n\",\n      \"22/11/24 06:14:06 INFO org.apache.spark.SparkEnv: Registering MapOutputTracker\\n\",\n      \"22/11/24 06:14:06 INFO org.apache.spark.SparkEnv: Registering BlockManagerMaster\\n\",\n      \"22/11/24 06:14:06 INFO org.apache.spark.SparkEnv: Registering BlockManagerMasterHeartbeat\\n\",\n      \"22/11/24 06:14:06 INFO org.apache.spark.SparkEnv: Registering OutputCommitCoordinator\\n\",\n      \"22/11/24 06:14:07 WARN com.nvidia.spark.rapids.RapidsPluginUtils: RAPIDS Accelerator 25.02.1 using cudf 25.02.1.\\n\",\n      \"22/11/24 06:14:07 WARN com.nvidia.spark.rapids.RapidsPluginUtils: spark.rapids.sql.multiThreadedRead.numThreads is set to 20.\\n\",\n      \"22/11/24 06:14:07 WARN com.nvidia.spark.rapids.RapidsPluginUtils: RAPIDS Accelerator is enabled, to disable GPU support set `spark.rapids.sql.enabled` to false.\\n\",\n      \"22/11/24 06:14:07 WARN com.nvidia.spark.rapids.RapidsPluginUtils: spark.rapids.sql.explain is set to `NOT_ON_GPU`. Set it to 'NONE' to suppress the diagnostics logging about the query placement on the GPU.\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"SPARK_MASTER_URL = os.getenv(\\\"SPARK_MASTER_URL\\\", \\\"/your-url\\\")\\n\",\n    \"RAPIDS_JAR = os.getenv(\\\"RAPIDS_JAR\\\", \\\"/your-jar-path\\\")\\n\",\n    \"\\n\",\n    \"# You need to update with your real hardware resource \\n\",\n    \"driverMem = os.getenv(\\\"DRIVER_MEM\\\", \\\"10g\\\")\\n\",\n    \"executorMem = os.getenv(\\\"EXECUTOR_MEM\\\", \\\"10g\\\")\\n\",\n    \"pinnedPoolSize = os.getenv(\\\"PINNED_POOL_SIZE\\\", \\\"2g\\\")\\n\",\n    \"concurrentGpuTasks = os.getenv(\\\"CONCURRENT_GPU_TASKS\\\", \\\"2\\\")\\n\",\n    \"executorCores = int(os.getenv(\\\"EXECUTOR_CORES\\\", \\\"4\\\"))\\n\",\n    \"\\n\",\n    \"# Common spark settings\\n\",\n    \"conf = SparkConf()\\n\",\n    \"conf.setMaster(SPARK_MASTER_URL)\\n\",\n    \"conf.setAppName(\\\"Microbenchmark on GPU\\\")\\n\",\n    \"conf.set(\\\"spark.driver.memory\\\", driverMem)\\n\",\n    \"## The tasks will run on GPU memory, so there is no need to set a high host memory\\n\",\n    \"conf.set(\\\"spark.executor.memory\\\", executorMem)\\n\",\n    \"## The tasks will run on GPU cores, so there is no need to use many cpu cores\\n\",\n    \"conf.set(\\\"spark.executor.cores\\\", executorCores)\\n\",\n    \"\\n\",\n    \"# Plugin settings\\n\",\n    \"conf.set(\\\"spark.executor.resource.gpu.amount\\\", \\\"1\\\")\\n\",\n    \"conf.set(\\\"spark.rapids.sql.concurrentGpuTasks\\\", concurrentGpuTasks)\\n\",\n    \"conf.set(\\\"spark.rapids.memory.pinnedPool.size\\\", pinnedPoolSize)\\n\",\n    \"##############note: only support value=1 see https://github.com/dmlc/xgboost/blame/master/python-package/xgboost/spark/core.py#L370-L374\\n\",\n    \"conf.set(\\\"spark.task.resource.gpu.amount\\\", 1) \\n\",\n    \"# since pyspark and xgboost share the same GPU, we disable RMM to avoid GPU OOM while training \\n\",\n    \"conf.set(\\\"spark.rapids.memory.gpu.pool\\\", \\\"NONE\\\")\\n\",\n    \"conf.set(\\\"spark.rapids.sql.enabled\\\", \\\"true\\\") \\n\",\n    \"conf.set(\\\"spark.plugins\\\", \\\"com.nvidia.spark.SQLPlugin\\\")\\n\",\n    \"conf.set(\\\"spark.sql.cache.serializer\\\",\\\"com.nvidia.spark.ParquetCachedBatchSerializer\\\")\\n\",\n    \"conf.set(\\\"spark.driver.extraClassPath\\\", RAPIDS_JAR)\\n\",\n    \"conf.set(\\\"spark.sql.execution.arrow.maxRecordsPerBatch\\\", 200000) \\n\",\n    \"conf.set(\\\"spark.executor.extraClassPath\\\", RAPIDS_JAR)\\n\",\n    \"conf.set(\\\"spark.jars\\\", RAPIDS_JAR)\\n\",\n    \"\\n\",\n    \"# if you pass/unpack the archive file and enable the environment\\n\",\n    \"# conf.set(\\\"spark.yarn.dist.archives\\\", \\\"your_pyspark_venv.tar.gz#environment\\\")\\n\",\n    \"\\n\",\n    \"# Create spark session\\n\",\n    \"spark = SparkSession.builder.config(conf=conf).getOrCreate()\\n\",\n    \"reader = spark.read\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Specify the Data Schema and Load the Data\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"label = 'delinquency_12'\\n\",\n    \"schema = StructType([\\n\",\n    \"    StructField('orig_channel', FloatType()),\\n\",\n    \"    StructField('first_home_buyer', FloatType()),\\n\",\n    \"    StructField('loan_purpose', FloatType()),\\n\",\n    \"    StructField('property_type', FloatType()),\\n\",\n    \"    StructField('occupancy_status', FloatType()),\\n\",\n    \"    StructField('property_state', FloatType()),\\n\",\n    \"    StructField('product_type', FloatType()),\\n\",\n    \"    StructField('relocation_mortgage_indicator', FloatType()),\\n\",\n    \"    StructField('seller_name', FloatType()),\\n\",\n    \"    StructField('mod_flag', FloatType()),\\n\",\n    \"    StructField('orig_interest_rate', FloatType()),\\n\",\n    \"    StructField('orig_upb', DoubleType()),\\n\",\n    \"    StructField('orig_loan_term', IntegerType()),\\n\",\n    \"    StructField('orig_ltv', FloatType()),\\n\",\n    \"    StructField('orig_cltv', FloatType()),\\n\",\n    \"    StructField('num_borrowers', FloatType()),\\n\",\n    \"    StructField('dti', FloatType()),\\n\",\n    \"    StructField('borrower_credit_score', FloatType()),\\n\",\n    \"    StructField('num_units', IntegerType()),\\n\",\n    \"    StructField('zip', IntegerType()),\\n\",\n    \"    StructField('mortgage_insurance_percent', FloatType()),\\n\",\n    \"    StructField('current_loan_delinquency_status', IntegerType()),\\n\",\n    \"    StructField('current_actual_upb', FloatType()),\\n\",\n    \"    StructField('interest_rate', FloatType()),\\n\",\n    \"    StructField('loan_age', FloatType()),\\n\",\n    \"    StructField('msa', FloatType()),\\n\",\n    \"    StructField('non_interest_bearing_upb', FloatType()),\\n\",\n    \"    StructField(label, IntegerType()),\\n\",\n    \"])\\n\",\n    \"features = [ x.name for x in schema if x.name != label ]\\n\",\n    \"\\n\",\n    \"# You need to update them to your real paths!\\n\",\n    \"dataRoot = os.getenv(\\\"DATA_ROOT\\\", \\\"/data\\\")\\n\",\n    \"train_path = dataRoot + \\\"/mortgage/output/train\\\"\\n\",\n    \"eval_path = dataRoot + \\\"/mortgage/output/eval\\\"\\n\",\n    \"\\n\",\n    \"data_format = 'parquet'\\n\",\n    \"has_header = 'true'\\n\",\n    \"if data_format == 'csv':\\n\",\n    \"    train_data = reader.schema(schema).option('header',has_header).csv(train_path)\\n\",\n    \"    trans_data = reader.schema(schema).option('header',has_header).csv(eval_path)\\n\",\n    \"else :\\n\",\n    \"    train_data = reader.load(train_path)\\n\",\n    \"    trans_data = reader.load(eval_path)\\n\",\n    \"  \"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Note on CPU version, vectorization is required before fitting data to classifier, which means you need to assemble all feature columns into one column.\\n\",\n    \"\\n\",\n    \"```Python\\n\",\n    \"def vectorize(data_frame):\\n\",\n    \"    to_floats = [ col(x.name).cast(FloatType()) for x in data_frame.schema ]\\n\",\n    \"    return (VectorAssembler()\\n\",\n    \"        .setInputCols(features)\\n\",\n    \"        .setOutputCol('features')\\n\",\n    \"        .transform(data_frame.select(to_floats))\\n\",\n    \"        .select(col('features'), col(label)))\\n\",\n    \"\\n\",\n    \"train_data = vectorize(train_data)\\n\",\n    \"trans_data = vectorize(trans_data)\\n\",\n    \"```\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Create a XGBoostClassifier\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"params = { \\n\",\n    \"    \\\"tree_method\\\": \\\"hist\\\",\\n\",\n    \"    \\\"grow_policy\\\": \\\"depthwise\\\",\\n\",\n    \"    \\\"num_workers\\\": 1,\\n\",\n    \"    \\\"device\\\": \\\"cuda\\\",\\n\",\n    \"}\\n\",\n    \"params['features_col'] = features\\n\",\n    \"params['label_col'] = label\\n\",\n    \"    \\n\",\n    \"classifier = SparkXGBClassifier(**params)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"The parameter `num_workers` should be set to the number of GPUs in Spark cluster for GPU version, while for CPU version it is usually equal to the number of the CPU cores.\\n\",\n    \"\\n\",\n    \"Concerning the device, GPU version only supports `cuda` currently, while `cpu` is designed and used here for CPU training.\\n\",\n    \"\\n\",\n    \"An example of CPU classifier:\\n\",\n    \"```\\n\",\n    \"classifier = SparkXGBClassifier(\\n\",\n    \"  feature_col=features,\\n\",\n    \"  label_col=label,  \\n\",\n    \"  num_workers=1024,\\n\",\n    \")\\n\",\n    \"```\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Train the Data with Benchmark\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"If features_cols param set, then features_col param is ignored.\\n\",\n      \"22/11/24 06:14:44 WARN org.apache.spark.sql.catalyst.util.package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.\\n\",\n      \"[Stage 12:>                                                         (0 + 1) / 1]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[06:15:10] WARNING: ../src/learner.cc:553: \\n\",\n      \"  If you are loading a serialized model (like pickle in Python, RDS in R) generated by\\n\",\n      \"  older XGBoost, please export the model by calling `Booster.save_model` from that version\\n\",\n      \"  first, then load it back in current version. See:\\n\",\n      \"\\n\",\n      \"    https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html\\n\",\n      \"\\n\",\n      \"  for more details about differences between saving model and serializing.\\n\",\n      \"\\n\",\n      \"Training takes 28.6 seconds\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\r\",\n      \"                                                                                \\r\",\n      \"/home/yuali_nvidia_com/.local/lib/python3.8/site-packages/xgboost/sklearn.py:808: UserWarning: Loading a native XGBoost model with Scikit-Learn interface.\\n\",\n      \"  warnings.warn(\\\"Loading a native XGBoost model with Scikit-Learn interface.\\\")\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"def with_benchmark(phrase, action):\\n\",\n    \"    start = time()\\n\",\n    \"    result = action()\\n\",\n    \"    end = time()\\n\",\n    \"    print('{} takes {} seconds'.format(phrase, round(end - start, 2)))\\n\",\n    \"    return result\\n\",\n    \"model = with_benchmark('Training', lambda: classifier.fit(train_data))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Save and Reload the Model\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"If features_cols param set, then features_col param is ignored.\\n\",\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"model.write().overwrite().save(dataRoot + '/model/mortgage')\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"loaded_model = SparkXGBClassifierModel().load(dataRoot + '/model/mortgage')\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Transformation and Show Result Sample\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 9,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"22/11/24 06:15:13 WARN com.nvidia.spark.rapids.GpuOverrides: \\n\",\n      \"!Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced; unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [rawPrediction#209, probability#275]\\n\",\n      \"  @Expression <AttributeReference> orig_channel#56 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> first_home_buyer#57 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> loan_purpose#58 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> property_type#59 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> occupancy_status#60 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> property_state#61 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> product_type#62 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> relocation_mortgage_indicator#63 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> seller_name#64 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> mod_flag#65 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> orig_interest_rate#66 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> orig_upb#67 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> orig_loan_term#68 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> orig_ltv#69 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> orig_cltv#70 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> num_borrowers#71 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> dti#72 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> borrower_credit_score#73 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> num_units#74 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> zip#75 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> mortgage_insurance_percent#76 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> current_loan_delinquency_status#77 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> current_actual_upb#78 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> interest_rate#79 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> loan_age#80 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> msa#81 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> non_interest_bearing_upb#82 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> delinquency_12#83 could run on GPU\\n\",\n      \"  !Expression <Alias> UDF(pythonUDF0#342.rawPrediction) AS rawPrediction#209 cannot run on GPU because input expression ScalaUDF UDF(pythonUDF0#342.rawPrediction) (org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 is not supported); expression Alias UDF(pythonUDF0#342.rawPrediction) AS rawPrediction#209 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\\n\",\n      \"    !Expression <ScalaUDF> UDF(pythonUDF0#342.rawPrediction) cannot run on GPU because expression ScalaUDF UDF(pythonUDF0#342.rawPrediction) produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7; neither UDF implemented by class org.apache.spark.ml.functions$$$Lambda$3898/645590696 provides a GPU implementation, nor the conf `spark.rapids.sql.rowBasedUDF.enabled` is enabled\\n\",\n      \"      @Expression <GetStructField> pythonUDF0#342.rawPrediction could run on GPU\\n\",\n      \"        @Expression <AttributeReference> pythonUDF0#342 could run on GPU\\n\",\n      \"  @Expression <Alias> pythonUDF0#342.prediction AS prediction#243 could run on GPU\\n\",\n      \"    @Expression <GetStructField> pythonUDF0#342.prediction could run on GPU\\n\",\n      \"      @Expression <AttributeReference> pythonUDF0#342 could run on GPU\\n\",\n      \"  !Expression <Alias> UDF(pythonUDF0#342.probability) AS probability#275 cannot run on GPU because input expression ScalaUDF UDF(pythonUDF0#342.probability) (org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 is not supported); expression Alias UDF(pythonUDF0#342.probability) AS probability#275 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\\n\",\n      \"    !Expression <ScalaUDF> UDF(pythonUDF0#342.probability) cannot run on GPU because neither UDF implemented by class org.apache.spark.ml.functions$$$Lambda$3898/645590696 provides a GPU implementation, nor the conf `spark.rapids.sql.rowBasedUDF.enabled` is enabled; expression ScalaUDF UDF(pythonUDF0#342.probability) produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\\n\",\n      \"      @Expression <GetStructField> pythonUDF0#342.probability could run on GPU\\n\",\n      \"        @Expression <AttributeReference> pythonUDF0#342 could run on GPU\\n\",\n      \"\\n\",\n      \"22/11/24 06:15:13 WARN com.nvidia.spark.rapids.GpuOverrides: \\n\",\n      \"!Exec <InMemoryTableScanExec> cannot run on GPU because unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [rawPrediction, probability]; not all expressions can be replaced; unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [rawPrediction#209, probability#275]\\n\",\n      \"  @Expression <AttributeReference> orig_channel#56 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> first_home_buyer#57 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> loan_purpose#58 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> property_type#59 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> occupancy_status#60 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> property_state#61 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> product_type#62 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> relocation_mortgage_indicator#63 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> seller_name#64 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> mod_flag#65 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> orig_interest_rate#66 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> orig_upb#67 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> orig_loan_term#68 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> orig_ltv#69 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> orig_cltv#70 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> num_borrowers#71 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> dti#72 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> borrower_credit_score#73 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> num_units#74 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> zip#75 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> mortgage_insurance_percent#76 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> current_loan_delinquency_status#77 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> current_actual_upb#78 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> interest_rate#79 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> loan_age#80 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> msa#81 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> non_interest_bearing_upb#82 could run on GPU\\n\",\n      \"  @Expression <AttributeReference> delinquency_12#83 could run on GPU\\n\",\n      \"  !Expression <AttributeReference> rawPrediction#209 cannot run on GPU because expression AttributeReference rawPrediction#209 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\\n\",\n      \"  @Expression <AttributeReference> prediction#243 could run on GPU\\n\",\n      \"  !Expression <AttributeReference> probability#275 cannot run on GPU because expression AttributeReference probability#275 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\\n\",\n      \"\\n\",\n      \"22/11/24 06:15:28 WARN com.nvidia.spark.rapids.GpuOverrides:                    \\n\",\n      \"!Exec <CollectLimitExec> cannot run on GPU because the Exec CollectLimitExec has been disabled, and is disabled by default because Collect Limit replacement can be slower on the GPU, if huge number of rows in a batch it could help by limiting the number of rows transferred from GPU to CPU. Set spark.rapids.sql.exec.CollectLimitExec to true if you wish to enable it\\n\",\n      \"  @Partitioning <SinglePartition$> could run on GPU\\n\",\n      \"  !Exec <ProjectExec> cannot run on GPU because not all expressions can be replaced; unsupported data types in input: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [probability#275, rawPrediction#209]\\n\",\n      \"    @Expression <Alias> cast(delinquency_12#83 as string) AS delinquency_12#971 could run on GPU\\n\",\n      \"      @Expression <Cast> cast(delinquency_12#83 as string) could run on GPU\\n\",\n      \"        @Expression <AttributeReference> delinquency_12#83 could run on GPU\\n\",\n      \"    @Expression <Alias> cast(rawPrediction#209 as string) AS rawPrediction#972 could run on GPU\\n\",\n      \"      !Expression <Cast> cast(rawPrediction#209 as string) cannot run on GPU because Cast from org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 to StringType is not supported\\n\",\n      \"        !Expression <AttributeReference> rawPrediction#209 cannot run on GPU because expression AttributeReference rawPrediction#209 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\\n\",\n      \"    @Expression <Alias> cast(probability#275 as string) AS probability#973 could run on GPU\\n\",\n      \"      !Expression <Cast> cast(probability#275 as string) cannot run on GPU because Cast from org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 to StringType is not supported\\n\",\n      \"        !Expression <AttributeReference> probability#275 cannot run on GPU because expression AttributeReference probability#275 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\\n\",\n      \"    @Expression <Alias> cast(prediction#243 as string) AS prediction#974 could run on GPU\\n\",\n      \"      @Expression <Cast> cast(prediction#243 as string) could run on GPU\\n\",\n      \"        @Expression <AttributeReference> prediction#243 could run on GPU\\n\",\n      \"    !Exec <InMemoryTableScanExec> cannot run on GPU because unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [rawPrediction, probability]; unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [probability#275, rawPrediction#209]; not all expressions can be replaced\\n\",\n      \"      @Expression <AttributeReference> delinquency_12#83 could run on GPU\\n\",\n      \"      @Expression <AttributeReference> prediction#243 could run on GPU\\n\",\n      \"      !Expression <AttributeReference> probability#275 cannot run on GPU because expression AttributeReference probability#275 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\\n\",\n      \"      !Expression <AttributeReference> rawPrediction#209 cannot run on GPU because expression AttributeReference rawPrediction#209 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\\n\",\n      \"\\n\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Transformation takes 15.62 seconds\\n\",\n      \"+--------------+--------------------+--------------------+----------+\\n\",\n      \"|delinquency_12|       rawPrediction|         probability|prediction|\\n\",\n      \"+--------------+--------------------+--------------------+----------+\\n\",\n      \"|             0|[8.84631538391113...|[0.99985611438751...|       0.0|\\n\",\n      \"|             0|[9.41864871978759...|[0.99991881847381...|       0.0|\\n\",\n      \"|             0|[9.41864871978759...|[0.99991881847381...|       0.0|\\n\",\n      \"|             0|[9.41864871978759...|[0.99991881847381...|       0.0|\\n\",\n      \"|             0|[8.84631538391113...|[0.99985611438751...|       0.0|\\n\",\n      \"+--------------+--------------------+--------------------+----------+\\n\",\n      \"only showing top 5 rows\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"def transform():\\n\",\n    \"    result = loaded_model.transform(trans_data).cache()\\n\",\n    \"    result.foreachPartition(lambda _: None)\\n\",\n    \"    return result\\n\",\n    \"result = with_benchmark('Transformation', transform)\\n\",\n    \"result.select(label, 'rawPrediction', 'probability', 'prediction').show(5)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Evaluation\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 10,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def check_classification_accuracy(data_frame, label):\\n\",\n    \"    accuracy = (MulticlassClassificationEvaluator()\\n\",\n    \"                .setLabelCol(label)\\n\",\n    \"                .evaluate(data_frame))\\n\",\n    \"    print('-' * 100)\\n\",\n    \"    print('Accuracy is ' + str(accuracy))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 11,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"22/11/24 06:15:28 WARN com.nvidia.spark.rapids.GpuOverrides: \\n\",\n      \"! <DeserializeToObjectExec> cannot run on GPU because not all expressions can be replaced; GPU does not currently support the operator class org.apache.spark.sql.execution.DeserializeToObjectExec\\n\",\n      \"  ! <CreateExternalRow> createexternalrow(prediction#243, delinquency_12#1450, 1.0#1449, newInstance(class org.apache.spark.ml.linalg.VectorUDT).deserialize, StructField(prediction,DoubleType,true), StructField(delinquency_12,DoubleType,true), StructField(1.0,DoubleType,false), StructField(probability,org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7,true)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.CreateExternalRow\\n\",\n      \"    @Expression <AttributeReference> prediction#243 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> delinquency_12#1450 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> 1.0#1449 could run on GPU\\n\",\n      \"    ! <Invoke> newInstance(class org.apache.spark.ml.linalg.VectorUDT).deserialize cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.Invoke\\n\",\n      \"      ! <NewInstance> newInstance(class org.apache.spark.ml.linalg.VectorUDT) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.NewInstance\\n\",\n      \"      !Expression <AttributeReference> probability#275 cannot run on GPU because expression AttributeReference probability#275 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\\n\",\n      \"  !Expression <AttributeReference> obj#1455 cannot run on GPU because expression AttributeReference obj#1455 produces an unsupported type ObjectType(interface org.apache.spark.sql.Row)\\n\",\n      \"  !Exec <ProjectExec> cannot run on GPU because unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [probability#275]; unsupported data types in input: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [probability#275]; not all expressions can be replaced\\n\",\n      \"    @Expression <AttributeReference> prediction#243 could run on GPU\\n\",\n      \"    @Expression <Alias> cast(delinquency_12#83 as double) AS delinquency_12#1450 could run on GPU\\n\",\n      \"      @Expression <Cast> cast(delinquency_12#83 as double) could run on GPU\\n\",\n      \"        @Expression <AttributeReference> delinquency_12#83 could run on GPU\\n\",\n      \"    @Expression <Alias> 1.0 AS 1.0#1449 could run on GPU\\n\",\n      \"      @Expression <Literal> 1.0 could run on GPU\\n\",\n      \"    !Expression <AttributeReference> probability#275 cannot run on GPU because expression AttributeReference probability#275 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\\n\",\n      \"    !Exec <InMemoryTableScanExec> cannot run on GPU because unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [rawPrediction, probability]; unsupported data types in output: org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 [probability#275]; not all expressions can be replaced\\n\",\n      \"      @Expression <AttributeReference> delinquency_12#83 could run on GPU\\n\",\n      \"      @Expression <AttributeReference> prediction#243 could run on GPU\\n\",\n      \"      !Expression <AttributeReference> probability#275 cannot run on GPU because expression AttributeReference probability#275 produces an unsupported type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7\\n\",\n      \"\\n\",\n      \"[Stage 19:>                                                         (0 + 1) / 1]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"----------------------------------------------------------------------------------------------------\\n\",\n      \"Accuracy is 1.0\\n\",\n      \"Evaluation takes 2.29 seconds\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\r\",\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"with_benchmark('Evaluation', lambda: check_classification_accuracy(result, label))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 12,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"spark.stop()\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.2\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "examples/XGBoost-Examples/mortgage/notebooks/scala/mortgage-ETL.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"e82e9fb4\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Introduction to Mortgage ETL Job\\n\",\n    \"This is the mortgage ETL job to generate the input datasets for the mortgage Xgboost job.\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"d0c8c3fa\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Prerequirement\\n\",\n    \"### 1. Download data\\n\",\n    \"<!-- Refer these [instructions](https://github.com/NVIDIA/spark-rapids-examples/blob/branch-23.12/docs/get-started/xgboost-examples/dataset/mortgage.md) to download the dataset -->\\n\",\n    \"Refer to these [instructions](https://github.com/NVIDIA/spark-rapids-examples/blob/branch-23.12/docs/get-started/xgboost-examples/dataset/mortgage.md) to download the dataset.\\n\",\n    \"\\n\",\n    \"### 2. Download needed jars\\n\",\n    \"* [rapids-4-spark_2.12-26.02.0.jar](https://repo1.maven.org/maven2/com/nvidia/rapids-4-spark_2.12/26.02.0/rapids-4-spark_2.12-26.02.0.jar)\\n\",\n    \"\\n\",\n    \"### 3. Start Spark Standalone\\n\",\n    \"Before Running the script, please setup Spark standalone mode\\n\",\n    \"\\n\",\n    \"### 4. Add ENV\\n\",\n    \"```\\n\",\n    \"$ export SPARK_JARS=rapids-4-spark_2.12-26.02.0.jar\\n\",\n    \"\\n\",\n    \"```\\n\",\n    \"\\n\",\n    \"### 5.Start Jupyter Notebook with spylon-kernel or toree\\n\",\n    \"\\n\",\n    \"```\\n\",\n    \"$ jupyter notebook --allow-root --notebook-dir=${your-dir} --config=${your-configs}\\n\",\n    \"```\\n\",\n    \"\\n\",\n    \"## Import Libs\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"id\": \"3ecc912c\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import org.apache.hadoop.fs.Path\\n\",\n    \"import org.apache.spark.sql.expressions.Window\\n\",\n    \"import org.apache.spark.sql.functions._\\n\",\n    \"import org.apache.spark.sql.types._\\n\",\n    \"import org.apache.spark.sql.{Column, DataFrame, SparkSession}\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"b58fcd6d\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Script Settings\\n\",\n    \"\\n\",\n    \"### 1. File Path Settings\\n\",\n    \"* Define input file path\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"b2834c06\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"val dataRoot = sys.env.getOrElse(\\\"DATA_ROOT\\\", \\\"/data\\\")\\n\",\n    \"val dataOut = sys.env.getOrElse(\\\"DATA_OUT\\\", \\\"/data\\\")\\n\",\n    \"val dataPath = dataRoot + \\\"/mortgage/input\\\"\\n\",\n    \"val outPath = dataOut + \\\"/mortgage/output\\\"\\n\",\n    \"val output_csv2parquet = dataOut + \\\"/mortgage/output/csv2parquet/\\\"\\n\",\n    \"val saveTrainEvalDataset = true\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"775a2c7b\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Function and Object Define\\n\",\n    \"### 1. Define the constants\\n\",\n    \"\\n\",\n    \"* Define input/output file schema (Performance and Acquisition)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"id\": \"e557beb0\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"rawSchema = StructType(StructField(reference_pool_id,StringType,true), StructField(loan_id,LongType,true), StructField(monthly_reporting_period,StringType,true), StructField(orig_channel,StringType,true), StructField(seller_name,StringType,true), StructField(servicer,StringType,true), StructField(master_servicer,StringType,true), StructField(orig_interest_rate,DoubleType,true), StructField(interest_rate,DoubleType,true), StructField(orig_upb,IntegerType,true), StructField(upb_at_issuance,StringType,true), StructField(current_actual_upb,DoubleType,true), StructField(orig_loan_term,IntegerType,true), StructField(orig_date,StringType,true), StructField(first_pay_date,StringType,true), StructField(loan_age,DoubleType,true), StructField(remaining_months...\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"StructType(StructField(reference_pool_id,StringType,true), StructField(loan_id,LongType,true), StructField(monthly_reporting_period,StringType,true), StructField(orig_channel,StringType,true), StructField(seller_name,StringType,true), StructField(servicer,StringType,true), StructField(master_servicer,StringType,true), StructField(orig_interest_rate,DoubleType,true), StructField(interest_rate,DoubleType,true), StructField(orig_upb,IntegerType,true), StructField(upb_at_issuance,StringType,true), StructField(current_actual_upb,DoubleType,true), StructField(orig_loan_term,IntegerType,true), StructField(orig_date,StringType,true), StructField(first_pay_date,StringType,true), StructField(loan_age,DoubleType,true), StructField(remaining_months...\"\n      ]\n     },\n     \"execution_count\": 3,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"// File schema\\n\",\n    \"val rawSchema = StructType(Array(\\n\",\n    \"      StructField(\\\"reference_pool_id\\\", StringType),\\n\",\n    \"      StructField(\\\"loan_id\\\", LongType),\\n\",\n    \"      StructField(\\\"monthly_reporting_period\\\", StringType),\\n\",\n    \"      StructField(\\\"orig_channel\\\", StringType),\\n\",\n    \"      StructField(\\\"seller_name\\\", StringType),\\n\",\n    \"      StructField(\\\"servicer\\\", StringType),\\n\",\n    \"      StructField(\\\"master_servicer\\\", StringType),\\n\",\n    \"      StructField(\\\"orig_interest_rate\\\", DoubleType),\\n\",\n    \"      StructField(\\\"interest_rate\\\", DoubleType),\\n\",\n    \"      StructField(\\\"orig_upb\\\", DoubleType),\\n\",\n    \"      StructField(\\\"upb_at_issuance\\\", StringType),\\n\",\n    \"      StructField(\\\"current_actual_upb\\\", DoubleType),\\n\",\n    \"      StructField(\\\"orig_loan_term\\\", IntegerType),\\n\",\n    \"      StructField(\\\"orig_date\\\", StringType),\\n\",\n    \"      StructField(\\\"first_pay_date\\\", StringType),    \\n\",\n    \"      StructField(\\\"loan_age\\\", DoubleType),\\n\",\n    \"      StructField(\\\"remaining_months_to_legal_maturity\\\", DoubleType),\\n\",\n    \"      StructField(\\\"adj_remaining_months_to_maturity\\\", DoubleType),\\n\",\n    \"      StructField(\\\"maturity_date\\\", StringType),\\n\",\n    \"      StructField(\\\"orig_ltv\\\", DoubleType),\\n\",\n    \"      StructField(\\\"orig_cltv\\\", DoubleType),\\n\",\n    \"      StructField(\\\"num_borrowers\\\", DoubleType),\\n\",\n    \"      StructField(\\\"dti\\\", DoubleType),\\n\",\n    \"      StructField(\\\"borrower_credit_score\\\", DoubleType),\\n\",\n    \"      StructField(\\\"coborrow_credit_score\\\", DoubleType),\\n\",\n    \"      StructField(\\\"first_home_buyer\\\", StringType),\\n\",\n    \"      StructField(\\\"loan_purpose\\\", StringType),\\n\",\n    \"      StructField(\\\"property_type\\\", StringType),\\n\",\n    \"      StructField(\\\"num_units\\\", IntegerType),\\n\",\n    \"      StructField(\\\"occupancy_status\\\", StringType),\\n\",\n    \"      StructField(\\\"property_state\\\", StringType),\\n\",\n    \"      StructField(\\\"msa\\\", DoubleType),\\n\",\n    \"      StructField(\\\"zip\\\", IntegerType),\\n\",\n    \"      StructField(\\\"mortgage_insurance_percent\\\", DoubleType),\\n\",\n    \"      StructField(\\\"product_type\\\", StringType),\\n\",\n    \"      StructField(\\\"prepayment_penalty_indicator\\\", StringType),\\n\",\n    \"      StructField(\\\"interest_only_loan_indicator\\\", StringType),\\n\",\n    \"      StructField(\\\"interest_only_first_principal_and_interest_payment_date\\\", StringType),\\n\",\n    \"      StructField(\\\"months_to_amortization\\\", StringType),\\n\",\n    \"      StructField(\\\"current_loan_delinquency_status\\\", IntegerType),\\n\",\n    \"      StructField(\\\"loan_payment_history\\\", StringType),\\n\",\n    \"      StructField(\\\"mod_flag\\\", StringType),\\n\",\n    \"      StructField(\\\"mortgage_insurance_cancellation_indicator\\\", StringType),\\n\",\n    \"      StructField(\\\"zero_balance_code\\\", StringType),\\n\",\n    \"      StructField(\\\"zero_balance_effective_date\\\", StringType),\\n\",\n    \"      StructField(\\\"upb_at_the_time_of_removal\\\", StringType),\\n\",\n    \"      StructField(\\\"repurchase_date\\\", StringType),\\n\",\n    \"      StructField(\\\"scheduled_principal_current\\\", StringType),\\n\",\n    \"      StructField(\\\"total_principal_current\\\", StringType),\\n\",\n    \"      StructField(\\\"unscheduled_principal_current\\\", StringType),\\n\",\n    \"      StructField(\\\"last_paid_installment_date\\\", StringType),\\n\",\n    \"      StructField(\\\"foreclosed_after\\\", StringType),\\n\",\n    \"      StructField(\\\"disposition_date\\\", StringType),\\n\",\n    \"      StructField(\\\"foreclosure_costs\\\", DoubleType),\\n\",\n    \"      StructField(\\\"prop_preservation_and_repair_costs\\\", DoubleType),\\n\",\n    \"      StructField(\\\"asset_recovery_costs\\\", DoubleType),\\n\",\n    \"      StructField(\\\"misc_holding_expenses\\\", DoubleType),\\n\",\n    \"      StructField(\\\"holding_taxes\\\", DoubleType),\\n\",\n    \"      StructField(\\\"net_sale_proceeds\\\", DoubleType),\\n\",\n    \"      StructField(\\\"credit_enhancement_proceeds\\\", DoubleType),\\n\",\n    \"      StructField(\\\"repurchase_make_whole_proceeds\\\", StringType),\\n\",\n    \"      StructField(\\\"other_foreclosure_proceeds\\\", DoubleType),\\n\",\n    \"      StructField(\\\"non_interest_bearing_upb\\\", DoubleType),\\n\",\n    \"      StructField(\\\"principal_forgiveness_upb\\\", StringType),\\n\",\n    \"      StructField(\\\"original_list_start_date\\\", StringType),\\n\",\n    \"      StructField(\\\"original_list_price\\\", StringType),\\n\",\n    \"      StructField(\\\"current_list_start_date\\\", StringType),\\n\",\n    \"      StructField(\\\"current_list_price\\\", StringType),\\n\",\n    \"      StructField(\\\"borrower_credit_score_at_issuance\\\", StringType),\\n\",\n    \"      StructField(\\\"co-borrower_credit_score_at_issuance\\\", StringType),\\n\",\n    \"      StructField(\\\"borrower_credit_score_current\\\", StringType),\\n\",\n    \"      StructField(\\\"co-Borrower_credit_score_current\\\", StringType),\\n\",\n    \"      StructField(\\\"mortgage_insurance_type\\\", DoubleType),\\n\",\n    \"      StructField(\\\"servicing_activity_indicator\\\", StringType),\\n\",\n    \"      StructField(\\\"current_period_modification_loss_amount\\\", StringType),\\n\",\n    \"      StructField(\\\"cumulative_modification_loss_amount\\\", StringType),\\n\",\n    \"      StructField(\\\"current_period_credit_event_net_gain_or_loss\\\", StringType),\\n\",\n    \"      StructField(\\\"cumulative_credit_event_net_gain_or_loss\\\", StringType),\\n\",\n    \"      StructField(\\\"homeready_program_indicator\\\", StringType),\\n\",\n    \"      StructField(\\\"foreclosure_principal_write_off_amount\\\", StringType),\\n\",\n    \"      StructField(\\\"relocation_mortgage_indicator\\\", StringType),\\n\",\n    \"      StructField(\\\"zero_balance_code_change_date\\\", StringType),\\n\",\n    \"      StructField(\\\"loan_holdback_indicator\\\", StringType),\\n\",\n    \"      StructField(\\\"loan_holdback_effective_date\\\", StringType),\\n\",\n    \"      StructField(\\\"delinquent_accrued_interest\\\", StringType),\\n\",\n    \"      StructField(\\\"property_valuation_method\\\", StringType),\\n\",\n    \"      StructField(\\\"high_balance_loan_indicator\\\", StringType),\\n\",\n    \"      StructField(\\\"arm_initial_fixed-rate_period_lt_5_yr_indicator\\\", StringType),\\n\",\n    \"      StructField(\\\"arm_product_type\\\", StringType),\\n\",\n    \"      StructField(\\\"initial_fixed-rate_period\\\", StringType),\\n\",\n    \"      StructField(\\\"interest_rate_adjustment_frequency\\\", StringType),\\n\",\n    \"      StructField(\\\"next_interest_rate_adjustment_date\\\", StringType),\\n\",\n    \"      StructField(\\\"next_payment_change_date\\\", StringType),\\n\",\n    \"      StructField(\\\"index\\\", StringType),\\n\",\n    \"      StructField(\\\"arm_cap_structure\\\", StringType),\\n\",\n    \"      StructField(\\\"initial_interest_rate_cap_up_percent\\\", StringType),\\n\",\n    \"      StructField(\\\"periodic_interest_rate_cap_up_percent\\\", StringType),\\n\",\n    \"      StructField(\\\"lifetime_interest_rate_cap_up_percent\\\", StringType),\\n\",\n    \"      StructField(\\\"mortgage_margin\\\", StringType),\\n\",\n    \"      StructField(\\\"arm_balloon_indicator\\\", StringType),\\n\",\n    \"      StructField(\\\"arm_plan_number\\\", StringType),\\n\",\n    \"      StructField(\\\"borrower_assistance_plan\\\", StringType),\\n\",\n    \"      StructField(\\\"hltv_refinance_option_indicator\\\", StringType),\\n\",\n    \"      StructField(\\\"deal_name\\\", StringType),\\n\",\n    \"      StructField(\\\"repurchase_make_whole_proceeds_flag\\\", StringType),\\n\",\n    \"      StructField(\\\"alternative_delinquency_resolution\\\", StringType),\\n\",\n    \"      StructField(\\\"alternative_delinquency_resolution_count\\\", StringType),\\n\",\n    \"      StructField(\\\"total_deferral_amount\\\", StringType)\\n\",\n    \"      )\\n\",\n    \"    )\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"86af48b6\",\n   \"metadata\": {},\n   \"source\": [\n    \"* Define seller name mapping\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"id\": \"69f193d7\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"defined object NameMapping\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"object NameMapping {\\n\",\n    \"  /**\\n\",\n    \"    * Returns a dataframe with two columns named based off of the column names passed in.\\n\",\n    \"    * The fromColName has the original name we want to clean up, the toColName\\n\",\n    \"    * will have the name we want to go to, the unambiguous name.\\n\",\n    \"    */\\n\",\n    \"  def apply(spark: SparkSession, fromColName: String, toColName: String): DataFrame = {\\n\",\n    \"    import spark.sqlContext.implicits._\\n\",\n    \"    broadcast(Seq(\\n\",\n    \"      (\\\"WITMER FUNDING, LLC\\\", \\\"Witmer\\\"),\\n\",\n    \"      (\\\"WELLS FARGO CREDIT RISK TRANSFER SECURITIES TRUST 2015\\\", \\\"Wells Fargo\\\"),\\n\",\n    \"      (\\\"WELLS FARGO BANK,  NA\\\" , \\\"Wells Fargo\\\"),\\n\",\n    \"      (\\\"WELLS FARGO BANK, N.A.\\\" , \\\"Wells Fargo\\\"),\\n\",\n    \"      (\\\"WELLS FARGO BANK, NA\\\" , \\\"Wells Fargo\\\"),\\n\",\n    \"      (\\\"USAA FEDERAL SAVINGS BANK\\\" , \\\"USAA\\\"),\\n\",\n    \"      (\\\"UNITED SHORE FINANCIAL SERVICES, LLC D\\\\\\\\/B\\\\\\\\/A UNITED WHOLESALE MORTGAGE\\\" , \\\"United Seq(e\\\"),\\n\",\n    \"      (\\\"U.S. BANK N.A.\\\" , \\\"US Bank\\\"),\\n\",\n    \"      (\\\"SUNTRUST MORTGAGE INC.\\\" , \\\"Suntrust\\\"),\\n\",\n    \"      (\\\"STONEGATE MORTGAGE CORPORATION\\\" , \\\"Stonegate Mortgage\\\"),\\n\",\n    \"      (\\\"STEARNS LENDING, LLC\\\" , \\\"Stearns Lending\\\"),\\n\",\n    \"      (\\\"STEARNS LENDING, INC.\\\" , \\\"Stearns Lending\\\"),\\n\",\n    \"      (\\\"SIERRA PACIFIC MORTGAGE COMPANY, INC.\\\" , \\\"Sierra Pacific Mortgage\\\"),\\n\",\n    \"      (\\\"REGIONS BANK\\\" , \\\"Regions\\\"),\\n\",\n    \"      (\\\"RBC MORTGAGE COMPANY\\\" , \\\"RBC\\\"),\\n\",\n    \"      (\\\"QUICKEN LOANS INC.\\\" , \\\"Quicken Loans\\\"),\\n\",\n    \"      (\\\"PULTE MORTGAGE, L.L.C.\\\" , \\\"Pulte Mortgage\\\"),\\n\",\n    \"      (\\\"PROVIDENT FUNDING ASSOCIATES, L.P.\\\" , \\\"Provident Funding\\\"),\\n\",\n    \"      (\\\"PROSPECT MORTGAGE, LLC\\\" , \\\"Prospect Mortgage\\\"),\\n\",\n    \"      (\\\"PRINCIPAL RESIDENTIAL MORTGAGE CAPITAL RESOURCES, LLC\\\" , \\\"Principal Residential\\\"),\\n\",\n    \"      (\\\"PNC BANK, N.A.\\\" , \\\"PNC\\\"),\\n\",\n    \"      (\\\"PMT CREDIT RISK TRANSFER TRUST 2015-2\\\" , \\\"PennyMac\\\"),\\n\",\n    \"      (\\\"PHH MORTGAGE CORPORATION\\\" , \\\"PHH Mortgage\\\"),\\n\",\n    \"      (\\\"PENNYMAC CORP.\\\" , \\\"PennyMac\\\"),\\n\",\n    \"      (\\\"PACIFIC UNION FINANCIAL, LLC\\\" , \\\"Other\\\"),\\n\",\n    \"      (\\\"OTHER\\\" , \\\"Other\\\"),\\n\",\n    \"      (\\\"NYCB MORTGAGE COMPANY, LLC\\\" , \\\"NYCB\\\"),\\n\",\n    \"      (\\\"NEW YORK COMMUNITY BANK\\\" , \\\"NYCB\\\"),\\n\",\n    \"      (\\\"NETBANK FUNDING SERVICES\\\" , \\\"Netbank\\\"),\\n\",\n    \"      (\\\"NATIONSTAR MORTGAGE, LLC\\\" , \\\"Nationstar Mortgage\\\"),\\n\",\n    \"      (\\\"METLIFE BANK, NA\\\" , \\\"Metlife\\\"),\\n\",\n    \"      (\\\"LOANDEPOT.COM, LLC\\\" , \\\"LoanDepot.com\\\"),\\n\",\n    \"      (\\\"J.P. MORGAN MADISON AVENUE SECURITIES TRUST, SERIES 2015-1\\\" , \\\"JP Morgan Chase\\\"),\\n\",\n    \"      (\\\"J.P. MORGAN MADISON AVENUE SECURITIES TRUST, SERIES 2014-1\\\" , \\\"JP Morgan Chase\\\"),\\n\",\n    \"      (\\\"JPMORGAN CHASE BANK, NATIONAL ASSOCIATION\\\" , \\\"JP Morgan Chase\\\"),\\n\",\n    \"      (\\\"JPMORGAN CHASE BANK, NA\\\" , \\\"JP Morgan Chase\\\"),\\n\",\n    \"      (\\\"JP MORGAN CHASE BANK, NA\\\" , \\\"JP Morgan Chase\\\"),\\n\",\n    \"      (\\\"IRWIN MORTGAGE, CORPORATION\\\" , \\\"Irwin Mortgage\\\"),\\n\",\n    \"      (\\\"IMPAC MORTGAGE CORP.\\\" , \\\"Impac Mortgage\\\"),\\n\",\n    \"      (\\\"HSBC BANK USA, NATIONAL ASSOCIATION\\\" , \\\"HSBC\\\"),\\n\",\n    \"      (\\\"HOMEWARD RESIDENTIAL, INC.\\\" , \\\"Homeward Mortgage\\\"),\\n\",\n    \"      (\\\"HOMESTREET BANK\\\" , \\\"Other\\\"),\\n\",\n    \"      (\\\"HOMEBRIDGE FINANCIAL SERVICES, INC.\\\" , \\\"HomeBridge\\\"),\\n\",\n    \"      (\\\"HARWOOD STREET FUNDING I, LLC\\\" , \\\"Harwood Mortgage\\\"),\\n\",\n    \"      (\\\"GUILD MORTGAGE COMPANY\\\" , \\\"Guild Mortgage\\\"),\\n\",\n    \"      (\\\"GMAC MORTGAGE, LLC (USAA FEDERAL SAVINGS BANK)\\\" , \\\"GMAC\\\"),\\n\",\n    \"      (\\\"GMAC MORTGAGE, LLC\\\" , \\\"GMAC\\\"),\\n\",\n    \"      (\\\"GMAC (USAA)\\\" , \\\"GMAC\\\"),\\n\",\n    \"      (\\\"FREMONT BANK\\\" , \\\"Fremont Bank\\\"),\\n\",\n    \"      (\\\"FREEDOM MORTGAGE CORP.\\\" , \\\"Freedom Mortgage\\\"),\\n\",\n    \"      (\\\"FRANKLIN AMERICAN MORTGAGE COMPANY\\\" , \\\"Franklin America\\\"),\\n\",\n    \"      (\\\"FLEET NATIONAL BANK\\\" , \\\"Fleet National\\\"),\\n\",\n    \"      (\\\"FLAGSTAR CAPITAL MARKETS CORPORATION\\\" , \\\"Flagstar Bank\\\"),\\n\",\n    \"      (\\\"FLAGSTAR BANK, FSB\\\" , \\\"Flagstar Bank\\\"),\\n\",\n    \"      (\\\"FIRST TENNESSEE BANK NATIONAL ASSOCIATION\\\" , \\\"Other\\\"),\\n\",\n    \"      (\\\"FIFTH THIRD BANK\\\" , \\\"Fifth Third Bank\\\"),\\n\",\n    \"      (\\\"FEDERAL HOME LOAN BANK OF CHICAGO\\\" , \\\"Fedral Home of Chicago\\\"),\\n\",\n    \"      (\\\"FDIC, RECEIVER, INDYMAC FEDERAL BANK FSB\\\" , \\\"FDIC\\\"),\\n\",\n    \"      (\\\"DOWNEY SAVINGS AND LOAN ASSOCIATION, F.A.\\\" , \\\"Downey Mortgage\\\"),\\n\",\n    \"      (\\\"DITECH FINANCIAL LLC\\\" , \\\"Ditech\\\"),\\n\",\n    \"      (\\\"CITIMORTGAGE, INC.\\\" , \\\"Citi\\\"),\\n\",\n    \"      (\\\"CHICAGO MORTGAGE SOLUTIONS DBA INTERFIRST MORTGAGE COMPANY\\\" , \\\"Chicago Mortgage\\\"),\\n\",\n    \"      (\\\"CHICAGO MORTGAGE SOLUTIONS DBA INTERBANK MORTGAGE COMPANY\\\" , \\\"Chicago Mortgage\\\"),\\n\",\n    \"      (\\\"CHASE HOME FINANCE, LLC\\\" , \\\"JP Morgan Chase\\\"),\\n\",\n    \"      (\\\"CHASE HOME FINANCE FRANKLIN AMERICAN MORTGAGE COMPANY\\\" , \\\"JP Morgan Chase\\\"),\\n\",\n    \"      (\\\"CHASE HOME FINANCE (CIE 1)\\\" , \\\"JP Morgan Chase\\\"),\\n\",\n    \"      (\\\"CHASE HOME FINANCE\\\" , \\\"JP Morgan Chase\\\"),\\n\",\n    \"      (\\\"CASHCALL, INC.\\\" , \\\"CashCall\\\"),\\n\",\n    \"      (\\\"CAPITAL ONE, NATIONAL ASSOCIATION\\\" , \\\"Capital One\\\"),\\n\",\n    \"      (\\\"CALIBER HOME LOANS, INC.\\\" , \\\"Caliber Funding\\\"),\\n\",\n    \"      (\\\"BISHOPS GATE RESIDENTIAL MORTGAGE TRUST\\\" , \\\"Bishops Gate Mortgage\\\"),\\n\",\n    \"      (\\\"BANK OF AMERICA, N.A.\\\" , \\\"Bank of America\\\"),\\n\",\n    \"      (\\\"AMTRUST BANK\\\" , \\\"AmTrust\\\"),\\n\",\n    \"      (\\\"AMERISAVE MORTGAGE CORPORATION\\\" , \\\"Amerisave\\\"),\\n\",\n    \"      (\\\"AMERIHOME MORTGAGE COMPANY, LLC\\\" , \\\"AmeriHome Mortgage\\\"),\\n\",\n    \"      (\\\"ALLY BANK\\\" , \\\"Ally Bank\\\"),\\n\",\n    \"      (\\\"ACADEMY MORTGAGE CORPORATION\\\" , \\\"Academy Mortgage\\\"),\\n\",\n    \"      (\\\"NO CASH-OUT REFINANCE\\\" , \\\"OTHER REFINANCE\\\"),\\n\",\n    \"      (\\\"REFINANCE - NOT SPECIFIED\\\" , \\\"OTHER REFINANCE\\\"),\\n\",\n    \"      (\\\"Other REFINANCE\\\" , \\\"OTHER REFINANCE\\\")\\n\",\n    \"    ).toDF(fromColName, toColName))\\n\",\n    \"  }\\n\",\n    \"}\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"42098a5a\",\n   \"metadata\": {},\n   \"source\": [\n    \"### 2. Define ETL Process\\n\",\n    \"\\n\",\n    \"Define the function to do the ETL process\\n\",\n    \"\\n\",\n    \"* Define function to get quarter from input CSV file name\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"id\": \"f18cab51\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"defined object GetQuarterFromCsvFileName\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"object GetQuarterFromCsvFileName {\\n\",\n    \"  // The format is path/TYPE_yyyy\\\\QQ.txt followed by a (_index)* where index is a single digit number [0-9]\\n\",\n    \"  // i.e. mortgage/perf/Performance_2003Q4.txt_0_1\\n\",\n    \"  // So we strip off the .txt and everything after it\\n\",\n    \"  // and then take everything after the last remaining _\\n\",\n    \"  def apply(): Column = substring_index(\\n\",\n    \"    substring_index(input_file_name(), \\\".\\\", 1), \\\"/\\\", -1)\\n\",\n    \"}\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"ead44543\",\n   \"metadata\": {},\n   \"source\": [\n    \"* Define category (string) column and numeric column\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"id\": \"9936e221\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"labelColName = delinquency_12\\n\",\n       \"categaryCols = List((orig_channel,FloatType), (first_home_buyer,FloatType), (loan_purpose,FloatType), (property_type,FloatType), (occupancy_status,FloatType), (property_state,FloatType), (product_type,FloatType), (relocation_mortgage_indicator,FloatType), (seller_name,FloatType), (mod_flag,FloatType))\\n\",\n       \"numericCols = List((orig_interest_rate,FloatType), (orig_upb,IntegerType), (orig_loan_term,IntegerType), (orig_ltv,FloatType), (orig_cltv,FloatType), (num_borrowers,FloatType), (dti,FloatType), (borrower_credit_score,FloatType), (num_units,IntegerType), (zip,IntegerType), (mortgage_insurance_percent,FloatType...\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"List((orig_interest_rate,FloatType), (orig_upb,IntegerType), (orig_loan_term,IntegerType), (orig_ltv,FloatType), (orig_cltv,FloatType), (num_borrowers,FloatType), (dti,FloatType), (borrower_credit_score,FloatType), (num_units,IntegerType), (zip,IntegerType), (mortgage_insurance_percent,FloatType...\"\n      ]\n     },\n     \"execution_count\": 6,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"val labelColName = \\\"delinquency_12\\\"\\n\",\n    \"val categaryCols = List(\\n\",\n    \"    (\\\"orig_channel\\\", FloatType),\\n\",\n    \"    (\\\"first_home_buyer\\\", FloatType),\\n\",\n    \"    (\\\"loan_purpose\\\", FloatType),\\n\",\n    \"    (\\\"property_type\\\", FloatType),\\n\",\n    \"    (\\\"occupancy_status\\\", FloatType),\\n\",\n    \"    (\\\"property_state\\\", FloatType),\\n\",\n    \"    (\\\"product_type\\\", FloatType),\\n\",\n    \"    (\\\"relocation_mortgage_indicator\\\", FloatType),\\n\",\n    \"    (\\\"seller_name\\\", FloatType),\\n\",\n    \"    (\\\"mod_flag\\\", FloatType)\\n\",\n    \"  )\\n\",\n    \"\\n\",\n    \"val numericCols = List(\\n\",\n    \"    (\\\"orig_interest_rate\\\", FloatType),\\n\",\n    \"    (\\\"orig_upb\\\", DoubleType),\\n\",\n    \"    (\\\"orig_loan_term\\\", IntegerType),\\n\",\n    \"    (\\\"orig_ltv\\\", FloatType),\\n\",\n    \"    (\\\"orig_cltv\\\", FloatType),\\n\",\n    \"    (\\\"num_borrowers\\\", FloatType),\\n\",\n    \"    (\\\"dti\\\", FloatType),\\n\",\n    \"    (\\\"borrower_credit_score\\\", FloatType),\\n\",\n    \"    (\\\"num_units\\\", IntegerType),\\n\",\n    \"    (\\\"zip\\\", IntegerType),\\n\",\n    \"    (\\\"mortgage_insurance_percent\\\", FloatType),\\n\",\n    \"    (\\\"current_loan_delinquency_status\\\", IntegerType),\\n\",\n    \"    (\\\"current_actual_upb\\\", FloatType),\\n\",\n    \"    (\\\"interest_rate\\\", FloatType),\\n\",\n    \"    (\\\"loan_age\\\", FloatType),\\n\",\n    \"    (\\\"msa\\\", FloatType),\\n\",\n    \"    (\\\"non_interest_bearing_upb\\\", FloatType),\\n\",\n    \"    (labelColName, IntegerType)\\n\",\n    \"  )\\n\",\n    \"\\n\",\n    \"var cachedDictDF: DataFrame = _\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"6177b6b8\",\n   \"metadata\": {},\n   \"source\": [\n    \"* Define Casting Process\\n\",\n    \"This part is casting String column to Numbric. \\n\",\n    \"Example:\\n\",\n    \"```\\n\",\n    \"col_1\\n\",\n    \" \\\"a\\\"\\n\",\n    \" \\\"b\\\"\\n\",\n    \" \\\"c\\\"\\n\",\n    \" \\\"a\\\"\\n\",\n    \"# After String ====> Numberic\\n\",\n    \"col_1\\n\",\n    \" 0\\n\",\n    \" 1\\n\",\n    \" 2\\n\",\n    \" 0\\n\",\n    \"```  \\n\",\n    \"<br>\\n\",\n    \"\\n\",\n    \"* Define function to get column dictionary\\n\",\n    \"\\n\",\n    \"    Example\\n\",\n    \"    ```\\n\",\n    \"    col1 = [row(data=\\\"a\\\",id=0), row(data=\\\"b\\\",id=1)]\\n\",\n    \"    ```\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"id\": \"5091c8a1\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"genDictionary: (etlDF: org.apache.spark.sql.DataFrame, colNames: Seq[String])org.apache.spark.sql.DataFrame\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"def genDictionary(etlDF: DataFrame, colNames: Seq[String]): DataFrame = {\\n\",\n    \"    val cntTable = etlDF\\n\",\n    \"      .select(posexplode(array(colNames.map(col(_)): _*)))\\n\",\n    \"      .withColumnRenamed(\\\"pos\\\", \\\"column_id\\\")\\n\",\n    \"      .withColumnRenamed(\\\"col\\\", \\\"data\\\")\\n\",\n    \"      .filter(\\\"data is not null\\\")\\n\",\n    \"      .groupBy(\\\"column_id\\\", \\\"data\\\")\\n\",\n    \"      .count()\\n\",\n    \"    val windowed = Window.partitionBy(\\\"column_id\\\").orderBy(desc(\\\"count\\\"))\\n\",\n    \"    cntTable\\n\",\n    \"      .withColumn(\\\"id\\\", row_number().over(windowed))\\n\",\n    \"      .drop(\\\"count\\\")\\n\",\n    \"  }\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"1466af65\",\n   \"metadata\": {},\n   \"source\": [\n    \"* Define function to convert string columns to numeric\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"id\": \"9df8fe60\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"castStringColumnsToNumeric: (inputDF: org.apache.spark.sql.DataFrame, spark: org.apache.spark.sql.SparkSession)org.apache.spark.sql.DataFrame\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"def castStringColumnsToNumeric(inputDF: DataFrame, spark: SparkSession): DataFrame = {\\n\",\n    \"    val cateColNames = categaryCols.map(_._1)\\n\",\n    \"    cachedDictDF = genDictionary(inputDF, cateColNames).cache()\\n\",\n    \"\\n\",\n    \"    // Generate the final table with all columns being numeric.\\n\",\n    \"    cateColNames.foldLeft(inputDF) {\\n\",\n    \"      case (df, colName) =>\\n\",\n    \"        val colPos = cateColNames.indexOf(colName)\\n\",\n    \"        val colDictDF = cachedDictDF\\n\",\n    \"          .filter(col(\\\"column_id\\\") === colPos)\\n\",\n    \"          .drop(\\\"column_id\\\")\\n\",\n    \"          .withColumnRenamed(\\\"data\\\", colName)\\n\",\n    \"        df.join(broadcast(colDictDF), Seq(colName), \\\"left\\\")\\n\",\n    \"          .drop(colName)\\n\",\n    \"          .withColumnRenamed(\\\"id\\\", colName)\\n\",\n    \"    }\\n\",\n    \"  }\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 9,\n   \"id\": \"9e1fbb61\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"defined object extractPerfColumns\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"object extractPerfColumns{\\n\",\n    \"  def apply(rawDf : DataFrame) : DataFrame = {\\n\",\n    \"    val perfDf = rawDf.select(\\n\",\n    \"      col(\\\"loan_id\\\"),\\n\",\n    \"      date_format(to_date(col(\\\"monthly_reporting_period\\\"),\\\"MMyyyy\\\"), \\\"MM/dd/yyyy\\\").as(\\\"monthly_reporting_period\\\"),\\n\",\n    \"      upper(col(\\\"servicer\\\")).as(\\\"servicer\\\"),\\n\",\n    \"      col(\\\"interest_rate\\\"),\\n\",\n    \"      col(\\\"current_actual_upb\\\"),\\n\",\n    \"      col(\\\"loan_age\\\"),\\n\",\n    \"      col(\\\"remaining_months_to_legal_maturity\\\"),\\n\",\n    \"      col(\\\"adj_remaining_months_to_maturity\\\"),\\n\",\n    \"      date_format(to_date(col(\\\"maturity_date\\\"),\\\"MMyyyy\\\"), \\\"MM/yyyy\\\").as(\\\"maturity_date\\\"),\\n\",\n    \"      col(\\\"msa\\\"),\\n\",\n    \"      col(\\\"current_loan_delinquency_status\\\"),\\n\",\n    \"      col(\\\"mod_flag\\\"),\\n\",\n    \"      col(\\\"zero_balance_code\\\"),\\n\",\n    \"      date_format(to_date(col(\\\"zero_balance_effective_date\\\"),\\\"MMyyyy\\\"), \\\"MM/yyyy\\\").as(\\\"zero_balance_effective_date\\\"),\\n\",\n    \"      date_format(to_date(col(\\\"last_paid_installment_date\\\"),\\\"MMyyyy\\\"), \\\"MM/dd/yyyy\\\").as(\\\"last_paid_installment_date\\\"),\\n\",\n    \"      date_format(to_date(col(\\\"foreclosed_after\\\"),\\\"MMyyyy\\\"), \\\"MM/dd/yyyy\\\").as(\\\"foreclosed_after\\\"),\\n\",\n    \"      date_format(to_date(col(\\\"disposition_date\\\"),\\\"MMyyyy\\\"), \\\"MM/dd/yyyy\\\").as(\\\"disposition_date\\\"),\\n\",\n    \"      col(\\\"foreclosure_costs\\\"),\\n\",\n    \"      col(\\\"prop_preservation_and_repair_costs\\\"),\\n\",\n    \"      col(\\\"asset_recovery_costs\\\"),\\n\",\n    \"      col(\\\"misc_holding_expenses\\\"),\\n\",\n    \"      col(\\\"holding_taxes\\\"),\\n\",\n    \"      col(\\\"net_sale_proceeds\\\"),\\n\",\n    \"      col(\\\"credit_enhancement_proceeds\\\"),\\n\",\n    \"      col(\\\"repurchase_make_whole_proceeds\\\"),\\n\",\n    \"      col(\\\"other_foreclosure_proceeds\\\"),\\n\",\n    \"      col(\\\"non_interest_bearing_upb\\\"),\\n\",\n    \"      col(\\\"principal_forgiveness_upb\\\"),\\n\",\n    \"      col(\\\"repurchase_make_whole_proceeds_flag\\\"),\\n\",\n    \"      col(\\\"foreclosure_principal_write_off_amount\\\"),\\n\",\n    \"      col(\\\"servicing_activity_indicator\\\"),\\n\",\n    \"      col(\\\"quarter\\\")\\n\",\n    \"    )\\n\",\n    \"    \\n\",\n    \"    perfDf.select(\\\"*\\\").filter(\\\"current_actual_upb != 0.0\\\")\\n\",\n    \"  }\\n\",\n    \"}\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 10,\n   \"id\": \"ce429163\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"defined object extractAcqColumns\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"object extractAcqColumns{\\n\",\n    \"  def apply(rawDf : DataFrame) : DataFrame = {\\n\",\n    \"    val acqDf = rawDf.select(\\n\",\n    \"      col(\\\"loan_id\\\"),\\n\",\n    \"      col(\\\"orig_channel\\\"),\\n\",\n    \"      upper(col(\\\"seller_name\\\")).as(\\\"seller_name\\\"),\\n\",\n    \"      col(\\\"orig_interest_rate\\\"),\\n\",\n    \"      col(\\\"orig_upb\\\"),\\n\",\n    \"      col(\\\"orig_loan_term\\\"),\\n\",\n    \"      date_format(to_date(col(\\\"orig_date\\\"),\\\"MMyyyy\\\"), \\\"MM/yyyy\\\").as(\\\"orig_date\\\"),\\n\",\n    \"      date_format(to_date(col(\\\"first_pay_date\\\"),\\\"MMyyyy\\\"), \\\"MM/yyyy\\\").as(\\\"first_pay_date\\\"),\\n\",\n    \"      col(\\\"orig_ltv\\\"),\\n\",\n    \"      col(\\\"orig_cltv\\\"),\\n\",\n    \"      col(\\\"num_borrowers\\\"),\\n\",\n    \"      col(\\\"dti\\\"),\\n\",\n    \"      col(\\\"borrower_credit_score\\\"),\\n\",\n    \"      col(\\\"first_home_buyer\\\"),\\n\",\n    \"      col(\\\"loan_purpose\\\"),\\n\",\n    \"      col(\\\"property_type\\\"),\\n\",\n    \"      col(\\\"num_units\\\"),\\n\",\n    \"      col(\\\"occupancy_status\\\"),\\n\",\n    \"      col(\\\"property_state\\\"),\\n\",\n    \"      col(\\\"zip\\\"),\\n\",\n    \"      col(\\\"mortgage_insurance_percent\\\"),\\n\",\n    \"      col(\\\"product_type\\\"),\\n\",\n    \"      col(\\\"coborrow_credit_score\\\"),\\n\",\n    \"      col(\\\"mortgage_insurance_type\\\"),\\n\",\n    \"      col(\\\"relocation_mortgage_indicator\\\"),\\n\",\n    \"      col(\\\"quarter\\\"),\\n\",\n    \"      dense_rank().over(Window.partitionBy(\\\"loan_id\\\").orderBy(to_date(col(\\\"monthly_reporting_period\\\"),\\\"MMyyyy\\\"))).as(\\\"rank\\\")\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    acqDf.select(\\\"*\\\").filter(col(\\\"rank\\\") === 1)\\n\",\n    \"  }\\n\",\n    \"\\n\",\n    \"}\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"37c64d85\",\n   \"metadata\": {},\n   \"source\": [\n    \"* Build the spark session and data reader\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 11,\n   \"id\": \"98d37174\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"sparkSession = org.apache.spark.sql.SparkSession@694178ec\\n\",\n       \"reader = org.apache.spark.sql.DataFrameReader@4b2afd51\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"org.apache.spark.sql.DataFrameReader@4b2afd51\"\n      ]\n     },\n     \"execution_count\": 11,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"// Build the spark session and data reader as usual\\n\",\n    \"val sparkSession = SparkSession.builder.appName(\\\"mortgage-gpu\\\").config(\\\"spark.sql.cache.serializer\\\", \\\"com.nvidia.spark.ParquetCachedBatchSerializer\\\").getOrCreate\\n\",\n    \"\\n\",\n    \"// GPU run, set to true\\n\",\n    \"sparkSession.conf.set(\\\"spark.rapids.sql.enabled\\\", true)\\n\",\n    \"// CPU run, set to false\\n\",\n    \"// sparkSession.conf.set('spark.rapids.sql.enabled', 'false')\\n\",\n    \"// remove config(\\\"spark.sql.cache.serializer\\\", \\\"com.nvidia.spark.ParquetCachedBatchSerializer\\\") for CPU\\n\",\n    \"sparkSession.conf.set(\\\"spark.sql.files.maxPartitionBytes\\\", \\\"1G\\\")\\n\",\n    \"sparkSession.conf.set(\\\"spark.sql.broadcastTimeout\\\", 700)\\n\",\n    \"// use GPU to read CSV\\n\",\n    \"sparkSession.conf.set(\\\"spark.rapids.sql.csv.read.double.enabled\\\", true)\\n\",\n    \"\\n\",\n    \"val reader = sparkSession.read.schema(rawSchema)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"b47b5456\",\n   \"metadata\": {},\n   \"source\": [\n    \"* Read CSV Files\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 12,\n   \"id\": \"5bac2301\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"optionsMap = Map(header -> true)\\n\",\n       \"rawDf = [reference_pool_id: string, loan_id: bigint ... 107 more fields]\\n\",\n       \"perfSet = [loan_id: bigint, monthly_reporting_period: string ... 30 more fields]\\n\",\n       \"acqSet = [loan_id: bigint, orig_channel: string ... 25 more fields]\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"[loan_id: bigint, orig_channel: string ... 25 more fields]\"\n      ]\n     },\n     \"execution_count\": 12,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"val rawDf_csv = reader.option(\\\"header\\\", false)\\n\",\n    \"      .option(\\\"nullValue\\\", \\\"\\\")\\n\",\n    \"      .option(\\\"delimiter\\\", \\\"|\\\")\\n\",\n    \"      .option(\\\"parserLib\\\", \\\"univocity\\\")\\n\",\n    \"      .schema(rawSchema)\\n\",\n    \"      .csv(dataPath)\\n\",\n    \"      .withColumn(\\\"quarter\\\", GetQuarterFromCsvFileName())\\n\",\n    \"\\n\",\n    \"rawDf_csv.write.mode(\\\"overwrite\\\").parquet(output_csv2parquet)\\n\",\n    \"val rawDf = spark.read.parquet(output_csv2parquet)\\n\",\n    \"\\n\",\n    \"val perfSet = extractPerfColumns(rawDf)\\n\",\n    \"val acqSet = extractAcqColumns(rawDf)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"f4c814c8\",\n   \"metadata\": {},\n   \"source\": [\n    \"* Define ETL Object\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 13,\n   \"id\": \"a16155cb\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"defined trait MortgageETL\\n\",\n       \"allCols = List(orig_channel, first_home_buyer, loan_purpose, property_type, occupancy_status, property_state, product_type, relocation_mortgage_indicator, seller_name, mod_flag, orig_interest_rate, orig_upb, orig_loan_term, orig_ltv, orig_cltv, num_borrowers, dti, borrower_credit_score, num_units, zip, mortgage_insurance_percent, current_loan_delinquency_status, current_actual_upb, interest_rate, loan_age, msa, non_interest_bearing_upb, delinquency_12)\\n\",\n       \"defined object PerformanceETL\\n\",\n       \"defined object AcquisitionETL\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"List(orig_channel, first_home_buyer, loan_purpose, property_type, occupancy_status, property_state, product_type, relocation_mortgage_indicator, seller_name, mod_flag, orig_interest_rate, orig_upb, orig_loan_term, orig_ltv, orig_cltv, num_borrowers, dti, borrower_credit_score, num_units, zip, mortgage_insurance_percent, current_loan_delinquency_status, current_actual_upb, interest_rate, loan_age, msa, non_interest_bearing_upb, delinquency_12)\"\n      ]\n     },\n     \"execution_count\": 13,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"trait MortgageETL {\\n\",\n    \"  var dataFrame: DataFrame = _\\n\",\n    \"\\n\",\n    \"  def from(df: DataFrame): this.type = {\\n\",\n    \"    dataFrame = df\\n\",\n    \"    this\\n\",\n    \"  }\\n\",\n    \"}\\n\",\n    \"val allCols = (categaryCols ++ numericCols).map(c => col(c._1))\\n\",\n    \"\\n\",\n    \"object PerformanceETL extends MortgageETL {\\n\",\n    \"\\n\",\n    \"  def prepare: this.type = {\\n\",\n    \"    dataFrame = dataFrame\\n\",\n    \"      .withColumn(\\\"monthly_reporting_period\\\", to_date(col(\\\"monthly_reporting_period\\\"), \\\"MM/dd/yyyy\\\"))\\n\",\n    \"      .withColumn(\\\"monthly_reporting_period_month\\\", month(col(\\\"monthly_reporting_period\\\")))\\n\",\n    \"      .withColumn(\\\"monthly_reporting_period_year\\\", year(col(\\\"monthly_reporting_period\\\")))\\n\",\n    \"      .withColumn(\\\"monthly_reporting_period_day\\\", dayofmonth(col(\\\"monthly_reporting_period\\\")))\\n\",\n    \"      .withColumn(\\\"last_paid_installment_date\\\", to_date(col(\\\"last_paid_installment_date\\\"), \\\"MM/dd/yyyy\\\"))\\n\",\n    \"      .withColumn(\\\"foreclosed_after\\\", to_date(col(\\\"foreclosed_after\\\"), \\\"MM/dd/yyyy\\\"))\\n\",\n    \"      .withColumn(\\\"disposition_date\\\", to_date(col(\\\"disposition_date\\\"), \\\"MM/dd/yyyy\\\"))\\n\",\n    \"      .withColumn(\\\"maturity_date\\\", to_date(col(\\\"maturity_date\\\"), \\\"MM/yyyy\\\"))\\n\",\n    \"      .withColumn(\\\"zero_balance_effective_date\\\", to_date(col(\\\"zero_balance_effective_date\\\"), \\\"MM/yyyy\\\"))\\n\",\n    \"      .withColumn(\\\"current_actual_upb\\\", col(\\\"current_actual_upb\\\"))\\n\",\n    \"      .withColumn(\\\"current_loan_delinquency_status\\\", col(\\\"current_loan_delinquency_status\\\"))\\n\",\n    \"    this\\n\",\n    \"  }\\n\",\n    \"\\n\",\n    \"  def createDelinquency(spark: SparkSession): this.type = {\\n\",\n    \"    val aggDF = dataFrame\\n\",\n    \"      .select(\\n\",\n    \"        col(\\\"quarter\\\"),\\n\",\n    \"        col(\\\"loan_id\\\"),\\n\",\n    \"        col(\\\"current_loan_delinquency_status\\\"),\\n\",\n    \"        when(col(\\\"current_loan_delinquency_status\\\") >= 1, col(\\\"monthly_reporting_period\\\")).alias(\\\"delinquency_30\\\"),\\n\",\n    \"        when(col(\\\"current_loan_delinquency_status\\\") >= 3, col(\\\"monthly_reporting_period\\\")).alias(\\\"delinquency_90\\\"),\\n\",\n    \"        when(col(\\\"current_loan_delinquency_status\\\") >= 6, col(\\\"monthly_reporting_period\\\")).alias(\\\"delinquency_180\\\")\\n\",\n    \"      )\\n\",\n    \"      .groupBy(\\\"quarter\\\", \\\"loan_id\\\")\\n\",\n    \"      .agg(\\n\",\n    \"        max(\\\"current_loan_delinquency_status\\\").alias(\\\"delinquency_12\\\"),\\n\",\n    \"        min(\\\"delinquency_30\\\").alias(\\\"delinquency_30\\\"),\\n\",\n    \"        min(\\\"delinquency_90\\\").alias(\\\"delinquency_90\\\"),\\n\",\n    \"        min(\\\"delinquency_180\\\").alias(\\\"delinquency_180\\\")\\n\",\n    \"      )\\n\",\n    \"      .select(\\n\",\n    \"        col(\\\"quarter\\\"),\\n\",\n    \"        col(\\\"loan_id\\\"),\\n\",\n    \"        (col(\\\"delinquency_12\\\") >= 1).alias(\\\"ever_30\\\"),\\n\",\n    \"        (col(\\\"delinquency_12\\\") >= 3).alias(\\\"ever_90\\\"),\\n\",\n    \"        (col(\\\"delinquency_12\\\") >= 6).alias(\\\"ever_180\\\"),\\n\",\n    \"        col(\\\"delinquency_30\\\"),\\n\",\n    \"        col(\\\"delinquency_90\\\"),\\n\",\n    \"        col(\\\"delinquency_180\\\")\\n\",\n    \"      )\\n\",\n    \"\\n\",\n    \"    val joinedDf = dataFrame\\n\",\n    \"      .withColumnRenamed(\\\"monthly_reporting_period\\\", \\\"timestamp\\\")\\n\",\n    \"      .withColumnRenamed(\\\"monthly_reporting_period_month\\\", \\\"timestamp_month\\\")\\n\",\n    \"      .withColumnRenamed(\\\"monthly_reporting_period_year\\\", \\\"timestamp_year\\\")\\n\",\n    \"      .withColumnRenamed(\\\"current_loan_delinquency_status\\\", \\\"delinquency_12\\\")\\n\",\n    \"      .withColumnRenamed(\\\"current_actual_upb\\\", \\\"upb_12\\\")\\n\",\n    \"      .select(\\\"quarter\\\", \\\"loan_id\\\", \\\"timestamp\\\", \\\"delinquency_12\\\", \\\"upb_12\\\", \\\"timestamp_month\\\", \\\"timestamp_year\\\")\\n\",\n    \"      .join(aggDF, Seq(\\\"loan_id\\\", \\\"quarter\\\"), \\\"left_outer\\\")\\n\",\n    \"\\n\",\n    \"    // calculate the 12 month delinquency and upb values\\n\",\n    \"    val months = 12\\n\",\n    \"    val monthArray = 0.until(months).toArray\\n\",\n    \"    val testDf = joinedDf\\n\",\n    \"      // explode on a small amount of data is actually slightly more efficient than a cross join\\n\",\n    \"      .withColumn(\\\"month_y\\\", explode(lit(monthArray)))\\n\",\n    \"      .select(\\n\",\n    \"        col(\\\"quarter\\\"),\\n\",\n    \"        floor(((col(\\\"timestamp_year\\\") * 12 + col(\\\"timestamp_month\\\")) - 24000) / months).alias(\\\"josh_mody\\\"),\\n\",\n    \"        floor(((col(\\\"timestamp_year\\\") * 12 + col(\\\"timestamp_month\\\")) - 24000 - col(\\\"month_y\\\")) / months).alias(\\\"josh_mody_n\\\"),\\n\",\n    \"        col(\\\"ever_30\\\"),\\n\",\n    \"        col(\\\"ever_90\\\"),\\n\",\n    \"        col(\\\"ever_180\\\"),\\n\",\n    \"        col(\\\"delinquency_30\\\"),\\n\",\n    \"        col(\\\"delinquency_90\\\"),\\n\",\n    \"        col(\\\"delinquency_180\\\"),\\n\",\n    \"        col(\\\"loan_id\\\"),\\n\",\n    \"        col(\\\"month_y\\\"),\\n\",\n    \"        col(\\\"delinquency_12\\\"),\\n\",\n    \"        col(\\\"upb_12\\\")\\n\",\n    \"      )\\n\",\n    \"      .groupBy(\\\"quarter\\\", \\\"loan_id\\\", \\\"josh_mody_n\\\", \\\"ever_30\\\", \\\"ever_90\\\", \\\"ever_180\\\", \\\"delinquency_30\\\", \\\"delinquency_90\\\", \\\"delinquency_180\\\", \\\"month_y\\\")\\n\",\n    \"      .agg(max(\\\"delinquency_12\\\").alias(\\\"delinquency_12\\\"), min(\\\"upb_12\\\").alias(\\\"upb_12\\\"))\\n\",\n    \"      .withColumn(\\\"timestamp_year\\\", floor((lit(24000) + (col(\\\"josh_mody_n\\\") * lit(months)) + (col(\\\"month_y\\\") - 1)) / lit(12)))\\n\",\n    \"      .withColumn(\\\"timestamp_month_tmp\\\", pmod(lit(24000) + (col(\\\"josh_mody_n\\\") * lit(months)) + col(\\\"month_y\\\"), lit(12)))\\n\",\n    \"      .withColumn(\\\"timestamp_month\\\", when(col(\\\"timestamp_month_tmp\\\") === lit(0), lit(12)).otherwise(col(\\\"timestamp_month_tmp\\\")))\\n\",\n    \"      .withColumn(\\\"delinquency_12\\\", ((col(\\\"delinquency_12\\\") > 3).cast(\\\"int\\\") + (col(\\\"upb_12\\\") === 0).cast(\\\"int\\\")).alias(\\\"delinquency_12\\\"))\\n\",\n    \"      .drop(\\\"timestamp_month_tmp\\\", \\\"josh_mody_n\\\", \\\"month_y\\\")\\n\",\n    \"\\n\",\n    \"    dataFrame = dataFrame\\n\",\n    \"      .withColumnRenamed(\\\"monthly_reporting_period_month\\\", \\\"timestamp_month\\\")\\n\",\n    \"      .withColumnRenamed(\\\"monthly_reporting_period_year\\\", \\\"timestamp_year\\\")\\n\",\n    \"      .join(testDf, Seq(\\\"quarter\\\", \\\"loan_id\\\", \\\"timestamp_year\\\", \\\"timestamp_month\\\"), \\\"left\\\").drop(\\\"timestamp_year\\\", \\\"timestamp_month\\\")\\n\",\n    \"    this\\n\",\n    \"  }\\n\",\n    \"}\\n\",\n    \"\\n\",\n    \"object AcquisitionETL extends MortgageETL {\\n\",\n    \"\\n\",\n    \"  def createAcquisition(spark: SparkSession): this.type = {\\n\",\n    \"    val nameMapping = NameMapping(spark, \\\"from_seller_name\\\", \\\"to_seller_name\\\")\\n\",\n    \"    dataFrame = dataFrame\\n\",\n    \"      .join(nameMapping, col(\\\"seller_name\\\") === col(\\\"from_seller_name\\\"), \\\"left\\\")\\n\",\n    \"      .drop(\\\"from_seller_name\\\")\\n\",\n    \"      /* backup the original name before we replace it */\\n\",\n    \"      .withColumn(\\\"old_name\\\", col(\\\"seller_name\\\"))\\n\",\n    \"      /* replace seller_name with the new version if we found one in the mapping, or the old version\\n\",\n    \"       if we didn't */\\n\",\n    \"      .withColumn(\\\"seller_name\\\", coalesce(col(\\\"to_seller_name\\\"), col(\\\"seller_name\\\")))\\n\",\n    \"      .drop(\\\"to_seller_name\\\")\\n\",\n    \"      .withColumn(\\\"orig_date\\\", to_date(col(\\\"orig_date\\\"), \\\"MM/yyyy\\\"))\\n\",\n    \"      .withColumn(\\\"first_pay_date\\\", to_date(col(\\\"first_pay_date\\\"), \\\"MM/yyyy\\\"))\\n\",\n    \"    this\\n\",\n    \"  }\\n\",\n    \"\\n\",\n    \"  def cleanPrime(perfDF: DataFrame): this.type = {\\n\",\n    \"    dataFrame = perfDF.join(dataFrame, Seq(\\\"loan_id\\\", \\\"quarter\\\"), \\\"inner\\\").drop(\\\"quarter\\\")\\n\",\n    \"    this\\n\",\n    \"  }\\n\",\n    \"}\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 14,\n   \"id\": \"78b76252\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"transform: (perfDF: org.apache.spark.sql.DataFrame, acqDF: org.apache.spark.sql.DataFrame, spark: org.apache.spark.sql.SparkSession)org.apache.spark.sql.DataFrame\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"def transform(perfDF: DataFrame, acqDF: DataFrame, spark: SparkSession): DataFrame = {\\n\",\n    \"    val etlPerfDF = PerformanceETL.from(perfDF)\\n\",\n    \"      .prepare\\n\",\n    \"      .createDelinquency(spark)\\n\",\n    \"      .dataFrame\\n\",\n    \"    val cleanDF = AcquisitionETL.from(acqDF)\\n\",\n    \"      .createAcquisition(spark)\\n\",\n    \"      .cleanPrime(etlPerfDF)\\n\",\n    \"      .dataFrame\\n\",\n    \"\\n\",\n    \"    // Convert to xgb required Dataset\\n\",\n    \"    castStringColumnsToNumeric(cleanDF, spark)\\n\",\n    \"      .select(allCols: _*)\\n\",\n    \"      .withColumn(labelColName, when(col(labelColName) > 0, 1).otherwise(0))\\n\",\n    \"      .na.fill(0.0f)\\n\",\n    \"  }\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"b1234f49\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Run ETL Process and Save the Result\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 15,\n   \"id\": \"ffdb0a62\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Elapsed time : 399.241s\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"t0 = 1656695479451\\n\",\n       \"optionsMap = Map(header -> true)\\n\",\n       \"rawDF = [orig_channel: int, first_home_buyer: int ... 26 more fields]\\n\",\n       \"t1 = 1656695878692\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"1656695878692\"\n      ]\n     },\n     \"execution_count\": 15,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"val t0 = System.currentTimeMillis\\n\",\n    \"val rawDF = transform(\\n\",\n    \"      perfSet,\\n\",\n    \"      acqSet,\\n\",\n    \"      sparkSession\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"val etlDataPath = new Path(outPath, \\\"data\\\").toString\\n\",\n    \"rawDF.write.mode(\\\"overwrite\\\").parquet(etlDataPath)\\n\",\n    \"\\n\",\n    \"if(saveTrainEvalDataset == true)\\n\",\n    \"{\\n\",\n    \"  val etlDf = sparkSession.read.parquet(etlDataPath)\\n\",\n    \"  val sets = etlDf.randomSplit(Array[Double](0.8, 0.2))\\n\",\n    \"  val train = sets(0)\\n\",\n    \"  val eval = sets(1)\\n\",\n    \"  train.write.mode(\\\"overwrite\\\").parquet(new Path(outPath, \\\"train\\\").toString)\\n\",\n    \"  eval.write.mode(\\\"overwrite\\\").parquet(new Path(outPath, \\\"eval\\\").toString)\\n\",\n    \"}\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"val t1 = System.currentTimeMillis\\n\",\n    \"println(\\\"Elapsed time : \\\" + ((t1 - t0).toFloat / 1000) + \\\"s\\\")\\n\",\n    \"sparkSession.stop()\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"XGBoost4j-Spark Scala\",\n   \"language\": \"scala\",\n   \"name\": \"XGBoost4j-Spark_scala\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": \"text/x-scala\",\n   \"file_extension\": \".scala\",\n   \"mimetype\": \"text/x-scala\",\n   \"name\": \"scala\",\n   \"pygments_lexer\": \"scala\",\n   \"version\": \"2.12.15\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "examples/XGBoost-Examples/mortgage/notebooks/scala/mortgage-gpu.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Introduction to XGBoost Spark with GPU\\n\",\n    \"\\n\",\n    \"The goal of this notebook is to show how to train a XGBoost Model with Spark RAPIDS XGBoost library on GPUs. The dataset used with this notebook is derived from Fannie Mae’s Single-Family Loan Performance Data with all rights reserved by Fannie Mae. This processed dataset is redistributed with permission and consent from Fannie Mae. This notebook uses XGBoost to train 12-month mortgage loan delinquency prediction model.\\n\",\n    \"\\n\",\n    \"## Load libraries\\n\",\n    \"First load some common libraries will be used by both GPU version and CPU version xgboost.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import ml.dmlc.xgboost4j.scala.spark.{XGBoostClassifier, XGBoostClassificationModel}\\n\",\n    \"import org.apache.spark.sql.SparkSession\\n\",\n    \"import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator\\n\",\n    \"import org.apache.spark.sql.types.{DoubleType, IntegerType, StructField, StructType}\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Besides CPU version requires some extra libraries, such as:\\n\",\n    \"\\n\",\n    \"```scala\\n\",\n    \"import org.apache.spark.ml.feature.VectorAssembler\\n\",\n    \"import org.apache.spark.sql.DataFrame\\n\",\n    \"import org.apache.spark.sql.functions._\\n\",\n    \"import org.apache.spark.sql.types.FloatType\\n\",\n    \"```\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Set the dataset path\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"// You need to update them to your real paths! The input data files is the output of mortgage-etl jobs\\n\",\n    \"val dataRoot = sys.env.getOrElse(\\\"DATA_ROOT\\\", \\\"/data\\\")\\n\",\n    \"val trainPath = dataRoot + \\\"/mortgage/output/train/\\\"\\n\",\n    \"val evalPath  = dataRoot + \\\"/mortgage/output/eval/\\\"\\n\",\n    \"val transPath = dataRoot + \\\"/mortgage/output/eval/\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Build the schema and parameters\\n\",\n    \"The mortgage data has 27 columns: 26 features and 1 label. \\\"deinquency_12\\\" is the label column. The schema will be used to load data in the future.\\n\",\n    \"\\n\",\n    \"The next block also defines some key parameters used in xgboost training process.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"labelColName = delinquency_12\\n\",\n       \"schema = StructType(StructField(orig_channel,DoubleType,true), StructField(first_home_buyer,DoubleType,true), StructField(loan_purpose,DoubleType,true), StructField(property_type,DoubleType,true), StructField(occupancy_status,DoubleType,true), StructField(property_state,DoubleType,true), StructField(product_type,DoubleType,true), StructField(relocation_mortgage_indicator,DoubleType,true), StructField(seller_name,DoubleType,true), StructField(mod_flag,DoubleType,true), StructField(orig_interest_rate,DoubleType,true), StructField(orig_upb,IntegerType,true), StructField(orig_loan_term,IntegerType,true), StructField(orig_ltv,DoubleType,true), StructField(orig_cltv,DoubleType,true), StructField(num_borrowers,DoubleT...\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"StructType(StructField(orig_channel,DoubleType,true), StructField(first_home_buyer,DoubleType,true), StructField(loan_purpose,DoubleType,true), StructField(property_type,DoubleType,true), StructField(occupancy_status,DoubleType,true), StructField(property_state,DoubleType,true), StructField(product_type,DoubleType,true), StructField(relocation_mortgage_indicator,DoubleType,true), StructField(seller_name,DoubleType,true), StructField(mod_flag,DoubleType,true), StructField(orig_interest_rate,DoubleType,true), StructField(orig_upb,IntegerType,true), StructField(orig_loan_term,IntegerType,true), StructField(orig_ltv,DoubleType,true), StructField(orig_cltv,DoubleType,true), StructField(num_borrowers,DoubleT...\"\n      ]\n     },\n     \"execution_count\": 3,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"val labelColName = \\\"delinquency_12\\\"\\n\",\n    \"val schema = StructType(List(\\n\",\n    \"  StructField(\\\"orig_channel\\\", DoubleType),\\n\",\n    \"  StructField(\\\"first_home_buyer\\\", DoubleType),\\n\",\n    \"  StructField(\\\"loan_purpose\\\", DoubleType),\\n\",\n    \"  StructField(\\\"property_type\\\", DoubleType),\\n\",\n    \"  StructField(\\\"occupancy_status\\\", DoubleType),\\n\",\n    \"  StructField(\\\"property_state\\\", DoubleType),\\n\",\n    \"  StructField(\\\"product_type\\\", DoubleType),\\n\",\n    \"  StructField(\\\"relocation_mortgage_indicator\\\", DoubleType),\\n\",\n    \"  StructField(\\\"seller_name\\\", DoubleType),\\n\",\n    \"  StructField(\\\"mod_flag\\\", DoubleType),\\n\",\n    \"  StructField(\\\"orig_interest_rate\\\", DoubleType),\\n\",\n    \"  StructField(\\\"orig_upb\\\", DoubleType),\\n\",\n    \"  StructField(\\\"orig_loan_term\\\", IntegerType),\\n\",\n    \"  StructField(\\\"orig_ltv\\\", DoubleType),\\n\",\n    \"  StructField(\\\"orig_cltv\\\", DoubleType),\\n\",\n    \"  StructField(\\\"num_borrowers\\\", DoubleType),\\n\",\n    \"  StructField(\\\"dti\\\", DoubleType),\\n\",\n    \"  StructField(\\\"borrower_credit_score\\\", DoubleType),\\n\",\n    \"  StructField(\\\"num_units\\\", IntegerType),\\n\",\n    \"  StructField(\\\"zip\\\", IntegerType),\\n\",\n    \"  StructField(\\\"mortgage_insurance_percent\\\", DoubleType),\\n\",\n    \"  StructField(\\\"current_loan_delinquency_status\\\", IntegerType),\\n\",\n    \"  StructField(\\\"current_actual_upb\\\", DoubleType),\\n\",\n    \"  StructField(\\\"interest_rate\\\", DoubleType),\\n\",\n    \"  StructField(\\\"loan_age\\\", DoubleType),\\n\",\n    \"  StructField(\\\"msa\\\", DoubleType),\\n\",\n    \"  StructField(\\\"non_interest_bearing_upb\\\", DoubleType),\\n\",\n    \"  StructField(labelColName, IntegerType)))\\n\",\n    \"\\n\",\n    \"val featureNames = schema.filter(_.name != labelColName).map(_.name).toArray\\n\",\n    \"\\n\",\n    \"val commParamMap = Map(\\n\",\n    \"  \\\"objective\\\" -> \\\"binary:logistic\\\",\\n\",\n    \"  \\\"num_round\\\" -> 100)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Create a new spark session and load data\\n\",\n    \"\\n\",\n    \"A new spark session should be created to continue all the following spark operations.\\n\",\n    \"\\n\",\n    \"NOTE: in this notebook, the dependency jars have been loaded when installing toree kernel. Alternatively the jars can be loaded into notebook by [%AddJar magic](https://toree.incubator.apache.org/docs/current/user/faq/). However, there's one restriction for `%AddJar`: the jar uploaded can only be available when `AddJar` is called just after a new spark session is created. Do it as below:\\n\",\n    \"\\n\",\n    \"```scala\\n\",\n    \"import org.apache.spark.sql.SparkSession\\n\",\n    \"val spark = SparkSession.builder().appName(\\\"mortgage-GPU\\\").getOrCreate\\n\",\n    \"%AddJar file:/data/libs/rapids-4-spark-XXX.jar\\n\",\n    \"%AddJar file:/data/libs/xgboost4j-spark-gpu_2.12-XXX.jar\\n\",\n    \"%AddJar file:/data/libs/xgboost4j-gpu_2.12-XXX.jar\\n\",\n    \"// ...\\n\",\n    \"```\\n\",\n    \"\\n\",\n    \"##### Please note the new jar \\\"rapids-4-spark-XXX.jar\\\" is only needed for GPU version, you can not add it to dependence list for CPU version.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"sparkSession = org.apache.spark.sql.SparkSession@26420dda\\n\",\n       \"reader = org.apache.spark.sql.DataFrameReader@77740a8c\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"org.apache.spark.sql.DataFrameReader@77740a8c\"\n      ]\n     },\n     \"execution_count\": 4,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"// Build the spark session and data reader as usual\\n\",\n    \"val sparkSession = SparkSession.builder.appName(\\\"mortgage-gpu\\\").getOrCreate\\n\",\n    \"val reader = sparkSession.read\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"trainSet = [orig_channel: double, first_home_buyer: double ... 26 more fields]\\n\",\n       \"evalSet = [orig_channel: double, first_home_buyer: double ... 26 more fields]\\n\",\n       \"transSet = [orig_channel: double, first_home_buyer: double ... 26 more fields]\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"[orig_channel: double, first_home_buyer: double ... 26 more fields]\"\n      ]\n     },\n     \"execution_count\": 5,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"val trainSet = reader.parquet(trainPath)\\n\",\n    \"val evalSet  = reader.parquet(evalPath)\\n\",\n    \"val transSet = reader.parquet(transPath)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Set xgboost parameters and build a XGBoostClassifier\\n\",\n    \"\\n\",\n    \"For CPU version, `num_workers` is recommended being equal to the number of CPU cores, while for GPU version, it should be set to the number of GPUs in Spark cluster.\\n\",\n    \"\\n\",\n    \"Besides the `device` for CPU version is also different from that for GPU version. Now only \\\"cuda\\\" is supported for training on GPU.\\n\",\n    \"\\n\",\n    \"```scala\\n\",\n    \"// difference in parameters\\n\",\n    \"  \\\"num_workers\\\" -> 12,\\n\",\n    \"  \\\"device\\\" -> \\\"cpu\\\",\\n\",\n    \"```\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"xgbParamFinal = Map(objective -> binary:logistic, num_round -> 100, tree_method -> hist, device -> cuda, num_workers -> 1)\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"Map(objective -> binary:logistic, num_round -> 100, tree_method -> hist, device -> cuda, num_workers -> 1)\"\n      ]\n     },\n     \"execution_count\": 6,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": \"val xgbParamFinal = commParamMap ++ Map(\\\"tree_method\\\" -> \\\"hist\\\", \\\"device\\\" -> \\\"cuda\\\", \\\"num_workers\\\" -> 1)\"\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"xgbClassifier = xgbc_ecac6474dbb2\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"xgbc_ecac6474dbb2\"\n      ]\n     },\n     \"execution_count\": 7,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"val xgbClassifier = new XGBoostClassifier(xgbParamFinal)\\n\",\n    \"      .setLabelCol(labelColName)\\n\",\n    \"      .setFeaturesCol(featureNames)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Benchmark and train\\n\",\n    \"The object `benchmark` is used to compute the elapsed time of some operations.\\n\",\n    \"\\n\",\n    \"Training with evaluation dataset is also supported, the same as CPU version's behavior:\\n\",\n    \"\\n\",\n    \"* Call API `setEvalDataset` after initializing an XGBoostClassifier\\n\",\n    \"\\n\",\n    \"```scala\\n\",\n    \"xgbClassifier.setEvalDataset(evalSet)\\n\",\n    \"```\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"xgbc_ecac6474dbb2\"\n      ]\n     },\n     \"execution_count\": 8,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"xgbClassifier.setEvalDataset(evalSet)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 9,\n   \"metadata\": {\n    \"scrolled\": true\n   },\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"defined object Benchmark\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"object Benchmark {\\n\",\n    \"  def time[R](phase: String)(block: => R): (R, Float) = {\\n\",\n    \"    val t0 = System.currentTimeMillis\\n\",\n    \"    val result = block // call-by-name\\n\",\n    \"    val t1 = System.currentTimeMillis\\n\",\n    \"    println(\\\"Elapsed time [\\\" + phase + \\\"]: \\\" + ((t1 - t0).toFloat / 1000) + \\\"s\\\")\\n\",\n    \"    (result, (t1 - t0).toFloat / 1000)\\n\",\n    \"  }\\n\",\n    \"}\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 10,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\n\",\n      \"------ Training ------\\n\",\n      \"Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=38315, DMLC_NUM_WORKER=1}\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"xgbClassificationModel = xgbc_ecac6474dbb2\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Elapsed time [train]: 8.083s\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"xgbc_ecac6474dbb2\"\n      ]\n     },\n     \"execution_count\": 10,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"// Start training\\n\",\n    \"println(\\\"\\\\n------ Training ------\\\")\\n\",\n    \"val (xgbClassificationModel, _) = Benchmark.time(\\\"train\\\") {\\n\",\n    \"  xgbClassifier.fit(trainSet)\\n\",\n    \"}\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Transformation and evaluation\\n\",\n    \"Here uses `transSet` to evaluate our model and prints some useful columns to show our prediction result. After that `MulticlassClassificationEvaluator` is used to calculate an overall accuracy of our predictions.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 11,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\n\",\n      \"------ Transforming ------\\n\",\n      \"Elapsed time [transform]: 1.916s\\n\",\n      \"+------------+--------------+--------------------+--------------------+----------+\\n\",\n      \"|orig_channel|delinquency_12|       rawPrediction|         probability|prediction|\\n\",\n      \"+------------+--------------+--------------------+--------------------+----------+\\n\",\n      \"|         0.0|             0|[7.57764625549316...|[0.99948849738575...|       0.0|\\n\",\n      \"|         0.0|             0|[8.74893283843994...|[0.99984139463049...|       0.0|\\n\",\n      \"|         0.0|             0|[8.74893283843994...|[0.99984139463049...|       0.0|\\n\",\n      \"|         0.0|             0|[8.74893283843994...|[0.99984139463049...|       0.0|\\n\",\n      \"|         0.0|             0|[7.57764625549316...|[0.99948849738575...|       0.0|\\n\",\n      \"|         0.0|             0|[7.57764625549316...|[0.99948849738575...|       0.0|\\n\",\n      \"|         0.0|             0|[7.57764625549316...|[0.99948849738575...|       0.0|\\n\",\n      \"|         0.0|             0|[6.58476591110229...|[0.99862065445631...|       0.0|\\n\",\n      \"|         0.0|             0|[7.98751401901245...|[0.99966043786844...|       0.0|\\n\",\n      \"|         0.0|             0|[7.21919107437133...|[0.99926814140053...|       0.0|\\n\",\n      \"+------------+--------------+--------------------+--------------------+----------+\\n\",\n      \"only showing top 10 rows\\n\",\n      \"\\n\",\n      \"\\n\",\n      \"------Accuracy of Evaluation------\\n\",\n      \"1.0\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"results = [orig_channel: double, first_home_buyer: double ... 29 more fields]\\n\",\n       \"evaluator = MulticlassClassificationEvaluator: uid=mcEval_d9645b60a007, metricName=f1, metricLabel=0.0, beta=1.0, eps=1.0E-15\\n\",\n       \"accuracy = 1.0\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"1.0\"\n      ]\n     },\n     \"execution_count\": 11,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"println(\\\"\\\\n------ Transforming ------\\\")\\n\",\n    \"val (results, _) = Benchmark.time(\\\"transform\\\") {\\n\",\n    \"  val ret = xgbClassificationModel.transform(transSet).cache()\\n\",\n    \"  ret.foreachPartition((_: Iterator[_]) => ())\\n\",\n    \"  ret\\n\",\n    \"}\\n\",\n    \"results.select(\\\"orig_channel\\\", labelColName,\\\"rawPrediction\\\",\\\"probability\\\",\\\"prediction\\\").show(10)\\n\",\n    \"\\n\",\n    \"println(\\\"\\\\n------Accuracy of Evaluation------\\\")\\n\",\n    \"val evaluator = new MulticlassClassificationEvaluator().setLabelCol(labelColName)\\n\",\n    \"val accuracy = evaluator.evaluate(results)\\n\",\n    \"println(accuracy)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Save the model to disk and load model\\n\",\n    \"Save the model to disk and then load it to memory. After that use the loaded model to do a new prediction.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 12,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Elapsed time [transform2]: 0.044s\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"modelFromDisk = xgbc_ecac6474dbb2\\n\",\n       \"results2 = [orig_channel: double, first_home_buyer: double ... 29 more fields]\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+------------+----------------+------------+-------------+----------------+--------------+------------+-----------------------------+-----------+--------+------------------+--------+--------------+--------+---------+-------------+----+---------------------+---------+---+--------------------------+-------------------------------+------------------+-------------+--------+-------+------------------------+--------------+--------------------+--------------------+----------+\\n\",\n      \"|orig_channel|first_home_buyer|loan_purpose|property_type|occupancy_status|property_state|product_type|relocation_mortgage_indicator|seller_name|mod_flag|orig_interest_rate|orig_upb|orig_loan_term|orig_ltv|orig_cltv|num_borrowers| dti|borrower_credit_score|num_units|zip|mortgage_insurance_percent|current_loan_delinquency_status|current_actual_upb|interest_rate|loan_age|    msa|non_interest_bearing_upb|delinquency_12|       rawPrediction|         probability|prediction|\\n\",\n      \"+------------+----------------+------------+-------------+----------------+--------------+------------+-----------------------------+-----------+--------+------------------+--------+--------------+--------+---------+-------------+----+---------------------+---------+---+--------------------------+-------------------------------+------------------+-------------+--------+-------+------------------------+--------------+--------------------+--------------------+----------+\\n\",\n      \"|         0.0|             0.0|         0.0|          0.0|             0.0|           0.0|         0.0|                          0.0|        0.0|     0.0|              5.75|   81000|           360|    95.0|      0.0|          1.0|39.0|                696.0|        1|191|                      30.0|                             -2|           7747.01|         5.75|    81.0|37980.0|                     0.0|             0|[7.57764625549316...|[0.99948849738575...|       0.0|\\n\",\n      \"|         0.0|             0.0|         0.0|          0.0|             0.0|           0.0|         0.0|                          0.0|        0.0|     0.0|              5.75|   81000|           360|    95.0|      0.0|          1.0|39.0|                696.0|        1|191|                      30.0|                              0|               0.0|         5.75|     0.0|37980.0|                     0.0|             0|[8.74893283843994...|[0.99984139463049...|       0.0|\\n\",\n      \"|         0.0|             0.0|         0.0|          0.0|             0.0|           0.0|         0.0|                          0.0|        0.0|     0.0|              5.75|   81000|           360|    95.0|      0.0|          1.0|39.0|                696.0|        1|191|                      30.0|                              0|               0.0|         5.75|     2.0|37980.0|                     0.0|             0|[8.74893283843994...|[0.99984139463049...|       0.0|\\n\",\n      \"|         0.0|             0.0|         0.0|          0.0|             0.0|           0.0|         0.0|                          0.0|        0.0|     0.0|              5.75|   81000|           360|    95.0|      0.0|          1.0|39.0|                696.0|        1|191|                      30.0|                              0|               0.0|         5.75|     5.0|37980.0|                     0.0|             0|[8.74893283843994...|[0.99984139463049...|       0.0|\\n\",\n      \"|         0.0|             0.0|         0.0|          0.0|             0.0|           0.0|         0.0|                          0.0|        0.0|     0.0|              5.75|   81000|           360|    95.0|      0.0|          1.0|39.0|                696.0|        1|191|                      30.0|                              0|           7747.01|         5.75|    80.0|37980.0|                     0.0|             0|[7.57764625549316...|[0.99948849738575...|       0.0|\\n\",\n      \"|         0.0|             0.0|         0.0|          0.0|             0.0|           0.0|         0.0|                          0.0|        0.0|     0.0|              5.75|   81000|           360|    95.0|      0.0|          1.0|39.0|                696.0|        1|191|                      30.0|                              0|          13155.21|         5.75|    79.0|37980.0|                     0.0|             0|[7.57764625549316...|[0.99948849738575...|       0.0|\\n\",\n      \"|         0.0|             0.0|         0.0|          0.0|             0.0|           0.0|         0.0|                          0.0|        0.0|     0.0|              5.75|   81000|           360|    95.0|      0.0|          1.0|39.0|                696.0|        1|191|                      30.0|                              0|          18526.93|         5.75|    78.0|37980.0|                     0.0|             0|[7.57764625549316...|[0.99948849738575...|       0.0|\\n\",\n      \"|         0.0|             0.0|         0.0|          0.0|             0.0|           0.0|         0.0|                          0.0|        0.0|     0.0|              5.75|   81000|           360|    95.0|      0.0|          1.0|39.0|                696.0|        1|191|                      30.0|                              0|          23883.73|         5.75|    77.0|37980.0|                     0.0|             0|[6.58476591110229...|[0.99862065445631...|       0.0|\\n\",\n      \"|         0.0|             0.0|         0.0|          0.0|             0.0|           0.0|         0.0|                          0.0|        0.0|     0.0|              5.75|   81000|           360|    95.0|      0.0|          1.0|39.0|                696.0|        1|191|                      30.0|                              0|          29214.98|         5.75|    76.0|37980.0|                     0.0|             0|[7.98751401901245...|[0.99966043786844...|       0.0|\\n\",\n      \"|         0.0|             0.0|         0.0|          0.0|             0.0|           0.0|         0.0|                          0.0|        0.0|     0.0|              5.75|   81000|           360|    95.0|      0.0|          1.0|39.0|                696.0|        1|191|                      30.0|                              0|          34520.81|         5.75|    75.0|37980.0|                     0.0|             0|[7.21919107437133...|[0.99926814140053...|       0.0|\\n\",\n      \"+------------+----------------+------------+-------------+----------------+--------------+------------+-----------------------------+-----------+--------+------------------+--------+--------------+--------+---------+-------------+----+---------------------+---------+---+--------------------------+-------------------------------+------------------+-------------+--------+-------+------------------------+--------------+--------------------+--------------------+----------+\\n\",\n      \"only showing top 10 rows\\n\",\n      \"\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"[orig_channel: double, first_home_buyer: double ... 29 more fields]\"\n      ]\n     },\n     \"execution_count\": 12,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"xgbClassificationModel.write.overwrite.save(dataRoot + \\\"/mortgage/model/\\\")\\n\",\n    \"\\n\",\n    \"val modelFromDisk = XGBoostClassificationModel.load(dataRoot + \\\"/mortgage/model/\\\")\\n\",\n    \"\\n\",\n    \"val (results2, _) = Benchmark.time(\\\"transform2\\\") {\\n\",\n    \"  modelFromDisk.transform(transSet)\\n\",\n    \"}\\n\",\n    \"results2.show(10)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 13,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"sparkSession.close()\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"XGBoost4j-Spark - Scala\",\n   \"language\": \"scala\",\n   \"name\": \"XGBoost4j-Spark_scala\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": \"text/x-scala\",\n   \"file_extension\": \".scala\",\n   \"mimetype\": \"text/x-scala\",\n   \"name\": \"scala\",\n   \"pygments_lexer\": \"scala\",\n   \"version\": \"2.12.15\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "examples/XGBoost-Examples/mortgage/notebooks/scala/mortgage_gpu_crossvalidation.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Mortgage CrossValidation with GPU accelerating on XGBoost\\n\",\n    \"\\n\",\n    \"In this notebook, we will show you how to levarage GPU to accelerate mortgage CrossValidation of XGBoost to find out the best model given a group of parameters.\\n\",\n    \"\\n\",\n    \"## Import classes\\n\",\n    \"First we need load some common classes that both GPU version and CPU version will use:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import ml.dmlc.xgboost4j.scala.spark.{XGBoostClassificationModel, XGBoostClassifier}\\n\",\n    \"\\n\",\n    \"import org.apache.spark.sql.SparkSession\\n\",\n    \"import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator\\n\",\n    \"import org.apache.spark.ml.tuning.{ParamGridBuilder,CrossValidator}\\n\",\n    \"import org.apache.spark.sql.types.{FloatType, IntegerType, StructField, StructType, DoubleType}\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"what is new to xgboost-spark users is **rapids.CrossValidator**\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Set dataset path\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"// You need to update them to your real paths!\\n\",\n    \"val dataRoot = sys.env.getOrElse(\\\"DATA_ROOT\\\", \\\"/data\\\")\\n\",\n    \"val trainParquetPath=dataRoot + \\\"/mortgage/output/train\\\"\\n\",\n    \"val evalParquetPath=dataRoot + \\\"/mortgage/output/eval\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Set the schema of the dataset\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"val labelColName = \\\"delinquency_12\\\"\\n\",\n    \"val schema = StructType(List(\\n\",\n    \"    StructField(\\\"orig_channel\\\", FloatType),\\n\",\n    \"    StructField(\\\"first_home_buyer\\\", FloatType),\\n\",\n    \"    StructField(\\\"loan_purpose\\\", FloatType),\\n\",\n    \"    StructField(\\\"property_type\\\", FloatType),\\n\",\n    \"    StructField(\\\"occupancy_status\\\", FloatType),\\n\",\n    \"    StructField(\\\"property_state\\\", FloatType),\\n\",\n    \"    StructField(\\\"product_type\\\", FloatType),\\n\",\n    \"    StructField(\\\"relocation_mortgage_indicator\\\", FloatType),\\n\",\n    \"    StructField(\\\"seller_name\\\", FloatType),\\n\",\n    \"    StructField(\\\"mod_flag\\\", FloatType),\\n\",\n    \"    StructField(\\\"orig_interest_rate\\\", FloatType),\\n\",\n    \"    StructField(\\\"orig_upb\\\", DoubleType),\\n\",\n    \"    StructField(\\\"orig_loan_term\\\", IntegerType),\\n\",\n    \"    StructField(\\\"orig_ltv\\\", FloatType),\\n\",\n    \"    StructField(\\\"orig_cltv\\\", FloatType),\\n\",\n    \"    StructField(\\\"num_borrowers\\\", FloatType),\\n\",\n    \"    StructField(\\\"dti\\\", FloatType),\\n\",\n    \"    StructField(\\\"borrower_credit_score\\\", FloatType),\\n\",\n    \"    StructField(\\\"num_units\\\", IntegerType),\\n\",\n    \"    StructField(\\\"zip\\\", IntegerType),\\n\",\n    \"    StructField(\\\"mortgage_insurance_percent\\\", FloatType),\\n\",\n    \"    StructField(\\\"current_loan_delinquency_status\\\", IntegerType),\\n\",\n    \"    StructField(\\\"current_actual_upb\\\", FloatType),\\n\",\n    \"    StructField(\\\"interest_rate\\\", FloatType),\\n\",\n    \"    StructField(\\\"loan_age\\\", FloatType),\\n\",\n    \"    StructField(\\\"msa\\\", FloatType),\\n\",\n    \"    StructField(\\\"non_interest_bearing_upb\\\", FloatType),\\n\",\n    \"    StructField(labelColName, IntegerType)))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Create a new spark session and load data\\n\",\n    \"we must create a new spark session to continue all spark operations.\\n\",\n    \"\\n\",\n    \"NOTE: in this notebook, we have uploaded dependency jars when installing toree kernel. If we don't upload them at installation time, we can also upload in notebook by [%AddJar magic](https://toree.incubator.apache.org/docs/current/user/faq/). However, there's one restriction for `%AddJar`: the jar uploaded can only be available when `AddJar` is called after a new spark session is created. We must use it as below:\\n\",\n    \"\\n\",\n    \"```scala\\n\",\n    \"import org.apache.spark.sql.SparkSession\\n\",\n    \"val spark = SparkSession.builder().appName(\\\"mortgage-gpu-cv\\\").getOrCreate\\n\",\n    \"%AddJar file:/data/libs/xgboost4j-spark-gpu_2.12-XXX.jar\\n\",\n    \"%AddJar file:/data/libs/xgboost4j-gpu_2.12-XXX.jar\\n\",\n    \"// ...\\n\",\n    \"```\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"spark = org.apache.spark.sql.SparkSession@51af6ff3\\n\",\n       \"trainDs = [orig_channel: double, first_home_buyer: double ... 26 more fields]\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"[orig_channel: double, first_home_buyer: double ... 26 more fields]\"\n      ]\n     },\n     \"execution_count\": 8,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"val spark = SparkSession.builder().appName(\\\"mortgage-gpu-cv\\\").getOrCreate()\\n\",\n    \"val trainDs = spark.read.parquet(trainParquetPath)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Find out features to train\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 9,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"featureNames = Array(orig_channel, first_home_buyer, loan_purpose, property_type, occupancy_status, property_state, product_type, relocation_mortgage_indicator, seller_name, mod_flag, orig_interest_rate, orig_upb, orig_loan_term, orig_ltv, orig_cltv, num_borrowers, dti, borrower_credit_score, num_units, zip, mortgage_insurance_percent, current_loan_delinquency_status, current_actual_upb, interest_rate, loan_age, msa, non_interest_bearing_upb)\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"Array(orig_channel, first_home_buyer, loan_purpose, property_type, occupancy_status, property_state, product_type, relocation_mortgage_indicator, seller_name, mod_flag, orig_interest_rate, orig_upb, orig_loan_term, orig_ltv, orig_cltv, num_borrowers, dti, borrower_credit_score, num_units, zip, mortgage_insurance_percent, current_loan_delinquency_status, current_actual_upb, interest_rate, loan_age, msa, non_interest_bearing_upb)\"\n      ]\n     },\n     \"execution_count\": 9,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"val featureNames = schema.filter(_.name != labelColName).map(_.name).toArray\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 13,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"classifierParam = Map(objective -> binary:logistic, num_round -> 100, num_workers -> 1, tree_method -> hist, device -> cuda)\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"Map(objective -> binary:logistic, num_round -> 100, num_workers -> 1, tree_method -> hist, device -> cuda)\"\n      ]\n     },\n     \"execution_count\": 13,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"val classifierParam = Map(\\n\",\n    \"    \\\"objective\\\" -> \\\"binary:logistic\\\",\\n\",\n    \"    \\\"num_round\\\" -> 100,\\n\",\n    \"    \\\"num_workers\\\" -> 1,\\n\",\n    \"    \\\"tree_method\\\" -> \\\"hist\\\",\\n\",\n    \"    \\\"device\\\" -> \\\"cuda\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Construct CrossValidator\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 14,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"classifier = xgbc_ae8896ab2b67\\n\",\n       \"paramGrid = \\n\",\n       \"evaluator = MulticlassClassificationEvaluator: uid=mcEval_ebda5b6cea6c, metricName=f1, metricLabel=0.0, beta=1.0, eps=1.0E-15\\n\",\n       \"cv = cv_cb7d8efe9ab5\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"Array({\\n\",\n       \"\\txgbc_ae8896ab2b67-eta: 0.2,\\n\",\n       \"\\txgbc_ae8896ab2b67-maxDepth: 3\\n\",\n       \"}, {\\n\",\n       \"\\txgbc_ae8896ab2b67-eta: 0.2,\\n\",\n       \"\\txgbc_ae8896ab2b67-maxDepth: 10\\n\",\n       \"}, {\\n\",\n       \"\\txgbc_ae8896ab2b67-eta: 0.6,\\n\",\n       \"\\txgbc_ae8896ab2b67-maxDepth: 3\\n\",\n       \"}, {\\n\",\n       \"\\txgbc_ae8896ab2b67-eta: 0.6,\\n\",\n       \"\\txgbc_ae8896ab2b67-maxDepth: 10\\n\",\n       \"})\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"cv_cb7d8efe9ab5\"\n      ]\n     },\n     \"execution_count\": 14,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"val classifier = new XGBoostClassifier(classifierParam)\\n\",\n    \"    .setLabelCol(labelColName)\\n\",\n    \"    .setFeaturesCol(featureNames)\\n\",\n    \"val paramGrid = new ParamGridBuilder()\\n\",\n    \"    .addGrid(classifier.maxDepth, Array(3, 10))\\n\",\n    \"    .addGrid(classifier.eta, Array(0.2, 0.6))\\n\",\n    \"    .build()\\n\",\n    \"val evaluator = new MulticlassClassificationEvaluator().setLabelCol(labelColName)\\n\",\n    \"val cv = new CrossValidator()\\n\",\n    \"    .setEstimator(classifier)\\n\",\n    \"    .setEvaluator(evaluator)\\n\",\n    \"    .setEstimatorParamMaps(paramGrid)\\n\",\n    \"    .setNumFolds(3)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## train with CrossValidator\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 15,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=41609, DMLC_NUM_WORKER=1}\\n\",\n      \"Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=45469, DMLC_NUM_WORKER=1}\\n\",\n      \"Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=52795, DMLC_NUM_WORKER=1}\\n\",\n      \"Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=53483, DMLC_NUM_WORKER=1}\\n\",\n      \"Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=58067, DMLC_NUM_WORKER=1}\\n\",\n      \"Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=43717, DMLC_NUM_WORKER=1}\\n\",\n      \"Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=36075, DMLC_NUM_WORKER=1}\\n\",\n      \"Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=53851, DMLC_NUM_WORKER=1}\\n\",\n      \"Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=42227, DMLC_NUM_WORKER=1}\\n\",\n      \"Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=46587, DMLC_NUM_WORKER=1}\\n\",\n      \"Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=51295, DMLC_NUM_WORKER=1}\\n\",\n      \"Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=54695, DMLC_NUM_WORKER=1}\\n\",\n      \"Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=54019, DMLC_NUM_WORKER=1}\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"model = xgbc_ae8896ab2b67\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"xgbc_ae8896ab2b67\"\n      ]\n     },\n     \"execution_count\": 15,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"val model = cv.fit(trainDs).bestModel.asInstanceOf[XGBoostClassificationModel]\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## tranform with best model trained by CrossValidator\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 16,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"transformDs = [orig_channel: double, first_home_buyer: double ... 26 more fields]\\n\",\n       \"df = [orig_channel: double, first_home_buyer: double ... 29 more fields]\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+--------------+--------------------+--------------------+----------+\\n\",\n      \"|delinquency_12|       rawPrediction|         probability|prediction|\\n\",\n      \"+--------------+--------------------+--------------------+----------+\\n\",\n      \"|             0|[17.3849449157714...|[0.99999997182821...|       0.0|\\n\",\n      \"|             0|[16.6074829101562...|[0.99999993869981...|       0.0|\\n\",\n      \"|             0|[16.0062618255615...|[0.99999988816731...|       0.0|\\n\",\n      \"|             0|[16.7623615264892...|[0.99999994749521...|       0.0|\\n\",\n      \"|             0|[15.1363153457641...|[0.99999973307967...|       0.0|\\n\",\n      \"+--------------+--------------------+--------------------+----------+\\n\",\n      \"only showing top 5 rows\\n\",\n      \"\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"[orig_channel: double, first_home_buyer: double ... 29 more fields]\"\n      ]\n     },\n     \"execution_count\": 16,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"val transformDs = spark.read.parquet(evalParquetPath)\\n\",\n    \"val df = model.transform(transformDs).cache()\\n\",\n    \"df.drop(featureNames: _*).show(5)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 17,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"evaluator = MulticlassClassificationEvaluator: uid=mcEval_d880c25944f1, metricName=f1, metricLabel=0.0, beta=1.0, eps=1.0E-15\\n\",\n       \"accuracy = 1.0\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"1.0\"\n      ]\n     },\n     \"execution_count\": 17,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"val evaluator = new MulticlassClassificationEvaluator().setLabelCol(labelColName)\\n\",\n    \"val accuracy = evaluator.evaluate(df)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 12,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"spark.close()\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"XGBoost4j-Spark - Scala\",\n   \"language\": \"scala\",\n   \"name\": \"XGBoost4j-Spark_scala\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": \"text/x-scala\",\n   \"file_extension\": \".scala\",\n   \"mimetype\": \"text/x-scala\",\n   \"name\": \"scala\",\n   \"pygments_lexer\": \"scala\",\n   \"version\": \"2.12.15\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "examples/XGBoost-Examples/mortgage/pom.xml",
    "content": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<!--\n  ~ Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.\n  ~\n  ~ Licensed under the Apache License, Version 2.0 (the \"License\");\n  ~ you may not use this file except in compliance with the License.\n  ~ You may obtain a copy of the License at\n  ~\n  ~ http://www.apache.org/licenses/LICENSE-2.0\n  ~\n  ~ Unless required by applicable law or agreed to in writing, software\n  ~ distributed under the License is distributed on an \"AS IS\" BASIS,\n  ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n  ~ See the License for the specific language governing permissions and\n  ~ limitations under the License.\n  -->\n\n<project xmlns=\"http://maven.apache.org/POM/4.0.0\"\n         xmlns:xsi=\"http://www.w3.org/2001/XMLSchema-instance\"\n         xsi:schemaLocation=\"http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd\">\n    <parent>\n        <artifactId>sample_xgboost_examples</artifactId>\n        <groupId>com.nvidia</groupId>\n        <version>0.2.3-SNAPSHOT</version>\n    </parent>\n    <modelVersion>4.0.0</modelVersion>\n\n    <artifactId>spark_examples_mortgage_${scala.binary.version}</artifactId>\n\n    <properties>\n        <maven.compiler.source>8</maven.compiler.source>\n        <maven.compiler.target>8</maven.compiler.target>\n    </properties>\n\n    <dependencies>\n        <dependency>\n            <groupId>com.nvidia</groupId>\n            <artifactId>spark_examples_utility_${scala.binary.version}</artifactId>\n            <version>${project.version}</version>\n            <scope>compile</scope>\n        </dependency>\n    </dependencies>\n\n    <build>\n        <sourceDirectory>scala/src</sourceDirectory>\n    </build>\n</project>"
  },
  {
    "path": "examples/XGBoost-Examples/mortgage/python/com/__init__.py",
    "content": "# Copyright (c) 2024, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n"
  },
  {
    "path": "examples/XGBoost-Examples/mortgage/python/com/nvidia/__init__.py",
    "content": "# Copyright (c) 2024, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n"
  },
  {
    "path": "examples/XGBoost-Examples/mortgage/python/com/nvidia/spark/__init__.py",
    "content": "# Copyright (c) 2024, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n"
  },
  {
    "path": "examples/XGBoost-Examples/mortgage/python/com/nvidia/spark/examples/__init__.py",
    "content": "# Copyright (c) 2024, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n"
  },
  {
    "path": "examples/XGBoost-Examples/mortgage/python/com/nvidia/spark/examples/mortgage/__init__.py",
    "content": "# Copyright (c) 2024, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n"
  },
  {
    "path": "examples/XGBoost-Examples/mortgage/python/com/nvidia/spark/examples/mortgage/consts.py",
    "content": "#\n# Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n# http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n\nfrom pyspark.sql.types import *\n\nlabel = 'delinquency_12'\n\nschema = StructType([\n    StructField('orig_channel', FloatType()),\n    StructField('first_home_buyer', FloatType()),\n    StructField('loan_purpose', FloatType()),\n    StructField('property_type', FloatType()),\n    StructField('occupancy_status', FloatType()),\n    StructField('property_state', FloatType()),\n    StructField('product_type', FloatType()),\n    StructField('relocation_mortgage_indicator', FloatType()),\n    StructField('seller_name', FloatType()),\n    StructField('mod_flag', FloatType()),\n    StructField('orig_interest_rate', FloatType()),\n    StructField('orig_upb', DoubleType()),\n    StructField('orig_loan_term', IntegerType()),\n    StructField('orig_ltv', FloatType()),\n    StructField('orig_cltv', FloatType()),\n    StructField('num_borrowers', FloatType()),\n    StructField('dti', FloatType()),\n    StructField('borrower_credit_score', FloatType()),\n    StructField('num_units', IntegerType()),\n    StructField('zip', IntegerType()),\n    StructField('mortgage_insurance_percent', FloatType()),\n    StructField('current_loan_delinquency_status', IntegerType()),\n    StructField('current_actual_upb', FloatType()),\n    StructField('interest_rate', FloatType()),\n    StructField('loan_age', FloatType()),\n    StructField('msa', FloatType()),\n    StructField('non_interest_bearing_upb', FloatType()),\n    StructField(label, IntegerType()),\n])\n\n\nname_mapping = {\n    'WITMER FUNDING, LLC': 'Witmer',\n    'WELLS FARGO CREDIT RISK TRANSFER SECURITIES TRUST 2015': 'Wells Fargo',\n    'WELLS FARGO BANK,  NA': 'Wells Fargo',\n    'WELLS FARGO BANK, N.A.': 'Wells Fargo',\n    'WELLS FARGO BANK, NA': 'Wells Fargo',\n    'USAA FEDERAL SAVINGS BANK': 'USAA',\n    'UNITED SHORE FINANCIAL SERVICES, LLC D\\\\/B\\\\/A UNITED WHOLESALE MORTGAGE': 'United Seq(e',\n    'U.S. BANK N.A.': 'US Bank',\n    'SUNTRUST MORTGAGE INC.': 'Suntrust',\n    'STONEGATE MORTGAGE CORPORATION': 'Stonegate Mortgage',\n    'STEARNS LENDING, LLC': 'Stearns Lending',\n    'STEARNS LENDING, INC.': 'Stearns Lending',\n    'SIERRA PACIFIC MORTGAGE COMPANY, INC.': 'Sierra Pacific Mortgage',\n    'REGIONS BANK': 'Regions',\n    'RBC MORTGAGE COMPANY': 'RBC',\n    'QUICKEN LOANS INC.': 'Quicken Loans',\n    'PULTE MORTGAGE, L.L.C.': 'Pulte Mortgage',\n    'PROVIDENT FUNDING ASSOCIATES, L.P.': 'Provident Funding',\n    'PROSPECT MORTGAGE, LLC': 'Prospect Mortgage',\n    'PRINCIPAL RESIDENTIAL MORTGAGE CAPITAL RESOURCES, LLC': 'Principal Residential',\n    'PNC BANK, N.A.': 'PNC',\n    'PMT CREDIT RISK TRANSFER TRUST 2015-2': 'PennyMac',\n    'PHH MORTGAGE CORPORATION': 'PHH Mortgage',\n    'PENNYMAC CORP.': 'PennyMac',\n    'PACIFIC UNION FINANCIAL, LLC': 'Other',\n    'OTHER': 'Other',\n    'NYCB MORTGAGE COMPANY, LLC': 'NYCB',\n    'NEW YORK COMMUNITY BANK': 'NYCB',\n    'NETBANK FUNDING SERVICES': 'Netbank',\n    'NATIONSTAR MORTGAGE, LLC': 'Nationstar Mortgage',\n    'METLIFE BANK, NA': 'Metlife',\n    'LOANDEPOT.COM, LLC': 'LoanDepot.com',\n    'J.P. MORGAN MADISON AVENUE SECURITIES TRUST, SERIES 2015-1': 'JP Morgan Chase',\n    'J.P. MORGAN MADISON AVENUE SECURITIES TRUST, SERIES 2014-1': 'JP Morgan Chase',\n    'JPMORGAN CHASE BANK, NATIONAL ASSOCIATION': 'JP Morgan Chase',\n    'JPMORGAN CHASE BANK, NA': 'JP Morgan Chase',\n    'JP MORGAN CHASE BANK, NA': 'JP Morgan Chase',\n    'IRWIN MORTGAGE, CORPORATION': 'Irwin Mortgage',\n    'IMPAC MORTGAGE CORP.': 'Impac Mortgage',\n    'HSBC BANK USA, NATIONAL ASSOCIATION': 'HSBC',\n    'HOMEWARD RESIDENTIAL, INC.': 'Homeward Mortgage',\n    'HOMESTREET BANK': 'Other',\n    'HOMEBRIDGE FINANCIAL SERVICES, INC.': 'HomeBridge',\n    'HARWOOD STREET FUNDING I, LLC': 'Harwood Mortgage',\n    'GUILD MORTGAGE COMPANY': 'Guild Mortgage',\n    'GMAC MORTGAGE, LLC (USAA FEDERAL SAVINGS BANK)': 'GMAC',\n    'GMAC MORTGAGE, LLC': 'GMAC',\n    'GMAC (USAA)': 'GMAC',\n    'FREMONT BANK': 'Fremont Bank',\n    'FREEDOM MORTGAGE CORP.': 'Freedom Mortgage',\n    'FRANKLIN AMERICAN MORTGAGE COMPANY': 'Franklin America',\n    'FLEET NATIONAL BANK': 'Fleet National',\n    'FLAGSTAR CAPITAL MARKETS CORPORATION': 'Flagstar Bank',\n    'FLAGSTAR BANK, FSB': 'Flagstar Bank',\n    'FIRST TENNESSEE BANK NATIONAL ASSOCIATION': 'Other',\n    'FIFTH THIRD BANK': 'Fifth Third Bank',\n    'FEDERAL HOME LOAN BANK OF CHICAGO': 'Fedral Home of Chicago',\n    'FDIC, RECEIVER, INDYMAC FEDERAL BANK FSB': 'FDIC',\n    'DOWNEY SAVINGS AND LOAN ASSOCIATION, F.A.': 'Downey Mortgage',\n    'DITECH FINANCIAL LLC': 'Ditech',\n    'CITIMORTGAGE, INC.': 'Citi',\n    'CHICAGO MORTGAGE SOLUTIONS DBA INTERFIRST MORTGAGE COMPANY': 'Chicago Mortgage',\n    'CHICAGO MORTGAGE SOLUTIONS DBA INTERBANK MORTGAGE COMPANY': 'Chicago Mortgage',\n    'CHASE HOME FINANCE, LLC': 'JP Morgan Chase',\n    'CHASE HOME FINANCE FRANKLIN AMERICAN MORTGAGE COMPANY': 'JP Morgan Chase',\n    'CHASE HOME FINANCE (CIE 1)': 'JP Morgan Chase',\n    'CHASE HOME FINANCE': 'JP Morgan Chase',\n    'CASHCALL, INC.': 'CashCall',\n    'CAPITAL ONE, NATIONAL ASSOCIATION': 'Capital One',\n    'CALIBER HOME LOANS, INC.': 'Caliber Funding',\n    'BISHOPS GATE RESIDENTIAL MORTGAGE TRUST': 'Bishops Gate Mortgage',\n    'BANK OF AMERICA, N.A.': 'Bank of America',\n    'AMTRUST BANK': 'AmTrust',\n    'AMERISAVE MORTGAGE CORPORATION': 'Amerisave',\n    'AMERIHOME MORTGAGE COMPANY, LLC': 'AmeriHome Mortgage',\n    'ALLY BANK': 'Ally Bank',\n    'ACADEMY MORTGAGE CORPORATION': 'Academy Mortgage',\n    'NO CASH-OUT REFINANCE': 'OTHER REFINANCE',\n    'REFINANCE - NOT SPECIFIED': 'OTHER REFINANCE',\n    'Other REFINANCE': 'OTHER REFINANCE',\n}\n\n\nrawSchema = StructType([\n      StructField(\"reference_pool_id\", StringType()),\n      StructField(\"loan_id\", LongType()),\n      StructField(\"monthly_reporting_period\", StringType()),\n      StructField(\"orig_channel\", StringType()),\n      StructField(\"seller_name\", StringType()),\n      StructField(\"servicer\", StringType()),\n      StructField(\"master_servicer\", StringType()),\n      StructField(\"orig_interest_rate\", DoubleType()),\n      StructField(\"interest_rate\", DoubleType()),\n      StructField(\"orig_upb\", DoubleType()),\n      StructField(\"upb_at_issuance\", StringType()),\n      StructField(\"current_actual_upb\", DoubleType()),\n      StructField(\"orig_loan_term\", IntegerType()),\n      StructField(\"orig_date\", StringType()),\n      StructField(\"first_pay_date\", StringType()),    \n      StructField(\"loan_age\", DoubleType()),\n      StructField(\"remaining_months_to_legal_maturity\", DoubleType()),\n      StructField(\"adj_remaining_months_to_maturity\", DoubleType()),\n      StructField(\"maturity_date\", StringType()),\n      StructField(\"orig_ltv\", DoubleType()),\n      StructField(\"orig_cltv\", DoubleType()),\n      StructField(\"num_borrowers\", DoubleType()),\n      StructField(\"dti\", DoubleType()),\n      StructField(\"borrower_credit_score\", DoubleType()),\n      StructField(\"coborrow_credit_score\", DoubleType()),\n      StructField(\"first_home_buyer\", StringType()),\n      StructField(\"loan_purpose\", StringType()),\n      StructField(\"property_type\", StringType()),\n      StructField(\"num_units\", IntegerType()),\n      StructField(\"occupancy_status\", StringType()),\n      StructField(\"property_state\", StringType()),\n      StructField(\"msa\", DoubleType()),\n      StructField(\"zip\", IntegerType()),\n      StructField(\"mortgage_insurance_percent\", DoubleType()),\n      StructField(\"product_type\", StringType()),\n      StructField(\"prepayment_penalty_indicator\", StringType()),\n      StructField(\"interest_only_loan_indicator\", StringType()),\n      StructField(\"interest_only_first_principal_and_interest_payment_date\", StringType()),\n      StructField(\"months_to_amortization\", StringType()),\n      StructField(\"current_loan_delinquency_status\", IntegerType()),\n      StructField(\"loan_payment_history\", StringType()),\n      StructField(\"mod_flag\", StringType()),\n      StructField(\"mortgage_insurance_cancellation_indicator\", StringType()),\n      StructField(\"zero_balance_code\", StringType()),\n      StructField(\"zero_balance_effective_date\", StringType()),\n      StructField(\"upb_at_the_time_of_removal\", StringType()),\n      StructField(\"repurchase_date\", StringType()),\n      StructField(\"scheduled_principal_current\", StringType()),\n      StructField(\"total_principal_current\", StringType()),\n      StructField(\"unscheduled_principal_current\", StringType()),\n      StructField(\"last_paid_installment_date\", StringType()),\n      StructField(\"foreclosed_after\", StringType()),\n      StructField(\"disposition_date\", StringType()),\n      StructField(\"foreclosure_costs\", DoubleType()),\n      StructField(\"prop_preservation_and_repair_costs\", DoubleType()),\n      StructField(\"asset_recovery_costs\", DoubleType()),\n      StructField(\"misc_holding_expenses\", DoubleType()),\n      StructField(\"holding_taxes\", DoubleType()),\n      StructField(\"net_sale_proceeds\", DoubleType()),\n      StructField(\"credit_enhancement_proceeds\", DoubleType()),\n      StructField(\"repurchase_make_whole_proceeds\", StringType()),\n      StructField(\"other_foreclosure_proceeds\", DoubleType()),\n      StructField(\"non_interest_bearing_upb\", DoubleType()),\n      StructField(\"principal_forgiveness_upb\", StringType()),\n      StructField(\"original_list_start_date\", StringType()),\n      StructField(\"original_list_price\", StringType()),\n      StructField(\"current_list_start_date\", StringType()),\n      StructField(\"current_list_price\", StringType()),\n      StructField(\"borrower_credit_score_at_issuance\", StringType()),\n      StructField(\"co-borrower_credit_score_at_issuance\", StringType()),\n      StructField(\"borrower_credit_score_current\", StringType()),\n      StructField(\"co-Borrower_credit_score_current\", StringType()),\n      StructField(\"mortgage_insurance_type\", DoubleType()),\n      StructField(\"servicing_activity_indicator\", StringType()),\n      StructField(\"current_period_modification_loss_amount\", StringType()),\n      StructField(\"cumulative_modification_loss_amount\", StringType()),\n      StructField(\"current_period_credit_event_net_gain_or_loss\", StringType()),\n      StructField(\"cumulative_credit_event_net_gain_or_loss\", StringType()),\n      StructField(\"homeready_program_indicator\", StringType()),\n      StructField(\"foreclosure_principal_write_off_amount\", StringType()),\n      StructField(\"relocation_mortgage_indicator\", StringType()),\n      StructField(\"zero_balance_code_change_date\", StringType()),\n      StructField(\"loan_holdback_indicator\", StringType()),\n      StructField(\"loan_holdback_effective_date\", StringType()),\n      StructField(\"delinquent_accrued_interest\", StringType()),\n      StructField(\"property_valuation_method\", StringType()),\n      StructField(\"high_balance_loan_indicator\", StringType()),\n      StructField(\"arm_initial_fixed-rate_period_lt_5_yr_indicator\", StringType()),\n      StructField(\"arm_product_type\", StringType()),\n      StructField(\"initial_fixed-rate_period\", StringType()),\n      StructField(\"interest_rate_adjustment_frequency\", StringType()),\n      StructField(\"next_interest_rate_adjustment_date\", StringType()),\n      StructField(\"next_payment_change_date\", StringType()),\n      StructField(\"index\", StringType()),\n      StructField(\"arm_cap_structure\", StringType()),\n      StructField(\"initial_interest_rate_cap_up_percent\", StringType()),\n      StructField(\"periodic_interest_rate_cap_up_percent\", StringType()),\n      StructField(\"lifetime_interest_rate_cap_up_percent\", StringType()),\n      StructField(\"mortgage_margin\", StringType()),\n      StructField(\"arm_balloon_indicator\", StringType()),\n      StructField(\"arm_plan_number\", StringType()),\n      StructField(\"borrower_assistance_plan\", StringType()),\n      StructField(\"hltv_refinance_option_indicator\", StringType()),\n      StructField(\"deal_name\", StringType()),\n      StructField(\"repurchase_make_whole_proceeds_flag\", StringType()),\n      StructField(\"alternative_delinquency_resolution\", StringType()),\n      StructField(\"alternative_delinquency_resolution_count\", StringType()),\n      StructField(\"total_deferral_amount\", StringType())\n      ])\n\ncategorical_columns = [\n    'orig_channel',\n    'first_home_buyer',\n    'loan_purpose',\n    'property_type',\n    'occupancy_status',\n    'property_state',\n    'product_type',\n    'relocation_mortgage_indicator',\n    'seller_name',\n    'mod_flag',\n]\n\nnumeric_columns = [\n    'orig_interest_rate',\n    'orig_upb',\n    'orig_loan_term',\n    'orig_ltv',\n    'orig_cltv',\n    'num_borrowers',\n    'dti',\n    'borrower_credit_score',\n    'num_units',\n    'zip',\n    'mortgage_insurance_percent',\n    'current_loan_delinquency_status',\n    'current_actual_upb',\n    'interest_rate',\n    'loan_age',\n    'msa',\n    'non_interest_bearing_upb',\n    'delinquency_12',\n]\n"
  },
  {
    "path": "examples/XGBoost-Examples/mortgage/python/com/nvidia/spark/examples/mortgage/cross_validator_main.py",
    "content": "#\n# Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n# http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\nfrom pyspark.ml.tuning import ParamGridBuilder, CrossValidator\n\nfrom .consts import *\nfrom com.nvidia.spark.examples.utility.utils import *\nfrom pyspark.sql import SparkSession\n\nfrom xgboost.spark import SparkXGBClassifier, SparkXGBClassifierModel\n\n\ndef main(args, xgboost_args):\n    spark = (SparkSession\n             .builder\n             .appName(args.mainClass)\n             .getOrCreate())\n\n    train_data, eval_data, trans_data = valid_input_data(spark, args, '', schema)\n\n    if args.mode in ['all', 'train']:\n        if train_data is None:\n            print('-' * 80)\n            print('Usage: training data path required when mode is all or train')\n            exit(1)\n\n        train_data, features = transform_data(train_data, label, args.use_gpu)\n        xgboost_args['features_col'] = features\n        xgboost_args['label_col'] = label\n\n        classifier = SparkXGBClassifier(**xgboost_args)\n\n        evaluator = (MulticlassClassificationEvaluator()\n                     .setLabelCol(label))\n\n        param_grid = (ParamGridBuilder()\n                      .addGrid(classifier.max_depth, [6, 8])\n                      .addGrid(classifier.n_estimators, [20, 40])\n                      .build())\n        cross_validator = (CrossValidator()\n                           .setEstimator(classifier)\n                           .setEvaluator(evaluator)\n                           .setEstimatorParamMaps(param_grid)\n                           .setNumFolds(3))\n        if not train_data:\n            print('-' * 80)\n            print('Usage: training data path required when mode is all or train')\n            exit(1)\n\n        model = with_benchmark('Training', lambda: cross_validator.fit(train_data))\n        # get the best model to do transform\n        model = model.bestModel\n        if args.modelPath:\n            writer = model.write().overwrite() if args.overwrite else model\n            writer.save(args.modelPath)\n    else:\n        model = SparkXGBClassifierModel.load(args.modelPath)\n\n    if args.mode in ['all', 'transform']:\n        if not trans_data:\n            print('-' * 80)\n            print('Usage: trans data path required when mode is all or transform')\n            exit(1)\n\n        trans_data, _ = transform_data(trans_data, label, args.use_gpu)\n\n        def transform():\n            result = model.transform(trans_data).cache()\n            result.foreachPartition(lambda _: None)\n            return result\n\n        result = with_benchmark('Transformation', transform)\n        show_sample(args, result, label)\n        with_benchmark('Evaluation', lambda: check_classification_accuracy(result, label))\n\n    spark.stop()\n"
  },
  {
    "path": "examples/XGBoost-Examples/mortgage/python/com/nvidia/spark/examples/mortgage/etl.py",
    "content": "#\n# Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n# http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\nfrom .consts import *\nfrom pyspark.sql.functions import *\nfrom pyspark.sql.types import *\nfrom pyspark.sql.window import Window\nfrom sys import exit\n\nget_quarter = udf(lambda path: path.split(r'.')[0].split('/')[-1], StringType())\nstandardize_name = udf(lambda name: name_mapping.get(name), StringType())\n\ndef load_data(spark, paths, schema, args, extra_csv_opts={}):\n    reader = (spark\n        .read\n        .format(args.format)\n        .option('asFloats', args.asFloats)\n        .option('maxRowsPerChunk', args.maxRowsPerChunk))\n    if args.format == 'csv':\n        (reader\n            .schema(schema)\n            .option('delimiter', '|')\n            .option('header', False))\n        for k, v in extra_csv_opts.items():\n            reader.option(k, v)\n    return reader.load(paths)\n\ndef prepare_rawDf(spark, args):\n    extra_csv_options = {\n        'nullValue': '',\n        'parserLib': 'univocity',\n    }\n    paths = extract_paths(args.dataPaths, 'data::')\n    rawDf = load_data(spark, paths, rawSchema, args, extra_csv_options)\n\n    return rawDf\n\ndef extract_perf_columns(rawDf):\n    perfDf = rawDf.select(\n      col(\"loan_id\"),\n      date_format(to_date(col(\"monthly_reporting_period\"),\"MMyyyy\"), \"MM/dd/yyyy\").alias(\"monthly_reporting_period\"),\n      upper(col(\"servicer\")).alias(\"servicer\"),\n      col(\"interest_rate\"),\n      col(\"current_actual_upb\"),\n      col(\"loan_age\"),\n      col(\"remaining_months_to_legal_maturity\"),\n      col(\"adj_remaining_months_to_maturity\"),\n      date_format(to_date(col(\"maturity_date\"),\"MMyyyy\"), \"MM/yyyy\").alias(\"maturity_date\"),\n      col(\"msa\"),\n      col(\"current_loan_delinquency_status\"),\n      col(\"mod_flag\"),\n      col(\"zero_balance_code\"),\n      date_format(to_date(col(\"zero_balance_effective_date\"),\"MMyyyy\"), \"MM/yyyy\").alias(\"zero_balance_effective_date\"),\n      date_format(to_date(col(\"last_paid_installment_date\"),\"MMyyyy\"), \"MM/dd/yyyy\").alias(\"last_paid_installment_date\"),\n      date_format(to_date(col(\"foreclosed_after\"),\"MMyyyy\"), \"MM/dd/yyyy\").alias(\"foreclosed_after\"),\n      date_format(to_date(col(\"disposition_date\"),\"MMyyyy\"), \"MM/dd/yyyy\").alias(\"disposition_date\"),\n      col(\"foreclosure_costs\"),\n      col(\"prop_preservation_and_repair_costs\"),\n      col(\"asset_recovery_costs\"),\n      col(\"misc_holding_expenses\"),\n      col(\"holding_taxes\"),\n      col(\"net_sale_proceeds\"),\n      col(\"credit_enhancement_proceeds\"),\n      col(\"repurchase_make_whole_proceeds\"),\n      col(\"other_foreclosure_proceeds\"),\n      col(\"non_interest_bearing_upb\"),\n      col(\"principal_forgiveness_upb\"),\n      col(\"repurchase_make_whole_proceeds_flag\"),\n      col(\"foreclosure_principal_write_off_amount\"),\n      col(\"servicing_activity_indicator\"))\n\n    return perfDf.select(\"*\").filter(\"current_actual_upb != 0.0\")\n    \n\ndef prepare_performance(spark, args, rawDf):\n    performance = (extract_perf_columns(rawDf)\n        .withColumn('quarter', get_quarter(input_file_name()))\n        .withColumn('timestamp', to_date(col('monthly_reporting_period'), 'MM/dd/yyyy'))\n        .withColumn('timestamp_year', year(col('timestamp')))\n        .withColumn('timestamp_month', month(col('timestamp'))))\n\n    aggregation = (performance\n        .select(\n            'quarter',\n            'loan_id',\n            'current_loan_delinquency_status',\n            when(col('current_loan_delinquency_status') >= 1, col('timestamp'))\n                .alias('delinquency_30'),\n            when(col('current_loan_delinquency_status') >= 3, col('timestamp'))\n                .alias('delinquency_90'),\n            when(col('current_loan_delinquency_status') >= 6, col('timestamp'))\n                .alias('delinquency_180'))\n        .groupBy('quarter', 'loan_id')\n        .agg(\n            max('current_loan_delinquency_status').alias('delinquency_12'),\n            min('delinquency_30').alias('delinquency_30'),\n            min('delinquency_90').alias('delinquency_90'),\n            min('delinquency_180').alias('delinquency_180'))\n        .select(\n            'quarter',\n            'loan_id',\n            (col('delinquency_12') >= 1).alias('ever_30'),\n            (col('delinquency_12') >= 3).alias('ever_90'),\n            (col('delinquency_12') >= 6).alias('ever_180'),\n            'delinquency_30',\n            'delinquency_90',\n            'delinquency_180'))\n\n    months = spark.createDataFrame(range(12), IntegerType()).withColumnRenamed('value', 'month_y')\n    to_join = (performance\n        .select(\n            'quarter',\n            'loan_id',\n            'timestamp_year',\n            'timestamp_month',\n            col('current_loan_delinquency_status').alias('delinquency_12'),\n            col('current_actual_upb').alias('upb_12'))\n        .join(aggregation, ['loan_id', 'quarter'], 'left_outer')\n        .crossJoin(months)\n        .select(\n            'quarter',\n            floor(\n                (col('timestamp_year') * 12 + col('timestamp_month') - 24000 - col('month_y')) / 12\n            ).alias('josh_mody_n'),\n            'ever_30',\n            'ever_90',\n            'ever_180',\n            'delinquency_30',\n            'delinquency_90',\n            'delinquency_180',\n            'loan_id',\n            'month_y',\n            'delinquency_12',\n            'upb_12')\n        .groupBy(\n            'quarter',\n            'loan_id',\n            'josh_mody_n',\n            'ever_30',\n            'ever_90',\n            'ever_180',\n            'delinquency_30',\n            'delinquency_90',\n            'delinquency_180',\n            'month_y')\n        .agg(\n            max('delinquency_12').alias('delinquency_12'),\n            min('upb_12').alias('upb_12'))\n        .withColumn(\n            'timestamp_year',\n            floor((24000 + (col('josh_mody_n') * 12) + (col('month_y') - 1)) / 12))\n        .withColumn(\n            'timestamp_month_tmp',\n            (24000 + (col('josh_mody_n') * 12) + col('month_y')) % 12)\n        .withColumn(\n            'timestamp_month',\n            when(col('timestamp_month_tmp') == 0, 12).otherwise(col('timestamp_month_tmp')))\n        .withColumn(\n            'delinquency_12',\n            ((col('delinquency_12') > 3).cast('int') + (col('upb_12') == 0).cast('int')))\n        .drop('timestamp_month_tmp', 'josh_mody_n', 'month_y'))\n\n    return (performance\n        .join(to_join, ['quarter', 'loan_id', 'timestamp_year', 'timestamp_month'], 'left')\n        .drop('timestamp_year', 'timestamp_month'))\n\ndef extract_acq_columns(rawDf):\n    acqDf = rawDf.select(\n      col(\"loan_id\"),\n      col(\"orig_channel\"),\n      upper(col(\"seller_name\")).alias(\"seller_name\"),\n      col(\"orig_interest_rate\"),\n      col(\"orig_upb\"),\n      col(\"orig_loan_term\"),\n      date_format(to_date(col(\"orig_date\"),\"MMyyyy\"), \"MM/yyyy\").alias(\"orig_date\"),\n      date_format(to_date(col(\"first_pay_date\"),\"MMyyyy\"), \"MM/yyyy\").alias(\"first_pay_date\"),\n      col(\"orig_ltv\"),\n      col(\"orig_cltv\"),\n      col(\"num_borrowers\"),\n      col(\"dti\"),\n      col(\"borrower_credit_score\"),\n      col(\"first_home_buyer\"),\n      col(\"loan_purpose\"),\n      col(\"property_type\"),\n      col(\"num_units\"),\n      col(\"occupancy_status\"),\n      col(\"property_state\"),\n      col(\"zip\"),\n      col(\"mortgage_insurance_percent\"),\n      col(\"product_type\"),\n      col(\"coborrow_credit_score\"),\n      col(\"mortgage_insurance_type\"),\n      col(\"relocation_mortgage_indicator\"),\n      dense_rank().over(Window.partitionBy(\"loan_id\").orderBy(to_date(col(\"monthly_reporting_period\"),\"MMyyyy\"))).alias(\"rank\")\n      )\n\n    return acqDf.select(\"*\").filter(col(\"rank\")==1)\n\n    \n\ndef prepare_acquisition(spark, args, rawDf):\n    return (extract_acq_columns(rawDf)\n        .withColumn('quarter', get_quarter(input_file_name()))\n        .withColumn('seller_name', standardize_name(col('seller_name'))))\n\ndef extract_paths(paths, prefix):\n    results = [ path[len(prefix):] for path in paths if path.startswith(prefix) ]\n    if not results:\n        print('-' * 80)\n        print('Usage: {} data path required'.format(prefix))\n        exit(1)\n    return results\n\ndef etl(spark, args):\n    rawDf = prepare_rawDf(spark, args)\n    rawDf.write.parquet(extract_paths(args.dataPaths, 'tmp::')[0], mode='overwrite')\n    rawDf = spark.read.parquet(extract_paths(args.dataPaths, 'tmp::')[0])\n    \n    performance = prepare_performance(spark, args, rawDf)\n    acquisition = prepare_acquisition(spark, args, rawDf)\n    return (performance\n        .join(acquisition, ['loan_id', 'quarter'], 'left_outer')\n        .select(\n            [(md5(col(x)) % 100).alias(x) for x in categorical_columns]\n            + [col(x) for x in numeric_columns])\n        .withColumn('delinquency_12', when(col('delinquency_12') > 0, 1).otherwise(0))\n        .na\n        .fill(0))\n\n"
  },
  {
    "path": "examples/XGBoost-Examples/mortgage/python/com/nvidia/spark/examples/mortgage/etl_main.py",
    "content": "#\n# Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n# http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\nfrom .etl import etl, extract_paths\nfrom com.nvidia.spark.examples.utility.utils import *\nfrom pyspark.sql import SparkSession\n\n\ndef main(args, xgboost_args):\n    spark = (SparkSession\n             .builder\n             .appName(args.mainClass)\n             .getOrCreate())\n    etled_df = etl(spark, args)\n    # outPath should has only one input\n    outPath = extract_paths(args.dataPaths, 'out::')[0]\n    etled_df.write.mode(\"overwrite\").parquet(outPath)\n"
  },
  {
    "path": "examples/XGBoost-Examples/mortgage/python/com/nvidia/spark/examples/mortgage/main.py",
    "content": "#\n# Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n# http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n\nfrom xgboost.spark import SparkXGBClassifier, SparkXGBClassifierModel\n\nfrom .consts import *\nfrom com.nvidia.spark.examples.utility.utils import *\nfrom pyspark.sql import SparkSession\n\n\ndef main(args, xgboost_args):\n    spark = (SparkSession\n             .builder\n             .appName(args.mainClass)\n             .getOrCreate())\n\n    train_data, eval_data, trans_data = valid_input_data(spark, args, '', schema)\n\n    if args.mode in ['all', 'train']:\n        if train_data is None:\n            print('-' * 80)\n            print('Usage: training data path required when mode is all or train')\n            exit(1)\n\n        train_data, features = transform_data(train_data, label, args.use_gpu)\n        xgboost_args['features_col'] = features\n        xgboost_args['label_col'] = label\n        classifier = SparkXGBClassifier(**xgboost_args)\n\n        if eval_data:\n            # TODO\n            pass\n\n        model = with_benchmark('Training', lambda: classifier.fit(train_data))\n\n        if args.modelPath:\n            writer = model.write().overwrite() if args.overwrite else model\n            writer.save(args.modelPath)\n    else:\n        model = SparkXGBClassifierModel.load(args.modelPath)\n\n    if args.mode in ['all', 'transform']:\n        trans_data, _ = transform_data(trans_data, label, args.use_gpu)\n\n        def transform():\n            result = model.transform(trans_data).cache()\n            result.foreachPartition(lambda _: None)\n            return result\n\n        if not trans_data:\n            print('-' * 80)\n            print('Usage: trans data path required when mode is all or transform')\n            exit(1)\n\n        result = with_benchmark('Transformation', transform)\n        show_sample(args, result, label)\n        with_benchmark('Evaluation', lambda: check_classification_accuracy(result, label))\n\n    spark.stop()\n"
  },
  {
    "path": "examples/XGBoost-Examples/mortgage/scala/src/com/nvidia/spark/examples/mortgage/CrossValidationMain.scala",
    "content": "/*\n * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.nvidia.spark.examples.mortgage\n\nimport com.nvidia.spark.examples.utility.{XGBoostArgs, Benchmark}\nimport ml.dmlc.xgboost4j.scala.spark.{XGBoostClassificationModel, XGBoostClassifier}\nimport org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator\nimport org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}\nimport org.apache.spark.sql.SparkSession\n\nobject CrossValidationMain extends Mortgage {\n\n  def main(args: Array[String]): Unit = {\n    val appArgs = XGBoostArgs(args)\n    val processor = this.getClass.getSimpleName.stripSuffix(\"$\").substring(0, 3)\n    val appInfo = Seq(appName, processor, appArgs.format)\n    val benchmark = Benchmark(appInfo(0), appInfo(1), appInfo(2))\n    // build spark session\n    val spark = SparkSession.builder().appName(appInfo.mkString(\"-\")).getOrCreate()\n    // build data reader\n    val dataReader = spark.read\n\n    try {\n      // loaded XGBoost ETLed data\n      val pathsArray = appArgs.getDataPaths\n      // 0: train 1: eval 2:transform\n      val datasets = pathsArray.map { paths =>\n        if (paths.nonEmpty) {\n          appArgs.format match {\n            case \"csv\" => Some(dataReader.option(\"header\", appArgs.hasHeader).schema(schema).csv(paths: _*))\n            case \"orc\" => Some(dataReader.orc(paths: _*))\n            case \"parquet\" => Some(dataReader.parquet(paths: _*))\n            case _ => throw new IllegalArgumentException(\"Unsupported data file format!\")\n          }\n        } else {\n          None\n        }\n      }\n\n      val xgbClassificationModel = if (appArgs.isToTrain) {\n        // build XGBoost classifier\n        val xgbParamFinal = appArgs.xgboostParams(commParamMap)\n        val xgbClassifier = new XGBoostClassifier(xgbParamFinal)\n          .setLabelCol(labelColName)\n          .setFeaturesCol(featureNames)\n\n        // Tune model using cross validation\n        val paramGrid = new ParamGridBuilder()\n          .addGrid(xgbClassifier.maxDepth, Array(3, 10))\n          .addGrid(xgbClassifier.eta, Array(0.2, 0.6))\n          .build()\n        val evaluator = new MulticlassClassificationEvaluator().setLabelCol(labelColName)\n\n        val cv = new CrossValidator()\n          .setEstimator(xgbClassifier)\n          .setEvaluator(evaluator)\n          .setEstimatorParamMaps(paramGrid)\n          .setNumFolds(appArgs.numFold)\n\n        // Start training\n        println(\"\\n------ CrossValidation ------\")\n        // Shall we not log the time if it is abnormal, which is usually caused by training failure\n        val (model, _) = benchmark.time(\"CrossValidation\") {\n          cv.fit(datasets(0).get).bestModel.asInstanceOf[XGBoostClassificationModel]\n        }\n        // Save model if modelPath exists\n        appArgs.modelPath.foreach(path =>\n          if (appArgs.isOverwrite) model.write.overwrite().save(path) else model.save(path))\n        model\n      } else {\n        XGBoostClassificationModel.load(appArgs.modelPath.get)\n      }\n\n      if (appArgs.isToTransform) {\n        println(\"\\n------ Transforming ------\")\n        var (results, _) = benchmark.time(\"transform\") {\n          val ret = xgbClassificationModel.transform(datasets(2).get).cache()\n          // Trigger the transformation\n          ret.foreachPartition((_: Iterator[_]) => ())\n          ret\n        }\n        results = if (appArgs.isShowFeatures) {\n          results\n        } else {\n          results.select(labelColName, \"rawPrediction\", \"probability\", \"prediction\")\n        }\n        results.show(appArgs.numRows)\n\n        println(\"\\n------Accuracy of Evaluation------\")\n        val evaluator = new MulticlassClassificationEvaluator().setLabelCol(labelColName)\n        evaluator.evaluate(results) match {\n          case accuracy if !accuracy.isNaN =>\n            benchmark.value(accuracy, \"Accuracy\", \"Accuracy for\")\n          // Throw an exception when NaN ?\n        }\n      }\n    } finally {\n      spark.close()\n    }\n  }\n}\n"
  },
  {
    "path": "examples/XGBoost-Examples/mortgage/scala/src/com/nvidia/spark/examples/mortgage/ETLMain.scala",
    "content": "/*\n * Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.nvidia.spark.examples.mortgage\n\nimport com.nvidia.spark.examples.utility.{XGBoostArgs, Benchmark}\nimport org.apache.hadoop.fs.Path\nimport org.apache.spark.sql.SparkSession\n\nobject ETLMain extends Mortgage {\n\n  def main(args: Array[String]): Unit = {\n    val xgbArgs = XGBoostArgs(args)\n    val subTitle = getClass.getSimpleName.stripSuffix(\"$\").substring(0, 3)\n    val appInfo = Seq(appName, subTitle, xgbArgs.format)\n    val benchmark = Benchmark(appInfo(0), appInfo(1), appInfo(2))\n    // build spark session\n    val spark = SparkSession.builder().appName(appInfo.mkString(\"-\")).getOrCreate()\n\n    try {\n      val (dataPaths, outPath, tmpPath) = checkAndGetPaths(xgbArgs.dataPaths)\n      println(\"\\n------ Start ETL ------\")\n      benchmark.time(\"ETL\") {\n        // ETL the raw data\n        val rawDF = xgbArgs.format match {\n          case \"csv\" => XGBoostETL.csv(spark, dataPaths, tmpPath, false)\n          case \"orc\" => XGBoostETL.orc(spark, dataPaths)\n          case \"parquet\" => XGBoostETL.parquet(spark, dataPaths)\n          case _ => throw new IllegalArgumentException(\"Unsupported data file format!\")\n        }\n        rawDF.write.mode(\"overwrite\").parquet(outPath)\n      }\n      if (xgbArgs.saveDict) {\n        XGBoostETL.saveDictTable(new Path(outPath, \".dict\").toString)\n      }\n    } finally {\n      XGBoostETL.clean()\n      spark.close()\n    }\n  }\n\n  def checkAndGetPaths(paths: Seq[String]): (Seq[String], String, String) = {\n    val prefixes = Array(\"data::\", \"out::\",  \"tmp::\")\n    val validPaths = paths.filter(_.nonEmpty).map(_.trim)\n\n    // get and check perf data paths\n    val dataPaths = validPaths.filter(_.startsWith(prefixes.head))\n    require(dataPaths.nonEmpty, s\"$appName ETL requires at least one path for data file.\" +\n      s\" Please specify it by '-dataPath=data::your_data_path'\")\n\n    // get and check out path\n    val outPath = validPaths.filter(_.startsWith(prefixes(1)))\n    require(outPath.nonEmpty, s\"$appName ETL requires a path to save the ETLed data file. Please specify it\" +\n      \" by '-dataPath=out::your_out_path', only the first path is used if multiple paths are found.\")\n    \n    // get and check tmp path\n    val tmpPath = validPaths.filter(_.startsWith(prefixes(2)))\n    require(tmpPath.nonEmpty, s\"$appName ETL requires a path to save the temp parquet files. Please specify it\" +\n      \" by '-dataPath=tmp::your_out_path'.\")\n\n    // check data paths not specified type\n    val unknownPaths = validPaths.filterNot(p => prefixes.exists(p.contains(_)))\n    require(unknownPaths.isEmpty, s\"Unknown type for data path: ${unknownPaths.head}, $appName requires to specify\" +\n      \" the type for each data path by adding the prefix 'data::' or 'out::'.\")\n\n    (dataPaths.map(_.stripPrefix(prefixes.head)),\n     outPath.head.stripPrefix(prefixes(1)),\n     tmpPath.head.stripPrefix(prefixes(2)))\n  }\n}\n"
  },
  {
    "path": "examples/XGBoost-Examples/mortgage/scala/src/com/nvidia/spark/examples/mortgage/Main.scala",
    "content": "/*\n * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.nvidia.spark.examples.mortgage\n\nimport com.nvidia.spark.examples.utility.{XGBoostArgs, Benchmark}\nimport ml.dmlc.xgboost4j.scala.spark.{XGBoostClassificationModel, XGBoostClassifier}\nimport org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator\nimport org.apache.spark.sql.SparkSession\n\nobject Main extends Mortgage {\n\n  def main(args: Array[String]): Unit = {\n    val appArgs = XGBoostArgs(args)\n    val processor = this.getClass.getSimpleName.stripSuffix(\"$\").substring(0, 3)\n    val appInfo = Seq(appName, processor, appArgs.format)\n    val benchmark = Benchmark(appInfo(0), appInfo(1), appInfo(2))\n    // build spark session\n    val spark = SparkSession.builder().appName(appInfo.mkString(\"-\")).getOrCreate()\n    // build data reader\n    val dataReader = spark.read\n\n    try {\n      // loaded XGBoost ETLed data\n      val pathsArray = appArgs.getDataPaths\n      // 0: train 1: eval 2:transform\n      val datasets = pathsArray.map { paths =>\n        if (paths.nonEmpty) {\n          appArgs.format match {\n            case \"csv\" => Some(dataReader.option(\"header\", appArgs.hasHeader).schema(schema).csv(paths: _*))\n            case \"orc\" => Some(dataReader.orc(paths: _*))\n            case \"parquet\" => Some(dataReader.parquet(paths: _*))\n            case _ => throw new IllegalArgumentException(\"Unsupported data file format!\")\n          }\n        } else {\n          None\n        }\n      }\n\n      val xgbClassificationModel = if (appArgs.isToTrain) {\n        // build XGBoost classifier\n        val xgbParamFinal = appArgs.xgboostParams(commParamMap)\n        val xgbClassifier = new XGBoostClassifier(xgbParamFinal)\n          .setLabelCol(labelColName)\n          .setFeaturesCol(featureNames)\n\n        datasets(1).foreach(_ => xgbClassifier.setEvalDataset(_))\n\n        // Start training\n        println(\"\\n------ Training ------\")\n        // Shall we not log the time if it is abnormal, which is usually caused by training failure\n        val (model, _) = benchmark.time(\"train\") {\n          xgbClassifier.fit(datasets(0).get)\n        }\n        // Save model if modelPath exists\n        appArgs.modelPath.foreach(path =>\n          if (appArgs.isOverwrite) model.write.overwrite().save(path) else model.save(path))\n        model\n      } else {\n        XGBoostClassificationModel.load(appArgs.modelPath.get)\n      }\n\n      if (appArgs.isToTransform) {\n        println(\"\\n------ Transforming ------\")\n        var (results, _) = benchmark.time(\"transform\") {\n          val ret = xgbClassificationModel.transform(datasets(2).get).cache()\n          // Trigger the transformation\n          ret.foreachPartition((_: Iterator[_]) => ())\n          ret\n        }\n        results = if (appArgs.isShowFeatures) {\n          results\n        } else {\n          results.select(labelColName, \"rawPrediction\", \"probability\", \"prediction\")\n        }\n        results.show(appArgs.numRows)\n\n        println(\"\\n------Accuracy of Evaluation------\")\n        val evaluator = new MulticlassClassificationEvaluator().setLabelCol(labelColName)\n        evaluator.evaluate(results) match {\n          case accuracy if !accuracy.isNaN =>\n            benchmark.value(accuracy, \"Accuracy\", \"Accuracy for\")\n          // Throw an exception when NaN ?\n        }\n      }\n    } finally {\n      spark.close()\n    }\n  }\n}\n"
  },
  {
    "path": "examples/XGBoost-Examples/mortgage/scala/src/com/nvidia/spark/examples/mortgage/Mortgage.scala",
    "content": "/*\n * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.nvidia.spark.examples.mortgage\n\nimport org.apache.spark.sql.types.{FloatType, IntegerType, StructField, StructType, DoubleType}\n\nprivate[mortgage] trait Mortgage {\n  val appName = \"Mortgage\"\n  val labelColName = \"delinquency_12\"\n\n  protected val categaryCols = List(\n    (\"orig_channel\", FloatType),\n    (\"first_home_buyer\", FloatType),\n    (\"loan_purpose\", FloatType),\n    (\"property_type\", FloatType),\n    (\"occupancy_status\", FloatType),\n    (\"property_state\", FloatType),\n    (\"product_type\", FloatType),\n    (\"relocation_mortgage_indicator\", FloatType),\n    (\"seller_name\", FloatType),\n    (\"mod_flag\", FloatType)\n  )\n\n  protected val numericCols = List(\n    (\"orig_interest_rate\", FloatType),\n    (\"orig_upb\", DoubleType),\n    (\"orig_loan_term\", IntegerType),\n    (\"orig_ltv\", FloatType),\n    (\"orig_cltv\", FloatType),\n    (\"num_borrowers\", FloatType),\n    (\"dti\", FloatType),\n    (\"borrower_credit_score\", FloatType),\n    (\"num_units\", IntegerType),\n    (\"zip\", IntegerType),\n    (\"mortgage_insurance_percent\", FloatType),\n    (\"current_loan_delinquency_status\", IntegerType),\n    (\"current_actual_upb\", FloatType),\n    (\"interest_rate\", FloatType),\n    (\"loan_age\", FloatType),\n    (\"msa\", FloatType),\n    (\"non_interest_bearing_upb\", FloatType),\n    (labelColName, IntegerType)\n  )\n\n  lazy val schema = StructType((categaryCols ++ numericCols).map(col => StructField(col._1, col._2)))\n  lazy val featureNames = schema.filter(_.name != labelColName).map(_.name).toArray\n\n  val commParamMap = Map(\n    \"objective\" -> \"binary:logistic\",\n    \"num_round\" -> 100)\n}\n"
  },
  {
    "path": "examples/XGBoost-Examples/mortgage/scala/src/com/nvidia/spark/examples/mortgage/XGBoostETL.scala",
    "content": "/*\n * Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\npackage com.nvidia.spark.examples.mortgage\n\nimport org.apache.spark.sql.expressions.Window\nimport org.apache.spark.sql.functions._\nimport org.apache.spark.sql.types._\nimport org.apache.spark.sql.{Column, DataFrame, SparkSession}\n\n\nobject GetQuarterFromCsvFileName {\n  // The format is path/TYPE_yyyy\\QQ.txt followed by a (_index)* where index is a single digit number [0-9]\n  // i.e. mortgage/perf/Performance_2003Q4.txt_0_1\n  // So we strip off the .txt and everything after it\n  // and then take everything after the last remaining _\n  def apply(): Column = substring_index(\n    substring_index(input_file_name(), \".\", 1), \"/\", -1)\n}\n\nprivate object CsvReader {\n\n  def readRaw(spark: SparkSession, paths: Seq[String], optionsMap: Map[String, String]): DataFrame = {\n\n    val rawSchema = StructType(Array(\n      StructField(\"reference_pool_id\", StringType),\n      StructField(\"loan_id\", LongType),\n      StructField(\"monthly_reporting_period\", StringType),\n      StructField(\"orig_channel\", StringType),\n      StructField(\"seller_name\", StringType),\n      StructField(\"servicer\", StringType),\n      StructField(\"master_servicer\", StringType),\n      StructField(\"orig_interest_rate\", DoubleType),\n      StructField(\"interest_rate\", DoubleType),\n      StructField(\"orig_upb\", DoubleType),\n      StructField(\"upb_at_issuance\", StringType),\n      StructField(\"current_actual_upb\", DoubleType),\n      StructField(\"orig_loan_term\", IntegerType),\n      StructField(\"orig_date\", StringType),\n      StructField(\"first_pay_date\", StringType),    \n      StructField(\"loan_age\", DoubleType),\n      StructField(\"remaining_months_to_legal_maturity\", DoubleType),\n      StructField(\"adj_remaining_months_to_maturity\", DoubleType),\n      StructField(\"maturity_date\", StringType),\n      StructField(\"orig_ltv\", DoubleType),\n      StructField(\"orig_cltv\", DoubleType),\n      StructField(\"num_borrowers\", DoubleType),\n      StructField(\"dti\", DoubleType),\n      StructField(\"borrower_credit_score\", DoubleType),\n      StructField(\"coborrow_credit_score\", DoubleType),\n      StructField(\"first_home_buyer\", StringType),\n      StructField(\"loan_purpose\", StringType),\n      StructField(\"property_type\", StringType),\n      StructField(\"num_units\", IntegerType),\n      StructField(\"occupancy_status\", StringType),\n      StructField(\"property_state\", StringType),\n      StructField(\"msa\", DoubleType),\n      StructField(\"zip\", IntegerType),\n      StructField(\"mortgage_insurance_percent\", DoubleType),\n      StructField(\"product_type\", StringType),\n      StructField(\"prepayment_penalty_indicator\", StringType),\n      StructField(\"interest_only_loan_indicator\", StringType),\n      StructField(\"interest_only_first_principal_and_interest_payment_date\", StringType),\n      StructField(\"months_to_amortization\", StringType),\n      StructField(\"current_loan_delinquency_status\", IntegerType),\n      StructField(\"loan_payment_history\", StringType),\n      StructField(\"mod_flag\", StringType),\n      StructField(\"mortgage_insurance_cancellation_indicator\", StringType),\n      StructField(\"zero_balance_code\", StringType),\n      StructField(\"zero_balance_effective_date\", StringType),\n      StructField(\"upb_at_the_time_of_removal\", StringType),\n      StructField(\"repurchase_date\", StringType),\n      StructField(\"scheduled_principal_current\", StringType),\n      StructField(\"total_principal_current\", StringType),\n      StructField(\"unscheduled_principal_current\", StringType),\n      StructField(\"last_paid_installment_date\", StringType),\n      StructField(\"foreclosed_after\", StringType),\n      StructField(\"disposition_date\", StringType),\n      StructField(\"foreclosure_costs\", DoubleType),\n      StructField(\"prop_preservation_and_repair_costs\", DoubleType),\n      StructField(\"asset_recovery_costs\", DoubleType),\n      StructField(\"misc_holding_expenses\", DoubleType),\n      StructField(\"holding_taxes\", DoubleType),\n      StructField(\"net_sale_proceeds\", DoubleType),\n      StructField(\"credit_enhancement_proceeds\", DoubleType),\n      StructField(\"repurchase_make_whole_proceeds\", StringType),\n      StructField(\"other_foreclosure_proceeds\", DoubleType),\n      StructField(\"non_interest_bearing_upb\", DoubleType),\n      StructField(\"principal_forgiveness_upb\", StringType),\n      StructField(\"original_list_start_date\", StringType),\n      StructField(\"original_list_price\", StringType),\n      StructField(\"current_list_start_date\", StringType),\n      StructField(\"current_list_price\", StringType),\n      StructField(\"borrower_credit_score_at_issuance\", StringType),\n      StructField(\"co-borrower_credit_score_at_issuance\", StringType),\n      StructField(\"borrower_credit_score_current\", StringType),\n      StructField(\"co-Borrower_credit_score_current\", StringType),\n      StructField(\"mortgage_insurance_type\", DoubleType),\n      StructField(\"servicing_activity_indicator\", StringType),\n      StructField(\"current_period_modification_loss_amount\", StringType),\n      StructField(\"cumulative_modification_loss_amount\", StringType),\n      StructField(\"current_period_credit_event_net_gain_or_loss\", StringType),\n      StructField(\"cumulative_credit_event_net_gain_or_loss\", StringType),\n      StructField(\"homeready_program_indicator\", StringType),\n      StructField(\"foreclosure_principal_write_off_amount\", StringType),\n      StructField(\"relocation_mortgage_indicator\", StringType),\n      StructField(\"zero_balance_code_change_date\", StringType),\n      StructField(\"loan_holdback_indicator\", StringType),\n      StructField(\"loan_holdback_effective_date\", StringType),\n      StructField(\"delinquent_accrued_interest\", StringType),\n      StructField(\"property_valuation_method\", StringType),\n      StructField(\"high_balance_loan_indicator\", StringType),\n      StructField(\"arm_initial_fixed-rate_period_lt_5_yr_indicator\", StringType),\n      StructField(\"arm_product_type\", StringType),\n      StructField(\"initial_fixed-rate_period\", StringType),\n      StructField(\"interest_rate_adjustment_frequency\", StringType),\n      StructField(\"next_interest_rate_adjustment_date\", StringType),\n      StructField(\"next_payment_change_date\", StringType),\n      StructField(\"index\", StringType),\n      StructField(\"arm_cap_structure\", StringType),\n      StructField(\"initial_interest_rate_cap_up_percent\", StringType),\n      StructField(\"periodic_interest_rate_cap_up_percent\", StringType),\n      StructField(\"lifetime_interest_rate_cap_up_percent\", StringType),\n      StructField(\"mortgage_margin\", StringType),\n      StructField(\"arm_balloon_indicator\", StringType),\n      StructField(\"arm_plan_number\", StringType),\n      StructField(\"borrower_assistance_plan\", StringType),\n      StructField(\"hltv_refinance_option_indicator\", StringType),\n      StructField(\"deal_name\", StringType),\n      StructField(\"repurchase_make_whole_proceeds_flag\", StringType),\n      StructField(\"alternative_delinquency_resolution\", StringType),\n      StructField(\"alternative_delinquency_resolution_count\", StringType),\n      StructField(\"total_deferral_amount\", StringType)\n      )\n    )\n\n    spark.read\n      .options(optionsMap)\n      .option(\"nullValue\", \"\")\n      .option(\"delimiter\", \"|\")\n      .schema(rawSchema)\n      .csv(paths: _*)\n      .withColumn(\"quarter\", GetQuarterFromCsvFileName())\n  }\n}\n\nobject extractPerfColumns{\n  def apply(rawDf : DataFrame) : DataFrame = {\n    val perfDf = rawDf.select(\n      col(\"loan_id\"),\n      date_format(to_date(col(\"monthly_reporting_period\"),\"MMyyyy\"), \"MM/dd/yyyy\").as(\"monthly_reporting_period\"),\n      upper(col(\"servicer\")).as(\"servicer\"),\n      col(\"interest_rate\"),\n      col(\"current_actual_upb\"),\n      col(\"loan_age\"),\n      col(\"remaining_months_to_legal_maturity\"),\n      col(\"adj_remaining_months_to_maturity\"),\n      date_format(to_date(col(\"maturity_date\"),\"MMyyyy\"), \"MM/yyyy\").as(\"maturity_date\"),\n      col(\"msa\"),\n      col(\"current_loan_delinquency_status\"),\n      col(\"mod_flag\"),\n      col(\"zero_balance_code\"),\n      date_format(to_date(col(\"zero_balance_effective_date\"),\"MMyyyy\"), \"MM/yyyy\").as(\"zero_balance_effective_date\"),\n      date_format(to_date(col(\"last_paid_installment_date\"),\"MMyyyy\"), \"MM/dd/yyyy\").as(\"last_paid_installment_date\"),\n      date_format(to_date(col(\"foreclosed_after\"),\"MMyyyy\"), \"MM/dd/yyyy\").as(\"foreclosed_after\"),\n      date_format(to_date(col(\"disposition_date\"),\"MMyyyy\"), \"MM/dd/yyyy\").as(\"disposition_date\"),\n      col(\"foreclosure_costs\"),\n      col(\"prop_preservation_and_repair_costs\"),\n      col(\"asset_recovery_costs\"),\n      col(\"misc_holding_expenses\"),\n      col(\"holding_taxes\"),\n      col(\"net_sale_proceeds\"),\n      col(\"credit_enhancement_proceeds\"),\n      col(\"repurchase_make_whole_proceeds\"),\n      col(\"other_foreclosure_proceeds\"),\n      col(\"non_interest_bearing_upb\"),\n      col(\"principal_forgiveness_upb\"),\n      col(\"repurchase_make_whole_proceeds_flag\"),\n      col(\"foreclosure_principal_write_off_amount\"),\n      col(\"servicing_activity_indicator\"),\n      col(\"quarter\")\n    )\n\n    perfDf.select(\"*\").filter(\"current_actual_upb != 0.0\")\n  }\n}\n\nobject extractAcqColumns{\n  def apply(rawDf : DataFrame) : DataFrame = {\n    val acqDf = rawDf.select(\n      col(\"loan_id\"),\n      col(\"orig_channel\"),\n      upper(col(\"seller_name\")).as(\"seller_name\"),\n      col(\"orig_interest_rate\"),\n      col(\"orig_upb\"),\n      col(\"orig_loan_term\"),\n      date_format(to_date(col(\"orig_date\"),\"MMyyyy\"), \"MM/yyyy\").as(\"orig_date\"),\n      date_format(to_date(col(\"first_pay_date\"),\"MMyyyy\"), \"MM/yyyy\").as(\"first_pay_date\"),\n      col(\"orig_ltv\"),\n      col(\"orig_cltv\"),\n      col(\"num_borrowers\"),\n      col(\"dti\"),\n      col(\"borrower_credit_score\"),\n      col(\"first_home_buyer\"),\n      col(\"loan_purpose\"),\n      col(\"property_type\"),\n      col(\"num_units\"),\n      col(\"occupancy_status\"),\n      col(\"property_state\"),\n      col(\"zip\"),\n      col(\"mortgage_insurance_percent\"),\n      col(\"product_type\"),\n      col(\"coborrow_credit_score\"),\n      col(\"mortgage_insurance_type\"),\n      col(\"relocation_mortgage_indicator\"),\n      col(\"quarter\"),\n      dense_rank().over(Window.partitionBy(\"loan_id\").orderBy(to_date(col(\"monthly_reporting_period\"),\"MMyyyy\"))).as(\"rank\")\n    )\n\n    acqDf.select(\"*\").filter(col(\"rank\") === 1).drop(\"rank\")\n  }\n\n}\n\nobject NameMapping {\n  /**\n    * Returns a dataframe with two columns named based off of the column names passed in.\n    * The fromColName has the original name we want to clean up, the toColName\n    * will have the name we want to go to, the unambiguous name.\n    */\n  def apply(spark: SparkSession, fromColName: String, toColName: String): DataFrame = {\n    import spark.sqlContext.implicits._\n    broadcast(Seq(\n      (\"WITMER FUNDING, LLC\", \"Witmer\"),\n      (\"WELLS FARGO CREDIT RISK TRANSFER SECURITIES TRUST 2015\", \"Wells Fargo\"),\n      (\"WELLS FARGO BANK,  NA\" , \"Wells Fargo\"),\n      (\"WELLS FARGO BANK, N.A.\" , \"Wells Fargo\"),\n      (\"WELLS FARGO BANK, NA\" , \"Wells Fargo\"),\n      (\"USAA FEDERAL SAVINGS BANK\" , \"USAA\"),\n      (\"UNITED SHORE FINANCIAL SERVICES, LLC D\\\\/B\\\\/A UNITED WHOLESALE MORTGAGE\" , \"United Seq(e\"),\n      (\"U.S. BANK N.A.\" , \"US Bank\"),\n      (\"SUNTRUST MORTGAGE INC.\" , \"Suntrust\"),\n      (\"STONEGATE MORTGAGE CORPORATION\" , \"Stonegate Mortgage\"),\n      (\"STEARNS LENDING, LLC\" , \"Stearns Lending\"),\n      (\"STEARNS LENDING, INC.\" , \"Stearns Lending\"),\n      (\"SIERRA PACIFIC MORTGAGE COMPANY, INC.\" , \"Sierra Pacific Mortgage\"),\n      (\"REGIONS BANK\" , \"Regions\"),\n      (\"RBC MORTGAGE COMPANY\" , \"RBC\"),\n      (\"QUICKEN LOANS INC.\" , \"Quicken Loans\"),\n      (\"PULTE MORTGAGE, L.L.C.\" , \"Pulte Mortgage\"),\n      (\"PROVIDENT FUNDING ASSOCIATES, L.P.\" , \"Provident Funding\"),\n      (\"PROSPECT MORTGAGE, LLC\" , \"Prospect Mortgage\"),\n      (\"PRINCIPAL RESIDENTIAL MORTGAGE CAPITAL RESOURCES, LLC\" , \"Principal Residential\"),\n      (\"PNC BANK, N.A.\" , \"PNC\"),\n      (\"PMT CREDIT RISK TRANSFER TRUST 2015-2\" , \"PennyMac\"),\n      (\"PHH MORTGAGE CORPORATION\" , \"PHH Mortgage\"),\n      (\"PENNYMAC CORP.\" , \"PennyMac\"),\n      (\"PACIFIC UNION FINANCIAL, LLC\" , \"Other\"),\n      (\"OTHER\" , \"Other\"),\n      (\"NYCB MORTGAGE COMPANY, LLC\" , \"NYCB\"),\n      (\"NEW YORK COMMUNITY BANK\" , \"NYCB\"),\n      (\"NETBANK FUNDING SERVICES\" , \"Netbank\"),\n      (\"NATIONSTAR MORTGAGE, LLC\" , \"Nationstar Mortgage\"),\n      (\"METLIFE BANK, NA\" , \"Metlife\"),\n      (\"LOANDEPOT.COM, LLC\" , \"LoanDepot.com\"),\n      (\"J.P. MORGAN MADISON AVENUE SECURITIES TRUST, SERIES 2015-1\" , \"JP Morgan Chase\"),\n      (\"J.P. MORGAN MADISON AVENUE SECURITIES TRUST, SERIES 2014-1\" , \"JP Morgan Chase\"),\n      (\"JPMORGAN CHASE BANK, NATIONAL ASSOCIATION\" , \"JP Morgan Chase\"),\n      (\"JPMORGAN CHASE BANK, NA\" , \"JP Morgan Chase\"),\n      (\"JP MORGAN CHASE BANK, NA\" , \"JP Morgan Chase\"),\n      (\"IRWIN MORTGAGE, CORPORATION\" , \"Irwin Mortgage\"),\n      (\"IMPAC MORTGAGE CORP.\" , \"Impac Mortgage\"),\n      (\"HSBC BANK USA, NATIONAL ASSOCIATION\" , \"HSBC\"),\n      (\"HOMEWARD RESIDENTIAL, INC.\" , \"Homeward Mortgage\"),\n      (\"HOMESTREET BANK\" , \"Other\"),\n      (\"HOMEBRIDGE FINANCIAL SERVICES, INC.\" , \"HomeBridge\"),\n      (\"HARWOOD STREET FUNDING I, LLC\" , \"Harwood Mortgage\"),\n      (\"GUILD MORTGAGE COMPANY\" , \"Guild Mortgage\"),\n      (\"GMAC MORTGAGE, LLC (USAA FEDERAL SAVINGS BANK)\" , \"GMAC\"),\n      (\"GMAC MORTGAGE, LLC\" , \"GMAC\"),\n      (\"GMAC (USAA)\" , \"GMAC\"),\n      (\"FREMONT BANK\" , \"Fremont Bank\"),\n      (\"FREEDOM MORTGAGE CORP.\" , \"Freedom Mortgage\"),\n      (\"FRANKLIN AMERICAN MORTGAGE COMPANY\" , \"Franklin America\"),\n      (\"FLEET NATIONAL BANK\" , \"Fleet National\"),\n      (\"FLAGSTAR CAPITAL MARKETS CORPORATION\" , \"Flagstar Bank\"),\n      (\"FLAGSTAR BANK, FSB\" , \"Flagstar Bank\"),\n      (\"FIRST TENNESSEE BANK NATIONAL ASSOCIATION\" , \"Other\"),\n      (\"FIFTH THIRD BANK\" , \"Fifth Third Bank\"),\n      (\"FEDERAL HOME LOAN BANK OF CHICAGO\" , \"Fedral Home of Chicago\"),\n      (\"FDIC, RECEIVER, INDYMAC FEDERAL BANK FSB\" , \"FDIC\"),\n      (\"DOWNEY SAVINGS AND LOAN ASSOCIATION, F.A.\" , \"Downey Mortgage\"),\n      (\"DITECH FINANCIAL LLC\" , \"Ditech\"),\n      (\"CITIMORTGAGE, INC.\" , \"Citi\"),\n      (\"CHICAGO MORTGAGE SOLUTIONS DBA INTERFIRST MORTGAGE COMPANY\" , \"Chicago Mortgage\"),\n      (\"CHICAGO MORTGAGE SOLUTIONS DBA INTERBANK MORTGAGE COMPANY\" , \"Chicago Mortgage\"),\n      (\"CHASE HOME FINANCE, LLC\" , \"JP Morgan Chase\"),\n      (\"CHASE HOME FINANCE FRANKLIN AMERICAN MORTGAGE COMPANY\" , \"JP Morgan Chase\"),\n      (\"CHASE HOME FINANCE (CIE 1)\" , \"JP Morgan Chase\"),\n      (\"CHASE HOME FINANCE\" , \"JP Morgan Chase\"),\n      (\"CASHCALL, INC.\" , \"CashCall\"),\n      (\"CAPITAL ONE, NATIONAL ASSOCIATION\" , \"Capital One\"),\n      (\"CALIBER HOME LOANS, INC.\" , \"Caliber Funding\"),\n      (\"BISHOPS GATE RESIDENTIAL MORTGAGE TRUST\" , \"Bishops Gate Mortgage\"),\n      (\"BANK OF AMERICA, N.A.\" , \"Bank of America\"),\n      (\"AMTRUST BANK\" , \"AmTrust\"),\n      (\"AMERISAVE MORTGAGE CORPORATION\" , \"Amerisave\"),\n      (\"AMERIHOME MORTGAGE COMPANY, LLC\" , \"AmeriHome Mortgage\"),\n      (\"ALLY BANK\" , \"Ally Bank\"),\n      (\"ACADEMY MORTGAGE CORPORATION\" , \"Academy Mortgage\"),\n      (\"NO CASH-OUT REFINANCE\" , \"OTHER REFINANCE\"),\n      (\"REFINANCE - NOT SPECIFIED\" , \"OTHER REFINANCE\"),\n      (\"Other REFINANCE\" , \"OTHER REFINANCE\")\n    ).toDF(fromColName, toColName))\n  }\n}\n\nprivate trait MortgageETL {\n  var dataFrame: DataFrame = _\n\n  def from(df: DataFrame): this.type = {\n    dataFrame = df\n    this\n  }\n}\n\nprivate object PerformanceETL extends MortgageETL {\n\n  def prepare: this.type = {\n    dataFrame = dataFrame\n      .withColumn(\"monthly_reporting_period\", to_date(col(\"monthly_reporting_period\"), \"MM/dd/yyyy\"))\n      .withColumn(\"monthly_reporting_period_month\", month(col(\"monthly_reporting_period\")))\n      .withColumn(\"monthly_reporting_period_year\", year(col(\"monthly_reporting_period\")))\n      .withColumn(\"monthly_reporting_period_day\", dayofmonth(col(\"monthly_reporting_period\")))\n      .withColumn(\"last_paid_installment_date\", to_date(col(\"last_paid_installment_date\"), \"MM/dd/yyyy\"))\n      .withColumn(\"foreclosed_after\", to_date(col(\"foreclosed_after\"), \"MM/dd/yyyy\"))\n      .withColumn(\"disposition_date\", to_date(col(\"disposition_date\"), \"MM/dd/yyyy\"))\n      .withColumn(\"maturity_date\", to_date(col(\"maturity_date\"), \"MM/yyyy\"))\n      .withColumn(\"zero_balance_effective_date\", to_date(col(\"zero_balance_effective_date\"), \"MM/yyyy\"))\n      .withColumn(\"current_actual_upb\", col(\"current_actual_upb\"))\n      .withColumn(\"current_loan_delinquency_status\", col(\"current_loan_delinquency_status\"))\n    this\n  }\n\n  def createDelinquency(spark: SparkSession): this.type = {\n    val aggDF = dataFrame\n      .select(\n        col(\"quarter\"),\n        col(\"loan_id\"),\n        col(\"current_loan_delinquency_status\"),\n        when(col(\"current_loan_delinquency_status\") >= 1, col(\"monthly_reporting_period\")).alias(\"delinquency_30\"),\n        when(col(\"current_loan_delinquency_status\") >= 3, col(\"monthly_reporting_period\")).alias(\"delinquency_90\"),\n        when(col(\"current_loan_delinquency_status\") >= 6, col(\"monthly_reporting_period\")).alias(\"delinquency_180\")\n      )\n      .groupBy(\"quarter\", \"loan_id\")\n      .agg(\n        max(\"current_loan_delinquency_status\").alias(\"delinquency_12\"),\n        min(\"delinquency_30\").alias(\"delinquency_30\"),\n        min(\"delinquency_90\").alias(\"delinquency_90\"),\n        min(\"delinquency_180\").alias(\"delinquency_180\")\n      )\n      .select(\n        col(\"quarter\"),\n        col(\"loan_id\"),\n        (col(\"delinquency_12\") >= 1).alias(\"ever_30\"),\n        (col(\"delinquency_12\") >= 3).alias(\"ever_90\"),\n        (col(\"delinquency_12\") >= 6).alias(\"ever_180\"),\n        col(\"delinquency_30\"),\n        col(\"delinquency_90\"),\n        col(\"delinquency_180\")\n      )\n\n    val joinedDf = dataFrame\n      .withColumnRenamed(\"monthly_reporting_period\", \"timestamp\")\n      .withColumnRenamed(\"monthly_reporting_period_month\", \"timestamp_month\")\n      .withColumnRenamed(\"monthly_reporting_period_year\", \"timestamp_year\")\n      .withColumnRenamed(\"current_loan_delinquency_status\", \"delinquency_12\")\n      .withColumnRenamed(\"current_actual_upb\", \"upb_12\")\n      .select(\"quarter\", \"loan_id\", \"timestamp\", \"delinquency_12\", \"upb_12\", \"timestamp_month\", \"timestamp_year\")\n      .join(aggDF, Seq(\"loan_id\", \"quarter\"), \"left_outer\")\n\n    // calculate the 12 month delinquency and upb values\n    val months = 12\n    val monthArray = 0.until(months).toArray\n    val testDf = joinedDf\n      // explode on a small amount of data is actually slightly more efficient than a cross join\n      .withColumn(\"month_y\", explode(lit(monthArray)))\n      .select(\n        col(\"quarter\"),\n        floor(((col(\"timestamp_year\") * 12 + col(\"timestamp_month\")) - 24000) / months).alias(\"josh_mody\"),\n        floor(((col(\"timestamp_year\") * 12 + col(\"timestamp_month\")) - 24000 - col(\"month_y\")) / months).alias(\"josh_mody_n\"),\n        col(\"ever_30\"),\n        col(\"ever_90\"),\n        col(\"ever_180\"),\n        col(\"delinquency_30\"),\n        col(\"delinquency_90\"),\n        col(\"delinquency_180\"),\n        col(\"loan_id\"),\n        col(\"month_y\"),\n        col(\"delinquency_12\"),\n        col(\"upb_12\")\n      )\n      .groupBy(\"quarter\", \"loan_id\", \"josh_mody_n\", \"ever_30\", \"ever_90\", \"ever_180\", \"delinquency_30\", \"delinquency_90\", \"delinquency_180\", \"month_y\")\n      .agg(max(\"delinquency_12\").alias(\"delinquency_12\"), min(\"upb_12\").alias(\"upb_12\"))\n      .withColumn(\"timestamp_year\", floor((lit(24000) + (col(\"josh_mody_n\") * lit(months)) + (col(\"month_y\") - 1)) / lit(12)))\n      .withColumn(\"timestamp_month_tmp\", pmod(lit(24000) + (col(\"josh_mody_n\") * lit(months)) + col(\"month_y\"), lit(12)))\n      .withColumn(\"timestamp_month\", when(col(\"timestamp_month_tmp\") === lit(0), lit(12)).otherwise(col(\"timestamp_month_tmp\")))\n      .withColumn(\"delinquency_12\", ((col(\"delinquency_12\") > 3).cast(\"int\") + (col(\"upb_12\") === 0).cast(\"int\")).alias(\"delinquency_12\"))\n      .drop(\"timestamp_month_tmp\", \"josh_mody_n\", \"month_y\")\n\n    dataFrame = dataFrame\n      .withColumnRenamed(\"monthly_reporting_period_month\", \"timestamp_month\")\n      .withColumnRenamed(\"monthly_reporting_period_year\", \"timestamp_year\")\n      .join(testDf, Seq(\"quarter\", \"loan_id\", \"timestamp_year\", \"timestamp_month\"), \"left\").drop(\"timestamp_year\", \"timestamp_month\")\n    this\n  }\n}\n\nprivate object AcquisitionETL extends MortgageETL {\n\n  def createAcquisition(spark: SparkSession): this.type = {\n    val nameMapping = NameMapping(spark, \"from_seller_name\", \"to_seller_name\")\n    dataFrame = dataFrame\n      .join(nameMapping, col(\"seller_name\") === col(\"from_seller_name\"), \"left\")\n      .drop(\"from_seller_name\")\n      /* backup the original name before we replace it */\n      .withColumn(\"old_name\", col(\"seller_name\"))\n      /* replace seller_name with the new version if we found one in the mapping, or the old version\n       if we didn't */\n      .withColumn(\"seller_name\", coalesce(col(\"to_seller_name\"), col(\"seller_name\")))\n      .drop(\"to_seller_name\")\n      .withColumn(\"orig_date\", to_date(col(\"orig_date\"), \"MM/yyyy\"))\n      .withColumn(\"first_pay_date\", to_date(col(\"first_pay_date\"), \"MM/yyyy\"))\n    this\n  }\n\n  def cleanPrime(perfDF: DataFrame): this.type = {\n    dataFrame = perfDF.join(dataFrame, Seq(\"loan_id\", \"quarter\"), \"inner\").drop(\"quarter\")\n    this\n  }\n}\n\nobject XGBoostETL extends Mortgage {\n\n  private lazy val allCols = (categaryCols ++ numericCols).map(c => col(c._1))\n  private var cachedDictDF: DataFrame = _\n\n  /**\n    * Generate a dictionary from string to numeric value for multiple category columns.\n    *\n    * (Copied the solution of casting string to numeric from the utils of DLRM.)\n    */\n  private def genDictionary(etlDF: DataFrame, colNames: Seq[String]): DataFrame = {\n    val cntTable = etlDF\n      .select(posexplode(array(colNames.map(col(_)): _*)))\n      .withColumnRenamed(\"pos\", \"column_id\")\n      .withColumnRenamed(\"col\", \"data\")\n      .filter(\"data is not null\")\n      .groupBy(\"column_id\", \"data\")\n      .count()\n    val windowed = Window.partitionBy(\"column_id\").orderBy(desc(\"count\"))\n    cntTable\n      .withColumn(\"id\", row_number().over(windowed))\n      .drop(\"count\")\n  }\n\n  /**\n    * Cast all the category columns to numeric columns in the given data frame.\n    * Then it is suitable for XGBoost training/transforming\n    */\n  private def castStringColumnsToNumeric(inputDF: DataFrame, spark: SparkSession): DataFrame = {\n    val cateColNames = categaryCols.map(_._1)\n    cachedDictDF = genDictionary(inputDF, cateColNames).cache()\n\n    // Generate the final table with all columns being numeric.\n    cateColNames.foldLeft(inputDF) {\n      case (df, colName) =>\n        val colPos = cateColNames.indexOf(colName)\n        val colDictDF = cachedDictDF\n          .filter(col(\"column_id\") === colPos)\n          .drop(\"column_id\")\n          .withColumnRenamed(\"data\", colName)\n        df.join(broadcast(colDictDF), Seq(colName), \"left\")\n          .drop(colName)\n          .withColumnRenamed(\"id\", colName)\n    }\n  }\n\n  private def transform(perfDF: DataFrame, acqDF: DataFrame, spark: SparkSession): DataFrame = {\n    val etlPerfDF = PerformanceETL.from(perfDF)\n      .prepare\n      .createDelinquency(spark)\n      .dataFrame\n    val cleanDF = AcquisitionETL.from(acqDF)\n      .createAcquisition(spark)\n      .cleanPrime(etlPerfDF)\n      .dataFrame\n\n    // Convert to xgb required Dataset\n    castStringColumnsToNumeric(cleanDF, spark)\n      .select(allCols: _*)\n      .withColumn(labelColName, when(col(labelColName) > 0, 1).otherwise(0))\n      .na.fill(0.0f)\n  }\n\n  def clean(): Unit = {\n    if (cachedDictDF != null) {\n      cachedDictDF.unpersist()\n      cachedDictDF = null\n    }\n  }\n\n  def saveDictTable(outPath: String): Unit = {\n    if (cachedDictDF != null) {\n      // The dict data is small, so merge it into one file.\n      cachedDictDF\n        .repartition(1)\n        .write\n        .mode(\"overwrite\")\n        .parquet(outPath)\n    }\n  }\n\n  def csv(spark: SparkSession, dataPaths: Seq[String], tmpPath: String, hasHeader: Boolean): DataFrame = {\n    val optionsMap = Map(\"header\" -> hasHeader.toString)\n    val rawDf_csv = CsvReader.readRaw(spark, dataPaths, optionsMap)\n    \n    rawDf_csv.write.mode(\"overwrite\").parquet(tmpPath)\n    val rawDf = spark.read.parquet(tmpPath)\n    \n    val perfDf = extractPerfColumns(rawDf)\n    val acqDf = extractAcqColumns(rawDf)\n    transform(\n      perfDf,\n      acqDf,\n      spark\n    )\n  }\n\n  def parquet(spark: SparkSession, dataPaths: Seq[String]): DataFrame = {\n    val rawDf = spark.read.parquet(dataPaths: _*)\n    val perfDf = extractPerfColumns(rawDf)\n    val acqDf = extractAcqColumns(rawDf)\n    transform(\n      perfDf,\n      acqDf,\n      spark\n    )\n  }\n\n  def orc(spark: SparkSession, dataPaths: Seq[String]): DataFrame = {\n    val rawDf = spark.read.orc(dataPaths: _*)\n    val perfDf = extractPerfColumns(rawDf)\n    val acqDf = extractAcqColumns(rawDf)\n    transform(\n      perfDf,\n      acqDf,\n      spark\n    )\n  }\n  \n  \n}\n"
  },
  {
    "path": "examples/XGBoost-Examples/pack_pyspark_example.sh",
    "content": "#!/bin/bash\n# Copyright (c) 2024-2025, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Follow these steps to package the Python zip file\nrm -fr samples.zip\ncd agaricus/python ; zip -r ../../samples.zip com ; cd ../..\ncd mortgage/python ; zip -r ../../samples.zip com ; cd ../..\ncd taxi/python ; zip -r ../../samples.zip com ; cd ../..\ncd utility/python ; zip -r ../../samples.zip com ; cd ../..\n"
  },
  {
    "path": "examples/XGBoost-Examples/pom.xml",
    "content": "<?xml version='1.0' encoding='UTF-8'?>\n<!--\n  ~ Copyright (c) 2019-2025, NVIDIA CORPORATION. All rights reserved.\n  ~\n  ~ Licensed under the Apache License, Version 2.0 (the \"License\");\n  ~ you may not use this file except in compliance with the License.\n  ~ You may obtain a copy of the License at\n  ~\n  ~ http://www.apache.org/licenses/LICENSE-2.0\n  ~\n  ~ Unless required by applicable law or agreed to in writing, software\n  ~ distributed under the License is distributed on an \"AS IS\" BASIS,\n  ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n  ~ See the License for the specific language governing permissions and\n  ~ limitations under the License.\n  -->\n\n<project xsi:schemaLocation=\"http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd\"\n         xmlns:xsi=\"http://www.w3.org/2001/XMLSchema-instance\" xmlns=\"http://maven.apache.org/POM/4.0.0\">\n\n    <modelVersion>4.0.0</modelVersion>\n    <groupId>com.nvidia</groupId>\n    <artifactId>sample_xgboost_examples</artifactId>\n\n    <packaging>pom</packaging>\n    <description>Sample XGBoost4J-Spark applications</description>\n    <modules>\n        \n        <module>utility</module>\n        <module>agaricus</module>\n        <module>mortgage</module>\n        <module>taxi</module>\n        <module>aggregator</module>\n    </modules>\n\n    <version>0.2.3-SNAPSHOT</version>\n    <name>sample_xgboost_apps</name>\n\n    <properties>\n        <encoding>UTF-8</encoding>\n        <xgboost.version>3.1.0-SNAPSHOT</xgboost.version>\n        <spark.version>3.5.0</spark.version>\n        <scala.version>2.12.8</scala.version>\n        <scala.binary.version>2.12</scala.binary.version>\n    </properties>\n\n    <dependencies>\n        <dependency>\n            <groupId>ml.dmlc</groupId>\n            <artifactId>xgboost4j-spark-gpu_${scala.binary.version}</artifactId>\n            <version>${xgboost.version}</version>\n        </dependency>\n        <dependency>\n            <groupId>org.scala-lang</groupId>\n            <artifactId>scala-library</artifactId>\n            <version>${scala.version}</version>\n            <scope>provided</scope>\n        </dependency>\n        <dependency>\n            <groupId>org.apache.spark</groupId>\n            <artifactId>spark-sql_${scala.binary.version}</artifactId>\n            <version>${spark.version}</version>\n            <scope>provided</scope>\n        </dependency>\n        <dependency>\n            <groupId>org.apache.spark</groupId>\n            <artifactId>spark-mllib_${scala.binary.version}</artifactId>\n            <version>${spark.version}</version>\n            <scope>provided</scope>\n        </dependency>\n        <dependency>\n            <groupId>org.scalatest</groupId>\n            <artifactId>scalatest_${scala.binary.version}</artifactId>\n            <version>3.2.15</version>\n            <scope>test</scope>\n        </dependency>\n    </dependencies>\n\n    <build>\n        <plugins>\n            <plugin>\n                <groupId>org.scala-tools</groupId>\n                <artifactId>maven-scala-plugin</artifactId>\n                <version>2.15.2</version>\n                <executions>\n                    <execution>\n                        <goals>\n                            <goal>compile</goal>\n                            <goal>testCompile</goal>\n                        </goals>\n                    </execution>\n                </executions>\n            </plugin>\n            <plugin>\n                <groupId>org.scalatest</groupId>\n                <artifactId>scalatest-maven-plugin</artifactId>\n                <version>1.0</version>\n                <configuration>\n                </configuration>\n                <executions>\n                    <execution>\n                        <id>test</id>\n                        <goals>\n                            <goal>test</goal>\n                        </goals>\n                    </execution>\n                </executions>\n            </plugin>\n            <plugin>\n                <groupId>org.apache.maven.plugins</groupId>\n                <artifactId>maven-assembly-plugin</artifactId>\n                <version>2.6</version>\n                <configuration>\n                    <descriptors>\n                        <descriptor>assembly/assembly-no-scala.xml</descriptor>\n                    </descriptors>\n                </configuration>\n                <executions>\n                    <execution>\n                        <id>assembly</id>\n                        <phase>package</phase>\n                        <goals>\n                            <goal>single</goal>\n                        </goals>\n                    </execution>\n                </executions>\n            </plugin>\n        </plugins>\n    </build>\n    <profiles>\n        <profile>\n            <id>scala-2.13</id>\n            <properties>\n                <xgboost.version>2.1.0-SNAPSHOT</xgboost.version>\n                <spark.version>3.5.0</spark.version>\n                <scala.version>2.13.11</scala.version>\n                <scala.binary.version>2.13</scala.binary.version>\n            </properties>\n        </profile>\n        <profile>\n            <id>sonatype-repo</id>\n            <repositories>\n                <repository>\n                    <id>sonatype-staging-repo</id>\n                    <name>Sonatype staging repo</name>\n                    <url>https://oss.sonatype.org/content/repositories/staging</url>\n                </repository>\n            </repositories>\n        </profile>\n    </profiles>\n\n    <repositories>\n      <repository>\n        <id>XGBoost4J Snapshot Repo</id>\n        <name>XGBoost4J Snapshot Repo</name>\n        <url>https://s3-us-west-2.amazonaws.com/xgboost-maven-repo/snapshot/</url>\n      </repository>\n    </repositories>\n\n</project>\n"
  },
  {
    "path": "examples/XGBoost-Examples/taxi/.gitignore",
    "content": ".idea\ntarget\n*.iml\n"
  },
  {
    "path": "examples/XGBoost-Examples/taxi/notebooks/python/cv-taxi-gpu.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Introduction to XGBoost-Spark Cross Validation with GPU\\n\",\n    \"\\n\",\n    \"The goal of this notebook is to show you how to levarage GPU to accelerate XGBoost spark cross validatoin for hyperparameter tuning. The best model for the given hyperparameters will be returned.\\n\",\n    \"\\n\",\n    \"Here takes the application 'Taxi' as an example.\\n\",\n    \"\\n\",\n    \"A few libraries are required for this notebook:\\n\",\n    \"  1. cudf-cu11\\n\",\n    \"  2. xgboost\\n\",\n    \"  3. scikit-learn\\n\",\n    \"  4. numpy\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Import the Required Libraries\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from xgboost.spark import SparkXGBRegressor, SparkXGBRegressorModel\\n\",\n    \"from pyspark.ml.tuning import ParamGridBuilder, CrossValidator\\n\",\n    \"from pyspark.ml.evaluation import RegressionEvaluator\\n\",\n    \"from pyspark.sql import SparkSession\\n\",\n    \"from pyspark.sql.types import FloatType, IntegerType, StructField, StructType\\n\",\n    \"from time import time\\n\",\n    \"from pyspark.conf import SparkConf\\n\",\n    \"import os\\n\",\n    \"# os.environ['PYSPARK_PYTHON'] = \\\"./environment/bin/python\\\"\\n\",\n    \"# os.environ['PYSPARK_DRIVER_PYTHON'] = \\\"./environment/bin/python\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Create a Spark Session\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2022-11-30 08:02:09,748 WARN util.NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\\n\",\n      \"Setting default log level to \\\"WARN\\\".\\n\",\n      \"To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\\n\",\n      \"2022-11-30 08:02:10,103 WARN resource.ResourceUtils: The configuration of cores (exec = 2 task = 1, runnable tasks = 2) will result in wasted resources due to resource gpu limiting the number of runnable tasks per executor to: 1. Please adjust your configuration.\\n\",\n      \"2022-11-30 08:02:23,737 WARN rapids.RapidsPluginUtils: RAPIDS Accelerator 25.02.1 using cudf 25.02.1.\\n\",\n      \"2022-11-30 08:02:23,752 WARN rapids.RapidsPluginUtils: spark.rapids.sql.multiThreadedRead.numThreads is set to 20.\\n\",\n      \"2022-11-30 08:02:23,756 WARN rapids.RapidsPluginUtils: RAPIDS Accelerator is enabled, to disable GPU support set `spark.rapids.sql.enabled` to false.\\n\",\n      \"2022-11-30 08:02:23,757 WARN rapids.RapidsPluginUtils: spark.rapids.sql.explain is set to `NOT_ON_GPU`. Set it to 'NONE' to suppress the diagnostics logging about the query placement on the GPU.\\n\",\n      \"2022-11-30 08:02:24,226 WARN yarn.Client: Neither spark.yarn.jars nor spark.yarn.archive is set, falling back to uploading libraries under SPARK_HOME.\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"SPARK_MASTER_URL = os.getenv(\\\"SPARK_MASTER_URL\\\", \\\"/your-url\\\")\\n\",\n    \"\\n\",\n    \"RAPIDS_JAR = os.getenv(\\\"RAPIDS_JAR\\\", \\\"/your-jar-path\\\")\\n\",\n    \"\\n\",\n    \"# You need to update with your real hardware resource \\n\",\n    \"driverMem = os.getenv(\\\"DRIVER_MEM\\\", \\\"2g\\\")\\n\",\n    \"executorMem = os.getenv(\\\"EXECUTOR_MEM\\\", \\\"2g\\\")\\n\",\n    \"pinnedPoolSize = os.getenv(\\\"PINNED_POOL_SIZE\\\", \\\"2g\\\")\\n\",\n    \"concurrentGpuTasks = os.getenv(\\\"CONCURRENT_GPU_TASKS\\\", \\\"2\\\")\\n\",\n    \"executorCores = int(os.getenv(\\\"EXECUTOR_CORES\\\", \\\"2\\\"))\\n\",\n    \"# Common spark settings\\n\",\n    \"conf = SparkConf()\\n\",\n    \"conf.setMaster(SPARK_MASTER_URL)\\n\",\n    \"conf.setAppName(\\\"Microbenchmark on GPU\\\")\\n\",\n    \"conf.set(\\\"spark.executor.instances\\\",\\\"1\\\")\\n\",\n    \"conf.set(\\\"spark.driver.memory\\\", driverMem)\\n\",\n    \"## The tasks will run on GPU memory, so there is no need to set a high host memory\\n\",\n    \"conf.set(\\\"spark.executor.memory\\\", executorMem)\\n\",\n    \"## The tasks will run on GPU cores, so there is no need to use many cpu cores\\n\",\n    \"conf.set(\\\"spark.executor.cores\\\", executorCores)\\n\",\n    \"\\n\",\n    \"# Plugin settings\\n\",\n    \"conf.set(\\\"spark.executor.resource.gpu.amount\\\", \\\"1\\\")\\n\",\n    \"conf.set(\\\"spark.rapids.sql.concurrentGpuTasks\\\", concurrentGpuTasks)\\n\",\n    \"conf.set(\\\"spark.rapids.memory.pinnedPool.size\\\", pinnedPoolSize)\\n\",\n    \"# since pyspark and xgboost share the same GPU, we disable RMM to avoid GPU OOM while training \\n\",\n    \"conf.set(\\\"spark.rapids.memory.gpu.pool\\\", \\\"NONE\\\")\\n\",\n    \"conf.set(\\\"spark.locality.wait\\\",\\\"0\\\")\\n\",\n    \"##############note: only support value=1 https://github.com/dmlc/xgboost/blame/master/python-package/xgboost/spark/core.py#L370-L374\\n\",\n    \"conf.set(\\\"spark.task.resource.gpu.amount\\\", 1) \\n\",\n    \"conf.set(\\\"spark.rapids.sql.enabled\\\", \\\"true\\\") \\n\",\n    \"conf.set(\\\"spark.plugins\\\", \\\"com.nvidia.spark.SQLPlugin\\\")\\n\",\n    \"conf.set(\\\"spark.sql.cache.serializer\\\",\\\"com.nvidia.spark.ParquetCachedBatchSerializer\\\")\\n\",\n    \"conf.set(\\\"spark.sql.execution.arrow.maxRecordsPerBatch\\\", 200000) \\n\",\n    \"conf.set(\\\"spark.driver.extraClassPath\\\", RAPIDS_JAR)\\n\",\n    \"conf.set(\\\"spark.executor.extraClassPath\\\", RAPIDS_JAR)\\n\",\n    \"# if you pass/unpack the archive file and enable the environment\\n\",\n    \"# conf.set(\\\"spark.yarn.dist.archives\\\", \\\"your_pyspark_venv.tar.gz#environment\\\")\\n\",\n    \"# Create spark session\\n\",\n    \"spark = SparkSession.builder.config(conf=conf).getOrCreate()\\n\",\n    \"\\n\",\n    \"reader = spark.read\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Specify the Data Schema and Load the Data\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"label = 'fare_amount'\\n\",\n    \"schema = StructType([\\n\",\n    \"    StructField('vendor_id', FloatType()),\\n\",\n    \"    StructField('passenger_count', FloatType()),\\n\",\n    \"    StructField('trip_distance', FloatType()),\\n\",\n    \"    StructField('pickup_longitude', FloatType()),\\n\",\n    \"    StructField('pickup_latitude', FloatType()),\\n\",\n    \"    StructField('rate_code', FloatType()),\\n\",\n    \"    StructField('store_and_fwd', FloatType()),\\n\",\n    \"    StructField('dropoff_longitude', FloatType()),\\n\",\n    \"    StructField('dropoff_latitude', FloatType()),\\n\",\n    \"    StructField(label, FloatType()),\\n\",\n    \"    StructField('hour', FloatType()),\\n\",\n    \"    StructField('year', IntegerType()),\\n\",\n    \"    StructField('month', IntegerType()),\\n\",\n    \"    StructField('day', FloatType()),\\n\",\n    \"    StructField('day_of_week', FloatType()),\\n\",\n    \"    StructField('is_weekend', FloatType()),\\n\",\n    \"])\\n\",\n    \"\\n\",\n    \"features = [ x.name for x in schema if x.name != label ]\\n\",\n    \"\\n\",\n    \"# You need to update them to your real paths!\\n\",\n    \"dataRoot = os.getenv(\\\"DATA_ROOT\\\", \\\"/data\\\")\\n\",\n    \"train_path = dataRoot + \\\"/taxi/csv/train\\\"\\n\",\n    \"eval_path = dataRoot + \\\"/taxi/csv/test\\\"\\n\",\n    \"\\n\",\n    \"data_format = 'csv'\\n\",\n    \"has_header = 'true'\\n\",\n    \"if data_format == 'csv':\\n\",\n    \"    train_data = reader.schema(schema).option('header',has_header).csv(train_path)\\n\",\n    \"    trans_data = reader.schema(schema).option('header',has_header).csv(eval_path)\\n\",\n    \"else :\\n\",\n    \"    train_data = reader.load(train_path)\\n\",\n    \"    trans_data = reader.load(eval_path)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Build a XGBoost-Spark CrossValidator\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# First build a regressor of GPU version using *setFeaturesCols* to set feature columns\\n\",\n    \"params = { \\n\",\n    \"    \\\"tree_method\\\": \\\"hist\\\",\\n\",\n    \"    \\\"grow_policy\\\": \\\"depthwise\\\",\\n\",\n    \"    \\\"num_workers\\\": 1,\\n\",\n    \"    \\\"device\\\": \\\"cuda\\\",\\n\",\n    \"}\\n\",\n    \"params['features_col'] = features\\n\",\n    \"params['label_col'] = label\\n\",\n    \"\\n\",\n    \"regressor = SparkXGBRegressor(**params)\\n\",\n    \"# Then build the evaluator and the hyperparameters\\n\",\n    \"evaluator = (RegressionEvaluator()\\n\",\n    \"    .setLabelCol(label))\\n\",\n    \"param_grid = (ParamGridBuilder()\\n\",\n    \"    .addGrid(regressor.max_depth, [3, 6])\\n\",\n    \"    .addGrid(regressor.n_estimators, [100, 200])\\n\",\n    \"    .build())\\n\",\n    \"# Finally the corss validator\\n\",\n    \"cross_validator = (CrossValidator()\\n\",\n    \"    .setEstimator(regressor)\\n\",\n    \"    .setEvaluator(evaluator)\\n\",\n    \"    .setEstimatorParamMaps(param_grid)\\n\",\n    \"    .setNumFolds(2))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Start Cross Validation by Fitting Data to CrossValidator\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"If features_cols param set, then features_col param is ignored.\\n\",\n      \"/data/home/yuanli/work/reviews/pr252/pyspark_venv_20221125/lib/python3.8/site-packages/xgboost/sklearn.py:808: UserWarning: Loading a native XGBoost model with Scikit-Learn interface.\\n\",\n      \"  warnings.warn(\\\"Loading a native XGBoost model with Scikit-Learn interface.\\\")\\n\",\n      \"2022-11-30 08:03:14,308 WARN rapids.GpuOverrides: \\n\",\n      \"! <DeserializeToObjectExec> cannot run on GPU because not all expressions can be replaced; GPU does not currently support the operator class org.apache.spark.sql.execution.DeserializeToObjectExec\\n\",\n      \"  ! <CreateExternalRow> createexternalrow(prediction#889, fare_amount#890, 1.0#891, StructField(prediction,DoubleType,true), StructField(fare_amount,DoubleType,true), StructField(1.0,DoubleType,false)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.CreateExternalRow\\n\",\n      \"    @Expression <AttributeReference> prediction#889 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> fare_amount#890 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> 1.0#891 could run on GPU\\n\",\n      \"  !Expression <AttributeReference> obj#895 cannot run on GPU because expression AttributeReference obj#895 produces an unsupported type ObjectType(interface org.apache.spark.sql.Row)\\n\",\n      \"\\n\",\n      \"2022-11-30 08:03:14,317 WARN util.package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.\\n\",\n      \"If features_cols param set, then features_col param is ignored.\\n\",\n      \"2022-11-30 08:03:20,073 WARN rapids.GpuOverrides:                               \\n\",\n      \"! <DeserializeToObjectExec> cannot run on GPU because not all expressions can be replaced; GPU does not currently support the operator class org.apache.spark.sql.execution.DeserializeToObjectExec\\n\",\n      \"  ! <CreateExternalRow> createexternalrow(prediction#1789, fare_amount#1790, 1.0#1791, StructField(prediction,DoubleType,true), StructField(fare_amount,DoubleType,true), StructField(1.0,DoubleType,false)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.CreateExternalRow\\n\",\n      \"    @Expression <AttributeReference> prediction#1789 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> fare_amount#1790 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> 1.0#1791 could run on GPU\\n\",\n      \"  !Expression <AttributeReference> obj#1795 cannot run on GPU because expression AttributeReference obj#1795 produces an unsupported type ObjectType(interface org.apache.spark.sql.Row)\\n\",\n      \"\\n\",\n      \"If features_cols param set, then features_col param is ignored.\\n\",\n      \"2022-11-30 08:03:23,687 WARN rapids.GpuOverrides:                               \\n\",\n      \"! <DeserializeToObjectExec> cannot run on GPU because not all expressions can be replaced; GPU does not currently support the operator class org.apache.spark.sql.execution.DeserializeToObjectExec\\n\",\n      \"  ! <CreateExternalRow> createexternalrow(prediction#2689, fare_amount#2690, 1.0#2691, StructField(prediction,DoubleType,true), StructField(fare_amount,DoubleType,true), StructField(1.0,DoubleType,false)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.CreateExternalRow\\n\",\n      \"    @Expression <AttributeReference> prediction#2689 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> fare_amount#2690 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> 1.0#2691 could run on GPU\\n\",\n      \"  !Expression <AttributeReference> obj#2695 cannot run on GPU because expression AttributeReference obj#2695 produces an unsupported type ObjectType(interface org.apache.spark.sql.Row)\\n\",\n      \"\\n\",\n      \"If features_cols param set, then features_col param is ignored.\\n\",\n      \"2022-11-30 08:03:27,457 WARN rapids.GpuOverrides:                               \\n\",\n      \"! <DeserializeToObjectExec> cannot run on GPU because not all expressions can be replaced; GPU does not currently support the operator class org.apache.spark.sql.execution.DeserializeToObjectExec\\n\",\n      \"  ! <CreateExternalRow> createexternalrow(prediction#3589, fare_amount#3590, 1.0#3591, StructField(prediction,DoubleType,true), StructField(fare_amount,DoubleType,true), StructField(1.0,DoubleType,false)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.CreateExternalRow\\n\",\n      \"    @Expression <AttributeReference> prediction#3589 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> fare_amount#3590 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> 1.0#3591 could run on GPU\\n\",\n      \"  !Expression <AttributeReference> obj#3595 cannot run on GPU because expression AttributeReference obj#3595 produces an unsupported type ObjectType(interface org.apache.spark.sql.Row)\\n\",\n      \"\\n\",\n      \"If features_cols param set, then features_col param is ignored.\\n\",\n      \"2022-11-30 08:03:30,964 WARN rapids.GpuOverrides:                               \\n\",\n      \"! <DeserializeToObjectExec> cannot run on GPU because not all expressions can be replaced; GPU does not currently support the operator class org.apache.spark.sql.execution.DeserializeToObjectExec\\n\",\n      \"  ! <CreateExternalRow> createexternalrow(prediction#4659, fare_amount#4660, 1.0#4661, StructField(prediction,DoubleType,true), StructField(fare_amount,DoubleType,true), StructField(1.0,DoubleType,false)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.CreateExternalRow\\n\",\n      \"    @Expression <AttributeReference> prediction#4659 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> fare_amount#4660 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> 1.0#4661 could run on GPU\\n\",\n      \"  !Expression <AttributeReference> obj#4665 cannot run on GPU because expression AttributeReference obj#4665 produces an unsupported type ObjectType(interface org.apache.spark.sql.Row)\\n\",\n      \"\\n\",\n      \"If features_cols param set, then features_col param is ignored.\\n\",\n      \"2022-11-30 08:03:34,524 WARN rapids.GpuOverrides:                               \\n\",\n      \"! <DeserializeToObjectExec> cannot run on GPU because not all expressions can be replaced; GPU does not currently support the operator class org.apache.spark.sql.execution.DeserializeToObjectExec\\n\",\n      \"  ! <CreateExternalRow> createexternalrow(prediction#5559, fare_amount#5560, 1.0#5561, StructField(prediction,DoubleType,true), StructField(fare_amount,DoubleType,true), StructField(1.0,DoubleType,false)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.CreateExternalRow\\n\",\n      \"    @Expression <AttributeReference> prediction#5559 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> fare_amount#5560 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> 1.0#5561 could run on GPU\\n\",\n      \"  !Expression <AttributeReference> obj#5565 cannot run on GPU because expression AttributeReference obj#5565 produces an unsupported type ObjectType(interface org.apache.spark.sql.Row)\\n\",\n      \"\\n\",\n      \"If features_cols param set, then features_col param is ignored.\\n\",\n      \"2022-11-30 08:03:38,067 WARN rapids.GpuOverrides:                               \\n\",\n      \"! <DeserializeToObjectExec> cannot run on GPU because not all expressions can be replaced; GPU does not currently support the operator class org.apache.spark.sql.execution.DeserializeToObjectExec\\n\",\n      \"  ! <CreateExternalRow> createexternalrow(prediction#6459, fare_amount#6460, 1.0#6461, StructField(prediction,DoubleType,true), StructField(fare_amount,DoubleType,true), StructField(1.0,DoubleType,false)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.CreateExternalRow\\n\",\n      \"    @Expression <AttributeReference> prediction#6459 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> fare_amount#6460 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> 1.0#6461 could run on GPU\\n\",\n      \"  !Expression <AttributeReference> obj#6465 cannot run on GPU because expression AttributeReference obj#6465 produces an unsupported type ObjectType(interface org.apache.spark.sql.Row)\\n\",\n      \"\\n\",\n      \"If features_cols param set, then features_col param is ignored.\\n\",\n      \"2022-11-30 08:03:41,793 WARN rapids.GpuOverrides:                               \\n\",\n      \"! <DeserializeToObjectExec> cannot run on GPU because not all expressions can be replaced; GPU does not currently support the operator class org.apache.spark.sql.execution.DeserializeToObjectExec\\n\",\n      \"  ! <CreateExternalRow> createexternalrow(prediction#7359, fare_amount#7360, 1.0#7361, StructField(prediction,DoubleType,true), StructField(fare_amount,DoubleType,true), StructField(1.0,DoubleType,false)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.CreateExternalRow\\n\",\n      \"    @Expression <AttributeReference> prediction#7359 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> fare_amount#7360 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> 1.0#7361 could run on GPU\\n\",\n      \"  !Expression <AttributeReference> obj#7365 cannot run on GPU because expression AttributeReference obj#7365 produces an unsupported type ObjectType(interface org.apache.spark.sql.Row)\\n\",\n      \"\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"If features_cols param set, then features_col param is ignored.\\n\",\n      \"[Stage 34:>                                                         (0 + 1) / 1]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Cross-Validation takes 55.19 seconds\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\r\",\n      \"                                                                                \\r\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"def with_benchmark(phrase, action):\\n\",\n    \"    start = time()\\n\",\n    \"    result = action()\\n\",\n    \"    end = time()\\n\",\n    \"    print('{} takes {} seconds'.format(phrase, round(end - start, 2)))\\n\",\n    \"    return result\\n\",\n    \"model = with_benchmark('Cross-Validation', lambda: cross_validator.fit(train_data)).bestModel\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Transform On the Best Model\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Transforming takes 0.23 seconds\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2022-11-30 08:03:45,503 WARN rapids.GpuOverrides: \\n\",\n      \"!Exec <CollectLimitExec> cannot run on GPU because the Exec CollectLimitExec has been disabled, and is disabled by default because Collect Limit replacement can be slower on the GPU, if huge number of rows in a batch it could help by limiting the number of rows transferred from GPU to CPU. Set spark.rapids.sql.exec.CollectLimitExec to true if you wish to enable it\\n\",\n      \"  @Partitioning <SinglePartition$> could run on GPU\\n\",\n      \"\\n\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+-----------+-----------+\\n\",\n      \"|fare_amount| prediction|\\n\",\n      \"+-----------+-----------+\\n\",\n      \"|        5.0| 5.01032114|\\n\",\n      \"|       34.0|  31.134758|\\n\",\n      \"|       10.0|9.288980484|\\n\",\n      \"|       16.5|15.33446312|\\n\",\n      \"|        7.0|8.197098732|\\n\",\n      \"+-----------+-----------+\\n\",\n      \"only showing top 5 rows\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"def transform():\\n\",\n    \"    result = model.transform(trans_data).cache()\\n\",\n    \"    result.foreachPartition(lambda _: None)\\n\",\n    \"    return result\\n\",\n    \"result = with_benchmark('Transforming', transform)\\n\",\n    \"result.select(label, 'prediction').show(5)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Evaluation\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Evaluation takes 0.05 seconds\\n\",\n      \"RMSE is 2.055690464034438\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2022-11-30 08:03:45,728 WARN rapids.GpuOverrides: \\n\",\n      \"! <DeserializeToObjectExec> cannot run on GPU because not all expressions can be replaced; GPU does not currently support the operator class org.apache.spark.sql.execution.DeserializeToObjectExec\\n\",\n      \"  ! <CreateExternalRow> createexternalrow(prediction#7645, fare_amount#8271, 1.0#8272, StructField(prediction,DoubleType,true), StructField(fare_amount,DoubleType,true), StructField(1.0,DoubleType,false)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.CreateExternalRow\\n\",\n      \"    @Expression <AttributeReference> prediction#7645 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> fare_amount#8271 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> 1.0#8272 could run on GPU\\n\",\n      \"  !Expression <AttributeReference> obj#8276 cannot run on GPU because expression AttributeReference obj#8276 produces an unsupported type ObjectType(interface org.apache.spark.sql.Row)\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"accuracy = with_benchmark(\\n\",\n    \"    'Evaluation',\\n\",\n    \"    lambda: RegressionEvaluator().setLabelCol(label).evaluate(result))\\n\",\n    \"print('RMSE is ' + str(accuracy))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"spark.stop()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.2\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "examples/XGBoost-Examples/taxi/notebooks/python/taxi-ETL.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"71bf747a\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Introduction to Taxi ETL Job\\n\",\n    \"This is the Taxi ETL job to generate the input datasets for the Taxi XGBoost job.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"f0524408\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Prerequirement\\n\",\n    \"### 1. Download data\\n\",\n    \"All data could be found at https://www1.nyc.gov/site/tlc/about/tlc-trip-record-data.page\\n\",\n    \"\\n\",\n    \"### 2. Download needed jars\\n\",\n    \"* [rapids-4-spark_2.12-26.02.0.jar](https://repo1.maven.org/maven2/com/nvidia/rapids-4-spark_2.12/26.02.0/rapids-4-spark_2.12-26.02.0.jar)\\n\",\n    \"\\n\",\n    \"### 3. Start Spark Standalone\\n\",\n    \"Before running the script, please setup Spark standalone mode\\n\",\n    \"\\n\",\n    \"### 4. Add ENV\\n\",\n    \"```\\n\",\n    \"$ export SPARK_JARS=rapids-4-spark_2.12-26.02.0.jar\\n\",\n    \"$ export PYSPARK_DRIVER_PYTHON=jupyter \\n\",\n    \"$ export PYSPARK_DRIVER_PYTHON_OPTS=notebook\\n\",\n    \"```\\n\",\n    \"\\n\",\n    \"### 5. Start Jupyter Notebook with plugin config\\n\",\n    \"\\n\",\n    \"```\\n\",\n    \"$ pyspark --master ${SPARK_MASTER}            \\\\\\n\",\n    \"--jars ${SPARK_JARS}                \\\\\\n\",\n    \"--conf spark.plugins=com.nvidia.spark.SQLPlugin \\\\\\n\",\n    \"--conf spark.rapids.sql.incompatibleDateFormats.enabled=true \\\\\\n\",\n    \"--conf spark.rapids.sql.csv.read.double.enabled=true \\\\\\n\",\n    \"--py-files ${SPARK_PY_FILES}\\n\",\n    \"```\\n\",\n    \"\\n\",\n    \"## Import Libs\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"id\": \"d2283aab\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import time\\n\",\n    \"import os\\n\",\n    \"import math\\n\",\n    \"from pyspark.sql import SparkSession\\n\",\n    \"from pyspark.sql.functions import *\\n\",\n    \"from pyspark.sql.types import *\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"f7ffcace\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Script Settings\\n\",\n    \"\\n\",\n    \"###  File Path Settings\\n\",\n    \"* Define input/output file path\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"id\": \"b348778a\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# You need to update them to your real paths! You can download the dataset \\n\",\n    \"# from https://www1.nyc.gov/site/tlc/about/tlc-trip-record-data.page\\n\",\n    \"# or you can just unzip datasets/taxi-small.tar.gz and use the provided\\n\",\n    \"# sample dataset datasets/taxi/taxi-etl-input-small.csv\\n\",\n    \"dataRoot = os.getenv('DATA_ROOT', '/data')\\n\",\n    \"rawPath = dataRoot + '/taxi/taxi-etl-input-small.csv'\\n\",\n    \"outPath = dataRoot + '/taxi/output'\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"0a500530\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Function and Object Define\\n\",\n    \"### Define the constants\\n\",\n    \"\\n\",\n    \"* Define input file schema\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"id\": \"094f31c5\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"raw_schema = StructType([\\n\",\n    \"    StructField('vendor_id', StringType()),\\n\",\n    \"    StructField('pickup_datetime', StringType()),\\n\",\n    \"    StructField('dropoff_datetime', StringType()),\\n\",\n    \"    StructField('passenger_count', IntegerType()),\\n\",\n    \"    StructField('trip_distance', DoubleType()),\\n\",\n    \"    StructField('pickup_longitude', DoubleType()),\\n\",\n    \"    StructField('pickup_latitude', DoubleType()),\\n\",\n    \"    StructField('rate_code', StringType()),\\n\",\n    \"    StructField('store_and_fwd_flag', StringType()),\\n\",\n    \"    StructField('dropoff_longitude', DoubleType()),\\n\",\n    \"    StructField('dropoff_latitude', DoubleType()),\\n\",\n    \"    StructField('payment_type', StringType()),\\n\",\n    \"    StructField('fare_amount', DoubleType()),\\n\",\n    \"    StructField('surcharge', DoubleType()),\\n\",\n    \"    StructField('mta_tax', DoubleType()),\\n\",\n    \"    StructField('tip_amount', DoubleType()),\\n\",\n    \"    StructField('tolls_amount', DoubleType()),\\n\",\n    \"    StructField('total_amount', DoubleType()),\\n\",\n    \"])\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"72a4ae18\",\n   \"metadata\": {},\n   \"source\": [\n    \"* Define some ETL functions\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"id\": \"b45b7606\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def drop_useless(data_frame):\\n\",\n    \"    return data_frame.drop(\\n\",\n    \"        'dropoff_datetime',\\n\",\n    \"        'payment_type',\\n\",\n    \"        'surcharge',\\n\",\n    \"        'mta_tax',\\n\",\n    \"        'tip_amount',\\n\",\n    \"        'tolls_amount',\\n\",\n    \"        'total_amount')\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"id\": \"7af7073d\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def encode_categories(data_frame):\\n\",\n    \"    categories = [ 'vendor_id', 'rate_code', 'store_and_fwd_flag' ]\\n\",\n    \"    for category in categories:\\n\",\n    \"        data_frame = data_frame.withColumn(category, hash(col(category)))\\n\",\n    \"    return data_frame.withColumnRenamed(\\\"store_and_fwd_flag\\\", \\\"store_and_fwd\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"id\": \"b799cd5a\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def fill_na(data_frame):\\n\",\n    \"    return data_frame.fillna(-1)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"id\": \"ceee5c7c\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def remove_invalid(data_frame):\\n\",\n    \"    conditions = [\\n\",\n    \"        ( 'fare_amount', 0, 500 ),\\n\",\n    \"        ( 'passenger_count', 0, 6 ),\\n\",\n    \"        ( 'pickup_longitude', -75, -73 ),\\n\",\n    \"        ( 'dropoff_longitude', -75, -73 ),\\n\",\n    \"        ( 'pickup_latitude', 40, 42 ),\\n\",\n    \"        ( 'dropoff_latitude', 40, 42 ),\\n\",\n    \"    ]\\n\",\n    \"    for column, min, max in conditions:\\n\",\n    \"        data_frame = data_frame.filter('{} > {} and {} < {}'.format(column, min, column, max))\\n\",\n    \"    return data_frame\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"id\": \"bd28ae14\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def convert_datetime(data_frame):\\n\",\n    \"    datetime = col('pickup_datetime')\\n\",\n    \"    return (data_frame\\n\",\n    \"        .withColumn('pickup_datetime', to_timestamp(datetime))\\n\",\n    \"        .withColumn('year', year(datetime))\\n\",\n    \"        .withColumn('month', month(datetime))\\n\",\n    \"        .withColumn('day', dayofmonth(datetime))\\n\",\n    \"        .withColumn('day_of_week', dayofweek(datetime))\\n\",\n    \"        .withColumn(\\n\",\n    \"            'is_weekend',\\n\",\n    \"            col('day_of_week').isin(1, 7).cast(IntegerType()))  # 1: Sunday, 7: Saturday\\n\",\n    \"        .withColumn('hour', hour(datetime))\\n\",\n    \"        .drop('pickup_datetime'))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 9,\n   \"id\": \"39e45f15\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def add_h_distance(data_frame):\\n\",\n    \"    p = math.pi / 180\\n\",\n    \"    lat1 = col('pickup_latitude')\\n\",\n    \"    lon1 = col('pickup_longitude')\\n\",\n    \"    lat2 = col('dropoff_latitude')\\n\",\n    \"    lon2 = col('dropoff_longitude')\\n\",\n    \"    internal_value = (0.5\\n\",\n    \"        - cos((lat2 - lat1) * p) / 2\\n\",\n    \"        + cos(lat1 * p) * cos(lat2 * p) * (1 - cos((lon2 - lon1) * p)) / 2)\\n\",\n    \"    h_distance = 12734 * asin(sqrt(internal_value))\\n\",\n    \"    return data_frame.withColumn('h_distance', h_distance)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"d52b062c\",\n   \"metadata\": {},\n   \"source\": [\n    \"* Define main ETL function\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 10,\n   \"id\": \"9fd36618\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def pre_process(data_frame):\\n\",\n    \"    processes = [\\n\",\n    \"        drop_useless,\\n\",\n    \"        encode_categories,\\n\",\n    \"        fill_na,\\n\",\n    \"        remove_invalid,\\n\",\n    \"        convert_datetime,\\n\",\n    \"        add_h_distance,\\n\",\n    \"    ]\\n\",\n    \"    for process in processes:\\n\",\n    \"        data_frame = process(data_frame)\\n\",\n    \"    return data_frame\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"2798f19a\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Run ETL Process and Save the Result\\n\",\n    \"* Create Spark Session and create dataframe\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 11,\n   \"id\": \"26ca4ca6\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"spark = (SparkSession\\n\",\n    \"    .builder\\n\",\n    \"    .appName(\\\"Taxi-ETL\\\")\\n\",\n    \"    .getOrCreate())\\n\",\n    \"reader = (spark\\n\",\n    \"        .read\\n\",\n    \"        .format('csv'))\\n\",\n    \"reader.schema(raw_schema).option('header', 'True')\\n\",\n    \"\\n\",\n    \"raw_data = reader.load(rawPath)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"6243b736\",\n   \"metadata\": {},\n   \"source\": [\n    \"* Run ETL Process and Save the Result\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 12,\n   \"id\": \"27f2119b\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"5.114504098892212\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"start = time.time()\\n\",\n    \"etled_train, etled_eval, etled_trans = pre_process(raw_data).randomSplit(list(map(float, (80,20,0))))\\n\",\n    \"etled_train.write.mode(\\\"overwrite\\\").parquet(outPath+'/train')\\n\",\n    \"etled_eval.write.mode(\\\"overwrite\\\").parquet(outPath+'/eval')\\n\",\n    \"etled_trans.write.mode(\\\"overwrite\\\").parquet(outPath+'/trans')\\n\",\n    \"end = time.time()\\n\",\n    \"print(end - start)\\n\",\n    \"spark.stop()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"91af3c97\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.2\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "examples/XGBoost-Examples/taxi/notebooks/python/taxi-gpu.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Introduction to XGBoost Spark3.1 with GPU\\n\",\n    \"\\n\",\n    \"Taxi is an example of xgboost regressor. This notebook will show you how to load data, train the xgboost model and use this model to predict \\\"fare_amount\\\" of your taxi trip.\\n\",\n    \"\\n\",\n    \"A few libraries required for this notebook:\\n\",\n    \"  1. cudf-cu11\\n\",\n    \"  2. xgboost\\n\",\n    \"  3. scikit-learn\\n\",\n    \"  4. numpy\\n\",\n    \"\\n\",\n    \"This notebook also illustrates the ease of porting a sample CPU based Spark xgboost4j code into GPU. There is no change required for running Spark XGBoost on GPU because both CPU and GPU call the same API. For CPU run, we need to vectorize the trained dataset before fitting data to regressor.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Import Required Libraries\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from xgboost.spark import SparkXGBRegressor, SparkXGBRegressorModel\\n\",\n    \"from pyspark.ml.evaluation import RegressionEvaluator\\n\",\n    \"from pyspark.sql import SparkSession\\n\",\n    \"from pyspark.sql.types import FloatType, IntegerType, StructField, StructType\\n\",\n    \"from time import time\\n\",\n    \"from pyspark.conf import SparkConf\\n\",\n    \"import os\\n\",\n    \"# if you pass/unpack the archive file and enable the environment\\n\",\n    \"# os.environ['PYSPARK_PYTHON'] = \\\"./environment/bin/python\\\"\\n\",\n    \"# os.environ['PYSPARK_DRIVER_PYTHON'] = \\\"./environment/bin/python\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Besides CPU version requires two extra libraries.\\n\",\n    \"```Python\\n\",\n    \"from pyspark.ml.feature import VectorAssembler\\n\",\n    \"from pyspark.sql.functions import col\\n\",\n    \"```\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Create Spark Session and Data Reader\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2022-11-30 07:51:19,104 WARN util.NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\\n\",\n      \"Setting default log level to \\\"WARN\\\".\\n\",\n      \"To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\\n\",\n      \"2022-11-30 07:51:19,480 WARN resource.ResourceUtils: The configuration of cores (exec = 2 task = 1, runnable tasks = 2) will result in wasted resources due to resource gpu limiting the number of runnable tasks per executor to: 1. Please adjust your configuration.\\n\",\n      \"2022-11-30 07:51:33,277 WARN rapids.RapidsPluginUtils: RAPIDS Accelerator 25.02.1 using cudf 25.02.1.\\n\",\n      \"2022-11-30 07:51:33,292 WARN rapids.RapidsPluginUtils: spark.rapids.sql.multiThreadedRead.numThreads is set to 20.\\n\",\n      \"2022-11-30 07:51:33,295 WARN rapids.RapidsPluginUtils: RAPIDS Accelerator is enabled, to disable GPU support set `spark.rapids.sql.enabled` to false.\\n\",\n      \"2022-11-30 07:51:33,295 WARN rapids.RapidsPluginUtils: spark.rapids.sql.explain is set to `NOT_ON_GPU`. Set it to 'NONE' to suppress the diagnostics logging about the query placement on the GPU.\\n\",\n      \"2022-11-30 07:51:33,798 WARN yarn.Client: Neither spark.yarn.jars nor spark.yarn.archive is set, falling back to uploading libraries under SPARK_HOME.\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"SPARK_MASTER_URL = os.getenv(\\\"SPARK_MASTER_URL\\\", \\\"/your-url\\\")\\n\",\n    \"\\n\",\n    \"RAPIDS_JAR = os.getenv(\\\"RAPIDS_JAR\\\", \\\"/your-jar-path\\\")\\n\",\n    \"\\n\",\n    \"# You need to update with your real hardware resource \\n\",\n    \"driverMem = os.getenv(\\\"DRIVER_MEM\\\", \\\"2g\\\")\\n\",\n    \"executorMem = os.getenv(\\\"EXECUTOR_MEM\\\", \\\"2g\\\")\\n\",\n    \"pinnedPoolSize = os.getenv(\\\"PINNED_POOL_SIZE\\\", \\\"2g\\\")\\n\",\n    \"concurrentGpuTasks = os.getenv(\\\"CONCURRENT_GPU_TASKS\\\", \\\"2\\\")\\n\",\n    \"executorCores = int(os.getenv(\\\"EXECUTOR_CORES\\\", \\\"2\\\"))\\n\",\n    \"# Common spark settings\\n\",\n    \"conf = SparkConf()\\n\",\n    \"conf.setMaster(SPARK_MASTER_URL)\\n\",\n    \"conf.setAppName(\\\"Microbenchmark on GPU\\\")\\n\",\n    \"conf.set(\\\"spark.executor.instances\\\",\\\"1\\\")\\n\",\n    \"conf.set(\\\"spark.driver.memory\\\", driverMem)\\n\",\n    \"## The tasks will run on GPU memory, so there is no need to set a high host memory\\n\",\n    \"conf.set(\\\"spark.executor.memory\\\", executorMem)\\n\",\n    \"## The tasks will run on GPU cores, so there is no need to use many cpu cores\\n\",\n    \"conf.set(\\\"spark.executor.cores\\\", executorCores)\\n\",\n    \"\\n\",\n    \"# Plugin settings\\n\",\n    \"conf.set(\\\"spark.executor.resource.gpu.amount\\\", \\\"1\\\")\\n\",\n    \"conf.set(\\\"spark.rapids.sql.concurrentGpuTasks\\\", concurrentGpuTasks)\\n\",\n    \"conf.set(\\\"spark.rapids.memory.pinnedPool.size\\\", pinnedPoolSize)\\n\",\n    \"# since pyspark and xgboost share the same GPU, we disable RMM to avoid GPU OOM while training \\n\",\n    \"conf.set(\\\"spark.rapids.memory.gpu.pool\\\", \\\"NONE\\\")\\n\",\n    \"conf.set(\\\"spark.locality.wait\\\",\\\"0\\\")\\n\",\n    \"##############note: only support value=1 https://github.com/dmlc/xgboost/blame/master/python-package/xgboost/spark/core.py#L370-L374\\n\",\n    \"conf.set(\\\"spark.task.resource.gpu.amount\\\", 1) \\n\",\n    \"conf.set(\\\"spark.rapids.sql.enabled\\\", \\\"true\\\") \\n\",\n    \"conf.set(\\\"spark.plugins\\\", \\\"com.nvidia.spark.SQLPlugin\\\")\\n\",\n    \"conf.set(\\\"spark.sql.cache.serializer\\\",\\\"com.nvidia.spark.ParquetCachedBatchSerializer\\\")\\n\",\n    \"conf.set(\\\"spark.sql.execution.arrow.maxRecordsPerBatch\\\", 200000) \\n\",\n    \"conf.set(\\\"spark.driver.extraClassPath\\\", RAPIDS_JAR)\\n\",\n    \"conf.set(\\\"spark.executor.extraClassPath\\\", RAPIDS_JAR)\\n\",\n    \"\\n\",\n    \"# if you pass/unpack the archive file and enable the environment\\n\",\n    \"# conf.set(\\\"spark.yarn.dist.archives\\\", \\\"your_pyspark_venv.tar.gz#environment\\\")\\n\",\n    \"# Create spark session\\n\",\n    \"spark = SparkSession.builder.config(conf=conf).getOrCreate()\\n\",\n    \"\\n\",\n    \"reader = spark.read\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Specify the Data Schema and Load the Data\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"label = 'fare_amount'\\n\",\n    \"schema = StructType([\\n\",\n    \"    StructField('vendor_id', FloatType()),\\n\",\n    \"    StructField('passenger_count', FloatType()),\\n\",\n    \"    StructField('trip_distance', FloatType()),\\n\",\n    \"    StructField('pickup_longitude', FloatType()),\\n\",\n    \"    StructField('pickup_latitude', FloatType()),\\n\",\n    \"    StructField('rate_code', FloatType()),\\n\",\n    \"    StructField('store_and_fwd', FloatType()),\\n\",\n    \"    StructField('dropoff_longitude', FloatType()),\\n\",\n    \"    StructField('dropoff_latitude', FloatType()),\\n\",\n    \"    StructField(label, FloatType()),\\n\",\n    \"    StructField('hour', FloatType()),\\n\",\n    \"    StructField('year', IntegerType()),\\n\",\n    \"    StructField('month', IntegerType()),\\n\",\n    \"    StructField('day', FloatType()),\\n\",\n    \"    StructField('day_of_week', FloatType()),\\n\",\n    \"    StructField('is_weekend', FloatType()),\\n\",\n    \"])\\n\",\n    \"features = [ x.name for x in schema if x.name != label ]\\n\",\n    \"\\n\",\n    \"# You need to update them to your real paths!\\n\",\n    \"dataRoot = os.getenv(\\\"DATA_ROOT\\\", \\\"/data\\\")\\n\",\n    \"train_path = dataRoot + \\\"/taxi/csv/train\\\"\\n\",\n    \"eval_path = dataRoot + \\\"/taxi/csv/test\\\"\\n\",\n    \"\\n\",\n    \"data_format = 'csv'\\n\",\n    \"has_header = 'true'\\n\",\n    \"if data_format == 'csv':\\n\",\n    \"    train_data = reader.schema(schema).option('header',has_header).csv(train_path)\\n\",\n    \"    trans_data = reader.schema(schema).option('header',has_header).csv(eval_path)\\n\",\n    \"else :\\n\",\n    \"    train_data = reader.load(train_path)\\n\",\n    \"    trans_data = reader.load(eval_path)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Note on CPU version, vectorization is required before fitting data to regressor, which means you need to assemble all feature columns into one column.\\n\",\n    \"\\n\",\n    \"```Python\\n\",\n    \"def vectorize(data_frame):\\n\",\n    \"    to_floats = [ col(x.name).cast(FloatType()) for x in data_frame.schema ]\\n\",\n    \"    return (VectorAssembler()\\n\",\n    \"        .setInputCols(features)\\n\",\n    \"        .setOutputCol('features')\\n\",\n    \"        .transform(data_frame.select(to_floats))\\n\",\n    \"        .select(col('features'), col(label)))\\n\",\n    \"\\n\",\n    \"train_data = vectorize(train_data)\\n\",\n    \"trans_data = vectorize(trans_data)\\n\",\n    \"```\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Create a XGBoostRegressor\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"params = { \\n\",\n    \"    \\\"tree_method\\\": \\\"hist\\\",\\n\",\n    \"    \\\"grow_policy\\\": \\\"depthwise\\\",\\n\",\n    \"    \\\"num_workers\\\": 1,\\n\",\n    \"    \\\"device\\\": \\\"cuda\\\",\\n\",\n    \"}\\n\",\n    \"params['features_col'] = features\\n\",\n    \"params['label_col'] = label\\n\",\n    \"    \\n\",\n    \"regressor = SparkXGBRegressor(**params)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"The parameter `num_workers` should be set to the number of GPUs in Spark cluster for GPU version, while for CPU version it is usually equal to the number of the CPU cores.\\n\",\n    \"\\n\",\n    \"Concerning the `device`, GPU version only supports `cuda` currently, while `cpu` is designed and used here for CPU training.\\n\",\n    \"\\n\",\n    \"An example of CPU classifier:\\n\",\n    \"```\\n\",\n    \"classifier = SparkXGBClassifier(\\n\",\n    \"  feature_col=features,\\n\",\n    \"  label_col=label,  \\n\",\n    \"  num_workers=1024,\\n\",\n    \")\\n\",\n    \"```\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Train the Data with Benchmark\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"metadata\": {\n    \"scrolled\": true\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"If features_cols param set, then features_col param is ignored.\\n\",\n      \"[Stage 2:>                                                          (0 + 1) / 1]\\r\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Training takes 24.12 seconds\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\r\",\n      \"                                                                                \\r\",\n      \"/data/home/yuanli/work/reviews/pr252/pyspark_venv_20221125/lib/python3.8/site-packages/xgboost/sklearn.py:808: UserWarning: Loading a native XGBoost model with Scikit-Learn interface.\\n\",\n      \"  warnings.warn(\\\"Loading a native XGBoost model with Scikit-Learn interface.\\\")\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"def with_benchmark(phrase, action):\\n\",\n    \"    start = time()\\n\",\n    \"    result = action()\\n\",\n    \"    end = time()\\n\",\n    \"    print('{} takes {} seconds'.format(phrase, round(end - start, 2)))\\n\",\n    \"    return result\\n\",\n    \"model = with_benchmark('Training', lambda: regressor.fit(train_data))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Save and Reload the Model\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"If features_cols param set, then features_col param is ignored.\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"model.write().overwrite().save(dataRoot + '/model/taxi')\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"loaded_model = SparkXGBRegressorModel().load(dataRoot + '/model/taxi')\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Transformation and Show Result Sample\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"metadata\": {\n    \"scrolled\": false\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2022-11-30 07:52:27,357 WARN util.package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.\\n\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Transformation takes 0.93 seconds\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2022-11-30 07:52:28,189 WARN rapids.GpuOverrides: \\n\",\n      \"!Exec <CollectLimitExec> cannot run on GPU because the Exec CollectLimitExec has been disabled, and is disabled by default because Collect Limit replacement can be slower on the GPU, if huge number of rows in a batch it could help by limiting the number of rows transferred from GPU to CPU. Set spark.rapids.sql.exec.CollectLimitExec to true if you wish to enable it\\n\",\n      \"  @Partitioning <SinglePartition$> could run on GPU\\n\",\n      \"\\n\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+--------------+---------------+-------------+-----------+-----------+\\n\",\n      \"|     vendor_id|passenger_count|trip_distance|fare_amount| prediction|\\n\",\n      \"+--------------+---------------+-------------+-----------+-----------+\\n\",\n      \"|1.559730432E09|            2.0|  0.699999988|        5.0|5.046935558|\\n\",\n      \"|1.559730432E09|            3.0|  10.69999981|       34.0|31.72706413|\\n\",\n      \"|1.559730432E09|            1.0|  2.299999952|       10.0|9.294451714|\\n\",\n      \"|1.559730432E09|            1.0|  4.400000095|       16.5|15.05233097|\\n\",\n      \"|1.559730432E09|            1.0|          1.5|        7.0|8.995832443|\\n\",\n      \"+--------------+---------------+-------------+-----------+-----------+\\n\",\n      \"only showing top 5 rows\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"def transform():\\n\",\n    \"    result = loaded_model.transform(trans_data).cache()\\n\",\n    \"    result.foreachPartition(lambda _: None)\\n\",\n    \"    return result\\n\",\n    \"result = with_benchmark('Transformation', transform)\\n\",\n    \"result.select('vendor_id', 'passenger_count', 'trip_distance', label, 'prediction').show(5)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Note on CPU version: You cannot `select` the feature columns after vectorization. So please use `result.show(5)` instead.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Evaluation\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 9,\n   \"metadata\": {\n    \"scrolled\": true\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Evaluation takes 0.22 seconds\\n\",\n      \"RMSE is 1.9141528471228921\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"2022-11-30 07:52:28,580 WARN rapids.GpuOverrides: \\n\",\n      \"! <DeserializeToObjectExec> cannot run on GPU because not all expressions can be replaced; GPU does not currently support the operator class org.apache.spark.sql.execution.DeserializeToObjectExec\\n\",\n      \"  ! <CreateExternalRow> createexternalrow(prediction#87, fare_amount#728, 1.0#729, StructField(prediction,DoubleType,true), StructField(fare_amount,DoubleType,true), StructField(1.0,DoubleType,false)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.CreateExternalRow\\n\",\n      \"    @Expression <AttributeReference> prediction#87 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> fare_amount#728 could run on GPU\\n\",\n      \"    @Expression <AttributeReference> 1.0#729 could run on GPU\\n\",\n      \"  !Expression <AttributeReference> obj#733 cannot run on GPU because expression AttributeReference obj#733 produces an unsupported type ObjectType(interface org.apache.spark.sql.Row)\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"accuracy = with_benchmark(\\n\",\n    \"    'Evaluation',\\n\",\n    \"    lambda: RegressionEvaluator().setLabelCol(label).evaluate(result))\\n\",\n    \"print('RMSE is ' + str(accuracy))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Stop\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 10,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"spark.stop()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.2\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "examples/XGBoost-Examples/taxi/notebooks/scala/taxi-ETL.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"e0336840\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Introduction to Taxi ETL Job\\n\",\n    \"This is the Taxi ETL job to generate the input datasets for the Taxi XGBoost job.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"86fd8ad9\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Prerequirement\\n\",\n    \"### 1. Download data\\n\",\n    \"All data could be found at https://www1.nyc.gov/site/tlc/about/tlc-trip-record-data.page\\n\",\n    \"\\n\",\n    \"### 2. Download needed jar\\n\",\n    \"* [rapids-4-spark_2.12-26.02.0.jar](https://repo1.maven.org/maven2/com/nvidia/rapids-4-spark_2.12/26.02.0/rapids-4-spark_2.12-26.02.0.jar)\\n\",\n    \"\\n\",\n    \"### 3. Start Spark Standalone\\n\",\n    \"Before running the script, please setup Spark standalone mode\\n\",\n    \"\\n\",\n    \"### 4. Add ENV\\n\",\n    \"```\\n\",\n    \"$ export SPARK_JARS=rapids-4-spark_2.12-26.02.0.jar\\n\",\n    \"\\n\",\n    \"```\\n\",\n    \"\\n\",\n    \"### 5.Start Jupyter Notebook with spylon-kernel or toree\\n\",\n    \"\\n\",\n    \"```\\n\",\n    \"$ jupyter notebook --allow-root --notebook-dir=${your-dir} --config=${your-configs}\\n\",\n    \"```\\n\",\n    \"\\n\",\n    \"## Import Libs\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"id\": \"1e50cfad\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import org.apache.spark.sql.SparkSession\\n\",\n    \"import org.apache.spark.sql.DataFrame\\n\",\n    \"import org.apache.spark.sql.functions._\\n\",\n    \"import org.apache.spark.sql.types.DataTypes.{DoubleType, IntegerType, StringType}\\n\",\n    \"import org.apache.spark.sql.types.{FloatType, StructField, StructType}\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"24f69140\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Script Settings\\n\",\n    \"\\n\",\n    \"### 1. File Path Settings\\n\",\n    \"* Define input file path\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"id\": \"317b9415\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"lastException = null\\n\",\n       \"dataRoot = /data\\n\",\n       \"rawPath = /data/taxi/taxi-etl-input-small.csv\\n\",\n       \"outPath = /data/datasets/taxi/output\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"/data/taxi/output\"\n      ]\n     },\n     \"execution_count\": 6,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"val dataRoot = sys.env.getOrElse(\\\"DATA_ROOT\\\", \\\"/data\\\")\\n\",\n    \"val rawPath = dataRoot + \\\"/taxi/taxi-etl-input-small.csv\\\"\\n\",\n    \"val outPath = dataRoot + \\\"/taxi/output\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"6f036d30\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Function and Object Define\\n\",\n    \"### Define the constants\\n\",\n    \"\\n\",\n    \"* Define input file schema\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"id\": \"acc23ac1\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"rawSchema = StructType(StructField(vendor_id,StringType,true), StructField(pickup_datetime,StringType,true), StructField(dropoff_datetime,StringType,true), StructField(passenger_count,IntegerType,true), StructField(trip_distance,DoubleType,true), StructField(pickup_longitude,DoubleType,true), StructField(pickup_latitude,DoubleType,true), StructField(rate_code,StringType,true), StructField(store_and_fwd_flag,StringType,true), StructField(dropoff_longitude,DoubleType,true), StructField(dropoff_latitude,DoubleType,true), StructField(payment_type,StringType,true), StructField(fare_amount,DoubleType,true), StructField(surcharge,DoubleType,true), StructField(mta_tax,DoubleType,true), StructField(tip_amount,DoubleType,true), StructField(tolls_amount,Doubl...\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"StructType(StructField(vendor_id,StringType,true), StructField(pickup_datetime,StringType,true), StructField(dropoff_datetime,StringType,true), StructField(passenger_count,IntegerType,true), StructField(trip_distance,DoubleType,true), StructField(pickup_longitude,DoubleType,true), StructField(pickup_latitude,DoubleType,true), StructField(rate_code,StringType,true), StructField(store_and_fwd_flag,StringType,true), StructField(dropoff_longitude,DoubleType,true), StructField(dropoff_latitude,DoubleType,true), StructField(payment_type,StringType,true), StructField(fare_amount,DoubleType,true), StructField(surcharge,DoubleType,true), StructField(mta_tax,DoubleType,true), StructField(tip_amount,DoubleType,true), StructField(tolls_amount,Doubl...\"\n      ]\n     },\n     \"execution_count\": 7,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"val rawSchema = StructType(Seq(\\n\",\n    \"    StructField(\\\"vendor_id\\\", StringType),\\n\",\n    \"    StructField(\\\"pickup_datetime\\\", StringType),\\n\",\n    \"    StructField(\\\"dropoff_datetime\\\", StringType),\\n\",\n    \"    StructField(\\\"passenger_count\\\", IntegerType),\\n\",\n    \"    StructField(\\\"trip_distance\\\", DoubleType),\\n\",\n    \"    StructField(\\\"pickup_longitude\\\", DoubleType),\\n\",\n    \"    StructField(\\\"pickup_latitude\\\", DoubleType),\\n\",\n    \"    StructField(\\\"rate_code\\\", StringType),\\n\",\n    \"    StructField(\\\"store_and_fwd_flag\\\", StringType),\\n\",\n    \"    StructField(\\\"dropoff_longitude\\\", DoubleType),\\n\",\n    \"    StructField(\\\"dropoff_latitude\\\", DoubleType),\\n\",\n    \"    StructField(\\\"payment_type\\\", StringType),\\n\",\n    \"    StructField(\\\"fare_amount\\\", DoubleType),\\n\",\n    \"    StructField(\\\"surcharge\\\", DoubleType),\\n\",\n    \"    StructField(\\\"mta_tax\\\", DoubleType),\\n\",\n    \"    StructField(\\\"tip_amount\\\", DoubleType),\\n\",\n    \"    StructField(\\\"tolls_amount\\\", DoubleType),\\n\",\n    \"    StructField(\\\"total_amount\\\", DoubleType)\\n\",\n    \"  ))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"id\": \"2e467519\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"trainRatio = 80\\n\",\n       \"evalRatio = 20\\n\",\n       \"trainEvalRatio = 0\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"dataRatios: (Int, Int, Int)\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"0\"\n      ]\n     },\n     \"execution_count\": 8,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"def dataRatios: (Int, Int, Int) = {\\n\",\n    \"    val ratios = (80, 20)\\n\",\n    \"    (ratios._1, ratios._2, 100 - ratios._1 - ratios._2)\\n\",\n    \"  }\\n\",\n    \"val (trainRatio, evalRatio, trainEvalRatio) = dataRatios\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"5c2024d7\",\n   \"metadata\": {},\n   \"source\": [\n    \"* Build the spark session and dataframe\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 9,\n   \"id\": \"b551ca1d\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"sparkSession = org.apache.spark.sql.SparkSession@68530eb7\\n\",\n       \"df = [vendor_id: string, pickup_datetime: string ... 16 more fields]\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"[vendor_id: string, pickup_datetime: string ... 16 more fields]\"\n      ]\n     },\n     \"execution_count\": 9,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"// Build the spark session and data reader as usual\\n\",\n    \"val sparkSession = SparkSession.builder.appName(\\\"taxi-etl\\\").getOrCreate\\n\",\n    \"val df = sparkSession.read.option(\\\"header\\\", true).schema(rawSchema).csv(rawPath)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"2f50ff7d\",\n   \"metadata\": {},\n   \"source\": [\n    \"* Define some ETL functions\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 10,\n   \"id\": \"3ca5738f\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"dropUseless: (dataFrame: org.apache.spark.sql.DataFrame)org.apache.spark.sql.DataFrame\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"def dropUseless(dataFrame: DataFrame): DataFrame = {\\n\",\n    \"    dataFrame.drop(\\n\",\n    \"      \\\"dropoff_datetime\\\",\\n\",\n    \"      \\\"payment_type\\\",\\n\",\n    \"      \\\"surcharge\\\",\\n\",\n    \"      \\\"mta_tax\\\",\\n\",\n    \"      \\\"tip_amount\\\",\\n\",\n    \"      \\\"tolls_amount\\\",\\n\",\n    \"      \\\"total_amount\\\")\\n\",\n    \"  }\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 11,\n   \"id\": \"852b06c3\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"encodeCategories: (dataFrame: org.apache.spark.sql.DataFrame)org.apache.spark.sql.DataFrame\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"def encodeCategories(dataFrame: DataFrame): DataFrame = {\\n\",\n    \"    val categories = Seq(\\\"vendor_id\\\", \\\"rate_code\\\", \\\"store_and_fwd_flag\\\")\\n\",\n    \"\\n\",\n    \"    (categories.foldLeft(dataFrame) {\\n\",\n    \"      case (df, category) => df.withColumn(category, hash(col(category)))\\n\",\n    \"    }).withColumnRenamed(\\\"store_and_fwd_flag\\\", \\\"store_and_fwd\\\")\\n\",\n    \"  }\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 12,\n   \"id\": \"dbf0ab75\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"fillNa: (dataFrame: org.apache.spark.sql.DataFrame)org.apache.spark.sql.DataFrame\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"def fillNa(dataFrame: DataFrame): DataFrame = {\\n\",\n    \"    dataFrame.na.fill(-1)\\n\",\n    \"  }\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 13,\n   \"id\": \"39308a05\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"removeInvalid: (dataFrame: org.apache.spark.sql.DataFrame)org.apache.spark.sql.DataFrame\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"def removeInvalid(dataFrame: DataFrame): DataFrame = {\\n\",\n    \"    val conditions = Seq(\\n\",\n    \"      Seq(\\\"fare_amount\\\", 0, 500),\\n\",\n    \"      Seq(\\\"passenger_count\\\", 0, 6),\\n\",\n    \"      Seq(\\\"pickup_longitude\\\", -75, -73),\\n\",\n    \"      Seq(\\\"dropoff_longitude\\\", -75, -73),\\n\",\n    \"      Seq(\\\"pickup_latitude\\\", 40, 42),\\n\",\n    \"      Seq(\\\"dropoff_latitude\\\", 40, 42))\\n\",\n    \"\\n\",\n    \"    conditions\\n\",\n    \"      .map { case Seq(column, min, max) => \\\"%s > %d and %s < %d\\\".format(column, min, column, max) }\\n\",\n    \"      .foldLeft(dataFrame) {\\n\",\n    \"        _.filter(_)\\n\",\n    \"      }\\n\",\n    \"  }\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 14,\n   \"id\": \"11cd052b\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"convertDatetime: (dataFrame: org.apache.spark.sql.DataFrame)org.apache.spark.sql.DataFrame\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"def convertDatetime(dataFrame: DataFrame): DataFrame = {\\n\",\n    \"    val datetime = col(\\\"pickup_datetime\\\")\\n\",\n    \"    dataFrame\\n\",\n    \"      .withColumn(\\\"pickup_datetime\\\", to_timestamp(datetime))\\n\",\n    \"      .withColumn(\\\"year\\\", year(datetime))\\n\",\n    \"      .withColumn(\\\"month\\\", month(datetime))\\n\",\n    \"      .withColumn(\\\"day\\\", dayofmonth(datetime))\\n\",\n    \"      .withColumn(\\\"day_of_week\\\", dayofweek(datetime))\\n\",\n    \"      .withColumn(\\n\",\n    \"        \\\"is_weekend\\\",\\n\",\n    \"        col(\\\"day_of_week\\\").isin(1, 7).cast(IntegerType)) // 1: Sunday, 7: Saturday\\n\",\n    \"      .withColumn(\\\"hour\\\", hour(datetime))\\n\",\n    \"      .drop(datetime.toString)\\n\",\n    \"  }\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 15,\n   \"id\": \"71e1b568\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"addHDistance: (dataFrame: org.apache.spark.sql.DataFrame)org.apache.spark.sql.DataFrame\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"def addHDistance(dataFrame: DataFrame): DataFrame = {\\n\",\n    \"    val P = math.Pi / 180\\n\",\n    \"    val lat1 = col(\\\"pickup_latitude\\\")\\n\",\n    \"    val lon1 = col(\\\"pickup_longitude\\\")\\n\",\n    \"    val lat2 = col(\\\"dropoff_latitude\\\")\\n\",\n    \"    val lon2 = col(\\\"dropoff_longitude\\\")\\n\",\n    \"    val internalValue = (lit(0.5)\\n\",\n    \"      - cos((lat2 - lat1) * P) / 2\\n\",\n    \"      + cos(lat1 * P) * cos(lat2 * P) * (lit(1) - cos((lon2 - lon1) * P)) / 2)\\n\",\n    \"    val hDistance = lit(12734) * asin(sqrt(internalValue))\\n\",\n    \"    dataFrame.withColumn(\\\"h_distance\\\", hDistance)\\n\",\n    \"  }\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"6fe805d5\",\n   \"metadata\": {},\n   \"source\": [\n    \"* Define main ETL function\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 19,\n   \"id\": \"6da3b832\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"preProcess: (dataFrame: org.apache.spark.sql.DataFrame, splits: Array[Int])Array[org.apache.spark.sql.DataFrame]\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"def preProcess(dataFrame: DataFrame, splits: Array[Int]): Array[DataFrame] = {\\n\",\n    \"    val processes = Seq[DataFrame => DataFrame](\\n\",\n    \"      dropUseless,\\n\",\n    \"      encodeCategories,\\n\",\n    \"      fillNa,\\n\",\n    \"      removeInvalid,\\n\",\n    \"      convertDatetime,\\n\",\n    \"      addHDistance\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    processes\\n\",\n    \"      .foldLeft(dataFrame) { case (df, process) => process(df) }\\n\",\n    \"      .randomSplit(splits.map(_.toDouble))\\n\",\n    \"  }\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 20,\n   \"id\": \"85541b03\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"dataset = Array([vendor_id: int, passenger_count: int ... 15 more fields], [vendor_id: int, passenger_count: int ... 15 more fields], [vendor_id: int, passenger_count: int ... 15 more fields])\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"Array([vendor_id: int, passenger_count: int ... 15 more fields], [vendor_id: int, passenger_count: int ... 15 more fields], [vendor_id: int, passenger_count: int ... 15 more fields])\"\n      ]\n     },\n     \"execution_count\": 20,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"val dataset = preProcess(df, Array(trainRatio, trainEvalRatio, evalRatio))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"6787cac7\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Run ETL Process and Save the Result\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 21,\n   \"id\": \"371886e8\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Elapsed time : 4.371s\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"t0 = 1654139600797\\n\",\n       \"t1 = 1654139605168\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"1654139605168\"\n      ]\n     },\n     \"execution_count\": 21,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"val t0 = System.currentTimeMillis\\n\",\n    \"for ((name, index) <- Seq(\\\"train\\\", \\\"eval\\\", \\\"trans\\\").zipWithIndex) {\\n\",\n    \"        dataset(index).write.mode(\\\"overwrite\\\").parquet(outPath + \\\"/parquet/\\\" + name)\\n\",\n    \"        dataset(index).write.mode(\\\"overwrite\\\").csv(outPath + \\\"/csv/\\\" + name)\\n\",\n    \"      }\\n\",\n    \"val t1 = System.currentTimeMillis\\n\",\n    \"println(\\\"Elapsed time : \\\" + ((t1 - t0).toFloat / 1000) + \\\"s\\\")\\n\",\n    \"sparkSession.stop()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"8d89fa1b\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"XGBoost4j-Spark - Scala\",\n   \"language\": \"scala\",\n   \"name\": \"XGBoost4j-Spark_scala\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": \"text/x-scala\",\n   \"file_extension\": \".scala\",\n   \"mimetype\": \"text/x-scala\",\n   \"name\": \"scala\",\n   \"pygments_lexer\": \"scala\",\n   \"version\": \"2.12.15\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "examples/XGBoost-Examples/taxi/notebooks/scala/taxi-gpu.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Introduction to XGBoost Spark with GPU\\n\",\n    \"\\n\",\n    \"Taxi is an example of XGBoost regressor. This notebook will show you how to load data, train the XGBoost model and use this model to predict \\\"fare_amount\\\" of your taxi trip.\\n\",\n    \"\\n\",\n    \"## Load libraries\\n\",\n    \"First load some common libraries will be used by both GPU version and CPU version XGBoost.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import ml.dmlc.xgboost4j.scala.spark.{XGBoostRegressor, XGBoostRegressionModel}\\n\",\n    \"import org.apache.spark.sql.SparkSession\\n\",\n    \"import org.apache.spark.ml.evaluation.RegressionEvaluator\\n\",\n    \"import org.apache.spark.sql.types.{DoubleType, IntegerType, StructField, StructType}\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Besides CPU version requires some extra libraries, such as:\\n\",\n    \"\\n\",\n    \"```scala\\n\",\n    \"import org.apache.spark.ml.feature.VectorAssembler\\n\",\n    \"import org.apache.spark.sql.DataFrame\\n\",\n    \"import org.apache.spark.sql.functions._\\n\",\n    \"import org.apache.spark.sql.types.FloatType\\n\",\n    \"```\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Set the dataset path\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"dataRoot = /data\\n\",\n       \"trainPath = /data/taxi/csv/train/\\n\",\n       \"evalPath = /data/taxi/csv/test/\\n\",\n       \"transPath = /data/taxi/csv/test/\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"/data/taxi/csv/test/\"\n      ]\n     },\n     \"execution_count\": 7,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"// You need to update them to your real paths! The input data files can be the output of taxi-etl jobs, or you can\\n\",\n    \"// just use the provided sample datasets upder datasets path. \\n\",\n    \"val dataRoot = sys.env.getOrElse(\\\"DATA_ROOT\\\", \\\"/data\\\")\\n\",\n    \"val trainPath = dataRoot + \\\"/taxi/csv/train/\\\"\\n\",\n    \"val evalPath  = dataRoot + \\\"/taxi/csv/test/\\\"\\n\",\n    \"val transPath = dataRoot + \\\"/taxi/csv/test/\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Build the schema of the dataset\\n\",\n    \"The Taxi data has 16 columns: 15 features and 1 label. \\\"fare_amount\\\" is the label column. The schema will be used to load data in the future. \\n\",\n    \"\\n\",\n    \"The next block also defines some key parameters used in XGBoost training process.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"labelName = fare_amount\\n\",\n       \"schema = <lazy>\\n\",\n       \"featureNames = Array(vendor_id, passenger_count, trip_distance, pickup_longitude, pickup_latitude, rate_code, store_and_fwd, dropoff_longitude, dropoff_latitude, hour, year, month, day, day_of_week, is_weekend)\\n\",\n       \"paramMap = <lazy>\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"<lazy>\"\n      ]\n     },\n     \"execution_count\": 8,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"val labelName = \\\"fare_amount\\\"\\n\",\n    \"lazy val schema =\\n\",\n    \"  StructType(Array(\\n\",\n    \"    StructField(\\\"vendor_id\\\", DoubleType),\\n\",\n    \"    StructField(\\\"passenger_count\\\", DoubleType),\\n\",\n    \"    StructField(\\\"trip_distance\\\", DoubleType),\\n\",\n    \"    StructField(\\\"pickup_longitude\\\", DoubleType),\\n\",\n    \"    StructField(\\\"pickup_latitude\\\", DoubleType),\\n\",\n    \"    StructField(\\\"rate_code\\\", DoubleType),\\n\",\n    \"    StructField(\\\"store_and_fwd\\\", DoubleType),\\n\",\n    \"    StructField(\\\"dropoff_longitude\\\", DoubleType),\\n\",\n    \"    StructField(\\\"dropoff_latitude\\\", DoubleType),\\n\",\n    \"    StructField(labelName, DoubleType),\\n\",\n    \"    StructField(\\\"hour\\\", DoubleType),\\n\",\n    \"    StructField(\\\"year\\\", IntegerType),\\n\",\n    \"    StructField(\\\"month\\\", IntegerType),\\n\",\n    \"    StructField(\\\"day\\\", DoubleType),\\n\",\n    \"    StructField(\\\"day_of_week\\\", DoubleType),\\n\",\n    \"    StructField(\\\"is_weekend\\\", DoubleType)\\n\",\n    \"  ))\\n\",\n    \"\\n\",\n    \"val featureNames = schema.filter(_.name != labelName).map(_.name).toArray\\n\",\n    \"\\n\",\n    \"lazy val paramMap = Map(\\n\",\n    \"  \\\"num_round\\\" -> 100\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Create a new spark session and load data\\n\",\n    \"\\n\",\n    \"A new spark session should be created to continue all the following spark operations.\\n\",\n    \"\\n\",\n    \"NOTE: in this notebook, the dependency jars have been loaded when installing toree kernel. Alternatively the jars can be loaded into notebook by [%AddJar magic](https://toree.incubator.apache.org/docs/current/user/faq/). However, there's one restriction for `%AddJar`: the jar uploaded can only be available when `AddJar` is called just after a new spark session is created. Do it as below:\\n\",\n    \"\\n\",\n    \"```scala\\n\",\n    \"import org.apache.spark.sql.SparkSession\\n\",\n    \"val spark = SparkSession.builder().appName(\\\"taxi-GPU\\\").getOrCreate\\n\",\n    \"%AddJar file:/data/libs/rapids-4-spark-XXX.jar\\n\",\n    \"%AddJar file:/data/libs/xgboost4j-spark-gpu_2.12-XXX.jar\\n\",\n    \"%AddJar file:/data/libs/xgboost4j-gpu_2.12-XXX.jar\\n\",\n    \"// ...\\n\",\n    \"```\\n\",\n    \"\\n\",\n    \"##### Please note the new jar \\\"rapids-4-spark-XXX.jar\\\" is only needed for GPU version, you can not add it to dependence list for CPU version.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 9,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"sparkSession = org.apache.spark.sql.SparkSession@6efbc93b\\n\",\n       \"reader = org.apache.spark.sql.DataFrameReader@64b8d6da\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"org.apache.spark.sql.DataFrameReader@64b8d6da\"\n      ]\n     },\n     \"execution_count\": 9,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"// Build the spark session and data reader as usual\\n\",\n    \"val sparkSession = SparkSession.builder().appName(\\\"taxi-GPU\\\").getOrCreate\\n\",\n    \"val reader = sparkSession.read.option(\\\"header\\\", true).schema(schema)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 10,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"trainSet = [vendor_id: double, passenger_count: double ... 14 more fields]\\n\",\n       \"evalSet = [vendor_id: double, passenger_count: double ... 14 more fields]\\n\",\n       \"transSet = [vendor_id: double, passenger_count: double ... 14 more fields]\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"[vendor_id: double, passenger_count: double ... 14 more fields]\"\n      ]\n     },\n     \"execution_count\": 10,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"// Please make sure to change the api to reader.parquet if you load parquet files.\\n\",\n    \"val trainSet = reader.csv(trainPath)\\n\",\n    \"val evalSet  = reader.csv(evalPath)\\n\",\n    \"val transSet = reader.csv(transPath)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Set XGBoost parameters and build a XGBoostRegressor\\n\",\n    \"\\n\",\n    \"For CPU version, `num_workers` is recommended being equal to the number of CPU cores, while for GPU version, it should be set to the number of GPUs in Spark cluster.\\n\",\n    \"\\n\",\n    \"Besides the `device` for CPU version is also different from that for GPU version. Now only \\\"cuda\\\" is supported for training on GPU.\\n\",\n    \"\\n\",\n    \"```scala\\n\",\n    \"// difference in parameters\\n\",\n    \"  \\\"num_workers\\\" -> 12\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 11,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"xgbParamFinal = Map(num_round -> 100, tree_method -> hist, device -> cuda, num_workers -> 1)\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"Map(num_round -> 100, tree_method -> hist, device -> cuda, num_workers -> 1)\"\n      ]\n     },\n     \"execution_count\": 11,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": \"val xgbParamFinal = paramMap ++ Map(\\\"tree_method\\\" -> \\\"hist\\\", \\\"device\\\" -> \\\"cuda\\\", \\\"num_workers\\\" -> 1)\"\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 12,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"xgbRegressor = xgbr_d36c6f5fd67c\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"xgbr_d36c6f5fd67c\"\n      ]\n     },\n     \"execution_count\": 12,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"val xgbRegressor = new XGBoostRegressor(xgbParamFinal)\\n\",\n    \"  .setLabelCol(labelName)\\n\",\n    \"  .setFeaturesCol(featureNames)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Benchmark and train\\n\",\n    \"The object `benchmark` is used to compute the elapsed time of some operations.\\n\",\n    \"\\n\",\n    \"Training with evaluation dataset is also supported, the same as CPU version's behavior:\\n\",\n    \"\\n\",\n    \"* Call API `setEvalDataset` after initializing an XGBoostClassifier\\n\",\n    \"\\n\",\n    \"```scala\\n\",\n    \"xgbClassifier.setEvalDataset(evalSet)\\n\",\n    \"```\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 13,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"xgbr_d36c6f5fd67c\"\n      ]\n     },\n     \"execution_count\": 13,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"xgbRegressor.setEvalDataset(evalSet)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 14,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"defined object Benchmark\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"object Benchmark {\\n\",\n    \"  def time[R](phase: String)(block: => R): (R, Float) = {\\n\",\n    \"    val t0 = System.currentTimeMillis\\n\",\n    \"    val result = block // call-by-name\\n\",\n    \"    val t1 = System.currentTimeMillis\\n\",\n    \"    println(\\\"Elapsed time [\\\" + phase + \\\"]: \\\" + ((t1 - t0).toFloat / 1000) + \\\"s\\\")\\n\",\n    \"    (result, (t1 - t0).toFloat / 1000)\\n\",\n    \"  }\\n\",\n    \"}\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 15,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=37275, DMLC_NUM_WORKER=1}\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"model = xgbr_d36c6f5fd67c\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Elapsed time [train]: 7.441s\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"xgbr_d36c6f5fd67c\"\n      ]\n     },\n     \"execution_count\": 15,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"// start training\\n\",\n    \"val (model, _) = Benchmark.time(\\\"train\\\") {\\n\",\n    \"  xgbRegressor.fit(trainSet)\\n\",\n    \"}\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Transformation and evaluation\\n\",\n    \"Here uses `transSet` to evaluate our model and use some key columns to show our predictions. Finally we use `RegressionEvaluator` to calculate an overall `rmse` of our predictions.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 16,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Elapsed time [transform]: 2.134s\\n\",\n      \"+-------------+---------------+------------------+-----------+------------------+\\n\",\n      \"|    vendor_id|passenger_count|     trip_distance|fare_amount|        prediction|\\n\",\n      \"+-------------+---------------+------------------+-----------+------------------+\\n\",\n      \"|1.559730423E9|            2.0|0.7000000000000001|        5.0|  5.04693603515625|\\n\",\n      \"|1.559730423E9|            3.0|10.700000000000001|       34.0|31.727073669433594|\\n\",\n      \"|1.559730423E9|            1.0|               2.3|       10.0| 9.294451713562012|\\n\",\n      \"|1.559730423E9|            1.0|               4.4|       16.5| 15.05233097076416|\\n\",\n      \"|1.559730423E9|            1.0|               1.5|        7.0| 8.995831489562988|\\n\",\n      \"|1.559730423E9|            1.0|               0.8|        7.5| 6.239481449127197|\\n\",\n      \"|1.559730423E9|            1.0|               1.2|        5.5| 7.339130401611328|\\n\",\n      \"|1.559730423E9|            1.0|               3.0|        2.5|13.403449058532715|\\n\",\n      \"| 4.52563162E8|            1.0|2.3399999999999994|        9.5| 9.672189712524414|\\n\",\n      \"| 4.52563162E8|            1.0|              3.17|       12.0|11.674100875854492|\\n\",\n      \"+-------------+---------------+------------------+-----------+------------------+\\n\",\n      \"only showing top 10 rows\\n\",\n      \"\\n\",\n      \"Elapsed time [evaluation]: 0.17s\\n\",\n      \"RMSE == 1.9141528880798715\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"prediction = [vendor_id: double, passenger_count: double ... 15 more fields]\\n\",\n       \"evaluator = RegressionEvaluator: uid=regEval_547b9abc7a3b, metricName=rmse, throughOrigin=false\\n\",\n       \"rmse = 1.9141528880798715\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"1.9141528880798715\"\n      ]\n     },\n     \"execution_count\": 16,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"// start transform\\n\",\n    \"val (prediction, _) = Benchmark.time(\\\"transform\\\") {\\n\",\n    \"  val ret = model.transform(transSet).cache()\\n\",\n    \"  ret.foreachPartition((_: Iterator[_]) => ())\\n\",\n    \"  ret\\n\",\n    \"}\\n\",\n    \"prediction.select(\\\"vendor_id\\\", \\\"passenger_count\\\", \\\"trip_distance\\\", labelName, \\\"prediction\\\").show(10)\\n\",\n    \"val evaluator = new RegressionEvaluator().setLabelCol(labelName)\\n\",\n    \"val (rmse, _) = Benchmark.time(\\\"evaluation\\\") {\\n\",\n    \"  evaluator.evaluate(prediction)\\n\",\n    \"}\\n\",\n    \"println(s\\\"RMSE == $rmse\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Save the model to disk and load model\\n\",\n    \"Save the model to disk and then load it to memory. After that use the loaded model to do a new prediction.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 17,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Elapsed time [transform2]: 0.025s\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"modelFromDisk = xgbr_d36c6f5fd67c\\n\",\n       \"results2 = [vendor_id: double, passenger_count: double ... 15 more fields]\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+-------------+---------------+------------------+-----------+------------------+\\n\",\n      \"|    vendor_id|passenger_count|     trip_distance|fare_amount|        prediction|\\n\",\n      \"+-------------+---------------+------------------+-----------+------------------+\\n\",\n      \"|1.559730423E9|            2.0|0.7000000000000001|        5.0|  5.04693603515625|\\n\",\n      \"|1.559730423E9|            3.0|10.700000000000001|       34.0|31.727073669433594|\\n\",\n      \"|1.559730423E9|            1.0|               2.3|       10.0| 9.294451713562012|\\n\",\n      \"|1.559730423E9|            1.0|               4.4|       16.5| 15.05233097076416|\\n\",\n      \"|1.559730423E9|            1.0|               1.5|        7.0| 8.995831489562988|\\n\",\n      \"+-------------+---------------+------------------+-----------+------------------+\\n\",\n      \"only showing top 5 rows\\n\",\n      \"\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"[vendor_id: double, passenger_count: double ... 15 more fields]\"\n      ]\n     },\n     \"execution_count\": 17,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"model.write.overwrite.save(dataRoot + \\\"/model/taxi\\\")\\n\",\n    \"\\n\",\n    \"val modelFromDisk = XGBoostRegressionModel.load(dataRoot + \\\"/model/taxi\\\")\\n\",\n    \"val (results2, _) = Benchmark.time(\\\"transform2\\\") {\\n\",\n    \"  modelFromDisk.transform(transSet)\\n\",\n    \"}\\n\",\n    \"results2.select(\\\"vendor_id\\\", \\\"passenger_count\\\", \\\"trip_distance\\\", labelName, \\\"prediction\\\").show(5)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 18,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"sparkSession.close()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"XGBoost4j-Spark - Scala\",\n   \"language\": \"scala\",\n   \"name\": \"XGBoost4j-Spark_scala\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": \"text/x-scala\",\n   \"file_extension\": \".scala\",\n   \"mimetype\": \"text/x-scala\",\n   \"name\": \"scala\",\n   \"pygments_lexer\": \"scala\",\n   \"version\": \"2.12.15\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "examples/XGBoost-Examples/taxi/notebooks/scala/taxi_gpu_crossvalidation.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Taxi CrossValidation with GPU accelerating on XGBoost\\n\",\n    \"\\n\",\n    \"In this notebook, we will show you how to levarage GPU to accelerate taxi CrossValidation on XGBoost to find out the best model given a group parameters.\\n\",\n    \"\\n\",\n    \"## Import classes\\n\",\n    \"First we need load some common classes that both GPU version and CPU version will use:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 16,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import ml.dmlc.xgboost4j.scala.spark.{XGBoostRegressionModel, XGBoostRegressor}\\n\",\n    \"import org.apache.spark.ml.evaluation.{RegressionEvaluator}\\n\",\n    \"import org.apache.spark.ml.tuning.{ParamGridBuilder,CrossValidator}\\n\",\n    \"import org.apache.spark.sql.SparkSession\\n\",\n    \"import org.apache.spark.sql.types.{FloatType, IntegerType, StructField, StructType}\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"what is new to xgboost-spark users is rapids.GpuDataReader and **rapids.CrossValidator**\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 17,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"ename\": \"Syntax Error.\",\n     \"evalue\": \"\",\n     \"output_type\": \"error\",\n     \"traceback\": []\n    }\n   ],\n   \"source\": [\n    \"// import ml.dmlc.xgboost4j.scala.spark.rapids.CrossValidator\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Set dataset path\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 18,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"dataRoot = /data\\n\",\n       \"trainParquetPath = /data/taxi/parquet/train\\n\",\n       \"evalParquetPath = /data/taxi/parquet/eval\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"/data/taxi/parquet/eval\"\n      ]\n     },\n     \"execution_count\": 18,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"// You need to update them to your real paths! The input data files can be the output of taxi-etl jobs, or you can\\n\",\n    \"// just use the provided sample datasets under datasets path. \\n\",\n    \"val dataRoot = sys.env.getOrElse(\\\"DATA_ROOT\\\", \\\"/data\\\")\\n\",\n    \"val trainParquetPath=dataRoot + \\\"/taxi/parquet/train\\\"\\n\",\n    \"val evalParquetPath=dataRoot + \\\"/taxi/parquet/eval\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Set the schema of the dataset\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 19,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"labelColName = fare_amount\\n\",\n       \"schema = StructType(StructField(vendor_id,FloatType,true), StructField(passenger_count,FloatType,true), StructField(trip_distance,FloatType,true), StructField(pickup_longitude,FloatType,true), StructField(pickup_latitude,FloatType,true), StructField(rate_code,FloatType,true), StructField(store_and_fwd,FloatType,true), StructField(dropoff_longitude,FloatType,true), StructField(dropoff_latitude,FloatType,true), StructField(fare_amount,FloatType,true), StructField(hour,FloatType,true), StructField(year,IntegerType,true), StructField(month,IntegerType,true), StructField(day,FloatType,true), StructField(day_of_week,FloatType,true), StructField(is_weekend,FloatType,true))\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"StructType(StructField(vendor_id,FloatType,true), StructField(passenger_count,FloatType,true), StructField(trip_distance,FloatType,true), StructField(pickup_longitude,FloatType,true), StructField(pickup_latitude,FloatType,true), StructField(rate_code,FloatType,true), StructField(store_and_fwd,FloatType,true), StructField(dropoff_longitude,FloatType,true), StructField(dropoff_latitude,FloatType,true), StructField(fare_amount,FloatType,true), StructField(hour,FloatType,true), StructField(year,IntegerType,true), StructField(month,IntegerType,true), StructField(day,FloatType,true), StructField(day_of_week,FloatType,true), StructField(is_weekend,FloatType,true))\"\n      ]\n     },\n     \"execution_count\": 19,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"val labelColName = \\\"fare_amount\\\"\\n\",\n    \"val schema =\\n\",\n    \"    StructType(Array(\\n\",\n    \"      StructField(\\\"vendor_id\\\", FloatType),\\n\",\n    \"      StructField(\\\"passenger_count\\\", FloatType),\\n\",\n    \"      StructField(\\\"trip_distance\\\", FloatType),\\n\",\n    \"      StructField(\\\"pickup_longitude\\\", FloatType),\\n\",\n    \"      StructField(\\\"pickup_latitude\\\", FloatType),\\n\",\n    \"      StructField(\\\"rate_code\\\", FloatType),\\n\",\n    \"      StructField(\\\"store_and_fwd\\\", FloatType),\\n\",\n    \"      StructField(\\\"dropoff_longitude\\\", FloatType),\\n\",\n    \"      StructField(\\\"dropoff_latitude\\\", FloatType),\\n\",\n    \"      StructField(labelColName, FloatType),\\n\",\n    \"      StructField(\\\"hour\\\", FloatType),\\n\",\n    \"      StructField(\\\"year\\\", IntegerType),\\n\",\n    \"      StructField(\\\"month\\\", IntegerType),\\n\",\n    \"      StructField(\\\"day\\\", FloatType),\\n\",\n    \"      StructField(\\\"day_of_week\\\", FloatType),\\n\",\n    \"      StructField(\\\"is_weekend\\\", FloatType)\\n\",\n    \"    ))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Create a new spark session and load data\\n\",\n    \"we must create a new spark session to continue all spark operations. It will also be used to initilize the `GpuDataReader` which is a data reader powered by GPU.\\n\",\n    \"\\n\",\n    \"NOTE: in this notebook, we have uploaded dependency jars when installing toree kernel. If we don't upload them at installation time, we can also upload in notebook by [%AddJar magic](https://toree.incubator.apache.org/docs/current/user/faq/). However, there's one restriction for `%AddJar`: the jar uploaded can only be available when `AddJar` is called after a new spark session is created. We must use it as below:\\n\",\n    \"\\n\",\n    \"```scala\\n\",\n    \"import org.apache.spark.sql.SparkSession\\n\",\n    \"val spark = SparkSession.builder().appName(\\\"Taxi-GPU-CV\\\").getOrCreate\\n\",\n    \"%AddJar file:/data/libs/rapids-4-spark-XXX.jar\\n\",\n    \"%AddJar file:/data/libs/xgboost4j-spark-gpu_2.12-XXX.jar\\n\",\n    \"%AddJar file:/data/libs/xgboost4j-gpu_2.12-XXX.jar\\n\",\n    \"// ...\\n\",\n    \"```\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 20,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"spark = org.apache.spark.sql.SparkSession@1b953a9c\\n\",\n       \"trainDs = [vendor_id: int, passenger_count: int ... 15 more fields]\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"[vendor_id: int, passenger_count: int ... 15 more fields]\"\n      ]\n     },\n     \"execution_count\": 20,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"val spark = SparkSession.builder().appName(\\\"taxi-gpu-cv\\\").getOrCreate()\\n\",\n    \"val trainDs = spark.read.parquet(trainParquetPath)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Find out features to train\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 21,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"featureNames = Array(vendor_id, passenger_count, trip_distance, pickup_longitude, pickup_latitude, rate_code, store_and_fwd, dropoff_longitude, dropoff_latitude, hour, year, month, day, day_of_week, is_weekend)\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"Array(vendor_id, passenger_count, trip_distance, pickup_longitude, pickup_latitude, rate_code, store_and_fwd, dropoff_longitude, dropoff_latitude, hour, year, month, day, day_of_week, is_weekend)\"\n      ]\n     },\n     \"execution_count\": 21,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"val featureNames = schema.filter(_.name != labelColName).map(_.name).toArray\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 22,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"regressorParam = Map(num_round -> 100, tree_method -> hist, device -> cuda,  num_workers -> 1)\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"Map(num_round -> 100, tree_method -> hist, device -> cuda, num_workers -> 1)\"\n      ]\n     },\n     \"execution_count\": 22,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"val regressorParam = Map(\\n\",\n    \"    \\\"num_round\\\" -> 100,\\n\",\n    \"    \\\"tree_method\\\" -> \\\"hist\\\",\\n\",\n    \"    \\\"device\\\" -> \\\"cuda\\\",\\n\",\n    \"    \\\"num_workers\\\" -> 1)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Construct CrossValidator\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 23,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"regressor = xgbr_1c1bd6fa3a5f\\n\",\n       \"paramGrid = \\n\",\n       \"evaluator = RegressionEvaluator: uid=regEval_c7293a967512, metricName=rmse, throughOrigin=false\\n\",\n       \"cv = cv_06528fc9d704\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"Array({\\n\",\n       \"\\txgbr_1c1bd6fa3a5f-eta: 0.2,\\n\",\n       \"\\txgbr_1c1bd6fa3a5f-maxDepth: 3\\n\",\n       \"}, {\\n\",\n       \"\\txgbr_1c1bd6fa3a5f-eta: 0.6,\\n\",\n       \"\\txgbr_1c1bd6fa3a5f-maxDepth: 3\\n\",\n       \"}, {\\n\",\n       \"\\txgbr_1c1bd6fa3a5f-eta: 0.2,\\n\",\n       \"\\txgbr_1c1bd6fa3a5f-maxDepth: 10\\n\",\n       \"}, {\\n\",\n       \"\\txgbr_1c1bd6fa3a5f-eta: 0.6,\\n\",\n       \"\\txgbr_1c1bd6fa3a5f-maxDepth: 10\\n\",\n       \"})\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"cv_06528fc9d704\"\n      ]\n     },\n     \"execution_count\": 23,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"val regressor = new XGBoostRegressor(regressorParam)\\n\",\n    \"    .setLabelCol(labelColName)\\n\",\n    \"    .setFeaturesCol(featureNames)\\n\",\n    \"val paramGrid = new ParamGridBuilder()\\n\",\n    \"    .addGrid(regressor.maxDepth, Array(3, 10))\\n\",\n    \"    .addGrid(regressor.eta, Array(0.2, 0.6))\\n\",\n    \"    .build()\\n\",\n    \"val evaluator = new RegressionEvaluator().setLabelCol(labelColName)\\n\",\n    \"val cv = new CrossValidator()\\n\",\n    \"    .setEstimator(regressor)\\n\",\n    \"    .setEvaluator(evaluator)\\n\",\n    \"    .setEstimatorParamMaps(paramGrid)\\n\",\n    \"    .setNumFolds(3)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## train with CrossValidator\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 24,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=36551, DMLC_NUM_WORKER=1}\\n\",\n      \"Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=40153, DMLC_NUM_WORKER=1}\\n\",\n      \"Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=46553, DMLC_NUM_WORKER=1}\\n\",\n      \"Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=50795, DMLC_NUM_WORKER=1}\\n\",\n      \"Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=44927, DMLC_NUM_WORKER=1}\\n\",\n      \"Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=55309, DMLC_NUM_WORKER=1}\\n\",\n      \"Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=55163, DMLC_NUM_WORKER=1}\\n\",\n      \"Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=54783, DMLC_NUM_WORKER=1}\\n\",\n      \"Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=49873, DMLC_NUM_WORKER=1}\\n\",\n      \"Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=36003, DMLC_NUM_WORKER=1}\\n\",\n      \"Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=41429, DMLC_NUM_WORKER=1}\\n\",\n      \"Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=60783, DMLC_NUM_WORKER=1}\\n\",\n      \"Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=49361, DMLC_NUM_WORKER=1}\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"model = xgbr_1c1bd6fa3a5f\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"xgbr_1c1bd6fa3a5f\"\n      ]\n     },\n     \"execution_count\": 24,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"val model = cv.fit(trainDs).bestModel.asInstanceOf[XGBoostRegressionModel]\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## tranform with best model trained by CrossValidator\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 25,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"transformDs = [vendor_id: int, passenger_count: int ... 15 more fields]\\n\",\n       \"df = [vendor_id: int, passenger_count: int ... 16 more fields]\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"+-----------+------------------+\\n\",\n      \"|fare_amount|        prediction|\\n\",\n      \"+-----------+------------------+\\n\",\n      \"|       11.4|12.278875350952148|\\n\",\n      \"|        7.4|7.4439215660095215|\\n\",\n      \"|        5.0| 4.565710067749023|\\n\",\n      \"|        8.5| 9.188780784606934|\\n\",\n      \"|        7.4| 7.266360759735107|\\n\",\n      \"+-----------+------------------+\\n\",\n      \"only showing top 5 rows\\n\",\n      \"\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"[vendor_id: int, passenger_count: int ... 16 more fields]\"\n      ]\n     },\n     \"execution_count\": 25,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"val transformDs = spark.read.parquet(evalParquetPath)\\n\",\n    \"val df = model.transform(transformDs).cache()\\n\",\n    \"df.select(\\\"fare_amount\\\", \\\"prediction\\\").show(5)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 26,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"evaluator = RegressionEvaluator: uid=regEval_1c57378a8fe1, metricName=rmse, throughOrigin=false\\n\",\n       \"rmse = 2.2492672858545992\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"2.2492672858545992\"\n      ]\n     },\n     \"execution_count\": 26,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"val evaluator = new RegressionEvaluator().setLabelCol(labelColName)\\n\",\n    \"val rmse = evaluator.evaluate(df)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 27,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"spark.close()\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"XGBoost4j-Spark - Scala\",\n   \"language\": \"scala\",\n   \"name\": \"XGBoost4j-Spark_scala\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": \"text/x-scala\",\n   \"file_extension\": \".scala\",\n   \"mimetype\": \"text/x-scala\",\n   \"name\": \"scala\",\n   \"pygments_lexer\": \"scala\",\n   \"version\": \"2.12.15\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "examples/XGBoost-Examples/taxi/pom.xml",
    "content": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<!--\n  ~ Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.\n  ~\n  ~ Licensed under the Apache License, Version 2.0 (the \"License\");\n  ~ you may not use this file except in compliance with the License.\n  ~ You may obtain a copy of the License at\n  ~\n  ~ http://www.apache.org/licenses/LICENSE-2.0\n  ~\n  ~ Unless required by applicable law or agreed to in writing, software\n  ~ distributed under the License is distributed on an \"AS IS\" BASIS,\n  ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n  ~ See the License for the specific language governing permissions and\n  ~ limitations under the License.\n  -->\n\n<project xmlns=\"http://maven.apache.org/POM/4.0.0\"\n         xmlns:xsi=\"http://www.w3.org/2001/XMLSchema-instance\"\n         xsi:schemaLocation=\"http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd\">\n    <parent>\n        <artifactId>sample_xgboost_examples</artifactId>\n        <groupId>com.nvidia</groupId>\n        <version>0.2.3-SNAPSHOT</version>\n    </parent>\n    <modelVersion>4.0.0</modelVersion>\n\n    <artifactId>spark_examples_taxi_${scala.binary.version}</artifactId>\n\n    <properties>\n        <maven.compiler.source>8</maven.compiler.source>\n        <maven.compiler.target>8</maven.compiler.target>\n    </properties>\n\n    <dependencies>\n        <dependency>\n            <groupId>com.nvidia</groupId>\n            <artifactId>spark_examples_utility_${scala.binary.version}</artifactId>\n            <version>${project.version}</version>\n            <scope>compile</scope>\n        </dependency>\n    </dependencies>\n\n    <build>\n        <sourceDirectory>scala/src</sourceDirectory>\n    </build>\n</project>"
  },
  {
    "path": "examples/XGBoost-Examples/taxi/python/com/__init__.py",
    "content": "# Copyright (c) 2024, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n"
  },
  {
    "path": "examples/XGBoost-Examples/taxi/python/com/nvidia/__init__.py",
    "content": "# Copyright (c) 2024, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n"
  },
  {
    "path": "examples/XGBoost-Examples/taxi/python/com/nvidia/spark/__init__.py",
    "content": "# Copyright (c) 2024, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n"
  },
  {
    "path": "examples/XGBoost-Examples/taxi/python/com/nvidia/spark/examples/__init__.py",
    "content": "# Copyright (c) 2024, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n"
  },
  {
    "path": "examples/XGBoost-Examples/taxi/python/com/nvidia/spark/examples/taxi/__init__.py",
    "content": "# Copyright (c) 2024, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n"
  },
  {
    "path": "examples/XGBoost-Examples/taxi/python/com/nvidia/spark/examples/taxi/consts.py",
    "content": "#\n# Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n# http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n\nfrom pyspark.sql.types import *\n\nlabel = 'fare_amount'\n\nraw_schema = StructType([\n    StructField('vendor_id', StringType()),\n    StructField('pickup_datetime', StringType()),\n    StructField('dropoff_datetime', StringType()),\n    StructField('passenger_count', IntegerType()),\n    StructField('trip_distance', DoubleType()),\n    StructField('pickup_longitude', DoubleType()),\n    StructField('pickup_latitude', DoubleType()),\n    StructField('rate_code', StringType()),\n    StructField('store_and_fwd_flag', StringType()),\n    StructField('dropoff_longitude', DoubleType()),\n    StructField('dropoff_latitude', DoubleType()),\n    StructField('payment_type', StringType()),\n    StructField(label, DoubleType()),\n    StructField('surcharge', DoubleType()),\n    StructField('mta_tax', DoubleType()),\n    StructField('tip_amount', DoubleType()),\n    StructField('tolls_amount', DoubleType()),\n    StructField('total_amount', DoubleType()),\n])\n\nfinal_schema = StructType([\n    StructField('vendor_id', FloatType()),\n    StructField('passenger_count', FloatType()),\n    StructField('trip_distance', FloatType()),\n    StructField('pickup_longitude', FloatType()),\n    StructField('pickup_latitude', FloatType()),\n    StructField('rate_code', FloatType()),\n    StructField('store_and_fwd', FloatType()),\n    StructField('dropoff_longitude', FloatType()),\n    StructField('dropoff_latitude', FloatType()),\n    StructField(label, FloatType()),\n    StructField('hour', FloatType()),\n    StructField('year', IntegerType()),\n    StructField('month', IntegerType()),\n    StructField('day', FloatType()),\n    StructField('day_of_week', FloatType()),\n    StructField('is_weekend', FloatType()),\n])\n"
  },
  {
    "path": "examples/XGBoost-Examples/taxi/python/com/nvidia/spark/examples/taxi/cross_validator_main.py",
    "content": "#\n# Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n# http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\nfrom .consts import *\nfrom com.nvidia.spark.examples.utility.utils import *\nfrom pyspark.ml.tuning import ParamGridBuilder, CrossValidator\nfrom pyspark.sql import SparkSession\n\nfrom xgboost.spark import SparkXGBRegressor, SparkXGBRegressorModel\n\n\ndef main(args, xgboost_args):\n    spark = (SparkSession\n             .builder\n             .appName(args.mainClass)\n             .getOrCreate())\n\n    train_data, eval_data, trans_data = valid_input_data(spark, args, raw_schema, final_schema)\n\n    if args.mode in ['all', 'train']:\n        if train_data is None:\n            print('-' * 80)\n            print('Usage: training data path required when mode is all or train')\n            print('-' * 80)\n            exit(1)\n\n        train_data, features = transform_data(train_data, label, args.use_gpu)\n        xgboost_args['features_col'] = features\n        xgboost_args['label_col'] = label\n\n        regressor = SparkXGBRegressor(**xgboost_args)\n\n        param_grid = (ParamGridBuilder()\n                      .addGrid(regressor.max_depth, [6, 8])\n                      .addGrid(regressor.n_estimators, [20, 40])\n                      .build())\n\n        evaluator = (RegressionEvaluator()\n                     .setLabelCol(label))\n\n        cross_validator = (CrossValidator()\n                           .setEstimator(regressor)\n                           .setEvaluator(evaluator)\n                           .setEstimatorParamMaps(param_grid)\n                           .setNumFolds(3))\n\n        model = with_benchmark('Training', lambda: cross_validator.fit(train_data))\n        # get the best model to do transform\n        model = model.bestModel\n        if args.modelPath:\n            writer = model.write().overwrite() if args.overwrite else model\n            writer.save(args.modelPath)\n    else:\n        model = SparkXGBRegressorModel.load(args.modelPath)\n\n    if args.mode in ['all', 'transform']:\n        if trans_data is None:\n            print('-' * 80)\n            print('Usage: trans data path required when mode is all or transform')\n            print('-' * 80)\n            exit(1)\n\n        trans_data, _ = transform_data(trans_data, label, args.use_gpu)\n\n        def transform():\n            result = model.transform(trans_data).cache()\n            result.foreachPartition(lambda _: None)\n            return result\n\n        result = with_benchmark('Transformation', transform)\n        show_sample(args, result, label)\n        with_benchmark('Evaluation', lambda: check_regression_accuracy(result, label))\n\n    spark.stop()\n"
  },
  {
    "path": "examples/XGBoost-Examples/taxi/python/com/nvidia/spark/examples/taxi/etl_main.py",
    "content": "#\n# Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n# http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .consts import *\nfrom .pre_process import pre_process\nfrom com.nvidia.spark.examples.utility.utils import *\nfrom pyspark.sql import SparkSession\n\n\ndef main(args, xgboost_args):\n    spark = (SparkSession\n             .builder\n             .appName(args.mainClass)\n             .getOrCreate())\n    raw_data_path = extract_paths(args.dataPaths, 'raw::')\n    output_path = extract_paths(args.dataPaths, 'out::')[0]\n    if not raw_data_path:\n        print('-' * 80)\n        print('Usage: raw data path required when ETL')\n        exit(1)\n    if not output_path:\n        print('-' * 80)\n        print('Usage: output data path required when ETL')\n        exit(1)\n    raw_data = prepare_data(spark, args, raw_schema, raw_data_path)\n    etled_train, etled_eval, etled_trans = pre_process(raw_data).randomSplit(list(map(float, args.splitRatios)))\n    etled_train.write.mode(\"overwrite\").parquet(output_path + '/train')\n    etled_eval.write.mode(\"overwrite\").parquet(output_path + '/eval')\n    etled_trans.write.mode(\"overwrite\").parquet(output_path + '/trans')\n"
  },
  {
    "path": "examples/XGBoost-Examples/taxi/python/com/nvidia/spark/examples/taxi/main.py",
    "content": "#\n# Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n# http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\nfrom .consts import *\nfrom com.nvidia.spark.examples.utility.utils import *\nfrom pyspark.sql import SparkSession\n\nfrom xgboost.spark import SparkXGBRegressor, SparkXGBRegressorModel\n\n\ndef main(args, xgboost_args):\n    spark = (SparkSession\n             .builder\n             .appName(args.mainClass)\n             .getOrCreate())\n\n    train_data, eval_data, trans_data = valid_input_data(spark, args, raw_schema, final_schema)\n\n    if args.mode in ['all', 'train']:\n        if not train_data:\n            print('-' * 80)\n            print('Usage: training data path required when mode is all or train')\n            print('-' * 80)\n            exit(1)\n\n        train_data, features = transform_data(train_data, label, args.use_gpu)\n        xgboost_args['features_col'] = features\n        xgboost_args['label_col'] = label\n        regressor = SparkXGBRegressor(**xgboost_args)\n\n        if eval_data:\n            # pass\n            pass\n\n        model = with_benchmark('Training', lambda: regressor.fit(train_data))\n\n        if args.modelPath:\n            writer = model.write().overwrite() if args.overwrite else model\n            writer.save(args.modelPath)\n    else:\n        model = SparkXGBRegressorModel.load(args.modelPath)\n\n    if args.mode in ['all', 'transform']:\n        if not trans_data:\n            print('-' * 80)\n            print('Usage: trans data path required when mode is all or transform')\n            print('-' * 80)\n            exit(1)\n\n        trans_data, _ = transform_data(trans_data, label, args.use_gpu)\n\n        def transform():\n            result = model.transform(trans_data).cache()\n            result.foreachPartition(lambda _: None)\n            return result\n\n        result = with_benchmark('Transformation', transform)\n        show_sample(args, result, label)\n        with_benchmark('Evaluation', lambda: check_regression_accuracy(result, label))\n\n    spark.stop()\n"
  },
  {
    "path": "examples/XGBoost-Examples/taxi/python/com/nvidia/spark/examples/taxi/pre_process.py",
    "content": "#\n# Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n# http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\nimport math\n\nfrom pyspark.sql.functions import *\nfrom pyspark.sql.types import *\nfrom pyspark.sql.functions import col\n\n\ndef pre_process(data_frame):\n    processes = [\n        drop_useless,\n        encode_categories,\n        fill_na,\n        remove_invalid,\n        convert_datetime,\n        add_h_distance,\n    ]\n    for process in processes:\n        data_frame = process(data_frame)\n    return data_frame\n\ndef drop_useless(data_frame):\n    return data_frame.drop(\n        'dropoff_datetime',\n        'payment_type',\n        'surcharge',\n        'mta_tax',\n        'tip_amount',\n        'tolls_amount',\n        'total_amount')\n\ndef encode_categories(data_frame):\n    categories = [ 'vendor_id', 'rate_code', 'store_and_fwd_flag' ]\n    for category in categories:\n        data_frame = data_frame.withColumn(category, hash(col(category)))\n    return data_frame.withColumnRenamed(\"store_and_fwd_flag\", \"store_and_fwd\")\n\ndef fill_na(data_frame):\n    return data_frame.fillna(-1)\n\ndef remove_invalid(data_frame):\n    conditions = [\n        ( 'fare_amount', 0, 500 ),\n        ( 'passenger_count', 0, 6 ),\n        ( 'pickup_longitude', -75, -73 ),\n        ( 'dropoff_longitude', -75, -73 ),\n        ( 'pickup_latitude', 40, 42 ),\n        ( 'dropoff_latitude', 40, 42 ),\n    ]\n    for column, min, max in conditions:\n        data_frame = data_frame.filter('{} > {} and {} < {}'.format(column, min, column, max))\n    return data_frame\n\ndef convert_datetime(data_frame):\n    datetime = col('pickup_datetime')\n    return (data_frame\n        .withColumn('pickup_datetime', to_timestamp(datetime))\n        .withColumn('year', year(datetime))\n        .withColumn('month', month(datetime))\n        .withColumn('day', dayofmonth(datetime))\n        .withColumn('day_of_week', dayofweek(datetime))\n        .withColumn(\n            'is_weekend',\n            col('day_of_week').isin(1, 7).cast(IntegerType()))  # 1: Sunday, 7: Saturday\n        .withColumn('hour', hour(datetime))\n        .drop('pickup_datetime'))\n\ndef add_h_distance(data_frame):\n    p = math.pi / 180\n    lat1 = col('pickup_latitude')\n    lon1 = col('pickup_longitude')\n    lat2 = col('dropoff_latitude')\n    lon2 = col('dropoff_longitude')\n    internal_value = (0.5\n        - cos((lat2 - lat1) * p) / 2\n        + cos(lat1 * p) * cos(lat2 * p) * (1 - cos((lon2 - lon1) * p)) / 2)\n    h_distance = 12734 * asin(sqrt(internal_value))\n    return data_frame.withColumn('h_distance', h_distance)\n"
  },
  {
    "path": "examples/XGBoost-Examples/taxi/scala/src/com/nvidia/spark/examples/taxi/CrossValidationMain.scala",
    "content": "/*\n * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.nvidia.spark.examples.taxi\n\nimport com.nvidia.spark.examples.utility.{XGBoostArgs, Benchmark}\nimport ml.dmlc.xgboost4j.scala.spark.{XGBoostRegressionModel, XGBoostRegressor}\nimport org.apache.spark.ml.evaluation.RegressionEvaluator\nimport org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}\nimport org.apache.spark.sql.SparkSession\n\nobject CrossValidationMain extends Taxi {\n\n  def main(args: Array[String]): Unit = {\n    val xgboostArgs = XGBoostArgs.parse(args)\n    val processor = this.getClass.getSimpleName.stripSuffix(\"$\").substring(0, 3)\n    val appInfo = Seq(appName, processor, xgboostArgs.format)\n\n    // build spark session\n    val spark = SparkSession.builder()\n      .appName(appInfo.mkString(\"-\"))\n      .getOrCreate()\n\n    val benchmark = Benchmark(appInfo(0), appInfo(1), appInfo(2))\n\n    // build data reader\n    val dataReader = spark.read\n\n    val (pathsArray, dataReadSchema, needEtl) = getDataPaths(xgboostArgs.dataPaths, xgboostArgs.isToTrain, xgboostArgs.isToTransform)\n\n    // 0: train 1: eval 2:transform\n    var datasets = pathsArray.map { paths =>\n      if (paths.nonEmpty) {\n        xgboostArgs.format match {\n          case \"csv\" => Some(dataReader.option(\"header\", xgboostArgs.hasHeader).schema(dataReadSchema).csv(paths: _*))\n          case \"orc\" => Some(dataReader.orc(paths: _*))\n          case \"parquet\" => Some(dataReader.parquet(paths: _*))\n          case _ => throw new IllegalArgumentException(\"Unsupported data file format!\")\n        }\n      } else {\n        None\n      }\n    }\n\n    if (needEtl) datasets = datasets.map(_.map(preProcess(_)))\n\n    val xgbRegressionModel = if (xgboostArgs.isToTrain) {\n      // build XGBoost XGBoostRegressor\n      val xgbParamFinal = xgboostArgs.xgboostParams(commParamMap)\n      val xgbRegressor = new XGBoostRegressor(xgbParamFinal)\n        .setLabelCol(labelColName)\n        .setFeaturesCol(featureNames)\n\n      // Tune model using cross validation\n      val paramGrid = new ParamGridBuilder()\n        .addGrid(xgbRegressor.maxDepth, Array(3, 10))\n        .addGrid(xgbRegressor.eta, Array(0.2, 0.6))\n        .build()\n\n      val evaluator = new RegressionEvaluator().setLabelCol(labelColName)\n\n      val cv = new CrossValidator()\n        .setEstimator(xgbRegressor)\n        .setEvaluator(evaluator)\n        .setEstimatorParamMaps(paramGrid)\n        .setNumFolds(xgboostArgs.numFold)\n\n      println(\"\\n------ Training ------\")\n      // Shall we not log the time if it is abnormal, which is usually caused by training failure\n      val (model, _) = benchmark.time(\"CrossValidator\") {\n        cv.fit(datasets(0).get).bestModel.asInstanceOf[XGBoostRegressionModel]\n      }\n      // Save model if modelPath exists\n      xgboostArgs.modelPath.foreach(path =>\n        if (xgboostArgs.isOverwrite) model.write.overwrite().save(path) else model.save(path))\n      model\n    } else {\n      XGBoostRegressionModel.load(xgboostArgs.modelPath.get)\n    }\n\n    if (xgboostArgs.isToTransform) {\n      println(\"\\n------ Transforming ------\")\n      var (prediction, _) = benchmark.time(\"transform\") {\n        val ret = xgbRegressionModel.transform(datasets(2).get).cache()\n        ret.foreachPartition((_: Iterator[_]) => ())\n        ret\n      }\n      prediction = if (xgboostArgs.isShowFeatures) {\n        prediction\n      } else {\n        prediction.select(labelColName, \"prediction\")\n      }\n      prediction.show(xgboostArgs.numRows)\n\n      println(\"\\n------Accuracy of Evaluation------\")\n      val evaluator = new RegressionEvaluator().setLabelCol(labelColName)\n      evaluator.evaluate(prediction) match {\n        case rmse if !rmse.isNaN => benchmark.value(rmse, \"RMSE\", \"RMSE for\")\n        // Throw an exception when NaN ?\n      }\n    }\n\n    spark.close()\n  }\n}\n"
  },
  {
    "path": "examples/XGBoost-Examples/taxi/scala/src/com/nvidia/spark/examples/taxi/ETLMain.scala",
    "content": "/*\n * Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.nvidia.spark.examples.taxi\n\nimport com.nvidia.spark.examples.utility.{XGBoostArgs, Benchmark}\nimport org.apache.spark.sql.SparkSession\n\nobject ETLMain extends Taxi {\n\n  def main(args: Array[String]): Unit = {\n    val xgboostArgs = XGBoostArgs.parse(args)\n    val processor = this.getClass.getSimpleName.stripSuffix(\"$\").substring(0, 3)\n    val appInfo = Seq(appName, processor, xgboostArgs.format)\n\n    // build spark session\n    val spark = SparkSession.builder()\n      .appName(appInfo.mkString(\"-\"))\n      .getOrCreate()\n\n    val benchmark = Benchmark(appInfo(0), appInfo(1), appInfo(2))\n\n    // build data reader\n    val dataReader = spark.read\n\n    val (rawPaths, outPath) = checkAndGetPaths(xgboostArgs.dataPaths)\n    val df = xgboostArgs.format match {\n      case \"csv\" => dataReader.option(\"header\", xgboostArgs.hasHeader).schema(rawSchema).csv(rawPaths: _*)\n      case \"parquet\" => dataReader.parquet(rawPaths: _*)\n      case \"orc\" => dataReader.orc(rawPaths: _*)\n      case _ => throw new IllegalArgumentException(\"Unsupported data file format!\")\n    }\n\n    val (trainRatio, evalRatio, trainEvalRatio) = xgboostArgs.dataRatios\n    val dataset = preProcess(df, Array(trainRatio, trainEvalRatio, evalRatio))\n\n    benchmark.time(\"ETL\") {\n      for ((name, index) <- Seq(\"train\", \"eval\", \"trans\").zipWithIndex) {\n        dataset(index).write.mode(\"overwrite\").parquet(outPath + \"/parquet/\" + name)\n        dataset(index).write.mode(\"overwrite\").csv(outPath + \"/csv/\" + name)\n      }\n    }\n\n    spark.close()\n  }\n\n  private def checkAndGetPaths(paths: Seq[String]): (Seq[String], String) = {\n    val prefixes = Array(\"raw::\", \"out::\")\n    val validPaths = paths.filter(_.nonEmpty).map(_.trim)\n\n    // get and check train data paths\n    val rawPaths = validPaths.filter(_.startsWith(prefixes.head))\n    require(rawPaths.nonEmpty, s\"$appName ETL requires at least one path for taxi data file.\" +\n      s\" Please specify it by '-dataPath=raw::your_taxi_data_path'\")\n\n    // get and check out path\n    val outPath = validPaths.filter(_.startsWith(prefixes(1)))\n    require(outPath.nonEmpty, s\"$appName ETL requires a path to save the ETLed data file. Please specify it\" +\n      \" by '-dataPath=out::your_out_path', only the first path is used if multiple paths are found.\")\n\n    // check data paths not specified type\n    val unknownPaths = validPaths.filterNot(p => prefixes.exists(p.contains(_)))\n    require(unknownPaths.isEmpty, s\"Unknown type for data path: ${unknownPaths.head}, $appName requires to specify\" +\n      \" the type for each data path by adding the prefix 'raw::' or 'out::'\")\n\n    (rawPaths.map(_.stripPrefix(prefixes.head)),\n      outPath.head.stripPrefix(prefixes(1)))\n  }\n}\n"
  },
  {
    "path": "examples/XGBoost-Examples/taxi/scala/src/com/nvidia/spark/examples/taxi/Main.scala",
    "content": "/*\n * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.nvidia.spark.examples.taxi\n\nimport com.nvidia.spark.examples.utility.{XGBoostArgs, Benchmark}\nimport ml.dmlc.xgboost4j.scala.spark.{XGBoostRegressionModel, XGBoostRegressor}\nimport org.apache.spark.ml.evaluation.RegressionEvaluator\nimport org.apache.spark.sql.SparkSession\n\nobject Main extends Taxi {\n\n  def main(args: Array[String]): Unit = {\n    val xgboostArgs = XGBoostArgs.parse(args)\n    val processor = this.getClass.getSimpleName.stripSuffix(\"$\").substring(0, 3)\n    val appInfo = Seq(appName, processor, xgboostArgs.format)\n\n    // build spark session\n    val spark = SparkSession.builder()\n      .appName(appInfo.mkString(\"-\"))\n      .getOrCreate()\n\n    val benchmark = Benchmark(appInfo(0), appInfo(1), appInfo(2))\n\n    // build data reader\n    val dataReader = spark.read\n\n    val (pathsArray, dataReadSchema, needEtl) = getDataPaths(xgboostArgs.dataPaths, xgboostArgs.isToTrain, xgboostArgs.isToTransform)\n\n    // 0: train 1: eval 2:transform\n    var datasets = pathsArray.map { paths =>\n      if (paths.nonEmpty) {\n        xgboostArgs.format match {\n          case \"csv\" => Some(dataReader.option(\"header\", xgboostArgs.hasHeader).schema(dataReadSchema).csv(paths: _*))\n          case \"orc\" => Some(dataReader.orc(paths: _*))\n          case \"parquet\" => Some(dataReader.parquet(paths: _*))\n          case _ => throw new IllegalArgumentException(\"Unsupported data file format!\")\n        }\n      } else {\n        None\n      }\n    }\n\n    if (needEtl) datasets = datasets.map(_.map(preProcess(_)))\n\n    val xgbRegressionModel = if (xgboostArgs.isToTrain) {\n      // build XGBoost XGBoostRegressor\n      val xgbParamFinal = xgboostArgs.xgboostParams(commParamMap)\n      val xgbRegressor = new XGBoostRegressor(xgbParamFinal)\n        .setLabelCol(labelColName)\n        .setFeaturesCol(featureNames)\n\n      datasets(1).foreach(_ => xgbRegressor.setEvalDataset(_))\n\n      println(\"\\n------ Training ------\")\n      // Shall we not log the time if it is abnormal, which is usually caused by training failure\n      val (model, _) = benchmark.time(\"train\") {\n        xgbRegressor.fit(datasets(0).get)\n      }\n      // Save model if modelPath exists\n      xgboostArgs.modelPath.foreach(path =>\n        if (xgboostArgs.isOverwrite) model.write.overwrite().save(path) else model.save(path))\n      model\n    } else {\n      XGBoostRegressionModel.load(xgboostArgs.modelPath.get)\n    }\n\n    if (xgboostArgs.isToTransform) {\n      println(\"\\n------ Transforming ------\")\n      var (prediction, _) = benchmark.time(\"transform\") {\n        val ret = xgbRegressionModel.transform(datasets(2).get).cache()\n        ret.foreachPartition((_: Iterator[_]) => ())\n        ret\n      }\n      prediction = if (xgboostArgs.isShowFeatures) {\n        prediction\n      } else {\n        prediction.select(labelColName, \"prediction\")\n      }\n      prediction.show(xgboostArgs.numRows)\n\n      println(\"\\n------Accuracy of Evaluation------\")\n      val evaluator = new RegressionEvaluator().setLabelCol(labelColName)\n      evaluator.evaluate(prediction) match {\n        case rmse if !rmse.isNaN => benchmark.value(rmse, \"RMSE\", \"RMSE for\")\n        // Throw an exception when NaN ?\n      }\n    }\n\n    spark.close()\n  }\n}\n"
  },
  {
    "path": "examples/XGBoost-Examples/taxi/scala/src/com/nvidia/spark/examples/taxi/Taxi.scala",
    "content": "/*\n * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.nvidia.spark.examples.taxi\n\nimport org.apache.spark.sql.DataFrame\nimport org.apache.spark.sql.functions._\nimport org.apache.spark.sql.types.DataTypes.{DoubleType, IntegerType, StringType}\nimport org.apache.spark.sql.types.{FloatType, StructField, StructType}\n\nprivate[taxi] trait Taxi {\n  val appName = \"Taxi\"\n  lazy val labelColName = \"fare_amount\"\n  lazy val featureNames = etledSchema.filter(_.name != labelColName).map(_.name).toArray\n\n  lazy val commParamMap = Map(\n    \"num_round\" -> 100\n  )\n\n  val rawSchema = StructType(Seq(\n    StructField(\"vendor_id\", StringType),\n    StructField(\"pickup_datetime\", StringType),\n    StructField(\"dropoff_datetime\", StringType),\n    StructField(\"passenger_count\", IntegerType),\n    StructField(\"trip_distance\", DoubleType),\n    StructField(\"pickup_longitude\", DoubleType),\n    StructField(\"pickup_latitude\", DoubleType),\n    StructField(\"rate_code\", StringType),\n    StructField(\"store_and_fwd_flag\", StringType),\n    StructField(\"dropoff_longitude\", DoubleType),\n    StructField(\"dropoff_latitude\", DoubleType),\n    StructField(\"payment_type\", StringType),\n    StructField(labelColName, DoubleType),\n    StructField(\"surcharge\", DoubleType),\n    StructField(\"mta_tax\", DoubleType),\n    StructField(\"tip_amount\", DoubleType),\n    StructField(\"tolls_amount\", DoubleType),\n    StructField(\"total_amount\", DoubleType)\n  ))\n\n  private val etledSchema =\n    StructType(Array(\n      StructField(\"vendor_id\", FloatType),\n      StructField(\"passenger_count\", FloatType),\n      StructField(\"trip_distance\", FloatType),\n      StructField(\"pickup_longitude\", FloatType),\n      StructField(\"pickup_latitude\", FloatType),\n      StructField(\"rate_code\", FloatType),\n      StructField(\"store_and_fwd\", FloatType),\n      StructField(\"dropoff_longitude\", FloatType),\n      StructField(\"dropoff_latitude\", FloatType),\n      StructField(labelColName, FloatType),\n      StructField(\"hour\", FloatType),\n      StructField(\"year\", IntegerType),\n      StructField(\"month\", IntegerType),\n      StructField(\"day\", FloatType),\n      StructField(\"day_of_week\", FloatType),\n      StructField(\"is_weekend\", FloatType)\n    ))\n\n  def preProcess(dataFrame: DataFrame): DataFrame = {\n    val processes = Seq[DataFrame => DataFrame](\n      dropUseless,\n      encodeCategories,\n      fillNa,\n      removeInvalid,\n      convertDatetime,\n      addHDistance\n    )\n\n    processes\n      .foldLeft(dataFrame) { case (df, process) => process(df) }\n  }\n\n  def preProcess(dataFrame: DataFrame, splits: Array[Int]): Array[DataFrame] = {\n    val processes = Seq[DataFrame => DataFrame](\n      dropUseless,\n      encodeCategories,\n      fillNa,\n      removeInvalid,\n      convertDatetime,\n      addHDistance\n    )\n\n    processes\n      .foldLeft(dataFrame) { case (df, process) => process(df) }\n      .cache()\n      .randomSplit(splits.map(_.toDouble))\n  }\n\n  def dropUseless(dataFrame: DataFrame): DataFrame = {\n    dataFrame.drop(\n      \"dropoff_datetime\",\n      \"payment_type\",\n      \"surcharge\",\n      \"mta_tax\",\n      \"tip_amount\",\n      \"tolls_amount\",\n      \"total_amount\")\n  }\n\n  def encodeCategories(dataFrame: DataFrame): DataFrame = {\n    val categories = Seq(\"vendor_id\", \"rate_code\", \"store_and_fwd_flag\")\n\n    (categories.foldLeft(dataFrame) {\n      case (df, category) => df.withColumn(category, hash(col(category)))\n    }).withColumnRenamed(\"store_and_fwd_flag\", \"store_and_fwd\")\n  }\n\n  def fillNa(dataFrame: DataFrame): DataFrame = {\n    dataFrame.na.fill(-1)\n  }\n\n  def removeInvalid(dataFrame: DataFrame): DataFrame = {\n    val conditions = Seq(\n      Seq(\"fare_amount\", 0, 500),\n      Seq(\"passenger_count\", 0, 6),\n      Seq(\"pickup_longitude\", -75, -73),\n      Seq(\"dropoff_longitude\", -75, -73),\n      Seq(\"pickup_latitude\", 40, 42),\n      Seq(\"dropoff_latitude\", 40, 42))\n\n    conditions\n      .map { case Seq(column, min, max) => \"%s > %d and %s < %d\".format(column, min, column, max) }\n      .foldLeft(dataFrame) {\n        _.filter(_)\n      }\n  }\n\n  def convertDatetime(dataFrame: DataFrame): DataFrame = {\n    val datetime = col(\"pickup_datetime\")\n    dataFrame\n      .withColumn(\"pickup_datetime\", to_timestamp(datetime))\n      .withColumn(\"year\", year(datetime))\n      .withColumn(\"month\", month(datetime))\n      .withColumn(\"day\", dayofmonth(datetime))\n      .withColumn(\"day_of_week\", dayofweek(datetime))\n      .withColumn(\n        \"is_weekend\",\n        col(\"day_of_week\").isin(1, 7).cast(IntegerType)) // 1: Sunday, 7: Saturday\n      .withColumn(\"hour\", hour(datetime))\n      .drop(datetime.toString)\n  }\n\n  def addHDistance(dataFrame: DataFrame): DataFrame = {\n    val P = math.Pi / 180\n    val lat1 = col(\"pickup_latitude\")\n    val lon1 = col(\"pickup_longitude\")\n    val lat2 = col(\"dropoff_latitude\")\n    val lon2 = col(\"dropoff_longitude\")\n    val internalValue = (lit(0.5)\n      - cos((lat2 - lat1) * P) / 2\n      + cos(lat1 * P) * cos(lat2 * P) * (lit(1) - cos((lon2 - lon1) * P)) / 2)\n    val hDistance = lit(12734) * asin(sqrt(internalValue))\n    dataFrame.withColumn(\"h_distance\", hDistance)\n  }\n\n  /**\n   * getDataPaths check and get train/eval/transform paths\n   *\n   * @return Array(train_paths, eval_paths, transform_paths)\n   */\n  def getDataPaths(dataPaths: Seq[String], isToTrain: Boolean, isToTransform: Boolean):\n  (Array[Seq[String]], StructType, Boolean) = {\n    val paths = dataPaths\n    val etledPrefixes = Array(\"train::\", \"eval::\", \"trans::\")\n    val rawPrefixes = Array(\"rawTrain::\", \"rawEval::\", \"rawTrans::\")\n    val validPaths = paths.filter(_.nonEmpty).map(_.trim)\n\n    val p1 = validPaths.filter(p => etledPrefixes.exists(p.startsWith(_)))\n    val p2 = validPaths.filter(p => rawPrefixes.exists(p.startsWith(_)))\n\n    require(p1.isEmpty || p2.isEmpty, s\"requires directly train by '-dataPath=${etledPrefixes(0)}train_data_path\" +\n      s\" -dataPath=${etledPrefixes(1)}eval_data_path -dataPath=${etledPrefixes(2)}transform_data_path' Or \" +\n      s\"E2E train by '-dataPath=${rawPrefixes(0)}train_data_path -dataPath=${rawPrefixes(1)}eval_data_path\" +\n      s\" -dataPath=${rawPrefixes(2)}transform_data_path'\")\n\n    val (prefixes, schema, needEtl) =\n      if (p1.nonEmpty) (etledPrefixes, etledSchema, false)\n      else (rawPrefixes, rawSchema, true)\n\n    // get train data paths\n    val trainPaths = validPaths.filter(_.startsWith(prefixes.head))\n    if (isToTrain) {\n      require(trainPaths.nonEmpty, s\"requires at least one path for train file.\" +\n        s\" Please specify it by '-dataPath=${prefixes(0)}your_train_data_path'\")\n    }\n\n    // get eval path\n    val evalPaths = validPaths.filter(_.startsWith(prefixes(1)))\n\n    // get and check train data paths\n    val transformPaths = validPaths.filter(_.startsWith(prefixes(2)))\n    if (isToTransform) {\n      require(transformPaths.nonEmpty, s\"requires at least one path for transform file.\" +\n        s\" Please specify it by '-dataPath=${prefixes(2)}your_transform_data_path'\")\n    }\n\n    // check data paths not specified type\n    val unknownPaths = validPaths.filterNot(p => prefixes.exists(p.startsWith(_)))\n    require(unknownPaths.isEmpty, s\"Unknown type for data path: ${unknownPaths.head}, requires to specify\" +\n      s\" the type for each data path by adding the prefix '${prefixes(0)}' or '${prefixes(1)}' or '${prefixes(2)}'.\")\n\n    (Array(trainPaths.map(_.stripPrefix(prefixes.head)),\n      evalPaths.map(_.stripPrefix(prefixes(1))),\n      transformPaths.map(_.stripPrefix(prefixes(2)))), schema, needEtl)\n  }\n}\n"
  },
  {
    "path": "examples/XGBoost-Examples/utility/.gitignore",
    "content": ".idea\ntarget\n*.iml\n"
  },
  {
    "path": "examples/XGBoost-Examples/utility/pom.xml",
    "content": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<!--\n  ~ Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.\n  ~\n  ~ Licensed under the Apache License, Version 2.0 (the \"License\");\n  ~ you may not use this file except in compliance with the License.\n  ~ You may obtain a copy of the License at\n  ~\n  ~ http://www.apache.org/licenses/LICENSE-2.0\n  ~\n  ~ Unless required by applicable law or agreed to in writing, software\n  ~ distributed under the License is distributed on an \"AS IS\" BASIS,\n  ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n  ~ See the License for the specific language governing permissions and\n  ~ limitations under the License.\n  -->\n\n<project xmlns=\"http://maven.apache.org/POM/4.0.0\"\n         xmlns:xsi=\"http://www.w3.org/2001/XMLSchema-instance\"\n         xsi:schemaLocation=\"http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd\">\n    <parent>\n        <artifactId>sample_xgboost_examples</artifactId>\n        <groupId>com.nvidia</groupId>\n        <version>0.2.3-SNAPSHOT</version>\n    </parent>\n    <modelVersion>4.0.0</modelVersion>\n\n    <artifactId>spark_examples_utility_${scala.binary.version}</artifactId>\n\n    <properties>\n        <maven.compiler.source>8</maven.compiler.source>\n        <maven.compiler.target>8</maven.compiler.target>\n    </properties>\n\n    <build>\n        <sourceDirectory>scala/src</sourceDirectory>\n    </build>\n</project>"
  },
  {
    "path": "examples/XGBoost-Examples/utility/python/com/__init__.py",
    "content": "# Copyright (c) 2024, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n"
  },
  {
    "path": "examples/XGBoost-Examples/utility/python/com/nvidia/__init__.py",
    "content": "# Copyright (c) 2024, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n"
  },
  {
    "path": "examples/XGBoost-Examples/utility/python/com/nvidia/spark/__init__.py",
    "content": "# Copyright (c) 2024, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n"
  },
  {
    "path": "examples/XGBoost-Examples/utility/python/com/nvidia/spark/examples/__init__.py",
    "content": "# Copyright (c) 2024, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n"
  },
  {
    "path": "examples/XGBoost-Examples/utility/python/com/nvidia/spark/examples/main.py",
    "content": "#\n# Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n# http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\nfrom .utility.args import parse_arguments\nfrom importlib import import_module\n\n\ndef main():\n    args, xgboost_args = parse_arguments()\n    getattr(import_module(args.mainClass), 'main')(args, xgboost_args)\n"
  },
  {
    "path": "examples/XGBoost-Examples/utility/python/com/nvidia/spark/examples/utility/__init__.py",
    "content": "# Copyright (c) 2024, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n"
  },
  {
    "path": "examples/XGBoost-Examples/utility/python/com/nvidia/spark/examples/utility/args.py",
    "content": "#\n# Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n# http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\nimport typing\nfrom argparse import ArgumentParser\nfrom distutils.util import strtobool\nfrom re import match\nfrom sys import exit\n\n\ndef _to_bool(literal):\n    return bool(strtobool(literal))\n\n\ndef _to_ratio_pair(literal):  # e.g., '80:20'\n    return match(r'^\\d+:\\d+$', literal) and [int(x) for x in literal.split(':')]\n\n\nMAX_CHUNK_SIZE = 2 ** 31 - 1\n\n_examples = [\n    'com.nvidia.spark.examples.agaricus.main',\n    'com.nvidia.spark.examples.mortgage.main',\n    'com.nvidia.spark.examples.mortgage.etl_main',\n    'com.nvidia.spark.examples.mortgage.cross_validator_main',\n    'com.nvidia.spark.examples.taxi.main',\n    'com.nvidia.spark.examples.taxi.etl_main',\n    'com.nvidia.spark.examples.taxi.cross_validator_main',\n]\n\n\ndef _validate_args(args):\n    usage = ''\n    if not args.dataPaths:\n        usage += '  --dataPaths is required.\\n'\n    if not (args.dataRatios\n            and 0 <= args.dataRatios[0] <= 100\n            and 0 <= args.dataRatios[1] <= 100\n            and args.dataRatios[0] + args.dataRatios[1] <= 100):\n        usage += '  --dataRatios should be in format \\'Int:Int\\', these two ints should be' \\\n                 ' in range [0, 100] and the sum should be less than or equal to 100.\\n'\n    if not (1 <= args.maxRowsPerChunk <= MAX_CHUNK_SIZE):\n        usage += '  --maxRowsPerChunk should be in range [1, {}].\\n'.format(MAX_CHUNK_SIZE)\n    if usage:\n        print('-' * 80)\n        print('Usage:\\n' + usage)\n        exit(1)\n\n\ndef _attach_derived_args(args):\n    args.trainRatio = args.dataRatios[0]\n    args.evalRatio = args.dataRatios[1]\n    args.trainEvalRatio = 100 - args.trainRatio - args.evalRatio\n    args.splitRatios = [args.trainRatio, args.trainEvalRatio, args.evalRatio]\n\n\ndef _inspect_xgb_parameters() -> typing.Dict[str, type]:\n    \"\"\"inspect XGBModel parameters from __init__\"\"\"\n    from xgboost import XGBModel\n    from typing import get_type_hints, get_origin\n    xgb_parameters = {}\n    xgb_model_sig = get_type_hints(XGBModel.__init__)\n    for k, v in xgb_model_sig.items():\n        if k != \"kwargs\" and k != \"return\":\n            if get_origin(v) == typing.Union:\n                xgb_parameters[k] = v.__args__[0]\n            else:\n                xgb_parameters[k] = v\n\n    # some extra parameters used by xgboost pyspark\n    xgb_parameters['objective'] = str\n    xgb_parameters['force_repartition'] = _to_bool\n    xgb_parameters['use_gpu'] = _to_bool\n    xgb_parameters['num_workers'] = int\n    xgb_parameters['enable_sparse_data_optim'] = _to_bool\n    return xgb_parameters\n\n\ndef parse_arguments():\n    parser = ArgumentParser()\n\n    # application arguments\n    parser.add_argument('--mainClass', required=True, choices=_examples)\n    parser.add_argument('--mode', choices=['all', 'train', 'transform'], default='all')\n    parser.add_argument('--format', required=True, choices=['csv', 'parquet', 'orc'])\n    parser.add_argument('--hasHeader', type=_to_bool, default=True)\n    parser.add_argument('--asFloats', type=_to_bool, default=True)\n    parser.add_argument('--maxRowsPerChunk', type=int, default=MAX_CHUNK_SIZE)\n    parser.add_argument('--modelPath')\n    parser.add_argument('--overwrite', type=_to_bool, default=False)\n    parser.add_argument('--dataPath', dest='dataPaths', action='append')\n    parser.add_argument('--dataRatios', type=_to_ratio_pair, default=[80, 20])\n    parser.add_argument('--numRows', type=int, default=5)\n    parser.add_argument('--showFeatures', type=_to_bool, default=True)\n\n    xgboost_all_args = _inspect_xgb_parameters()\n    for arg, tp in xgboost_all_args.items():\n        parser.add_argument('--' + arg, type=tp)\n\n    parsed_all = parser.parse_args()\n    _validate_args(parsed_all)\n    _attach_derived_args(parsed_all)\n\n    parsed_xgboost = {\n        k: v\n        for k, v in vars(parsed_all).items()\n        if k in xgboost_all_args and v is not None\n    }\n\n    return parsed_all, parsed_xgboost\n"
  },
  {
    "path": "examples/XGBoost-Examples/utility/python/com/nvidia/spark/examples/utility/utils.py",
    "content": "#\n# Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n# http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\nimport typing\n\nfrom pyspark.ml.evaluation import *\nfrom pyspark.ml.feature import VectorAssembler\nfrom pyspark.sql import DataFrame\nfrom pyspark.sql.functions import col\nfrom pyspark.sql.types import FloatType\nfrom com.nvidia.spark.examples.taxi.pre_process import pre_process\nfrom time import time\n\n\ndef merge_dicts(dict_x, dict_y):\n    result = dict_x.copy()\n    result.update(dict_y)\n    return result\n\n\ndef show_sample(args, data_frame, label):\n    data_frame = data_frame if args.showFeatures else data_frame.select(label, 'prediction')\n    data_frame.show(args.numRows)\n\n\ndef vectorize_data_frame(data_frame, label):\n    features = [x.name for x in data_frame.schema if x.name != label]\n    to_floats = [col(x.name).cast(FloatType()) for x in data_frame.schema]\n    return (VectorAssembler()\n            .setInputCols(features)\n            .setOutputCol('features')\n            .transform(data_frame.select(to_floats))\n            .select(col('features'), col(label)))\n\n\ndef vectorize_data_frames(data_frames, label):\n    return [vectorize_data_frame(x, label) for x in data_frames]\n\n\ndef with_benchmark(phrase, action):\n    start = time()\n    result = action()\n    end = time()\n    print('-' * 100)\n    print('{} takes {} seconds'.format(phrase, round(end - start, 2)))\n    return result\n\n\ndef check_classification_accuracy(data_frame, label):\n    accuracy = (MulticlassClassificationEvaluator()\n                .setLabelCol(label)\n                .evaluate(data_frame))\n    print('-' * 100)\n    print('Accuracy is ' + str(accuracy))\n\n\ndef check_regression_accuracy(data_frame, label):\n    accuracy = (RegressionEvaluator()\n                .setLabelCol(label)\n                .evaluate(data_frame))\n    print('-' * 100)\n    print('RMSE is ' + str(accuracy))\n\n\ndef prepare_data(spark, args, schema, dataPath):\n    reader = (spark\n              .read\n              .format(args.format))\n    if args.format == 'csv':\n        reader.schema(schema).option('header', args.hasHeader)\n    return reader.load(dataPath)\n\n\ndef extract_paths(paths, prefix):\n    results = [path[len(prefix):] for path in paths if path.startswith(prefix)]\n    return results\n\n\ndef transform_data(\n        df: DataFrame,\n        label: str,\n        use_gpu: typing.Optional[bool],\n) -> (DataFrame, typing.Union[str, typing.List[str]]):\n    if use_gpu:\n        features = [x.name for x in df.schema if x.name != label]\n    else:\n        df = vectorize_data_frame(df, label)\n        features = 'features'\n    return df, features\n\n\ndef valid_input_data(spark, args, raw_schema, final_schema):\n    e2e = False\n    for path in args.dataPaths:\n        if 'raw' in path:\n            e2e = True\n            break\n    raw_train_path = ''\n    raw_eval_path = ''\n    raw_trans_path = ''\n    eval_path = ''\n\n    if e2e:\n        raw_train_path = extract_paths(args.dataPaths, 'rawTrain::')\n        raw_eval_path = extract_paths(args.dataPaths, 'rawEval::')\n        raw_trans_path = extract_paths(args.dataPaths, 'rawTrans::')\n\n    train_data = ''\n    eval_data = ''\n    trans_data = ''\n\n    # if this is an e2e run\n    if raw_train_path or raw_eval_path or raw_trans_path:\n        raw_train_data = prepare_data(spark, args, raw_schema, raw_train_path)\n        raw_eval_data = ''\n        raw_trans_data = ''\n        if raw_eval_path:\n            raw_eval_data = prepare_data(spark, args, raw_schema, raw_eval_path)\n        if raw_trans_path:\n            raw_trans_data = prepare_data(spark, args, raw_schema, raw_trans_path)\n\n        train_data = pre_process(raw_train_data)\n        if raw_eval_data:\n            eval_data = pre_process(raw_eval_data)\n        if raw_trans_data:\n            trans_data = pre_process(raw_trans_data)\n\n    # if this is just a train/transform\n    else:\n        train_path = extract_paths(args.dataPaths, 'train::')\n        eval_path = extract_paths(args.dataPaths, 'eval::')\n        trans_path = extract_paths(args.dataPaths, 'trans::')\n        if train_path:\n            train_data = prepare_data(spark, args, final_schema, train_path)\n        if eval_path:\n            eval_data = prepare_data(spark, args, final_schema, eval_path)\n        if trans_path:\n            trans_data = prepare_data(spark, args, final_schema, trans_path)\n    return (train_data, eval_data, trans_data)\n"
  },
  {
    "path": "examples/XGBoost-Examples/utility/scala/src/com/nvidia/spark/examples/utility/Benchmark.scala",
    "content": "/*\n * Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.nvidia.spark.examples.utility\n\nimport scala.util.Properties\n\nclass Benchmark(\n    appName: String,\n    processor: String,\n    dataFormat: String) {\n\n  def time[R](phase: String, silent: (Any, Float) => Boolean = (_,_) => false)\n             (block: => R): (R, Float) = {\n    val t0 = System.currentTimeMillis\n    val result = block // call-by-name\n    val elapsedTimeSec = (System.currentTimeMillis - t0).toFloat / 1000\n    logging(elapsedTimeSec, phase, \"Elapsed time for\", \"s\", silent(result, elapsedTimeSec))\n    (result, elapsedTimeSec)\n  }\n\n  def value(value: Any, name: String = \"value\",  prefix: String=\"\", suffix: String = \"\") = {\n    logging(value, name, prefix, suffix, false)\n  }\n\n  private def logging(value: Any, name: String , prefix: String, suffix: String, silent: Boolean) = {\n    if (!silent) {\n      val logString = buildLogSimple(value, prefix, suffix, buildRuntimeInfo(name))\n      println(\"\\n--------------\")\n      println(\"==> Benchmark: \" + logString)\n      println(\"--------------\\n\")\n    }\n  }\n\n  private def buildRuntimeInfo(name: String): String = {\n    // Get runtime information from Environment\n    val osType = Properties.envOrElse(\"RAPIDS_XGB_EXAMPLE_OS_TYPE\", \"Unknown\")\n    val cudaVersion = Properties.envOrElse(\"RAPIDS_XGB_EXAMPLE_CUDA_VERSION\", \"Unknown\")\n    val sparkVersion = Properties.envOrElse(\"RAPIDS_XGB_EXAMPLE_SPARK_VERSION\", \"Unknown\")\n    Seq(appName, processor, name, dataFormat, \"stub\", cudaVersion, osType, sparkVersion)\n      .mkString(\" \")\n  }\n\n  private def buildLogSimple(value: Any, prefix: String, suffix: String, runtimeInfo: String): String =\n    prefix + \" [\" + runtimeInfo + \"]: \" + value + suffix\n}\n\nobject Benchmark {\n  def apply(appName: String, processor: String, dataFormat: String) =\n    new Benchmark(appName, processor, dataFormat)\n}\n"
  },
  {
    "path": "examples/XGBoost-Examples/utility/scala/src/com/nvidia/spark/examples/utility/SparkSetup.scala",
    "content": "\n/*\n * Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.nvidia.spark.examples.utility\n\nimport org.apache.spark.sql.SparkSession\n\nobject SparkSetup {\n  def apply(args: Array[String], appName: String) = {\n    val builder = SparkSession.builder()\n    val masterBuilder = Option(System.getenv(\"SPARK_MASTER\")).map { master =>\n      builder.master(master)\n    }.getOrElse(builder)\n\n    masterBuilder.appName(appName).getOrCreate()\n  }\n\n  def apply(args: Array[String]): SparkSession = SparkSetup(args, \"default\")\n\n}\n"
  },
  {
    "path": "examples/XGBoost-Examples/utility/scala/src/com/nvidia/spark/examples/utility/Vectorize.scala",
    "content": "\n/*\n * Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.nvidia.spark.examples.utility\n\nimport org.apache.spark.ml.feature.VectorAssembler\nimport org.apache.spark.sql.DataFrame\nimport org.apache.spark.sql.functions.col\nimport org.apache.spark.sql.types.FloatType\n\nobject Vectorize {\n  def apply(df: DataFrame, labelName: String, changeLabelName: Boolean = true): DataFrame = {\n    val features = df.schema.collect { case f if f.name != labelName => f.name }\n    val toFloat = df.schema.map(f => col(f.name).cast(FloatType))\n    val labelCol = if (changeLabelName) col(labelName).alias(\"label\") else col(labelName)\n    new VectorAssembler()\n      .setInputCols(features.toArray)\n      .setOutputCol(\"features\")\n      .transform(df.select(toFloat: _*))\n      .select(col(\"features\"), labelCol)\n  }\n\n  def apply(df: DataFrame, featureNames: Seq[String], labelName: String): DataFrame = {\n    val toFloat = df.schema.map(f => col(f.name).cast(FloatType))\n    new VectorAssembler()\n      .setInputCols(featureNames.toArray)\n      .setOutputCol(\"features\")\n      .transform(df.select(toFloat: _*))\n      .select(col(\"features\"), col(labelName))\n  }\n\n  def apply(featureNames: Seq[String], df: DataFrame, otherNames: String*): DataFrame = {\n    val resultCols = (otherNames :+ \"features\").map(col(_))\n    new VectorAssembler()\n      .setInputCols(featureNames.toArray)\n      .setOutputCol(\"features\")\n      .transform(df)\n      .select(resultCols: _*)\n  }\n\n  def criteoApply(df: DataFrame, featureNames: Seq[String], labelName: String): DataFrame = {\n    val toFloat = df.schema.map(f => col(f.name).cast(FloatType))\n    new VectorAssembler()\n      .setHandleInvalid(\"keep\")\n      .setInputCols(featureNames.toArray)\n      .setOutputCol(\"features\")\n      .transform(df.select(toFloat: _*))\n      .select(col(\"features\"), col(labelName))\n  }\n\n}\n"
  },
  {
    "path": "examples/XGBoost-Examples/utility/scala/src/com/nvidia/spark/examples/utility/XGBoostArgs.scala",
    "content": "\n/*\n * Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage com.nvidia.spark.examples.utility\n\nimport com.google.common.base.CaseFormat\n\nimport scala.collection.mutable\nimport scala.util.Try\n\nprivate case class XGBoostArg(\n  required: Boolean = false,\n  parse: String => Any = value => value,\n  message: String = \"\")\n\nobject XGBoostArgs {\n  private val modes = Seq(\"all\", \"train\", \"transform\")\n  private val formats = Seq(\"csv\", \"parquet\", \"orc\")\n  private val stringToBool = Map(\n    \"true\"  -> true,\n    \"false\" -> false,\n    \"1\" -> true,\n    \"0\" -> false\n  )\n  private val booleanMessage = \"Expect 'true' or '1' for true, 'false' or '0' for false.\"\n\n  private def parseDataRatios(value: String): (Int, Int) = {\n    val ratios = value.split(\":\").filter(_.nonEmpty).map(_.toInt)\n    require(ratios.length == 2 && ratios(0) + ratios(1) <= 100)\n    (ratios(0), ratios(1))\n  }\n\n  private val supportedArgs = Map(\n    \"mode\"   -> XGBoostArg(\n      parse = value => { require(modes.contains(value)); value },\n      message = s\"Expect one of [${modes.mkString(\", \")}]\"),\n    \"format\" -> XGBoostArg(true,\n      parse = value => { require(formats.contains(value)); value },\n      message = s\"Expect one of [${formats.mkString(\", \")}]\"),\n    \"dataPath\"  -> XGBoostArg(true),\n    \"dataRatios\" -> XGBoostArg(\n      parse = parseDataRatios,\n      message = \"Expect as <train>:<transform>, both train and transform require Int, and total value <= 100\"),\n    \"modelPath\" -> XGBoostArg(),\n    \"numRows\"   -> XGBoostArg(parse = _.toInt, message = \"Require an Int.\"),\n    \"numFold\" -> XGBoostArg(parse = _.toInt, message = \"Require an Int.\"),\n    \"showFeatures\" -> XGBoostArg(parse = stringToBool, message = booleanMessage),\n    \"overwrite\" -> XGBoostArg(parse = stringToBool, message = booleanMessage),\n    \"hasHeader\" -> XGBoostArg(parse = stringToBool, message = booleanMessage),\n    \"saveDict\"  -> XGBoostArg(parse = stringToBool, message = booleanMessage),\n  )\n\n  private def help: Unit = {\n    println(\"\\n\\nSupported arguments:\")\n    println(\"    -dataPath=path: String, Required\\n\" +\n      \"        The path of data file(s). Use multiple '-dataPath=path#' to specify multiple paths. Such as\" +\n      \" '-dataPath=path1 -dataPath=path2'.\\n\")\n    println(\"    -format=<csv/parquet/orc>: String, Required\\n\" +\n      \"        The format of the data, now only supports 'csv', 'parquet' and 'orc'.\\n\")\n    println(\"    -mode=<all/train/transform>: String\\n\" +\n      \"        To control the behavior of apps. Default is 'all'. \\n\" +\n      \"        * all: Do training and transformation.\\n\" +\n      \"        * train: Do training only, will save model to 'modelPath' if specified.\\n\" +\n      \"        * transform: Transformation only, 'modelPath' is required to provide the model.\\n\")\n    println(\"    -modelPath=path: String\\n\" +\n      \"        Specify where to save model after training, or where to load model for transforming only. \\n\")\n    println(\"    -overwrite=value: Boolean\\n\" +\n      \"        Whether to overwrite the current model data under 'modelPath'. Default is false\\n\")\n    println(\"    -dataRatios=train<Int>:transform<Int>\\n\" +\n      \"        The ratios of data used for train and transform, then the ratio for evaluation is (100-train-test).\" +\n      \" default is 80:20, no evaluation\\n\")\n    println(\"    -hasHeader=value: Boolean\\n\" +\n      \"        Whether the csv file has header. Default is true.\\n\")\n    println(\"    -numRows=value: Int\\n\" +\n      \"        Number of the rows to show after transformation. Default is 5.\\n\")\n    println(\"    -numFold=value: Int\\n\" +\n      \"        Number of the folders to be used in Cross Validation. Default is 3.\\n\")\n    println(\"    -showFeatures=value: Boolean\\n\" +\n      \"        Whether to include the features columns when showing results of transformation. Default is true.\\n\")\n    println(\"    -saveDict=value: Boolean\\n\" +\n      \"        Whether to save the dictionary table for Mortgage ETL. It is saved under '<out>/.dict'. Default is true.\\n\")\n    println(\"    -rabitTrackerHost=value: String\\n\" +\n      \"        Specify rabit tracker host IP address. In some environments XGBoost might fail to resolve\\n\" +\n               \"the IP address of the rabit tracker, a symptom is user receiving ``OSError: [Errno 99]\\n\" +\n               \"Cannot assign requested address`` error during training.  A quick workaround is to\\n\" +\n               \"specify the address explicitly.\\n\")\n    println(\"For XGBoost arguments:\")\n    println(\"    Now we pass all XGBoost parameters transparently to XGBoost, no longer to verify them.\")\n    println(\"    Both of the formats are supported, such as 'numWorkers'. You can pass as either one below:\")\n    println(\"    -numWorkers=10  or  -num_workers=10 \")\n    println()\n  }\n\n  def apply(args: Array[String]) = parse(args)\n\n  def parse(args: Array[String]): XGBoostArgs = {\n    val appArgsMap = mutable.HashMap.empty[String, Any]\n    val xgbArgsMap = mutable.HashMap.empty[String, String]\n    try {\n      args.filter(_.nonEmpty).foreach {\n        argString =>\n          require(argString.startsWith(\"-\") && argString.contains('='),\n            s\"Invalid argument: $argString, expect '-name=value'\")\n\n          val parts = argString.stripPrefix(\"-\").split('=').filter(_.nonEmpty)\n          require(parts.length == 2, s\"Invalid argument: $argString, expect '-name=value'\")\n\n          val (key, value) = (parts(0), parts(1))\n          if (supportedArgs.contains(key)) {\n            // App arguments\n            val parseTry = Try(supportedArgs(key).parse(value))\n            require(parseTry.isSuccess,\n              s\"Invalid value to '$key'. ${supportedArgs(key).message}\")\n            if (key == \"dataPath\") {\n              val paths = appArgsMap.getOrElse(key, Seq.empty).asInstanceOf[Seq[String]] :+ parseTry.get\n              appArgsMap += key -> paths\n            } else {\n              appArgsMap += key -> parseTry.get\n            }\n          } else {\n            // Supposed to be XGBooost parameters\n            xgbArgsMap += key -> value\n          }\n      }\n      supportedArgs.filter(_._2.required).foreach {\n        case (name, _) => require(appArgsMap.contains(name), s\"Missing argument: $name.\")\n      }\n      new XGBoostArgs(appArgsMap.toMap, xgbArgsMap.toMap)\n    } catch {\n      case e: Exception =>\n        help\n        throw e\n    }\n  }\n}\n\nclass XGBoostArgs private[utility] (\n    val appArgsMap: Map[String, Any],\n    val xgbArgsMap: Map[String, String]) {\n\n  def format: String = appArgsMap(\"format\").asInstanceOf[String]\n\n  def modelPath: Option[String] = appArgsMap.get(\"modelPath\").asInstanceOf[Option[String]]\n\n  // mode is optional with default value 'all'\n  private def mode: String = appArgsMap.getOrElse(\"mode\", \"all\").asInstanceOf[String]\n\n  private[utility] def verifyArgsRelation: Unit = {\n    if (mode == \"train\" && modelPath.isEmpty) {\n      println(\"==> You may want to specify the 'modelPath' to save the model when 'train only' mode.\")\n    }\n    if (mode == \"transform\") {\n      require(modelPath.nonEmpty, \"'modelPath' is required for mode: transform\")\n    }\n  }\n  verifyArgsRelation\n\n  def isToTrain: Boolean = mode != \"transform\"\n  def isToTransform: Boolean = mode != \"train\"\n\n  def dataPaths: Seq[String] = appArgsMap(\"dataPath\").asInstanceOf[Seq[String]]\n\n  def dataRatios: (Int, Int, Int) = {\n    val ratios = appArgsMap.get(\"dataRatios\").asInstanceOf[Option[(Int, Int)]].getOrElse((80, 20))\n    (ratios._1, ratios._2, 100 - ratios._1 - ratios._2)\n  }\n\n  def isShowFeatures: Boolean = appArgsMap.get(\"showFeatures\").forall(_.asInstanceOf[Boolean])\n\n  def isOverwrite: Boolean = appArgsMap.get(\"overwrite\").exists(_.asInstanceOf[Boolean])\n\n  def hasHeader: Boolean = appArgsMap.get(\"hasHeader\").forall(_.asInstanceOf[Boolean])\n\n  def saveDict: Boolean = appArgsMap.get(\"saveDict\").forall(_.asInstanceOf[Boolean])\n\n  def numRows: Int = appArgsMap.get(\"numRows\").asInstanceOf[Option[Int]].getOrElse(5)\n\n  def numFold: Int = appArgsMap.get(\"numFold\").asInstanceOf[Option[Int]].getOrElse(3)\n\n  def xgboostParams(otherParams: Map[String, Any] = Map.empty): Map[String, Any] = {\n    val params = otherParams ++ xgbArgsMap.map{\n        case (name, value) if !name.contains('_') =>\n          (CaseFormat.LOWER_CAMEL.to(CaseFormat.LOWER_UNDERSCORE, name), value)\n        case (name, value) => (name, value)\n    }\n\n    val hostIp = params.getOrElse(\"rabit_tracker_host\", \"\").toString\n    if (!hostIp.isEmpty) {\n      params ++ Map(\"rabitTrackerHostIp\" -> hostIp)\n    } else params\n  }\n\n  /**\n   *  getDataPaths check and get train/eval/transform paths\n   * @return Array(train_paths, eval_paths, transform_paths)\n   */\n  def getDataPaths: Array[Seq[String]] = {\n    val paths = dataPaths\n    val prefixes = Array(\"train::\", \"eval::\", \"trans::\")\n    val validPaths = paths.filter(_.nonEmpty).map(_.trim)\n\n    // get train data paths\n    val trainPaths = validPaths.filter(_.startsWith(prefixes.head))\n    if (isToTrain) {\n      require(trainPaths.nonEmpty, s\"requires at least one path for train file.\" +\n        s\" Please specify it by '-dataPath=train::your_train_data_path'\")\n    }\n\n    // get eval path\n    val evalPaths = validPaths.filter(_.startsWith(prefixes(1)))\n\n    // get and check train data paths\n    val transformPaths = validPaths.filter(_.startsWith(prefixes(2)))\n    if (isToTransform) {\n      require(transformPaths.nonEmpty, s\"requires at least one path for transform file.\" +\n        s\" Please specify it by '-dataPath=trans::your_transform_data_path'\")\n    }\n\n    // check data paths not specified type\n    val unknownPaths = validPaths.filterNot(p => prefixes.exists(p.contains(_)))\n    require(unknownPaths.isEmpty, s\"Unknown type for data path: ${unknownPaths.head}, requires to specify\" +\n      \" the type for each data path by adding the prefix 'train::' or 'eval::' or 'trans::'.\")\n\n    Array(trainPaths.map(_.stripPrefix(prefixes.head)),\n      evalPaths.map(_.stripPrefix(prefixes(1))),\n      transformPaths.map(_.stripPrefix(prefixes(2))))\n  }\n}\n"
  },
  {
    "path": "examples/spark-connect-gpu/client/Dockerfile",
    "content": "# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n# Licensed to the Apache Software Foundation (ASF) under one or more\n# contributor license agreements.  See the NOTICE file distributed with\n# this work for additional information regarding copyright ownership.\n# The ASF licenses this file to You under the Apache License, Version 2.0\n# (the \"License\"); you may not use this file except in compliance with\n# the License.  You may obtain a copy of the License at\n#\n#    http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Build stage\nFROM maven:3.8-openjdk-17 AS builder\n# Set platform to linux/amd64\nENV DOCKER_DEFAULT_PLATFORM=linux/amd64\n# Copy the entire project\nCOPY scala /build\nWORKDIR /build\n# Build all modules\nRUN mvn clean install -DskipTests\n\n# Spark connect client image\nFROM apache/spark:4.0.0\n\nUSER root\n\nRUN set -x  \\\n    && apt -q update -y && apt-get install -y vim git\n\nCOPY requirements.txt /tmp/requirements.txt\nRUN pip3 install -r /tmp/requirements.txt\nRUN update-alternatives --install /usr/bin/python python /usr/bin/python3 10\n\nRUN mkdir -p /home/spark/demo\n\nCOPY notebook /home/spark/demo/notebook\nCOPY scala /home/spark/demo/scala\nCOPY --from=builder /build/target/*-jar-with-dependencies.jar /home/spark/demo/scala/\nCOPY python /home/spark/demo/python\n\n# Prepare NDS, make NDS as a package.\nCOPY nds /home/spark/demo/nds\nRUN git clone --depth 1 -b dev https://github.com/NVIDIA/spark-rapids-benchmarks /tmp/spark-rapids-benchmarks && \\\n    cp /tmp/spark-rapids-benchmarks/nds/nds_power.py \\\n       /tmp/spark-rapids-benchmarks/nds/check.py \\\n       /tmp/spark-rapids-benchmarks/nds/nds_schema.py \\\n       /tmp/spark-rapids-benchmarks/nds/PysparkBenchReport.py \\\n       /home/spark/demo/nds/ && \\\n    rm -rf /tmp/spark-rapids-benchmarks\n\nRUN chown -R spark:spark /home/spark\nRUN chown -R spark:spark /home/spark/demo\nRUN usermod -d /home/spark spark\n\nUSER spark\nWORKDIR /home/spark/demo\n\nSHELL [ \"/bin/bash\", \"-c\" ]\nENV SHELL=/bin/bash"
  },
  {
    "path": "examples/spark-connect-gpu/client/README.md",
    "content": "# GPU-Accelerated Spark Connect for ETL and ML (Spark 4.0)\n\nThis project demonstrates some python/scala batch jobs and a complete GPU-accelerated ETL and\nMachine Learning pipeline using Apache Spark 4.0 with Spark Connect, featuring the RAPIDS Accelerator.\n\n## 🏗️ Architecture\n\nThe client side consists of one Docker services:\n\n**Jupyter Lab - Spark Connect Client** (`spark-connect-client`) - Interactive development environment\n\nThe first step, however, is to set up the GPU-accelerated Spark Connect Server. More details can be\nfound [here](../server/README.md).\n\n## 📋 Prerequisites\n\n### Required\n- [Docker](https://docs.docker.com/engine/install/) and [Docker Compose](https://docs.docker.com/compose/install/linux)\n- At least 8GB of available RAM\n- Available ports: 8888\n\n## 🚀 Quick Start\n\n1. **Clone and navigate to the project:**\n   ```bash\n   cd examples/spark-connect-gpu/client\n   ```\n\n2. **Start all services:**\n\n   Set the `SPARK_REMOTE` environment variable to point to your spark-connect-gpu server. By default\n   this is `sc://localhost` (for same node deployments). If the client and server are on `different nodes`,\n   you can either establish an SSH tunnel with port 15002 forwarded (e.g., `ssh -g -L 15002:localhost:15002 -N CONNECT_SERVER_IP`)\n   and use the default `SPARK_REMOTE` value (`sc://localhost`), or override it with the server’s accessible IP address:\n\n   ``` bash\n   export SPARK_REMOTE=sc://CONNECT_SERVER_IP\n   ```\n   Then start the client service:\n\n   ```bash\n   $ docker compose up -d\n   ```\n   (`docker compose` can be used in place of `docker-compose` here and throughout)\n\n3. **Access the Web UI interfaces:**\n\n   **Jupyter Lab**: http://localhost:8888 (no password required) - Interactive notebook environment\n\n4. **Run the demo ETL + ML notebook:**\n   - Navigate to `notebook/spark-connect-gpu-etl-ml.ipynb` in Jupyter Lab\n   - You can also open it in VS Code by selecting http://localhost:8888 as the\n     existing notebook server connection\n   - Run the complete ETL and ML pipeline demonstration\n\n5. **Run the demo python batch job:**\n   - Create a Terminal in the Jupyter Lab\n   - Navigate to `/home/spark/demo/python`\n   - Execute `python batch-job.py`\n\n6. **Run the demo scala batch job:**\n   - Create a Terminal in the Jupyter Lab\n   - Navigate to `/home/spark/demo/scala`\n   - Execute `./run.sh`\n\n7. **Run the demo NDS notebook:**\n   - Navigate to `nds/nds.ipynb` in Jupyter Lab\n   - Run the nds demonstration\n\n## Advanced GPU Configurations\n\nMost users won't need to adjust the GPU configurations. However, if you'd like\nto tune your GPU for better performance, refer to the\n[advanced GPU configurations documentation](https://nvidia.github.io/spark-rapids/docs/additional-functionality/advanced_configs.html).\n\n**Note**: Configurations prefixed with spark.rapids.sql are session-specific\nand can be set safely. However, those marked as **startup** will not take\neffect in Spark Connect.\n\n## 🐳 Service Details\n\n### JupyterLab - Spark Connect Client\n- **Image**: Based on `apache/spark:4.0.0`\n- **Environment**: Pre-configured with PySpark Connect Client\n- **Ports**: 8888 (Jupyter Lab)\n- **Volumes**: Notebooks and work directory mounted\n\n## 🧹 Cleanup\n\nStop and remove all services:\n```bash\ndocker-compose down -v\n```\n\nRemove built images:\n```bash\ndocker-compose down --rmi all -v\n```\n\n### Logs\nLogs for the spark driver/connect server, standalone master, standalone worker, and jupyter server can be viewed using the respective commands:\n```bash\ndocker logs spark-connect-client\n```\n\n## 📖 Additional Resources\n\n- [Apache Spark 4.0 Documentation](https://spark.apache.org/docs/latest/)\n- [Spark Connect Guide](https://spark.apache.org/docs/latest/spark-connect-overview.html)\n- [NVIDIA RAPIDS Accelerator](https://nvidia.github.io/spark-rapids/)\n- [Data and AI Summit Session](https://www.databricks.com/dataaisummit/session/gpu-accelerated-spark-connect)\n"
  },
  {
    "path": "examples/spark-connect-gpu/client/docker-compose.yaml",
    "content": "# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n# Licensed to the Apache Software Foundation (ASF) under one or more\n# contributor license agreements.  See the NOTICE file distributed with\n# this work for additional information regarding copyright ownership.\n# The ASF licenses this file to You under the Apache License, Version 2.0\n# (the \"License\"); you may not use this file except in compliance with\n# the License.  You may obtain a copy of the License at\n#\n#    http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# YAML anchors for shared configurations\nx-spark-common: &spark-common\n  volumes:\n    - ${DATA_DIR:-${PWD}/data}:/data\n\nservices:\n  spark-connect-client:\n    <<: *spark-common\n    image: spark-connect-client-image\n    build:\n      context: .\n      dockerfile: Dockerfile\n    container_name: spark-connect-client\n    hostname: spark-connect-client\n    network_mode: host\n    environment:\n      - SPARK_REMOTE=${SPARK_REMOTE:-sc://localhost}\n    command: >\n      bash -c\n      'jupyter-lab\n      --port 8888\n      --no-browser\n      --IdentityProvider.token=\"\"\n      --ServerApp.password=\"\"\n      --ServerApp.ip='0.0.0.0'\n      --ServerApp.allow_origin='*' '\n"
  },
  {
    "path": "examples/spark-connect-gpu/client/nds/nds.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"metadata\": {},\n   \"cell_type\": \"markdown\",\n   \"source\": \"### Run nds_power.py directly\",\n   \"id\": \"2274cf637f6f4702\"\n  },\n  {\n   \"metadata\": {},\n   \"cell_type\": \"code\",\n   \"outputs\": [],\n   \"execution_count\": null,\n   \"source\": \"%run nds_power.py /data/nds query_0.sql time.csv\",\n   \"id\": \"b5d2eeaeb7a2f63\"\n  },\n  {\n   \"metadata\": {\n    \"collapsed\": true\n   },\n   \"cell_type\": \"markdown\",\n   \"source\": [\n    \"### Importing and Executing APIs in a Jupyter Notebook\\n\",\n    \"\\n\",\n    \"Alternatively, you can import the relevant APIs into your Jupyter notebook and execute them as shown below:\"\n   ],\n   \"id\": \"cb4ca58b118a9209\"\n  },\n  {\n   \"metadata\": {},\n   \"cell_type\": \"code\",\n   \"outputs\": [],\n   \"execution_count\": null,\n   \"source\": [\n    \"from nds_power import gen_sql_from_stream, run_query_stream\\n\",\n    \"\\n\",\n    \"query_stream_file = \\\"query_0.sql\\\"\\n\",\n    \"nds_data_path = \\\"/data/nds\\\"\\n\",\n    \"time_log_file = \\\"time.csv\\\"\\n\",\n    \"\\n\",\n    \"query_dict = gen_sql_from_stream(query_stream_file)\\n\",\n    \"\\n\",\n    \"run_query_stream(input_prefix=nds_data_path,\\n\",\n    \"                 property_file=None,\\n\",\n    \"                 query_dict=query_dict,\\n\",\n    \"                 time_log_output_path=time_log_file,\\n\",\n    \"                 extra_time_log_output_path=None,\\n\",\n    \"                 sub_queries=None,\\n\",\n    \"                 warmup_iterations=0,\\n\",\n    \"                 iterations=1,\\n\",\n    \"                 plan_types=\\\"logical\\\",\\n\",\n    \"                 )\"\n   ],\n   \"id\": \"f8ccb334ec1c6766\"\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 2\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython2\",\n   \"version\": \"2.7.6\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "examples/spark-connect-gpu/client/nds/query_0.sql",
    "content": "-- start query 1 in stream 0 using template query96.tpl\nselect  count(*) \nfrom store_sales\n    ,household_demographics \n    ,time_dim, store\nwhere ss_sold_time_sk = time_dim.t_time_sk   \n    and ss_hdemo_sk = household_demographics.hd_demo_sk \n    and ss_store_sk = s_store_sk\n    and time_dim.t_hour = 8\n    and time_dim.t_minute >= 30\n    and household_demographics.hd_dep_count = 5\n    and store.s_store_name = 'ese'\norder by count(*)\n LIMIT 100;\n\n-- end query 1 in stream 0 using template query96.tpl\n-- start query 2 in stream 0 using template query7.tpl\nselect  i_item_id, \n        avg(ss_quantity) agg1,\n        avg(ss_list_price) agg2,\n        avg(ss_coupon_amt) agg3,\n        avg(ss_sales_price) agg4 \n from store_sales, customer_demographics, date_dim, item, promotion\n where ss_sold_date_sk = d_date_sk and\n       ss_item_sk = i_item_sk and\n       ss_cdemo_sk = cd_demo_sk and\n       ss_promo_sk = p_promo_sk and\n       cd_gender = 'M' and \n       cd_marital_status = 'M' and\n       cd_education_status = '4 yr Degree' and\n       (p_channel_email = 'N' or p_channel_event = 'N') and\n       d_year = 2001 \n group by i_item_id\n order by i_item_id\n  LIMIT 100;\n\n-- end query 2 in stream 0 using template query7.tpl\n-- start query 3 in stream 0 using template query75.tpl\nWITH all_sales AS (\n SELECT d_year\n       ,i_brand_id\n       ,i_class_id\n       ,i_category_id\n       ,i_manufact_id\n       ,SUM(sales_cnt) AS sales_cnt\n       ,SUM(sales_amt) AS sales_amt\n FROM (SELECT d_year\n             ,i_brand_id\n             ,i_class_id\n             ,i_category_id\n             ,i_manufact_id\n             ,cs_quantity - COALESCE(cr_return_quantity,0) AS sales_cnt\n             ,cs_ext_sales_price - COALESCE(cr_return_amount,0.0) AS sales_amt\n       FROM catalog_sales JOIN item ON i_item_sk=cs_item_sk\n                          JOIN date_dim ON d_date_sk=cs_sold_date_sk\n                          LEFT JOIN catalog_returns ON (cs_order_number=cr_order_number \n                                                    AND cs_item_sk=cr_item_sk)\n       WHERE i_category='Shoes'\n       UNION\n       SELECT d_year\n             ,i_brand_id\n             ,i_class_id\n             ,i_category_id\n             ,i_manufact_id\n             ,ss_quantity - COALESCE(sr_return_quantity,0) AS sales_cnt\n             ,ss_ext_sales_price - COALESCE(sr_return_amt,0.0) AS sales_amt\n       FROM store_sales JOIN item ON i_item_sk=ss_item_sk\n                        JOIN date_dim ON d_date_sk=ss_sold_date_sk\n                        LEFT JOIN store_returns ON (ss_ticket_number=sr_ticket_number \n                                                AND ss_item_sk=sr_item_sk)\n       WHERE i_category='Shoes'\n       UNION\n       SELECT d_year\n             ,i_brand_id\n             ,i_class_id\n             ,i_category_id\n             ,i_manufact_id\n             ,ws_quantity - COALESCE(wr_return_quantity,0) AS sales_cnt\n             ,ws_ext_sales_price - COALESCE(wr_return_amt,0.0) AS sales_amt\n       FROM web_sales JOIN item ON i_item_sk=ws_item_sk\n                      JOIN date_dim ON d_date_sk=ws_sold_date_sk\n                      LEFT JOIN web_returns ON (ws_order_number=wr_order_number \n                                            AND ws_item_sk=wr_item_sk)\n       WHERE i_category='Shoes') sales_detail\n GROUP BY d_year, i_brand_id, i_class_id, i_category_id, i_manufact_id)\n SELECT  prev_yr.d_year AS prev_year\n                          ,curr_yr.d_year AS year\n                          ,curr_yr.i_brand_id\n                          ,curr_yr.i_class_id\n                          ,curr_yr.i_category_id\n                          ,curr_yr.i_manufact_id\n                          ,prev_yr.sales_cnt AS prev_yr_cnt\n                          ,curr_yr.sales_cnt AS curr_yr_cnt\n                          ,curr_yr.sales_cnt-prev_yr.sales_cnt AS sales_cnt_diff\n                          ,curr_yr.sales_amt-prev_yr.sales_amt AS sales_amt_diff\n FROM all_sales curr_yr, all_sales prev_yr\n WHERE curr_yr.i_brand_id=prev_yr.i_brand_id\n   AND curr_yr.i_class_id=prev_yr.i_class_id\n   AND curr_yr.i_category_id=prev_yr.i_category_id\n   AND curr_yr.i_manufact_id=prev_yr.i_manufact_id\n   AND curr_yr.d_year=2000\n   AND prev_yr.d_year=2000-1\n   AND CAST(curr_yr.sales_cnt AS DECIMAL(17,2))/CAST(prev_yr.sales_cnt AS DECIMAL(17,2))<0.9\n ORDER BY sales_cnt_diff,sales_amt_diff\n  LIMIT 100;\n\n-- end query 3 in stream 0 using template query75.tpl\n-- start query 4 in stream 0 using template query44.tpl\nselect  asceding.rnk, i1.i_product_name best_performing, i2.i_product_name worst_performing\nfrom(select *\n     from (select item_sk,rank() over (order by rank_col asc) rnk\n           from (select ss_item_sk item_sk,avg(ss_net_profit) rank_col \n                 from store_sales ss1\n                 where ss_store_sk = 30\n                 group by ss_item_sk\n                 having avg(ss_net_profit) > 0.9*(select avg(ss_net_profit) rank_col\n                                                  from store_sales\n                                                  where ss_store_sk = 30\n                                                    and ss_hdemo_sk is null\n                                                  group by ss_store_sk))V1)V11\n     where rnk  < 11) asceding,\n    (select *\n     from (select item_sk,rank() over (order by rank_col desc) rnk\n           from (select ss_item_sk item_sk,avg(ss_net_profit) rank_col\n                 from store_sales ss1\n                 where ss_store_sk = 30\n                 group by ss_item_sk\n                 having avg(ss_net_profit) > 0.9*(select avg(ss_net_profit) rank_col\n                                                  from store_sales\n                                                  where ss_store_sk = 30\n                                                    and ss_hdemo_sk is null\n                                                  group by ss_store_sk))V2)V21\n     where rnk  < 11) descending,\nitem i1,\nitem i2\nwhere asceding.rnk = descending.rnk \n  and i1.i_item_sk=asceding.item_sk\n  and i2.i_item_sk=descending.item_sk\norder by asceding.rnk\n LIMIT 100;\n\n-- end query 4 in stream 0 using template query44.tpl\n-- start query 5 in stream 0 using template query39.tpl\nwith inv as\n(select w_warehouse_name,w_warehouse_sk,i_item_sk,d_moy\n       ,stdev,mean, case mean when 0 then null else stdev/mean end cov\n from(select w_warehouse_name,w_warehouse_sk,i_item_sk,d_moy\n            ,stddev_samp(inv_quantity_on_hand) stdev,avg(inv_quantity_on_hand) mean\n      from inventory\n          ,item\n          ,warehouse\n          ,date_dim\n      where inv_item_sk = i_item_sk\n        and inv_warehouse_sk = w_warehouse_sk\n        and inv_date_sk = d_date_sk\n        and d_year =2001\n      group by w_warehouse_name,w_warehouse_sk,i_item_sk,d_moy) foo\n where case mean when 0 then 0 else stdev/mean end > 1)\nselect inv1.w_warehouse_sk,inv1.i_item_sk,inv1.d_moy,inv1.mean, inv1.cov\n        ,inv2.w_warehouse_sk,inv2.i_item_sk,inv2.d_moy,inv2.mean, inv2.cov\nfrom inv inv1,inv inv2\nwhere inv1.i_item_sk = inv2.i_item_sk\n  and inv1.w_warehouse_sk =  inv2.w_warehouse_sk\n  and inv1.d_moy=1\n  and inv2.d_moy=1+1\norder by inv1.w_warehouse_sk,inv1.i_item_sk,inv1.d_moy,inv1.mean,inv1.cov\n        ,inv2.d_moy,inv2.mean, inv2.cov\n;\nwith inv as\n(select w_warehouse_name,w_warehouse_sk,i_item_sk,d_moy\n       ,stdev,mean, case mean when 0 then null else stdev/mean end cov\n from(select w_warehouse_name,w_warehouse_sk,i_item_sk,d_moy\n            ,stddev_samp(inv_quantity_on_hand) stdev,avg(inv_quantity_on_hand) mean\n      from inventory\n          ,item\n          ,warehouse\n          ,date_dim\n      where inv_item_sk = i_item_sk\n        and inv_warehouse_sk = w_warehouse_sk\n        and inv_date_sk = d_date_sk\n        and d_year =2001\n      group by w_warehouse_name,w_warehouse_sk,i_item_sk,d_moy) foo\n where case mean when 0 then 0 else stdev/mean end > 1)\nselect inv1.w_warehouse_sk,inv1.i_item_sk,inv1.d_moy,inv1.mean, inv1.cov\n        ,inv2.w_warehouse_sk,inv2.i_item_sk,inv2.d_moy,inv2.mean, inv2.cov\nfrom inv inv1,inv inv2\nwhere inv1.i_item_sk = inv2.i_item_sk\n  and inv1.w_warehouse_sk =  inv2.w_warehouse_sk\n  and inv1.d_moy=1\n  and inv2.d_moy=1+1\n  and inv1.cov > 1.5\norder by inv1.w_warehouse_sk,inv1.i_item_sk,inv1.d_moy,inv1.mean,inv1.cov\n        ,inv2.d_moy,inv2.mean, inv2.cov\n;\n\n-- end query 5 in stream 0 using template query39.tpl\n-- start query 6 in stream 0 using template query80.tpl\nwith ssr as\n (select  s_store_id as store_id,\n          sum(ss_ext_sales_price) as sales,\n          sum(coalesce(sr_return_amt, 0)) as returns,\n          sum(ss_net_profit - coalesce(sr_net_loss, 0)) as profit\n  from store_sales left outer join store_returns on\n         (ss_item_sk = sr_item_sk and ss_ticket_number = sr_ticket_number),\n     date_dim,\n     store,\n     item,\n     promotion\n where ss_sold_date_sk = d_date_sk\n       and d_date between cast('2002-08-04' as date) \n                  and (cast('2002-08-04' as date) + interval 30 days)\n       and ss_store_sk = s_store_sk\n       and ss_item_sk = i_item_sk\n       and i_current_price > 50\n       and ss_promo_sk = p_promo_sk\n       and p_channel_tv = 'N'\n group by s_store_id)\n ,\n csr as\n (select  cp_catalog_page_id as catalog_page_id,\n          sum(cs_ext_sales_price) as sales,\n          sum(coalesce(cr_return_amount, 0)) as returns,\n          sum(cs_net_profit - coalesce(cr_net_loss, 0)) as profit\n  from catalog_sales left outer join catalog_returns on\n         (cs_item_sk = cr_item_sk and cs_order_number = cr_order_number),\n     date_dim,\n     catalog_page,\n     item,\n     promotion\n where cs_sold_date_sk = d_date_sk\n       and d_date between cast('2002-08-04' as date)\n                  and (cast('2002-08-04' as date) + interval 30 days)\n        and cs_catalog_page_sk = cp_catalog_page_sk\n       and cs_item_sk = i_item_sk\n       and i_current_price > 50\n       and cs_promo_sk = p_promo_sk\n       and p_channel_tv = 'N'\ngroup by cp_catalog_page_id)\n ,\n wsr as\n (select  web_site_id,\n          sum(ws_ext_sales_price) as sales,\n          sum(coalesce(wr_return_amt, 0)) as returns,\n          sum(ws_net_profit - coalesce(wr_net_loss, 0)) as profit\n  from web_sales left outer join web_returns on\n         (ws_item_sk = wr_item_sk and ws_order_number = wr_order_number),\n     date_dim,\n     web_site,\n     item,\n     promotion\n where ws_sold_date_sk = d_date_sk\n       and d_date between cast('2002-08-04' as date)\n                  and (cast('2002-08-04' as date) + interval 30 days)\n        and ws_web_site_sk = web_site_sk\n       and ws_item_sk = i_item_sk\n       and i_current_price > 50\n       and ws_promo_sk = p_promo_sk\n       and p_channel_tv = 'N'\ngroup by web_site_id)\n  select  channel\n        , id\n        , sum(sales) as sales\n        , sum(returns) as returns\n        , sum(profit) as profit\n from \n (select 'store channel' as channel\n        , 'store' || store_id as id\n        , sales\n        , returns\n        , profit\n from   ssr\n union all\n select 'catalog channel' as channel\n        , 'catalog_page' || catalog_page_id as id\n        , sales\n        , returns\n        , profit\n from  csr\n union all\n select 'web channel' as channel\n        , 'web_site' || web_site_id as id\n        , sales\n        , returns\n        , profit\n from   wsr\n ) x\n group by rollup (channel, id)\n order by channel\n         ,id\n  LIMIT 100;\n\n-- end query 6 in stream 0 using template query80.tpl\n-- start query 7 in stream 0 using template query32.tpl\nselect  sum(cs_ext_discount_amt)  as `excess discount amount`\nfrom \n   catalog_sales \n   ,item \n   ,date_dim\nwhere\ni_manufact_id = 283\nand i_item_sk = cs_item_sk \nand d_date between '1999-02-22' and \n        (cast('1999-02-22' as date) + interval 90 days)\nand d_date_sk = cs_sold_date_sk \nand cs_ext_discount_amt  \n     > ( \n         select \n            1.3 * avg(cs_ext_discount_amt) \n         from \n            catalog_sales \n           ,date_dim\n         where \n              cs_item_sk = i_item_sk \n          and d_date between '1999-02-22' and\n                             (cast('1999-02-22' as date) + interval 90 days)\n          and d_date_sk = cs_sold_date_sk \n      ) \n LIMIT 100;\n\n-- end query 7 in stream 0 using template query32.tpl\n-- start query 8 in stream 0 using template query19.tpl\nselect  i_brand_id brand_id, i_brand brand, i_manufact_id, i_manufact,\n \tsum(ss_ext_sales_price) ext_price\n from date_dim, store_sales, item,customer,customer_address,store\n where d_date_sk = ss_sold_date_sk\n   and ss_item_sk = i_item_sk\n   and i_manager_id=8\n   and d_moy=11\n   and d_year=1999\n   and ss_customer_sk = c_customer_sk \n   and c_current_addr_sk = ca_address_sk\n   and substr(ca_zip,1,5) <> substr(s_zip,1,5) \n   and ss_store_sk = s_store_sk \n group by i_brand\n      ,i_brand_id\n      ,i_manufact_id\n      ,i_manufact\n order by ext_price desc\n         ,i_brand\n         ,i_brand_id\n         ,i_manufact_id\n         ,i_manufact\n LIMIT 100 ;\n\n-- end query 8 in stream 0 using template query19.tpl\n-- start query 9 in stream 0 using template query25.tpl\nselect  \n i_item_id\n ,i_item_desc\n ,s_store_id\n ,s_store_name\n ,min(ss_net_profit) as store_sales_profit\n ,min(sr_net_loss) as store_returns_loss\n ,min(cs_net_profit) as catalog_sales_profit\n from\n store_sales\n ,store_returns\n ,catalog_sales\n ,date_dim d1\n ,date_dim d2\n ,date_dim d3\n ,store\n ,item\n where\n d1.d_moy = 4\n and d1.d_year = 2002\n and d1.d_date_sk = ss_sold_date_sk\n and i_item_sk = ss_item_sk\n and s_store_sk = ss_store_sk\n and ss_customer_sk = sr_customer_sk\n and ss_item_sk = sr_item_sk\n and ss_ticket_number = sr_ticket_number\n and sr_returned_date_sk = d2.d_date_sk\n and d2.d_moy               between 4 and  10\n and d2.d_year              = 2002\n and sr_customer_sk = cs_bill_customer_sk\n and sr_item_sk = cs_item_sk\n and cs_sold_date_sk = d3.d_date_sk\n and d3.d_moy               between 4 and  10 \n and d3.d_year              = 2002\n group by\n i_item_id\n ,i_item_desc\n ,s_store_id\n ,s_store_name\n order by\n i_item_id\n ,i_item_desc\n ,s_store_id\n ,s_store_name\n  LIMIT 100;\n\n-- end query 9 in stream 0 using template query25.tpl\n-- start query 10 in stream 0 using template query78.tpl\nwith ws as\n  (select d_year AS ws_sold_year, ws_item_sk,\n    ws_bill_customer_sk ws_customer_sk,\n    sum(ws_quantity) ws_qty,\n    sum(ws_wholesale_cost) ws_wc,\n    sum(ws_sales_price) ws_sp\n   from web_sales\n   left join web_returns on wr_order_number=ws_order_number and ws_item_sk=wr_item_sk\n   join date_dim on ws_sold_date_sk = d_date_sk\n   where wr_order_number is null\n   group by d_year, ws_item_sk, ws_bill_customer_sk\n   ),\ncs as\n  (select d_year AS cs_sold_year, cs_item_sk,\n    cs_bill_customer_sk cs_customer_sk,\n    sum(cs_quantity) cs_qty,\n    sum(cs_wholesale_cost) cs_wc,\n    sum(cs_sales_price) cs_sp\n   from catalog_sales\n   left join catalog_returns on cr_order_number=cs_order_number and cs_item_sk=cr_item_sk\n   join date_dim on cs_sold_date_sk = d_date_sk\n   where cr_order_number is null\n   group by d_year, cs_item_sk, cs_bill_customer_sk\n   ),\nss as\n  (select d_year AS ss_sold_year, ss_item_sk,\n    ss_customer_sk,\n    sum(ss_quantity) ss_qty,\n    sum(ss_wholesale_cost) ss_wc,\n    sum(ss_sales_price) ss_sp\n   from store_sales\n   left join store_returns on sr_ticket_number=ss_ticket_number and ss_item_sk=sr_item_sk\n   join date_dim on ss_sold_date_sk = d_date_sk\n   where sr_ticket_number is null\n   group by d_year, ss_item_sk, ss_customer_sk\n   )\n select \nss_customer_sk,\nround(ss_qty/(coalesce(ws_qty,0)+coalesce(cs_qty,0)),2) ratio,\nss_qty store_qty, ss_wc store_wholesale_cost, ss_sp store_sales_price,\ncoalesce(ws_qty,0)+coalesce(cs_qty,0) other_chan_qty,\ncoalesce(ws_wc,0)+coalesce(cs_wc,0) other_chan_wholesale_cost,\ncoalesce(ws_sp,0)+coalesce(cs_sp,0) other_chan_sales_price\nfrom ss\nleft join ws on (ws_sold_year=ss_sold_year and ws_item_sk=ss_item_sk and ws_customer_sk=ss_customer_sk)\nleft join cs on (cs_sold_year=ss_sold_year and cs_item_sk=ss_item_sk and cs_customer_sk=ss_customer_sk)\nwhere (coalesce(ws_qty,0)>0 or coalesce(cs_qty, 0)>0) and ss_sold_year=2001\norder by \n  ss_customer_sk,\n  ss_qty desc, ss_wc desc, ss_sp desc,\n  other_chan_qty,\n  other_chan_wholesale_cost,\n  other_chan_sales_price,\n  ratio\n LIMIT 100;\n\n-- end query 10 in stream 0 using template query78.tpl\n-- start query 11 in stream 0 using template query86.tpl\nselect   \n    sum(ws_net_paid) as total_sum\n   ,i_category\n   ,i_class\n   ,grouping(i_category)+grouping(i_class) as lochierarchy\n   ,rank() over (\n \tpartition by grouping(i_category)+grouping(i_class),\n \tcase when grouping(i_class) = 0 then i_category end \n \torder by sum(ws_net_paid) desc) as rank_within_parent\n from\n    web_sales\n   ,date_dim       d1\n   ,item\n where\n    d1.d_month_seq between 1205 and 1205+11\n and d1.d_date_sk = ws_sold_date_sk\n and i_item_sk  = ws_item_sk\n group by rollup(i_category,i_class)\n order by\n   lochierarchy desc,\n   case when lochierarchy = 0 then i_category end,\n   rank_within_parent\n  LIMIT 100;\n\n-- end query 11 in stream 0 using template query86.tpl\n-- start query 12 in stream 0 using template query1.tpl\nwith customer_total_return as\n(select sr_customer_sk as ctr_customer_sk\n,sr_store_sk as ctr_store_sk\n,sum(SR_RETURN_AMT_INC_TAX) as ctr_total_return\nfrom store_returns\n,date_dim\nwhere sr_returned_date_sk = d_date_sk\nand d_year =1999\ngroup by sr_customer_sk\n,sr_store_sk)\n select  c_customer_id\nfrom customer_total_return ctr1\n,store\n,customer\nwhere ctr1.ctr_total_return > (select avg(ctr_total_return)*1.2\nfrom customer_total_return ctr2\nwhere ctr1.ctr_store_sk = ctr2.ctr_store_sk)\nand s_store_sk = ctr1.ctr_store_sk\nand s_state = 'SD'\nand ctr1.ctr_customer_sk = c_customer_sk\norder by c_customer_id\n LIMIT 100;\n\n-- end query 12 in stream 0 using template query1.tpl\n-- start query 13 in stream 0 using template query91.tpl\nselect  \n        cc_call_center_id Call_Center,\n        cc_name Call_Center_Name,\n        cc_manager Manager,\n        sum(cr_net_loss) Returns_Loss\nfrom\n        call_center,\n        catalog_returns,\n        date_dim,\n        customer,\n        customer_address,\n        customer_demographics,\n        household_demographics\nwhere\n        cr_call_center_sk       = cc_call_center_sk\nand     cr_returned_date_sk     = d_date_sk\nand     cr_returning_customer_sk= c_customer_sk\nand     cd_demo_sk              = c_current_cdemo_sk\nand     hd_demo_sk              = c_current_hdemo_sk\nand     ca_address_sk           = c_current_addr_sk\nand     d_year                  = 2002 \nand     d_moy                   = 11\nand     ( (cd_marital_status       = 'M' and cd_education_status     = 'Unknown')\n        or(cd_marital_status       = 'W' and cd_education_status     = 'Advanced Degree'))\nand     hd_buy_potential like 'Unknown%'\nand     ca_gmt_offset           = -6\ngroup by cc_call_center_id,cc_name,cc_manager,cd_marital_status,cd_education_status\norder by sum(cr_net_loss) desc;\n\n-- end query 13 in stream 0 using template query91.tpl\n-- start query 14 in stream 0 using template query21.tpl\nselect  *\n from(select w_warehouse_name\n            ,i_item_id\n            ,sum(case when (cast(d_date as date) < cast ('2000-05-19' as date))\n\t                then inv_quantity_on_hand \n                      else 0 end) as inv_before\n            ,sum(case when (cast(d_date as date) >= cast ('2000-05-19' as date))\n                      then inv_quantity_on_hand \n                      else 0 end) as inv_after\n   from inventory\n       ,warehouse\n       ,item\n       ,date_dim\n   where i_current_price between 0.99 and 1.49\n     and i_item_sk          = inv_item_sk\n     and inv_warehouse_sk   = w_warehouse_sk\n     and inv_date_sk    = d_date_sk\n     and d_date between (cast ('2000-05-19' as date) - interval 30 days)\n                    and (cast ('2000-05-19' as date) + interval 30 days)\n   group by w_warehouse_name, i_item_id) x\n where (case when inv_before > 0 \n             then inv_after / inv_before \n             else null\n             end) between 2.0/3.0 and 3.0/2.0\n order by w_warehouse_name\n         ,i_item_id\n  LIMIT 100;\n\n-- end query 14 in stream 0 using template query21.tpl\n-- start query 15 in stream 0 using template query43.tpl\nselect  s_store_name, s_store_id,\n        sum(case when (d_day_name='Sunday') then ss_sales_price else null end) sun_sales,\n        sum(case when (d_day_name='Monday') then ss_sales_price else null end) mon_sales,\n        sum(case when (d_day_name='Tuesday') then ss_sales_price else  null end) tue_sales,\n        sum(case when (d_day_name='Wednesday') then ss_sales_price else null end) wed_sales,\n        sum(case when (d_day_name='Thursday') then ss_sales_price else null end) thu_sales,\n        sum(case when (d_day_name='Friday') then ss_sales_price else null end) fri_sales,\n        sum(case when (d_day_name='Saturday') then ss_sales_price else null end) sat_sales\n from date_dim, store_sales, store\n where d_date_sk = ss_sold_date_sk and\n       s_store_sk = ss_store_sk and\n       s_gmt_offset = -5 and\n       d_year = 2000 \n group by s_store_name, s_store_id\n order by s_store_name, s_store_id,sun_sales,mon_sales,tue_sales,wed_sales,thu_sales,fri_sales,sat_sales\n  LIMIT 100;\n\n-- end query 15 in stream 0 using template query43.tpl\n-- start query 16 in stream 0 using template query27.tpl\nselect  i_item_id,\n        s_state, grouping(s_state) g_state,\n        avg(ss_quantity) agg1,\n        avg(ss_list_price) agg2,\n        avg(ss_coupon_amt) agg3,\n        avg(ss_sales_price) agg4\n from store_sales, customer_demographics, date_dim, store, item\n where ss_sold_date_sk = d_date_sk and\n       ss_item_sk = i_item_sk and\n       ss_store_sk = s_store_sk and\n       ss_cdemo_sk = cd_demo_sk and\n       cd_gender = 'F' and\n       cd_marital_status = 'D' and\n       cd_education_status = 'College' and\n       d_year = 2002 and\n       s_state in ('SD','AL', 'TN', 'TN', 'SD', 'SD')\n group by rollup (i_item_id, s_state)\n order by i_item_id\n         ,s_state\n  LIMIT 100;\n\n-- end query 16 in stream 0 using template query27.tpl\n-- start query 17 in stream 0 using template query94.tpl\nselect  \n   count(distinct ws_order_number) as `order count`\n  ,sum(ws_ext_ship_cost) as `total shipping cost`\n  ,sum(ws_net_profit) as `total net profit`\nfrom\n   web_sales ws1\n  ,date_dim\n  ,customer_address\n  ,web_site\nwhere\n    d_date between '2001-5-01' and \n           (cast('2001-5-01' as date) + interval 60 days)\nand ws1.ws_ship_date_sk = d_date_sk\nand ws1.ws_ship_addr_sk = ca_address_sk\nand ca_state = 'AR'\nand ws1.ws_web_site_sk = web_site_sk\nand web_company_name = 'pri'\nand exists (select *\n            from web_sales ws2\n            where ws1.ws_order_number = ws2.ws_order_number\n              and ws1.ws_warehouse_sk <> ws2.ws_warehouse_sk)\nand not exists(select *\n               from web_returns wr1\n               where ws1.ws_order_number = wr1.wr_order_number)\norder by count(distinct ws_order_number)\n LIMIT 100;\n\n-- end query 17 in stream 0 using template query94.tpl\n-- start query 18 in stream 0 using template query45.tpl\nselect  ca_zip, ca_county, sum(ws_sales_price)\n from web_sales, customer, customer_address, date_dim, item\n where ws_bill_customer_sk = c_customer_sk\n \tand c_current_addr_sk = ca_address_sk \n \tand ws_item_sk = i_item_sk \n \tand ( substr(ca_zip,1,5) in ('85669', '86197','88274','83405','86475', '85392', '85460', '80348', '81792')\n \t      or \n \t      i_item_id in (select i_item_id\n                             from item\n                             where i_item_sk in (2, 3, 5, 7, 11, 13, 17, 19, 23, 29)\n                             )\n \t    )\n \tand ws_sold_date_sk = d_date_sk\n \tand d_qoy = 2 and d_year = 2000\n group by ca_zip, ca_county\n order by ca_zip, ca_county\n  LIMIT 100;\n\n-- end query 18 in stream 0 using template query45.tpl\n-- start query 19 in stream 0 using template query58.tpl\nwith ss_items as\n (select i_item_id item_id\n        ,sum(ss_ext_sales_price) ss_item_rev \n from store_sales\n     ,item\n     ,date_dim\n where ss_item_sk = i_item_sk\n   and d_date in (select d_date\n                  from date_dim\n                  where d_week_seq = (select d_week_seq \n                                      from date_dim\n                                      where d_date = '2002-04-19'))\n   and ss_sold_date_sk   = d_date_sk\n group by i_item_id),\n cs_items as\n (select i_item_id item_id\n        ,sum(cs_ext_sales_price) cs_item_rev\n  from catalog_sales\n      ,item\n      ,date_dim\n where cs_item_sk = i_item_sk\n  and  d_date in (select d_date\n                  from date_dim\n                  where d_week_seq = (select d_week_seq \n                                      from date_dim\n                                      where d_date = '2002-04-19'))\n  and  cs_sold_date_sk = d_date_sk\n group by i_item_id),\n ws_items as\n (select i_item_id item_id\n        ,sum(ws_ext_sales_price) ws_item_rev\n  from web_sales\n      ,item\n      ,date_dim\n where ws_item_sk = i_item_sk\n  and  d_date in (select d_date\n                  from date_dim\n                  where d_week_seq =(select d_week_seq \n                                     from date_dim\n                                     where d_date = '2002-04-19'))\n  and ws_sold_date_sk   = d_date_sk\n group by i_item_id)\n  select  ss_items.item_id\n       ,ss_item_rev\n       ,ss_item_rev/((ss_item_rev+cs_item_rev+ws_item_rev)/3) * 100 ss_dev\n       ,cs_item_rev\n       ,cs_item_rev/((ss_item_rev+cs_item_rev+ws_item_rev)/3) * 100 cs_dev\n       ,ws_item_rev\n       ,ws_item_rev/((ss_item_rev+cs_item_rev+ws_item_rev)/3) * 100 ws_dev\n       ,(ss_item_rev+cs_item_rev+ws_item_rev)/3 average\n from ss_items,cs_items,ws_items\n where ss_items.item_id=cs_items.item_id\n   and ss_items.item_id=ws_items.item_id \n   and ss_item_rev between 0.9 * cs_item_rev and 1.1 * cs_item_rev\n   and ss_item_rev between 0.9 * ws_item_rev and 1.1 * ws_item_rev\n   and cs_item_rev between 0.9 * ss_item_rev and 1.1 * ss_item_rev\n   and cs_item_rev between 0.9 * ws_item_rev and 1.1 * ws_item_rev\n   and ws_item_rev between 0.9 * ss_item_rev and 1.1 * ss_item_rev\n   and ws_item_rev between 0.9 * cs_item_rev and 1.1 * cs_item_rev\n order by item_id\n         ,ss_item_rev\n  LIMIT 100;\n\n-- end query 19 in stream 0 using template query58.tpl\n-- start query 20 in stream 0 using template query64.tpl\nwith cs_ui as\n (select cs_item_sk\n        ,sum(cs_ext_list_price) as sale,sum(cr_refunded_cash+cr_reversed_charge+cr_store_credit) as refund\n  from catalog_sales\n      ,catalog_returns\n  where cs_item_sk = cr_item_sk\n    and cs_order_number = cr_order_number\n  group by cs_item_sk\n  having sum(cs_ext_list_price)>2*sum(cr_refunded_cash+cr_reversed_charge+cr_store_credit)),\ncross_sales as\n (select i_product_name product_name\n     ,i_item_sk item_sk\n     ,s_store_name store_name\n     ,s_zip store_zip\n     ,ad1.ca_street_number b_street_number\n     ,ad1.ca_street_name b_street_name\n     ,ad1.ca_city b_city\n     ,ad1.ca_zip b_zip\n     ,ad2.ca_street_number c_street_number\n     ,ad2.ca_street_name c_street_name\n     ,ad2.ca_city c_city\n     ,ad2.ca_zip c_zip\n     ,d1.d_year as syear\n     ,d2.d_year as fsyear\n     ,d3.d_year s2year\n     ,count(*) cnt\n     ,sum(ss_wholesale_cost) s1\n     ,sum(ss_list_price) s2\n     ,sum(ss_coupon_amt) s3\n  FROM   store_sales\n        ,store_returns\n        ,cs_ui\n        ,date_dim d1\n        ,date_dim d2\n        ,date_dim d3\n        ,store\n        ,customer\n        ,customer_demographics cd1\n        ,customer_demographics cd2\n        ,promotion\n        ,household_demographics hd1\n        ,household_demographics hd2\n        ,customer_address ad1\n        ,customer_address ad2\n        ,income_band ib1\n        ,income_band ib2\n        ,item\n  WHERE  ss_store_sk = s_store_sk AND\n         ss_sold_date_sk = d1.d_date_sk AND\n         ss_customer_sk = c_customer_sk AND\n         ss_cdemo_sk= cd1.cd_demo_sk AND\n         ss_hdemo_sk = hd1.hd_demo_sk AND\n         ss_addr_sk = ad1.ca_address_sk and\n         ss_item_sk = i_item_sk and\n         ss_item_sk = sr_item_sk and\n         ss_ticket_number = sr_ticket_number and\n         ss_item_sk = cs_ui.cs_item_sk and\n         c_current_cdemo_sk = cd2.cd_demo_sk AND\n         c_current_hdemo_sk = hd2.hd_demo_sk AND\n         c_current_addr_sk = ad2.ca_address_sk and\n         c_first_sales_date_sk = d2.d_date_sk and\n         c_first_shipto_date_sk = d3.d_date_sk and\n         ss_promo_sk = p_promo_sk and\n         hd1.hd_income_band_sk = ib1.ib_income_band_sk and\n         hd2.hd_income_band_sk = ib2.ib_income_band_sk and\n         cd1.cd_marital_status <> cd2.cd_marital_status and\n         i_color in ('lawn','blush','smoke','ghost','floral','chartreuse') and\n         i_current_price between 51 and 51 + 10 and\n         i_current_price between 51 + 1 and 51 + 15\ngroup by i_product_name\n       ,i_item_sk\n       ,s_store_name\n       ,s_zip\n       ,ad1.ca_street_number\n       ,ad1.ca_street_name\n       ,ad1.ca_city\n       ,ad1.ca_zip\n       ,ad2.ca_street_number\n       ,ad2.ca_street_name\n       ,ad2.ca_city\n       ,ad2.ca_zip\n       ,d1.d_year\n       ,d2.d_year\n       ,d3.d_year\n)\nselect cs1.product_name\n     ,cs1.store_name\n     ,cs1.store_zip\n     ,cs1.b_street_number\n     ,cs1.b_street_name\n     ,cs1.b_city\n     ,cs1.b_zip\n     ,cs1.c_street_number\n     ,cs1.c_street_name\n     ,cs1.c_city\n     ,cs1.c_zip\n     ,cs1.syear\n     ,cs1.cnt\n     ,cs1.s1 as s11\n     ,cs1.s2 as s21\n     ,cs1.s3 as s31\n     ,cs2.s1 as s12\n     ,cs2.s2 as s22\n     ,cs2.s3 as s32\n     ,cs2.syear\n     ,cs2.cnt\nfrom cross_sales cs1,cross_sales cs2\nwhere cs1.item_sk=cs2.item_sk and\n     cs1.syear = 2001 and\n     cs2.syear = 2001 + 1 and\n     cs2.cnt <= cs1.cnt and\n     cs1.store_name = cs2.store_name and\n     cs1.store_zip = cs2.store_zip\norder by cs1.product_name\n       ,cs1.store_name\n       ,cs2.cnt\n       ,cs1.s1\n       ,cs2.s1;\n\n-- end query 20 in stream 0 using template query64.tpl\n-- start query 21 in stream 0 using template query36.tpl\nselect  \n    sum(ss_net_profit)/sum(ss_ext_sales_price) as gross_margin\n   ,i_category\n   ,i_class\n   ,grouping(i_category)+grouping(i_class) as lochierarchy\n   ,rank() over (\n \tpartition by grouping(i_category)+grouping(i_class),\n \tcase when grouping(i_class) = 0 then i_category end \n \torder by sum(ss_net_profit)/sum(ss_ext_sales_price) asc) as rank_within_parent\n from\n    store_sales\n   ,date_dim       d1\n   ,item\n   ,store\n where\n    d1.d_year = 1999 \n and d1.d_date_sk = ss_sold_date_sk\n and i_item_sk  = ss_item_sk \n and s_store_sk  = ss_store_sk\n and s_state in ('AL','TN','SD','SD',\n                 'SD','SD','SD','SD')\n group by rollup(i_category,i_class)\n order by\n   lochierarchy desc\n  ,case when lochierarchy = 0 then i_category end\n  ,rank_within_parent\n   LIMIT 100;\n\n-- end query 21 in stream 0 using template query36.tpl\n-- start query 22 in stream 0 using template query33.tpl\nwith ss as (\n select\n          i_manufact_id,sum(ss_ext_sales_price) total_sales\n from\n \tstore_sales,\n \tdate_dim,\n         customer_address,\n         item\n where\n         i_manufact_id in (select\n  i_manufact_id\nfrom\n item\nwhere i_category in ('Electronics'))\n and     ss_item_sk              = i_item_sk\n and     ss_sold_date_sk         = d_date_sk\n and     d_year                  = 2002\n and     d_moy                   = 1\n and     ss_addr_sk              = ca_address_sk\n and     ca_gmt_offset           = -6 \n group by i_manufact_id),\n cs as (\n select\n          i_manufact_id,sum(cs_ext_sales_price) total_sales\n from\n \tcatalog_sales,\n \tdate_dim,\n         customer_address,\n         item\n where\n         i_manufact_id               in (select\n  i_manufact_id\nfrom\n item\nwhere i_category in ('Electronics'))\n and     cs_item_sk              = i_item_sk\n and     cs_sold_date_sk         = d_date_sk\n and     d_year                  = 2002\n and     d_moy                   = 1\n and     cs_bill_addr_sk         = ca_address_sk\n and     ca_gmt_offset           = -6 \n group by i_manufact_id),\n ws as (\n select\n          i_manufact_id,sum(ws_ext_sales_price) total_sales\n from\n \tweb_sales,\n \tdate_dim,\n         customer_address,\n         item\n where\n         i_manufact_id               in (select\n  i_manufact_id\nfrom\n item\nwhere i_category in ('Electronics'))\n and     ws_item_sk              = i_item_sk\n and     ws_sold_date_sk         = d_date_sk\n and     d_year                  = 2002\n and     d_moy                   = 1\n and     ws_bill_addr_sk         = ca_address_sk\n and     ca_gmt_offset           = -6\n group by i_manufact_id)\n  select  i_manufact_id ,sum(total_sales) total_sales\n from  (select * from ss \n        union all\n        select * from cs \n        union all\n        select * from ws) tmp1\n group by i_manufact_id\n order by total_sales\n LIMIT 100;\n\n-- end query 22 in stream 0 using template query33.tpl\n-- start query 23 in stream 0 using template query46.tpl\nselect  c_last_name\n       ,c_first_name\n       ,ca_city\n       ,bought_city\n       ,ss_ticket_number\n       ,amt,profit \n from\n   (select ss_ticket_number\n          ,ss_customer_sk\n          ,ca_city bought_city\n          ,sum(ss_coupon_amt) amt\n          ,sum(ss_net_profit) profit\n    from store_sales,date_dim,store,household_demographics,customer_address \n    where store_sales.ss_sold_date_sk = date_dim.d_date_sk\n    and store_sales.ss_store_sk = store.s_store_sk  \n    and store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk\n    and store_sales.ss_addr_sk = customer_address.ca_address_sk\n    and (household_demographics.hd_dep_count = 3 or\n         household_demographics.hd_vehicle_count= 4)\n    and date_dim.d_dow in (6,0)\n    and date_dim.d_year in (2000,2000+1,2000+2) \n    and store.s_city in ('Oak Grove','Fairview','Five Points','Riverside','Pleasant Hill') \n    group by ss_ticket_number,ss_customer_sk,ss_addr_sk,ca_city) dn,customer,customer_address current_addr\n    where ss_customer_sk = c_customer_sk\n      and customer.c_current_addr_sk = current_addr.ca_address_sk\n      and current_addr.ca_city <> bought_city\n  order by c_last_name\n          ,c_first_name\n          ,ca_city\n          ,bought_city\n          ,ss_ticket_number\n   LIMIT 100;\n\n-- end query 23 in stream 0 using template query46.tpl\n-- start query 24 in stream 0 using template query62.tpl\nselect  \n   substr(w_warehouse_name,1,20)\n  ,sm_type\n  ,web_name\n  ,sum(case when (ws_ship_date_sk - ws_sold_date_sk <= 30 ) then 1 else 0 end)  as `30 days`\n  ,sum(case when (ws_ship_date_sk - ws_sold_date_sk > 30) and \n                 (ws_ship_date_sk - ws_sold_date_sk <= 60) then 1 else 0 end )  as `31-60 days`\n  ,sum(case when (ws_ship_date_sk - ws_sold_date_sk > 60) and \n                 (ws_ship_date_sk - ws_sold_date_sk <= 90) then 1 else 0 end)  as `61-90 days`\n  ,sum(case when (ws_ship_date_sk - ws_sold_date_sk > 90) and\n                 (ws_ship_date_sk - ws_sold_date_sk <= 120) then 1 else 0 end)  as `91-120 days`\n  ,sum(case when (ws_ship_date_sk - ws_sold_date_sk  > 120) then 1 else 0 end)  as `>120 days`\nfrom\n   web_sales\n  ,warehouse\n  ,ship_mode\n  ,web_site\n  ,date_dim\nwhere\n    d_month_seq between 1211 and 1211 + 11\nand ws_ship_date_sk   = d_date_sk\nand ws_warehouse_sk   = w_warehouse_sk\nand ws_ship_mode_sk   = sm_ship_mode_sk\nand ws_web_site_sk    = web_site_sk\ngroup by\n   substr(w_warehouse_name,1,20)\n  ,sm_type\n  ,web_name\norder by substr(w_warehouse_name,1,20)\n        ,sm_type\n       ,web_name\n LIMIT 100;\n\n-- end query 24 in stream 0 using template query62.tpl\n-- start query 25 in stream 0 using template query16.tpl\nselect  \n   count(distinct cs_order_number) as `order count`\n  ,sum(cs_ext_ship_cost) as `total shipping cost`\n  ,sum(cs_net_profit) as `total net profit`\nfrom\n   catalog_sales cs1\n  ,date_dim\n  ,customer_address\n  ,call_center\nwhere\n    d_date between '1999-4-01' and \n           (cast('1999-4-01' as date) + interval 60 days)\nand cs1.cs_ship_date_sk = d_date_sk\nand cs1.cs_ship_addr_sk = ca_address_sk\nand ca_state = 'MD'\nand cs1.cs_call_center_sk = cc_call_center_sk\nand cc_county in ('Ziebach County','Williamson County','Walker County','Williamson County',\n                  'Ziebach County'\n)\nand exists (select *\n            from catalog_sales cs2\n            where cs1.cs_order_number = cs2.cs_order_number\n              and cs1.cs_warehouse_sk <> cs2.cs_warehouse_sk)\nand not exists(select *\n               from catalog_returns cr1\n               where cs1.cs_order_number = cr1.cr_order_number)\norder by count(distinct cs_order_number)\n LIMIT 100;\n\n-- end query 25 in stream 0 using template query16.tpl\n-- start query 26 in stream 0 using template query10.tpl\nselect  \n  cd_gender,\n  cd_marital_status,\n  cd_education_status,\n  count(*) cnt1,\n  cd_purchase_estimate,\n  count(*) cnt2,\n  cd_credit_rating,\n  count(*) cnt3,\n  cd_dep_count,\n  count(*) cnt4,\n  cd_dep_employed_count,\n  count(*) cnt5,\n  cd_dep_college_count,\n  count(*) cnt6\n from\n  customer c,customer_address ca,customer_demographics\n where\n  c.c_current_addr_sk = ca.ca_address_sk and\n  ca_county in ('Bottineau County','Marion County','Randolph County','Providence County','Sagadahoc County') and\n  cd_demo_sk = c.c_current_cdemo_sk and \n  exists (select *\n          from store_sales,date_dim\n          where c.c_customer_sk = ss_customer_sk and\n                ss_sold_date_sk = d_date_sk and\n                d_year = 2000 and\n                d_moy between 1 and 1+3) and\n   (exists (select *\n            from web_sales,date_dim\n            where c.c_customer_sk = ws_bill_customer_sk and\n                  ws_sold_date_sk = d_date_sk and\n                  d_year = 2000 and\n                  d_moy between 1 ANd 1+3) or \n    exists (select * \n            from catalog_sales,date_dim\n            where c.c_customer_sk = cs_ship_customer_sk and\n                  cs_sold_date_sk = d_date_sk and\n                  d_year = 2000 and\n                  d_moy between 1 and 1+3))\n group by cd_gender,\n          cd_marital_status,\n          cd_education_status,\n          cd_purchase_estimate,\n          cd_credit_rating,\n          cd_dep_count,\n          cd_dep_employed_count,\n          cd_dep_college_count\n order by cd_gender,\n          cd_marital_status,\n          cd_education_status,\n          cd_purchase_estimate,\n          cd_credit_rating,\n          cd_dep_count,\n          cd_dep_employed_count,\n          cd_dep_college_count\n LIMIT 100;\n\n-- end query 26 in stream 0 using template query10.tpl\n-- start query 27 in stream 0 using template query63.tpl\nselect  * \nfrom (select i_manager_id\n             ,sum(ss_sales_price) sum_sales\n             ,avg(sum(ss_sales_price)) over (partition by i_manager_id) avg_monthly_sales\n      from item\n          ,store_sales\n          ,date_dim\n          ,store\n      where ss_item_sk = i_item_sk\n        and ss_sold_date_sk = d_date_sk\n        and ss_store_sk = s_store_sk\n        and d_month_seq in (1179,1179+1,1179+2,1179+3,1179+4,1179+5,1179+6,1179+7,1179+8,1179+9,1179+10,1179+11)\n        and ((    i_category in ('Books','Children','Electronics')\n              and i_class in ('personal','portable','reference','self-help')\n              and i_brand in ('scholaramalgamalg #14','scholaramalgamalg #7',\n\t\t                  'exportiunivamalg #9','scholaramalgamalg #9'))\n           or(    i_category in ('Women','Music','Men')\n              and i_class in ('accessories','classical','fragrances','pants')\n              and i_brand in ('amalgimporto #1','edu packscholar #1','exportiimporto #1',\n\t\t                 'importoamalg #1')))\ngroup by i_manager_id, d_moy) tmp1\nwhere case when avg_monthly_sales > 0 then abs (sum_sales - avg_monthly_sales) / avg_monthly_sales else null end > 0.1\norder by i_manager_id\n        ,avg_monthly_sales\n        ,sum_sales\n LIMIT 100;\n\n-- end query 27 in stream 0 using template query63.tpl\n-- start query 28 in stream 0 using template query69.tpl\nselect  \n  cd_gender,\n  cd_marital_status,\n  cd_education_status,\n  count(*) cnt1,\n  cd_purchase_estimate,\n  count(*) cnt2,\n  cd_credit_rating,\n  count(*) cnt3\n from\n  customer c,customer_address ca,customer_demographics\n where\n  c.c_current_addr_sk = ca.ca_address_sk and\n  ca_state in ('IN','ND','PA') and\n  cd_demo_sk = c.c_current_cdemo_sk and \n  exists (select *\n          from store_sales,date_dim\n          where c.c_customer_sk = ss_customer_sk and\n                ss_sold_date_sk = d_date_sk and\n                d_year = 1999 and\n                d_moy between 2 and 2+2) and\n   (not exists (select *\n            from web_sales,date_dim\n            where c.c_customer_sk = ws_bill_customer_sk and\n                  ws_sold_date_sk = d_date_sk and\n                  d_year = 1999 and\n                  d_moy between 2 and 2+2) and\n    not exists (select * \n            from catalog_sales,date_dim\n            where c.c_customer_sk = cs_ship_customer_sk and\n                  cs_sold_date_sk = d_date_sk and\n                  d_year = 1999 and\n                  d_moy between 2 and 2+2))\n group by cd_gender,\n          cd_marital_status,\n          cd_education_status,\n          cd_purchase_estimate,\n          cd_credit_rating\n order by cd_gender,\n          cd_marital_status,\n          cd_education_status,\n          cd_purchase_estimate,\n          cd_credit_rating\n  LIMIT 100;\n\n-- end query 28 in stream 0 using template query69.tpl\n-- start query 29 in stream 0 using template query60.tpl\nwith ss as (\n select\n          i_item_id,sum(ss_ext_sales_price) total_sales\n from\n \tstore_sales,\n \tdate_dim,\n         customer_address,\n         item\n where\n         i_item_id in (select\n  i_item_id\nfrom\n item\nwhere i_category in ('Music'))\n and     ss_item_sk              = i_item_sk\n and     ss_sold_date_sk         = d_date_sk\n and     d_year                  = 1998\n and     d_moy                   = 10\n and     ss_addr_sk              = ca_address_sk\n and     ca_gmt_offset           = -5 \n group by i_item_id),\n cs as (\n select\n          i_item_id,sum(cs_ext_sales_price) total_sales\n from\n \tcatalog_sales,\n \tdate_dim,\n         customer_address,\n         item\n where\n         i_item_id               in (select\n  i_item_id\nfrom\n item\nwhere i_category in ('Music'))\n and     cs_item_sk              = i_item_sk\n and     cs_sold_date_sk         = d_date_sk\n and     d_year                  = 1998\n and     d_moy                   = 10\n and     cs_bill_addr_sk         = ca_address_sk\n and     ca_gmt_offset           = -5 \n group by i_item_id),\n ws as (\n select\n          i_item_id,sum(ws_ext_sales_price) total_sales\n from\n \tweb_sales,\n \tdate_dim,\n         customer_address,\n         item\n where\n         i_item_id               in (select\n  i_item_id\nfrom\n item\nwhere i_category in ('Music'))\n and     ws_item_sk              = i_item_sk\n and     ws_sold_date_sk         = d_date_sk\n and     d_year                  = 1998\n and     d_moy                   = 10\n and     ws_bill_addr_sk         = ca_address_sk\n and     ca_gmt_offset           = -5\n group by i_item_id)\n  select   \n  i_item_id\n,sum(total_sales) total_sales\n from  (select * from ss \n        union all\n        select * from cs \n        union all\n        select * from ws) tmp1\n group by i_item_id\n order by i_item_id\n      ,total_sales\n  LIMIT 100;\n\n-- end query 29 in stream 0 using template query60.tpl\n-- start query 30 in stream 0 using template query59.tpl\nwith wss as \n (select d_week_seq,\n        ss_store_sk,\n        sum(case when (d_day_name='Sunday') then ss_sales_price else null end) sun_sales,\n        sum(case when (d_day_name='Monday') then ss_sales_price else null end) mon_sales,\n        sum(case when (d_day_name='Tuesday') then ss_sales_price else  null end) tue_sales,\n        sum(case when (d_day_name='Wednesday') then ss_sales_price else null end) wed_sales,\n        sum(case when (d_day_name='Thursday') then ss_sales_price else null end) thu_sales,\n        sum(case when (d_day_name='Friday') then ss_sales_price else null end) fri_sales,\n        sum(case when (d_day_name='Saturday') then ss_sales_price else null end) sat_sales\n from store_sales,date_dim\n where d_date_sk = ss_sold_date_sk\n group by d_week_seq,ss_store_sk\n )\n  select  s_store_name1,s_store_id1,d_week_seq1\n       ,sun_sales1/sun_sales2,mon_sales1/mon_sales2\n       ,tue_sales1/tue_sales2,wed_sales1/wed_sales2,thu_sales1/thu_sales2\n       ,fri_sales1/fri_sales2,sat_sales1/sat_sales2\n from\n (select s_store_name s_store_name1,wss.d_week_seq d_week_seq1\n        ,s_store_id s_store_id1,sun_sales sun_sales1\n        ,mon_sales mon_sales1,tue_sales tue_sales1\n        ,wed_sales wed_sales1,thu_sales thu_sales1\n        ,fri_sales fri_sales1,sat_sales sat_sales1\n  from wss,store,date_dim d\n  where d.d_week_seq = wss.d_week_seq and\n        ss_store_sk = s_store_sk and \n        d_month_seq between 1202 and 1202 + 11) y,\n (select s_store_name s_store_name2,wss.d_week_seq d_week_seq2\n        ,s_store_id s_store_id2,sun_sales sun_sales2\n        ,mon_sales mon_sales2,tue_sales tue_sales2\n        ,wed_sales wed_sales2,thu_sales thu_sales2\n        ,fri_sales fri_sales2,sat_sales sat_sales2\n  from wss,store,date_dim d\n  where d.d_week_seq = wss.d_week_seq and\n        ss_store_sk = s_store_sk and \n        d_month_seq between 1202+ 12 and 1202 + 23) x\n where s_store_id1=s_store_id2\n   and d_week_seq1=d_week_seq2-52\n order by s_store_name1,s_store_id1,d_week_seq1\n LIMIT 100;\n\n-- end query 30 in stream 0 using template query59.tpl\n-- start query 31 in stream 0 using template query37.tpl\nselect  i_item_id\n       ,i_item_desc\n       ,i_current_price\n from item, inventory, date_dim, catalog_sales\n where i_current_price between 16 and 16 + 30\n and inv_item_sk = i_item_sk\n and d_date_sk=inv_date_sk\n and d_date between cast('1999-03-27' as date) and (cast('1999-03-27' as date) + interval 60 days)\n and i_manufact_id in (821,673,849,745)\n and inv_quantity_on_hand between 100 and 500\n and cs_item_sk = i_item_sk\n group by i_item_id,i_item_desc,i_current_price\n order by i_item_id\n  LIMIT 100;\n\n-- end query 31 in stream 0 using template query37.tpl\n-- start query 32 in stream 0 using template query98.tpl\nselect i_item_id\n      ,i_item_desc \n      ,i_category \n      ,i_class \n      ,i_current_price\n      ,sum(ss_ext_sales_price) as itemrevenue \n      ,sum(ss_ext_sales_price)*100/sum(sum(ss_ext_sales_price)) over\n          (partition by i_class) as revenueratio\nfrom\t\n\tstore_sales\n    \t,item \n    \t,date_dim\nwhere \n\tss_item_sk = i_item_sk \n  \tand i_category in ('Children', 'Women', 'Shoes')\n  \tand ss_sold_date_sk = d_date_sk\n\tand d_date between cast('2001-03-09' as date) \n\t\t\t\tand (cast('2001-03-09' as date) + interval 30 days)\ngroup by \n\ti_item_id\n        ,i_item_desc \n        ,i_category\n        ,i_class\n        ,i_current_price\norder by \n\ti_category\n        ,i_class\n        ,i_item_id\n        ,i_item_desc\n        ,revenueratio;\n\n-- end query 32 in stream 0 using template query98.tpl\n-- start query 33 in stream 0 using template query85.tpl\nselect  substr(r_reason_desc,1,20)\n       ,avg(ws_quantity)\n       ,avg(wr_refunded_cash)\n       ,avg(wr_fee)\n from web_sales, web_returns, web_page, customer_demographics cd1,\n      customer_demographics cd2, customer_address, date_dim, reason \n where ws_web_page_sk = wp_web_page_sk\n   and ws_item_sk = wr_item_sk\n   and ws_order_number = wr_order_number\n   and ws_sold_date_sk = d_date_sk and d_year = 2001\n   and cd1.cd_demo_sk = wr_refunded_cdemo_sk \n   and cd2.cd_demo_sk = wr_returning_cdemo_sk\n   and ca_address_sk = wr_refunded_addr_sk\n   and r_reason_sk = wr_reason_sk\n   and\n   (\n    (\n     cd1.cd_marital_status = 'W'\n     and\n     cd1.cd_marital_status = cd2.cd_marital_status\n     and\n     cd1.cd_education_status = 'Primary'\n     and \n     cd1.cd_education_status = cd2.cd_education_status\n     and\n     ws_sales_price between 100.00 and 150.00\n    )\n   or\n    (\n     cd1.cd_marital_status = 'D'\n     and\n     cd1.cd_marital_status = cd2.cd_marital_status\n     and\n     cd1.cd_education_status = 'College' \n     and\n     cd1.cd_education_status = cd2.cd_education_status\n     and\n     ws_sales_price between 50.00 and 100.00\n    )\n   or\n    (\n     cd1.cd_marital_status = 'S'\n     and\n     cd1.cd_marital_status = cd2.cd_marital_status\n     and\n     cd1.cd_education_status = '2 yr Degree'\n     and\n     cd1.cd_education_status = cd2.cd_education_status\n     and\n     ws_sales_price between 150.00 and 200.00\n    )\n   )\n   and\n   (\n    (\n     ca_country = 'United States'\n     and\n     ca_state in ('PA', 'IN', 'VA')\n     and ws_net_profit between 100 and 200  \n    )\n    or\n    (\n     ca_country = 'United States'\n     and\n     ca_state in ('TX', 'MO', 'MS')\n     and ws_net_profit between 150 and 300  \n    )\n    or\n    (\n     ca_country = 'United States'\n     and\n     ca_state in ('MT', 'OR', 'MN')\n     and ws_net_profit between 50 and 250  \n    )\n   )\ngroup by r_reason_desc\norder by substr(r_reason_desc,1,20)\n        ,avg(ws_quantity)\n        ,avg(wr_refunded_cash)\n        ,avg(wr_fee)\n LIMIT 100;\n\n-- end query 33 in stream 0 using template query85.tpl\n-- start query 34 in stream 0 using template query70.tpl\nselect  \n    sum(ss_net_profit) as total_sum\n   ,s_state\n   ,s_county\n   ,grouping(s_state)+grouping(s_county) as lochierarchy\n   ,rank() over (\n \tpartition by grouping(s_state)+grouping(s_county),\n \tcase when grouping(s_county) = 0 then s_state end \n \torder by sum(ss_net_profit) desc) as rank_within_parent\n from\n    store_sales\n   ,date_dim       d1\n   ,store\n where\n    d1.d_month_seq between 1191 and 1191+11\n and d1.d_date_sk = ss_sold_date_sk\n and s_store_sk  = ss_store_sk\n and s_state in\n             ( select s_state\n               from  (select s_state as s_state,\n \t\t\t    rank() over ( partition by s_state order by sum(ss_net_profit) desc) as ranking\n                      from   store_sales, store, date_dim\n                      where  d_month_seq between 1191 and 1191+11\n \t\t\t    and d_date_sk = ss_sold_date_sk\n \t\t\t    and s_store_sk  = ss_store_sk\n                      group by s_state\n                     ) tmp1 \n               where ranking <= 5\n             )\n group by rollup(s_state,s_county)\n order by\n   lochierarchy desc\n  ,case when lochierarchy = 0 then s_state end\n  ,rank_within_parent\n  LIMIT 100;\n\n-- end query 34 in stream 0 using template query70.tpl\n-- start query 35 in stream 0 using template query67.tpl\nselect  *\nfrom (select i_category\n            ,i_class\n            ,i_brand\n            ,i_product_name\n            ,d_year\n            ,d_qoy\n            ,d_moy\n            ,s_store_id\n            ,sumsales\n            ,rank() over (partition by i_category order by sumsales desc) rk\n      from (select i_category\n                  ,i_class\n                  ,i_brand\n                  ,i_product_name\n                  ,d_year\n                  ,d_qoy\n                  ,d_moy\n                  ,s_store_id\n                  ,sum(coalesce(ss_sales_price*ss_quantity,0)) sumsales\n            from store_sales\n                ,date_dim\n                ,store\n                ,item\n       where  ss_sold_date_sk=d_date_sk\n          and ss_item_sk=i_item_sk\n          and ss_store_sk = s_store_sk\n          and d_month_seq between 1192 and 1192+11\n       group by  rollup(i_category, i_class, i_brand, i_product_name, d_year, d_qoy, d_moy,s_store_id))dw1) dw2\nwhere rk <= 100\norder by i_category\n        ,i_class\n        ,i_brand\n        ,i_product_name\n        ,d_year\n        ,d_qoy\n        ,d_moy\n        ,s_store_id\n        ,sumsales\n        ,rk\n LIMIT 100;\n\n-- end query 35 in stream 0 using template query67.tpl\n-- start query 36 in stream 0 using template query28.tpl\nselect  *\nfrom (select avg(ss_list_price) B1_LP\n            ,count(ss_list_price) B1_CNT\n            ,count(distinct ss_list_price) B1_CNTD\n      from store_sales\n      where ss_quantity between 0 and 5\n        and (ss_list_price between 49 and 49+10 \n             or ss_coupon_amt between 5040 and 5040+1000\n             or ss_wholesale_cost between 4 and 4+20)) B1,\n     (select avg(ss_list_price) B2_LP\n            ,count(ss_list_price) B2_CNT\n            ,count(distinct ss_list_price) B2_CNTD\n      from store_sales\n      where ss_quantity between 6 and 10\n        and (ss_list_price between 5 and 5+10\n          or ss_coupon_amt between 441 and 441+1000\n          or ss_wholesale_cost between 80 and 80+20)) B2,\n     (select avg(ss_list_price) B3_LP\n            ,count(ss_list_price) B3_CNT\n            ,count(distinct ss_list_price) B3_CNTD\n      from store_sales\n      where ss_quantity between 11 and 15\n        and (ss_list_price between 153 and 153+10\n          or ss_coupon_amt between 10459 and 10459+1000\n          or ss_wholesale_cost between 3 and 3+20)) B3,\n     (select avg(ss_list_price) B4_LP\n            ,count(ss_list_price) B4_CNT\n            ,count(distinct ss_list_price) B4_CNTD\n      from store_sales\n      where ss_quantity between 16 and 20\n        and (ss_list_price between 14 and 14+10\n          or ss_coupon_amt between 13311 and 13311+1000\n          or ss_wholesale_cost between 1 and 1+20)) B4,\n     (select avg(ss_list_price) B5_LP\n            ,count(ss_list_price) B5_CNT\n            ,count(distinct ss_list_price) B5_CNTD\n      from store_sales\n      where ss_quantity between 21 and 25\n        and (ss_list_price between 29 and 29+10\n          or ss_coupon_amt between 6047 and 6047+1000\n          or ss_wholesale_cost between 27 and 27+20)) B5,\n     (select avg(ss_list_price) B6_LP\n            ,count(ss_list_price) B6_CNT\n            ,count(distinct ss_list_price) B6_CNTD\n      from store_sales\n      where ss_quantity between 26 and 30\n        and (ss_list_price between 159 and 159+10\n          or ss_coupon_amt between 2432 and 2432+1000\n          or ss_wholesale_cost between 48 and 48+20)) B6\n LIMIT 100;\n\n-- end query 36 in stream 0 using template query28.tpl\n-- start query 37 in stream 0 using template query81.tpl\nwith customer_total_return as\n (select cr_returning_customer_sk as ctr_customer_sk\n        ,ca_state as ctr_state, \n \tsum(cr_return_amt_inc_tax) as ctr_total_return\n from catalog_returns\n     ,date_dim\n     ,customer_address\n where cr_returned_date_sk = d_date_sk \n   and d_year =2002\n   and cr_returning_addr_sk = ca_address_sk \n group by cr_returning_customer_sk\n         ,ca_state )\n  select  c_customer_id,c_salutation,c_first_name,c_last_name,ca_street_number,ca_street_name\n                   ,ca_street_type,ca_suite_number,ca_city,ca_county,ca_state,ca_zip,ca_country,ca_gmt_offset\n                  ,ca_location_type,ctr_total_return\n from customer_total_return ctr1\n     ,customer_address\n     ,customer\n where ctr1.ctr_total_return > (select avg(ctr_total_return)*1.2\n \t\t\t  from customer_total_return ctr2 \n                  \t  where ctr1.ctr_state = ctr2.ctr_state)\n       and ca_address_sk = c_current_addr_sk\n       and ca_state = 'IL'\n       and ctr1.ctr_customer_sk = c_customer_sk\n order by c_customer_id,c_salutation,c_first_name,c_last_name,ca_street_number,ca_street_name\n                   ,ca_street_type,ca_suite_number,ca_city,ca_county,ca_state,ca_zip,ca_country,ca_gmt_offset\n                  ,ca_location_type,ctr_total_return\n  LIMIT 100;\n\n-- end query 37 in stream 0 using template query81.tpl\n-- start query 38 in stream 0 using template query97.tpl\nwith ssci as (\nselect ss_customer_sk customer_sk\n      ,ss_item_sk item_sk\nfrom store_sales,date_dim\nwhere ss_sold_date_sk = d_date_sk\n  and d_month_seq between 1176 and 1176 + 11\ngroup by ss_customer_sk\n        ,ss_item_sk),\ncsci as(\n select cs_bill_customer_sk customer_sk\n      ,cs_item_sk item_sk\nfrom catalog_sales,date_dim\nwhere cs_sold_date_sk = d_date_sk\n  and d_month_seq between 1176 and 1176 + 11\ngroup by cs_bill_customer_sk\n        ,cs_item_sk)\n select  sum(case when ssci.customer_sk is not null and csci.customer_sk is null then 1 else 0 end) store_only\n      ,sum(case when ssci.customer_sk is null and csci.customer_sk is not null then 1 else 0 end) catalog_only\n      ,sum(case when ssci.customer_sk is not null and csci.customer_sk is not null then 1 else 0 end) store_and_catalog\nfrom ssci full outer join csci on (ssci.customer_sk=csci.customer_sk\n                               and ssci.item_sk = csci.item_sk)\n LIMIT 100;\n\n-- end query 38 in stream 0 using template query97.tpl\n-- start query 39 in stream 0 using template query66.tpl\nselect   \n         w_warehouse_name\n \t,w_warehouse_sq_ft\n \t,w_city\n \t,w_county\n \t,w_state\n \t,w_country\n        ,ship_carriers\n        ,year\n \t,sum(jan_sales) as jan_sales\n \t,sum(feb_sales) as feb_sales\n \t,sum(mar_sales) as mar_sales\n \t,sum(apr_sales) as apr_sales\n \t,sum(may_sales) as may_sales\n \t,sum(jun_sales) as jun_sales\n \t,sum(jul_sales) as jul_sales\n \t,sum(aug_sales) as aug_sales\n \t,sum(sep_sales) as sep_sales\n \t,sum(oct_sales) as oct_sales\n \t,sum(nov_sales) as nov_sales\n \t,sum(dec_sales) as dec_sales\n \t,sum(jan_sales/w_warehouse_sq_ft) as jan_sales_per_sq_foot\n \t,sum(feb_sales/w_warehouse_sq_ft) as feb_sales_per_sq_foot\n \t,sum(mar_sales/w_warehouse_sq_ft) as mar_sales_per_sq_foot\n \t,sum(apr_sales/w_warehouse_sq_ft) as apr_sales_per_sq_foot\n \t,sum(may_sales/w_warehouse_sq_ft) as may_sales_per_sq_foot\n \t,sum(jun_sales/w_warehouse_sq_ft) as jun_sales_per_sq_foot\n \t,sum(jul_sales/w_warehouse_sq_ft) as jul_sales_per_sq_foot\n \t,sum(aug_sales/w_warehouse_sq_ft) as aug_sales_per_sq_foot\n \t,sum(sep_sales/w_warehouse_sq_ft) as sep_sales_per_sq_foot\n \t,sum(oct_sales/w_warehouse_sq_ft) as oct_sales_per_sq_foot\n \t,sum(nov_sales/w_warehouse_sq_ft) as nov_sales_per_sq_foot\n \t,sum(dec_sales/w_warehouse_sq_ft) as dec_sales_per_sq_foot\n \t,sum(jan_net) as jan_net\n \t,sum(feb_net) as feb_net\n \t,sum(mar_net) as mar_net\n \t,sum(apr_net) as apr_net\n \t,sum(may_net) as may_net\n \t,sum(jun_net) as jun_net\n \t,sum(jul_net) as jul_net\n \t,sum(aug_net) as aug_net\n \t,sum(sep_net) as sep_net\n \t,sum(oct_net) as oct_net\n \t,sum(nov_net) as nov_net\n \t,sum(dec_net) as dec_net\n from (\n     select \n \tw_warehouse_name\n \t,w_warehouse_sq_ft\n \t,w_city\n \t,w_county\n \t,w_state\n \t,w_country\n \t,'ZOUROS' || ',' || 'ZHOU' as ship_carriers\n       ,d_year as year\n \t,sum(case when d_moy = 1 \n \t\tthen ws_sales_price* ws_quantity else 0 end) as jan_sales\n \t,sum(case when d_moy = 2 \n \t\tthen ws_sales_price* ws_quantity else 0 end) as feb_sales\n \t,sum(case when d_moy = 3 \n \t\tthen ws_sales_price* ws_quantity else 0 end) as mar_sales\n \t,sum(case when d_moy = 4 \n \t\tthen ws_sales_price* ws_quantity else 0 end) as apr_sales\n \t,sum(case when d_moy = 5 \n \t\tthen ws_sales_price* ws_quantity else 0 end) as may_sales\n \t,sum(case when d_moy = 6 \n \t\tthen ws_sales_price* ws_quantity else 0 end) as jun_sales\n \t,sum(case when d_moy = 7 \n \t\tthen ws_sales_price* ws_quantity else 0 end) as jul_sales\n \t,sum(case when d_moy = 8 \n \t\tthen ws_sales_price* ws_quantity else 0 end) as aug_sales\n \t,sum(case when d_moy = 9 \n \t\tthen ws_sales_price* ws_quantity else 0 end) as sep_sales\n \t,sum(case when d_moy = 10 \n \t\tthen ws_sales_price* ws_quantity else 0 end) as oct_sales\n \t,sum(case when d_moy = 11\n \t\tthen ws_sales_price* ws_quantity else 0 end) as nov_sales\n \t,sum(case when d_moy = 12\n \t\tthen ws_sales_price* ws_quantity else 0 end) as dec_sales\n \t,sum(case when d_moy = 1 \n \t\tthen ws_net_paid * ws_quantity else 0 end) as jan_net\n \t,sum(case when d_moy = 2\n \t\tthen ws_net_paid * ws_quantity else 0 end) as feb_net\n \t,sum(case when d_moy = 3 \n \t\tthen ws_net_paid * ws_quantity else 0 end) as mar_net\n \t,sum(case when d_moy = 4 \n \t\tthen ws_net_paid * ws_quantity else 0 end) as apr_net\n \t,sum(case when d_moy = 5 \n \t\tthen ws_net_paid * ws_quantity else 0 end) as may_net\n \t,sum(case when d_moy = 6 \n \t\tthen ws_net_paid * ws_quantity else 0 end) as jun_net\n \t,sum(case when d_moy = 7 \n \t\tthen ws_net_paid * ws_quantity else 0 end) as jul_net\n \t,sum(case when d_moy = 8 \n \t\tthen ws_net_paid * ws_quantity else 0 end) as aug_net\n \t,sum(case when d_moy = 9 \n \t\tthen ws_net_paid * ws_quantity else 0 end) as sep_net\n \t,sum(case when d_moy = 10 \n \t\tthen ws_net_paid * ws_quantity else 0 end) as oct_net\n \t,sum(case when d_moy = 11\n \t\tthen ws_net_paid * ws_quantity else 0 end) as nov_net\n \t,sum(case when d_moy = 12\n \t\tthen ws_net_paid * ws_quantity else 0 end) as dec_net\n     from\n          web_sales\n         ,warehouse\n         ,date_dim\n         ,time_dim\n \t  ,ship_mode\n     where\n            ws_warehouse_sk =  w_warehouse_sk\n        and ws_sold_date_sk = d_date_sk\n        and ws_sold_time_sk = t_time_sk\n \tand ws_ship_mode_sk = sm_ship_mode_sk\n        and d_year = 2000\n \tand t_time between 18479 and 18479+28800 \n \tand sm_carrier in ('ZOUROS','ZHOU')\n     group by \n        w_warehouse_name\n \t,w_warehouse_sq_ft\n \t,w_city\n \t,w_county\n \t,w_state\n \t,w_country\n       ,d_year\n union all\n     select \n \tw_warehouse_name\n \t,w_warehouse_sq_ft\n \t,w_city\n \t,w_county\n \t,w_state\n \t,w_country\n \t,'ZOUROS' || ',' || 'ZHOU' as ship_carriers\n       ,d_year as year\n \t,sum(case when d_moy = 1 \n \t\tthen cs_ext_sales_price* cs_quantity else 0 end) as jan_sales\n \t,sum(case when d_moy = 2 \n \t\tthen cs_ext_sales_price* cs_quantity else 0 end) as feb_sales\n \t,sum(case when d_moy = 3 \n \t\tthen cs_ext_sales_price* cs_quantity else 0 end) as mar_sales\n \t,sum(case when d_moy = 4 \n \t\tthen cs_ext_sales_price* cs_quantity else 0 end) as apr_sales\n \t,sum(case when d_moy = 5 \n \t\tthen cs_ext_sales_price* cs_quantity else 0 end) as may_sales\n \t,sum(case when d_moy = 6 \n \t\tthen cs_ext_sales_price* cs_quantity else 0 end) as jun_sales\n \t,sum(case when d_moy = 7 \n \t\tthen cs_ext_sales_price* cs_quantity else 0 end) as jul_sales\n \t,sum(case when d_moy = 8 \n \t\tthen cs_ext_sales_price* cs_quantity else 0 end) as aug_sales\n \t,sum(case when d_moy = 9 \n \t\tthen cs_ext_sales_price* cs_quantity else 0 end) as sep_sales\n \t,sum(case when d_moy = 10 \n \t\tthen cs_ext_sales_price* cs_quantity else 0 end) as oct_sales\n \t,sum(case when d_moy = 11\n \t\tthen cs_ext_sales_price* cs_quantity else 0 end) as nov_sales\n \t,sum(case when d_moy = 12\n \t\tthen cs_ext_sales_price* cs_quantity else 0 end) as dec_sales\n \t,sum(case when d_moy = 1 \n \t\tthen cs_net_paid_inc_ship * cs_quantity else 0 end) as jan_net\n \t,sum(case when d_moy = 2 \n \t\tthen cs_net_paid_inc_ship * cs_quantity else 0 end) as feb_net\n \t,sum(case when d_moy = 3 \n \t\tthen cs_net_paid_inc_ship * cs_quantity else 0 end) as mar_net\n \t,sum(case when d_moy = 4 \n \t\tthen cs_net_paid_inc_ship * cs_quantity else 0 end) as apr_net\n \t,sum(case when d_moy = 5 \n \t\tthen cs_net_paid_inc_ship * cs_quantity else 0 end) as may_net\n \t,sum(case when d_moy = 6 \n \t\tthen cs_net_paid_inc_ship * cs_quantity else 0 end) as jun_net\n \t,sum(case when d_moy = 7 \n \t\tthen cs_net_paid_inc_ship * cs_quantity else 0 end) as jul_net\n \t,sum(case when d_moy = 8 \n \t\tthen cs_net_paid_inc_ship * cs_quantity else 0 end) as aug_net\n \t,sum(case when d_moy = 9 \n \t\tthen cs_net_paid_inc_ship * cs_quantity else 0 end) as sep_net\n \t,sum(case when d_moy = 10 \n \t\tthen cs_net_paid_inc_ship * cs_quantity else 0 end) as oct_net\n \t,sum(case when d_moy = 11\n \t\tthen cs_net_paid_inc_ship * cs_quantity else 0 end) as nov_net\n \t,sum(case when d_moy = 12\n \t\tthen cs_net_paid_inc_ship * cs_quantity else 0 end) as dec_net\n     from\n          catalog_sales\n         ,warehouse\n         ,date_dim\n         ,time_dim\n \t ,ship_mode\n     where\n            cs_warehouse_sk =  w_warehouse_sk\n        and cs_sold_date_sk = d_date_sk\n        and cs_sold_time_sk = t_time_sk\n \tand cs_ship_mode_sk = sm_ship_mode_sk\n        and d_year = 2000\n \tand t_time between 18479 AND 18479+28800 \n \tand sm_carrier in ('ZOUROS','ZHOU')\n     group by \n        w_warehouse_name\n \t,w_warehouse_sq_ft\n \t,w_city\n \t,w_county\n \t,w_state\n \t,w_country\n       ,d_year\n ) x\n group by \n        w_warehouse_name\n \t,w_warehouse_sq_ft\n \t,w_city\n \t,w_county\n \t,w_state\n \t,w_country\n \t,ship_carriers\n       ,year\n order by w_warehouse_name\n  LIMIT 100;\n\n-- end query 39 in stream 0 using template query66.tpl\n-- start query 40 in stream 0 using template query90.tpl\nselect  cast(amc as decimal(15,4))/cast(pmc as decimal(15,4)) am_pm_ratio\n from ( select count(*) amc\n       from web_sales, household_demographics , time_dim, web_page\n       where ws_sold_time_sk = time_dim.t_time_sk\n         and ws_ship_hdemo_sk = household_demographics.hd_demo_sk\n         and ws_web_page_sk = web_page.wp_web_page_sk\n         and time_dim.t_hour between 12 and 12+1\n         and household_demographics.hd_dep_count = 0\n         and web_page.wp_char_count between 5000 and 5200) at,\n      ( select count(*) pmc\n       from web_sales, household_demographics , time_dim, web_page\n       where ws_sold_time_sk = time_dim.t_time_sk\n         and ws_ship_hdemo_sk = household_demographics.hd_demo_sk\n         and ws_web_page_sk = web_page.wp_web_page_sk\n         and time_dim.t_hour between 15 and 15+1\n         and household_demographics.hd_dep_count = 0\n         and web_page.wp_char_count between 5000 and 5200) pt\n order by am_pm_ratio\n  LIMIT 100;\n\n-- end query 40 in stream 0 using template query90.tpl\n-- start query 41 in stream 0 using template query17.tpl\nselect  i_item_id\n       ,i_item_desc\n       ,s_state\n       ,count(ss_quantity) as store_sales_quantitycount\n       ,avg(ss_quantity) as store_sales_quantityave\n       ,stddev_samp(ss_quantity) as store_sales_quantitystdev\n       ,stddev_samp(ss_quantity)/avg(ss_quantity) as store_sales_quantitycov\n       ,count(sr_return_quantity) as store_returns_quantitycount\n       ,avg(sr_return_quantity) as store_returns_quantityave\n       ,stddev_samp(sr_return_quantity) as store_returns_quantitystdev\n       ,stddev_samp(sr_return_quantity)/avg(sr_return_quantity) as store_returns_quantitycov\n       ,count(cs_quantity) as catalog_sales_quantitycount ,avg(cs_quantity) as catalog_sales_quantityave\n       ,stddev_samp(cs_quantity) as catalog_sales_quantitystdev\n       ,stddev_samp(cs_quantity)/avg(cs_quantity) as catalog_sales_quantitycov\n from store_sales\n     ,store_returns\n     ,catalog_sales\n     ,date_dim d1\n     ,date_dim d2\n     ,date_dim d3\n     ,store\n     ,item\n where d1.d_quarter_name = '2001Q1'\n   and d1.d_date_sk = ss_sold_date_sk\n   and i_item_sk = ss_item_sk\n   and s_store_sk = ss_store_sk\n   and ss_customer_sk = sr_customer_sk\n   and ss_item_sk = sr_item_sk\n   and ss_ticket_number = sr_ticket_number\n   and sr_returned_date_sk = d2.d_date_sk\n   and d2.d_quarter_name in ('2001Q1','2001Q2','2001Q3')\n   and sr_customer_sk = cs_bill_customer_sk\n   and sr_item_sk = cs_item_sk\n   and cs_sold_date_sk = d3.d_date_sk\n   and d3.d_quarter_name in ('2001Q1','2001Q2','2001Q3')\n group by i_item_id\n         ,i_item_desc\n         ,s_state\n order by i_item_id\n         ,i_item_desc\n         ,s_state\n LIMIT 100;\n\n-- end query 41 in stream 0 using template query17.tpl\n-- start query 42 in stream 0 using template query47.tpl\nwith v1 as(\n select i_category, i_brand,\n        s_store_name, s_company_name,\n        d_year, d_moy,\n        sum(ss_sales_price) sum_sales,\n        avg(sum(ss_sales_price)) over\n          (partition by i_category, i_brand,\n                     s_store_name, s_company_name, d_year)\n          avg_monthly_sales,\n        rank() over\n          (partition by i_category, i_brand,\n                     s_store_name, s_company_name\n           order by d_year, d_moy) rn\n from item, store_sales, date_dim, store\n where ss_item_sk = i_item_sk and\n       ss_sold_date_sk = d_date_sk and\n       ss_store_sk = s_store_sk and\n       (\n         d_year = 2001 or\n         ( d_year = 2001-1 and d_moy =12) or\n         ( d_year = 2001+1 and d_moy =1)\n       )\n group by i_category, i_brand,\n          s_store_name, s_company_name,\n          d_year, d_moy),\n v2 as(\n select v1.s_company_name\n        ,v1.d_year, v1.d_moy\n        ,v1.avg_monthly_sales\n        ,v1.sum_sales, v1_lag.sum_sales psum, v1_lead.sum_sales nsum\n from v1, v1 v1_lag, v1 v1_lead\n where v1.i_category = v1_lag.i_category and\n       v1.i_category = v1_lead.i_category and\n       v1.i_brand = v1_lag.i_brand and\n       v1.i_brand = v1_lead.i_brand and\n       v1.s_store_name = v1_lag.s_store_name and\n       v1.s_store_name = v1_lead.s_store_name and\n       v1.s_company_name = v1_lag.s_company_name and\n       v1.s_company_name = v1_lead.s_company_name and\n       v1.rn = v1_lag.rn + 1 and\n       v1.rn = v1_lead.rn - 1)\n  select  *\n from v2\n where  d_year = 2001 and    \n        avg_monthly_sales > 0 and\n        case when avg_monthly_sales > 0 then abs(sum_sales - avg_monthly_sales) / avg_monthly_sales else null end > 0.1\n order by sum_sales - avg_monthly_sales, avg_monthly_sales\n  LIMIT 100;\n\n-- end query 42 in stream 0 using template query47.tpl\n-- start query 43 in stream 0 using template query95.tpl\nwith ws_wh as\n(select ws1.ws_order_number,ws1.ws_warehouse_sk wh1,ws2.ws_warehouse_sk wh2\n from web_sales ws1,web_sales ws2\n where ws1.ws_order_number = ws2.ws_order_number\n   and ws1.ws_warehouse_sk <> ws2.ws_warehouse_sk)\n select  \n   count(distinct ws_order_number) as `order count`\n  ,sum(ws_ext_ship_cost) as `total shipping cost`\n  ,sum(ws_net_profit) as `total net profit`\nfrom\n   web_sales ws1\n  ,date_dim\n  ,customer_address\n  ,web_site\nwhere\n    d_date between '1999-3-01' and \n           (cast('1999-3-01' as date) + interval 60 days)\nand ws1.ws_ship_date_sk = d_date_sk\nand ws1.ws_ship_addr_sk = ca_address_sk\nand ca_state = 'OR'\nand ws1.ws_web_site_sk = web_site_sk\nand web_company_name = 'pri'\nand ws1.ws_order_number in (select ws_order_number\n                            from ws_wh)\nand ws1.ws_order_number in (select wr_order_number\n                            from web_returns,ws_wh\n                            where wr_order_number = ws_wh.ws_order_number)\norder by count(distinct ws_order_number)\n LIMIT 100;\n\n-- end query 43 in stream 0 using template query95.tpl\n-- start query 44 in stream 0 using template query92.tpl\nselect  \n   sum(ws_ext_discount_amt)  as `Excess Discount Amount`\nfrom \n    web_sales \n   ,item \n   ,date_dim\nwhere\ni_manufact_id = 783\nand i_item_sk = ws_item_sk \nand d_date between '1999-03-21' and \n        (cast('1999-03-21' as date) + interval 90 days)\nand d_date_sk = ws_sold_date_sk \nand ws_ext_discount_amt  \n     > ( \n         SELECT \n            1.3 * avg(ws_ext_discount_amt) \n         FROM \n            web_sales \n           ,date_dim\n         WHERE \n              ws_item_sk = i_item_sk \n          and d_date between '1999-03-21' and\n                             (cast('1999-03-21' as date) + interval 90 days)\n          and d_date_sk = ws_sold_date_sk \n      ) \norder by sum(ws_ext_discount_amt)\n LIMIT 100;\n\n-- end query 44 in stream 0 using template query92.tpl\n-- start query 45 in stream 0 using template query3.tpl\nselect  dt.d_year \n       ,item.i_brand_id brand_id \n       ,item.i_brand brand\n       ,sum(ss_sales_price) sum_agg\n from  date_dim dt \n      ,store_sales\n      ,item\n where dt.d_date_sk = store_sales.ss_sold_date_sk\n   and store_sales.ss_item_sk = item.i_item_sk\n   and item.i_manufact_id = 211\n   and dt.d_moy=11\n group by dt.d_year\n      ,item.i_brand\n      ,item.i_brand_id\n order by dt.d_year\n         ,sum_agg desc\n         ,brand_id\n  LIMIT 100;\n\n-- end query 45 in stream 0 using template query3.tpl\n-- start query 46 in stream 0 using template query51.tpl\nWITH web_v1 as (\nselect\n  ws_item_sk item_sk, d_date,\n  sum(sum(ws_sales_price))\n      over (partition by ws_item_sk order by d_date rows between unbounded preceding and current row) cume_sales\nfrom web_sales\n    ,date_dim\nwhere ws_sold_date_sk=d_date_sk\n  and d_month_seq between 1195 and 1195+11\n  and ws_item_sk is not NULL\ngroup by ws_item_sk, d_date),\nstore_v1 as (\nselect\n  ss_item_sk item_sk, d_date,\n  sum(sum(ss_sales_price))\n      over (partition by ss_item_sk order by d_date rows between unbounded preceding and current row) cume_sales\nfrom store_sales\n    ,date_dim\nwhere ss_sold_date_sk=d_date_sk\n  and d_month_seq between 1195 and 1195+11\n  and ss_item_sk is not NULL\ngroup by ss_item_sk, d_date)\n select  *\nfrom (select item_sk\n     ,d_date\n     ,web_sales\n     ,store_sales\n     ,max(web_sales)\n         over (partition by item_sk order by d_date rows between unbounded preceding and current row) web_cumulative\n     ,max(store_sales)\n         over (partition by item_sk order by d_date rows between unbounded preceding and current row) store_cumulative\n     from (select case when web.item_sk is not null then web.item_sk else store.item_sk end item_sk\n                 ,case when web.d_date is not null then web.d_date else store.d_date end d_date\n                 ,web.cume_sales web_sales\n                 ,store.cume_sales store_sales\n           from web_v1 web full outer join store_v1 store on (web.item_sk = store.item_sk\n                                                          and web.d_date = store.d_date)\n          )x )y\nwhere web_cumulative > store_cumulative\norder by item_sk\n        ,d_date\n LIMIT 100;\n\n-- end query 46 in stream 0 using template query51.tpl\n-- start query 47 in stream 0 using template query35.tpl\nselect   \n  ca_state,\n  cd_gender,\n  cd_marital_status,\n  cd_dep_count,\n  count(*) cnt1,\n  stddev_samp(cd_dep_count) aggone1,\n  sum(cd_dep_count) aggtwo1,\n  min(cd_dep_count) aggthree1,\n  cd_dep_employed_count,\n  count(*) cnt2,\n  stddev_samp(cd_dep_employed_count) aggone2,\n  sum(cd_dep_employed_count) aggtwo2,\n  min(cd_dep_employed_count) aggthree2,\n  cd_dep_college_count,\n  count(*) cnt3,\n  stddev_samp(cd_dep_college_count) aggone3,\n  sum(cd_dep_college_count) aggtwo3,\n  min(cd_dep_college_count) aggthree3\n from\n  customer c,customer_address ca,customer_demographics\n where\n  c.c_current_addr_sk = ca.ca_address_sk and\n  cd_demo_sk = c.c_current_cdemo_sk and \n  exists (select *\n          from store_sales,date_dim\n          where c.c_customer_sk = ss_customer_sk and\n                ss_sold_date_sk = d_date_sk and\n                d_year = 2001 and\n                d_qoy < 4) and\n   (exists (select *\n            from web_sales,date_dim\n            where c.c_customer_sk = ws_bill_customer_sk and\n                  ws_sold_date_sk = d_date_sk and\n                  d_year = 2001 and\n                  d_qoy < 4) or \n    exists (select * \n            from catalog_sales,date_dim\n            where c.c_customer_sk = cs_ship_customer_sk and\n                  cs_sold_date_sk = d_date_sk and\n                  d_year = 2001 and\n                  d_qoy < 4))\n group by ca_state,\n          cd_gender,\n          cd_marital_status,\n          cd_dep_count,\n          cd_dep_employed_count,\n          cd_dep_college_count\n order by ca_state,\n          cd_gender,\n          cd_marital_status,\n          cd_dep_count,\n          cd_dep_employed_count,\n          cd_dep_college_count\n  LIMIT 100;\n\n-- end query 47 in stream 0 using template query35.tpl\n-- start query 48 in stream 0 using template query49.tpl\nselect  channel, item, return_ratio, return_rank, currency_rank from\n (select\n 'web' as channel\n ,web.item\n ,web.return_ratio\n ,web.return_rank\n ,web.currency_rank\n from (\n \tselect \n \t item\n \t,return_ratio\n \t,currency_ratio\n \t,rank() over (order by return_ratio) as return_rank\n \t,rank() over (order by currency_ratio) as currency_rank\n \tfrom\n \t(\tselect ws.ws_item_sk as item\n \t\t,(cast(sum(coalesce(wr.wr_return_quantity,0)) as decimal(15,4))/\n \t\tcast(sum(coalesce(ws.ws_quantity,0)) as decimal(15,4) )) as return_ratio\n \t\t,(cast(sum(coalesce(wr.wr_return_amt,0)) as decimal(15,4))/\n \t\tcast(sum(coalesce(ws.ws_net_paid,0)) as decimal(15,4) )) as currency_ratio\n \t\tfrom \n \t\t web_sales ws left outer join web_returns wr \n \t\t\ton (ws.ws_order_number = wr.wr_order_number and \n \t\t\tws.ws_item_sk = wr.wr_item_sk)\n                 ,date_dim\n \t\twhere \n \t\t\twr.wr_return_amt > 10000 \n \t\t\tand ws.ws_net_profit > 1\n                         and ws.ws_net_paid > 0\n                         and ws.ws_quantity > 0\n                         and ws_sold_date_sk = d_date_sk\n                         and d_year = 2000\n                         and d_moy = 12\n \t\tgroup by ws.ws_item_sk\n \t) in_web\n ) web\n where \n (\n web.return_rank <= 10\n or\n web.currency_rank <= 10\n )\n union\n select \n 'catalog' as channel\n ,catalog.item\n ,catalog.return_ratio\n ,catalog.return_rank\n ,catalog.currency_rank\n from (\n \tselect \n \t item\n \t,return_ratio\n \t,currency_ratio\n \t,rank() over (order by return_ratio) as return_rank\n \t,rank() over (order by currency_ratio) as currency_rank\n \tfrom\n \t(\tselect \n \t\tcs.cs_item_sk as item\n \t\t,(cast(sum(coalesce(cr.cr_return_quantity,0)) as decimal(15,4))/\n \t\tcast(sum(coalesce(cs.cs_quantity,0)) as decimal(15,4) )) as return_ratio\n \t\t,(cast(sum(coalesce(cr.cr_return_amount,0)) as decimal(15,4))/\n \t\tcast(sum(coalesce(cs.cs_net_paid,0)) as decimal(15,4) )) as currency_ratio\n \t\tfrom \n \t\tcatalog_sales cs left outer join catalog_returns cr\n \t\t\ton (cs.cs_order_number = cr.cr_order_number and \n \t\t\tcs.cs_item_sk = cr.cr_item_sk)\n                ,date_dim\n \t\twhere \n \t\t\tcr.cr_return_amount > 10000 \n \t\t\tand cs.cs_net_profit > 1\n                         and cs.cs_net_paid > 0\n                         and cs.cs_quantity > 0\n                         and cs_sold_date_sk = d_date_sk\n                         and d_year = 2000\n                         and d_moy = 12\n                 group by cs.cs_item_sk\n \t) in_cat\n ) catalog\n where \n (\n catalog.return_rank <= 10\n or\n catalog.currency_rank <=10\n )\n union\n select \n 'store' as channel\n ,store.item\n ,store.return_ratio\n ,store.return_rank\n ,store.currency_rank\n from (\n \tselect \n \t item\n \t,return_ratio\n \t,currency_ratio\n \t,rank() over (order by return_ratio) as return_rank\n \t,rank() over (order by currency_ratio) as currency_rank\n \tfrom\n \t(\tselect sts.ss_item_sk as item\n \t\t,(cast(sum(coalesce(sr.sr_return_quantity,0)) as decimal(15,4))/cast(sum(coalesce(sts.ss_quantity,0)) as decimal(15,4) )) as return_ratio\n \t\t,(cast(sum(coalesce(sr.sr_return_amt,0)) as decimal(15,4))/cast(sum(coalesce(sts.ss_net_paid,0)) as decimal(15,4) )) as currency_ratio\n \t\tfrom \n \t\tstore_sales sts left outer join store_returns sr\n \t\t\ton (sts.ss_ticket_number = sr.sr_ticket_number and sts.ss_item_sk = sr.sr_item_sk)\n                ,date_dim\n \t\twhere \n \t\t\tsr.sr_return_amt > 10000 \n \t\t\tand sts.ss_net_profit > 1\n                         and sts.ss_net_paid > 0 \n                         and sts.ss_quantity > 0\n                         and ss_sold_date_sk = d_date_sk\n                         and d_year = 2000\n                         and d_moy = 12\n \t\tgroup by sts.ss_item_sk\n \t) in_store\n ) store\n where  (\n store.return_rank <= 10\n or \n store.currency_rank <= 10\n )\n )\n order by 1,4,5,2\n  LIMIT 100;\n\n-- end query 48 in stream 0 using template query49.tpl\n-- start query 49 in stream 0 using template query9.tpl\nselect case when (select count(*) \n                  from store_sales \n                  where ss_quantity between 1 and 20) > 144610\n            then (select avg(ss_ext_tax) \n                  from store_sales \n                  where ss_quantity between 1 and 20) \n            else (select avg(ss_net_paid)\n                  from store_sales\n                  where ss_quantity between 1 and 20) end bucket1 ,\n       case when (select count(*)\n                  from store_sales\n                  where ss_quantity between 21 and 40) > 162498\n            then (select avg(ss_ext_tax)\n                  from store_sales\n                  where ss_quantity between 21 and 40) \n            else (select avg(ss_net_paid)\n                  from store_sales\n                  where ss_quantity between 21 and 40) end bucket2,\n       case when (select count(*)\n                  from store_sales\n                  where ss_quantity between 41 and 60) > 28387\n            then (select avg(ss_ext_tax)\n                  from store_sales\n                  where ss_quantity between 41 and 60)\n            else (select avg(ss_net_paid)\n                  from store_sales\n                  where ss_quantity between 41 and 60) end bucket3,\n       case when (select count(*)\n                  from store_sales\n                  where ss_quantity between 61 and 80) > 442573\n            then (select avg(ss_ext_tax)\n                  from store_sales\n                  where ss_quantity between 61 and 80)\n            else (select avg(ss_net_paid)\n                  from store_sales\n                  where ss_quantity between 61 and 80) end bucket4,\n       case when (select count(*)\n                  from store_sales\n                  where ss_quantity between 81 and 100) > 212532\n            then (select avg(ss_ext_tax)\n                  from store_sales\n                  where ss_quantity between 81 and 100)\n            else (select avg(ss_net_paid)\n                  from store_sales\n                  where ss_quantity between 81 and 100) end bucket5\nfrom reason\nwhere r_reason_sk = 1\n;\n\n-- end query 49 in stream 0 using template query9.tpl\n-- start query 50 in stream 0 using template query31.tpl\nwith ss as\n (select ca_county,d_qoy, d_year,sum(ss_ext_sales_price) as store_sales\n from store_sales,date_dim,customer_address\n where ss_sold_date_sk = d_date_sk\n  and ss_addr_sk=ca_address_sk\n group by ca_county,d_qoy, d_year),\n ws as\n (select ca_county,d_qoy, d_year,sum(ws_ext_sales_price) as web_sales\n from web_sales,date_dim,customer_address\n where ws_sold_date_sk = d_date_sk\n  and ws_bill_addr_sk=ca_address_sk\n group by ca_county,d_qoy, d_year)\n select \n        ss1.ca_county\n       ,ss1.d_year\n       ,ws2.web_sales/ws1.web_sales web_q1_q2_increase\n       ,ss2.store_sales/ss1.store_sales store_q1_q2_increase\n       ,ws3.web_sales/ws2.web_sales web_q2_q3_increase\n       ,ss3.store_sales/ss2.store_sales store_q2_q3_increase\n from\n        ss ss1\n       ,ss ss2\n       ,ss ss3\n       ,ws ws1\n       ,ws ws2\n       ,ws ws3\n where\n    ss1.d_qoy = 1\n    and ss1.d_year = 2000\n    and ss1.ca_county = ss2.ca_county\n    and ss2.d_qoy = 2\n    and ss2.d_year = 2000\n and ss2.ca_county = ss3.ca_county\n    and ss3.d_qoy = 3\n    and ss3.d_year = 2000\n    and ss1.ca_county = ws1.ca_county\n    and ws1.d_qoy = 1\n    and ws1.d_year = 2000\n    and ws1.ca_county = ws2.ca_county\n    and ws2.d_qoy = 2\n    and ws2.d_year = 2000\n    and ws1.ca_county = ws3.ca_county\n    and ws3.d_qoy = 3\n    and ws3.d_year =2000\n    and case when ws1.web_sales > 0 then ws2.web_sales/ws1.web_sales else null end \n       > case when ss1.store_sales > 0 then ss2.store_sales/ss1.store_sales else null end\n    and case when ws2.web_sales > 0 then ws3.web_sales/ws2.web_sales else null end\n       > case when ss2.store_sales > 0 then ss3.store_sales/ss2.store_sales else null end\n order by web_q2_q3_increase;\n\n-- end query 50 in stream 0 using template query31.tpl\n-- start query 51 in stream 0 using template query11.tpl\nwith year_total as (\n select c_customer_id customer_id\n       ,c_first_name customer_first_name\n       ,c_last_name customer_last_name\n       ,c_preferred_cust_flag customer_preferred_cust_flag\n       ,c_birth_country customer_birth_country\n       ,c_login customer_login\n       ,c_email_address customer_email_address\n       ,d_year dyear\n       ,sum(ss_ext_list_price-ss_ext_discount_amt) year_total\n       ,'s' sale_type\n from customer\n     ,store_sales\n     ,date_dim\n where c_customer_sk = ss_customer_sk\n   and ss_sold_date_sk = d_date_sk\n group by c_customer_id\n         ,c_first_name\n         ,c_last_name\n         ,c_preferred_cust_flag \n         ,c_birth_country\n         ,c_login\n         ,c_email_address\n         ,d_year \n union all\n select c_customer_id customer_id\n       ,c_first_name customer_first_name\n       ,c_last_name customer_last_name\n       ,c_preferred_cust_flag customer_preferred_cust_flag\n       ,c_birth_country customer_birth_country\n       ,c_login customer_login\n       ,c_email_address customer_email_address\n       ,d_year dyear\n       ,sum(ws_ext_list_price-ws_ext_discount_amt) year_total\n       ,'w' sale_type\n from customer\n     ,web_sales\n     ,date_dim\n where c_customer_sk = ws_bill_customer_sk\n   and ws_sold_date_sk = d_date_sk\n group by c_customer_id\n         ,c_first_name\n         ,c_last_name\n         ,c_preferred_cust_flag \n         ,c_birth_country\n         ,c_login\n         ,c_email_address\n         ,d_year\n         )\n  select  \n                  t_s_secyear.customer_id\n                 ,t_s_secyear.customer_first_name\n                 ,t_s_secyear.customer_last_name\n                 ,t_s_secyear.customer_preferred_cust_flag\n from year_total t_s_firstyear\n     ,year_total t_s_secyear\n     ,year_total t_w_firstyear\n     ,year_total t_w_secyear\n where t_s_secyear.customer_id = t_s_firstyear.customer_id\n         and t_s_firstyear.customer_id = t_w_secyear.customer_id\n         and t_s_firstyear.customer_id = t_w_firstyear.customer_id\n         and t_s_firstyear.sale_type = 's'\n         and t_w_firstyear.sale_type = 'w'\n         and t_s_secyear.sale_type = 's'\n         and t_w_secyear.sale_type = 'w'\n         and t_s_firstyear.dyear = 1998\n         and t_s_secyear.dyear = 1998+1\n         and t_w_firstyear.dyear = 1998\n         and t_w_secyear.dyear = 1998+1\n         and t_s_firstyear.year_total > 0\n         and t_w_firstyear.year_total > 0\n         and case when t_w_firstyear.year_total > 0 then t_w_secyear.year_total / t_w_firstyear.year_total else 0.0 end\n             > case when t_s_firstyear.year_total > 0 then t_s_secyear.year_total / t_s_firstyear.year_total else 0.0 end\n order by t_s_secyear.customer_id\n         ,t_s_secyear.customer_first_name\n         ,t_s_secyear.customer_last_name\n         ,t_s_secyear.customer_preferred_cust_flag\n LIMIT 100;\n\n-- end query 51 in stream 0 using template query11.tpl\n-- start query 52 in stream 0 using template query93.tpl\nselect  ss_customer_sk\n            ,sum(act_sales) sumsales\n      from (select ss_item_sk\n                  ,ss_ticket_number\n                  ,ss_customer_sk\n                  ,case when sr_return_quantity is not null then (ss_quantity-sr_return_quantity)*ss_sales_price\n                                                            else (ss_quantity*ss_sales_price) end act_sales\n            from store_sales left outer join store_returns on (sr_item_sk = ss_item_sk\n                                                               and sr_ticket_number = ss_ticket_number)\n                ,reason\n            where sr_reason_sk = r_reason_sk\n              and r_reason_desc = 'reason 56') t\n      group by ss_customer_sk\n      order by sumsales, ss_customer_sk\n LIMIT 100;\n\n-- end query 52 in stream 0 using template query93.tpl\n-- start query 53 in stream 0 using template query29.tpl\nselect   \n     i_item_id\n    ,i_item_desc\n    ,s_store_id\n    ,s_store_name\n    ,max(ss_quantity)        as store_sales_quantity\n    ,max(sr_return_quantity) as store_returns_quantity\n    ,max(cs_quantity)        as catalog_sales_quantity\n from\n    store_sales\n   ,store_returns\n   ,catalog_sales\n   ,date_dim             d1\n   ,date_dim             d2\n   ,date_dim             d3\n   ,store\n   ,item\n where\n     d1.d_moy               = 4 \n and d1.d_year              = 2000\n and d1.d_date_sk           = ss_sold_date_sk\n and i_item_sk              = ss_item_sk\n and s_store_sk             = ss_store_sk\n and ss_customer_sk         = sr_customer_sk\n and ss_item_sk             = sr_item_sk\n and ss_ticket_number       = sr_ticket_number\n and sr_returned_date_sk    = d2.d_date_sk\n and d2.d_moy               between 4 and  4 + 3 \n and d2.d_year              = 2000\n and sr_customer_sk         = cs_bill_customer_sk\n and sr_item_sk             = cs_item_sk\n and cs_sold_date_sk        = d3.d_date_sk     \n and d3.d_year              in (2000,2000+1,2000+2)\n group by\n    i_item_id\n   ,i_item_desc\n   ,s_store_id\n   ,s_store_name\n order by\n    i_item_id \n   ,i_item_desc\n   ,s_store_id\n   ,s_store_name\n  LIMIT 100;\n\n-- end query 53 in stream 0 using template query29.tpl\n-- start query 54 in stream 0 using template query38.tpl\nselect  count(*) from (\n    select distinct c_last_name, c_first_name, d_date\n    from store_sales, date_dim, customer\n          where store_sales.ss_sold_date_sk = date_dim.d_date_sk\n      and store_sales.ss_customer_sk = customer.c_customer_sk\n      and d_month_seq between 1212 and 1212 + 11\n  intersect\n    select distinct c_last_name, c_first_name, d_date\n    from catalog_sales, date_dim, customer\n          where catalog_sales.cs_sold_date_sk = date_dim.d_date_sk\n      and catalog_sales.cs_bill_customer_sk = customer.c_customer_sk\n      and d_month_seq between 1212 and 1212 + 11\n  intersect\n    select distinct c_last_name, c_first_name, d_date\n    from web_sales, date_dim, customer\n          where web_sales.ws_sold_date_sk = date_dim.d_date_sk\n      and web_sales.ws_bill_customer_sk = customer.c_customer_sk\n      and d_month_seq between 1212 and 1212 + 11\n) hot_cust\n LIMIT 100;\n\n-- end query 54 in stream 0 using template query38.tpl\n-- start query 55 in stream 0 using template query22.tpl\nselect  i_product_name\n             ,i_brand\n             ,i_class\n             ,i_category\n             ,avg(inv_quantity_on_hand) qoh\n       from inventory\n           ,date_dim\n           ,item\n       where inv_date_sk=d_date_sk\n              and inv_item_sk=i_item_sk\n              and d_month_seq between 1188 and 1188 + 11\n       group by rollup(i_product_name\n                       ,i_brand\n                       ,i_class\n                       ,i_category)\norder by qoh, i_product_name, i_brand, i_class, i_category\n LIMIT 100;\n\n-- end query 55 in stream 0 using template query22.tpl\n-- start query 56 in stream 0 using template query89.tpl\nselect  *\nfrom(\nselect i_category, i_class, i_brand,\n       s_store_name, s_company_name,\n       d_moy,\n       sum(ss_sales_price) sum_sales,\n       avg(sum(ss_sales_price)) over\n         (partition by i_category, i_brand, s_store_name, s_company_name)\n         avg_monthly_sales\nfrom item, store_sales, date_dim, store\nwhere ss_item_sk = i_item_sk and\n      ss_sold_date_sk = d_date_sk and\n      ss_store_sk = s_store_sk and\n      d_year in (2001) and\n        ((i_category in ('Electronics','Books','Home') and\n          i_class in ('scanners','parenting','wallpaper')\n         )\n      or (i_category in ('Shoes','Sports','Women') and\n          i_class in ('kids','archery','dresses') \n        ))\ngroup by i_category, i_class, i_brand,\n         s_store_name, s_company_name, d_moy) tmp1\nwhere case when (avg_monthly_sales <> 0) then (abs(sum_sales - avg_monthly_sales) / avg_monthly_sales) else null end > 0.1\norder by sum_sales - avg_monthly_sales, s_store_name\n LIMIT 100;\n\n-- end query 56 in stream 0 using template query89.tpl\n-- start query 57 in stream 0 using template query15.tpl\nselect  ca_zip\n       ,sum(cs_sales_price)\n from catalog_sales\n     ,customer\n     ,customer_address\n     ,date_dim\n where cs_bill_customer_sk = c_customer_sk\n \tand c_current_addr_sk = ca_address_sk \n \tand ( substr(ca_zip,1,5) in ('85669', '86197','88274','83405','86475',\n                                   '85392', '85460', '80348', '81792')\n \t      or ca_state in ('CA','WA','GA')\n \t      or cs_sales_price > 500)\n \tand cs_sold_date_sk = d_date_sk\n \tand d_qoy = 2 and d_year = 2002\n group by ca_zip\n order by ca_zip\n  LIMIT 100;\n\n-- end query 57 in stream 0 using template query15.tpl\n-- start query 58 in stream 0 using template query6.tpl\nselect  a.ca_state state, count(*) cnt\n from customer_address a\n     ,customer c\n     ,store_sales s\n     ,date_dim d\n     ,item i\n where       a.ca_address_sk = c.c_current_addr_sk\n \tand c.c_customer_sk = s.ss_customer_sk\n \tand s.ss_sold_date_sk = d.d_date_sk\n \tand s.ss_item_sk = i.i_item_sk\n \tand d.d_month_seq = \n \t     (select distinct (d_month_seq)\n \t      from date_dim\n               where d_year = 1998\n \t        and d_moy = 6 )\n \tand i.i_current_price > 1.2 * \n             (select avg(j.i_current_price) \n \t     from item j \n \t     where j.i_category = i.i_category)\n group by a.ca_state\n having count(*) >= 10\n order by cnt, a.ca_state \n  LIMIT 100;\n\n-- end query 58 in stream 0 using template query6.tpl\n-- start query 59 in stream 0 using template query52.tpl\nselect  dt.d_year\n \t,item.i_brand_id brand_id\n \t,item.i_brand brand\n \t,sum(ss_ext_sales_price) ext_price\n from date_dim dt\n     ,store_sales\n     ,item\n where dt.d_date_sk = store_sales.ss_sold_date_sk\n    and store_sales.ss_item_sk = item.i_item_sk\n    and item.i_manager_id = 1\n    and dt.d_moy=12\n    and dt.d_year=2002\n group by dt.d_year\n \t,item.i_brand\n \t,item.i_brand_id\n order by dt.d_year\n \t,ext_price desc\n \t,brand_id\n LIMIT 100 ;\n\n-- end query 59 in stream 0 using template query52.tpl\n-- start query 60 in stream 0 using template query50.tpl\nselect  \n   s_store_name\n  ,s_company_id\n  ,s_street_number\n  ,s_street_name\n  ,s_street_type\n  ,s_suite_number\n  ,s_city\n  ,s_county\n  ,s_state\n  ,s_zip\n  ,sum(case when (sr_returned_date_sk - ss_sold_date_sk <= 30 ) then 1 else 0 end)  as `30 days`\n  ,sum(case when (sr_returned_date_sk - ss_sold_date_sk > 30) and \n                 (sr_returned_date_sk - ss_sold_date_sk <= 60) then 1 else 0 end )  as `31-60 days`\n  ,sum(case when (sr_returned_date_sk - ss_sold_date_sk > 60) and \n                 (sr_returned_date_sk - ss_sold_date_sk <= 90) then 1 else 0 end)  as `61-90 days`\n  ,sum(case when (sr_returned_date_sk - ss_sold_date_sk > 90) and\n                 (sr_returned_date_sk - ss_sold_date_sk <= 120) then 1 else 0 end)  as `91-120 days`\n  ,sum(case when (sr_returned_date_sk - ss_sold_date_sk  > 120) then 1 else 0 end)  as `>120 days`\nfrom\n   store_sales\n  ,store_returns\n  ,store\n  ,date_dim d1\n  ,date_dim d2\nwhere\n    d2.d_year = 2002\nand d2.d_moy  = 8\nand ss_ticket_number = sr_ticket_number\nand ss_item_sk = sr_item_sk\nand ss_sold_date_sk   = d1.d_date_sk\nand sr_returned_date_sk   = d2.d_date_sk\nand ss_customer_sk = sr_customer_sk\nand ss_store_sk = s_store_sk\ngroup by\n   s_store_name\n  ,s_company_id\n  ,s_street_number\n  ,s_street_name\n  ,s_street_type\n  ,s_suite_number\n  ,s_city\n  ,s_county\n  ,s_state\n  ,s_zip\norder by s_store_name\n        ,s_company_id\n        ,s_street_number\n        ,s_street_name\n        ,s_street_type\n        ,s_suite_number\n        ,s_city\n        ,s_county\n        ,s_state\n        ,s_zip\n LIMIT 100;\n\n-- end query 60 in stream 0 using template query50.tpl\n-- start query 61 in stream 0 using template query42.tpl\nselect  dt.d_year\n \t,item.i_category_id\n \t,item.i_category\n \t,sum(ss_ext_sales_price)\n from \tdate_dim dt\n \t,store_sales\n \t,item\n where dt.d_date_sk = store_sales.ss_sold_date_sk\n \tand store_sales.ss_item_sk = item.i_item_sk\n \tand item.i_manager_id = 1  \t\n \tand dt.d_moy=11\n \tand dt.d_year=1999\n group by \tdt.d_year\n \t\t,item.i_category_id\n \t\t,item.i_category\n order by       sum(ss_ext_sales_price) desc,dt.d_year\n \t\t,item.i_category_id\n \t\t,item.i_category\n LIMIT 100 ;\n\n-- end query 61 in stream 0 using template query42.tpl\n-- start query 62 in stream 0 using template query41.tpl\nselect  distinct(i_product_name)\n from item i1\n where i_manufact_id between 794 and 794+40 \n   and (select count(*) as item_cnt\n        from item\n        where (i_manufact = i1.i_manufact and\n        ((i_category = 'Women' and \n        (i_color = 'pink' or i_color = 'yellow') and \n        (i_units = 'Lb' or i_units = 'Pallet') and\n        (i_size = 'small' or i_size = 'petite')\n        ) or\n        (i_category = 'Women' and\n        (i_color = 'deep' or i_color = 'goldenrod') and\n        (i_units = 'Bundle' or i_units = 'Oz') and\n        (i_size = 'extra large' or i_size = 'economy')\n        ) or\n        (i_category = 'Men' and\n        (i_color = 'peru' or i_color = 'cream') and\n        (i_units = 'Case' or i_units = 'Ounce') and\n        (i_size = 'medium' or i_size = 'N/A')\n        ) or\n        (i_category = 'Men' and\n        (i_color = 'purple' or i_color = 'floral') and\n        (i_units = 'Each' or i_units = 'Cup') and\n        (i_size = 'small' or i_size = 'petite')\n        ))) or\n       (i_manufact = i1.i_manufact and\n        ((i_category = 'Women' and \n        (i_color = 'blue' or i_color = 'seashell') and \n        (i_units = 'Pound' or i_units = 'Carton') and\n        (i_size = 'small' or i_size = 'petite')\n        ) or\n        (i_category = 'Women' and\n        (i_color = 'slate' or i_color = 'saddle') and\n        (i_units = 'Gram' or i_units = 'Tsp') and\n        (i_size = 'extra large' or i_size = 'economy')\n        ) or\n        (i_category = 'Men' and\n        (i_color = 'midnight' or i_color = 'chiffon') and\n        (i_units = 'Box' or i_units = 'Ton') and\n        (i_size = 'medium' or i_size = 'N/A')\n        ) or\n        (i_category = 'Men' and\n        (i_color = 'orchid' or i_color = 'magenta') and\n        (i_units = 'Unknown' or i_units = 'Tbl') and\n        (i_size = 'small' or i_size = 'petite')\n        )))) > 0\n order by i_product_name\n  LIMIT 100;\n\n-- end query 62 in stream 0 using template query41.tpl\n-- start query 63 in stream 0 using template query8.tpl\nselect  s_store_name\n      ,sum(ss_net_profit)\n from store_sales\n     ,date_dim\n     ,store,\n     (select ca_zip\n     from (\n      SELECT substr(ca_zip,1,5) ca_zip\n      FROM customer_address\n      WHERE substr(ca_zip,1,5) IN (\n                          '43758','76357','20728','59309','19777','27690',\n                          '23681','52275','64367','24674','79465',\n                          '52936','53936','91889','89248','70394',\n                          '66020','56289','45541','29900','99055',\n                          '47395','16654','26748','74456','31039',\n                          '77674','87076','92273','31667','20150',\n                          '84426','75885','61588','57973','29487',\n                          '95008','65615','24339','84923','38463',\n                          '13811','44227','18570','40389','14584',\n                          '33007','61590','47363','57853','43499',\n                          '90755','47141','14392','33991','77031',\n                          '22854','20127','10624','15730','75295',\n                          '98460','17059','26953','82996','17095',\n                          '53227','34618','86978','33613','12541',\n                          '63977','53929','55459','11516','85350',\n                          '99888','23506','10569','66837','50031',\n                          '28282','83901','98554','54828','14616',\n                          '12743','42473','95507','30542','12883',\n                          '95097','61307','32530','37753','53116',\n                          '10989','87430','22114','68848','21246',\n                          '68327','28446','85870','11697','30541',\n                          '22933','70727','17570','55311','73355',\n                          '16347','61573','81229','95480','92091',\n                          '52603','51232','62666','12173','31993',\n                          '98202','78325','46798','63259','34167',\n                          '50435','56182','29390','51732','88435',\n                          '10366','46637','69283','18218','33324',\n                          '24139','16122','53142','16832','98386',\n                          '41451','85109','32534','83953','76537',\n                          '60857','59939','22271','38788','26296',\n                          '59937','14272','98651','38185','16322',\n                          '13735','56321','81398','36035','36512',\n                          '96290','40596','22748','77965','28512',\n                          '15540','20574','72340','81870','31905',\n                          '18121','26282','30345','38703','74274',\n                          '71129','23244','68810','10106','55461',\n                          '25528','71474','37071','21552','81846',\n                          '64930','13233','11694','17829','43790',\n                          '60379','11482','22714','40977','73320',\n                          '13928','78952','92802','66663','95765',\n                          '86101','19813','90867','81258','93891',\n                          '32755','21548','36452','50931','95773',\n                          '57046','14736','30562','44667','80519',\n                          '99886','97296','38505','29732','38693',\n                          '83898','88032','64442','25944','39303',\n                          '70781','92448','64252','89641','88070',\n                          '38159','27654','72120','41689','37122',\n                          '63776','90416','28479','14787','18038',\n                          '39783','50062','28010','13042','86777',\n                          '32380','80664','33558','43641','14627',\n                          '68858','57733','53458','73016','76141',\n                          '42375','12248','38778','50092','80825',\n                          '58934','12145','78407','57009','52782',\n                          '72140','35635','63926','35282','29292',\n                          '30149','33576','95945','48303','56310',\n                          '32214','69726','48249','91163','57311',\n                          '12361','20491','13551','61620','59648',\n                          '44466','53607','18410','99090','37973',\n                          '17986','80713','95948','35103','51799',\n                          '54707','52269','86117','44909','15530',\n                          '28999','80844','62823','46487','15144',\n                          '51445','81050','34943','45141','28541',\n                          '12414','56922','50548','16422','16780',\n                          '53104','60629','24405','61768','48257',\n                          '92852','27390','24411','17776','81487',\n                          '34848','45773','64188','24209','55276',\n                          '11379','33956','46173','67361','32337',\n                          '82112','73196','38461','43987','17980',\n                          '65414','12247','42107','15326','73018',\n                          '59993','85526','50231','60176','23889',\n                          '88012','27859','44921','50915','21742',\n                          '21272','64763','78761','62002','18502',\n                          '42208','49675','69413','46013','67034',\n                          '52739','94050','76249','25105','67299',\n                          '77588','50637','14333','39372','98030',\n                          '79792','12014','56236','61057','51347',\n                          '87879','71564','48478','33078','23325',\n                          '25526','52855','27570','78396','18695',\n                          '24397','76087','35195','97232','29136',\n                          '15812','18408','40746','78749')\n     intersect\n      select ca_zip\n      from (SELECT substr(ca_zip,1,5) ca_zip,count(*) cnt\n            FROM customer_address, customer\n            WHERE ca_address_sk = c_current_addr_sk and\n                  c_preferred_cust_flag='Y'\n            group by ca_zip\n            having count(*) > 10)A1)A2) V1\n where ss_store_sk = s_store_sk\n  and ss_sold_date_sk = d_date_sk\n  and d_qoy = 1 and d_year = 2000\n  and (substr(s_zip,1,2) = substr(V1.ca_zip,1,2))\n group by s_store_name\n order by s_store_name\n  LIMIT 100;\n\n-- end query 63 in stream 0 using template query8.tpl\n-- start query 64 in stream 0 using template query12.tpl\nselect  i_item_id\n      ,i_item_desc \n      ,i_category \n      ,i_class \n      ,i_current_price\n      ,sum(ws_ext_sales_price) as itemrevenue \n      ,sum(ws_ext_sales_price)*100/sum(sum(ws_ext_sales_price)) over\n          (partition by i_class) as revenueratio\nfrom\t\n\tweb_sales\n    \t,item \n    \t,date_dim\nwhere \n\tws_item_sk = i_item_sk \n  \tand i_category in ('Women', 'Children', 'Books')\n  \tand ws_sold_date_sk = d_date_sk\n\tand d_date between cast('2001-02-28' as date) \n\t\t\t\tand (cast('2001-02-28' as date) + interval 30 days)\ngroup by \n\ti_item_id\n        ,i_item_desc \n        ,i_category\n        ,i_class\n        ,i_current_price\norder by \n\ti_category\n        ,i_class\n        ,i_item_id\n        ,i_item_desc\n        ,revenueratio\n LIMIT 100;\n\n-- end query 64 in stream 0 using template query12.tpl\n-- start query 65 in stream 0 using template query20.tpl\nselect  i_item_id\n       ,i_item_desc \n       ,i_category \n       ,i_class \n       ,i_current_price\n       ,sum(cs_ext_sales_price) as itemrevenue \n       ,sum(cs_ext_sales_price)*100/sum(sum(cs_ext_sales_price)) over\n           (partition by i_class) as revenueratio\n from\tcatalog_sales\n     ,item \n     ,date_dim\n where cs_item_sk = i_item_sk \n   and i_category in ('Men', 'Home', 'Music')\n   and cs_sold_date_sk = d_date_sk\n and d_date between cast('1999-03-08' as date) \n \t\t\t\tand (cast('1999-03-08' as date) + interval 30 days)\n group by i_item_id\n         ,i_item_desc \n         ,i_category\n         ,i_class\n         ,i_current_price\n order by i_category\n         ,i_class\n         ,i_item_id\n         ,i_item_desc\n         ,revenueratio\n LIMIT 100;\n\n-- end query 65 in stream 0 using template query20.tpl\n-- start query 66 in stream 0 using template query88.tpl\nselect  *\nfrom\n (select count(*) h8_30_to_9\n from store_sales, household_demographics , time_dim, store\n where ss_sold_time_sk = time_dim.t_time_sk   \n     and ss_hdemo_sk = household_demographics.hd_demo_sk \n     and ss_store_sk = s_store_sk\n     and time_dim.t_hour = 8\n     and time_dim.t_minute >= 30\n     and ((household_demographics.hd_dep_count = -1 and household_demographics.hd_vehicle_count<=-1+2) or\n          (household_demographics.hd_dep_count = 0 and household_demographics.hd_vehicle_count<=0+2) or\n          (household_demographics.hd_dep_count = 2 and household_demographics.hd_vehicle_count<=2+2)) \n     and store.s_store_name = 'ese') s1,\n (select count(*) h9_to_9_30 \n from store_sales, household_demographics , time_dim, store\n where ss_sold_time_sk = time_dim.t_time_sk\n     and ss_hdemo_sk = household_demographics.hd_demo_sk\n     and ss_store_sk = s_store_sk \n     and time_dim.t_hour = 9 \n     and time_dim.t_minute < 30\n     and ((household_demographics.hd_dep_count = -1 and household_demographics.hd_vehicle_count<=-1+2) or\n          (household_demographics.hd_dep_count = 0 and household_demographics.hd_vehicle_count<=0+2) or\n          (household_demographics.hd_dep_count = 2 and household_demographics.hd_vehicle_count<=2+2))\n     and store.s_store_name = 'ese') s2,\n (select count(*) h9_30_to_10 \n from store_sales, household_demographics , time_dim, store\n where ss_sold_time_sk = time_dim.t_time_sk\n     and ss_hdemo_sk = household_demographics.hd_demo_sk\n     and ss_store_sk = s_store_sk\n     and time_dim.t_hour = 9\n     and time_dim.t_minute >= 30\n     and ((household_demographics.hd_dep_count = -1 and household_demographics.hd_vehicle_count<=-1+2) or\n          (household_demographics.hd_dep_count = 0 and household_demographics.hd_vehicle_count<=0+2) or\n          (household_demographics.hd_dep_count = 2 and household_demographics.hd_vehicle_count<=2+2))\n     and store.s_store_name = 'ese') s3,\n (select count(*) h10_to_10_30\n from store_sales, household_demographics , time_dim, store\n where ss_sold_time_sk = time_dim.t_time_sk\n     and ss_hdemo_sk = household_demographics.hd_demo_sk\n     and ss_store_sk = s_store_sk\n     and time_dim.t_hour = 10 \n     and time_dim.t_minute < 30\n     and ((household_demographics.hd_dep_count = -1 and household_demographics.hd_vehicle_count<=-1+2) or\n          (household_demographics.hd_dep_count = 0 and household_demographics.hd_vehicle_count<=0+2) or\n          (household_demographics.hd_dep_count = 2 and household_demographics.hd_vehicle_count<=2+2))\n     and store.s_store_name = 'ese') s4,\n (select count(*) h10_30_to_11\n from store_sales, household_demographics , time_dim, store\n where ss_sold_time_sk = time_dim.t_time_sk\n     and ss_hdemo_sk = household_demographics.hd_demo_sk\n     and ss_store_sk = s_store_sk\n     and time_dim.t_hour = 10 \n     and time_dim.t_minute >= 30\n     and ((household_demographics.hd_dep_count = -1 and household_demographics.hd_vehicle_count<=-1+2) or\n          (household_demographics.hd_dep_count = 0 and household_demographics.hd_vehicle_count<=0+2) or\n          (household_demographics.hd_dep_count = 2 and household_demographics.hd_vehicle_count<=2+2))\n     and store.s_store_name = 'ese') s5,\n (select count(*) h11_to_11_30\n from store_sales, household_demographics , time_dim, store\n where ss_sold_time_sk = time_dim.t_time_sk\n     and ss_hdemo_sk = household_demographics.hd_demo_sk\n     and ss_store_sk = s_store_sk \n     and time_dim.t_hour = 11\n     and time_dim.t_minute < 30\n     and ((household_demographics.hd_dep_count = -1 and household_demographics.hd_vehicle_count<=-1+2) or\n          (household_demographics.hd_dep_count = 0 and household_demographics.hd_vehicle_count<=0+2) or\n          (household_demographics.hd_dep_count = 2 and household_demographics.hd_vehicle_count<=2+2))\n     and store.s_store_name = 'ese') s6,\n (select count(*) h11_30_to_12\n from store_sales, household_demographics , time_dim, store\n where ss_sold_time_sk = time_dim.t_time_sk\n     and ss_hdemo_sk = household_demographics.hd_demo_sk\n     and ss_store_sk = s_store_sk\n     and time_dim.t_hour = 11\n     and time_dim.t_minute >= 30\n     and ((household_demographics.hd_dep_count = -1 and household_demographics.hd_vehicle_count<=-1+2) or\n          (household_demographics.hd_dep_count = 0 and household_demographics.hd_vehicle_count<=0+2) or\n          (household_demographics.hd_dep_count = 2 and household_demographics.hd_vehicle_count<=2+2))\n     and store.s_store_name = 'ese') s7,\n (select count(*) h12_to_12_30\n from store_sales, household_demographics , time_dim, store\n where ss_sold_time_sk = time_dim.t_time_sk\n     and ss_hdemo_sk = household_demographics.hd_demo_sk\n     and ss_store_sk = s_store_sk\n     and time_dim.t_hour = 12\n     and time_dim.t_minute < 30\n     and ((household_demographics.hd_dep_count = -1 and household_demographics.hd_vehicle_count<=-1+2) or\n          (household_demographics.hd_dep_count = 0 and household_demographics.hd_vehicle_count<=0+2) or\n          (household_demographics.hd_dep_count = 2 and household_demographics.hd_vehicle_count<=2+2))\n     and store.s_store_name = 'ese') s8\n;\n\n-- end query 66 in stream 0 using template query88.tpl\n-- start query 67 in stream 0 using template query82.tpl\nselect  i_item_id\n       ,i_item_desc\n       ,i_current_price\n from item, inventory, date_dim, store_sales\n where i_current_price between 9 and 9+30\n and inv_item_sk = i_item_sk\n and d_date_sk=inv_date_sk\n and d_date between cast('2001-06-07' as date) and (cast('2001-06-07' as date) + interval 60 days)\n and i_manufact_id in (797,412,331,589)\n and inv_quantity_on_hand between 100 and 500\n and ss_item_sk = i_item_sk\n group by i_item_id,i_item_desc,i_current_price\n order by i_item_id\n  LIMIT 100;\n\n-- end query 67 in stream 0 using template query82.tpl\n-- start query 68 in stream 0 using template query23.tpl\nwith frequent_ss_items as \n (select substr(i_item_desc,1,30) itemdesc,i_item_sk item_sk,d_date solddate,count(*) cnt\n  from store_sales\n      ,date_dim \n      ,item\n  where ss_sold_date_sk = d_date_sk\n    and ss_item_sk = i_item_sk \n    and d_year in (2000,2000+1,2000+2,2000+3)\n  group by substr(i_item_desc,1,30),i_item_sk,d_date\n  having count(*) >4),\n max_store_sales as\n (select max(csales) tpcds_cmax \n  from (select c_customer_sk,sum(ss_quantity*ss_sales_price) csales\n        from store_sales\n            ,customer\n            ,date_dim \n        where ss_customer_sk = c_customer_sk\n         and ss_sold_date_sk = d_date_sk\n         and d_year in (2000,2000+1,2000+2,2000+3) \n        group by c_customer_sk)),\n best_ss_customer as\n (select c_customer_sk,sum(ss_quantity*ss_sales_price) ssales\n  from store_sales\n      ,customer\n  where ss_customer_sk = c_customer_sk\n  group by c_customer_sk\n  having sum(ss_quantity*ss_sales_price) > (95/100.0) * (select\n  *\nfrom\n max_store_sales))\n  select  sum(sales)\n from (select cs_quantity*cs_list_price sales\n       from catalog_sales\n           ,date_dim \n       where d_year = 2000 \n         and d_moy = 7 \n         and cs_sold_date_sk = d_date_sk \n         and cs_item_sk in (select item_sk from frequent_ss_items)\n         and cs_bill_customer_sk in (select c_customer_sk from best_ss_customer)\n      union all\n      select ws_quantity*ws_list_price sales\n       from web_sales \n           ,date_dim \n       where d_year = 2000 \n         and d_moy = 7 \n         and ws_sold_date_sk = d_date_sk \n         and ws_item_sk in (select item_sk from frequent_ss_items)\n         and ws_bill_customer_sk in (select c_customer_sk from best_ss_customer)) \n  LIMIT 100;\nwith frequent_ss_items as\n (select substr(i_item_desc,1,30) itemdesc,i_item_sk item_sk,d_date solddate,count(*) cnt\n  from store_sales\n      ,date_dim\n      ,item\n  where ss_sold_date_sk = d_date_sk\n    and ss_item_sk = i_item_sk\n    and d_year in (2000,2000 + 1,2000 + 2,2000 + 3)\n  group by substr(i_item_desc,1,30),i_item_sk,d_date\n  having count(*) >4),\n max_store_sales as\n (select max(csales) tpcds_cmax\n  from (select c_customer_sk,sum(ss_quantity*ss_sales_price) csales\n        from store_sales\n            ,customer\n            ,date_dim \n        where ss_customer_sk = c_customer_sk\n         and ss_sold_date_sk = d_date_sk\n         and d_year in (2000,2000+1,2000+2,2000+3)\n        group by c_customer_sk)),\n best_ss_customer as\n (select c_customer_sk,sum(ss_quantity*ss_sales_price) ssales\n  from store_sales\n      ,customer\n  where ss_customer_sk = c_customer_sk\n  group by c_customer_sk\n  having sum(ss_quantity*ss_sales_price) > (95/100.0) * (select\n  *\n from max_store_sales))\n  select  c_last_name,c_first_name,sales\n from (select c_last_name,c_first_name,sum(cs_quantity*cs_list_price) sales\n        from catalog_sales\n            ,customer\n            ,date_dim \n        where d_year = 2000 \n         and d_moy = 7 \n         and cs_sold_date_sk = d_date_sk \n         and cs_item_sk in (select item_sk from frequent_ss_items)\n         and cs_bill_customer_sk in (select c_customer_sk from best_ss_customer)\n         and cs_bill_customer_sk = c_customer_sk \n       group by c_last_name,c_first_name\n      union all\n      select c_last_name,c_first_name,sum(ws_quantity*ws_list_price) sales\n       from web_sales\n           ,customer\n           ,date_dim \n       where d_year = 2000 \n         and d_moy = 7 \n         and ws_sold_date_sk = d_date_sk \n         and ws_item_sk in (select item_sk from frequent_ss_items)\n         and ws_bill_customer_sk in (select c_customer_sk from best_ss_customer)\n         and ws_bill_customer_sk = c_customer_sk\n       group by c_last_name,c_first_name) \n     order by c_last_name,c_first_name,sales\n   LIMIT 100;\n\n-- end query 68 in stream 0 using template query23.tpl\n-- start query 69 in stream 0 using template query14.tpl\nwith  cross_items as\n (select i_item_sk ss_item_sk\n from item,\n (select iss.i_brand_id brand_id\n     ,iss.i_class_id class_id\n     ,iss.i_category_id category_id\n from store_sales\n     ,item iss\n     ,date_dim d1\n where ss_item_sk = iss.i_item_sk\n   and ss_sold_date_sk = d1.d_date_sk\n   and d1.d_year between 1999 AND 1999 + 2\n intersect \n select ics.i_brand_id\n     ,ics.i_class_id\n     ,ics.i_category_id\n from catalog_sales\n     ,item ics\n     ,date_dim d2\n where cs_item_sk = ics.i_item_sk\n   and cs_sold_date_sk = d2.d_date_sk\n   and d2.d_year between 1999 AND 1999 + 2\n intersect\n select iws.i_brand_id\n     ,iws.i_class_id\n     ,iws.i_category_id\n from web_sales\n     ,item iws\n     ,date_dim d3\n where ws_item_sk = iws.i_item_sk\n   and ws_sold_date_sk = d3.d_date_sk\n   and d3.d_year between 1999 AND 1999 + 2)\n where i_brand_id = brand_id\n      and i_class_id = class_id\n      and i_category_id = category_id\n),\n avg_sales as\n (select avg(quantity*list_price) average_sales\n  from (select ss_quantity quantity\n             ,ss_list_price list_price\n       from store_sales\n           ,date_dim\n       where ss_sold_date_sk = d_date_sk\n         and d_year between 1999 and 1999 + 2\n       union all \n       select cs_quantity quantity \n             ,cs_list_price list_price\n       from catalog_sales\n           ,date_dim\n       where cs_sold_date_sk = d_date_sk\n         and d_year between 1999 and 1999 + 2 \n       union all\n       select ws_quantity quantity\n             ,ws_list_price list_price\n       from web_sales\n           ,date_dim\n       where ws_sold_date_sk = d_date_sk\n         and d_year between 1999 and 1999 + 2) x)\n  select  channel, i_brand_id,i_class_id,i_category_id,sum(sales), sum(number_sales)\n from(\n       select 'store' channel, i_brand_id,i_class_id\n             ,i_category_id,sum(ss_quantity*ss_list_price) sales\n             , count(*) number_sales\n       from store_sales\n           ,item\n           ,date_dim\n       where ss_item_sk in (select ss_item_sk from cross_items)\n         and ss_item_sk = i_item_sk\n         and ss_sold_date_sk = d_date_sk\n         and d_year = 1999+2 \n         and d_moy = 11\n       group by i_brand_id,i_class_id,i_category_id\n       having sum(ss_quantity*ss_list_price) > (select average_sales from avg_sales)\n       union all\n       select 'catalog' channel, i_brand_id,i_class_id,i_category_id, sum(cs_quantity*cs_list_price) sales, count(*) number_sales\n       from catalog_sales\n           ,item\n           ,date_dim\n       where cs_item_sk in (select ss_item_sk from cross_items)\n         and cs_item_sk = i_item_sk\n         and cs_sold_date_sk = d_date_sk\n         and d_year = 1999+2 \n         and d_moy = 11\n       group by i_brand_id,i_class_id,i_category_id\n       having sum(cs_quantity*cs_list_price) > (select average_sales from avg_sales)\n       union all\n       select 'web' channel, i_brand_id,i_class_id,i_category_id, sum(ws_quantity*ws_list_price) sales , count(*) number_sales\n       from web_sales\n           ,item\n           ,date_dim\n       where ws_item_sk in (select ss_item_sk from cross_items)\n         and ws_item_sk = i_item_sk\n         and ws_sold_date_sk = d_date_sk\n         and d_year = 1999+2\n         and d_moy = 11\n       group by i_brand_id,i_class_id,i_category_id\n       having sum(ws_quantity*ws_list_price) > (select average_sales from avg_sales)\n ) y\n group by rollup (channel, i_brand_id,i_class_id,i_category_id)\n order by channel,i_brand_id,i_class_id,i_category_id\n  LIMIT 100;\nwith  cross_items as\n (select i_item_sk ss_item_sk\n from item,\n (select iss.i_brand_id brand_id\n     ,iss.i_class_id class_id\n     ,iss.i_category_id category_id\n from store_sales\n     ,item iss\n     ,date_dim d1\n where ss_item_sk = iss.i_item_sk\n   and ss_sold_date_sk = d1.d_date_sk\n   and d1.d_year between 1999 AND 1999 + 2\n intersect\n select ics.i_brand_id\n     ,ics.i_class_id\n     ,ics.i_category_id\n from catalog_sales\n     ,item ics\n     ,date_dim d2\n where cs_item_sk = ics.i_item_sk\n   and cs_sold_date_sk = d2.d_date_sk\n   and d2.d_year between 1999 AND 1999 + 2\n intersect\n select iws.i_brand_id\n     ,iws.i_class_id\n     ,iws.i_category_id\n from web_sales\n     ,item iws\n     ,date_dim d3\n where ws_item_sk = iws.i_item_sk\n   and ws_sold_date_sk = d3.d_date_sk\n   and d3.d_year between 1999 AND 1999 + 2) x\n where i_brand_id = brand_id\n      and i_class_id = class_id\n      and i_category_id = category_id\n),\n avg_sales as\n(select avg(quantity*list_price) average_sales\n  from (select ss_quantity quantity\n             ,ss_list_price list_price\n       from store_sales\n           ,date_dim\n       where ss_sold_date_sk = d_date_sk\n         and d_year between 1999 and 1999 + 2\n       union all\n       select cs_quantity quantity\n             ,cs_list_price list_price\n       from catalog_sales\n           ,date_dim\n       where cs_sold_date_sk = d_date_sk\n         and d_year between 1999 and 1999 + 2\n       union all\n       select ws_quantity quantity\n             ,ws_list_price list_price\n       from web_sales\n           ,date_dim\n       where ws_sold_date_sk = d_date_sk\n         and d_year between 1999 and 1999 + 2) x)\n  select  this_year.channel ty_channel\n                           ,this_year.i_brand_id ty_brand\n                           ,this_year.i_class_id ty_class\n                           ,this_year.i_category_id ty_category\n                           ,this_year.sales ty_sales\n                           ,this_year.number_sales ty_number_sales\n                           ,last_year.channel ly_channel\n                           ,last_year.i_brand_id ly_brand\n                           ,last_year.i_class_id ly_class\n                           ,last_year.i_category_id ly_category\n                           ,last_year.sales ly_sales\n                           ,last_year.number_sales ly_number_sales \n from\n (select 'store' channel, i_brand_id,i_class_id,i_category_id\n        ,sum(ss_quantity*ss_list_price) sales, count(*) number_sales\n from store_sales \n     ,item\n     ,date_dim\n where ss_item_sk in (select ss_item_sk from cross_items)\n   and ss_item_sk = i_item_sk\n   and ss_sold_date_sk = d_date_sk\n   and d_week_seq = (select d_week_seq\n                     from date_dim\n                     where d_year = 1999 + 1\n                       and d_moy = 12\n                       and d_dom = 28)\n group by i_brand_id,i_class_id,i_category_id\n having sum(ss_quantity*ss_list_price) > (select average_sales from avg_sales)) this_year,\n (select 'store' channel, i_brand_id,i_class_id\n        ,i_category_id, sum(ss_quantity*ss_list_price) sales, count(*) number_sales\n from store_sales\n     ,item\n     ,date_dim\n where ss_item_sk in (select ss_item_sk from cross_items)\n   and ss_item_sk = i_item_sk\n   and ss_sold_date_sk = d_date_sk\n   and d_week_seq = (select d_week_seq\n                     from date_dim\n                     where d_year = 1999\n                       and d_moy = 12\n                       and d_dom = 28)\n group by i_brand_id,i_class_id,i_category_id\n having sum(ss_quantity*ss_list_price) > (select average_sales from avg_sales)) last_year\n where this_year.i_brand_id= last_year.i_brand_id\n   and this_year.i_class_id = last_year.i_class_id\n   and this_year.i_category_id = last_year.i_category_id\n order by this_year.channel, this_year.i_brand_id, this_year.i_class_id, this_year.i_category_id\n  LIMIT 100;\n\n-- end query 69 in stream 0 using template query14.tpl\n-- start query 70 in stream 0 using template query57.tpl\nwith v1 as(\n select i_category, i_brand,\n        cc_name,\n        d_year, d_moy,\n        sum(cs_sales_price) sum_sales,\n        avg(sum(cs_sales_price)) over\n          (partition by i_category, i_brand,\n                     cc_name, d_year)\n          avg_monthly_sales,\n        rank() over\n          (partition by i_category, i_brand,\n                     cc_name\n           order by d_year, d_moy) rn\n from item, catalog_sales, date_dim, call_center\n where cs_item_sk = i_item_sk and\n       cs_sold_date_sk = d_date_sk and\n       cc_call_center_sk= cs_call_center_sk and\n       (\n         d_year = 1999 or\n         ( d_year = 1999-1 and d_moy =12) or\n         ( d_year = 1999+1 and d_moy =1)\n       )\n group by i_category, i_brand,\n          cc_name , d_year, d_moy),\n v2 as(\n select v1.i_category, v1.i_brand\n        ,v1.d_year, v1.d_moy\n        ,v1.avg_monthly_sales\n        ,v1.sum_sales, v1_lag.sum_sales psum, v1_lead.sum_sales nsum\n from v1, v1 v1_lag, v1 v1_lead\n where v1.i_category = v1_lag.i_category and\n       v1.i_category = v1_lead.i_category and\n       v1.i_brand = v1_lag.i_brand and\n       v1.i_brand = v1_lead.i_brand and\n       v1. cc_name = v1_lag. cc_name and\n       v1. cc_name = v1_lead. cc_name and\n       v1.rn = v1_lag.rn + 1 and\n       v1.rn = v1_lead.rn - 1)\n  select  *\n from v2\n where  d_year = 1999 and\n        avg_monthly_sales > 0 and\n        case when avg_monthly_sales > 0 then abs(sum_sales - avg_monthly_sales) / avg_monthly_sales else null end > 0.1\n order by sum_sales - avg_monthly_sales, nsum\n  LIMIT 100;\n\n-- end query 70 in stream 0 using template query57.tpl\n-- start query 71 in stream 0 using template query65.tpl\nselect \n\ts_store_name,\n\ti_item_desc,\n\tsc.revenue,\n\ti_current_price,\n\ti_wholesale_cost,\n\ti_brand\n from store, item,\n     (select ss_store_sk, avg(revenue) as ave\n \tfrom\n \t    (select  ss_store_sk, ss_item_sk, \n \t\t     sum(ss_sales_price) as revenue\n \t\tfrom store_sales, date_dim\n \t\twhere ss_sold_date_sk = d_date_sk and d_month_seq between 1212 and 1212+11\n \t\tgroup by ss_store_sk, ss_item_sk) sa\n \tgroup by ss_store_sk) sb,\n     (select  ss_store_sk, ss_item_sk, sum(ss_sales_price) as revenue\n \tfrom store_sales, date_dim\n \twhere ss_sold_date_sk = d_date_sk and d_month_seq between 1212 and 1212+11\n \tgroup by ss_store_sk, ss_item_sk) sc\n where sb.ss_store_sk = sc.ss_store_sk and \n       sc.revenue <= 0.1 * sb.ave and\n       s_store_sk = sc.ss_store_sk and\n       i_item_sk = sc.ss_item_sk\n order by s_store_name, i_item_desc\n LIMIT 100;\n\n-- end query 71 in stream 0 using template query65.tpl\n-- start query 72 in stream 0 using template query71.tpl\nselect i_brand_id brand_id, i_brand brand,t_hour,t_minute,\n \tsum(ext_price) ext_price\n from item, (select ws_ext_sales_price as ext_price, \n                        ws_sold_date_sk as sold_date_sk,\n                        ws_item_sk as sold_item_sk,\n                        ws_sold_time_sk as time_sk  \n                 from web_sales,date_dim\n                 where d_date_sk = ws_sold_date_sk\n                   and d_moy=12\n                   and d_year=2002\n                 union all\n                 select cs_ext_sales_price as ext_price,\n                        cs_sold_date_sk as sold_date_sk,\n                        cs_item_sk as sold_item_sk,\n                        cs_sold_time_sk as time_sk\n                 from catalog_sales,date_dim\n                 where d_date_sk = cs_sold_date_sk\n                   and d_moy=12\n                   and d_year=2002\n                 union all\n                 select ss_ext_sales_price as ext_price,\n                        ss_sold_date_sk as sold_date_sk,\n                        ss_item_sk as sold_item_sk,\n                        ss_sold_time_sk as time_sk\n                 from store_sales,date_dim\n                 where d_date_sk = ss_sold_date_sk\n                   and d_moy=12\n                   and d_year=2002\n                 ) tmp,time_dim\n where\n   sold_item_sk = i_item_sk\n   and i_manager_id=1\n   and time_sk = t_time_sk\n   and (t_meal_time = 'breakfast' or t_meal_time = 'dinner')\n group by i_brand, i_brand_id,t_hour,t_minute\n order by ext_price desc, i_brand_id\n ;\n\n-- end query 72 in stream 0 using template query71.tpl\n-- start query 73 in stream 0 using template query34.tpl\nselect c_last_name\n       ,c_first_name\n       ,c_salutation\n       ,c_preferred_cust_flag\n       ,ss_ticket_number\n       ,cnt from\n   (select ss_ticket_number\n          ,ss_customer_sk\n          ,count(*) cnt\n    from store_sales,date_dim,store,household_demographics\n    where store_sales.ss_sold_date_sk = date_dim.d_date_sk\n    and store_sales.ss_store_sk = store.s_store_sk  \n    and store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk\n    and (date_dim.d_dom between 1 and 3 or date_dim.d_dom between 25 and 28)\n    and (household_demographics.hd_buy_potential = '1001-5000' or\n         household_demographics.hd_buy_potential = '0-500')\n    and household_demographics.hd_vehicle_count > 0\n    and (case when household_demographics.hd_vehicle_count > 0 \n\tthen household_demographics.hd_dep_count/ household_demographics.hd_vehicle_count \n\telse null \n\tend)  > 1.2\n    and date_dim.d_year in (2000,2000+1,2000+2)\n    and store.s_county in ('Williamson County','Walker County','Ziebach County','Walker County',\n                           'Ziebach County','Ziebach County','Ziebach County','Ziebach County')\n    group by ss_ticket_number,ss_customer_sk) dn,customer\n    where ss_customer_sk = c_customer_sk\n      and cnt between 15 and 20\n    order by c_last_name,c_first_name,c_salutation,c_preferred_cust_flag desc, ss_ticket_number;\n\n-- end query 73 in stream 0 using template query34.tpl\n-- start query 74 in stream 0 using template query48.tpl\nselect sum (ss_quantity)\n from store_sales, store, customer_demographics, customer_address, date_dim\n where s_store_sk = ss_store_sk\n and  ss_sold_date_sk = d_date_sk and d_year = 1998\n and  \n (\n  (\n   cd_demo_sk = ss_cdemo_sk\n   and \n   cd_marital_status = 'S'\n   and \n   cd_education_status = 'Secondary'\n   and \n   ss_sales_price between 100.00 and 150.00  \n   )\n or\n  (\n  cd_demo_sk = ss_cdemo_sk\n   and \n   cd_marital_status = 'M'\n   and \n   cd_education_status = 'Primary'\n   and \n   ss_sales_price between 50.00 and 100.00   \n  )\n or \n (\n  cd_demo_sk = ss_cdemo_sk\n  and \n   cd_marital_status = 'W'\n   and \n   cd_education_status = '2 yr Degree'\n   and \n   ss_sales_price between 150.00 and 200.00  \n )\n )\n and\n (\n  (\n  ss_addr_sk = ca_address_sk\n  and\n  ca_country = 'United States'\n  and\n  ca_state in ('ND', 'KY', 'TX')\n  and ss_net_profit between 0 and 2000  \n  )\n or\n  (ss_addr_sk = ca_address_sk\n  and\n  ca_country = 'United States'\n  and\n  ca_state in ('WI', 'AR', 'GA')\n  and ss_net_profit between 150 and 3000 \n  )\n or\n  (ss_addr_sk = ca_address_sk\n  and\n  ca_country = 'United States'\n  and\n  ca_state in ('NC', 'SD', 'IL')\n  and ss_net_profit between 50 and 25000 \n  )\n )\n;\n\n-- end query 74 in stream 0 using template query48.tpl\n-- start query 75 in stream 0 using template query30.tpl\nwith customer_total_return as\n (select wr_returning_customer_sk as ctr_customer_sk\n        ,ca_state as ctr_state, \n \tsum(wr_return_amt) as ctr_total_return\n from web_returns\n     ,date_dim\n     ,customer_address\n where wr_returned_date_sk = d_date_sk \n   and d_year =2001\n   and wr_returning_addr_sk = ca_address_sk \n group by wr_returning_customer_sk\n         ,ca_state)\n  select  c_customer_id,c_salutation,c_first_name,c_last_name,c_preferred_cust_flag\n       ,c_birth_day,c_birth_month,c_birth_year,c_birth_country,c_login,c_email_address\n       ,c_last_review_date_sk,ctr_total_return\n from customer_total_return ctr1\n     ,customer_address\n     ,customer\n where ctr1.ctr_total_return > (select avg(ctr_total_return)*1.2\n \t\t\t  from customer_total_return ctr2 \n                  \t  where ctr1.ctr_state = ctr2.ctr_state)\n       and ca_address_sk = c_current_addr_sk\n       and ca_state = 'MO'\n       and ctr1.ctr_customer_sk = c_customer_sk\n order by c_customer_id,c_salutation,c_first_name,c_last_name,c_preferred_cust_flag\n                  ,c_birth_day,c_birth_month,c_birth_year,c_birth_country,c_login,c_email_address\n                  ,c_last_review_date_sk,ctr_total_return\n LIMIT 100;\n\n-- end query 75 in stream 0 using template query30.tpl\n-- start query 76 in stream 0 using template query74.tpl\nwith year_total as (\n select c_customer_id customer_id\n       ,c_first_name customer_first_name\n       ,c_last_name customer_last_name\n       ,d_year as year\n       ,sum(ss_net_paid) year_total\n       ,'s' sale_type\n from customer\n     ,store_sales\n     ,date_dim\n where c_customer_sk = ss_customer_sk\n   and ss_sold_date_sk = d_date_sk\n   and d_year in (1998,1998+1)\n group by c_customer_id\n         ,c_first_name\n         ,c_last_name\n         ,d_year\n union all\n select c_customer_id customer_id\n       ,c_first_name customer_first_name\n       ,c_last_name customer_last_name\n       ,d_year as year\n       ,sum(ws_net_paid) year_total\n       ,'w' sale_type\n from customer\n     ,web_sales\n     ,date_dim\n where c_customer_sk = ws_bill_customer_sk\n   and ws_sold_date_sk = d_date_sk\n   and d_year in (1998,1998+1)\n group by c_customer_id\n         ,c_first_name\n         ,c_last_name\n         ,d_year\n         )\n  select \n        t_s_secyear.customer_id, t_s_secyear.customer_first_name, t_s_secyear.customer_last_name\n from year_total t_s_firstyear\n     ,year_total t_s_secyear\n     ,year_total t_w_firstyear\n     ,year_total t_w_secyear\n where t_s_secyear.customer_id = t_s_firstyear.customer_id\n         and t_s_firstyear.customer_id = t_w_secyear.customer_id\n         and t_s_firstyear.customer_id = t_w_firstyear.customer_id\n         and t_s_firstyear.sale_type = 's'\n         and t_w_firstyear.sale_type = 'w'\n         and t_s_secyear.sale_type = 's'\n         and t_w_secyear.sale_type = 'w'\n         and t_s_firstyear.year = 1998\n         and t_s_secyear.year = 1998+1\n         and t_w_firstyear.year = 1998\n         and t_w_secyear.year = 1998+1\n         and t_s_firstyear.year_total > 0\n         and t_w_firstyear.year_total > 0\n         and case when t_w_firstyear.year_total > 0 then t_w_secyear.year_total / t_w_firstyear.year_total else null end\n           > case when t_s_firstyear.year_total > 0 then t_s_secyear.year_total / t_s_firstyear.year_total else null end\n order by 2,1,3\n LIMIT 100;\n\n-- end query 76 in stream 0 using template query74.tpl\n-- start query 77 in stream 0 using template query87.tpl\nselect count(*) \nfrom ((select distinct c_last_name, c_first_name, d_date\n       from store_sales, date_dim, customer\n       where store_sales.ss_sold_date_sk = date_dim.d_date_sk\n         and store_sales.ss_customer_sk = customer.c_customer_sk\n         and d_month_seq between 1212 and 1212+11)\n       except\n      (select distinct c_last_name, c_first_name, d_date\n       from catalog_sales, date_dim, customer\n       where catalog_sales.cs_sold_date_sk = date_dim.d_date_sk\n         and catalog_sales.cs_bill_customer_sk = customer.c_customer_sk\n         and d_month_seq between 1212 and 1212+11)\n       except\n      (select distinct c_last_name, c_first_name, d_date\n       from web_sales, date_dim, customer\n       where web_sales.ws_sold_date_sk = date_dim.d_date_sk\n         and web_sales.ws_bill_customer_sk = customer.c_customer_sk\n         and d_month_seq between 1212 and 1212+11)\n) cool_cust\n;\n\n-- end query 77 in stream 0 using template query87.tpl\n-- start query 78 in stream 0 using template query77.tpl\nwith ss as\n (select s_store_sk,\n         sum(ss_ext_sales_price) as sales,\n         sum(ss_net_profit) as profit\n from store_sales,\n      date_dim,\n      store\n where ss_sold_date_sk = d_date_sk\n       and d_date between cast('2002-08-18' as date) \n                  and (cast('2002-08-18' as date) + interval 30 days)\n       and ss_store_sk = s_store_sk\n group by s_store_sk)\n ,\n sr as\n (select s_store_sk,\n         sum(sr_return_amt) as returns,\n         sum(sr_net_loss) as profit_loss\n from store_returns,\n      date_dim,\n      store\n where sr_returned_date_sk = d_date_sk\n       and d_date between cast('2002-08-18' as date)\n                  and (cast('2002-08-18' as date) + interval 30 days)\n       and sr_store_sk = s_store_sk\n group by s_store_sk), \n cs as\n (select cs_call_center_sk,\n        sum(cs_ext_sales_price) as sales,\n        sum(cs_net_profit) as profit\n from catalog_sales,\n      date_dim\n where cs_sold_date_sk = d_date_sk\n       and d_date between cast('2002-08-18' as date)\n                  and (cast('2002-08-18' as date) + interval 30 days)\n group by cs_call_center_sk \n ), \n cr as\n (select cr_call_center_sk,\n         sum(cr_return_amount) as returns,\n         sum(cr_net_loss) as profit_loss\n from catalog_returns,\n      date_dim\n where cr_returned_date_sk = d_date_sk\n       and d_date between cast('2002-08-18' as date)\n                  and (cast('2002-08-18' as date) + interval 30 days)\n group by cr_call_center_sk\n ), \n ws as\n ( select wp_web_page_sk,\n        sum(ws_ext_sales_price) as sales,\n        sum(ws_net_profit) as profit\n from web_sales,\n      date_dim,\n      web_page\n where ws_sold_date_sk = d_date_sk\n       and d_date between cast('2002-08-18' as date)\n                  and (cast('2002-08-18' as date) + interval 30 days)\n       and ws_web_page_sk = wp_web_page_sk\n group by wp_web_page_sk), \n wr as\n (select wp_web_page_sk,\n        sum(wr_return_amt) as returns,\n        sum(wr_net_loss) as profit_loss\n from web_returns,\n      date_dim,\n      web_page\n where wr_returned_date_sk = d_date_sk\n       and d_date between cast('2002-08-18' as date)\n                  and (cast('2002-08-18' as date) + interval 30 days)\n       and wr_web_page_sk = wp_web_page_sk\n group by wp_web_page_sk)\n  select  channel\n        , id\n        , sum(sales) as sales\n        , sum(returns) as returns\n        , sum(profit) as profit\n from \n (select 'store channel' as channel\n        , ss.s_store_sk as id\n        , sales\n        , coalesce(returns, 0) as returns\n        , (profit - coalesce(profit_loss,0)) as profit\n from   ss left join sr\n        on  ss.s_store_sk = sr.s_store_sk\n union all\n select 'catalog channel' as channel\n        , cs_call_center_sk as id\n        , sales\n        , returns\n        , (profit - profit_loss) as profit\n from  cs\n       , cr\n union all\n select 'web channel' as channel\n        , ws.wp_web_page_sk as id\n        , sales\n        , coalesce(returns, 0) returns\n        , (profit - coalesce(profit_loss,0)) as profit\n from   ws left join wr\n        on  ws.wp_web_page_sk = wr.wp_web_page_sk\n ) x\n group by rollup (channel, id)\n order by channel\n         ,id\n  LIMIT 100;\n\n-- end query 78 in stream 0 using template query77.tpl\n-- start query 79 in stream 0 using template query73.tpl\nselect c_last_name\n       ,c_first_name\n       ,c_salutation\n       ,c_preferred_cust_flag \n       ,ss_ticket_number\n       ,cnt from\n   (select ss_ticket_number\n          ,ss_customer_sk\n          ,count(*) cnt\n    from store_sales,date_dim,store,household_demographics\n    where store_sales.ss_sold_date_sk = date_dim.d_date_sk\n    and store_sales.ss_store_sk = store.s_store_sk  \n    and store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk\n    and date_dim.d_dom between 1 and 2 \n    and (household_demographics.hd_buy_potential = '1001-5000' or\n         household_demographics.hd_buy_potential = 'Unknown')\n    and household_demographics.hd_vehicle_count > 0\n    and case when household_demographics.hd_vehicle_count > 0 then \n             household_demographics.hd_dep_count/ household_demographics.hd_vehicle_count else null end > 1\n    and date_dim.d_year in (1999,1999+1,1999+2)\n    and store.s_county in ('Walker County','Williamson County','Ziebach County','Walker County')\n    group by ss_ticket_number,ss_customer_sk) dj,customer\n    where ss_customer_sk = c_customer_sk\n      and cnt between 1 and 5\n    order by cnt desc, c_last_name asc;\n\n-- end query 79 in stream 0 using template query73.tpl\n-- start query 80 in stream 0 using template query84.tpl\nselect  c_customer_id as customer_id\n       , coalesce(c_last_name,'') || ', ' || coalesce(c_first_name,'') as customername\n from customer\n     ,customer_address\n     ,customer_demographics\n     ,household_demographics\n     ,income_band\n     ,store_returns\n where ca_city\t        =  'Fairfield'\n   and c_current_addr_sk = ca_address_sk\n   and ib_lower_bound   >=  58125\n   and ib_upper_bound   <=  58125 + 50000\n   and ib_income_band_sk = hd_income_band_sk\n   and cd_demo_sk = c_current_cdemo_sk\n   and hd_demo_sk = c_current_hdemo_sk\n   and sr_cdemo_sk = cd_demo_sk\n order by c_customer_id\n  LIMIT 100;\n\n-- end query 80 in stream 0 using template query84.tpl\n-- start query 81 in stream 0 using template query54.tpl\nwith my_customers as (\n select distinct c_customer_sk\n        , c_current_addr_sk\n from   \n        ( select cs_sold_date_sk sold_date_sk,\n                 cs_bill_customer_sk customer_sk,\n                 cs_item_sk item_sk\n          from   catalog_sales\n          union all\n          select ws_sold_date_sk sold_date_sk,\n                 ws_bill_customer_sk customer_sk,\n                 ws_item_sk item_sk\n          from   web_sales\n         ) cs_or_ws_sales,\n         item,\n         date_dim,\n         customer\n where   sold_date_sk = d_date_sk\n         and item_sk = i_item_sk\n         and i_category = 'Children'\n         and i_class = 'toddlers'\n         and c_customer_sk = cs_or_ws_sales.customer_sk\n         and d_moy = 4\n         and d_year = 1999\n )\n , my_revenue as (\n select c_customer_sk,\n        sum(ss_ext_sales_price) as revenue\n from   my_customers,\n        store_sales,\n        customer_address,\n        store,\n        date_dim\n where  c_current_addr_sk = ca_address_sk\n        and ca_county = s_county\n        and ca_state = s_state\n        and ss_sold_date_sk = d_date_sk\n        and c_customer_sk = ss_customer_sk\n        and d_month_seq between (select distinct d_month_seq+1\n                                 from   date_dim where d_year = 1999 and d_moy = 4)\n                           and  (select distinct d_month_seq+3\n                                 from   date_dim where d_year = 1999 and d_moy = 4)\n group by c_customer_sk\n )\n , segments as\n (select cast((revenue/50) as int) as segment\n  from   my_revenue\n )\n  select  segment, count(*) as num_customers, segment*50 as segment_base\n from segments\n group by segment\n order by segment, num_customers\n  LIMIT 100;\n\n-- end query 81 in stream 0 using template query54.tpl\n-- start query 82 in stream 0 using template query55.tpl\nselect  i_brand_id brand_id, i_brand brand,\n \tsum(ss_ext_sales_price) ext_price\n from date_dim, store_sales, item\n where d_date_sk = ss_sold_date_sk\n \tand ss_item_sk = i_item_sk\n \tand i_manager_id=76\n \tand d_moy=12\n \tand d_year=1999\n group by i_brand, i_brand_id\n order by ext_price desc, i_brand_id\n LIMIT 100 ;\n\n-- end query 82 in stream 0 using template query55.tpl\n-- start query 83 in stream 0 using template query56.tpl\nwith ss as (\n select i_item_id,sum(ss_ext_sales_price) total_sales\n from\n \tstore_sales,\n \tdate_dim,\n         customer_address,\n         item\n where i_item_id in (select\n     i_item_id\nfrom item\nwhere i_color in ('blush','hot','orange'))\n and     ss_item_sk              = i_item_sk\n and     ss_sold_date_sk         = d_date_sk\n and     d_year                  = 2000\n and     d_moy                   = 5\n and     ss_addr_sk              = ca_address_sk\n and     ca_gmt_offset           = -5 \n group by i_item_id),\n cs as (\n select i_item_id,sum(cs_ext_sales_price) total_sales\n from\n \tcatalog_sales,\n \tdate_dim,\n         customer_address,\n         item\n where\n         i_item_id               in (select\n  i_item_id\nfrom item\nwhere i_color in ('blush','hot','orange'))\n and     cs_item_sk              = i_item_sk\n and     cs_sold_date_sk         = d_date_sk\n and     d_year                  = 2000\n and     d_moy                   = 5\n and     cs_bill_addr_sk         = ca_address_sk\n and     ca_gmt_offset           = -5 \n group by i_item_id),\n ws as (\n select i_item_id,sum(ws_ext_sales_price) total_sales\n from\n \tweb_sales,\n \tdate_dim,\n         customer_address,\n         item\n where\n         i_item_id               in (select\n  i_item_id\nfrom item\nwhere i_color in ('blush','hot','orange'))\n and     ws_item_sk              = i_item_sk\n and     ws_sold_date_sk         = d_date_sk\n and     d_year                  = 2000\n and     d_moy                   = 5\n and     ws_bill_addr_sk         = ca_address_sk\n and     ca_gmt_offset           = -5\n group by i_item_id)\n  select  i_item_id ,sum(total_sales) total_sales\n from  (select * from ss \n        union all\n        select * from cs \n        union all\n        select * from ws) tmp1\n group by i_item_id\n order by total_sales,\n          i_item_id\n  LIMIT 100;\n\n-- end query 83 in stream 0 using template query56.tpl\n-- start query 84 in stream 0 using template query2.tpl\nwith wscs as\n (select sold_date_sk\n        ,sales_price\n  from (select ws_sold_date_sk sold_date_sk\n              ,ws_ext_sales_price sales_price\n        from web_sales \n        union all\n        select cs_sold_date_sk sold_date_sk\n              ,cs_ext_sales_price sales_price\n        from catalog_sales)),\n wswscs as \n (select d_week_seq,\n        sum(case when (d_day_name='Sunday') then sales_price else null end) sun_sales,\n        sum(case when (d_day_name='Monday') then sales_price else null end) mon_sales,\n        sum(case when (d_day_name='Tuesday') then sales_price else  null end) tue_sales,\n        sum(case when (d_day_name='Wednesday') then sales_price else null end) wed_sales,\n        sum(case when (d_day_name='Thursday') then sales_price else null end) thu_sales,\n        sum(case when (d_day_name='Friday') then sales_price else null end) fri_sales,\n        sum(case when (d_day_name='Saturday') then sales_price else null end) sat_sales\n from wscs\n     ,date_dim\n where d_date_sk = sold_date_sk\n group by d_week_seq)\n select d_week_seq1\n       ,round(sun_sales1/sun_sales2,2)\n       ,round(mon_sales1/mon_sales2,2)\n       ,round(tue_sales1/tue_sales2,2)\n       ,round(wed_sales1/wed_sales2,2)\n       ,round(thu_sales1/thu_sales2,2)\n       ,round(fri_sales1/fri_sales2,2)\n       ,round(sat_sales1/sat_sales2,2)\n from\n (select wswscs.d_week_seq d_week_seq1\n        ,sun_sales sun_sales1\n        ,mon_sales mon_sales1\n        ,tue_sales tue_sales1\n        ,wed_sales wed_sales1\n        ,thu_sales thu_sales1\n        ,fri_sales fri_sales1\n        ,sat_sales sat_sales1\n  from wswscs,date_dim \n  where date_dim.d_week_seq = wswscs.d_week_seq and\n        d_year = 1998) y,\n (select wswscs.d_week_seq d_week_seq2\n        ,sun_sales sun_sales2\n        ,mon_sales mon_sales2\n        ,tue_sales tue_sales2\n        ,wed_sales wed_sales2\n        ,thu_sales thu_sales2\n        ,fri_sales fri_sales2\n        ,sat_sales sat_sales2\n  from wswscs\n      ,date_dim \n  where date_dim.d_week_seq = wswscs.d_week_seq and\n        d_year = 1998+1) z\n where d_week_seq1=d_week_seq2-53\n order by d_week_seq1;\n\n-- end query 84 in stream 0 using template query2.tpl\n-- start query 85 in stream 0 using template query26.tpl\nselect  i_item_id, \n        avg(cs_quantity) agg1,\n        avg(cs_list_price) agg2,\n        avg(cs_coupon_amt) agg3,\n        avg(cs_sales_price) agg4 \n from catalog_sales, customer_demographics, date_dim, item, promotion\n where cs_sold_date_sk = d_date_sk and\n       cs_item_sk = i_item_sk and\n       cs_bill_cdemo_sk = cd_demo_sk and\n       cs_promo_sk = p_promo_sk and\n       cd_gender = 'M' and \n       cd_marital_status = 'S' and\n       cd_education_status = '4 yr Degree' and\n       (p_channel_email = 'N' or p_channel_event = 'N') and\n       d_year = 1999 \n group by i_item_id\n order by i_item_id\n  LIMIT 100;\n\n-- end query 85 in stream 0 using template query26.tpl\n-- start query 86 in stream 0 using template query40.tpl\nselect  \n   w_state\n  ,i_item_id\n  ,sum(case when (cast(d_date as date) < cast ('1998-03-13' as date)) \n \t\tthen cs_sales_price - coalesce(cr_refunded_cash,0) else 0 end) as sales_before\n  ,sum(case when (cast(d_date as date) >= cast ('1998-03-13' as date)) \n \t\tthen cs_sales_price - coalesce(cr_refunded_cash,0) else 0 end) as sales_after\n from\n   catalog_sales left outer join catalog_returns on\n       (cs_order_number = cr_order_number \n        and cs_item_sk = cr_item_sk)\n  ,warehouse \n  ,item\n  ,date_dim\n where\n     i_current_price between 0.99 and 1.49\n and i_item_sk          = cs_item_sk\n and cs_warehouse_sk    = w_warehouse_sk \n and cs_sold_date_sk    = d_date_sk\n and d_date between (cast ('1998-03-13' as date) - interval 30 days)\n                and (cast ('1998-03-13' as date) + interval 30 days)\n group by\n    w_state,i_item_id\n order by w_state,i_item_id\n LIMIT 100;\n\n-- end query 86 in stream 0 using template query40.tpl\n-- start query 87 in stream 0 using template query72.tpl\nselect  i_item_desc\n      ,w_warehouse_name\n      ,d1.d_week_seq\n      ,sum(case when p_promo_sk is null then 1 else 0 end) no_promo\n      ,sum(case when p_promo_sk is not null then 1 else 0 end) promo\n      ,count(*) total_cnt\nfrom catalog_sales\njoin inventory on (cs_item_sk = inv_item_sk)\njoin warehouse on (w_warehouse_sk=inv_warehouse_sk)\njoin item on (i_item_sk = cs_item_sk)\njoin customer_demographics on (cs_bill_cdemo_sk = cd_demo_sk)\njoin household_demographics on (cs_bill_hdemo_sk = hd_demo_sk)\njoin date_dim d1 on (cs_sold_date_sk = d1.d_date_sk)\njoin date_dim d2 on (inv_date_sk = d2.d_date_sk)\njoin date_dim d3 on (cs_ship_date_sk = d3.d_date_sk)\nleft outer join promotion on (cs_promo_sk=p_promo_sk)\nleft outer join catalog_returns on (cr_item_sk = cs_item_sk and cr_order_number = cs_order_number)\nwhere d1.d_week_seq = d2.d_week_seq\n  and inv_quantity_on_hand < cs_quantity \n  and d3.d_date > d1.d_date + 5\n  and hd_buy_potential = '501-1000'\n  and d1.d_year = 2002\n  and cd_marital_status = 'M'\ngroup by i_item_desc,w_warehouse_name,d1.d_week_seq\norder by total_cnt desc, i_item_desc, w_warehouse_name, d_week_seq\n LIMIT 100;\n\n-- end query 87 in stream 0 using template query72.tpl\n-- start query 88 in stream 0 using template query53.tpl\nselect  * from \n(select i_manufact_id,\nsum(ss_sales_price) sum_sales,\navg(sum(ss_sales_price)) over (partition by i_manufact_id) avg_quarterly_sales\nfrom item, store_sales, date_dim, store\nwhere ss_item_sk = i_item_sk and\nss_sold_date_sk = d_date_sk and\nss_store_sk = s_store_sk and\nd_month_seq in (1202,1202+1,1202+2,1202+3,1202+4,1202+5,1202+6,1202+7,1202+8,1202+9,1202+10,1202+11) and\n((i_category in ('Books','Children','Electronics') and\ni_class in ('personal','portable','reference','self-help') and\ni_brand in ('scholaramalgamalg #14','scholaramalgamalg #7',\n\t\t'exportiunivamalg #9','scholaramalgamalg #9'))\nor(i_category in ('Women','Music','Men') and\ni_class in ('accessories','classical','fragrances','pants') and\ni_brand in ('amalgimporto #1','edu packscholar #1','exportiimporto #1',\n\t\t'importoamalg #1')))\ngroup by i_manufact_id, d_qoy ) tmp1\nwhere case when avg_quarterly_sales > 0 \n\tthen abs (sum_sales - avg_quarterly_sales)/ avg_quarterly_sales \n\telse null end > 0.1\norder by avg_quarterly_sales,\n\t sum_sales,\n\t i_manufact_id\n LIMIT 100;\n\n-- end query 88 in stream 0 using template query53.tpl\n-- start query 89 in stream 0 using template query79.tpl\nselect \n  c_last_name,c_first_name,substr(s_city,1,30),ss_ticket_number,amt,profit\n  from\n   (select ss_ticket_number\n          ,ss_customer_sk\n          ,store.s_city\n          ,sum(ss_coupon_amt) amt\n          ,sum(ss_net_profit) profit\n    from store_sales,date_dim,store,household_demographics\n    where store_sales.ss_sold_date_sk = date_dim.d_date_sk\n    and store_sales.ss_store_sk = store.s_store_sk  \n    and store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk\n    and (household_demographics.hd_dep_count = 9 or household_demographics.hd_vehicle_count > -1)\n    and date_dim.d_dow = 1\n    and date_dim.d_year in (2000,2000+1,2000+2) \n    and store.s_number_employees between 200 and 295\n    group by ss_ticket_number,ss_customer_sk,ss_addr_sk,store.s_city) ms,customer\n    where ss_customer_sk = c_customer_sk\n order by c_last_name,c_first_name,substr(s_city,1,30), profit\n LIMIT 100;\n\n-- end query 89 in stream 0 using template query79.tpl\n-- start query 90 in stream 0 using template query18.tpl\nselect  i_item_id,\n        ca_country,\n        ca_state, \n        ca_county,\n        avg( cast(cs_quantity as decimal(12,2))) agg1,\n        avg( cast(cs_list_price as decimal(12,2))) agg2,\n        avg( cast(cs_coupon_amt as decimal(12,2))) agg3,\n        avg( cast(cs_sales_price as decimal(12,2))) agg4,\n        avg( cast(cs_net_profit as decimal(12,2))) agg5,\n        avg( cast(c_birth_year as decimal(12,2))) agg6,\n        avg( cast(cd1.cd_dep_count as decimal(12,2))) agg7\n from catalog_sales, customer_demographics cd1, \n      customer_demographics cd2, customer, customer_address, date_dim, item\n where cs_sold_date_sk = d_date_sk and\n       cs_item_sk = i_item_sk and\n       cs_bill_cdemo_sk = cd1.cd_demo_sk and\n       cs_bill_customer_sk = c_customer_sk and\n       cd1.cd_gender = 'F' and \n       cd1.cd_education_status = '4 yr Degree' and\n       c_current_cdemo_sk = cd2.cd_demo_sk and\n       c_current_addr_sk = ca_address_sk and\n       c_birth_month in (4,2,12,10,11,3) and\n       d_year = 2001 and\n       ca_state in ('AR','GA','CO'\n                   ,'MS','ND','KS','KY')\n group by rollup (i_item_id, ca_country, ca_state, ca_county)\n order by ca_country,\n        ca_state, \n        ca_county,\n\ti_item_id\n  LIMIT 100;\n\n-- end query 90 in stream 0 using template query18.tpl\n-- start query 91 in stream 0 using template query13.tpl\nselect avg(ss_quantity)\n       ,avg(ss_ext_sales_price)\n       ,avg(ss_ext_wholesale_cost)\n       ,sum(ss_ext_wholesale_cost)\n from store_sales\n     ,store\n     ,customer_demographics\n     ,household_demographics\n     ,customer_address\n     ,date_dim\n where s_store_sk = ss_store_sk\n and  ss_sold_date_sk = d_date_sk and d_year = 2001\n and((ss_hdemo_sk=hd_demo_sk\n  and cd_demo_sk = ss_cdemo_sk\n  and cd_marital_status = 'D'\n  and cd_education_status = 'Advanced Degree'\n  and ss_sales_price between 100.00 and 150.00\n  and hd_dep_count = 3   \n     )or\n     (ss_hdemo_sk=hd_demo_sk\n  and cd_demo_sk = ss_cdemo_sk\n  and cd_marital_status = 'U'\n  and cd_education_status = '2 yr Degree'\n  and ss_sales_price between 50.00 and 100.00   \n  and hd_dep_count = 1\n     ) or \n     (ss_hdemo_sk=hd_demo_sk\n  and cd_demo_sk = ss_cdemo_sk\n  and cd_marital_status = 'W'\n  and cd_education_status = '4 yr Degree'\n  and ss_sales_price between 150.00 and 200.00 \n  and hd_dep_count = 1  \n     ))\n and((ss_addr_sk = ca_address_sk\n  and ca_country = 'United States'\n  and ca_state in ('TX', 'OH', 'OK')\n  and ss_net_profit between 100 and 200  \n     ) or\n     (ss_addr_sk = ca_address_sk\n  and ca_country = 'United States'\n  and ca_state in ('MS', 'NY', 'GA')\n  and ss_net_profit between 150 and 300  \n     ) or\n     (ss_addr_sk = ca_address_sk\n  and ca_country = 'United States'\n  and ca_state in ('TN', 'IN', 'AL')\n  and ss_net_profit between 50 and 250  \n     ))\n;\n\n-- end query 91 in stream 0 using template query13.tpl\n-- start query 92 in stream 0 using template query24.tpl\nwith ssales as\n(select c_last_name\n      ,c_first_name\n      ,s_store_name\n      ,ca_state\n      ,s_state\n      ,i_color\n      ,i_current_price\n      ,i_manager_id\n      ,i_units\n      ,i_size\n      ,sum(ss_net_profit) netpaid\nfrom store_sales\n    ,store_returns\n    ,store\n    ,item\n    ,customer\n    ,customer_address\nwhere ss_ticket_number = sr_ticket_number\n  and ss_item_sk = sr_item_sk\n  and ss_customer_sk = c_customer_sk\n  and ss_item_sk = i_item_sk\n  and ss_store_sk = s_store_sk\n  and c_current_addr_sk = ca_address_sk\n  and c_birth_country <> upper(ca_country)\n  and s_zip = ca_zip\nand s_market_id=10\ngroup by c_last_name\n        ,c_first_name\n        ,s_store_name\n        ,ca_state\n        ,s_state\n        ,i_color\n        ,i_current_price\n        ,i_manager_id\n        ,i_units\n        ,i_size)\nselect c_last_name\n      ,c_first_name\n      ,s_store_name\n      ,sum(netpaid) paid\nfrom ssales\nwhere i_color = 'firebrick'\ngroup by c_last_name\n        ,c_first_name\n        ,s_store_name\nhaving sum(netpaid) > (select 0.05*avg(netpaid)\n                                 from ssales)\norder by c_last_name\n        ,c_first_name\n        ,s_store_name\n;\nwith ssales as\n(select c_last_name\n      ,c_first_name\n      ,s_store_name\n      ,ca_state\n      ,s_state\n      ,i_color\n      ,i_current_price\n      ,i_manager_id\n      ,i_units\n      ,i_size\n      ,sum(ss_net_profit) netpaid\nfrom store_sales\n    ,store_returns\n    ,store\n    ,item\n    ,customer\n    ,customer_address\nwhere ss_ticket_number = sr_ticket_number\n  and ss_item_sk = sr_item_sk\n  and ss_customer_sk = c_customer_sk\n  and ss_item_sk = i_item_sk\n  and ss_store_sk = s_store_sk\n  and c_current_addr_sk = ca_address_sk\n  and c_birth_country <> upper(ca_country)\n  and s_zip = ca_zip\n  and s_market_id = 10\ngroup by c_last_name\n        ,c_first_name\n        ,s_store_name\n        ,ca_state\n        ,s_state\n        ,i_color\n        ,i_current_price\n        ,i_manager_id\n        ,i_units\n        ,i_size)\nselect c_last_name\n      ,c_first_name\n      ,s_store_name\n      ,sum(netpaid) paid\nfrom ssales\nwhere i_color = 'sienna'\ngroup by c_last_name\n        ,c_first_name\n        ,s_store_name\nhaving sum(netpaid) > (select 0.05*avg(netpaid)\n                           from ssales)\norder by c_last_name\n        ,c_first_name\n        ,s_store_name\n;\n\n-- end query 92 in stream 0 using template query24.tpl\n-- start query 93 in stream 0 using template query4.tpl\nwith year_total as (\n select c_customer_id customer_id\n       ,c_first_name customer_first_name\n       ,c_last_name customer_last_name\n       ,c_preferred_cust_flag customer_preferred_cust_flag\n       ,c_birth_country customer_birth_country\n       ,c_login customer_login\n       ,c_email_address customer_email_address\n       ,d_year dyear\n       ,sum(((ss_ext_list_price-ss_ext_wholesale_cost-ss_ext_discount_amt)+ss_ext_sales_price)/2) year_total\n       ,'s' sale_type\n from customer\n     ,store_sales\n     ,date_dim\n where c_customer_sk = ss_customer_sk\n   and ss_sold_date_sk = d_date_sk\n group by c_customer_id\n         ,c_first_name\n         ,c_last_name\n         ,c_preferred_cust_flag\n         ,c_birth_country\n         ,c_login\n         ,c_email_address\n         ,d_year\n union all\n select c_customer_id customer_id\n       ,c_first_name customer_first_name\n       ,c_last_name customer_last_name\n       ,c_preferred_cust_flag customer_preferred_cust_flag\n       ,c_birth_country customer_birth_country\n       ,c_login customer_login\n       ,c_email_address customer_email_address\n       ,d_year dyear\n       ,sum((((cs_ext_list_price-cs_ext_wholesale_cost-cs_ext_discount_amt)+cs_ext_sales_price)/2) ) year_total\n       ,'c' sale_type\n from customer\n     ,catalog_sales\n     ,date_dim\n where c_customer_sk = cs_bill_customer_sk\n   and cs_sold_date_sk = d_date_sk\n group by c_customer_id\n         ,c_first_name\n         ,c_last_name\n         ,c_preferred_cust_flag\n         ,c_birth_country\n         ,c_login\n         ,c_email_address\n         ,d_year\nunion all\n select c_customer_id customer_id\n       ,c_first_name customer_first_name\n       ,c_last_name customer_last_name\n       ,c_preferred_cust_flag customer_preferred_cust_flag\n       ,c_birth_country customer_birth_country\n       ,c_login customer_login\n       ,c_email_address customer_email_address\n       ,d_year dyear\n       ,sum((((ws_ext_list_price-ws_ext_wholesale_cost-ws_ext_discount_amt)+ws_ext_sales_price)/2) ) year_total\n       ,'w' sale_type\n from customer\n     ,web_sales\n     ,date_dim\n where c_customer_sk = ws_bill_customer_sk\n   and ws_sold_date_sk = d_date_sk\n group by c_customer_id\n         ,c_first_name\n         ,c_last_name\n         ,c_preferred_cust_flag\n         ,c_birth_country\n         ,c_login\n         ,c_email_address\n         ,d_year\n         )\n  select  \n                  t_s_secyear.customer_id\n                 ,t_s_secyear.customer_first_name\n                 ,t_s_secyear.customer_last_name\n                 ,t_s_secyear.customer_preferred_cust_flag\n from year_total t_s_firstyear\n     ,year_total t_s_secyear\n     ,year_total t_c_firstyear\n     ,year_total t_c_secyear\n     ,year_total t_w_firstyear\n     ,year_total t_w_secyear\n where t_s_secyear.customer_id = t_s_firstyear.customer_id\n   and t_s_firstyear.customer_id = t_c_secyear.customer_id\n   and t_s_firstyear.customer_id = t_c_firstyear.customer_id\n   and t_s_firstyear.customer_id = t_w_firstyear.customer_id\n   and t_s_firstyear.customer_id = t_w_secyear.customer_id\n   and t_s_firstyear.sale_type = 's'\n   and t_c_firstyear.sale_type = 'c'\n   and t_w_firstyear.sale_type = 'w'\n   and t_s_secyear.sale_type = 's'\n   and t_c_secyear.sale_type = 'c'\n   and t_w_secyear.sale_type = 'w'\n   and t_s_firstyear.dyear =  1999\n   and t_s_secyear.dyear = 1999+1\n   and t_c_firstyear.dyear =  1999\n   and t_c_secyear.dyear =  1999+1\n   and t_w_firstyear.dyear = 1999\n   and t_w_secyear.dyear = 1999+1\n   and t_s_firstyear.year_total > 0\n   and t_c_firstyear.year_total > 0\n   and t_w_firstyear.year_total > 0\n   and case when t_c_firstyear.year_total > 0 then t_c_secyear.year_total / t_c_firstyear.year_total else null end\n           > case when t_s_firstyear.year_total > 0 then t_s_secyear.year_total / t_s_firstyear.year_total else null end\n   and case when t_c_firstyear.year_total > 0 then t_c_secyear.year_total / t_c_firstyear.year_total else null end\n           > case when t_w_firstyear.year_total > 0 then t_w_secyear.year_total / t_w_firstyear.year_total else null end\n order by t_s_secyear.customer_id\n         ,t_s_secyear.customer_first_name\n         ,t_s_secyear.customer_last_name\n         ,t_s_secyear.customer_preferred_cust_flag\n LIMIT 100;\n\n-- end query 93 in stream 0 using template query4.tpl\n-- start query 94 in stream 0 using template query99.tpl\nselect  \n   substr(w_warehouse_name,1,20)\n  ,sm_type\n  ,cc_name\n  ,sum(case when (cs_ship_date_sk - cs_sold_date_sk <= 30 ) then 1 else 0 end)  as `30 days`\n  ,sum(case when (cs_ship_date_sk - cs_sold_date_sk > 30) and \n                 (cs_ship_date_sk - cs_sold_date_sk <= 60) then 1 else 0 end )  as `31-60 days`\n  ,sum(case when (cs_ship_date_sk - cs_sold_date_sk > 60) and \n                 (cs_ship_date_sk - cs_sold_date_sk <= 90) then 1 else 0 end)  as `61-90 days`\n  ,sum(case when (cs_ship_date_sk - cs_sold_date_sk > 90) and\n                 (cs_ship_date_sk - cs_sold_date_sk <= 120) then 1 else 0 end)  as `91-120 days`\n  ,sum(case when (cs_ship_date_sk - cs_sold_date_sk  > 120) then 1 else 0 end)  as `>120 days`\nfrom\n   catalog_sales\n  ,warehouse\n  ,ship_mode\n  ,call_center\n  ,date_dim\nwhere\n    d_month_seq between 1222 and 1222 + 11\nand cs_ship_date_sk   = d_date_sk\nand cs_warehouse_sk   = w_warehouse_sk\nand cs_ship_mode_sk   = sm_ship_mode_sk\nand cs_call_center_sk = cc_call_center_sk\ngroup by\n   substr(w_warehouse_name,1,20)\n  ,sm_type\n  ,cc_name\norder by substr(w_warehouse_name,1,20)\n        ,sm_type\n        ,cc_name\n LIMIT 100;\n\n-- end query 94 in stream 0 using template query99.tpl\n-- start query 95 in stream 0 using template query68.tpl\nselect  c_last_name\n       ,c_first_name\n       ,ca_city\n       ,bought_city\n       ,ss_ticket_number\n       ,extended_price\n       ,extended_tax\n       ,list_price\n from (select ss_ticket_number\n             ,ss_customer_sk\n             ,ca_city bought_city\n             ,sum(ss_ext_sales_price) extended_price \n             ,sum(ss_ext_list_price) list_price\n             ,sum(ss_ext_tax) extended_tax \n       from store_sales\n           ,date_dim\n           ,store\n           ,household_demographics\n           ,customer_address \n       where store_sales.ss_sold_date_sk = date_dim.d_date_sk\n         and store_sales.ss_store_sk = store.s_store_sk  \n        and store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk\n        and store_sales.ss_addr_sk = customer_address.ca_address_sk\n        and date_dim.d_dom between 1 and 2 \n        and (household_demographics.hd_dep_count = 6 or\n             household_demographics.hd_vehicle_count= 1)\n        and date_dim.d_year in (1998,1998+1,1998+2)\n        and store.s_city in ('Midway','Pleasant Hill')\n       group by ss_ticket_number\n               ,ss_customer_sk\n               ,ss_addr_sk,ca_city) dn\n      ,customer\n      ,customer_address current_addr\n where ss_customer_sk = c_customer_sk\n   and customer.c_current_addr_sk = current_addr.ca_address_sk\n   and current_addr.ca_city <> bought_city\n order by c_last_name\n         ,ss_ticket_number\n  LIMIT 100;\n\n-- end query 95 in stream 0 using template query68.tpl\n-- start query 96 in stream 0 using template query83.tpl\nwith sr_items as\n (select i_item_id item_id,\n        sum(sr_return_quantity) sr_item_qty\n from store_returns,\n      item,\n      date_dim\n where sr_item_sk = i_item_sk\n and   d_date    in \n\t(select d_date\n\tfrom date_dim\n\twhere d_week_seq in \n\t\t(select d_week_seq\n\t\tfrom date_dim\n\t  where d_date in ('1998-05-29','1998-08-19','1998-11-10')))\n and   sr_returned_date_sk   = d_date_sk\n group by i_item_id),\n cr_items as\n (select i_item_id item_id,\n        sum(cr_return_quantity) cr_item_qty\n from catalog_returns,\n      item,\n      date_dim\n where cr_item_sk = i_item_sk\n and   d_date    in \n\t(select d_date\n\tfrom date_dim\n\twhere d_week_seq in \n\t\t(select d_week_seq\n\t\tfrom date_dim\n\t  where d_date in ('1998-05-29','1998-08-19','1998-11-10')))\n and   cr_returned_date_sk   = d_date_sk\n group by i_item_id),\n wr_items as\n (select i_item_id item_id,\n        sum(wr_return_quantity) wr_item_qty\n from web_returns,\n      item,\n      date_dim\n where wr_item_sk = i_item_sk\n and   d_date    in \n\t(select d_date\n\tfrom date_dim\n\twhere d_week_seq in \n\t\t(select d_week_seq\n\t\tfrom date_dim\n\t\twhere d_date in ('1998-05-29','1998-08-19','1998-11-10')))\n and   wr_returned_date_sk   = d_date_sk\n group by i_item_id)\n  select  sr_items.item_id\n       ,sr_item_qty\n       ,sr_item_qty/(sr_item_qty+cr_item_qty+wr_item_qty)/3.0 * 100 sr_dev\n       ,cr_item_qty\n       ,cr_item_qty/(sr_item_qty+cr_item_qty+wr_item_qty)/3.0 * 100 cr_dev\n       ,wr_item_qty\n       ,wr_item_qty/(sr_item_qty+cr_item_qty+wr_item_qty)/3.0 * 100 wr_dev\n       ,(sr_item_qty+cr_item_qty+wr_item_qty)/3.0 average\n from sr_items\n     ,cr_items\n     ,wr_items\n where sr_items.item_id=cr_items.item_id\n   and sr_items.item_id=wr_items.item_id \n order by sr_items.item_id\n         ,sr_item_qty\n  LIMIT 100;\n\n-- end query 96 in stream 0 using template query83.tpl\n-- start query 97 in stream 0 using template query61.tpl\nselect  promotions,total,cast(promotions as decimal(15,4))/cast(total as decimal(15,4))*100\nfrom\n  (select sum(ss_ext_sales_price) promotions\n   from  store_sales\n        ,store\n        ,promotion\n        ,date_dim\n        ,customer\n        ,customer_address \n        ,item\n   where ss_sold_date_sk = d_date_sk\n   and   ss_store_sk = s_store_sk\n   and   ss_promo_sk = p_promo_sk\n   and   ss_customer_sk= c_customer_sk\n   and   ca_address_sk = c_current_addr_sk\n   and   ss_item_sk = i_item_sk \n   and   ca_gmt_offset = -6\n   and   i_category = 'Sports'\n   and   (p_channel_dmail = 'Y' or p_channel_email = 'Y' or p_channel_tv = 'Y')\n   and   s_gmt_offset = -6\n   and   d_year = 1998\n   and   d_moy  = 12) promotional_sales,\n  (select sum(ss_ext_sales_price) total\n   from  store_sales\n        ,store\n        ,date_dim\n        ,customer\n        ,customer_address\n        ,item\n   where ss_sold_date_sk = d_date_sk\n   and   ss_store_sk = s_store_sk\n   and   ss_customer_sk= c_customer_sk\n   and   ca_address_sk = c_current_addr_sk\n   and   ss_item_sk = i_item_sk\n   and   ca_gmt_offset = -6\n   and   i_category = 'Sports'\n   and   s_gmt_offset = -6\n   and   d_year = 1998\n   and   d_moy  = 12) all_sales\norder by promotions, total\n LIMIT 100;\n\n-- end query 97 in stream 0 using template query61.tpl\n-- start query 98 in stream 0 using template query5.tpl\nwith ssr as\n (select s_store_id,\n        sum(sales_price) as sales,\n        sum(profit) as profit,\n        sum(return_amt) as returns,\n        sum(net_loss) as profit_loss\n from\n  ( select  ss_store_sk as store_sk,\n            ss_sold_date_sk  as date_sk,\n            ss_ext_sales_price as sales_price,\n            ss_net_profit as profit,\n            cast(0 as decimal(7,2)) as return_amt,\n            cast(0 as decimal(7,2)) as net_loss\n    from store_sales\n    union all\n    select sr_store_sk as store_sk,\n           sr_returned_date_sk as date_sk,\n           cast(0 as decimal(7,2)) as sales_price,\n           cast(0 as decimal(7,2)) as profit,\n           sr_return_amt as return_amt,\n           sr_net_loss as net_loss\n    from store_returns\n   ) salesreturns,\n     date_dim,\n     store\n where date_sk = d_date_sk\n       and d_date between cast('1998-08-21' as date) \n                  and (cast('1998-08-21' as date) + interval 14 days)\n       and store_sk = s_store_sk\n group by s_store_id)\n ,\n csr as\n (select cp_catalog_page_id,\n        sum(sales_price) as sales,\n        sum(profit) as profit,\n        sum(return_amt) as returns,\n        sum(net_loss) as profit_loss\n from\n  ( select  cs_catalog_page_sk as page_sk,\n            cs_sold_date_sk  as date_sk,\n            cs_ext_sales_price as sales_price,\n            cs_net_profit as profit,\n            cast(0 as decimal(7,2)) as return_amt,\n            cast(0 as decimal(7,2)) as net_loss\n    from catalog_sales\n    union all\n    select cr_catalog_page_sk as page_sk,\n           cr_returned_date_sk as date_sk,\n           cast(0 as decimal(7,2)) as sales_price,\n           cast(0 as decimal(7,2)) as profit,\n           cr_return_amount as return_amt,\n           cr_net_loss as net_loss\n    from catalog_returns\n   ) salesreturns,\n     date_dim,\n     catalog_page\n where date_sk = d_date_sk\n       and d_date between cast('1998-08-21' as date)\n                  and (cast('1998-08-21' as date) + interval 14 days)\n       and page_sk = cp_catalog_page_sk\n group by cp_catalog_page_id)\n ,\n wsr as\n (select web_site_id,\n        sum(sales_price) as sales,\n        sum(profit) as profit,\n        sum(return_amt) as returns,\n        sum(net_loss) as profit_loss\n from\n  ( select  ws_web_site_sk as wsr_web_site_sk,\n            ws_sold_date_sk  as date_sk,\n            ws_ext_sales_price as sales_price,\n            ws_net_profit as profit,\n            cast(0 as decimal(7,2)) as return_amt,\n            cast(0 as decimal(7,2)) as net_loss\n    from web_sales\n    union all\n    select ws_web_site_sk as wsr_web_site_sk,\n           wr_returned_date_sk as date_sk,\n           cast(0 as decimal(7,2)) as sales_price,\n           cast(0 as decimal(7,2)) as profit,\n           wr_return_amt as return_amt,\n           wr_net_loss as net_loss\n    from web_returns left outer join web_sales on\n         ( wr_item_sk = ws_item_sk\n           and wr_order_number = ws_order_number)\n   ) salesreturns,\n     date_dim,\n     web_site\n where date_sk = d_date_sk\n       and d_date between cast('1998-08-21' as date)\n                  and (cast('1998-08-21' as date) + interval 14 days)\n       and wsr_web_site_sk = web_site_sk\n group by web_site_id)\n  select  channel\n        , id\n        , sum(sales) as sales\n        , sum(returns) as returns\n        , sum(profit) as profit\n from \n (select 'store channel' as channel\n        , 'store' || s_store_id as id\n        , sales\n        , returns\n        , (profit - profit_loss) as profit\n from   ssr\n union all\n select 'catalog channel' as channel\n        , 'catalog_page' || cp_catalog_page_id as id\n        , sales\n        , returns\n        , (profit - profit_loss) as profit\n from  csr\n union all\n select 'web channel' as channel\n        , 'web_site' || web_site_id as id\n        , sales\n        , returns\n        , (profit - profit_loss) as profit\n from   wsr\n ) x\n group by rollup (channel, id)\n order by channel\n         ,id\n  LIMIT 100;\n\n-- end query 98 in stream 0 using template query5.tpl\n-- start query 99 in stream 0 using template query76.tpl\nselect  channel, col_name, d_year, d_qoy, i_category, COUNT(*) sales_cnt, SUM(ext_sales_price) sales_amt FROM (\n        SELECT 'store' as channel, 'ss_addr_sk' col_name, d_year, d_qoy, i_category, ss_ext_sales_price ext_sales_price\n         FROM store_sales, item, date_dim\n         WHERE ss_addr_sk IS NULL\n           AND ss_sold_date_sk=d_date_sk\n           AND ss_item_sk=i_item_sk\n        UNION ALL\n        SELECT 'web' as channel, 'ws_web_page_sk' col_name, d_year, d_qoy, i_category, ws_ext_sales_price ext_sales_price\n         FROM web_sales, item, date_dim\n         WHERE ws_web_page_sk IS NULL\n           AND ws_sold_date_sk=d_date_sk\n           AND ws_item_sk=i_item_sk\n        UNION ALL\n        SELECT 'catalog' as channel, 'cs_ship_mode_sk' col_name, d_year, d_qoy, i_category, cs_ext_sales_price ext_sales_price\n         FROM catalog_sales, item, date_dim\n         WHERE cs_ship_mode_sk IS NULL\n           AND cs_sold_date_sk=d_date_sk\n           AND cs_item_sk=i_item_sk) foo\nGROUP BY channel, col_name, d_year, d_qoy, i_category\nORDER BY channel, col_name, d_year, d_qoy, i_category\n LIMIT 100;\n\n-- end query 99 in stream 0 using template query76.tpl\n"
  },
  {
    "path": "examples/spark-connect-gpu/client/notebook/README.md",
    "content": "# Demo Notebook Overview\n\nThe `spark-connect-gpu-etl-ml.ipynb` notebook demonstrates:\n\n## ETL Pipeline\n- **Data ingestion** from CSV with custom schema\n- **Complex transformations** including date parsing and delinquency calculations\n- **String-to-numeric encoding** for categorical features\n- **Data joins and aggregations** with mortgage performance data\n\n## Machine Learning Workflow\n- **Feature engineering** with FeatureHasher and VectorAssembler\n- **Logistic Regression** training for multi-class prediction\n- **Model evaluation** with performance metrics\n- **GPU vs CPU timing comparisons**\n\n## Key Code Examples\n\n**Connecting to Spark with GPU acceleration:**\n```python\nfrom pyspark.sql import SparkSession\n\nspark = (\n  SparkSession.builder\n    .remote('sc://spark-connect-server')\n    .appName('GPU-Accelerated-ETL-ML-Demo')\n    .getOrCreate()\n)\n```\n\nIn the actual demo code we find it handier to use the `SPARK_REMOTE` environment variable instead\nof having it in the code\nso it is easy to run it in a Spark Classic way as well.\n\n**Machine Learning with GPU acceleration:**\n```python\nfrom pyspark.ml import Pipeline\nfrom pyspark.ml.classification import LogisticRegression\nfrom pyspark.ml.feature import VectorAssembler, FeatureHasher\n\nspark.conf.set('spark.connect.ml.backend.classes', 'com.nvidia.rapids.ml.Plugin')\n\n# Feature preparation\nhasher = FeatureHasher(inputCols=categorical_cols, outputCol='hashed_categorical')\nassembler = VectorAssembler().setInputCols(numerical_cols + ['hashed_categorical']).setOutputCol('features')\n\n# Model training\nlogistic = LogisticRegression().setFeaturesCol('features').setLabelCol('delinquency_12')\npipeline = Pipeline().setStages([hasher, assembler, logistic])\nmodel = pipeline.fit(training_data)\n```\n\n## Results\n\nThe demo at the Data+AI Summit'25 used the following mortgage quarters\n\n```bash\n$ du -h *\n503M    2023Q1.csv\n412M    2023Q2.csv\n162M    2023Q3.csv\n1.1G    2023Q4.csv\n```\n\nand was tested on a machine with a 6GiB RTX A3000 Laptop GPU\n\n```bash\n$ nvidia-smi\n+-----------------------------------------------------------------------------------------+\n| NVIDIA-SMI 560.35.05              Driver Version: 560.35.05      CUDA Version: 12.6     |\n|-----------------------------------------+------------------------+----------------------+\n| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |\n| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |\n|                                         |                        |               MIG M. |\n|=========================================+========================+======================|\n|   0  NVIDIA RTX A3000 Laptop GPU    Off |   00000000:01:00.0 Off |                  N/A |\n| N/A   56C    P8             13W /   60W |    1353MiB /   6144MiB |      1%      Default |\n|                                         |                        |                  N/A |\n+-----------------------------------------+------------------------+----------------------+\n```\n\nand a 2x8-core CPU\n\n![GPU Acceleration Results](example-acceleration-chart.png)\n"
  },
  {
    "path": "examples/spark-connect-gpu/client/notebook/spark-connect-gpu-etl-ml.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# GPU-Accelerated Spark Connect - SQL/DF ETL and MLlib on Mortgage Dataset (Spark 4.0+)\\n\",\n    \"\\n\",\n    \"Based on the Data and AI Summit 2025 session: [GPU Accelerated Spark Connect](https://www.databricks.com/dataaisummit/session/gpu-accelerated-spark-connect)\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Import packages\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from pyspark.ml import Pipeline\\n\",\n    \"from pyspark.ml.classification import LogisticRegression\\n\",\n    \"from pyspark.ml.evaluation import MulticlassClassificationEvaluator\\n\",\n    \"from pyspark.ml.feature import VectorAssembler, FeatureHasher\\n\",\n    \"from pyspark.sql import SparkSession\\n\",\n    \"from pyspark.sql.functions import *\\n\",\n    \"from pyspark.sql.types import IntegerType\\n\",\n    \"from pyspark.sql.window import Window\\n\",\n    \"import csv\\n\",\n    \"import os\\n\",\n    \"import pandas as pd\\n\",\n    \"import time\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Connect to Spark via Spark Connect\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Create GPU-accelerated Spark session using Spark Connect 4.0+\\n\",\n    \"spark = (\\n\",\n    \"  SparkSession.builder\\n\",\n    \"    .appName('GPU-Accelerated Spark Connect - SQL/ETL and MLlib') \\n\",\n    \"    .getOrCreate()\\n\",\n    \")\\n\",\n    \"print(f'Spark Connect session id: {spark.session_id}')\\n\",\n    \"print(f'Spark version: {spark.version}')\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Local and Global Storage Access \"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# This can be a local storage location accessible to the thin Spark Connect app\\n\",\n    \"# such as a IPython kernel\\n\",\n    \"local_data_dir = 'work'\\n\",\n    \"\\n\",\n    \"# This would normally be a global storage location such as Cloud Object Storage\\n\",\n    \"# This notebook requires a writable directory on the host. It is mounted into containers\\n\",\n    \"# requiring access to it as /data from the host \\n\",\n    \"# This directory should contain directory `mortgage.input.csv` with files from the Mortgage dataset.\\n\",\n    \"# We also store here data useful across the container life cycle such as metrics from the previous runs\\n\",\n    \"# and Spark event logs. \\n\",\n    \"global_data_dir = '/data'\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Normalize references to the same bank \"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"with open(f'{local_data_dir}/name_mapping.csv', 'r') as name_mapping_file:\\n\",\n    \"  nm_reader = csv.reader(name_mapping_file,)\\n\",\n    \"  name_mapping = [r for r in nm_reader]\\n\",\n    \"name_mapping_df = spark.createDataFrame(name_mapping, ['from_seller_name', 'to_seller_name'])\\n\",\n    \"\\n\",\n    \"(\\n\",\n    \"  name_mapping_df\\n\",\n    \"    .where(col('to_seller_name') == 'Wells Fargo' )\\n\",\n    \"    .show(truncate=False)\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# String columns\\n\",\n    \"cate_col_names = [\\n\",\n    \"  'orig_channel',\\n\",\n    \"  'first_home_buyer',\\n\",\n    \"  'loan_purpose',\\n\",\n    \"  'property_type',\\n\",\n    \"  'occupancy_status',\\n\",\n    \"  'property_state',\\n\",\n    \"  'product_type',\\n\",\n    \"  'relocation_mortgage_indicator',\\n\",\n    \"  'seller_name',\\n\",\n    \"  'mod_flag'\\n\",\n    \"]\\n\",\n    \"# Numeric columns\\n\",\n    \"label_col_name = 'delinquency_12'\\n\",\n    \"numeric_col_names = [\\n\",\n    \"  'orig_interest_rate',\\n\",\n    \"  'orig_upb',\\n\",\n    \"  'orig_loan_term',\\n\",\n    \"  'orig_ltv',\\n\",\n    \"  'orig_cltv',\\n\",\n    \"  'num_borrowers',\\n\",\n    \"  'dti',\\n\",\n    \"  'borrower_credit_score',\\n\",\n    \"  'num_units',\\n\",\n    \"  'zip',\\n\",\n    \"  'mortgage_insurance_percent',\\n\",\n    \"  'current_loan_delinquency_status',\\n\",\n    \"  'current_actual_upb',\\n\",\n    \"  'interest_rate',\\n\",\n    \"  'loan_age',\\n\",\n    \"  'msa',\\n\",\n    \"  'non_interest_bearing_upb',\\n\",\n    \"  label_col_name\\n\",\n    \"]\\n\",\n    \"all_col_names = cate_col_names + numeric_col_names\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Define ETL Process\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Functions to read raw columns\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def read_raw_csv(spark, path):\\n\",\n    \"  def _get_quarter_from_csv_file_name():\\n\",\n    \"    return substring_index(substring_index(input_file_name(), '.', 1), '/', -1)\\n\",\n    \"\\n\",\n    \"  with open(f'{local_data_dir}/csv_raw_schema.ddl', 'r') as f:\\n\",\n    \"    _csv_raw_schema_str = f.read()\\n\",\n    \"  \\n\",\n    \"  return (\\n\",\n    \"    spark.read\\n\",\n    \"    .format('csv') \\n\",\n    \"    .option('nullValue', '') \\n\",\n    \"    .option('header', False) \\n\",\n    \"    .option('delimiter', '|') \\n\",\n    \"    .schema(_csv_raw_schema_str) \\n\",\n    \"    .load(path) \\n\",\n    \"    .withColumn('quarter', _get_quarter_from_csv_file_name())\\n\",\n    \"  )\\n\",\n    \"\\n\",\n    \"def extract_perf_columns(rawDf):\\n\",\n    \"  perfDf = rawDf.select(\\n\",\n    \"    col('loan_id'),\\n\",\n    \"    date_format(to_date(col('monthly_reporting_period'),'MMyyyy'), 'MM/dd/yyyy').alias('monthly_reporting_period'),\\n\",\n    \"    upper(col('servicer')).alias('servicer'),\\n\",\n    \"    col('interest_rate'),\\n\",\n    \"    col('current_actual_upb'),\\n\",\n    \"    col('loan_age'),\\n\",\n    \"    col('remaining_months_to_legal_maturity'),\\n\",\n    \"    col('adj_remaining_months_to_maturity'),\\n\",\n    \"    date_format(to_date(col('maturity_date'),'MMyyyy'), 'MM/yyyy').alias('maturity_date'),\\n\",\n    \"    col('msa'),\\n\",\n    \"    col('current_loan_delinquency_status'),\\n\",\n    \"    col('mod_flag'),\\n\",\n    \"    col('zero_balance_code'),\\n\",\n    \"    date_format(to_date(col('zero_balance_effective_date'),'MMyyyy'), 'MM/yyyy').alias('zero_balance_effective_date'),\\n\",\n    \"    date_format(to_date(col('last_paid_installment_date'),'MMyyyy'), 'MM/dd/yyyy').alias('last_paid_installment_date'),\\n\",\n    \"    date_format(to_date(col('foreclosed_after'),'MMyyyy'), 'MM/dd/yyyy').alias('foreclosed_after'),\\n\",\n    \"    date_format(to_date(col('disposition_date'),'MMyyyy'), 'MM/dd/yyyy').alias('disposition_date'),\\n\",\n    \"    col('foreclosure_costs'),\\n\",\n    \"    col('prop_preservation_and_repair_costs'),\\n\",\n    \"    col('asset_recovery_costs'),\\n\",\n    \"    col('misc_holding_expenses'),\\n\",\n    \"    col('holding_taxes'),\\n\",\n    \"    col('net_sale_proceeds'),\\n\",\n    \"    col('credit_enhancement_proceeds'),\\n\",\n    \"    col('repurchase_make_whole_proceeds'),\\n\",\n    \"    col('other_foreclosure_proceeds'),\\n\",\n    \"    col('non_interest_bearing_upb'),\\n\",\n    \"    col('principal_forgiveness_upb'),\\n\",\n    \"    col('repurchase_make_whole_proceeds_flag'),\\n\",\n    \"    col('foreclosure_principal_write_off_amount'),\\n\",\n    \"    col('servicing_activity_indicator'),\\n\",\n    \"    col('quarter')\\n\",\n    \"  )\\n\",\n    \"  return perfDf.select('*').filter('current_actual_upb != 0.0')\\n\",\n    \"\\n\",\n    \"def extract_acq_columns(rawDf):\\n\",\n    \"  acqDf = rawDf.select(\\n\",\n    \"    col('loan_id'),\\n\",\n    \"    col('orig_channel'),\\n\",\n    \"    upper(col('seller_name')).alias('seller_name'),\\n\",\n    \"    col('orig_interest_rate'),\\n\",\n    \"    col('orig_upb'),\\n\",\n    \"    col('orig_loan_term'),\\n\",\n    \"    date_format(to_date(col('orig_date'),'MMyyyy'), 'MM/yyyy').alias('orig_date'),\\n\",\n    \"    date_format(to_date(col('first_pay_date'),'MMyyyy'), 'MM/yyyy').alias('first_pay_date'),\\n\",\n    \"    col('orig_ltv'),\\n\",\n    \"    col('orig_cltv'),\\n\",\n    \"    col('num_borrowers'),\\n\",\n    \"    col('dti'),\\n\",\n    \"    col('borrower_credit_score'),\\n\",\n    \"    col('first_home_buyer'),\\n\",\n    \"    col('loan_purpose'),\\n\",\n    \"    col('property_type'),\\n\",\n    \"    col('num_units'),\\n\",\n    \"    col('occupancy_status'),\\n\",\n    \"    col('property_state'),\\n\",\n    \"    col('zip'),\\n\",\n    \"    col('mortgage_insurance_percent'),\\n\",\n    \"    col('product_type'),\\n\",\n    \"    col('coborrow_credit_score'),\\n\",\n    \"    col('mortgage_insurance_type'),\\n\",\n    \"    col('relocation_mortgage_indicator'),\\n\",\n    \"    dense_rank().over(Window.partitionBy('loan_id').orderBy(to_date(col('monthly_reporting_period'),'MMyyyy'))).alias('rank'),\\n\",\n    \"    col('quarter')\\n\",\n    \"  )\\n\",\n    \"\\n\",\n    \"  return acqDf.select('*').filter(col('rank')==1)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Define function to parse date in Performance data \"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def _parse_dates(perf):\\n\",\n    \"  return (\\n\",\n    \"    perf.withColumn('monthly_reporting_period', to_date(col('monthly_reporting_period'), 'MM/dd/yyyy')) \\n\",\n    \"      .withColumn('monthly_reporting_period_month', month(col('monthly_reporting_period'))) \\n\",\n    \"      .withColumn('monthly_reporting_period_year', year(col('monthly_reporting_period'))) \\n\",\n    \"      .withColumn('monthly_reporting_period_day', dayofmonth(col('monthly_reporting_period'))) \\n\",\n    \"      .withColumn('last_paid_installment_date', to_date(col('last_paid_installment_date'), 'MM/dd/yyyy')) \\n\",\n    \"      .withColumn('foreclosed_after', to_date(col('foreclosed_after'), 'MM/dd/yyyy')) \\n\",\n    \"      .withColumn('disposition_date', to_date(col('disposition_date'), 'MM/dd/yyyy')) \\n\",\n    \"      .withColumn('maturity_date', to_date(col('maturity_date'), 'MM/yyyy')) \\n\",\n    \"  )\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Define function to create deliquency data frame from Performance data.  \\n\",\n    \"\\n\",\n    \"The computed `delinquency_12` column denotes whether a loan will become delinquent by 3, 6, or 9 months, \\n\",\n    \"or not delinquent, within the next 12 month period.   \\n\",\n    \"\\n\",\n    \"It will be the target label for ML multi-class prediction.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def _create_perf_deliquency(spark, perf):\\n\",\n    \"  aggDF = (\\n\",\n    \"    perf\\n\",\n    \"      .select(\\n\",\n    \"        col('quarter'),\\n\",\n    \"        col('loan_id'),\\n\",\n    \"        col('current_loan_delinquency_status'),\\n\",\n    \"        when(col('current_loan_delinquency_status') >= 1, col('monthly_reporting_period')).alias('delinquency_30'),\\n\",\n    \"        when(col('current_loan_delinquency_status') >= 3, col('monthly_reporting_period')).alias('delinquency_90'),\\n\",\n    \"        when(col('current_loan_delinquency_status') >= 6, col('monthly_reporting_period')).alias('delinquency_180')\\n\",\n    \"      ).groupBy('quarter', 'loan_id')\\n\",\n    \"       .agg(\\n\",\n    \"         max('current_loan_delinquency_status').alias('delinquency_12'),\\n\",\n    \"         min('delinquency_30').alias('delinquency_30'),\\n\",\n    \"         min('delinquency_90').alias('delinquency_90'),\\n\",\n    \"         min('delinquency_180').alias('delinquency_180')\\n\",\n    \"       ).select(\\n\",\n    \"         col('quarter'),\\n\",\n    \"         col('loan_id'),\\n\",\n    \"         (col('delinquency_12') >= 1).alias('ever_30'),\\n\",\n    \"         (col('delinquency_12') >= 3).alias('ever_90'),\\n\",\n    \"         (col('delinquency_12') >= 6).alias('ever_180'),\\n\",\n    \"         col('delinquency_30'),\\n\",\n    \"         col('delinquency_90'),\\n\",\n    \"         col('delinquency_180')\\n\",\n    \"       )\\n\",\n    \"  )\\n\",\n    \"  #aggDF.printSchema()\\n\",\n    \"  joinedDf = (\\n\",\n    \"    perf\\n\",\n    \"      .withColumnRenamed('monthly_reporting_period', 'timestamp')\\n\",\n    \"      .withColumnRenamed('monthly_reporting_period_month', 'timestamp_month') \\n\",\n    \"      .withColumnRenamed('monthly_reporting_period_year', 'timestamp_year') \\n\",\n    \"      .withColumnRenamed('current_loan_delinquency_status', 'delinquency_12') \\n\",\n    \"      .withColumnRenamed('current_actual_upb', 'upb_12') \\n\",\n    \"      .select('quarter', 'loan_id', 'timestamp', 'delinquency_12', 'upb_12', 'timestamp_month', 'timestamp_year') \\n\",\n    \"      .join(aggDF, ['loan_id', 'quarter'], 'left_outer')\\n\",\n    \"  )\\n\",\n    \"  # calculate the 12 month delinquency and upb values\\n\",\n    \"  months = 12\\n\",\n    \"  monthArray = [lit(x) for x in range(0, 12)]\\n\",\n    \"  \\n\",\n    \"  testDf = ( \\n\",\n    \"    joinedDf\\n\",\n    \"      .withColumn('month_y', explode(array(monthArray)))\\n\",\n    \"      .select(\\n\",\n    \"        col('quarter'),\\n\",\n    \"        floor(((col('timestamp_year') * 12 + col('timestamp_month')) - 24000) / months).alias('josh_mody'),\\n\",\n    \"        floor(((col('timestamp_year') * 12 + col('timestamp_month')) - 24000 - col('month_y')) / months).alias('josh_mody_n'),\\n\",\n    \"        col('ever_30'),\\n\",\n    \"        col('ever_90'),\\n\",\n    \"        col('ever_180'),\\n\",\n    \"        col('delinquency_30'),\\n\",\n    \"        col('delinquency_90'),\\n\",\n    \"        col('delinquency_180'),\\n\",\n    \"        col('loan_id'),\\n\",\n    \"        col('month_y'),\\n\",\n    \"        col('delinquency_12'),\\n\",\n    \"        col('upb_12')\\n\",\n    \"      ).groupBy('quarter', 'loan_id', 'josh_mody_n', 'ever_30', 'ever_90', 'ever_180', 'delinquency_30', 'delinquency_90', 'delinquency_180', 'month_y')\\n\",\n    \"    .agg(max('delinquency_12').alias('delinquency_12'), min('upb_12').alias('upb_12')) \\n\",\n    \"    .withColumn('timestamp_year', floor((lit(24000) + (col('josh_mody_n') * lit(months)) + (col('month_y') - 1)) / lit(12))) \\n\",\n    \"    .selectExpr('*', f'pmod(24000 + (josh_mody_n * {months}) + month_y, 12) as timestamp_month_tmp') \\n\",\n    \"    .withColumn('timestamp_month', when(col('timestamp_month_tmp') == lit(0), lit(12)).otherwise(col('timestamp_month_tmp'))) \\n\",\n    \"    .withColumn('delinquency_12', ((col('delinquency_12') > 9).cast('int') + (col('delinquency_12') > 6).cast('int') + (col('delinquency_12') > 3).cast('int') + (col('upb_12') == 0).cast('int')).alias('delinquency_12')) \\n\",\n    \"    .drop('timestamp_month_tmp', 'josh_mody_n', 'month_y')\\n\",\n    \"  )\\n\",\n    \"\\n\",\n    \"  return (\\n\",\n    \"    perf\\n\",\n    \"      .withColumnRenamed('monthly_reporting_period_month', 'timestamp_month')\\n\",\n    \"      .withColumnRenamed('monthly_reporting_period_year', 'timestamp_year')\\n\",\n    \"      .join(testDf, ['quarter', 'loan_id', 'timestamp_year', 'timestamp_month'], 'left')\\n\",\n    \"      .drop('timestamp_year', 'timestamp_month')\\n\",\n    \"  )\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Define function to create acquisition data frame from Acquisition data\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def _create_acquisition(spark, acq):\\n\",\n    \"  return (\\n\",\n    \"    acq.join(name_mapping_df, col('seller_name') == col('from_seller_name'), 'left')\\n\",\n    \"      .drop('from_seller_name') \\n\",\n    \"      .withColumn('old_name', col('seller_name')) \\n\",\n    \"      .withColumn('seller_name', coalesce(col('to_seller_name'), col('seller_name'))) \\n\",\n    \"      .drop('to_seller_name') \\n\",\n    \"      .withColumn('orig_date', to_date(col('orig_date'), 'MM/yyyy')) \\n\",\n    \"      .withColumn('first_pay_date', to_date(col('first_pay_date'), 'MM/yyyy')) \\n\",\n    \"  )\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Define Casting Process\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"This part is casting String column to Numeric one. \\n\",\n    \"Example:\\n\",\n    \"```\\n\",\n    \"col_1\\n\",\n    \" \\\"a\\\"\\n\",\n    \" \\\"b\\\"\\n\",\n    \" \\\"c\\\"\\n\",\n    \" \\\"a\\\"\\n\",\n    \"# After String ====> Numeric\\n\",\n    \"col_1\\n\",\n    \" 0\\n\",\n    \" 1\\n\",\n    \" 2\\n\",\n    \" 0\\n\",\n    \"```  \\n\",\n    \"\\n\",\n    \"### Define function to get column dictionary\\n\",\n    \"\\n\",\n    \"Example\\n\",\n    \"\\n\",\n    \"```\\n\",\n    \"col1 = [row(data=\\\"a\\\",id=0), row(data=\\\"b\\\",id=1)]\\n\",\n    \"```\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def _gen_dictionary(etl_df, col_names):\\n\",\n    \"  cnt_table = (\\n\",\n    \"    etl_df.select(posexplode(array([col(i) for i in col_names])))\\n\",\n    \"      .withColumnRenamed('pos', 'column_id')\\n\",\n    \"      .withColumnRenamed('col', 'data')\\n\",\n    \"      .filter('data is not null')\\n\",\n    \"      .groupBy('column_id', 'data')\\n\",\n    \"      .count()\\n\",\n    \"  )\\n\",\n    \"  windowed = Window.partitionBy('column_id').orderBy(desc('count'))\\n\",\n    \"  return cnt_table.withColumn('id', row_number().over(windowed)).drop('count')\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Define function to convert string columns to numeric\\n\",\n    \"\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def _cast_string_columns_to_numeric(spark, input_df):\\n\",\n    \"  cached_dict_df = _gen_dictionary(input_df, cate_col_names)  \\n\",\n    \"  # .cache()\\n\",\n    \"  #  Uncomment above line to cache the dictionary dataframe. You need to spark.catalog.clearCache()\\n\",\n    \"  #  when running the notebook multiple times switching between CPU and GPU.\\n\",\n    \"  \\n\",\n    \"  output_df = input_df\\n\",\n    \"  #  Generate the final table with all columns being numeric.\\n\",\n    \"  for col_pos, col_name in enumerate(cate_col_names):\\n\",\n    \"    col_dict_df = (\\n\",\n    \"      cached_dict_df.filter(col('column_id') == col_pos)\\n\",\n    \"        .drop('column_id')\\n\",\n    \"        .withColumnRenamed('data', col_name)\\n\",\n    \"    )\\n\",\n    \"    output_df = (\\n\",\n    \"      output_df.join(broadcast(col_dict_df), col_name, 'left')\\n\",\n    \"        .drop(col_name)\\n\",\n    \"        .withColumnRenamed('id', col_name)\\n\",\n    \"    )\\n\",\n    \"  return output_df     \"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Define Main Function\\n\",\n    \"\\n\",\n    \"In this function:\\n\",\n    \"1. Parse date in Performance data by calling _parse_dates (parsed_perf)\\n\",\n    \"2. Create deliqency dataframe(perf_deliqency) form Performance data by calling _create_perf_deliquency\\n\",\n    \"3. Create cleaned acquisition dataframe(cleaned_acq) from Acquisition data by calling _create_acquisition\\n\",\n    \"4. Join deliqency dataframe(perf_deliqency) and cleaned acquisition dataframe(cleaned_acq), get clean_df\\n\",\n    \"5. Cast String column to Numeric in clean_df by calling _cast_string_columns_to_numeric, get casted_clean_df\\n\",\n    \"6. Return casted_clean_df as final result\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def run_mortgage(spark, perf, acq):\\n\",\n    \"  parsed_perf = _parse_dates(perf)\\n\",\n    \"  perf_deliqency = _create_perf_deliquency(spark, parsed_perf)\\n\",\n    \"  cleaned_acq = _create_acquisition(spark, acq)\\n\",\n    \"  clean_df = perf_deliqency.join(cleaned_acq, ['loan_id', 'quarter'], 'inner').drop('quarter')\\n\",\n    \"  casted_clean_df = (\\n\",\n    \"    _cast_string_columns_to_numeric(spark, clean_df)\\n\",\n    \"      .select(all_col_names)\\n\",\n    \"      .withColumn(label_col_name, when(col(label_col_name) > 0, col(label_col_name)).otherwise(0))\\n\",\n    \"      .fillna(float(0))\\n\",\n    \"  )\\n\",\n    \"  return casted_clean_df\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Knobs for running the pipelines\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Should raw csv input be used or input persisted to Parquet \\n\",\n    \"read_from_csv = False\\n\",\n    \"# if not read_from_csv, include conversion to Parquet in this run?\\n\",\n    \"convert_csv_to_parquet = True\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"###  Execute SQL and ML on GPU ?\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": \"accelerate_on_gpu = True\"\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### ETL on GPU?\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"spark.conf.set('spark.rapids.sql.enabled', accelerate_on_gpu)  \"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### ML on GPU?\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"if accelerate_on_gpu:\\n\",\n    \"  spark.conf.set('spark.connect.ml.backend.classes', 'com.nvidia.rapids.ml.Plugin')\\n\",\n    \"else:\\n\",\n    \"  spark.conf.unset('spark.connect.ml.backend.classes')\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Run ETL Pipeline\\n\",\n    \"\\n\",\n    \"#### Read Raw Data and Run ETL Process, Save the Result\\n\",\n    \"\\n\",\n    \"##### Convert CSV to Parquet\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"if read_from_csv:\\n\",\n    \"  mortgage_csv = read_raw_csv(spark, f'{global_data_dir}/mortgage.input.csv')\\n\",\n    \"elif convert_csv_to_parquet:\\n\",\n    \"  read_raw_csv(spark, f'{global_data_dir}/mortgage.input.csv')\\\\\\n\",\n    \"    .write.parquet(f'{global_data_dir}/mortgage_input.pq', mode='overwrite')\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"##### ETL from Parquet or raw CSV Data\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"mortgage = mortgage_csv if read_from_csv else spark.read.parquet(f'{global_data_dir}/mortgage_input.pq')\\n\",\n    \"acq = extract_acq_columns(mortgage)\\n\",\n    \"perf = extract_perf_columns(mortgage)\\n\",\n    \"# run main function to process data\\n\",\n    \"preprocessed = run_mortgage(spark, perf, acq)\\n\",\n    \"# save processed data\\n\",\n    \"\\n\",\n    \"start = time.time()\\n\",\n    \"preprocessed.write.parquet(f'{global_data_dir}/mortgage_preprocessed.pq' , mode='overwrite')\\n\",\n    \"end = time.time()\\n\",\n    \"\\n\",\n    \"etl_dur = end - start\\n\",\n    \"print(f'ETL takes {etl_dur}')\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Modeling Pipeline\\n\",\n    \"\\n\",\n    \"#### The ML modeling phase of the example uses the `spark.ml` Pipeline API to carry out the following steps on a random subsample of the ETL output:\\n\",\n    \"  - use `spark.ml FeatureHasher` to map the int type columns in the ETL output to a 2^15 dimensional sparse feature vector with a non-zero entry in each location corresponding to hash value of each input column value + column name.\\n\",\n    \"  - use `spark.ml VectorAssembler` to combine the output of `FeatureHasher` with the original float type columns into a single `VectorUDT` type feature vector\\n\",\n    \"  - train a model using `LogisticRegression` to predict the multi-class (4 class values) label \\\"delinquency_12\\\".\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"etlDf = spark.read.parquet(f'{global_data_dir}/mortgage_preprocessed.pq')\\n\",\n    \"etlDf = etlDf.sample(fraction=0.1, seed=1234)\\n\",\n    \"etlDf.describe().filter(col('summary') == 'mean').show(vertical=True, truncate=False)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"etlDf = etlDf.withColumn('loc',(etlDf.msa*1000+etlDf.zip).cast('int')).drop('zip' ,'msa')\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"label_col_name = 'delinquency_12'\\n\",\n    \"schema = etlDf.schema\\n\",\n    \"raw_features = [ x for x in schema.fields if x.name != label_col_name ]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"categorical_cols = [f.name for f in raw_features if f.dataType == IntegerType()]\\n\",\n    \"numerical_cols = [f.name for f in raw_features if f.name not in categorical_cols]\\n\",\n    \"hasher = FeatureHasher(inputCols=categorical_cols, outputCol='hashed_categorical', \\n\",\n    \"                       categoricalCols=categorical_cols, numFeatures=(1 << 15))\\n\",\n    \"va = VectorAssembler().setInputCols(numerical_cols + [hasher.getOutputCol()]).setOutputCol('features')\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"logistic =  ( \\n\",\n    \"  LogisticRegression()\\n\",\n    \"    .setMaxIter(200)\\n\",\n    \"    .setRegParam(0.00002)\\n\",\n    \"    .setElasticNetParam(0.1)\\n\",\n    \"    .setTol(1.0e-12)\\n\",\n    \"    .setFeaturesCol('features')\\n\",\n    \"    .setLabelCol(label_col_name)\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"[df_train, df_test] = etlDf.randomSplit([0.8, 0.2], seed=1234)\\n\",\n    \"pipeline = Pipeline().setStages([hasher, va, logistic])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"start = time.time()\\n\",\n    \"# gpu lr, gpu etl, gpu transform, 200 iters, double precision, elasticnet=0.1, featurehasher, 0.1 sample, multiclass, float64\\n\",\n    \"pipeline_model = pipeline.fit(df_train)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"predictions = pipeline_model.transform(df_test)\\n\",\n    \"predictions.sample(0.1).show(1, vertical=True, truncate=False)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"evaluator = MulticlassClassificationEvaluator().setMetricName('logLoss').setLabelCol(label_col_name)\\n\",\n    \"eval_res = evaluator.evaluate(predictions)\\n\",\n    \"end = time.time()\\n\",\n    \"print(f'Evaluation result: {eval_res}')\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"ml_dur = end - start\\n\",\n    \"print(f'ML takes {ml_dur}')\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Save current run times  \"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Preserve across container restarts\\n\",\n    \"\\n\",\n    \"local_gpu_times_file = f'gpu_times.csv'\\n\",\n    \"local_cpu_times_file = f'cpu_times.csv'\\n\",\n    \"\\n\",\n    \"run_times = pd.Series({'etl' : etl_dur, 'ml' : ml_dur})\\n\",\n    \"run_times.to_csv(local_gpu_times_file if accelerate_on_gpu else local_cpu_times_file, index=True, header=False)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Visualize acceleration\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"if os.path.exists(local_cpu_times_file) and os.path.exists(local_gpu_times_file):\\n\",\n    \"  cpu_times = pd.read_csv(local_cpu_times_file, header=None, index_col=0)\\n\",\n    \"  gpu_times = pd.read_csv(local_gpu_times_file, header=None, index_col=0)\\n\",\n    \"  gpu_speedup = cpu_times / gpu_times\\n\",\n    \"  gpu_speedup.plot(kind='bar', \\n\",\n    \"    title='GPU Acceleration Factor (> 1.0 is good)', \\n\",\n    \"    color='#76B900', \\n\",\n    \"    legend=False)\\n\",\n    \"  cpu_times = cpu_times[1].rename('cpu')\\n\",\n    \"  gpu_times = gpu_times[1].rename('gpu')\\n\",\n    \"  times = pd.DataFrame([cpu_times, gpu_times]).transpose()\\n\",\n    \"  times.plot(kind='bar', \\n\",\n    \"    title = 'ETL and ML elapsed times for CPU and GPU (lower is better)', \\n\",\n    \"    color=['blue', '#76B900'])\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3 (ipykernel)\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.11.6\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 4\n}\n"
  },
  {
    "path": "examples/spark-connect-gpu/client/notebook/work/csv_raw_schema.ddl",
    "content": "reference_pool_id STRING,\nloan_id BIGINT,\nmonthly_reporting_period STRING,\norig_channel STRING,\nseller_name STRING,\nservicer STRING,\nmaster_servicer STRING,\norig_interest_rate DOUBLE,\ninterest_rate DOUBLE,\norig_upb DOUBLE,\nupb_at_issuance STRING,\ncurrent_actual_upb DOUBLE,\norig_loan_term INT,\norig_date STRING,\nfirst_pay_date STRING,\nloan_age DOUBLE,\nremaining_months_to_legal_maturity DOUBLE,\nadj_remaining_months_to_maturity DOUBLE,\nmaturity_date STRING,\norig_ltv DOUBLE,\norig_cltv DOUBLE,\nnum_borrowers DOUBLE,\ndti DOUBLE,\nborrower_credit_score DOUBLE,\ncoborrow_credit_score DOUBLE,\nfirst_home_buyer STRING,\nloan_purpose STRING,\nproperty_type STRING,\nnum_units INT,\noccupancy_status STRING,\nproperty_state STRING,\nmsa DOUBLE,\nzip INT,\nmortgage_insurance_percent DOUBLE,\nproduct_type STRING,\nprepayment_penalty_indicator STRING,\ninterest_only_loan_indicator STRING,\ninterest_only_first_principal_and_interest_payment_date STRING,\nmonths_to_amortization STRING,\ncurrent_loan_delinquency_status INT,\nloan_payment_history STRING,\nmod_flag STRING,\nmortgage_insurance_cancellation_indicator STRING,\nzero_balance_code STRING,\nzero_balance_effective_date STRING,\nupb_at_the_time_of_removal STRING,\nrepurchase_date STRING,\nscheduled_principal_current STRING,\ntotal_principal_current STRING,\nunscheduled_principal_current STRING,\nlast_paid_installment_date STRING,\nforeclosed_after STRING,\ndisposition_date STRING,\nforeclosure_costs DOUBLE,\nprop_preservation_and_repair_costs DOUBLE,\nasset_recovery_costs DOUBLE,\nmisc_holding_expenses DOUBLE,\nholding_taxes DOUBLE,\nnet_sale_proceeds DOUBLE,\ncredit_enhancement_proceeds DOUBLE,\nrepurchase_make_whole_proceeds STRING,\nother_foreclosure_proceeds DOUBLE,\nnon_interest_bearing_upb DOUBLE,\nprincipal_forgiveness_upb STRING,\noriginal_list_start_date STRING,\noriginal_list_price STRING,\ncurrent_list_start_date STRING,\ncurrent_list_price STRING,\nborrower_credit_score_at_issuance STRING,\n`co-borrower_credit_score_at_issuance` STRING,\nborrower_credit_score_current STRING,\n`co-Borrower_credit_score_current` STRING,\nmortgage_insurance_type DOUBLE,\nservicing_activity_indicator STRING,\ncurrent_period_modification_loss_amount STRING,\ncumulative_modification_loss_amount STRING,\ncurrent_period_credit_event_net_gain_or_loss STRING,\ncumulative_credit_event_net_gain_or_loss STRING,\nhomeready_program_indicator STRING,\nforeclosure_principal_write_off_amount STRING,\nrelocation_mortgage_indicator STRING,\nzero_balance_code_change_date STRING,\nloan_holdback_indicator STRING,\nloan_holdback_effective_date STRING,\ndelinquent_accrued_interest STRING,\nproperty_valuation_method STRING,\nhigh_balance_loan_indicator STRING,\n`arm_initial_fixed-rate_period_lt_5_yr_indicator` STRING,\narm_product_type STRING,\n`initial_fixed-rate_period` STRING,\ninterest_rate_adjustment_frequency STRING,\nnext_interest_rate_adjustment_date STRING,\nnext_payment_change_date STRING,\nindex STRING,\narm_cap_structure STRING,\ninitial_interest_rate_cap_up_percent STRING,\nperiodic_interest_rate_cap_up_percent STRING,\nlifetime_interest_rate_cap_up_percent STRING,\nmortgage_margin STRING,\narm_balloon_indicator STRING,\narm_plan_number STRING,\nborrower_assistance_plan STRING,\nhltv_refinance_option_indicator STRING,\ndeal_name STRING,\nrepurchase_make_whole_proceeds_flag STRING,\nalternative_delinquency_resolution STRING,\nalternative_delinquency_resolution_count STRING,\ntotal_deferral_amount STRING"
  },
  {
    "path": "examples/spark-connect-gpu/client/notebook/work/name_mapping.csv",
    "content": "\"WITMER FUNDING, LLC\",Witmer\nWELLS FARGO CREDIT RISK TRANSFER SECURITIES TRUST 2015,Wells Fargo\n\"WELLS FARGO BANK,  NA\",Wells Fargo\n\"WELLS FARGO BANK, N.A.\",Wells Fargo\n\"WELLS FARGO BANK, NA\",Wells Fargo\nUSAA FEDERAL SAVINGS BANK,USAA\n\"UNITED SHORE FINANCIAL SERVICES, LLC D\\/B\\/A UNITED WHOLESALE MORTGAGE\",United Seq(e\nU.S. BANK N.A.,US Bank\nSUNTRUST MORTGAGE INC.,Suntrust\nSTONEGATE MORTGAGE CORPORATION,Stonegate Mortgage\n\"STEARNS LENDING, LLC\",Stearns Lending\n\"STEARNS LENDING, INC.\",Stearns Lending\n\"SIERRA PACIFIC MORTGAGE COMPANY, INC.\",Sierra Pacific Mortgage\nREGIONS BANK,Regions\nRBC MORTGAGE COMPANY,RBC\nQUICKEN LOANS INC.,Quicken Loans\n\"PULTE MORTGAGE, L.L.C.\",Pulte Mortgage\n\"PROVIDENT FUNDING ASSOCIATES, L.P.\",Provident Funding\n\"PROSPECT MORTGAGE, LLC\",Prospect Mortgage\n\"PRINCIPAL RESIDENTIAL MORTGAGE CAPITAL RESOURCES, LLC\",Principal Residential\n\"PNC BANK, N.A.\",PNC\nPMT CREDIT RISK TRANSFER TRUST 2015-2,PennyMac\nPHH MORTGAGE CORPORATION,PHH Mortgage\nPENNYMAC CORP.,PennyMac\n\"PACIFIC UNION FINANCIAL, LLC\",Other\nOTHER,Other\n\"NYCB MORTGAGE COMPANY, LLC\",NYCB\nNEW YORK COMMUNITY BANK,NYCB\nNETBANK FUNDING SERVICES,Netbank\n\"NATIONSTAR MORTGAGE, LLC\",Nationstar Mortgage\n\"METLIFE BANK, NA\",Metlife\n\"LOANDEPOT.COM, LLC\",LoanDepot.com\n\"J.P. MORGAN MADISON AVENUE SECURITIES TRUST, SERIES 2015-1\",JP Morgan Chase\n\"J.P. MORGAN MADISON AVENUE SECURITIES TRUST, SERIES 2014-1\",JP Morgan Chase\n\"JPMORGAN CHASE BANK, NATIONAL ASSOCIATION\",JP Morgan Chase\n\"JPMORGAN CHASE BANK, NA\",JP Morgan Chase\n\"JP MORGAN CHASE BANK, NA\",JP Morgan Chase\n\"IRWIN MORTGAGE, CORPORATION\",Irwin Mortgage\nIMPAC MORTGAGE CORP.,Impac Mortgage\n\"HSBC BANK USA, NATIONAL ASSOCIATION\",HSBC\n\"HOMEWARD RESIDENTIAL, INC.\",Homeward Mortgage\nHOMESTREET BANK,Other\n\"HOMEBRIDGE FINANCIAL SERVICES, INC.\",HomeBridge\n\"HARWOOD STREET FUNDING I, LLC\",Harwood Mortgage\nGUILD MORTGAGE COMPANY,Guild Mortgage\n\"GMAC MORTGAGE, LLC (USAA FEDERAL SAVINGS BANK)\",GMAC\n\"GMAC MORTGAGE, LLC\",GMAC\nGMAC (USAA),GMAC\nFREMONT BANK,Fremont Bank\nFREEDOM MORTGAGE CORP.,Freedom Mortgage\nFRANKLIN AMERICAN MORTGAGE COMPANY,Franklin America\nFLEET NATIONAL BANK,Fleet National\nFLAGSTAR CAPITAL MARKETS CORPORATION,Flagstar Bank\n\"FLAGSTAR BANK, FSB\",Flagstar Bank\nFIRST TENNESSEE BANK NATIONAL ASSOCIATION,Other\nFIFTH THIRD BANK,Fifth Third Bank\nFEDERAL HOME LOAN BANK OF CHICAGO,Fedral Home of Chicago\n\"FDIC, RECEIVER, INDYMAC FEDERAL BANK FSB\",FDIC\n\"DOWNEY SAVINGS AND LOAN ASSOCIATION, F.A.\",Downey Mortgage\nDITECH FINANCIAL LLC,Ditech\n\"CITIMORTGAGE, INC.\",Citi\nCHICAGO MORTGAGE SOLUTIONS DBA INTERFIRST MORTGAGE COMPANY,Chicago Mortgage\nCHICAGO MORTGAGE SOLUTIONS DBA INTERBANK MORTGAGE COMPANY,Chicago Mortgage\n\"CHASE HOME FINANCE, LLC\",JP Morgan Chase\nCHASE HOME FINANCE FRANKLIN AMERICAN MORTGAGE COMPANY,JP Morgan Chase\nCHASE HOME FINANCE (CIE 1),JP Morgan Chase\nCHASE HOME FINANCE,JP Morgan Chase\n\"CASHCALL, INC.\",CashCall\n\"CAPITAL ONE, NATIONAL ASSOCIATION\",Capital One\n\"CALIBER HOME LOANS, INC.\",Caliber Funding\nBISHOPS GATE RESIDENTIAL MORTGAGE TRUST,Bishops Gate Mortgage\n\"BANK OF AMERICA, N.A.\",Bank of America\nAMTRUST BANK,AmTrust\nAMERISAVE MORTGAGE CORPORATION,Amerisave\n\"AMERIHOME MORTGAGE COMPANY, LLC\",AmeriHome Mortgage\nALLY BANK,Ally Bank\nACADEMY MORTGAGE CORPORATION,Academy Mortgage\nNO CASH-OUT REFINANCE,OTHER REFINANCE\nREFINANCE - NOT SPECIFIED,OTHER REFINANCE\nOther REFINANCE,OTHER REFINANCE\n"
  },
  {
    "path": "examples/spark-connect-gpu/client/python/batch-job.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"initial_id\",\n   \"metadata\": {\n    \"collapsed\": true\n   },\n   \"outputs\": [],\n   \"source\": \"%run batch-job.py\"\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 2\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython2\",\n   \"version\": \"2.7.6\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "examples/spark-connect-gpu/client/python/batch-job.py",
    "content": "# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n# Licensed to the Apache Software Foundation (ASF) under one or more\n# contributor license agreements.  See the NOTICE file distributed with\n# this work for additional information regarding copyright ownership.\n# The ASF licenses this file to You under the Apache License, Version 2.0\n# (the \"License\"); you may not use this file except in compliance with\n# the License.  You may obtain a copy of the License at\n#\n#    http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nfrom pyspark.sql import SparkSession\nfrom pyspark.sql.functions import *\n\nspark = (SparkSession\n         .builder\n         .getOrCreate()\n         )\n\ndf = (\n    spark.range(2 ** 12)\n    .withColumn(\"mod10\", col(\"id\") % lit(10))\n    .groupBy(\"mod10\").agg(count(\"*\"))\n    .orderBy(\"mod10\")\n)\n# workaround to get a plan with GpuOverrides applied by disabling adaptive execution\ndef explain(dataframe):\n    spark.conf.set(\"spark.sql.adaptive.enabled\", False)\n    dataframe.explain(mode=\"extended\")\n    spark.conf.set(\"spark.sql.adaptive.enabled\", True)\n\n## Disable GPU accelerating\nprint(\"--------------- CPU running by disabling spark.rapids.sql.enabled ---------------\")\nspark.conf.set(\"spark.rapids.sql.enabled\", False)\nexplain(df)\ndf.show()\n\n## Enable GPU accelerating\nspark.conf.set(\"spark.rapids.sql.enabled\", True)\nprint(\"--------------- GPU running by enabling spark.rapids.sql.enabled ---------------\")\nexplain(df)\ndf.show()\n\nspark.stop()\n"
  },
  {
    "path": "examples/spark-connect-gpu/client/requirements.txt",
    "content": "# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n# Licensed to the Apache Software Foundation (ASF) under one or more\n# contributor license agreements.  See the NOTICE file distributed with\n# this work for additional information regarding copyright ownership.\n# The ASF licenses this file to You under the Apache License, Version 2.0\n# (the \"License\"); you may not use this file except in compliance with\n# the License.  You may obtain a copy of the License at\n#\n#    http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\njupyterlab\nmatplotlib\n# https://spark.apache.org/docs/latest/api/python/getting_started/install.html#python-spark-connect-client\n# ... pure Python library\npyspark-client==4.0.0\n"
  },
  {
    "path": "examples/spark-connect-gpu/client/scala/.gitignore",
    "content": "target\n.idea\n"
  },
  {
    "path": "examples/spark-connect-gpu/client/scala/pom.xml",
    "content": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project xmlns=\"http://maven.apache.org/POM/4.0.0\"\n    xmlns:xsi=\"http://www.w3.org/2001/XMLSchema-instance\"\n    xsi:schemaLocation=\"http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd\">\n    <modelVersion>4.0.0</modelVersion>\n    <groupId>com.example</groupId>\n    <artifactId>spark-connect-demo</artifactId>\n    <version>1.0-SNAPSHOT</version>\n\n    <properties>\n        <spark.connect.version>4.0.0</spark.connect.version>\n        <scala.version>2.13.16</scala.version>\n        <scala.binary.version>2.13</scala.binary.version>\n    </properties>\n\n    <dependencies>\n        <dependency>\n            <groupId>org.apache.spark</groupId>\n            <artifactId>spark-connect-client-jvm_${scala.binary.version}</artifactId>\n            <version>${spark.connect.version}</version>\n        </dependency>\n        <dependency>\n            <groupId>org.scala-lang</groupId>\n            <artifactId>scala-library</artifactId>\n            <version>${scala.version}</version>\n        </dependency>\n    </dependencies>\n\n    <build>\n        <plugins>\n            <plugin>\n                <groupId>net.alchim31.maven</groupId>\n                <artifactId>scala-maven-plugin</artifactId>\n                <version>4.9.6</version>\n                <configuration>\n                    <javacArgs combine.children=\"append\">\n                        <javacArg>-XDignore.symbol.file</javacArg>\n                    </javacArgs>\n                    <fork>true</fork>\n                </configuration>\n                <executions>\n                    <execution>\n                        <id>scala-compile-first</id>\n                        <phase>process-resources</phase>\n                        <goals>\n                            <goal>add-source</goal>\n                            <goal>compile</goal>\n                        </goals>\n                    </execution>\n                </executions>\n            </plugin>\n            <plugin>\n                <groupId>org.apache.maven.plugins</groupId>\n                <artifactId>maven-assembly-plugin</artifactId>\n                <version>3.7.1</version>\n                <configuration>\n                    <descriptorRefs>\n                        <descriptorRef>jar-with-dependencies</descriptorRef>\n                    </descriptorRefs>\n                </configuration>\n                <executions>\n                    <execution>\n                        <id>assembly</id>\n                        <phase>package</phase>\n                        <goals>\n                            <goal>single</goal>\n                        </goals>\n                    </execution>\n                </executions>\n            </plugin>\n        </plugins>\n    </build>\n</project>"
  },
  {
    "path": "examples/spark-connect-gpu/client/scala/run.sh",
    "content": "#! /bin/bash\n\n# work for jdk 17\njava \\\n--add-exports=java.base/sun.nio.ch=ALL-UNNAMED \\\n--add-opens=java.base/java.nio=ALL-UNNAMED \\\n--add-opens=java.base/java.lang.invoke=ALL-UNNAMED \\\n--add-opens=java.base/java.util=ALL-UNNAMED \\\n--add-opens=java.base/sun.security.action=ALL-UNNAMED  \\\n  -cp spark-connect-demo-1.0-SNAPSHOT-jar-with-dependencies.jar connect\n\n"
  },
  {
    "path": "examples/spark-connect-gpu/client/scala/scala-run.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"initial_id\",\n   \"metadata\": {\n    \"collapsed\": true\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"%%bash\\n\",\n    \"java \\\\\\n\",\n    \"  --add-exports=java.base/sun.nio.ch=ALL-UNNAMED \\\\\\n\",\n    \"  --add-opens=java.base/java.nio=ALL-UNNAMED \\\\\\n\",\n    \"  --add-opens=java.base/java.lang.invoke=ALL-UNNAMED \\\\\\n\",\n    \"  --add-opens=java.base/java.util=ALL-UNNAMED \\\\\\n\",\n    \"  --add-opens=java.base/sun.security.action=ALL-UNNAMED  \\\\\\n\",\n    \"  -cp spark-connect-demo-1.0-SNAPSHOT-jar-with-dependencies.jar connect\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 2\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython2\",\n   \"version\": \"2.7.6\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "examples/spark-connect-gpu/client/scala/src/main/scala/connect.scala",
    "content": "// Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n// Licensed to the Apache Software Foundation (ASF) under one or more\n// contributor license agreements.  See the NOTICE file distributed with\n// this work for additional information regarding copyright ownership.\n// The ASF licenses this file to You under the Apache License, Version 2.0\n// (the \"License\"); you may not use this file except in compliance with\n// the License.  You may obtain a copy of the License at\n//\n//    http://www.apache.org/licenses/LICENSE-2.0\n//\n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\nimport org.apache.spark.sql.{DataFrame, SparkSession}\nimport org.apache.spark.sql.functions._\n\nobject connect extends Serializable {\n\n  def explain(df: DataFrame): Unit = {\n    // workaround to get a plan with GpuOverrides applied by disabling adaptive execution\n    df.sparkSession.conf.set(\"spark.sql.adaptive.enabled\", \"false\")\n    df.explain(mode = \"extended\")\n    df.sparkSession.conf.set(\"spark.sql.adaptive.enabled\", \"true\")\n  }\n\n  def main(args: Array[String]): Unit = {\n    val spark = SparkSession.builder().getOrCreate()\n    val df = spark.range(1L << 12)\n      .withColumn(\"mod10\", col(\"id\") % lit(10))\n      .groupBy(\"mod10\").agg(count(\"*\"))\n      .orderBy(\"mod10\")\n\n    // Disable GPU accelerating\n    println(\"--------------- CPU running by disabling spark.rapids.sql.enabled ---------------\")\n    spark.conf.set(\"spark.rapids.sql.enabled\", \"false\")\n    explain(df)\n    df.show()\n\n    // Enable GPU accelerating\n    spark.conf.set(\"spark.rapids.sql.enabled\", \"true\")\n    println(\"--------------- GPU running by enabling spark.rapids.sql.enabled ---------------\")\n    explain(df)\n    df.show()\n\n    spark.stop()\n  }\n}\n"
  },
  {
    "path": "examples/spark-connect-gpu/server/README.md",
    "content": "# GPU-Accelerated Spark Connect Server\n\nThis project demonstrates how to set up a GPU-accelerated Spark server using Apache Spark 4.0\nwith Spark Connect, featuring the RAPIDS Accelerator.\n\n## 🚀 Key Features\n\n- **Apache Spark 4.0** with cutting-edge Spark Connect capabilities\n- **GPU acceleration** via RAPIDS Accelerator\n- **MLlib over Spark Connect** - new in Spark 4.0\n- **Zero-code-change acceleration** - existing Spark applications automatically benefit\n- **Jupyter Lab integration** for interactive development\n- **Docker Compose** setup for easy deployment with clear distinction what dependencies are\nrequired by what service and where GPUs are really used\n\n## 🏗️ Architecture\n\nThe setup consists of four Docker services:\n\n### Apache Spark Standalone Cluster \n1. **Spark Master** (`spark-master`) - Cluster coordination and job scheduling. This container does \nnot have GPU capability\n\n2. **Spark Worker** (`spark-worker`) - GPU-enabled worker node for task execution. This is the only \nservice requiring and having access to the host GPUs \n\n### Middle Tier \n3. **Spark Connect Server** (`spark-connect-server`) - gRPC interface with the RAPIDS integration\n\n### Proxy Service\n4. nginx configured as provide access to various Apache Spark WebUI using the Docker network\n\n### Frontend Web Browser\n5. WebUI for the Spark Connect Server and the Spark Standalone Cluster\n\nTo reduce the complexity of the demo, no services for global storage is included.\nThe demo relies on the **DATA_DIR** location mounted from the host in place of a storage\nservice. This location is also used for convenience to preserve metrics and\nSpark event logs beyond the container life cycle for analysis or debugging.\n\nWhen the **DATA_DIR** is accessed in a way that would normally require a global access\nwe indicate this by using the `global_` prefix for the variable storing the complete\npath. Otherwise, we use variables starting with `local_`.\n\n## 📋 Prerequisites\n\n### Required\n- [Docker](https://docs.docker.com/engine/install/) and [Docker Compose](https://docs.docker.com/compose/install/linux)\n- At least 8GB of available RAM\n- Available ports: 2080, 8080, 8081, 8888, 7077, 4040, 15002\n\n### For GPU Acceleration\n- NVIDIA GPU with CUDA compute capability supported by RAPIDS\n- [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html)\n- Docker Compose version should be `2.30.x` or newer to avoid an NVIDIA Container Toolkit related bug.  [Update](https://docs.docker.com/compose/install/linux) if necessary\n- CUDA 12.x drivers\n\n## 🚀 Quick Start\n\n1. **Clone and navigate to the project:**\n   ```bash\n   cd examples/spark-connect-gpu/server\n   ```\n\n2. **Set up data directory (if needed):**\n   ```bash\n   export DATA_DIR=$(pwd)/data\n   mkdir -p $DATA_DIR/mortgage.input.csv $DATA_DIR/spark-events $DATA_DIR/nds\n   chmod 1777 $DATA_DIR $DATA_DIR/spark-events \n   ```\n   Download a few quarters worth of the [Mortgage Dataset](https://capitalmarkets.fanniemae.com/credit-risk-transfer/single-family-credit-risk-transfer/fannie-mae-single-family-loan-performance-data)\n   to the `$DATA_DIR/mortgage.input.csv` location. More details can refer to [How to download the Mortgage dataset](https://github.com/NVIDIA/spark-rapids-examples/blob/main/docs/get-started/xgboost-examples/dataset/mortgage.md)\n\n   To run NDS (see [NDS v2.0 Automation](https://github.com/NVIDIA/spark-rapids-benchmarks/tree/dev/nds#nds-v20-automation)),\n   generate the dataset and place it in \"$DATA_DIR/nds\". For more details,\n   refer to [NDS Data Generation](https://github.com/NVIDIA/spark-rapids-benchmarks/tree/dev/nds#data-generation).\n\n3. **Start all services:**\n\n   ```bash\n   $ docker compose up -d\n   ```\n   (`docker compose` can be used in place of `docker-compose` here and throughout)\n\n4. **Access the Web UI interfaces:**\n\n   ***Option 1 (default)***\n\n   All containers' webUI are available using localhost URI's by default\n\n   - **Spark Master UI**: http://localhost:8080 - Cluster coordination and resource management\n   - **Spark Worker UI**: http://localhost:8081 - GPU-enabled worker node status and tasks\n   - **Spark Driver UI**: http://localhost:4040 - Application monitoring and SQL queries\n    \n   ***Option 2***\n\n   if you launch docker compose in the environment with `SPARK_PUBLIC_DNS=container-hostname`, all containers'\n   web UI but Jupyter Lab is available using the corresponding container host names such as spark-master\n  \n   - **Spark Master UI**: http://spark-master:8080 - Cluster coordination and resource management\n   - **Spark Worker UI**: http://spark-worker:8081 - GPU-enabled worker node status and tasks\n   - **Spark Driver UI**: http://spark-connect-server:4040 - Application monitoring and SQL queries\n   \n   Docker DNS names require configuring your browser an http proxy on the Docker network exposed at http://localhost:2080.\n  \n   Here are examples of launching Google Chrome with a temporary user profile without making persistent changes on the browser\n\n   ***Linux***\n\n   ```bash\n   $ google-chrome --user-data-dir=\"/tmp/chrome-proxy-profile\" --proxy-server=\"http=http://localhost:2080\"\n   ```\n\n   ***macOS***\n\n   ```bash\n   $ open -n -a \"Google Chrome\" --args --user-data-dir=\"/tmp/chrome-proxy-profile\" --proxy-server=\"http=http://localhost:2080\"\n   ```\n\n   ***Launching containers on a remote machine***\n\n   Your local machine might not have a GPU, and it is common in this case to use a\n   remote machine/cluster with GPUs residing in a remote Cloud or on-prem environment\n\n   If you followed the default Option 1 make sure to create local port forwards for\n   every webUI port\n\n   ```bash\n   ssh <user@gpu-host> -L 8888:localhost:8888 -L 8080:localhost:8080 -L 8081:localhost:8081 -L 4040:localhost:4040\n   ```\n\n   if you used Option 2 it is sufficient to forward ports only for the HTTP proxy and the Notebook app:\n  \n   ```bash\n   ssh <user@gpu-host> -L 2080:localhost:2080 -L 8888:localhost:8888\n   ```\n\n## 🐳 Service Details\n\n### Spark Master\n- **Image**: `apache/spark:4.0.0`\n- **Ports**: 8080 (Web UI), 7077 (Master)\n- **Role**: Cluster coordination and resource management\n\n### Spark Worker (the only GPU node role)\n- **Image**: Custom build based on `apache/spark:4.0.0`\n- **GPU**: NVIDIA GPU support via Docker Compose deploy configuration\n- **Ports**: 8081 (Web UI)\n- **Features**: GPU resource discovery and task execution\n\n### Spark Connect Server\n- **Image**: Custom build based on `apache/spark:4.0.0` with Spark RAPIDS ETL and ML Plugins\n- **RAPIDS Version**: 26.02.0 for CUDA 12\n- **Ports**: 15002 (gRPC), 4040 (Driver UI)\n- **Configuration**: Optimized for GPU acceleration with memory management\n\n## 📊 Performance Monitoring\n\nYou can use tools like nvtop, nvitop, btop or jupyterlab_nvdashboard running on the GPU host(s)\n\n\n## 🧹 Cleanup\n\nStop and remove all services:\n```bash\ndocker-compose down -v\n```\n\nRemove built images:\n```bash\ndocker-compose down --rmi all -v\n```\n\n### Logs\nLogs for the spark driver/connect server, standalone master, standalone worker, and jupyter server can be viewed using the respective commands:\n```bash\ndocker logs spark-connect-server\ndocker logs spark-master\ndocker logs spark-worker\n```\n\nSpark executor logs can be accessed via the Spark UI as usual.\n\n## 📖 Additional Resources\n\n- [Apache Spark 4.0 Documentation](https://spark.apache.org/docs/latest/)\n- [Spark Connect Guide](https://spark.apache.org/docs/latest/spark-connect-overview.html)\n- [NVIDIA RAPIDS Accelerator](https://nvidia.github.io/spark-rapids/)\n- [Data and AI Summit Session](https://www.databricks.com/dataaisummit/session/gpu-accelerated-spark-connect)\n"
  },
  {
    "path": "examples/spark-connect-gpu/server/docker-compose.yaml",
    "content": "# Copyright (c) 2025-2026, NVIDIA CORPORATION. All rights reserved.\n# Licensed to the Apache Software Foundation (ASF) under one or more\n# contributor license agreements.  See the NOTICE file distributed with\n# this work for additional information regarding copyright ownership.\n# The ASF licenses this file to You under the Apache License, Version 2.0\n# (the \"License\"); you may not use this file except in compliance with\n# the License.  You may obtain a copy of the License at\n#\n#    http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\n# YAML anchors for shared configurations\nx-spark-common: &spark-common\n  networks:\n    - spark-network\n  volumes:\n    - ${DATA_DIR:-${PWD}/data}:/data\nx-spark-common-env: &spark-common-env\n  SPARK_PUBLIC_DNS: \"${SPARK_PUBLIC_DNS:-localhost}\"\n  SPARK_NO_DAEMONIZE: \"1\"\n\nservices:\n  # Spark Master Node\n  spark-master:\n    <<: *spark-common\n    image: spark-master-image\n    build:\n      context: ./spark-master\n      dockerfile: Dockerfile\n    container_name: spark-master\n    hostname: spark-master\n    environment:\n      <<: *spark-common-env\n    ports:\n      - \"8080:8080\"   # Spark Master Web UI\n      - \"7077:7077\"   # Spark Master Port\n    command: /opt/spark/sbin/start-master.sh\n\n  # Spark Worker Node (GPU-enabled)\n  spark-worker:\n    <<: *spark-common\n    image: spark-worker-image\n    build:\n      context: ./spark-worker\n      dockerfile: Dockerfile\n    container_name: spark-worker\n    hostname: spark-worker\n    environment:\n      <<: *spark-common-env\n    ports:\n      - \"8081:8081\"   # Spark Worker WebUI\n    depends_on:\n      - spark-master\n    command: /opt/spark/sbin/start-worker.sh spark://spark-master:7077\n    deploy:\n      resources:\n        reservations:\n          devices:\n            - driver: nvidia\n              capabilities: [gpu]\n\n  # Spark Connect Server\n  spark-connect-server:\n    <<: *spark-common\n    image: spark-connect-server-image\n    build:\n      context: ./spark-connect-server\n      dockerfile: Dockerfile\n      args:\n        - CUDA_VERSION=${CUDA_VERSION:-12}\n        - RAPIDS_VERSION=${RAPIDS_VERSION:-26.02.0}\n        - REPO_URL=${REPO_URL:-https://repo1.maven.org/maven2}\n    container_name: spark-connect-server\n    hostname: spark-connect-server\n    environment:\n      <<: *spark-common-env\n    ports:\n      - \"4040:4040\"               # Spark Driver WebUI\n      - \"15002:15002\"             # Spark Connect grpc\n    depends_on:\n      - spark-master\n      - spark-worker\n    command: >\n      /opt/spark/sbin/start-connect-server.sh\n        --driver-memory=24G\n        --conf spark.executor.memory=28G\n        --conf spark.executor.cores=8\n\n  proxy-service:\n    build:\n      context: ./proxy-service\n      dockerfile: Dockerfile\n    container_name: proxy-service\n    ports:\n      - \"2080:2080\"\n    networks:\n      - spark-network\n    depends_on:\n      - spark-master\n      - spark-worker\n      - spark-connect-server\n    restart: unless-stopped\n\nnetworks:\n  spark-network:\n    driver: bridge\n"
  },
  {
    "path": "examples/spark-connect-gpu/server/proxy-service/Dockerfile",
    "content": "# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n# Licensed to the Apache Software Foundation (ASF) under one or more\n# contributor license agreements.  See the NOTICE file distributed with\n# this work for additional information regarding copyright ownership.\n# The ASF licenses this file to You under the Apache License, Version 2.0\n# (the \"License\"); you may not use this file except in compliance with\n# the License.  You may obtain a copy of the License at\n#\n#    http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nFROM nginx:latest\n\n# Copy the nginx configuration file into the container\nCOPY nginx.conf /etc/nginx/nginx.conf\n\n"
  },
  {
    "path": "examples/spark-connect-gpu/server/proxy-service/nginx.conf",
    "content": "# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n# Licensed to the Apache Software Foundation (ASF) under one or more\n# contributor license agreements.  See the NOTICE file distributed with\n# this work for additional information regarding copyright ownership.\n# The ASF licenses this file to You under the Apache License, Version 2.0\n# (the \"License\"); you may not use this file except in compliance with\n# the License.  You may obtain a copy of the License at\n#\n#    http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nevents {\n    worker_connections 1024;\n}\n\nhttp {\n    resolver 127.0.0.11;\n\n    server {\n        listen 2080;\n\n        location / {\n            proxy_set_header        Host $http_host;\n            proxy_set_header        X-Real-IP $remote_addr;\n            proxy_set_header        X-Forwarded-For $proxy_add_x_forwarded_for;\n            proxy_set_header        X-Forwarded-Proto $scheme;\n            proxy_pass              $scheme://$http_host;\n            proxy_read_timeout      90;\n        }\n    }\n}"
  },
  {
    "path": "examples/spark-connect-gpu/server/spark-connect-server/Dockerfile",
    "content": "# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n# Licensed to the Apache Software Foundation (ASF) under one or more\n# contributor license agreements.  See the NOTICE file distributed with\n# this work for additional information regarding copyright ownership.\n# The ASF licenses this file to You under the Apache License, Version 2.0\n# (the \"License\"); you may not use this file except in compliance with\n# the License.  You may obtain a copy of the License at\n#\n#    http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nFROM apache/spark:4.0.0\n\nARG CUDA_VERSION\nARG RAPIDS_VERSION\nARG REPO_URL\n\nUSER root\nCOPY requirements.txt /tmp/requirements.txt\nRUN pip3 install -r /tmp/requirements.txt\n\nRUN mkdir -p /opt/spark-rapids/jars\nRUN chown -R spark:spark /opt/spark-rapids\n\nUSER spark\n\nENV CUDA_VERSION=${CUDA_VERSION}\nENV RAPIDS_VERSION=${RAPIDS_VERSION}\nENV REPO_URL=${REPO_URL}\n\nRUN wget -q ${REPO_URL}/com/nvidia/rapids-4-spark_2.13/${RAPIDS_VERSION}/rapids-4-spark_2.13-${RAPIDS_VERSION}-cuda${CUDA_VERSION}.jar -O /opt/spark-rapids/jars/rapids-4-spark-sql.jar\nCOPY spark-defaults.conf /opt/spark/conf/spark-defaults.conf\nCOPY spark-env.sh /opt/spark/conf/spark-env.sh\n"
  },
  {
    "path": "examples/spark-connect-gpu/server/spark-connect-server/requirements.txt",
    "content": "# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n# Licensed to the Apache Software Foundation (ASF) under one or more\n# contributor license agreements.  See the NOTICE file distributed with\n# this work for additional information regarding copyright ownership.\n# The ASF licenses this file to You under the Apache License, Version 2.0\n# (the \"License\"); you may not use this file except in compliance with\n# the License.  You may obtain a copy of the License at\n#\n#    http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\npackaging==23.1\npandas==2.2.3\npsutil\npyarrow\nscikit-learn>=1.2.1\nspark-rapids-ml==25.8.0"
  },
  {
    "path": "examples/spark-connect-gpu/server/spark-connect-server/spark-defaults.conf",
    "content": "# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n# Licensed to the Apache Software Foundation (ASF) under one or more\n# contributor license agreements.  See the NOTICE file distributed with\n# this work for additional information regarding copyright ownership.\n# The ASF licenses this file to You under the Apache License, Version 2.0\n# (the \"License\"); you may not use this file except in compliance with\n# the License.  You may obtain a copy of the License at\n#\n#    http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nspark.driver.maxResultSize=3g\nspark.eventLog.compress=true\nspark.eventLog.dir=/data/spark-events\nspark.eventLog.enabled=true\nspark.eventLog.compress=true\nspark.executor.resource.gpu.amount=1\nspark.executor.resource.gpu.discoveryScript=/opt/spark/examples/src/main/scripts/getGpusResources.sh\nspark.jars=/opt/spark-rapids/jars/rapids-4-spark-sql.jar,/usr/local/lib/python3.10/dist-packages/spark_rapids_ml/jars/com.nvidia.rapids.ml-25.08.0.jar\nspark.local.dir=/opt/spark/work\nspark.locality.wait=0\nspark.master=spark://spark-master:7077\nspark.plugins=com.nvidia.spark.SQLPlugin\nspark.rapids.memory.gpu.allocFraction=0.45\nspark.rapids.memory.gpu.maxAllocFraction=0.45\nspark.rapids.memory.gpu.minAllocFraction=0.0\nspark.rapids.ml.float32_inputs=false\nspark.rapids.ml.python.transform.enabled=false\nspark.rapids.ml.verbose=6\nspark.rapids.sql.batchSizeBytes=512m\nspark.rapids.sql.concurrentGpuTasks=4\nspark.rapids.sql.debug.logTransformations=true\nspark.rapids.sql.explain=ALL\nspark.shuffle.manager=com.nvidia.spark.rapids.spark400.RapidsShuffleManager\nspark.sql.ansi.enabled=false\nspark.sql.files.maxPartitionBytes=512m\nspark.sql.session.timeZone=UTC\nspark.task.resource.gpu.amount=0.0625\n"
  },
  {
    "path": "examples/spark-connect-gpu/server/spark-connect-server/spark-env.sh",
    "content": "#!/bin/bash\n\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n# Licensed to the Apache Software Foundation (ASF) under one or more\n# contributor license agreements.  See the NOTICE file distributed with\n# this work for additional information regarding copyright ownership.\n# The ASF licenses this file to You under the Apache License, Version 2.0\n# (the \"License\"); you may not use this file except in compliance with\n# the License.  You may obtain a copy of the License at\n#\n#    http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nif [[ \"$SPARK_PUBLIC_DNS\" == \"container-hostname\" ]]; then\n  export SPARK_PUBLIC_DNS=$(hostname)\nelif [[ \"$SPARK_PUBLIC_DNS\" != \"\" ]]; then\n  # handles default localhost or any other custom value\n  export SPARK_PUBLIC_DNS\nfi"
  },
  {
    "path": "examples/spark-connect-gpu/server/spark-master/Dockerfile",
    "content": "# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n# Licensed to the Apache Software Foundation (ASF) under one or more\n# contributor license agreements.  See the NOTICE file distributed with\n# this work for additional information regarding copyright ownership.\n# The ASF licenses this file to You under the Apache License, Version 2.0\n# (the \"License\"); you may not use this file except in compliance with\n# the License.  You may obtain a copy of the License at\n#\n#    http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nFROM apache/spark:4.0.0\n\nCOPY spark-env.sh /opt/spark/conf/spark-env.sh"
  },
  {
    "path": "examples/spark-connect-gpu/server/spark-master/spark-env.sh",
    "content": "#!/bin/bash\n\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n# Licensed to the Apache Software Foundation (ASF) under one or more\n# contributor license agreements.  See the NOTICE file distributed with\n# this work for additional information regarding copyright ownership.\n# The ASF licenses this file to You under the Apache License, Version 2.0\n# (the \"License\"); you may not use this file except in compliance with\n# the License.  You may obtain a copy of the License at\n#\n#    http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nif [[ \"$SPARK_PUBLIC_DNS\" == \"container-hostname\" ]]; then\n  export SPARK_PUBLIC_DNS=$(hostname)\nelif [[ \"$SPARK_PUBLIC_DNS\" != \"\" ]]; then\n  # handles default localhost or any other custom value\n  export SPARK_PUBLIC_DNS\nfi"
  },
  {
    "path": "examples/spark-connect-gpu/server/spark-worker/Dockerfile",
    "content": "# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n# Licensed to the Apache Software Foundation (ASF) under one or more\n# contributor license agreements.  See the NOTICE file distributed with\n# this work for additional information regarding copyright ownership.\n# The ASF licenses this file to You under the Apache License, Version 2.0\n# (the \"License\"); you may not use this file except in compliance with\n# the License.  You may obtain a copy of the License at\n#\n#    http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nFROM apache/spark:4.0.0\n\nUSER root\nCOPY requirements.txt /tmp/requirements.txt\nRUN pip3 install --extra-index-url=https://pypi.nvidia.com -r /tmp/requirements.txt\n\n# TODO hack to avoid configuring cupy compiler path\nRUN mkdir -p /home/spark\nRUN chown -R spark:spark /home/spark\nRUN usermod -d /home/spark spark\n\nUSER spark\nCOPY spark-env.sh /opt/spark/conf/spark-env.sh\n"
  },
  {
    "path": "examples/spark-connect-gpu/server/spark-worker/requirements.txt",
    "content": "# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n# Licensed to the Apache Software Foundation (ASF) under one or more\n# contributor license agreements.  See the NOTICE file distributed with\n# this work for additional information regarding copyright ownership.\n# The ASF licenses this file to You under the Apache License, Version 2.0\n# (the \"License\"); you may not use this file except in compliance with\n# the License.  You may obtain a copy of the License at\n#\n#    http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\ncuml-cu12~=25.8.0\nnumpy~=1.0\nspark-rapids-ml==25.8.0"
  },
  {
    "path": "examples/spark-connect-gpu/server/spark-worker/spark-env.sh",
    "content": "#!/bin/bash\n\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n# Licensed to the Apache Software Foundation (ASF) under one or more\n# contributor license agreements.  See the NOTICE file distributed with\n# this work for additional information regarding copyright ownership.\n# The ASF licenses this file to You under the Apache License, Version 2.0\n# (the \"License\"); you may not use this file except in compliance with\n# the License.  You may obtain a copy of the License at\n#\n#    http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nif [[ \"$SPARK_PUBLIC_DNS\" == \"container-hostname\" ]]; then\n  export SPARK_PUBLIC_DNS=$(hostname)\nelif [[ \"$SPARK_PUBLIC_DNS\" != \"\" ]]; then\n  # handles default localhost or any other custom value\n  export SPARK_PUBLIC_DNS\nfi\n\nGPU_COUNT_MAX=$(nvidia-smi -L | wc -l)\nexport SPARK_WORKER_OPTS=\"\n  -Dspark.worker.resource.gpu.amount=${GPU_COUNT_MAX}\n  -Dspark.worker.resource.gpu.discoveryScript=/opt/spark/examples/src/main/scripts/getGpusResources.sh\n\"\n\n# workaround for wheels installation not setting the correct LD_LIBRARY_PATH\n# https://github.com/rapidsai/cuml/issues/5300#issuecomment-2084646729\nLD_LIBRARY_PATH=$(find /usr/local/lib/python3.10/dist-packages/nvidia -name lib -type d | xargs printf '%s:'):$LD_LIBRARY_PATH\nexport LD_LIBRARY_PATH\n"
  },
  {
    "path": "scripts/README.md",
    "content": "### Encoding Tool\nThis tool is to convert the values from categorical type to numerical type in certain columns. Currently we supoort `mean encoding` and `one-hot encoding`.\n\n### Main Procedure\n1. User should firstly use our tool to profile the raw data source to get a \"dictinary\"(We call this dictionary `model`) that maps categorical values to certain numerical values. We call this method `train`. Each column will have its own `model`\n2. User will use the `model` they got from step 1 to replace those categorical values with numerical values. \n\n### Usage\n1. `cd encoding/python`\n2. `zip -r sample.zip com` to get a python encoding tool library\n3. submit the encoding job to your Spark host\n\nYou can find full use cases in `encoding-sample/run.sh`\n\n### Application Parameters\n - mainClass: \n   \n   - `com.nvidia.spark.encoding.criteo.one_hot_cpu_main`: one-hot encoding\n   - `com.nvidia.spark.encoding.criteo.target_cpu_main`: target(mean) encoding\n - mode: \n   - `train`: use raw data to get encoding model\n   - `transform`: use encoding moddel to convert raw data\n - format:\n   - `csv`: only csv is supported\n - columns: \n   - the target columns user wants to convert, e.g. `_34,_35` means user wants to get dictionary for both `_34` and `_35` columns\n - modelPaths: \n   - for `train` mode, it points to the path where user wants to save the encoding model\n   - for `transform` mode, it points to the model that the encoding conversion needs.\n   - it is 1-1 mapped to `columns`. If user wants to encode 2 columns, he must provide 2 `modelPaths`. e.g. `model_34,model_35`\n - inputPaths: \n   - raw data user wants to get encoding model from, or to convert\n - outputPaths: \n   - only used in `transform` mode.\n - overwrite:\n   - whether overwrite the exsiting model or output data\n - numRows:\n   - optinal. show some rows in command line when encoding is finished. \n - labelColumn:\n   - required in `target encoding`. Set the label column of raw data.\n\n### Optimization\n1. Due to default behaviors from some Spark methods, Some value may contain useless precison which causes the large size of `model`.e.g. 0.000000 and 1.000000 are identical to 0 and 1 in value perspective, but the csv model file that contains those values costs more disk space. We provide `truncate-model.py` in `encoding-sample` to remove the extra useless precisions.\n2. We provide a repartition kit `repartition.py` to reparitition your output data.\n\nThe usage can also be found in `encoding-sample/run.sh`"
  },
  {
    "path": "scripts/building/python_build.sh",
    "content": "#!/bin/bash\n# Copyright (c) 2024-2025, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Follow these steps to package the Python zip file\ncd ../../examples/XGBoost-Examples\ncd agaricus/python ; zip -r ../../samples.zip com ; cd ../..\ncd mortgage/python ; zip -r ../../samples.zip com ; cd ../..\ncd taxi/python ; zip -r ../../samples.zip com ; cd ../..\ncd utility/python ; zip -r ../../samples.zip com ; cd ../..\n"
  },
  {
    "path": "scripts/csp-startup-scripts/README.md",
    "content": "# Startup Scripts for CSPs with Spark Rapids\n\nWith the exception of Dataproc, CSP offerings like EMR have specific set of steps that are required to enable the Spark Rapids Plugin in their environment. The set of scripts here automate parts of that process, for EMR currently. The exact usage can be found in our docs [here](https://docs.nvidia.com/spark-rapids/user-guide/latest/getting-started/aws-emr.html)\n"
  },
  {
    "path": "scripts/csp-startup-scripts/emr/cgroup-bootstrap-action-emr6.sh",
    "content": "#!/bin/bash\n#\n# Copyright (c) 2024-2026, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n\nset -ex\n\nsudo chmod a+rwx -R /sys/fs/cgroup/cpu,cpuacct\nsudo chmod a+rwx -R /sys/fs/cgroup/devices\n"
  },
  {
    "path": "scripts/csp-startup-scripts/emr/cgroup-bootstrap-action-emr7.sh",
    "content": "#!/bin/bash\n#\n# Copyright (c) 2024-2026, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\nset -ex\n\nsudo mkdir -p /spark-rapids-cgroup/devices\nsudo mount -t cgroup -o devices cgroupv1-devices /spark-rapids-cgroup/devices\nsudo chmod a+rwx -R /spark-rapids-cgroup\n"
  },
  {
    "path": "scripts/csp-startup-scripts/emr/config-emr6.json",
    "content": "[\n  {\n    \"Classification\":\"spark\",\n    \"Properties\":{\n      \"enableSparkRapids\":\"true\"\n    }\n  },\n  {\n    \"Classification\":\"yarn-site\",\n    \"Properties\":{\n      \"yarn.nodemanager.resource-plugins\":\"yarn.io/gpu\",\n      \"yarn.resource-types\":\"yarn.io/gpu\",\n      \"yarn.nodemanager.resource-plugins.gpu.allowed-gpu-devices\":\"auto\",\n      \"yarn.nodemanager.resource-plugins.gpu.path-to-discovery-executables\":\"/usr/bin\",\n      \"yarn.nodemanager.linux-container-executor.cgroups.mount\":\"true\",\n      \"yarn.nodemanager.linux-container-executor.cgroups.mount-path\":\"/sys/fs/cgroup\",\n      \"yarn.nodemanager.linux-container-executor.cgroups.hierarchy\":\"yarn\",\n      \"yarn.nodemanager.container-executor.class\":\"org.apache.hadoop.yarn.server.nodemanager.LinuxContainerExecutor\"\n    }\n  },\n  {\n    \"Classification\":\"container-executor\",\n    \"Properties\":{\n\n    },\n    \"Configurations\":[\n      {\n        \"Classification\":\"gpu\",\n        \"Properties\":{\n          \"module.enabled\":\"true\"\n        }\n      },\n      {\n        \"Classification\":\"cgroups\",\n        \"Properties\":{\n          \"root\":\"/sys/fs/cgroup\",\n          \"yarn-hierarchy\":\"yarn\"\n        }\n      }\n    ]\n  },\n  {\n    \"Classification\":\"spark-defaults\",\n    \"Properties\":{\n      \"spark.plugins\":\"com.nvidia.spark.SQLPlugin\",\n      \"spark.executor.resource.gpu.discoveryScript\":\"/usr/lib/spark/scripts/gpu/getGpusResources.sh\",\n      \"spark.submit.pyFiles\":\"/usr/lib/spark/jars/xgboost4j-spark_3.0-1.4.2-0.3.0.jar\",\n      \"spark.executor.extraLibraryPath\":\"/usr/local/cuda/targets/x86_64-linux/lib:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda/compat/lib:/usr/local/cuda/lib:/usr/local/cuda/lib64:/usr/lib/hadoop/lib/native:/usr/lib/hadoop-lzo/lib/native:/docker/usr/lib/hadoop/lib/native:/docker/usr/lib/hadoop-lzo/lib/native\",\n      \"spark.rapids.sql.concurrentGpuTasks\":\"2\",\n      \"spark.executor.resource.gpu.amount\":\"1\",\n      \"spark.executor.cores\":\"${executor_cores}\",\n      \"spark.task.cpus\":\"1\",\n      \"spark.task.resource.gpu.amount\":\"${task_gpu_amount}\",\n      \"spark.rapids.memory.pinnedPool.size\":\"2G\",\n      \"spark.executor.memoryOverhead\":\"2G\",\n      \"spark.sql.files.maxPartitionBytes\":\"256m\",\n      \"spark.sql.adaptive.enabled\":\"false\"\n    }\n  },\n  {\n    \"Classification\":\"capacity-scheduler\",\n    \"Properties\":{\n      \"yarn.scheduler.capacity.resource-calculator\":\"org.apache.hadoop.yarn.util.resource.DominantResourceCalculator\"\n    }\n  }\n]\n"
  },
  {
    "path": "scripts/csp-startup-scripts/emr/config-emr7.json",
    "content": "[\n  {\n    \"Classification\": \"spark\",\n    \"Properties\": {\n      \"enableSparkRapids\": \"true\"\n    }\n  },\n  {\n    \"Classification\": \"yarn-site\",\n    \"Properties\": {\n      \"yarn.nodemanager.resource-plugins\": \"yarn.io/gpu\",\n      \"yarn.resource-types\": \"yarn.io/gpu\",\n      \"yarn.nodemanager.resource-plugins.gpu.allowed-gpu-devices\": \"auto\",\n      \"yarn.nodemanager.resource-plugins.gpu.path-to-discovery-executables\": \"/usr/bin\",\n      \"yarn.nodemanager.linux-container-executor.cgroups.mount\": \"true\",\n      \"yarn.nodemanager.linux-container-executor.cgroups.mount-path\": \"/spark-rapids-cgroup\",\n      \"yarn.nodemanager.linux-container-executor.cgroups.hierarchy\": \"yarn\",\n      \"yarn.nodemanager.container-executor.class\": \"org.apache.hadoop.yarn.server.nodemanager.LinuxContainerExecutor\"\n    }\n  },\n  {\n    \"Classification\": \"container-executor\",\n    \"Properties\": {},\n    \"Configurations\": [\n      {\n        \"Classification\": \"gpu\",\n        \"Properties\": {\n          \"module.enabled\": \"true\"\n        }\n      },\n      {\n        \"Classification\": \"cgroups\",\n        \"Properties\": {\n          \"root\": \"/spark-rapids-cgroup\",\n          \"yarn-hierarchy\": \"yarn\"\n        }\n      }\n    ]\n  },\n  {\n    \"Classification\": \"spark-defaults\",\n    \"Properties\": {\n      \"spark.plugins\": \"com.nvidia.spark.SQLPlugin\",\n      \"spark.executor.resource.gpu.discoveryScript\": \"/usr/lib/spark/scripts/gpu/getGpusResources.sh\",\n      \"spark.submit.pyFiles\": \"/usr/lib/spark/jars/xgboost4j-spark_3.0-1.4.2-0.3.0.jar\",\n      \"spark.executor.extraLibraryPath\": \"/usr/local/cuda/targets/x86_64-linux/lib:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda/compat/lib:/usr/local/cuda/lib:/usr/local/cuda/lib64:/usr/lib/hadoop/lib/native:/usr/lib/hadoop-lzo/lib/native:/docker/usr/lib/hadoop/lib/native:/docker/usr/lib/hadoop-lzo/lib/native\",\n      \"spark.rapids.sql.concurrentGpuTasks\": \"2\",\n      \"spark.executor.resource.gpu.amount\": \"1\",\n      \"spark.executor.cores\": \"${executor_cores}\",\n      \"spark.task.cpus\": \"1\",\n      \"spark.task.resource.gpu.amount\": \"${task_gpu_amount}\",\n      \"spark.rapids.memory.pinnedPool.size\": \"2G\",\n      \"spark.executor.memoryOverhead\": \"2G\",\n      \"spark.sql.files.maxPartitionBytes\": \"256m\",\n      \"spark.sql.adaptive.enabled\": \"false\"\n    }\n  },\n  {\n    \"Classification\": \"capacity-scheduler\",\n    \"Properties\": {\n      \"yarn.scheduler.capacity.resource-calculator\": \"org.apache.hadoop.yarn.util.resource.DominantResourceCalculator\"\n    }\n  }\n]\n"
  },
  {
    "path": "scripts/csp-startup-scripts/emr/emr-spark-plugin-startup.py",
    "content": "#\n# Copyright (c) 2024-2026, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n\nimport argparse\nimport json\nimport os\nimport subprocess\nimport tempfile\nimport boto3\nfrom botocore.exceptions import NoCredentialsError, PartialCredentialsError\n\ndef upload_file_to_s3(file_name, bucket_name, object_name=None):\n    s3 = boto3.client('s3')\n\n    # If no object name is specified, use the file name\n    if object_name is None:\n        object_name = file_name\n\n    try:\n        s3.upload_file(file_name, bucket_name, object_name)\n        print(f\"File '{file_name}' uploaded successfully to bucket '{bucket_name}' as '{object_name}'\")\n        return True\n    except FileNotFoundError:\n        print(f\"Error: The file {file_name} was not found.\")\n    except NoCredentialsError:\n        print(\"Error: AWS credentials not found.\")\n    except PartialCredentialsError:\n        print(\"Error: Incomplete AWS credentials.\")\n    except Exception as e:\n        print(f\"An error occurred: {e}\")\n    return False\n\ng4dn_instance_map = {\n    \"g4dn.xlarge\": 4,\n    \"g4dn.2xlarge\": 8,\n    \"g4dn.4xlarge\": 16,\n    \"g4dn.12xlarge\": 48,\n    \"g4dn.16xlarge\": 64\n}\n\n_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))\n\ndef create_emr_cluster(release_label, key_name, service_role, subnet_id, az, instance_profile, worker_instance, s3_bucket_name):\n    try:\n        conf_json_fn = None\n        bootstrap_fn = None\n        if \"emr-7\" in release_label:\n            conf_json_fn=\"config-emr7.json\"\n            bootstrap_fn=\"cgroup-bootstrap-action-emr7.sh\"\n        else:\n            conf_json_fn=\"config-emr6.json\"\n            bootstrap_fn=\"cgroup-bootstrap-action-emr6.sh\"\n        # Replace the fields in the json\n        exec_cores = g4dn_instance_map.get(worker_instance)\n        if exec_cores is None:\n            print(f\"Error: Unsupported worker instance type '{worker_instance}'. \"\n                  f\"Supported types: {list(g4dn_instance_map.keys())}\")\n            return\n\n        conf_json_path = os.path.join(_SCRIPT_DIR, conf_json_fn)\n        bootstrap_path = os.path.join(_SCRIPT_DIR, bootstrap_fn)\n        print(\"Config Json\" + conf_json_fn)\n        with open(conf_json_path, 'r') as file:\n            data = json.load(file)\n        json_string = json.dumps(data)\n\n        # Replace the placeholder with the actual variable\n        json_string = json_string.replace(\"${task_gpu_amount}\", str(1/exec_cores))\n        json_string = json_string.replace(\"${executor_cores}\", str(exec_cores))\n        updated_data = json.loads(json_string)\n\n        print(json.dumps(updated_data, indent=4))\n        if not upload_file_to_s3(bootstrap_path, s3_bucket_name, bootstrap_fn):\n            return\n\n        with tempfile.NamedTemporaryFile(mode=\"w\", suffix=\".json\") as config_file:\n            json.dump(updated_data, config_file)\n            config_file.flush()\n\n            command = [\n                \"aws\", \"emr\", \"create-cluster\",\n                \"--release-label\", release_label,\n                \"--applications\", \"Name=Hadoop\", \"Name=Spark\", \"Name=Livy\", \"Name=JupyterEnterpriseGateway\",\n                \"--service-role\", service_role,\n                \"--ec2-attributes\",\n                f\"KeyName={key_name},SubnetId={subnet_id},AvailabilityZone={az},InstanceProfile={instance_profile}\",\n                \"--instance-groups\",\n                \"InstanceGroupType=MASTER,InstanceCount=1,InstanceType=m4.4xlarge\",\n                f\"InstanceGroupType=CORE,InstanceCount=1,InstanceType={worker_instance}\",\n                \"--configurations\", f\"file://{config_file.name}\",\n                \"--bootstrap-actions\", f\"Name='Setup cgroups bootstrap',Path=s3://{s3_bucket_name}/{bootstrap_fn}\"\n            ]\n\n            result = subprocess.run(command, check=True, text=True, capture_output=True)\n\n        print(\"Cluster created successfully!\")\n        print(result.stdout)\n\n    except subprocess.CalledProcessError as e:\n        print(\"Error creating EMR cluster:\", e.stderr)\n\n\nparser = argparse.ArgumentParser(description=\"A script that takes command-line arguments.\")\n\n# Define arguments\nparser.add_argument(\"-r\", \"--release_label\", type=str, default=\"emr-7.1.0\",  help=\"EMR Release Label, emr-7.1.0 for example\")\nparser.add_argument(\"-k\", \"--key_name\", type=str, required=True, help=\"Access Key Name\")\nparser.add_argument(\"-s\", \"--service_role\", type=str, required=True, help=\"AWS EMR service Role\")\nparser.add_argument(\"-n\", \"--subnet\", type=str, required=True, help=\"Subnet ID\")\nparser.add_argument(\"-z\", \"--availability_zone\", type=str, default=\"us-west-2b\", help=\"Availability Zone\")\nparser.add_argument(\"-i\", \"--instance_profile\", type=str, required=True, help=\"Instance Profile\")\nparser.add_argument(\"-w\", \"--worker_instance\", type=str, default=\"g4dn.2xlarge\",  help=\"Worker Instance g4dn.xxxx\")\nparser.add_argument(\"-b\", \"--s3_bucket_name\", type=str, required=True, help=\"S3 Bucket Name to store the bootstrap and config info\")\n\nargs = parser.parse_args()\n\nrelease_label = args.release_label\nkey_name = args.key_name\nservice_role = args.service_role\nsubnet_id = args.subnet\naz = args.availability_zone\ninstance_profile = args.instance_profile\nworker_instance = args.worker_instance\ns3_bucket_name = args.s3_bucket_name\n\ncreate_emr_cluster(release_label, key_name, service_role, subnet_id, az, instance_profile, worker_instance, s3_bucket_name)\n"
  },
  {
    "path": "scripts/encoding/python/.gitignore",
    "content": ".idea\n"
  },
  {
    "path": "scripts/encoding/python/com/__init__.py",
    "content": "# Copyright (c) 2024, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n"
  },
  {
    "path": "scripts/encoding/python/com/nvidia/__init__.py",
    "content": "# Copyright (c) 2024, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n"
  },
  {
    "path": "scripts/encoding/python/com/nvidia/spark/__init__.py",
    "content": "# Copyright (c) 2024, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n"
  },
  {
    "path": "scripts/encoding/python/com/nvidia/spark/encoding/__init__.py",
    "content": "# Copyright (c) 2024, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n"
  },
  {
    "path": "scripts/encoding/python/com/nvidia/spark/encoding/criteo/__init__.py",
    "content": "# Copyright (c) 2024, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n"
  },
  {
    "path": "scripts/encoding/python/com/nvidia/spark/encoding/criteo/common.py",
    "content": "#\n# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n# http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n\ndef customize_reader(reader):\n    (reader\n        .option('sep', '\\t'))\n\ndef customize_writer(writer):\n    (writer\n        .option('sep', '\\t')\n        .option('nullValue', None))\n"
  },
  {
    "path": "scripts/encoding/python/com/nvidia/spark/encoding/criteo/one_hot_cpu_main.py",
    "content": "#\n# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n# http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\nfrom com.nvidia.spark.encoding.criteo.common import *\nfrom com.nvidia.spark.encoding.utility.utils import *\nfrom pyspark.ml.feature import StringIndexer, StringIndexerModel\nfrom pyspark.sql import SparkSession\nfrom pyspark.sql.functions import col\n\ndef index(df, column):\n    column_index = column + '_index'\n    return (StringIndexer(inputCol=column, outputCol=column_index)\n        .setHandleInvalid('keep')\n        .fit(df))\n\ndef expand(indexer, df, column):\n    column_index = column + '_index'\n    df = (indexer\n        .transform(df)\n        .withColumn(column_index, col(column_index).cast('int')))\n    for i in range(0, len(indexer.labels)):\n        df = df.withColumn(column + '_' + str(i), (col(column_index) == i).cast('int'))\n    return df.drop(column, column_index)\n\ndef main(args):\n    spark = (SparkSession\n        .builder\n        .appName(args.mainClass)\n        .getOrCreate())\n\n    if args.mode == 'train':\n        df = load_data(spark, args.inputPaths, args, customize_reader).cache()\n        for column, path in zip(args.columns, args.modelPaths):\n            indexer = index(df, column)\n            save_model(indexer, path, args)\n\n    if args.mode == 'transform':\n        indexers = list(zip(args.columns, load_models(StringIndexerModel, args.modelPaths)))\n        for input_path, output_path in zip(args.inputPaths, args.outputPaths):\n            df = load_data(spark, input_path, args, customize_reader)\n            for column, indexer in indexers:\n                df = expand(indexer, df, column)\n            args.numRows and df.show(args.numRows)\n            save_data(df, output_path, args, customize_writer)\n\n    spark.stop()\n"
  },
  {
    "path": "scripts/encoding/python/com/nvidia/spark/encoding/criteo/target_cpu_main.py",
    "content": "#\n# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n# http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\nfrom com.nvidia.spark.encoding.criteo.common import *\nfrom com.nvidia.spark.encoding.utility.utils import *\nfrom pyspark.sql import SparkSession\nfrom pyspark.sql.functions import udf\nfrom pyspark.sql import functions as F\nfrom pyspark.sql.types import FloatType, DoubleType\nimport time\n\n\ndef get_dict_df(train_df, target_col, label_col):\n    '''\n    get one dict dataframe for one column\n    '''\n    col_target_df = train_df.groupBy(target_col).agg(F.mean(label_col))\n    return col_target_df\n\ndef encode_df(original_df, dict_df, col_name):\n    dict_df_rename = dict_df.withColumnRenamed('_c0', 'hash').withColumnRenamed('_c1', col_name+'_mean')\n    df_mean = (original_df.join(dict_df_rename, original_df[col_name] == dict_df_rename['hash'], how='left').drop('hash').drop(col_name)\n        .na.fill(-1, [col_name + '_mean']))\n    return df_mean\n\n\ndef main(args):\n    spark = (SparkSession\n        .builder\n        .appName(args.mainClass)\n        .getOrCreate())\n    if args.mode == 'train':\n        for col_name, model_path in zip(args.columns, args.modelPaths):\n            df = load_data(spark, args.inputPaths, args, customize_reader).cache()\n            dict_df = get_dict_df(df, col_name, args.labelColumn)\n            dict_df.repartition(1).write.csv(model_path)\n\n    if args.mode == 'transform':\n        dict_dfs = [\n            load_dict_df(spark, path).withColumn('_c1', F.col('_c1').cast(DoubleType())).cache()\n            for path in args.modelPaths\n        ]\n        for input_path, output_path in zip(args.inputPaths, args.outputPaths):\n            df = load_data(spark, input_path, args, customize_reader)\n            for col_name, dict_df in zip(args.columns, dict_dfs):\n                df = encode_df(df, dict_df, col_name)\n            save_data(df, output_path, args, customize_writer)"
  },
  {
    "path": "scripts/encoding/python/com/nvidia/spark/encoding/main.py",
    "content": "#\n# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n# http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\nfrom com.nvidia.spark.encoding.utility.args import parse_arguments\nfrom importlib import import_module\n\ndef main():\n    args = parse_arguments()\n    getattr(import_module(args.mainClass), 'main')(args)\n"
  },
  {
    "path": "scripts/encoding/python/com/nvidia/spark/encoding/utility/__init__.py",
    "content": "# Copyright (c) 2024, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n"
  },
  {
    "path": "scripts/encoding/python/com/nvidia/spark/encoding/utility/args.py",
    "content": "#\n# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n# http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\nimport sys\n\nfrom argparse import ArgumentParser\nfrom distutils.util import strtobool\n\ndef _to_bool(literal):\n    return bool(strtobool(literal))\n\ndef _to_str_list(literal):\n    return [x for x in literal.split(',') if x]\n\n_examples = [\n    'com.nvidia.spark.encoding.criteo.one_hot_cpu_main',\n    'com.nvidia.spark.encoding.criteo.target_cpu_main'\n]\n\ndef _validate_args(args):\n    usage = ''\n    if args.mode == 'transform' and not args.outputPaths:\n        usage += '  --outputPaths required for transform.\\n'\n    # for production:\n    #     validates that --columns and --inputPaths exists\n    #     validates that --inputPath and --outputPath matches for transform\n    if (args.mainClass == 'com.nvidia.spark.encoding.criteo.target_cpu_main'\n            and args.mode == 'train'\n            and not args.labelColumn):\n        usage += '  --labelColumn required for target encoding. \\n'\n    if usage:\n        print('-' * 80)\n        print('Usage:\\n' + usage)\n        sys.exit(1)\n\ndef parse_arguments():\n    parser = ArgumentParser()\n\n    # application arguments\n    parser.add_argument('--mainClass', required=True, choices=_examples)\n    parser.add_argument('--mode', choices=['train', 'transform'], required=True)\n    parser.add_argument('--format', choices=['csv'], default='csv')\n    parser.add_argument('--columns', type=_to_str_list, required=True)\n    parser.add_argument('--modelPaths', type=_to_str_list, required=True)\n    parser.add_argument('--inputPaths', type=_to_str_list, required=True)\n    parser.add_argument('--outputPaths', type=_to_str_list)             # for transform, required\n    parser.add_argument('--overwrite', type=_to_bool, default=False)\n    parser.add_argument('--numRows', type=int)                          # for transform, optional\n    parser.add_argument('--labelColumn', help='name of the label column') # for target encoding, required\n\n    parsed = parser.parse_args()\n    _validate_args(parsed)\n\n    return parsed\n"
  },
  {
    "path": "scripts/encoding/python/com/nvidia/spark/encoding/utility/utils.py",
    "content": "#\n# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n# http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\nimport pickle\n\ndef load_data(spark, paths, args, customize=None):\n    reader = (spark\n        .read\n        .format(args.format))\n    customize and customize(reader)\n    return reader.load(paths)\n\ndef save_data(data_frame, path, args, customize=None):\n    writer = (data_frame\n        .write\n        .format(args.format))\n    args.overwrite and writer.mode('overwrite')\n    customize and customize(writer)\n    writer.save(path)\n\ndef load_model(model_class, path):\n    return model_class.load(path)\n\ndef load_models(model_class, paths):\n    return [load_model(model_class, path) for path in paths]\n\ndef save_model(model, path, args):\n    writer = model.write().overwrite() if args.overwrite else model\n    writer.save(path)\n\ndef save_dict(mean_dict, target_path):\n    '''\n    target_path: full path of the target location to save the dict\n    '''\n    with open(target_path+'.pkl', 'wb') as f:\n        pickle.dump(mean_dict, f, pickle.HIGHEST_PROTOCOL)\n\ndef load_dict(dict_path):\n    '''\n    dict_path: full path of target dict with '.pkl' tail.\n    '''\n    with open(dict_path, 'rb') as f:\n        return pickle.load(f)\n\ndef load_dict_df(spark, dict_df_path):\n    return spark.read.option(\"header\",\"false\").csv(dict_df_path)\n"
  },
  {
    "path": "scripts/encoding/python/main.py",
    "content": "#\n# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n# http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\nfrom com.nvidia.spark.encoding.main import main\n\nmain()\n"
  },
  {
    "path": "scripts/encoding-sample/repartition.py",
    "content": "# Copyright (c) 2024, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Note: Plase modify the data source options for your case.\n\nimport sys\n\nfrom pyspark.sql import SparkSession\n\n(SparkSession\n    .builder\n    .getOrCreate()\n    .read\n    .option('sep', '\\t')\n    .csv(sys.argv[1])\n    .repartition(int(sys.argv[3]))\n    .write\n    .option('sep', '\\t')\n    .option('nullValue', None)\n    .csv(sys.argv[2]))\n"
  },
  {
    "path": "scripts/encoding-sample/run.sh",
    "content": "#!/bin/bash\n# Copyright (c) 2024-2025, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# clear\nrm -f encoding.zip main.py\nrm -f raw-*.csv\nrm -rf model target-* onehot-* final-*\n\n# prepare data\nhead -n 500 ../../datasets/clicklog.csv > raw-1.csv\nhead -n 750 ../../datasets/clicklog.csv | tail -n 250 > raw-2.csv\ntail -n 250 ../../datasets/clicklog.csv > raw-3.csv\n\n# assemble python libs\npushd ../encoding/python/\nzip -r ../../encoding-sample/encoding.zip ai/\ncp main.py ../../encoding-sample/\npopd\n\n# train target models/dicts\nspark-submit --py-files encoding.zip main.py \\\n    --mainClass=com.nvidia.spark.encoding.criteo.target_cpu_main --mode=train \\\n    --format=csv --inputPaths=raw-1.csv,raw-2.csv \\\n    --labelColumn=_c0 --columns=_c34,_c35 --modelPaths=model/c34.dict,model/c35.dict\nspark-submit truncate-model.py model/c34.dict model/c34_truncated.dict\nspark-submit truncate-model.py model/c35.dict model/c35_truncated.dict\n\n# train onehot models/indexers\nspark-submit --py-files encoding.zip main.py \\\n    --mainClass=com.nvidia.spark.encoding.criteo.one_hot_cpu_main --mode=train \\\n    --format=csv --inputPaths=raw-1.csv,raw-2.csv \\\n    --columns=_c19,_c26 --modelPaths=model/_c19,model/_c26\n\n# target encoding\nspark-submit --py-files encoding.zip main.py \\\n    --mainClass=com.nvidia.spark.encoding.criteo.target_cpu_main --mode=transform \\\n    --columns=_c34,_c35 --modelPaths=model/c34_truncated.dict,model/c35_truncated.dict \\\n    --format=csv --inputPaths=raw-1.csv,raw-2.csv,raw-3.csv --outputPaths=target-1,target-2,target-3\n\n# onehot encoding\n# NOTE: If the column index changed after target encoding, you should change the metadata of all\n#       models accordingly. E.g., change \"outputCol\":\"_c26_index\",\"inputCol\":\"_c26\" to\n#       \"outputCol\":\"_c25_index\",\"inputCol\":\"_c25\" for file model/_c26/metadata/part-00000.\n#       This is verified on Spark 2.x.\nspark-submit --py-files encoding.zip main.py \\\n    --mainClass=com.nvidia.spark.encoding.criteo.one_hot_cpu_main --mode=transform \\\n    --columns=_c19,_c26 --modelPaths=model/_c19,model/_c26 \\\n    --format=csv --inputPaths=target-1,target-2,target-3 --outputPaths=onehot-1,onehot-2,onehot-3\n\n# NOTE: As an example, not all categorical columns are encoded here.\n#       But please encode all categorical columns in production environment.\n\n# repartition <input path> <output path> <number of partitions>\nspark-submit repartition.py onehot-1 final-1 5\nspark-submit repartition.py onehot-2 final-2 5\nspark-submit repartition.py onehot-3 final-3 5\n\n# known issues:\n#   - Issue: \"org.apache.spark.shuffle.FetchFailedException: Too large frame: ...\":\n#     Solution: Add \"--conf spark.maxRemoteBlockSizeFetchToMem=1G\"\n"
  },
  {
    "path": "scripts/encoding-sample/truncate-model.py",
    "content": "# Copyright (c) 2024, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport sys\n\nfrom pyspark.sql import SparkSession\nfrom pyspark.sql.functions import *\n\n(SparkSession\n    .builder\n    .getOrCreate()\n    .read\n    .csv(sys.argv[1])\n    .withColumn('_c1', format_string('%.6f', col('_c1').cast('float')))\n    .withColumn('_c1', when(col('_c1') == '0.000000', lit('0.0')).otherwise(col('_c1')))\n    .withColumn('_c1', when(col('_c1') == '1.000000', lit('1.0')).otherwise(col('_c1')))\n    .repartition(1)\n    .write\n    .option('nullValue', None)\n    .csv(sys.argv[2]))\n"
  },
  {
    "path": "tools/databricks/README.md",
    "content": "# Databricks Qualification/Profiling Quick Start Notebooks\n\nThe RAPIDS Accelerator for Apache Spark includes two key tools for understanding the benefits of\nGPU acceleration as well as analyzing GPU Spark jobs.  For customers on Databricks, the quick start notebooks offer a simple interface for running the tools given a set of Spark event logs from\nCPU (qualification) or GPU (profiling) application runs.\n\nTo use a demo notebook, you can import the notebook in the Databricks Notebook UI via File->Import Notebook.\n\nOnce the demo notebook is imported, you can select run to activate the notebook to an available compute\ncluster.  Once the notebook is activated, you can enter in the log path location in the text widget at the\ntop of the notebook.  After that, select *Run all* to execute the tools for the specific logs in the log path.\n\n## Limitations\n1. Currently local, S3 or DBFS event log paths are supported.\n1. S3 path is only supported on Databricks AWS using [instance profiles](https://docs.databricks.com/en/connect/storage/tutorial-s3-instance-profile.html).\n1. Eventlog path must follow the formats `/dbfs/path/to/eventlog` or `dbfs:/path/to/eventlog` for logs stored in DBFS.\n1. Use wildcards for nested lookup of eventlogs. \n   - For example: `/dbfs/path/to/clusterlogs/*/*`\n1. Multiple event logs must be comma-separated. \n   - For example: `/dbfs/path/to/eventlog1,/dbfs/path/to/eventlog2`\n\n**Latest Tools Version Supported** 26.02.0"
  },
  {
    "path": "tools/databricks/[RAPIDS Accelerator for Apache Spark] Profiling Tool Notebook Template.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"metadata\": {},\n   \"cell_type\": \"raw\",\n   \"source\": [\n    \"{\\n\",\n    \" \\\"cells\\\": [\\n\",\n    \"  {\\n\",\n    \"   \\\"cell_type\\\": \\\"markdown\\\",\\n\",\n    \"   \\\"metadata\\\": {\\n\",\n    \"    \\\"application/vnd.databricks.v1+cell\\\": {\\n\",\n    \"     \\\"cellMetadata\\\": {\\n\",\n    \"      \\\"byteLimit\\\": 2048000,\\n\",\n    \"      \\\"rowLimit\\\": 10000\\n\",\n    \"     },\\n\",\n    \"     \\\"inputWidgets\\\": {},\\n\",\n    \"     \\\"nuid\\\": \\\"df33c614-2ecc-47a0-8600-bc891681997f\\\",\\n\",\n    \"     \\\"showTitle\\\": false,\\n\",\n    \"     \\\"title\\\": \\\"\\\"\\n\",\n    \"    }\\n\",\n    \"   },\\n\",\n    \"   \\\"source\\\": [\\n\",\n    \"    \\\"## Welcome to the Profiling Tool for the RAPIDS Accelerator for Apache Spark\\\\n\\\",\\n\",\n    \"    \\\"\\\\n\\\",\\n\",\n    \"    \\\"To run the profiling tool, enter the log path that represents the DBFS location of your Spark GPU event logs. Then, select \\\\\\\"Run all\\\\\\\" to execute the notebook. Once the notebook completes, various output tables will appear below. For more options on running the profiling tool, please refer to the [Profiling Tool User Guide](https://docs.nvidia.com/spark-rapids/user-guide/latest/profiling/quickstart.html#running-the-tool).\\\\n\\\",\\n\",\n    \"    \\\"\\\\n\\\",\\n\",\n    \"    \\\"### Note\\\\n\\\",\\n\",\n    \"    \\\"- Currently, local, S3 or DBFS event log paths are supported.\\\\n\\\",\\n\",\n    \"    \\\"- S3 path is only supported on Databricks AWS using [instance profiles](https://docs.databricks.com/en/connect/storage/tutorial-s3-instance-profile.html).\\\\n\\\",\\n\",\n    \"    \\\"- Eventlog path must follow the formats `/dbfs/path/to/eventlog` or `dbfs:/path/to/eventlog` for logs stored in DBFS.\\\\n\\\",\\n\",\n    \"    \\\"- Use wildcards for nested lookup of eventlogs. \\\\n\\\",\\n\",\n    \"    \\\"   - For example: `/dbfs/path/to/clusterlogs/*/*`\\\\n\\\",\\n\",\n    \"    \\\"- Multiple event logs must be comma-separated. \\\\n\\\",\\n\",\n    \"    \\\"   - For example: `/dbfs/path/to/eventlog1,/dbfs/path/to/eventlog2`\\\\n\\\",\\n\",\n    \"    \\\"\\\\n\\\",\\n\",\n    \"    \\\"### Per-Job Profile\\\\n\\\",\\n\",\n    \"    \\\"\\\\n\\\",\\n\",\n    \"    \\\"The profiler output includes information about the application, data sources, executors, SQL stages, Spark properties, and key application metrics at the job and stage levels.\\\"\\n\",\n    \"   ]\\n\",\n    \"  },\\n\",\n    \"  {\\n\",\n    \"   \\\"cell_type\\\": \\\"code\\\",\\n\",\n    \"   \\\"execution_count\\\": 0,\\n\",\n    \"   \\\"metadata\\\": {\\n\",\n    \"    \\\"application/vnd.databricks.v1+cell\\\": {\\n\",\n    \"     \\\"cellMetadata\\\": {\\n\",\n    \"      \\\"byteLimit\\\": 2048000,\\n\",\n    \"      \\\"rowLimit\\\": 10000\\n\",\n    \"     },\\n\",\n    \"     \\\"inputWidgets\\\": {},\\n\",\n    \"     \\\"nuid\\\": \\\"5e9f5796-46ed-49ac-9d08-c8b98a87c39d\\\",\\n\",\n    \"     \\\"showTitle\\\": true,\\n\",\n    \"     \\\"title\\\": \\\"Set Tools Version\\\"\\n\",\n    \"    },\\n\",\n    \"    \\\"jupyter\\\": {\\n\",\n    \"     \\\"source_hidden\\\": true\\n\",\n    \"    }\\n\",\n    \"   },\\n\",\n    \"   \\\"outputs\\\": [],\\n\",\n    \"   \\\"source\\\": [\\n\",\n    \"    \\\"DEFAULT_TOOLS_VER = \\\\\\\"24.12.4\\\\\\\"\\\\n\\\",\\n\",\n    \"    \\\"TOOLS_VER_ARG = dbutils.widgets.get(\\\\\\\"Tools Version\\\\\\\")\\\\n\\\",\\n\",\n    \"    \\\"TOOLS_VER = TOOLS_VER_ARG if TOOLS_VER_ARG else DEFAULT_TOOLS_VER\\\\n\\\",\\n\",\n    \"    \\\"print(f\\\\\\\"Using Tools Version: {TOOLS_VER}\\\\\\\")\\\"\\n\",\n    \"   ]\\n\",\n    \"  },\\n\",\n    \"  {\\n\",\n    \"   \\\"cell_type\\\": \\\"code\\\",\\n\",\n    \"   \\\"execution_count\\\": 0,\\n\",\n    \"   \\\"metadata\\\": {\\n\",\n    \"    \\\"application/vnd.databricks.v1+cell\\\": {\\n\",\n    \"     \\\"cellMetadata\\\": {\\n\",\n    \"      \\\"byteLimit\\\": 2048000,\\n\",\n    \"      \\\"rowLimit\\\": 10000\\n\",\n    \"     },\\n\",\n    \"     \\\"inputWidgets\\\": {},\\n\",\n    \"     \\\"nuid\\\": \\\"313ee58b-61b3-4010-9d60-d21eceea796c\\\",\\n\",\n    \"     \\\"showTitle\\\": true,\\n\",\n    \"     \\\"title\\\": \\\"Install Package\\\"\\n\",\n    \"    }\\n\",\n    \"   },\\n\",\n    \"   \\\"outputs\\\": [],\\n\",\n    \"   \\\"source\\\": [\\n\",\n    \"    \\\"%pip install spark-rapids-user-tools==$TOOLS_VER > /dev/null\\\"\\n\",\n    \"   ]\\n\",\n    \"  },\\n\",\n    \"  {\\n\",\n    \"   \\\"cell_type\\\": \\\"code\\\",\\n\",\n    \"   \\\"execution_count\\\": 0,\\n\",\n    \"   \\\"metadata\\\": {\\n\",\n    \"    \\\"application/vnd.databricks.v1+cell\\\": {\\n\",\n    \"     \\\"cellMetadata\\\": {\\n\",\n    \"      \\\"byteLimit\\\": 2048000,\\n\",\n    \"      \\\"rowLimit\\\": 10000\\n\",\n    \"     },\\n\",\n    \"     \\\"inputWidgets\\\": {},\\n\",\n    \"     \\\"nuid\\\": \\\"34492d18-1130-45be-b9f7-e6931d3fa66b\\\",\\n\",\n    \"     \\\"showTitle\\\": true,\\n\",\n    \"     \\\"title\\\": \\\"Environment Setup\\\"\\n\",\n    \"    },\\n\",\n    \"    \\\"jupyter\\\": {\\n\",\n    \"     \\\"source_hidden\\\": true\\n\",\n    \"    }\\n\",\n    \"   },\\n\",\n    \"   \\\"outputs\\\": [],\\n\",\n    \"   \\\"source\\\": [\\n\",\n    \"    \\\"import os\\\\n\\\",\\n\",\n    \"    \\\"import pandas as pd\\\\n\\\",\\n\",\n    \"    \\\"\\\\n\\\",\\n\",\n    \"    \\\"\\\\n\\\",\\n\",\n    \"    \\\"def convert_dbfs_path(path):\\\\n\\\",\\n\",\n    \"    \\\"    return path.replace(\\\\\\\"dbfs:/\\\\\\\", \\\\\\\"/dbfs/\\\\\\\")\\\\n\\\",\\n\",\n    \"    \\\"  \\\\n\\\",\\n\",\n    \"    \\\"# Detect cloud provider from cluster usage tags\\\\n\\\",\\n\",\n    \"    \\\"valid_csps = [\\\\\\\"aws\\\\\\\", \\\\\\\"azure\\\\\\\"]\\\\n\\\",\\n\",\n    \"    \\\"CSP=spark.conf.get(\\\\\\\"spark.databricks.clusterUsageTags.cloudProvider\\\\\\\", \\\\\\\"\\\\\\\").lower()\\\\n\\\",\\n\",\n    \"    \\\"if CSP not in valid_csps:\\\\n\\\",\\n\",\n    \"    \\\"    print(f\\\\\\\"ERROR: Cannot detect cloud provider from cluster usage tags. Using '{valid_csps[0]}' as default. \\\\\\\")\\\\n\\\",\\n\",\n    \"    \\\"    CSP = valid_csps[0]\\\\n\\\",\\n\",\n    \"    \\\"else:\\\\n\\\",\\n\",\n    \"    \\\"    print(f\\\\\\\"Detected Cloud Provider from Spark Configs: '{CSP}'\\\\\\\")\\\\n\\\",\\n\",\n    \"    \\\"\\\\n\\\",\\n\",\n    \"    \\\"# Initialize variables from widgets\\\\n\\\",\\n\",\n    \"    \\\"dbutils.widgets.text(\\\\\\\"Eventlog Path\\\\\\\", \\\\\\\"/dbfs/user1/profiling_logs\\\\\\\")\\\\n\\\",\\n\",\n    \"    \\\"EVENTLOG_PATH=dbutils.widgets.get(\\\\\\\"Eventlog Path\\\\\\\")\\\\n\\\",\\n\",\n    \"    \\\"EVENTLOG_PATH=convert_dbfs_path(EVENTLOG_PATH)\\\\n\\\",\\n\",\n    \"    \\\"\\\\n\\\",\\n\",\n    \"    \\\"dbutils.widgets.text(\\\\\\\"Output Path\\\\\\\", \\\\\\\"/tmp\\\\\\\")\\\\n\\\",\\n\",\n    \"    \\\"OUTPUT_PATH=dbutils.widgets.get(\\\\\\\"Output Path\\\\\\\")\\\\n\\\",\\n\",\n    \"    \\\"\\\\n\\\",\\n\",\n    \"    \\\"# Setup environment variables\\\\n\\\",\\n\",\n    \"    \\\"os.environ[\\\\\\\"CSP\\\\\\\"] = CSP\\\\n\\\",\\n\",\n    \"    \\\"os.environ[\\\\\\\"EVENTLOG_PATH\\\\\\\"] = EVENTLOG_PATH\\\\n\\\",\\n\",\n    \"    \\\"os.environ[\\\\\\\"OUTPUT_PATH\\\\\\\"] = OUTPUT_PATH\\\\n\\\",\\n\",\n    \"    \\\"\\\\n\\\",\\n\",\n    \"    \\\"# Setup console output file\\\\n\\\",\\n\",\n    \"    \\\"CONSOLE_OUTPUT_PATH = os.path.join(OUTPUT_PATH, 'console_output.log')\\\\n\\\",\\n\",\n    \"    \\\"CONSOLE_ERROR_PATH = os.path.join(OUTPUT_PATH, 'console_error.log')\\\\n\\\",\\n\",\n    \"    \\\"os.environ['CONSOLE_OUTPUT_PATH'] = CONSOLE_OUTPUT_PATH\\\\n\\\",\\n\",\n    \"    \\\"os.environ['CONSOLE_ERROR_PATH'] = CONSOLE_ERROR_PATH\\\\n\\\",\\n\",\n    \"    \\\"print(f'Console output will be stored at {CONSOLE_OUTPUT_PATH} and errors will be stored at {CONSOLE_ERROR_PATH}')\\\"\\n\",\n    \"   ]\\n\",\n    \"  },\\n\",\n    \"  {\\n\",\n    \"   \\\"cell_type\\\": \\\"code\\\",\\n\",\n    \"   \\\"execution_count\\\": 0,\\n\",\n    \"   \\\"metadata\\\": {\\n\",\n    \"    \\\"application/vnd.databricks.v1+cell\\\": {\\n\",\n    \"     \\\"cellMetadata\\\": {\\n\",\n    \"      \\\"byteLimit\\\": 2048000,\\n\",\n    \"      \\\"rowLimit\\\": 10000\\n\",\n    \"     },\\n\",\n    \"     \\\"inputWidgets\\\": {},\\n\",\n    \"     \\\"nuid\\\": \\\"693b5ee0-7500-43f3-b3e2-717fd5468aa8\\\",\\n\",\n    \"     \\\"showTitle\\\": true,\\n\",\n    \"     \\\"title\\\": \\\"Run Profiling Tool\\\"\\n\",\n    \"    }\\n\",\n    \"   },\\n\",\n    \"   \\\"outputs\\\": [],\\n\",\n    \"   \\\"source\\\": [\\n\",\n    \"    \\\"%sh\\\\n\\\",\\n\",\n    \"    \\\"spark_rapids profiling --platform databricks-$CSP --eventlogs \\\\\\\"$EVENTLOG_PATH\\\\\\\" -o \\\\\\\"$OUTPUT_PATH\\\\\\\" --verbose > \\\\\\\"$CONSOLE_OUTPUT_PATH\\\\\\\" 2> \\\\\\\"$CONSOLE_ERROR_PATH\\\\\\\"\\\"\\n\",\n    \"   ]\\n\",\n    \"  },\\n\",\n    \"  {\\n\",\n    \"   \\\"cell_type\\\": \\\"markdown\\\",\\n\",\n    \"   \\\"metadata\\\": {\\n\",\n    \"    \\\"application/vnd.databricks.v1+cell\\\": {\\n\",\n    \"     \\\"cellMetadata\\\": {\\n\",\n    \"      \\\"byteLimit\\\": 2048000,\\n\",\n    \"      \\\"rowLimit\\\": 10000\\n\",\n    \"     },\\n\",\n    \"     \\\"inputWidgets\\\": {},\\n\",\n    \"     \\\"nuid\\\": \\\"f83af6c8-5a79-4a46-965b-38a4cb621877\\\",\\n\",\n    \"     \\\"showTitle\\\": false,\\n\",\n    \"     \\\"title\\\": \\\"\\\"\\n\",\n    \"    }\\n\",\n    \"   },\\n\",\n    \"   \\\"source\\\": [\\n\",\n    \"    \\\"## Console Output\\\\n\\\",\\n\",\n    \"    \\\"Console output shows the recommended configurations for each app\\\\n\\\"\\n\",\n    \"   ]\\n\",\n    \"  },\\n\",\n    \"  {\\n\",\n    \"   \\\"cell_type\\\": \\\"code\\\",\\n\",\n    \"   \\\"execution_count\\\": 0,\\n\",\n    \"   \\\"metadata\\\": {\\n\",\n    \"    \\\"application/vnd.databricks.v1+cell\\\": {\\n\",\n    \"     \\\"cellMetadata\\\": {\\n\",\n    \"      \\\"byteLimit\\\": 2048000,\\n\",\n    \"      \\\"rowLimit\\\": 10000\\n\",\n    \"     },\\n\",\n    \"     \\\"inputWidgets\\\": {},\\n\",\n    \"     \\\"nuid\\\": \\\"c61527b7-a21a-492c-bab8-77f83dc5cabf\\\",\\n\",\n    \"     \\\"showTitle\\\": true,\\n\",\n    \"     \\\"title\\\": \\\"Show Console Output\\\"\\n\",\n    \"    },\\n\",\n    \"    \\\"jupyter\\\": {\\n\",\n    \"     \\\"source_hidden\\\": true\\n\",\n    \"    }\\n\",\n    \"   },\\n\",\n    \"   \\\"outputs\\\": [],\\n\",\n    \"   \\\"source\\\": [\\n\",\n    \"    \\\"%sh\\\\n\\\",\\n\",\n    \"    \\\"cat $CONSOLE_OUTPUT_PATH\\\"\\n\",\n    \"   ]\\n\",\n    \"  },\\n\",\n    \"  {\\n\",\n    \"   \\\"cell_type\\\": \\\"code\\\",\\n\",\n    \"   \\\"execution_count\\\": 0,\\n\",\n    \"   \\\"metadata\\\": {\\n\",\n    \"    \\\"application/vnd.databricks.v1+cell\\\": {\\n\",\n    \"     \\\"cellMetadata\\\": {\\n\",\n    \"      \\\"byteLimit\\\": 2048000,\\n\",\n    \"      \\\"rowLimit\\\": 10000\\n\",\n    \"     },\\n\",\n    \"     \\\"inputWidgets\\\": {},\\n\",\n    \"     \\\"nuid\\\": \\\"f3c68b28-fc62-40ae-8528-799f3fc7507e\\\",\\n\",\n    \"     \\\"showTitle\\\": true,\\n\",\n    \"     \\\"title\\\": \\\"Show Logs\\\"\\n\",\n    \"    },\\n\",\n    \"    \\\"jupyter\\\": {\\n\",\n    \"     \\\"source_hidden\\\": true\\n\",\n    \"    }\\n\",\n    \"   },\\n\",\n    \"   \\\"outputs\\\": [],\\n\",\n    \"   \\\"source\\\": [\\n\",\n    \"    \\\"%sh\\\\n\\\",\\n\",\n    \"    \\\"cat $CONSOLE_ERROR_PATH\\\"\\n\",\n    \"   ]\\n\",\n    \"  },\\n\",\n    \"  {\\n\",\n    \"   \\\"cell_type\\\": \\\"code\\\",\\n\",\n    \"   \\\"execution_count\\\": 0,\\n\",\n    \"   \\\"metadata\\\": {\\n\",\n    \"    \\\"application/vnd.databricks.v1+cell\\\": {\\n\",\n    \"     \\\"cellMetadata\\\": {\\n\",\n    \"      \\\"byteLimit\\\": 2048000,\\n\",\n    \"      \\\"rowLimit\\\": 10000\\n\",\n    \"     },\\n\",\n    \"     \\\"inputWidgets\\\": {},\\n\",\n    \"     \\\"nuid\\\": \\\"05f96ca1-1b08-494c-a12b-7e6cc3dcc546\\\",\\n\",\n    \"     \\\"showTitle\\\": true,\\n\",\n    \"     \\\"title\\\": \\\"Parse Output\\\"\\n\",\n    \"    },\\n\",\n    \"    \\\"jupyter\\\": {\\n\",\n    \"     \\\"source_hidden\\\": true\\n\",\n    \"    }\\n\",\n    \"   },\\n\",\n    \"   \\\"outputs\\\": [],\\n\",\n    \"   \\\"source\\\": [\\n\",\n    \"    \\\"import re\\\\n\\\",\\n\",\n    \"    \\\"import shutil\\\\n\\\",\\n\",\n    \"    \\\"import os\\\\n\\\",\\n\",\n    \"    \\\"\\\\n\\\",\\n\",\n    \"    \\\"def extract_file_info(console_output_path, output_base_path):\\\\n\\\",\\n\",\n    \"    \\\"    try:\\\\n\\\",\\n\",\n    \"    \\\"        with open(console_output_path, 'r') as file:\\\\n\\\",\\n\",\n    \"    \\\"            stdout_text = file.read()\\\\n\\\",\\n\",\n    \"    \\\"        \\\\n\\\",\\n\",\n    \"    \\\"        # Extract log file location\\\\n\\\",\\n\",\n    \"    \\\"        location_match = re.search(r\\\\\\\"Location: (.+)\\\\\\\", stdout_text)\\\\n\\\",\\n\",\n    \"    \\\"        if not location_match:\\\\n\\\",\\n\",\n    \"    \\\"            raise ValueError(\\\\\\\"Log file location not found in the provided text.\\\\\\\")\\\\n\\\",\\n\",\n    \"    \\\"        \\\\n\\\",\\n\",\n    \"    \\\"        log_file_location = location_match.group(1)\\\\n\\\",\\n\",\n    \"    \\\"        \\\\n\\\",\\n\",\n    \"    \\\"        # Extract profiling output folder\\\\n\\\",\\n\",\n    \"    \\\"        prof_match = re.search(r\\\\\\\"prof_[^/]+(?=\\\\\\\\.log)\\\\\\\", log_file_location)\\\\n\\\",\\n\",\n    \"    \\\"        if not prof_match:\\\\n\\\",\\n\",\n    \"    \\\"            raise ValueError(\\\\\\\"Output folder not found in the log file location.\\\\\\\")\\\\n\\\",\\n\",\n    \"    \\\"        \\\\n\\\",\\n\",\n    \"    \\\"        output_folder_name = prof_match.group(0)\\\\n\\\",\\n\",\n    \"    \\\"        output_folder = os.path.join(output_base_path, output_folder_name)\\\\n\\\",\\n\",\n    \"    \\\"        return output_folder, log_file_location\\\\n\\\",\\n\",\n    \"    \\\"    \\\\n\\\",\\n\",\n    \"    \\\"    except Exception as e:\\\\n\\\",\\n\",\n    \"    \\\"        raise RuntimeError(f\\\\\\\"Cannot parse console output. Reason: {e}\\\\\\\")\\\\n\\\",\\n\",\n    \"    \\\"\\\\n\\\",\\n\",\n    \"    \\\"def copy_logs(destination_folder, *log_files):\\\\n\\\",\\n\",\n    \"    \\\"    try:\\\\n\\\",\\n\",\n    \"    \\\"        log_folder = os.path.join(destination_folder, \\\\\\\"logs\\\\\\\")\\\\n\\\",\\n\",\n    \"    \\\"        os.makedirs(log_folder, exist_ok=True)\\\\n\\\",\\n\",\n    \"    \\\"        \\\\n\\\",\\n\",\n    \"    \\\"        for log_file in log_files:\\\\n\\\",\\n\",\n    \"    \\\"            if os.path.exists(log_file):\\\\n\\\",\\n\",\n    \"    \\\"                shutil.copy2(log_file, log_folder)\\\\n\\\",\\n\",\n    \"    \\\"            else:\\\\n\\\",\\n\",\n    \"    \\\"                print(f\\\\\\\"Log file not found: {log_file}\\\\\\\")\\\\n\\\",\\n\",\n    \"    \\\"    except Exception as e:\\\\n\\\",\\n\",\n    \"    \\\"        raise RuntimeError(f\\\\\\\"Cannot copy logs to output. Reason: {e}\\\\\\\")\\\\n\\\",\\n\",\n    \"    \\\"\\\\n\\\",\\n\",\n    \"    \\\"try:\\\\n\\\",\\n\",\n    \"    \\\"    output_folder, log_file_location = extract_file_info(CONSOLE_OUTPUT_PATH, OUTPUT_PATH)\\\\n\\\",\\n\",\n    \"    \\\"    print(f\\\\\\\"Output folder detected {output_folder}\\\\\\\")\\\\n\\\",\\n\",\n    \"    \\\"    copy_logs(output_folder, log_file_location, CONSOLE_OUTPUT_PATH, CONSOLE_ERROR_PATH)\\\\n\\\",\\n\",\n    \"    \\\"    print(f\\\\\\\"Logs successfully copied to {output_folder}\\\\\\\")\\\\n\\\",\\n\",\n    \"    \\\"except Exception as e:\\\\n\\\",\\n\",\n    \"    \\\"    print(e)\\\\n\\\"\\n\",\n    \"   ]\\n\",\n    \"  },\\n\",\n    \"  {\\n\",\n    \"   \\\"cell_type\\\": \\\"code\\\",\\n\",\n    \"   \\\"execution_count\\\": 0,\\n\",\n    \"   \\\"metadata\\\": {\\n\",\n    \"    \\\"application/vnd.databricks.v1+cell\\\": {\\n\",\n    \"     \\\"cellMetadata\\\": {\\n\",\n    \"      \\\"byteLimit\\\": 2048000,\\n\",\n    \"      \\\"rowLimit\\\": 10000\\n\",\n    \"     },\\n\",\n    \"     \\\"inputWidgets\\\": {},\\n\",\n    \"     \\\"nuid\\\": \\\"8c65adcd-a933-482e-a50b-d40fa8f50e16\\\",\\n\",\n    \"     \\\"showTitle\\\": true,\\n\",\n    \"     \\\"title\\\": \\\"Download Output\\\"\\n\",\n    \"    },\\n\",\n    \"    \\\"jupyter\\\": {\\n\",\n    \"     \\\"source_hidden\\\": true\\n\",\n    \"    }\\n\",\n    \"   },\\n\",\n    \"   \\\"outputs\\\": [],\\n\",\n    \"   \\\"source\\\": [\\n\",\n    \"    \\\"import shutil\\\\n\\\",\\n\",\n    \"    \\\"import os\\\\n\\\",\\n\",\n    \"    \\\"import re\\\\n\\\",\\n\",\n    \"    \\\"\\\\n\\\",\\n\",\n    \"    \\\"current_working_directory = os.getcwd()\\\\n\\\",\\n\",\n    \"    \\\"\\\\n\\\",\\n\",\n    \"    \\\"def create_destination_folders(folder_name):\\\\n\\\",\\n\",\n    \"    \\\"    os.makedirs(folder_name, exist_ok=True)\\\\n\\\",\\n\",\n    \"    \\\"    base_download_folder_path = os.path.join(\\\\\\\"/dbfs/FileStore/\\\\\\\", folder_name)\\\\n\\\",\\n\",\n    \"    \\\"    os.makedirs(base_download_folder_path, exist_ok=True) \\\\n\\\",\\n\",\n    \"    \\\"    return base_download_folder_path\\\\n\\\",\\n\",\n    \"    \\\"\\\\n\\\",\\n\",\n    \"    \\\"def create_download_link(source_folder, destination_folder_name):\\\\n\\\",\\n\",\n    \"    \\\"    folder_to_compress = os.path.basename(source_folder)\\\\n\\\",\\n\",\n    \"    \\\"    zip_file_name = folder_to_compress + '.zip'\\\\n\\\",\\n\",\n    \"    \\\"    local_zip_file_path = os.path.join(current_working_directory, destination_folder_name, zip_file_name)\\\\n\\\",\\n\",\n    \"    \\\"    download_folder_path = os.path.join(destination_folder_name, zip_file_name)\\\\n\\\",\\n\",\n    \"    \\\"    try:\\\\n\\\",\\n\",\n    \"    \\\"        base_download_folder_path = create_destination_folders(destination_folder_name)\\\\n\\\",\\n\",\n    \"    \\\"        shutil.make_archive(folder_to_compress, 'zip', source_folder)\\\\n\\\",\\n\",\n    \"    \\\"        shutil.copy2(zip_file_name, base_download_folder_path)\\\\n\\\",\\n\",\n    \"    \\\"        if os.path.exists(local_zip_file_path):\\\\n\\\",\\n\",\n    \"    \\\"            os.remove(local_zip_file_path)\\\\n\\\",\\n\",\n    \"    \\\"        shutil.move(zip_file_name, local_zip_file_path)\\\\n\\\",\\n\",\n    \"    \\\"    \\\\n\\\",\\n\",\n    \"    \\\"        download_button_html = f\\\\\\\"\\\\\\\"\\\\\\\"\\\\n\\\",\\n\",\n    \"    \\\"        <style>\\\\n\\\",\\n\",\n    \"    \\\"            .download-btn {{\\\\n\\\",\\n\",\n    \"    \\\"                display: inline-block;\\\\n\\\",\\n\",\n    \"    \\\"                padding: 10px 20px;\\\\n\\\",\\n\",\n    \"    \\\"                font-size: 16px;\\\\n\\\",\\n\",\n    \"    \\\"                color: white;\\\\n\\\",\\n\",\n    \"    \\\"                background-color: #4CAF50;\\\\n\\\",\\n\",\n    \"    \\\"                text-align: center;\\\\n\\\",\\n\",\n    \"    \\\"                text-decoration: none;\\\\n\\\",\\n\",\n    \"    \\\"                border-radius: 5px;\\\\n\\\",\\n\",\n    \"    \\\"                border: none;\\\\n\\\",\\n\",\n    \"    \\\"                cursor: pointer;\\\\n\\\",\\n\",\n    \"    \\\"                margin: 15px auto;\\\\n\\\",\\n\",\n    \"    \\\"            }}\\\\n\\\",\\n\",\n    \"    \\\"            .download-btn:hover {{\\\\n\\\",\\n\",\n    \"    \\\"                background-color: #45a049;\\\\n\\\",\\n\",\n    \"    \\\"            }}\\\\n\\\",\\n\",\n    \"    \\\"            .button-container {{\\\\n\\\",\\n\",\n    \"    \\\"                display: flex;\\\\n\\\",\\n\",\n    \"    \\\"                justify-content: center;\\\\n\\\",\\n\",\n    \"    \\\"                align-items: center;\\\\n\\\",\\n\",\n    \"    \\\"            }}\\\\n\\\",\\n\",\n    \"    \\\"        </style>\\\\n\\\",\\n\",\n    \"    \\\"        \\\\n\\\",\\n\",\n    \"    \\\"        <div style=\\\\\\\"color: #444; font-size: 14px; text-align: center; margin: 10px;\\\\\\\">\\\\n\\\",\\n\",\n    \"    \\\"            Zipped output file created at {local_zip_file_path}\\\\n\\\",\\n\",\n    \"    \\\"        </div>\\\\n\\\",\\n\",\n    \"    \\\"        <div class='button-container'>\\\\n\\\",\\n\",\n    \"    \\\"            <a href='/files/{download_folder_path}' class='download-btn'>Download Output</a>\\\\n\\\",\\n\",\n    \"    \\\"        </div>\\\\n\\\",\\n\",\n    \"    \\\"        \\\\\\\"\\\\\\\"\\\\\\\"\\\\n\\\",\\n\",\n    \"    \\\"        displayHTML(download_button_html)\\\\n\\\",\\n\",\n    \"    \\\"    except Exception as e:\\\\n\\\",\\n\",\n    \"    \\\"        error_message_html = f\\\\\\\"\\\\\\\"\\\\\\\"\\\\n\\\",\\n\",\n    \"    \\\"        <div style=\\\\\\\"color: red; text-align: center; margin: 20px;\\\\\\\">\\\\n\\\",\\n\",\n    \"    \\\"            <strong>Error:</strong> Cannot create download link for {source_folder}. Reason: {e}\\\\n\\\",\\n\",\n    \"    \\\"        </div>\\\\n\\\",\\n\",\n    \"    \\\"        \\\\\\\"\\\\\\\"\\\\\\\"\\\\n\\\",\\n\",\n    \"    \\\"        displayHTML(error_message_html)\\\\n\\\",\\n\",\n    \"    \\\"\\\\n\\\",\\n\",\n    \"    \\\"destination_folder_name = \\\\\\\"Tools_Output\\\\\\\"\\\\n\\\",\\n\",\n    \"    \\\"create_download_link(output_folder, destination_folder_name)\\\\n\\\"\\n\",\n    \"   ]\\n\",\n    \"  },\\n\",\n    \"  {\\n\",\n    \"   \\\"cell_type\\\": \\\"markdown\\\",\\n\",\n    \"   \\\"metadata\\\": {\\n\",\n    \"    \\\"application/vnd.databricks.v1+cell\\\": {\\n\",\n    \"     \\\"cellMetadata\\\": {\\n\",\n    \"      \\\"byteLimit\\\": 2048000,\\n\",\n    \"      \\\"rowLimit\\\": 10000\\n\",\n    \"     },\\n\",\n    \"     \\\"inputWidgets\\\": {},\\n\",\n    \"     \\\"nuid\\\": \\\"bbe50fde-0bd6-4281-95fd-6a1ec6f17ab2\\\",\\n\",\n    \"     \\\"showTitle\\\": false,\\n\",\n    \"     \\\"title\\\": \\\"\\\"\\n\",\n    \"    }\\n\",\n    \"   },\\n\",\n    \"   \\\"source\\\": [\\n\",\n    \"    \\\"\\\\n\\\",\\n\",\n    \"    \\\"%md\\\\n\\\",\\n\",\n    \"    \\\"\\\\n\\\",\\n\",\n    \"    \\\"## GPU Job Tuning Recommendations\\\\n\\\",\\n\",\n    \"    \\\"This has general suggestions for tuning your applications to run optimally on GPUs.\\\\n\\\"\\n\",\n    \"   ]\\n\",\n    \"  },\\n\",\n    \"  {\\n\",\n    \"   \\\"cell_type\\\": \\\"code\\\",\\n\",\n    \"   \\\"execution_count\\\": 0,\\n\",\n    \"   \\\"metadata\\\": {\\n\",\n    \"    \\\"application/vnd.databricks.v1+cell\\\": {\\n\",\n    \"     \\\"cellMetadata\\\": {\\n\",\n    \"      \\\"byteLimit\\\": 2048000,\\n\",\n    \"      \\\"rowLimit\\\": 10000\\n\",\n    \"     },\\n\",\n    \"     \\\"inputWidgets\\\": {},\\n\",\n    \"     \\\"nuid\\\": \\\"b8bca4a6-16d8-4b60-ba7b-9aff64bdcaa1\\\",\\n\",\n    \"     \\\"showTitle\\\": true,\\n\",\n    \"     \\\"title\\\": \\\"Show Recommendations\\\"\\n\",\n    \"    },\\n\",\n    \"    \\\"jupyter\\\": {\\n\",\n    \"     \\\"source_hidden\\\": true\\n\",\n    \"    }\\n\",\n    \"   },\\n\",\n    \"   \\\"outputs\\\": [],\\n\",\n    \"   \\\"source\\\": [\\n\",\n    \"    \\\"jar_output_folder = os.path.join(output_folder, \\\\\\\"rapids_4_spark_profile\\\\\\\")\\\\n\\\",\\n\",\n    \"    \\\"app_df = pd.DataFrame(columns=['appId', 'appName'])\\\\n\\\",\\n\",\n    \"    \\\"\\\\n\\\",\\n\",\n    \"    \\\"for x in os.scandir(jar_output_folder):\\\\n\\\",\\n\",\n    \"    \\\"    if x.is_dir():\\\\n\\\",\\n\",\n    \"    \\\"        csv_path = os.path.join(x.path, \\\\\\\"application_information.csv\\\\\\\")\\\\n\\\",\\n\",\n    \"    \\\"        if os.path.exists(csv_path):\\\\n\\\",\\n\",\n    \"    \\\"          tmp_df = pd.read_csv(csv_path)\\\\n\\\",\\n\",\n    \"    \\\"          app_df = pd.concat([app_df, tmp_df[['appId', 'appName']]])\\\\n\\\",\\n\",\n    \"    \\\"\\\\n\\\",\\n\",\n    \"    \\\"\\\\n\\\",\\n\",\n    \"    \\\"app_list = app_df[\\\\\\\"appId\\\\\\\"].tolist()\\\\n\\\",\\n\",\n    \"    \\\"app_recommendations = pd.DataFrame(columns=['app', 'recommendations'])\\\\n\\\",\\n\",\n    \"    \\\"\\\\n\\\",\\n\",\n    \"    \\\"for app in app_list:\\\\n\\\",\\n\",\n    \"    \\\"  app_file = open(os.path.join(jar_output_folder, app, \\\\\\\"profile.log\\\\\\\"))\\\\n\\\",\\n\",\n    \"    \\\"  recommendations_start = 0\\\\n\\\",\\n\",\n    \"    \\\"  recommendations_str = \\\\\\\"\\\\\\\"\\\\n\\\",\\n\",\n    \"    \\\"  for line in app_file:\\\\n\\\",\\n\",\n    \"    \\\"    if recommendations_start == 1:\\\\n\\\",\\n\",\n    \"    \\\"      recommendations_str = recommendations_str + line\\\\n\\\",\\n\",\n    \"    \\\"    if \\\\\\\"### D. Recommended Configuration ###\\\\\\\" in line:\\\\n\\\",\\n\",\n    \"    \\\"      recommendations_start = 1\\\\n\\\",\\n\",\n    \"    \\\"  app_recommendations = pd.concat([app_recommendations, pd.DataFrame({'app': [app], 'recommendations': [recommendations_str]})], ignore_index=True)\\\\n\\\",\\n\",\n    \"    \\\"display(app_recommendations)\\\"\\n\",\n    \"   ]\\n\",\n    \"  }\\n\",\n    \" ],\\n\",\n    \" \\\"metadata\\\": {\\n\",\n    \"  \\\"application/vnd.databricks.v1+notebook\\\": {\\n\",\n    \"   \\\"dashboards\\\": [\\n\",\n    \"    {\\n\",\n    \"     \\\"elements\\\": [],\\n\",\n    \"     \\\"globalVars\\\": {},\\n\",\n    \"     \\\"guid\\\": \\\"\\\",\\n\",\n    \"     \\\"layoutOption\\\": {\\n\",\n    \"      \\\"grid\\\": true,\\n\",\n    \"      \\\"stack\\\": true\\n\",\n    \"     },\\n\",\n    \"     \\\"nuid\\\": \\\"91c1bfb2-695a-4e5c-8a25-848a433108dc\\\",\\n\",\n    \"     \\\"origId\\\": 2173122769183713,\\n\",\n    \"     \\\"title\\\": \\\"Executive View\\\",\\n\",\n    \"     \\\"version\\\": \\\"DashboardViewV1\\\",\\n\",\n    \"     \\\"width\\\": 1600\\n\",\n    \"    },\\n\",\n    \"    {\\n\",\n    \"     \\\"elements\\\": [],\\n\",\n    \"     \\\"globalVars\\\": {},\\n\",\n    \"     \\\"guid\\\": \\\"\\\",\\n\",\n    \"     \\\"layoutOption\\\": {\\n\",\n    \"      \\\"grid\\\": true,\\n\",\n    \"      \\\"stack\\\": true\\n\",\n    \"     },\\n\",\n    \"     \\\"nuid\\\": \\\"62243296-4562-4f06-90ac-d7a609f19c16\\\",\\n\",\n    \"     \\\"origId\\\": 2173122769183714,\\n\",\n    \"     \\\"title\\\": \\\"App View\\\",\\n\",\n    \"     \\\"version\\\": \\\"DashboardViewV1\\\",\\n\",\n    \"     \\\"width\\\": 1920\\n\",\n    \"    }\\n\",\n    \"   ],\\n\",\n    \"   \\\"environmentMetadata\\\": null,\\n\",\n    \"   \\\"language\\\": \\\"python\\\",\\n\",\n    \"   \\\"notebookMetadata\\\": {\\n\",\n    \"    \\\"mostRecentlyExecutedCommandWithImplicitDF\\\": {\\n\",\n    \"     \\\"commandId\\\": 2173122769183692,\\n\",\n    \"     \\\"dataframes\\\": [\\n\",\n    \"      \\\"_sqldf\\\"\\n\",\n    \"     ]\\n\",\n    \"    },\\n\",\n    \"    \\\"pythonIndentUnit\\\": 2,\\n\",\n    \"    \\\"widgetLayout\\\": [\\n\",\n    \"     {\\n\",\n    \"      \\\"breakBefore\\\": false,\\n\",\n    \"      \\\"name\\\": \\\"Eventlog Path\\\",\\n\",\n    \"      \\\"width\\\": 778\\n\",\n    \"     },\\n\",\n    \"     {\\n\",\n    \"      \\\"breakBefore\\\": false,\\n\",\n    \"      \\\"name\\\": \\\"Output Path\\\",\\n\",\n    \"      \\\"width\\\": 302\\n\",\n    \"     }\\n\",\n    \"    ]\\n\",\n    \"   },\\n\",\n    \"   \\\"notebookName\\\": \\\"[RAPIDS Accelerator for Apache Spark] Profiling Tool Notebook Template\\\",\\n\",\n    \"   \\\"widgets\\\": {\\n\",\n    \"    \\\"Eventlog Path\\\": {\\n\",\n    \"     \\\"currentValue\\\": \\\"/dbfs/user1/profiling_logs\\\",\\n\",\n    \"     \\\"nuid\\\": \\\"1272501d-5ad9-42be-ab62-35768b2fc384\\\",\\n\",\n    \"     \\\"typedWidgetInfo\\\": null,\\n\",\n    \"     \\\"widgetInfo\\\": {\\n\",\n    \"      \\\"defaultValue\\\": \\\"/dbfs/user1/profiling_logs\\\",\\n\",\n    \"      \\\"label\\\": \\\"\\\",\\n\",\n    \"      \\\"name\\\": \\\"Eventlog Path\\\",\\n\",\n    \"      \\\"options\\\": {\\n\",\n    \"       \\\"autoCreated\\\": false,\\n\",\n    \"       \\\"validationRegex\\\": null,\\n\",\n    \"       \\\"widgetType\\\": \\\"text\\\"\\n\",\n    \"      },\\n\",\n    \"      \\\"widgetType\\\": \\\"text\\\"\\n\",\n    \"     }\\n\",\n    \"    },\\n\",\n    \"    \\\"Output Path\\\": {\\n\",\n    \"     \\\"currentValue\\\": \\\"/tmp\\\",\\n\",\n    \"     \\\"nuid\\\": \\\"ab7e082c-1ef9-4912-8fd7-51bf985eb9c1\\\",\\n\",\n    \"     \\\"typedWidgetInfo\\\": null,\\n\",\n    \"     \\\"widgetInfo\\\": {\\n\",\n    \"      \\\"defaultValue\\\": \\\"/tmp\\\",\\n\",\n    \"      \\\"label\\\": null,\\n\",\n    \"      \\\"name\\\": \\\"Output Path\\\",\\n\",\n    \"      \\\"options\\\": {\\n\",\n    \"       \\\"autoCreated\\\": null,\\n\",\n    \"       \\\"validationRegex\\\": null,\\n\",\n    \"       \\\"widgetType\\\": \\\"text\\\"\\n\",\n    \"      },\\n\",\n    \"      \\\"widgetType\\\": \\\"text\\\"\\n\",\n    \"     }\\n\",\n    \"    }\\n\",\n    \"   }\\n\",\n    \"  },\\n\",\n    \"  \\\"language_info\\\": {\\n\",\n    \"   \\\"name\\\": \\\"python\\\"\\n\",\n    \"  }\\n\",\n    \" },\\n\",\n    \" \\\"nbformat\\\": 4,\\n\",\n    \" \\\"nbformat_minor\\\": 0\\n\",\n    \"}\\n\"\n   ],\n   \"id\": \"4e6b53d2fff1910e\"\n  }\n ],\n \"metadata\": {},\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "tools/databricks/[RAPIDS Accelerator for Apache Spark] Qualification Tool Notebook Template.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"metadata\": {},\n   \"cell_type\": \"raw\",\n   \"source\": [\n    \"{\\n\",\n    \" \\\"cells\\\": [\\n\",\n    \"  {\\n\",\n    \"   \\\"cell_type\\\": \\\"markdown\\\",\\n\",\n    \"   \\\"metadata\\\": {\\n\",\n    \"    \\\"application/vnd.databricks.v1+cell\\\": {\\n\",\n    \"     \\\"cellMetadata\\\": {\\n\",\n    \"      \\\"byteLimit\\\": 2048000,\\n\",\n    \"      \\\"rowLimit\\\": 10000\\n\",\n    \"     },\\n\",\n    \"     \\\"inputWidgets\\\": {},\\n\",\n    \"     \\\"nuid\\\": \\\"df33c614-2ecc-47a0-8600-bc891681997f\\\",\\n\",\n    \"     \\\"showTitle\\\": false,\\n\",\n    \"     \\\"title\\\": \\\"\\\"\\n\",\n    \"    }\\n\",\n    \"   },\\n\",\n    \"   \\\"source\\\": [\\n\",\n    \"    \\\"## Welcome to the Qualification Tool for the RAPIDS Accelerator for Apache Spark\\\\n\\\",\\n\",\n    \"    \\\"\\\\n\\\",\\n\",\n    \"    \\\"To run the qualification tool, enter the log path that represents the DBFS location of your Spark GPU event logs. Then, select \\\\\\\"Run all\\\\\\\" to execute the notebook. Once the notebook completes, various output tables will appear below. For more options on running the profiling tool, please refer to the [Qualification Tool User Guide](https://docs.nvidia.com/spark-rapids/user-guide/latest/qualification/quickstart.html#running-the-tool).\\\\n\\\",\\n\",\n    \"    \\\"\\\\n\\\",\\n\",\n    \"    \\\"### Note\\\\n\\\",\\n\",\n    \"    \\\"- Currently, local, S3 or DBFS event log paths are supported.\\\\n\\\",\\n\",\n    \"    \\\"- S3 path is only supported on Databricks AWS using [instance profiles](https://docs.databricks.com/en/connect/storage/tutorial-s3-instance-profile.html).\\\\n\\\",\\n\",\n    \"    \\\"- Eventlog path must follow the formats `/dbfs/path/to/eventlog` or `dbfs:/path/to/eventlog` for logs stored in DBFS.\\\\n\\\",\\n\",\n    \"    \\\"- Use wildcards for nested lookup of eventlogs. \\\\n\\\",\\n\",\n    \"    \\\"   - For example: `/dbfs/path/to/clusterlogs/*/*`\\\\n\\\",\\n\",\n    \"    \\\"- Multiple event logs must be comma-separated. \\\\n\\\",\\n\",\n    \"    \\\"   - For example: `/dbfs/path/to/eventlog1,/dbfs/path/to/eventlog2`\\\\n\\\"\\n\",\n    \"   ]\\n\",\n    \"  },\\n\",\n    \"  {\\n\",\n    \"   \\\"cell_type\\\": \\\"code\\\",\\n\",\n    \"   \\\"execution_count\\\": 0,\\n\",\n    \"   \\\"metadata\\\": {\\n\",\n    \"    \\\"application/vnd.databricks.v1+cell\\\": {\\n\",\n    \"     \\\"cellMetadata\\\": {\\n\",\n    \"      \\\"byteLimit\\\": 2048000,\\n\",\n    \"      \\\"rowLimit\\\": 10000\\n\",\n    \"     },\\n\",\n    \"     \\\"inputWidgets\\\": {},\\n\",\n    \"     \\\"nuid\\\": \\\"5e9f5796-46ed-49ac-9d08-c8b98a87c39d\\\",\\n\",\n    \"     \\\"showTitle\\\": true,\\n\",\n    \"     \\\"title\\\": \\\"Set Tools Version\\\"\\n\",\n    \"    },\\n\",\n    \"    \\\"jupyter\\\": {\\n\",\n    \"     \\\"source_hidden\\\": true\\n\",\n    \"    }\\n\",\n    \"   },\\n\",\n    \"   \\\"outputs\\\": [],\\n\",\n    \"   \\\"source\\\": [\\n\",\n    \"    \\\"DEFAULT_TOOLS_VER = \\\\\\\"24.12.4\\\\\\\"\\\\n\\\",\\n\",\n    \"    \\\"TOOLS_VER_ARG = dbutils.widgets.get(\\\\\\\"Tools Version\\\\\\\")\\\\n\\\",\\n\",\n    \"    \\\"TOOLS_VER = TOOLS_VER_ARG if TOOLS_VER_ARG else DEFAULT_TOOLS_VER\\\\n\\\",\\n\",\n    \"    \\\"print(f\\\\\\\"Using Tools Version: {TOOLS_VER}\\\\\\\")\\\"\\n\",\n    \"   ]\\n\",\n    \"  },\\n\",\n    \"  {\\n\",\n    \"   \\\"cell_type\\\": \\\"code\\\",\\n\",\n    \"   \\\"execution_count\\\": 0,\\n\",\n    \"   \\\"metadata\\\": {\\n\",\n    \"    \\\"application/vnd.databricks.v1+cell\\\": {\\n\",\n    \"     \\\"cellMetadata\\\": {\\n\",\n    \"      \\\"byteLimit\\\": 2048000,\\n\",\n    \"      \\\"rowLimit\\\": 10000\\n\",\n    \"     },\\n\",\n    \"     \\\"inputWidgets\\\": {},\\n\",\n    \"     \\\"nuid\\\": \\\"313ee58b-61b3-4010-9d60-d21eceea796c\\\",\\n\",\n    \"     \\\"showTitle\\\": true,\\n\",\n    \"     \\\"title\\\": \\\"Install Package\\\"\\n\",\n    \"    }\\n\",\n    \"   },\\n\",\n    \"   \\\"outputs\\\": [],\\n\",\n    \"   \\\"source\\\": [\\n\",\n    \"    \\\"%pip install spark-rapids-user-tools==$TOOLS_VER > /dev/null\\\"\\n\",\n    \"   ]\\n\",\n    \"  },\\n\",\n    \"  {\\n\",\n    \"   \\\"cell_type\\\": \\\"code\\\",\\n\",\n    \"   \\\"execution_count\\\": 0,\\n\",\n    \"   \\\"metadata\\\": {\\n\",\n    \"    \\\"application/vnd.databricks.v1+cell\\\": {\\n\",\n    \"     \\\"cellMetadata\\\": {\\n\",\n    \"      \\\"byteLimit\\\": 2048000,\\n\",\n    \"      \\\"rowLimit\\\": 10000\\n\",\n    \"     },\\n\",\n    \"     \\\"inputWidgets\\\": {},\\n\",\n    \"     \\\"nuid\\\": \\\"acf401a3-12d3-4236-a6c5-8fe8990b153a\\\",\\n\",\n    \"     \\\"showTitle\\\": true,\\n\",\n    \"     \\\"title\\\": \\\"Environment Setup\\\"\\n\",\n    \"    },\\n\",\n    \"    \\\"jupyter\\\": {\\n\",\n    \"     \\\"source_hidden\\\": true\\n\",\n    \"    }\\n\",\n    \"   },\\n\",\n    \"   \\\"outputs\\\": [],\\n\",\n    \"   \\\"source\\\": [\\n\",\n    \"    \\\"import os\\\\n\\\",\\n\",\n    \"    \\\"import pandas as pd\\\\n\\\",\\n\",\n    \"    \\\"\\\\n\\\",\\n\",\n    \"    \\\"\\\\n\\\",\\n\",\n    \"    \\\"def convert_dbfs_path(path):\\\\n\\\",\\n\",\n    \"    \\\"    return path.replace(\\\\\\\"dbfs:/\\\\\\\", \\\\\\\"/dbfs/\\\\\\\")\\\\n\\\",\\n\",\n    \"    \\\"  \\\\n\\\",\\n\",\n    \"    \\\"# Detect cloud provider from cluster usage tags\\\\n\\\",\\n\",\n    \"    \\\"valid_csps = [\\\\\\\"aws\\\\\\\", \\\\\\\"azure\\\\\\\"]\\\\n\\\",\\n\",\n    \"    \\\"CSP=spark.conf.get(\\\\\\\"spark.databricks.clusterUsageTags.cloudProvider\\\\\\\", \\\\\\\"\\\\\\\").lower()\\\\n\\\",\\n\",\n    \"    \\\"if CSP not in valid_csps:\\\\n\\\",\\n\",\n    \"    \\\"    print(f\\\\\\\"ERROR: Cannot detect cloud provider from cluster usage tags. Using '{valid_csps[0]}' as default. \\\\\\\")\\\\n\\\",\\n\",\n    \"    \\\"    CSP = valid_csps[0]\\\\n\\\",\\n\",\n    \"    \\\"else:\\\\n\\\",\\n\",\n    \"    \\\"    print(f\\\\\\\"Detected Cloud Provider from Spark Configs: '{CSP}'\\\\\\\")\\\\n\\\",\\n\",\n    \"    \\\"\\\\n\\\",\\n\",\n    \"    \\\"# Initialize variables from widgets\\\\n\\\",\\n\",\n    \"    \\\"dbutils.widgets.text(\\\\\\\"Eventlog Path\\\\\\\", \\\\\\\"/dbfs/user1/qualification_logs\\\\\\\")\\\\n\\\",\\n\",\n    \"    \\\"EVENTLOG_PATH=dbutils.widgets.get(\\\\\\\"Eventlog Path\\\\\\\")\\\\n\\\",\\n\",\n    \"    \\\"EVENTLOG_PATH=convert_dbfs_path(EVENTLOG_PATH)\\\\n\\\",\\n\",\n    \"    \\\"\\\\n\\\",\\n\",\n    \"    \\\"dbutils.widgets.text(\\\\\\\"Output Path\\\\\\\", \\\\\\\"/tmp\\\\\\\")\\\\n\\\",\\n\",\n    \"    \\\"OUTPUT_PATH=dbutils.widgets.get(\\\\\\\"Output Path\\\\\\\")\\\\n\\\",\\n\",\n    \"    \\\"\\\\n\\\",\\n\",\n    \"    \\\"  \\\\n\\\",\\n\",\n    \"    \\\"# Setup environment variables\\\\n\\\",\\n\",\n    \"    \\\"os.environ[\\\\\\\"CSP\\\\\\\"] = CSP\\\\n\\\",\\n\",\n    \"    \\\"os.environ[\\\\\\\"EVENTLOG_PATH\\\\\\\"] = EVENTLOG_PATH\\\\n\\\",\\n\",\n    \"    \\\"os.environ[\\\\\\\"OUTPUT_PATH\\\\\\\"] = OUTPUT_PATH\\\\n\\\",\\n\",\n    \"    \\\"\\\\n\\\",\\n\",\n    \"    \\\"# Setup console output file\\\\n\\\",\\n\",\n    \"    \\\"CONSOLE_OUTPUT_PATH = os.path.join(OUTPUT_PATH, 'console_output.log')\\\\n\\\",\\n\",\n    \"    \\\"CONSOLE_ERROR_PATH = os.path.join(OUTPUT_PATH, 'console_error.log')\\\\n\\\",\\n\",\n    \"    \\\"os.environ['CONSOLE_OUTPUT_PATH'] = CONSOLE_OUTPUT_PATH\\\\n\\\",\\n\",\n    \"    \\\"os.environ['CONSOLE_ERROR_PATH'] = CONSOLE_ERROR_PATH\\\\n\\\",\\n\",\n    \"    \\\"print(f'Console output will be stored at {CONSOLE_OUTPUT_PATH} and errors will be stored at {CONSOLE_ERROR_PATH}')\\\"\\n\",\n    \"   ]\\n\",\n    \"  },\\n\",\n    \"  {\\n\",\n    \"   \\\"cell_type\\\": \\\"code\\\",\\n\",\n    \"   \\\"execution_count\\\": 0,\\n\",\n    \"   \\\"metadata\\\": {\\n\",\n    \"    \\\"application/vnd.databricks.v1+cell\\\": {\\n\",\n    \"     \\\"cellMetadata\\\": {\\n\",\n    \"      \\\"byteLimit\\\": 2048000,\\n\",\n    \"      \\\"rowLimit\\\": 10000\\n\",\n    \"     },\\n\",\n    \"     \\\"inputWidgets\\\": {},\\n\",\n    \"     \\\"nuid\\\": \\\"693b5ee0-7500-43f3-b3e2-717fd5468aa8\\\",\\n\",\n    \"     \\\"showTitle\\\": true,\\n\",\n    \"     \\\"title\\\": \\\"Run Qualification Tool\\\"\\n\",\n    \"    }\\n\",\n    \"   },\\n\",\n    \"   \\\"outputs\\\": [],\\n\",\n    \"   \\\"source\\\": [\\n\",\n    \"    \\\"%sh\\\\n\\\",\\n\",\n    \"    \\\"spark_rapids qualification --platform databricks-$CSP --eventlogs \\\\\\\"$EVENTLOG_PATH\\\\\\\" -o \\\\\\\"$OUTPUT_PATH\\\\\\\" --verbose > \\\\\\\"$CONSOLE_OUTPUT_PATH\\\\\\\" 2> \\\\\\\"$CONSOLE_ERROR_PATH\\\\\\\"\\\"\\n\",\n    \"   ]\\n\",\n    \"  },\\n\",\n    \"  {\\n\",\n    \"   \\\"cell_type\\\": \\\"markdown\\\",\\n\",\n    \"   \\\"metadata\\\": {\\n\",\n    \"    \\\"application/vnd.databricks.v1+cell\\\": {\\n\",\n    \"     \\\"cellMetadata\\\": {\\n\",\n    \"      \\\"byteLimit\\\": 2048000,\\n\",\n    \"      \\\"rowLimit\\\": 10000\\n\",\n    \"     },\\n\",\n    \"     \\\"inputWidgets\\\": {},\\n\",\n    \"     \\\"nuid\\\": \\\"f83af6c8-5a79-4a46-965b-38a4cb621877\\\",\\n\",\n    \"     \\\"showTitle\\\": false,\\n\",\n    \"     \\\"title\\\": \\\"\\\"\\n\",\n    \"    }\\n\",\n    \"   },\\n\",\n    \"   \\\"source\\\": [\\n\",\n    \"    \\\"## Console Output\\\\n\\\",\\n\",\n    \"    \\\"Console output shows the top candidates and their estimated GPU speedup.\\\\n\\\"\\n\",\n    \"   ]\\n\",\n    \"  },\\n\",\n    \"  {\\n\",\n    \"   \\\"cell_type\\\": \\\"code\\\",\\n\",\n    \"   \\\"execution_count\\\": 0,\\n\",\n    \"   \\\"metadata\\\": {\\n\",\n    \"    \\\"application/vnd.databricks.v1+cell\\\": {\\n\",\n    \"     \\\"cellMetadata\\\": {\\n\",\n    \"      \\\"byteLimit\\\": 2048000,\\n\",\n    \"      \\\"rowLimit\\\": 10000\\n\",\n    \"     },\\n\",\n    \"     \\\"inputWidgets\\\": {},\\n\",\n    \"     \\\"nuid\\\": \\\"c61527b7-a21a-492c-bab8-77f83dc5cabf\\\",\\n\",\n    \"     \\\"showTitle\\\": true,\\n\",\n    \"     \\\"title\\\": \\\"Show Console Output\\\"\\n\",\n    \"    },\\n\",\n    \"    \\\"jupyter\\\": {\\n\",\n    \"     \\\"source_hidden\\\": true\\n\",\n    \"    }\\n\",\n    \"   },\\n\",\n    \"   \\\"outputs\\\": [],\\n\",\n    \"   \\\"source\\\": [\\n\",\n    \"    \\\"%sh\\\\n\\\",\\n\",\n    \"    \\\"cat $CONSOLE_OUTPUT_PATH\\\"\\n\",\n    \"   ]\\n\",\n    \"  },\\n\",\n    \"  {\\n\",\n    \"   \\\"cell_type\\\": \\\"code\\\",\\n\",\n    \"   \\\"execution_count\\\": 0,\\n\",\n    \"   \\\"metadata\\\": {\\n\",\n    \"    \\\"application/vnd.databricks.v1+cell\\\": {\\n\",\n    \"     \\\"cellMetadata\\\": {\\n\",\n    \"      \\\"byteLimit\\\": 2048000,\\n\",\n    \"      \\\"rowLimit\\\": 10000\\n\",\n    \"     },\\n\",\n    \"     \\\"inputWidgets\\\": {},\\n\",\n    \"     \\\"nuid\\\": \\\"f3c68b28-fc62-40ae-8528-799f3fc7507e\\\",\\n\",\n    \"     \\\"showTitle\\\": true,\\n\",\n    \"     \\\"title\\\": \\\"Show Logs\\\"\\n\",\n    \"    },\\n\",\n    \"    \\\"jupyter\\\": {\\n\",\n    \"     \\\"source_hidden\\\": true\\n\",\n    \"    }\\n\",\n    \"   },\\n\",\n    \"   \\\"outputs\\\": [],\\n\",\n    \"   \\\"source\\\": [\\n\",\n    \"    \\\"%sh\\\\n\\\",\\n\",\n    \"    \\\"cat $CONSOLE_ERROR_PATH\\\"\\n\",\n    \"   ]\\n\",\n    \"  },\\n\",\n    \"  {\\n\",\n    \"   \\\"cell_type\\\": \\\"code\\\",\\n\",\n    \"   \\\"execution_count\\\": 0,\\n\",\n    \"   \\\"metadata\\\": {\\n\",\n    \"    \\\"application/vnd.databricks.v1+cell\\\": {\\n\",\n    \"     \\\"cellMetadata\\\": {\\n\",\n    \"      \\\"byteLimit\\\": 2048000,\\n\",\n    \"      \\\"rowLimit\\\": 10000\\n\",\n    \"     },\\n\",\n    \"     \\\"inputWidgets\\\": {},\\n\",\n    \"     \\\"nuid\\\": \\\"05f96ca1-1b08-494c-a12b-7e6cc3dcc546\\\",\\n\",\n    \"     \\\"showTitle\\\": true,\\n\",\n    \"     \\\"title\\\": \\\"Parse Output\\\"\\n\",\n    \"    },\\n\",\n    \"    \\\"jupyter\\\": {\\n\",\n    \"     \\\"source_hidden\\\": true\\n\",\n    \"    }\\n\",\n    \"   },\\n\",\n    \"   \\\"outputs\\\": [],\\n\",\n    \"   \\\"source\\\": [\\n\",\n    \"    \\\"import re\\\\n\\\",\\n\",\n    \"    \\\"import shutil\\\\n\\\",\\n\",\n    \"    \\\"import os\\\\n\\\",\\n\",\n    \"    \\\"\\\\n\\\",\\n\",\n    \"    \\\"def extract_file_info(console_output_path, output_base_path):\\\\n\\\",\\n\",\n    \"    \\\"    try:\\\\n\\\",\\n\",\n    \"    \\\"        with open(console_output_path, 'r') as file:\\\\n\\\",\\n\",\n    \"    \\\"            stdout_text = file.read()\\\\n\\\",\\n\",\n    \"    \\\"        \\\\n\\\",\\n\",\n    \"    \\\"        # Extract log file location\\\\n\\\",\\n\",\n    \"    \\\"        location_match = re.search(r\\\\\\\"Location: (.+)\\\\\\\", stdout_text)\\\\n\\\",\\n\",\n    \"    \\\"        if not location_match:\\\\n\\\",\\n\",\n    \"    \\\"            raise ValueError(\\\\\\\"Log file location not found in the provided text.\\\\\\\")\\\\n\\\",\\n\",\n    \"    \\\"        \\\\n\\\",\\n\",\n    \"    \\\"        log_file_location = location_match.group(1)\\\\n\\\",\\n\",\n    \"    \\\"        \\\\n\\\",\\n\",\n    \"    \\\"        # Extract qualification output folder\\\\n\\\",\\n\",\n    \"    \\\"        qual_match = re.search(r\\\\\\\"qual_[^/]+(?=\\\\\\\\.log)\\\\\\\", log_file_location)\\\\n\\\",\\n\",\n    \"    \\\"        if not qual_match:\\\\n\\\",\\n\",\n    \"    \\\"            raise ValueError(\\\\\\\"Output folder not found in the log file location.\\\\\\\")\\\\n\\\",\\n\",\n    \"    \\\"        \\\\n\\\",\\n\",\n    \"    \\\"        output_folder_name = qual_match.group(0)\\\\n\\\",\\n\",\n    \"    \\\"        output_folder = os.path.join(output_base_path, output_folder_name)\\\\n\\\",\\n\",\n    \"    \\\"        return output_folder, log_file_location\\\\n\\\",\\n\",\n    \"    \\\"    \\\\n\\\",\\n\",\n    \"    \\\"    except Exception as e:\\\\n\\\",\\n\",\n    \"    \\\"        raise RuntimeError(f\\\\\\\"Cannot parse console output. Reason: {e}\\\\\\\")\\\\n\\\",\\n\",\n    \"    \\\"\\\\n\\\",\\n\",\n    \"    \\\"def copy_logs(destination_folder, *log_files):\\\\n\\\",\\n\",\n    \"    \\\"    try:\\\\n\\\",\\n\",\n    \"    \\\"        log_folder = os.path.join(destination_folder, \\\\\\\"logs\\\\\\\")\\\\n\\\",\\n\",\n    \"    \\\"        os.makedirs(log_folder, exist_ok=True)\\\\n\\\",\\n\",\n    \"    \\\"        \\\\n\\\",\\n\",\n    \"    \\\"        for log_file in log_files:\\\\n\\\",\\n\",\n    \"    \\\"            if os.path.exists(log_file):\\\\n\\\",\\n\",\n    \"    \\\"                shutil.copy2(log_file, log_folder)\\\\n\\\",\\n\",\n    \"    \\\"            else:\\\\n\\\",\\n\",\n    \"    \\\"                print(f\\\\\\\"Log file not found: {log_file}\\\\\\\")\\\\n\\\",\\n\",\n    \"    \\\"    except Exception as e:\\\\n\\\",\\n\",\n    \"    \\\"        raise RuntimeError(f\\\\\\\"Cannot copy logs to output. Reason: {e}\\\\\\\")\\\\n\\\",\\n\",\n    \"    \\\"\\\\n\\\",\\n\",\n    \"    \\\"try:\\\\n\\\",\\n\",\n    \"    \\\"    output_folder, log_file_location = extract_file_info(CONSOLE_OUTPUT_PATH, OUTPUT_PATH)\\\\n\\\",\\n\",\n    \"    \\\"    jar_output_folder = os.path.join(output_folder, \\\\\\\"rapids_4_spark_qualification_output\\\\\\\")\\\\n\\\",\\n\",\n    \"    \\\"    print(f\\\\\\\"Output folder detected {output_folder}\\\\\\\")\\\\n\\\",\\n\",\n    \"    \\\"    copy_logs(output_folder, log_file_location, CONSOLE_OUTPUT_PATH, CONSOLE_ERROR_PATH)\\\\n\\\",\\n\",\n    \"    \\\"    print(f\\\\\\\"Logs successfully copied to {output_folder}\\\\\\\")\\\\n\\\",\\n\",\n    \"    \\\"except Exception as e:\\\\n\\\",\\n\",\n    \"    \\\"    print(e)\\\\n\\\"\\n\",\n    \"   ]\\n\",\n    \"  },\\n\",\n    \"  {\\n\",\n    \"   \\\"cell_type\\\": \\\"code\\\",\\n\",\n    \"   \\\"execution_count\\\": 0,\\n\",\n    \"   \\\"metadata\\\": {\\n\",\n    \"    \\\"application/vnd.databricks.v1+cell\\\": {\\n\",\n    \"     \\\"cellMetadata\\\": {\\n\",\n    \"      \\\"byteLimit\\\": 2048000,\\n\",\n    \"      \\\"rowLimit\\\": 10000\\n\",\n    \"     },\\n\",\n    \"     \\\"inputWidgets\\\": {},\\n\",\n    \"     \\\"nuid\\\": \\\"8c65adcd-a933-482e-a50b-d40fa8f50e16\\\",\\n\",\n    \"     \\\"showTitle\\\": true,\\n\",\n    \"     \\\"title\\\": \\\"Download Output\\\"\\n\",\n    \"    },\\n\",\n    \"    \\\"jupyter\\\": {\\n\",\n    \"     \\\"source_hidden\\\": true\\n\",\n    \"    }\\n\",\n    \"   },\\n\",\n    \"   \\\"outputs\\\": [],\\n\",\n    \"   \\\"source\\\": [\\n\",\n    \"    \\\"import shutil\\\\n\\\",\\n\",\n    \"    \\\"import os\\\\n\\\",\\n\",\n    \"    \\\"import re\\\\n\\\",\\n\",\n    \"    \\\"\\\\n\\\",\\n\",\n    \"    \\\"current_working_directory = os.getcwd()\\\\n\\\",\\n\",\n    \"    \\\"\\\\n\\\",\\n\",\n    \"    \\\"def create_destination_folders(folder_name):\\\\n\\\",\\n\",\n    \"    \\\"    os.makedirs(folder_name, exist_ok=True)\\\\n\\\",\\n\",\n    \"    \\\"    base_download_folder_path = os.path.join(\\\\\\\"/dbfs/FileStore/\\\\\\\", folder_name)\\\\n\\\",\\n\",\n    \"    \\\"    os.makedirs(base_download_folder_path, exist_ok=True) \\\\n\\\",\\n\",\n    \"    \\\"    return base_download_folder_path\\\\n\\\",\\n\",\n    \"    \\\"\\\\n\\\",\\n\",\n    \"    \\\"def create_download_link(source_folder, destination_folder_name):\\\\n\\\",\\n\",\n    \"    \\\"    folder_to_compress = os.path.basename(source_folder)\\\\n\\\",\\n\",\n    \"    \\\"    zip_file_name = folder_to_compress + '.zip'\\\\n\\\",\\n\",\n    \"    \\\"    local_zip_file_path = os.path.join(current_working_directory, destination_folder_name, zip_file_name)\\\\n\\\",\\n\",\n    \"    \\\"    download_folder_path = os.path.join(destination_folder_name, zip_file_name)\\\\n\\\",\\n\",\n    \"    \\\"    try:\\\\n\\\",\\n\",\n    \"    \\\"        base_download_folder_path = create_destination_folders(destination_folder_name)\\\\n\\\",\\n\",\n    \"    \\\"        shutil.make_archive(folder_to_compress, 'zip', source_folder)\\\\n\\\",\\n\",\n    \"    \\\"        shutil.copy2(zip_file_name, base_download_folder_path)\\\\n\\\",\\n\",\n    \"    \\\"        if os.path.exists(local_zip_file_path):\\\\n\\\",\\n\",\n    \"    \\\"            os.remove(local_zip_file_path)\\\\n\\\",\\n\",\n    \"    \\\"        shutil.move(zip_file_name, local_zip_file_path)\\\\n\\\",\\n\",\n    \"    \\\"    \\\\n\\\",\\n\",\n    \"    \\\"        download_button_html = f\\\\\\\"\\\\\\\"\\\\\\\"\\\\n\\\",\\n\",\n    \"    \\\"        <style>\\\\n\\\",\\n\",\n    \"    \\\"            .download-btn {{\\\\n\\\",\\n\",\n    \"    \\\"                display: inline-block;\\\\n\\\",\\n\",\n    \"    \\\"                padding: 10px 20px;\\\\n\\\",\\n\",\n    \"    \\\"                font-size: 16px;\\\\n\\\",\\n\",\n    \"    \\\"                color: white;\\\\n\\\",\\n\",\n    \"    \\\"                background-color: #4CAF50;\\\\n\\\",\\n\",\n    \"    \\\"                text-align: center;\\\\n\\\",\\n\",\n    \"    \\\"                text-decoration: none;\\\\n\\\",\\n\",\n    \"    \\\"                border-radius: 5px;\\\\n\\\",\\n\",\n    \"    \\\"                border: none;\\\\n\\\",\\n\",\n    \"    \\\"                cursor: pointer;\\\\n\\\",\\n\",\n    \"    \\\"                margin: 15px auto;\\\\n\\\",\\n\",\n    \"    \\\"            }}\\\\n\\\",\\n\",\n    \"    \\\"            .download-btn:hover {{\\\\n\\\",\\n\",\n    \"    \\\"                background-color: #45a049;\\\\n\\\",\\n\",\n    \"    \\\"            }}\\\\n\\\",\\n\",\n    \"    \\\"            .button-container {{\\\\n\\\",\\n\",\n    \"    \\\"                display: flex;\\\\n\\\",\\n\",\n    \"    \\\"                justify-content: center;\\\\n\\\",\\n\",\n    \"    \\\"                align-items: center;\\\\n\\\",\\n\",\n    \"    \\\"            }}\\\\n\\\",\\n\",\n    \"    \\\"        </style>\\\\n\\\",\\n\",\n    \"    \\\"        \\\\n\\\",\\n\",\n    \"    \\\"        <div style=\\\\\\\"color: #444; font-size: 14px; text-align: center; margin: 10px;\\\\\\\">\\\\n\\\",\\n\",\n    \"    \\\"            Zipped output file created at {local_zip_file_path}\\\\n\\\",\\n\",\n    \"    \\\"        </div>\\\\n\\\",\\n\",\n    \"    \\\"        <div class='button-container'>\\\\n\\\",\\n\",\n    \"    \\\"            <a href='/files/{download_folder_path}' class='download-btn'>Download Output</a>\\\\n\\\",\\n\",\n    \"    \\\"        </div>\\\\n\\\",\\n\",\n    \"    \\\"        \\\\\\\"\\\\\\\"\\\\\\\"\\\\n\\\",\\n\",\n    \"    \\\"        displayHTML(download_button_html)\\\\n\\\",\\n\",\n    \"    \\\"    except Exception as e:\\\\n\\\",\\n\",\n    \"    \\\"        error_message_html = f\\\\\\\"\\\\\\\"\\\\\\\"\\\\n\\\",\\n\",\n    \"    \\\"        <div style=\\\\\\\"color: red; text-align: center; margin: 20px;\\\\\\\">\\\\n\\\",\\n\",\n    \"    \\\"            <strong>Error:</strong> Cannot create download link for {source_folder}. Reason: {e}\\\\n\\\",\\n\",\n    \"    \\\"        </div>\\\\n\\\",\\n\",\n    \"    \\\"        \\\\\\\"\\\\\\\"\\\\\\\"\\\\n\\\",\\n\",\n    \"    \\\"        displayHTML(error_message_html)\\\\n\\\",\\n\",\n    \"    \\\"\\\\n\\\",\\n\",\n    \"    \\\"destination_folder_name = \\\\\\\"Tools_Output\\\\\\\"\\\\n\\\",\\n\",\n    \"    \\\"create_download_link(output_folder, destination_folder_name)\\\\n\\\"\\n\",\n    \"   ]\\n\",\n    \"  },\\n\",\n    \"  {\\n\",\n    \"   \\\"cell_type\\\": \\\"markdown\\\",\\n\",\n    \"   \\\"metadata\\\": {\\n\",\n    \"    \\\"application/vnd.databricks.v1+cell\\\": {\\n\",\n    \"     \\\"cellMetadata\\\": {\\n\",\n    \"      \\\"byteLimit\\\": 2048000,\\n\",\n    \"      \\\"rowLimit\\\": 10000\\n\",\n    \"     },\\n\",\n    \"     \\\"inputWidgets\\\": {},\\n\",\n    \"     \\\"nuid\\\": \\\"bbe50fde-0bd6-4281-95fd-6a1ec6f17ab2\\\",\\n\",\n    \"     \\\"showTitle\\\": false,\\n\",\n    \"     \\\"title\\\": \\\"\\\"\\n\",\n    \"    }\\n\",\n    \"   },\\n\",\n    \"   \\\"source\\\": [\\n\",\n    \"    \\\"\\\\n\\\",\\n\",\n    \"    \\\"## Summary Output\\\\n\\\",\\n\",\n    \"    \\\"\\\\n\\\",\\n\",\n    \"    \\\"The report provides a comprehensive overview of the entire application execution, estimated speedup, including unsupported operators and non-SQL operations. By default, the applications and queries are sorted in descending order based on the following fields:\\\\n\\\",\\n\",\n    \"    \\\"\\\\n\\\",\\n\",\n    \"    \\\"- Estimated GPU Speedup Category\\\\n\\\",\\n\",\n    \"    \\\"- Estimated GPU Speedup\\\"\\n\",\n    \"   ]\\n\",\n    \"  },\\n\",\n    \"  {\\n\",\n    \"   \\\"cell_type\\\": \\\"code\\\",\\n\",\n    \"   \\\"execution_count\\\": 0,\\n\",\n    \"   \\\"metadata\\\": {\\n\",\n    \"    \\\"application/vnd.databricks.v1+cell\\\": {\\n\",\n    \"     \\\"cellMetadata\\\": {\\n\",\n    \"      \\\"byteLimit\\\": 2048000,\\n\",\n    \"      \\\"rowLimit\\\": 10000\\n\",\n    \"     },\\n\",\n    \"     \\\"inputWidgets\\\": {},\\n\",\n    \"     \\\"nuid\\\": \\\"b8bca4a6-16d8-4b60-ba7b-9aff64bdcaa1\\\",\\n\",\n    \"     \\\"showTitle\\\": true,\\n\",\n    \"     \\\"title\\\": \\\"qualification_summary.csv\\\"\\n\",\n    \"    },\\n\",\n    \"    \\\"jupyter\\\": {\\n\",\n    \"     \\\"source_hidden\\\": true\\n\",\n    \"    }\\n\",\n    \"   },\\n\",\n    \"   \\\"outputs\\\": [],\\n\",\n    \"   \\\"source\\\": [\\n\",\n    \"    \\\"summary_output=pd.read_csv(os.path.join(output_folder, \\\\\\\"qualification_summary.csv\\\\\\\"))\\\\n\\\",\\n\",\n    \"    \\\"summary_output=summary_output.drop(columns=[\\\\\\\"Unnamed: 0\\\\\\\"]).rename_axis('Index').reset_index()\\\\n\\\",\\n\",\n    \"    \\\"display(summary_output)\\\"\\n\",\n    \"   ]\\n\",\n    \"  },\\n\",\n    \"  {\\n\",\n    \"   \\\"cell_type\\\": \\\"markdown\\\",\\n\",\n    \"   \\\"metadata\\\": {\\n\",\n    \"    \\\"application/vnd.databricks.v1+cell\\\": {\\n\",\n    \"     \\\"cellMetadata\\\": {},\\n\",\n    \"     \\\"inputWidgets\\\": {},\\n\",\n    \"     \\\"nuid\\\": \\\"73b5e0b0-3a96-4cc6-8e6c-840e4b0d9d43\\\",\\n\",\n    \"     \\\"showTitle\\\": false,\\n\",\n    \"     \\\"title\\\": \\\"\\\"\\n\",\n    \"    }\\n\",\n    \"   },\\n\",\n    \"   \\\"source\\\": [\\n\",\n    \"    \\\"\\\\n\\\",\\n\",\n    \"    \\\"## Application Status\\\\n\\\",\\n\",\n    \"    \\\"\\\\n\\\",\\n\",\n    \"    \\\"The report show the status of each eventlog file that was provided\\\\n\\\"\\n\",\n    \"   ]\\n\",\n    \"  },\\n\",\n    \"  {\\n\",\n    \"   \\\"cell_type\\\": \\\"code\\\",\\n\",\n    \"   \\\"execution_count\\\": 0,\\n\",\n    \"   \\\"metadata\\\": {\\n\",\n    \"    \\\"application/vnd.databricks.v1+cell\\\": {\\n\",\n    \"     \\\"cellMetadata\\\": {\\n\",\n    \"      \\\"byteLimit\\\": 2048000,\\n\",\n    \"      \\\"rowLimit\\\": 10000\\n\",\n    \"     },\\n\",\n    \"     \\\"inputWidgets\\\": {},\\n\",\n    \"     \\\"nuid\\\": \\\"c9ffbfdb-dbb6-4736-b9cb-2ac457cc6714\\\",\\n\",\n    \"     \\\"showTitle\\\": true,\\n\",\n    \"     \\\"title\\\": \\\"rapids_4_spark_qualification_output_status.csv\\\"\\n\",\n    \"    },\\n\",\n    \"    \\\"jupyter\\\": {\\n\",\n    \"     \\\"source_hidden\\\": true\\n\",\n    \"    }\\n\",\n    \"   },\\n\",\n    \"   \\\"outputs\\\": [],\\n\",\n    \"   \\\"source\\\": [\\n\",\n    \"    \\\"status_output=pd.read_csv(os.path.join(jar_output_folder, \\\\\\\"rapids_4_spark_qualification_output_status.csv\\\\\\\"))\\\\n\\\",\\n\",\n    \"    \\\"display(status_output)\\\"\\n\",\n    \"   ]\\n\",\n    \"  },\\n\",\n    \"  {\\n\",\n    \"   \\\"cell_type\\\": \\\"markdown\\\",\\n\",\n    \"   \\\"metadata\\\": {\\n\",\n    \"    \\\"application/vnd.databricks.v1+cell\\\": {\\n\",\n    \"     \\\"cellMetadata\\\": {},\\n\",\n    \"     \\\"inputWidgets\\\": {},\\n\",\n    \"     \\\"nuid\\\": \\\"09945d39-f9c2-4f4a-8afd-4f309f24f8e0\\\",\\n\",\n    \"     \\\"showTitle\\\": false,\\n\",\n    \"     \\\"title\\\": \\\"\\\"\\n\",\n    \"    }\\n\",\n    \"   },\\n\",\n    \"   \\\"source\\\": [\\n\",\n    \"    \\\"\\\\n\\\",\\n\",\n    \"    \\\"## Metadata for Migration\\\\n\\\",\\n\",\n    \"    \\\"\\\\n\\\",\\n\",\n    \"    \\\"The report show the metadata of each app as:\\\\n\\\",\\n\",\n    \"    \\\"- Recommended GPU cluster\\\\n\\\",\\n\",\n    \"    \\\"- File location of full cluster config recommendations\\\\n\\\",\\n\",\n    \"    \\\"- File location of only Gpu specific config recommendations\\\\n\\\"\\n\",\n    \"   ]\\n\",\n    \"  },\\n\",\n    \"  {\\n\",\n    \"   \\\"cell_type\\\": \\\"code\\\",\\n\",\n    \"   \\\"execution_count\\\": 0,\\n\",\n    \"   \\\"metadata\\\": {\\n\",\n    \"    \\\"application/vnd.databricks.v1+cell\\\": {\\n\",\n    \"     \\\"cellMetadata\\\": {\\n\",\n    \"      \\\"byteLimit\\\": 2048000,\\n\",\n    \"      \\\"rowLimit\\\": 10000\\n\",\n    \"     },\\n\",\n    \"     \\\"inputWidgets\\\": {},\\n\",\n    \"     \\\"nuid\\\": \\\"133cf1bd-33b6-4a62-9ae2-5505717092d1\\\",\\n\",\n    \"     \\\"showTitle\\\": true,\\n\",\n    \"     \\\"title\\\": \\\"app_metadata.json\\\"\\n\",\n    \"    },\\n\",\n    \"    \\\"jupyter\\\": {\\n\",\n    \"     \\\"source_hidden\\\": true\\n\",\n    \"    }\\n\",\n    \"   },\\n\",\n    \"   \\\"outputs\\\": [],\\n\",\n    \"   \\\"source\\\": [\\n\",\n    \"    \\\"import json\\\\n\\\",\\n\",\n    \"    \\\"metadata_file = os.path.join(output_folder, \\\\\\\"app_metadata.json\\\\\\\")\\\\n\\\",\\n\",\n    \"    \\\"def camel_to_title(name):\\\\n\\\",\\n\",\n    \"    \\\"    return re.sub('([a-z])([A-Z])', r'\\\\\\\\1 \\\\\\\\2', name).title()\\\\n\\\",\\n\",\n    \"    \\\"  \\\\n\\\",\\n\",\n    \"    \\\"with open(metadata_file, 'r') as file:\\\\n\\\",\\n\",\n    \"    \\\"    json_data = json.load(file)\\\\n\\\",\\n\",\n    \"    \\\"\\\\n\\\",\\n\",\n    \"    \\\"df = pd.DataFrame(json_data)\\\\n\\\",\\n\",\n    \"    \\\"df['recommendedGpuCluster'] = df['clusterInfo'].apply(lambda x: x['recommendedCluster'])\\\\n\\\",\\n\",\n    \"    \\\"df['sourceCluster'] = df['clusterInfo'].apply(lambda x: x['sourceCluster'])\\\\n\\\",\\n\",\n    \"    \\\"df.drop(columns=['clusterInfo'], inplace=True)\\\\n\\\",\\n\",\n    \"    \\\"df = df[['appId', 'appName', 'estimatedGpuSpeedupCategory', 'recommendedGpuCluster', 'fullClusterConfigRecommendations', 'gpuConfigRecommendationBreakdown']]\\\\n\\\",\\n\",\n    \"    \\\"df.columns = [camel_to_title(col) for col in df.columns]\\\\n\\\",\\n\",\n    \"    \\\"display(df)\\\"\\n\",\n    \"   ]\\n\",\n    \"  },\\n\",\n    \"  {\\n\",\n    \"   \\\"cell_type\\\": \\\"markdown\\\",\\n\",\n    \"   \\\"metadata\\\": {\\n\",\n    \"    \\\"application/vnd.databricks.v1+cell\\\": {\\n\",\n    \"     \\\"cellMetadata\\\": {\\n\",\n    \"      \\\"byteLimit\\\": 2048000,\\n\",\n    \"      \\\"rowLimit\\\": 10000\\n\",\n    \"     },\\n\",\n    \"     \\\"inputWidgets\\\": {},\\n\",\n    \"     \\\"nuid\\\": \\\"6756159b-30ca-407a-ab6b-9c29ced01ea6\\\",\\n\",\n    \"     \\\"showTitle\\\": false,\\n\",\n    \"     \\\"title\\\": \\\"\\\"\\n\",\n    \"    }\\n\",\n    \"   },\\n\",\n    \"   \\\"source\\\": [\\n\",\n    \"    \\\"## Stages Output\\\\n\\\",\\n\",\n    \"    \\\"\\\\n\\\",\\n\",\n    \"    \\\"For each stage used in SQL operations, the Qualification tool generates the following information:\\\\n\\\",\\n\",\n    \"    \\\"\\\\n\\\",\\n\",\n    \"    \\\"1. App ID\\\\n\\\",\\n\",\n    \"    \\\"2. Stage ID\\\\n\\\",\\n\",\n    \"    \\\"3. Average Speedup Factor: The average estimated speed-up of all the operators in the given stage.\\\\n\\\",\\n\",\n    \"    \\\"4. Stage Task Duration: The amount of time spent in tasks of SQL DataFrame operations for the given stage.\\\\n\\\",\\n\",\n    \"    \\\"5. Unsupported Task Duration: The sum of task durations for the unsupported operators. For more details, see [Supported Operators](https://nvidia.github.io/spark-rapids/docs/supported_ops.html).\\\\n\\\",\\n\",\n    \"    \\\"6. Stage Estimated: Indicates if the stage duration had to be estimated (True or False).\\\\n\\\"\\n\",\n    \"   ]\\n\",\n    \"  },\\n\",\n    \"  {\\n\",\n    \"   \\\"cell_type\\\": \\\"code\\\",\\n\",\n    \"   \\\"execution_count\\\": 0,\\n\",\n    \"   \\\"metadata\\\": {\\n\",\n    \"    \\\"application/vnd.databricks.v1+cell\\\": {\\n\",\n    \"     \\\"cellMetadata\\\": {\\n\",\n    \"      \\\"byteLimit\\\": 2048000,\\n\",\n    \"      \\\"rowLimit\\\": 10000\\n\",\n    \"     },\\n\",\n    \"     \\\"inputWidgets\\\": {},\\n\",\n    \"     \\\"nuid\\\": \\\"cdde6177-db5f-434a-995b-776678a64a3a\\\",\\n\",\n    \"     \\\"showTitle\\\": true,\\n\",\n    \"     \\\"title\\\": \\\"rapids_4_spark_qualification_output_stages.csv\\\"\\n\",\n    \"    },\\n\",\n    \"    \\\"jupyter\\\": {\\n\",\n    \"     \\\"source_hidden\\\": true\\n\",\n    \"    }\\n\",\n    \"   },\\n\",\n    \"   \\\"outputs\\\": [],\\n\",\n    \"   \\\"source\\\": [\\n\",\n    \"    \\\"stages_output=pd.read_csv(os.path.join(jar_output_folder, \\\\\\\"rapids_4_spark_qualification_output_stages.csv\\\\\\\"))\\\\n\\\",\\n\",\n    \"    \\\"display(stages_output)\\\"\\n\",\n    \"   ]\\n\",\n    \"  },\\n\",\n    \"  {\\n\",\n    \"   \\\"cell_type\\\": \\\"markdown\\\",\\n\",\n    \"   \\\"metadata\\\": {\\n\",\n    \"    \\\"application/vnd.databricks.v1+cell\\\": {\\n\",\n    \"     \\\"cellMetadata\\\": {\\n\",\n    \"      \\\"byteLimit\\\": 2048000,\\n\",\n    \"      \\\"rowLimit\\\": 10000\\n\",\n    \"     },\\n\",\n    \"     \\\"inputWidgets\\\": {},\\n\",\n    \"     \\\"nuid\\\": \\\"4d7ce219-ae75-4a0c-a78c-4e7f25b8cd6f\\\",\\n\",\n    \"     \\\"showTitle\\\": false,\\n\",\n    \"     \\\"title\\\": \\\"\\\"\\n\",\n    \"    }\\n\",\n    \"   },\\n\",\n    \"   \\\"source\\\": [\\n\",\n    \"    \\\"## Execs Output\\\\n\\\",\\n\",\n    \"    \\\"\\\\n\\\",\\n\",\n    \"    \\\"The Qualification tool generates a report of the “Exec” in the “SparkPlan” or “Executor Nodes” along with the estimated acceleration on the GPU. Please refer to the [Supported Operators guide](https://nvidia.github.io/spark-rapids/docs/supported_ops.html) for more details on limitations on UDFs and unsupported operators.\\\\n\\\",\\n\",\n    \"    \\\"\\\\n\\\",\\n\",\n    \"    \\\"1. App ID\\\\n\\\",\\n\",\n    \"    \\\"2. SQL ID\\\\n\\\",\\n\",\n    \"    \\\"3. Exec Name: Example: Filter, HashAggregate\\\\n\\\",\\n\",\n    \"    \\\"4. Expression Name\\\\n\\\",\\n\",\n    \"    \\\"5. Task Speedup Factor: The average acceleration of the operators based on the original CPU duration of the operator divided by the GPU duration. The tool uses historical queries and benchmarks to estimate a speed-up at an individual operator level to calculate how much a specific operator would accelerate on GPU.\\\\n\\\",\\n\",\n    \"    \\\"6. Exec Duration: Wall-clock time measured from when the operator starts until it is completed.\\\\n\\\",\\n\",\n    \"    \\\"7. SQL Node ID\\\\n\\\",\\n\",\n    \"    \\\"8. Exec Is Supported: Indicates whether the Exec is supported by RAPIDS. Refer to the Supported Operators section for details.\\\\n\\\",\\n\",\n    \"    \\\"9. Exec Stages: An array of stage IDs.\\\\n\\\",\\n\",\n    \"    \\\"10. Exec Children\\\\n\\\",\\n\",\n    \"    \\\"11. Exec Children Node IDs\\\\n\\\",\\n\",\n    \"    \\\"12. Exec Should Remove: Indicates whether the Op is removed from the migrated plan.\\\\n\\\"\\n\",\n    \"   ]\\n\",\n    \"  },\\n\",\n    \"  {\\n\",\n    \"   \\\"cell_type\\\": \\\"code\\\",\\n\",\n    \"   \\\"execution_count\\\": 0,\\n\",\n    \"   \\\"metadata\\\": {\\n\",\n    \"    \\\"application/vnd.databricks.v1+cell\\\": {\\n\",\n    \"     \\\"cellMetadata\\\": {\\n\",\n    \"      \\\"byteLimit\\\": 2048000,\\n\",\n    \"      \\\"rowLimit\\\": 10000\\n\",\n    \"     },\\n\",\n    \"     \\\"inputWidgets\\\": {},\\n\",\n    \"     \\\"nuid\\\": \\\"998b0c51-0cb6-408e-a01a-d1f5b1a61e1f\\\",\\n\",\n    \"     \\\"showTitle\\\": true,\\n\",\n    \"     \\\"title\\\": \\\"rapids_4_spark_qualification_output_execs.csv\\\"\\n\",\n    \"    },\\n\",\n    \"    \\\"jupyter\\\": {\\n\",\n    \"     \\\"source_hidden\\\": true\\n\",\n    \"    }\\n\",\n    \"   },\\n\",\n    \"   \\\"outputs\\\": [],\\n\",\n    \"   \\\"source\\\": [\\n\",\n    \"    \\\"execs_output=pd.read_csv(os.path.join(jar_output_folder, \\\\\\\"rapids_4_spark_qualification_output_execs.csv\\\\\\\"))\\\\n\\\",\\n\",\n    \"    \\\"display(execs_output)\\\"\\n\",\n    \"   ]\\n\",\n    \"  }\\n\",\n    \" ],\\n\",\n    \" \\\"metadata\\\": {\\n\",\n    \"  \\\"application/vnd.databricks.v1+notebook\\\": {\\n\",\n    \"   \\\"dashboards\\\": [\\n\",\n    \"    {\\n\",\n    \"     \\\"elements\\\": [],\\n\",\n    \"     \\\"globalVars\\\": {},\\n\",\n    \"     \\\"guid\\\": \\\"\\\",\\n\",\n    \"     \\\"layoutOption\\\": {\\n\",\n    \"      \\\"grid\\\": true,\\n\",\n    \"      \\\"stack\\\": true\\n\",\n    \"     },\\n\",\n    \"     \\\"nuid\\\": \\\"91c1bfb2-695a-4e5c-8a25-848a433108dc\\\",\\n\",\n    \"     \\\"origId\\\": 2173122769183715,\\n\",\n    \"     \\\"title\\\": \\\"Executive View\\\",\\n\",\n    \"     \\\"version\\\": \\\"DashboardViewV1\\\",\\n\",\n    \"     \\\"width\\\": 1600\\n\",\n    \"    },\\n\",\n    \"    {\\n\",\n    \"     \\\"elements\\\": [],\\n\",\n    \"     \\\"globalVars\\\": {},\\n\",\n    \"     \\\"guid\\\": \\\"\\\",\\n\",\n    \"     \\\"layoutOption\\\": {\\n\",\n    \"      \\\"grid\\\": true,\\n\",\n    \"      \\\"stack\\\": true\\n\",\n    \"     },\\n\",\n    \"     \\\"nuid\\\": \\\"62243296-4562-4f06-90ac-d7a609f19c16\\\",\\n\",\n    \"     \\\"origId\\\": 2173122769183716,\\n\",\n    \"     \\\"title\\\": \\\"App View\\\",\\n\",\n    \"     \\\"version\\\": \\\"DashboardViewV1\\\",\\n\",\n    \"     \\\"width\\\": 1920\\n\",\n    \"    },\\n\",\n    \"    {\\n\",\n    \"     \\\"elements\\\": [],\\n\",\n    \"     \\\"globalVars\\\": {},\\n\",\n    \"     \\\"guid\\\": \\\"\\\",\\n\",\n    \"     \\\"layoutOption\\\": {\\n\",\n    \"      \\\"grid\\\": true,\\n\",\n    \"      \\\"stack\\\": true\\n\",\n    \"     },\\n\",\n    \"     \\\"nuid\\\": \\\"854f9c75-5977-42aa-b3dd-c680b8331f19\\\",\\n\",\n    \"     \\\"origId\\\": 2173122769183722,\\n\",\n    \"     \\\"title\\\": \\\"Untitled\\\",\\n\",\n    \"     \\\"version\\\": \\\"DashboardViewV1\\\",\\n\",\n    \"     \\\"width\\\": 1024\\n\",\n    \"    }\\n\",\n    \"   ],\\n\",\n    \"   \\\"environmentMetadata\\\": null,\\n\",\n    \"   \\\"language\\\": \\\"python\\\",\\n\",\n    \"   \\\"notebookMetadata\\\": {\\n\",\n    \"    \\\"mostRecentlyExecutedCommandWithImplicitDF\\\": {\\n\",\n    \"     \\\"commandId\\\": 2173122769183704,\\n\",\n    \"     \\\"dataframes\\\": [\\n\",\n    \"      \\\"_sqldf\\\"\\n\",\n    \"     ]\\n\",\n    \"    },\\n\",\n    \"    \\\"pythonIndentUnit\\\": 2,\\n\",\n    \"    \\\"widgetLayout\\\": [\\n\",\n    \"     {\\n\",\n    \"      \\\"breakBefore\\\": false,\\n\",\n    \"      \\\"name\\\": \\\"Eventlog Path\\\",\\n\",\n    \"      \\\"width\\\": 778\\n\",\n    \"     },\\n\",\n    \"     {\\n\",\n    \"      \\\"breakBefore\\\": false,\\n\",\n    \"      \\\"name\\\": \\\"Output Path\\\",\\n\",\n    \"      \\\"width\\\": 302\\n\",\n    \"     }\\n\",\n    \"    ]\\n\",\n    \"   },\\n\",\n    \"   \\\"notebookName\\\": \\\"[RAPIDS Accelerator for Apache Spark] Qualification Tool Notebook Template\\\",\\n\",\n    \"   \\\"widgets\\\": {\\n\",\n    \"    \\\"Eventlog Path\\\": {\\n\",\n    \"     \\\"currentValue\\\": \\\"/dbfs/user1/qualification_logs\\\",\\n\",\n    \"     \\\"nuid\\\": \\\"1272501d-5ad9-42be-ab62-35768b2fc384\\\",\\n\",\n    \"     \\\"typedWidgetInfo\\\": null,\\n\",\n    \"     \\\"widgetInfo\\\": {\\n\",\n    \"      \\\"defaultValue\\\": \\\"/dbfs/user1/qualification_logs\\\",\\n\",\n    \"      \\\"label\\\": \\\"\\\",\\n\",\n    \"      \\\"name\\\": \\\"Eventlog Path\\\",\\n\",\n    \"      \\\"options\\\": {\\n\",\n    \"       \\\"autoCreated\\\": false,\\n\",\n    \"       \\\"validationRegex\\\": null,\\n\",\n    \"       \\\"widgetType\\\": \\\"text\\\"\\n\",\n    \"      },\\n\",\n    \"      \\\"widgetType\\\": \\\"text\\\"\\n\",\n    \"     }\\n\",\n    \"    },\\n\",\n    \"    \\\"Output Path\\\": {\\n\",\n    \"     \\\"currentValue\\\": \\\"/tmp\\\",\\n\",\n    \"     \\\"nuid\\\": \\\"ab7e082c-1ef9-4912-8fd7-51bf985eb9c1\\\",\\n\",\n    \"     \\\"typedWidgetInfo\\\": null,\\n\",\n    \"     \\\"widgetInfo\\\": {\\n\",\n    \"      \\\"defaultValue\\\": \\\"/tmp\\\",\\n\",\n    \"      \\\"label\\\": null,\\n\",\n    \"      \\\"name\\\": \\\"Output Path\\\",\\n\",\n    \"      \\\"options\\\": {\\n\",\n    \"       \\\"autoCreated\\\": null,\\n\",\n    \"       \\\"validationRegex\\\": null,\\n\",\n    \"       \\\"widgetType\\\": \\\"text\\\"\\n\",\n    \"      },\\n\",\n    \"      \\\"widgetType\\\": \\\"text\\\"\\n\",\n    \"     }\\n\",\n    \"    }\\n\",\n    \"   }\\n\",\n    \"  },\\n\",\n    \"  \\\"language_info\\\": {\\n\",\n    \"   \\\"name\\\": \\\"python\\\"\\n\",\n    \"  }\\n\",\n    \" },\\n\",\n    \" \\\"nbformat\\\": 4,\\n\",\n    \" \\\"nbformat_minor\\\": 0\\n\",\n    \"}\\n\"\n   ],\n   \"id\": \"4ba18da2c217d2f1\"\n  }\n ],\n \"metadata\": {},\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "tools/emr/README.md",
    "content": "# EMR Qualification/Profiling Quick Start Notebooks\n\nThe RAPIDS Accelerator for Apache Spark includes two key tools for understanding the benefits of\nGPU acceleration as well as analyzing GPU Spark jobs.  For customers on EMR, the quick start\nnotebooks offer a simple interface for running the tools given a set of Spark event logs from\nCPU (qualification) or GPU (profiling) application runs.\n\n## Usage\n\n### Pre-requisites: Setup EMR Studio and Workspace\n1. Ensure that you have an **EMR cluster** running.\n2. Set up **EMR Studio** and **Workspace** by following the instructions in the [AWS Documentation](https://docs.aws.amazon.com/emr/latest/ManagementGuide/emr-studio-create-studio.html):\n   - Select **Custom Settings** while creating the Studio.\n   - Choose the **VPC** and **Subnet** where the EMR cluster is running.\n3. Attach the Workspace to the running EMR cluster. For more details, refer to the [AWS Documentation](https://docs.aws.amazon.com/emr/latest/ManagementGuide/emr-studio-create-use-clusters.html).\n\n### Running the Notebook\n1. Import the notebook into the EMR Workspace by dragging and dropping the notebook file.\n2. In the **User Input** section of the notebook, enter the path to event log files.\n3. Click the **fast-forward** icon labeled *Restart the kernel, then re-run the whole notebook* to process the logs at the specified path.\n\n## Limitations\n1. Currently, local and S3 event log paths are supported.\n1. Eventlog path must follow the formats `/local/path/to/eventlog` for local logs or `s3://my-bucket/path/to/eventlog` for logs stored in S3.\n1. The specified path can also be a directory. In such cases, the tool will recursively search for event logs within the directory.\n   - For example: `/path/to/clusterlogs`\n1. To specify multiple event logs, separate the paths with commas.\n   - For example: `s3://my-bucket/path/to/eventlog1,s3://my-bucket/path/to/eventlog2`\n\n**Latest Tools Version Supported** 24.08.2\n"
  },
  {
    "path": "tools/emr/[RAPIDS Accelerator for Apache Spark] Profiling Tool Notebook Template.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"df33c614-2ecc-47a0-8600-bc891681997f\",\n     \"showTitle\": false,\n     \"title\": \"\"\n    }\n   },\n   \"source\": [\n    \"## Profiling Tool for the RAPIDS Accelerator for Apache Spark\\n\",\n    \"\\n\",\n    \"To run the profiling tool, enter the log path that represents the location of your Spark GPU event logs. Then, select \\\"Run all\\\" to execute the notebook. Once the notebook completes, various output tables will appear below. For more options on running the profiling tool, please refer to the [Profiling Tool User Guide](https://docs.nvidia.com/spark-rapids/user-guide/latest/profiling/quickstart.html#running-the-tool).\\n\",\n    \"\\n\",\n    \"### Note\\n\",\n    \"- Currently, local and S3 event log paths are supported.\\n\",\n    \"- Eventlog path must follow the formats `/local/path/to/eventlog` for local logs or `s3://my-bucket/path/to/eventlog` for logs stored in S3.\\n\",\n    \"- The specified path can also be a directory. In such cases, the tool will recursively search for event logs within the directory.\\n\",\n    \"   - For example: `/path/to/clusterlogs`\\n\",\n    \"- To specify multiple event logs, separate the paths with commas.\\n\",\n    \"   - For example: `s3://my-bucket/path/to/eventlog1,s3://my-bucket/path/to/eventlog2`\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## User Input\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# Path to the event log in S3 (or local path)\\n\",\n    \"EVENTLOG_PATH = \\\"s3://my-bucket/path/to/eventlog\\\"  # or \\\"/local/path/to/eventlog\\\"\\n\",\n    \"\\n\",\n    \"# S3 path with write access where the output will be copied. \\n\",\n    \"S3_OUTPUT_PATH = \\\"s3://my-bucket/path/to/output\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"source\": [\n    \"## Setup Environment\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"jupyter\": {\n     \"source_hidden\": true\n    },\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"from IPython.display import display, Markdown\\n\",\n    \"\\n\",\n    \"TOOLS_VER = \\\"24.08.2\\\"\\n\",\n    \"display(Markdown(f\\\"**Using Spark RAPIDS Tools Version:** {TOOLS_VER}\\\"))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"%pip install spark-rapids-user-tools==$TOOLS_VER --user > /dev/null 2>&1\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"acf401a3-12d3-4236-a6c5-8fe8990b153a\",\n     \"showTitle\": true,\n     \"title\": \"Environment Setup\"\n    },\n    \"jupyter\": {\n     \"source_hidden\": true\n    },\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"import os\\n\",\n    \"import pandas as pd\\n\",\n    \"\\n\",\n    \"# Update PATH to include local binaries\\n\",\n    \"os.environ['PATH'] += os.pathsep + os.path.expanduser(\\\"~/.local/bin\\\")\\n\",\n    \"\\n\",\n    \"OUTPUT_PATH = \\\"/tmp\\\"\\n\",\n    \"DEST_FOLDER_NAME = \\\"prof-tool-result\\\"\\n\",\n    \"\\n\",\n    \"# Set environment variables\\n\",\n    \"os.environ[\\\"EVENTLOG_PATH\\\"] = EVENTLOG_PATH \\n\",\n    \"os.environ[\\\"OUTPUT_PATH\\\"] = OUTPUT_PATH\\n\",\n    \"\\n\",\n    \"CONSOLE_OUTPUT_PATH = os.path.join(OUTPUT_PATH, 'console_output.log')\\n\",\n    \"CONSOLE_ERROR_PATH = os.path.join(OUTPUT_PATH, 'console_error.log')\\n\",\n    \"\\n\",\n    \"os.environ['CONSOLE_OUTPUT_PATH'] = CONSOLE_OUTPUT_PATH\\n\",\n    \"os.environ['CONSOLE_ERROR_PATH'] = CONSOLE_ERROR_PATH\\n\",\n    \"\\n\",\n    \"print(f'Console output will be stored at {CONSOLE_OUTPUT_PATH} and errors will be stored at {CONSOLE_ERROR_PATH}')\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"execution\": {\n     \"iopub.execute_input\": \"2024-10-24T21:27:00.924906Z\",\n     \"iopub.status.busy\": \"2024-10-24T21:27:00.924587Z\",\n     \"iopub.status.idle\": \"2024-10-24T21:27:00.928129Z\",\n     \"shell.execute_reply\": \"2024-10-24T21:27:00.927454Z\",\n     \"shell.execute_reply.started\": \"2024-10-24T21:27:00.924879Z\"\n    }\n   },\n   \"source\": [\n    \"## Run Profiling Tool\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"693b5ee0-7500-43f3-b3e2-717fd5468aa8\",\n     \"showTitle\": true,\n     \"title\": \"Run Profiling Tool\"\n    },\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"%%sh\\n\",\n    \"spark_rapids profiling --platform emr --eventlogs \\\"$EVENTLOG_PATH\\\" -o \\\"$OUTPUT_PATH\\\" --verbose > \\\"$CONSOLE_OUTPUT_PATH\\\" 2> \\\"$CONSOLE_ERROR_PATH\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"f83af6c8-5a79-4a46-965b-38a4cb621877\",\n     \"showTitle\": false,\n     \"title\": \"\"\n    }\n   },\n   \"source\": [\n    \"## Console Output\\n\",\n    \"Console output shows the top candidates and their estimated GPU speedup.\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"c61527b7-a21a-492c-bab8-77f83dc5cabf\",\n     \"showTitle\": true,\n     \"title\": \"Show Console Output\"\n    },\n    \"scrolled\": true,\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"%%sh\\n\",\n    \"cat $CONSOLE_OUTPUT_PATH\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"f3c68b28-fc62-40ae-8528-799f3fc7507e\",\n     \"showTitle\": true,\n     \"title\": \"Show Logs\"\n    },\n    \"scrolled\": true,\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"%%sh\\n\",\n    \"cat $CONSOLE_ERROR_PATH\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"05f96ca1-1b08-494c-a12b-7e6cc3dcc546\",\n     \"showTitle\": true,\n     \"title\": \"Parse Output\"\n    },\n    \"jupyter\": {\n     \"source_hidden\": true\n    },\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"import re\\n\",\n    \"import shutil\\n\",\n    \"import os\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def extract_file_info(console_output_path, output_base_path):\\n\",\n    \"    try:\\n\",\n    \"        with open(console_output_path, 'r') as file:\\n\",\n    \"            stdout_text = file.read()\\n\",\n    \"\\n\",\n    \"        # Extract log file location\\n\",\n    \"        location_match = re.search(r\\\"Location: (.+)\\\", stdout_text)\\n\",\n    \"        if not location_match:\\n\",\n    \"            raise ValueError(\\n\",\n    \"                \\\"Log file location not found in the provided text.\\\")\\n\",\n    \"\\n\",\n    \"        log_file_location = location_match.group(1)\\n\",\n    \"\\n\",\n    \"        # Extract profiling output folder\\n\",\n    \"        qual_match = re.search(r\\\"prof_[^/]+(?=\\\\.log)\\\", log_file_location)\\n\",\n    \"        if not qual_match:\\n\",\n    \"            raise ValueError(\\n\",\n    \"                \\\"Output folder not found in the log file location.\\\")\\n\",\n    \"\\n\",\n    \"        output_folder_name = qual_match.group(0)\\n\",\n    \"        output_folder = os.path.join(output_base_path, output_folder_name)\\n\",\n    \"        return output_folder, log_file_location\\n\",\n    \"\\n\",\n    \"    except Exception as e:\\n\",\n    \"        raise RuntimeError(f\\\"Cannot parse console output. Reason: {e}\\\")\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def copy_logs(destination_folder, *log_files):\\n\",\n    \"    try:\\n\",\n    \"        log_folder = os.path.join(destination_folder, \\\"logs\\\")\\n\",\n    \"        os.makedirs(log_folder, exist_ok=True)\\n\",\n    \"\\n\",\n    \"        for log_file in log_files:\\n\",\n    \"            if os.path.exists(log_file):\\n\",\n    \"                shutil.copy2(log_file, log_folder)\\n\",\n    \"            else:\\n\",\n    \"                print(f\\\"Log file not found: {log_file}\\\")\\n\",\n    \"    except Exception as e:\\n\",\n    \"        raise RuntimeError(f\\\"Cannot copy logs to output. Reason: {e}\\\")\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"try:\\n\",\n    \"    output_folder, log_file_location = extract_file_info(\\n\",\n    \"        CONSOLE_OUTPUT_PATH, OUTPUT_PATH)\\n\",\n    \"    jar_output_folder = os.path.join(output_folder,\\n\",\n    \"                                     \\\"rapids_4_spark_profile\\\")\\n\",\n    \"    print(f\\\"Output folder detected {output_folder}\\\")\\n\",\n    \"    copy_logs(output_folder, log_file_location, CONSOLE_OUTPUT_PATH,\\n\",\n    \"              CONSOLE_ERROR_PATH)\\n\",\n    \"    print(f\\\"Logs successfully copied to {output_folder}\\\")\\n\",\n    \"except Exception as e:\\n\",\n    \"    print(e)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Download Output\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"8c65adcd-a933-482e-a50b-d40fa8f50e16\",\n     \"showTitle\": true,\n     \"title\": \"Download Output\"\n    },\n    \"jupyter\": {\n     \"source_hidden\": true\n    },\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"import shutil\\n\",\n    \"import os\\n\",\n    \"import subprocess\\n\",\n    \"from IPython.display import HTML, display\\n\",\n    \"from urllib.parse import urlparse\\n\",\n    \"\\n\",\n    \"def display_error_message(error_message, exception):\\n\",\n    \"    error_message_html = f\\\"\\\"\\\"\\n\",\n    \"    <div style=\\\"color: red; margin: 20px;\\\">\\n\",\n    \"        <strong>Error:</strong> {error_message}.\\n\",\n    \"        <br/>\\n\",\n    \"        <strong>Exception:</strong> {exception}\\n\",\n    \"    </div>\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"    display(HTML(error_message_html))\\n\",\n    \"\\n\",\n    \"def copy_file_to_s3(local_file: str, bucket: str, destination_folder_name: str):\\n\",\n    \"    try:\\n\",\n    \"        file_name = os.path.basename(local_file)\\n\",\n    \"        s3_path = f\\\"s3://{bucket}/{destination_folder_name}/{file_name}\\\"\\n\",\n    \"        subprocess.run([\\\"aws\\\", \\\"s3\\\", \\\"cp\\\", local_file, s3_path], check=True, capture_output=True, text=True)\\n\",\n    \"        return construct_download_url(file_name, bucket, destination_folder_name)\\n\",\n    \"    except subprocess.CalledProcessError as e:\\n\",\n    \"        raise Exception(f\\\"Error copying file to S3: {e.stderr}\\\") from e\\n\",\n    \"\\n\",\n    \"def get_default_aws_region():\\n\",\n    \"    try:\\n\",\n    \"        return subprocess.check_output(\\n\",\n    \"            \\\"aws configure list | grep region | awk '{print $2}'\\\",\\n\",\n    \"            shell=True,\\n\",\n    \"            text=True\\n\",\n    \"        ).strip()\\n\",\n    \"    except subprocess.CalledProcessError:\\n\",\n    \"        return \\\"Error: Unable to retrieve the region.\\\"\\n\",\n    \"\\n\",\n    \"def construct_download_url(file_name: str, bucket_name: str, destination_folder_name: str):\\n\",\n    \"    region = get_default_aws_region()\\n\",\n    \"    return f\\\"https://{region}.console.aws.amazon.com/s3/object/{bucket_name}?region={region}&prefix={destination_folder_name}/{file_name}\\\"\\n\",\n    \"\\n\",\n    \"def create_download_link(source_folder, bucket_name, destination_folder_name):\\n\",\n    \"    folder_to_compress = os.path.join(\\\"/tmp\\\", os.path.basename(source_folder))\\n\",\n    \"    local_zip_file_path = shutil.make_archive(folder_to_compress, 'zip', source_folder)\\n\",\n    \"    download_url = copy_file_to_s3(local_zip_file_path, bucket_name, destination_folder_name)\\n\",\n    \"\\n\",\n    \"    download_button_html = f\\\"\\\"\\\"\\n\",\n    \"    <style>\\n\",\n    \"        .download-btn {{\\n\",\n    \"            display: inline-block;\\n\",\n    \"            padding: 10px 20px;\\n\",\n    \"            font-size: 16px;\\n\",\n    \"            color: white;\\n\",\n    \"            background-color: #4CAF50;\\n\",\n    \"            text-align: center;\\n\",\n    \"            text-decoration: none;\\n\",\n    \"            border-radius: 5px;\\n\",\n    \"            border: none;\\n\",\n    \"            cursor: pointer;\\n\",\n    \"            margin: 15px auto;\\n\",\n    \"        }}\\n\",\n    \"        .download-btn:hover {{\\n\",\n    \"            background-color: #45a049;\\n\",\n    \"        }}\\n\",\n    \"        .button-container {{\\n\",\n    \"            display: flex;\\n\",\n    \"            justify-content: center;\\n\",\n    \"            align-items: center;\\n\",\n    \"        }}\\n\",\n    \"        .button-container a {{\\n\",\n    \"            color: white !important;\\n\",\n    \"        }}\\n\",\n    \"    </style>\\n\",\n    \"\\n\",\n    \"    <div style=\\\"color: #444; font-size: 14px; text-align: center; margin: 10px;\\\">\\n\",\n    \"        Zipped output file created at {download_url}\\n\",\n    \"    </div>\\n\",\n    \"    <div class='button-container'>\\n\",\n    \"        <a href='{download_url}' class='download-btn'>Download Output</a>\\n\",\n    \"    </div>\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"    display(HTML(download_button_html))\\n\",\n    \"\\n\",\n    \"try:\\n\",\n    \"    current_working_directory = os.getcwd()\\n\",\n    \"    parsed_s3_output_path = urlparse(S3_OUTPUT_PATH)\\n\",\n    \"    bucket_name = parsed_s3_output_path.netloc\\n\",\n    \"    destination_path = os.path.join(parsed_s3_output_path.path.strip(\\\"/\\\"), DEST_FOLDER_NAME.strip(\\\"/\\\"))\\n\",\n    \"    create_download_link(output_folder, bucket_name, destination_path)\\n\",\n    \"    \\n\",\n    \"except Exception as e:\\n\",\n    \"    error_msg = f\\\"Failed to create download link for {output_folder}\\\"\\n\",\n    \"    display_error_message(error_msg, e)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {},\n     \"inputWidgets\": {},\n     \"nuid\": \"73b5e0b0-3a96-4cc6-8e6c-840e4b0d9d43\",\n     \"showTitle\": false,\n     \"title\": \"\"\n    }\n   },\n   \"source\": [\n    \"\\n\",\n    \"## Application Status\\n\",\n    \"\\n\",\n    \"The report show the status of each eventlog file that was provided\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"c9ffbfdb-dbb6-4736-b9cb-2ac457cc6714\",\n     \"showTitle\": true,\n     \"title\": \"profiling_status.csv\"\n    },\n    \"jupyter\": {\n     \"source_hidden\": true\n    },\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"try:\\n\",\n    \"    status_output = pd.read_csv(\\n\",\n    \"        os.path.join(jar_output_folder, \\\"profiling_status.csv\\\"))\\n\",\n    \"\\n\",\n    \"    # Set options to display the full content of the DataFrame\\n\",\n    \"    pd.set_option('display.max_rows', None)  # Show all rows\\n\",\n    \"    pd.set_option('display.max_columns', None)  # Show all columns\\n\",\n    \"    pd.set_option('display.width', None)  # Adjust column width to fit the display\\n\",\n    \"    pd.set_option('display.max_colwidth', None)  # Display full content of each column\\n\",\n    \"\\n\",\n    \"    display(status_output)\\n\",\n    \"except Exception as e:\\n\",\n    \"    error_msg = \\\"Unable to show Application Status\\\"\\n\",\n    \"    display_error_message(error_msg, e)        \\n\",\n    \"        \\n\",\n    \"        \"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"6756159b-30ca-407a-ab6b-9c29ced01ea6\",\n     \"showTitle\": false,\n     \"title\": \"\"\n    }\n   },\n   \"source\": [\n    \"## GPU Job Tuning Recommendations\\n\",\n    \"This has general suggestions for tuning your applications to run optimally on GPUs.\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"cdde6177-db5f-434a-995b-776678a64a3a\",\n     \"showTitle\": true,\n     \"title\": \"application_information.csv\"\n    },\n    \"jupyter\": {\n     \"source_hidden\": true\n    },\n    \"scrolled\": true,\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"try:\\n\",\n    \"    jar_output_folder = os.path.join(output_folder, \\\"rapids_4_spark_profile\\\")\\n\",\n    \"    app_df = pd.DataFrame(columns=['appId', 'appName'])\\n\",\n    \"\\n\",\n    \"    for x in os.scandir(jar_output_folder):\\n\",\n    \"        if x.is_dir():\\n\",\n    \"            csv_path = os.path.join(x.path, \\\"application_information.csv\\\")\\n\",\n    \"            if os.path.exists(csv_path):\\n\",\n    \"              tmp_df = pd.read_csv(csv_path)\\n\",\n    \"              app_df = pd.concat([app_df, tmp_df[['appId', 'appName']]])\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"    app_list = app_df[\\\"appId\\\"].tolist()\\n\",\n    \"    app_recommendations = pd.DataFrame(columns=['App', 'Recommended Configuration'])\\n\",\n    \"\\n\",\n    \"    for app in app_list:\\n\",\n    \"      app_file = open(os.path.join(jar_output_folder, app, \\\"profile.log\\\"))\\n\",\n    \"      recommendations_start = 0\\n\",\n    \"      recommendations_str = \\\"\\\"\\n\",\n    \"      for line in app_file:\\n\",\n    \"        if recommendations_start == 1:\\n\",\n    \"          recommendations_str = recommendations_str + line\\n\",\n    \"        if \\\"### D. Recommended Configuration ###\\\" in line:\\n\",\n    \"          recommendations_start = 1\\n\",\n    \"      app_recommendations = pd.concat([app_recommendations, pd.DataFrame({'App': [app], 'Recommended Configuration': [recommendations_str]})], ignore_index=True)\\n\",\n    \"      html = app_recommendations.to_html().replace(\\\"\\\\\\\\n\\\", \\\"<br>\\\")\\n\",\n    \"      style = \\\"<style>table td { vertical-align: top !important; text-align: left !important; white-space: pre-wrap; } th { text-align: left !important; }</style>\\\"\\n\",\n    \"      display(HTML(html + style))\\n\",\n    \"except Exception as e:\\n\",\n    \"    error_msg = \\\"Unable to show stage output\\\"\\n\",\n    \"    display_error_message(error_msg, e) \"\n   ]\n  }\n ],\n \"metadata\": {\n  \"application/vnd.databricks.v1+notebook\": {\n   \"dashboards\": [\n    {\n     \"elements\": [],\n     \"globalVars\": {},\n     \"guid\": \"\",\n     \"layoutOption\": {\n      \"grid\": true,\n      \"stack\": true\n     },\n     \"nuid\": \"91c1bfb2-695a-4e5c-8a25-848a433108dc\",\n     \"origId\": 2173122769183715,\n     \"title\": \"Executive View\",\n     \"version\": \"DashboardViewV1\",\n     \"width\": 1600\n    },\n    {\n     \"elements\": [],\n     \"globalVars\": {},\n     \"guid\": \"\",\n     \"layoutOption\": {\n      \"grid\": true,\n      \"stack\": true\n     },\n     \"nuid\": \"62243296-4562-4f06-90ac-d7a609f19c16\",\n     \"origId\": 2173122769183716,\n     \"title\": \"App View\",\n     \"version\": \"DashboardViewV1\",\n     \"width\": 1920\n    },\n    {\n     \"elements\": [],\n     \"globalVars\": {},\n     \"guid\": \"\",\n     \"layoutOption\": {\n      \"grid\": true,\n      \"stack\": true\n     },\n     \"nuid\": \"854f9c75-5977-42aa-b3dd-c680b8331f19\",\n     \"origId\": 2173122769183722,\n     \"title\": \"Untitled\",\n     \"version\": \"DashboardViewV1\",\n     \"width\": 1024\n    }\n   ],\n   \"environmentMetadata\": null,\n   \"language\": \"python\",\n   \"notebookMetadata\": {\n    \"mostRecentlyExecutedCommandWithImplicitDF\": {\n     \"commandId\": 2173122769183704,\n     \"dataframes\": [\n      \"_sqldf\"\n     ]\n    },\n    \"pythonIndentUnit\": 2,\n    \"widgetLayout\": [\n     {\n      \"breakBefore\": false,\n      \"name\": \"Eventlog Path\",\n      \"width\": 778\n     },\n     {\n      \"breakBefore\": false,\n      \"name\": \"Output Path\",\n      \"width\": 302\n     }\n    ]\n   },\n   \"notebookName\": \"[RAPIDS Accelerator for Apache Spark] Profiling Tool Notebook Template\"\n  },\n  \"kernelspec\": {\n   \"display_name\": \"Python 3\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.9.18\"\n  },\n  \"widgets\": {\n   \"application/vnd.jupyter.widget-state+json\": {\n    \"state\": {},\n    \"version_major\": 2,\n    \"version_minor\": 0\n   }\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 4\n}\n"
  },
  {
    "path": "tools/emr/[RAPIDS Accelerator for Apache Spark] Qualification Tool Notebook Template.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"df33c614-2ecc-47a0-8600-bc891681997f\",\n     \"showTitle\": false,\n     \"title\": \"\"\n    }\n   },\n   \"source\": [\n    \"## Qualification Tool for the RAPIDS Accelerator for Apache Spark\\n\",\n    \"\\n\",\n    \"To run the qualification tool, enter the log path that represents the location of your Spark CPU event logs. Then, select \\\"Run all\\\" to execute the notebook. Once the notebook completes, various output tables will appear below. For more options on running the qualification tool, please refer to the [Qualification Tool User Guide](https://docs.nvidia.com/spark-rapids/user-guide/latest/qualification/quickstart.html#running-the-tool).\\n\",\n    \"\\n\",\n    \"### Note\\n\",\n    \"- Currently, local and S3 event log paths are supported.\\n\",\n    \"- Eventlog path must follow the formats `/local/path/to/eventlog` for local logs or `s3://my-bucket/path/to/eventlog` for logs stored in S3.\\n\",\n    \"- The specified path can also be a directory. In such cases, the tool will recursively search for event logs within the directory.\\n\",\n    \"   - For example: `/path/to/clusterlogs`\\n\",\n    \"- To specify multiple event logs, separate the paths with commas.\\n\",\n    \"   - For example: `s3://my-bucket/path/to/eventlog1,s3://my-bucket/path/to/eventlog2`\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## User Input\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# Path to the event log in S3 (or local path)\\n\",\n    \"EVENTLOG_PATH = \\\"s3://my-bucket/path/to/eventlog\\\"  # or \\\"/local/path/to/eventlog\\\"\\n\",\n    \"\\n\",\n    \"# S3 path with write access where the output will be copied. \\n\",\n    \"S3_OUTPUT_PATH = \\\"s3://my-bucket/path/to/output\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"source\": [\n    \"## Setup Environment\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"jupyter\": {\n     \"source_hidden\": true\n    },\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"from IPython.display import display, Markdown\\n\",\n    \"\\n\",\n    \"TOOLS_VER = \\\"24.08.2\\\"\\n\",\n    \"display(Markdown(f\\\"**Using Spark RAPIDS Tools Version:** {TOOLS_VER}\\\"))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"%pip install spark-rapids-user-tools==$TOOLS_VER --user > /dev/null 2>&1\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"acf401a3-12d3-4236-a6c5-8fe8990b153a\",\n     \"showTitle\": true,\n     \"title\": \"Environment Setup\"\n    },\n    \"jupyter\": {\n     \"source_hidden\": true\n    },\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"import os\\n\",\n    \"import pandas as pd\\n\",\n    \"\\n\",\n    \"# Update PATH to include local binaries\\n\",\n    \"os.environ['PATH'] += os.pathsep + os.path.expanduser(\\\"~/.local/bin\\\")\\n\",\n    \"\\n\",\n    \"OUTPUT_PATH = \\\"/tmp\\\"\\n\",\n    \"DEST_FOLDER_NAME = \\\"qual-tool-result\\\"\\n\",\n    \"\\n\",\n    \"# Set environment variables\\n\",\n    \"os.environ[\\\"EVENTLOG_PATH\\\"] = EVENTLOG_PATH \\n\",\n    \"os.environ[\\\"OUTPUT_PATH\\\"] = OUTPUT_PATH\\n\",\n    \"\\n\",\n    \"CONSOLE_OUTPUT_PATH = os.path.join(OUTPUT_PATH, 'console_output.log')\\n\",\n    \"CONSOLE_ERROR_PATH = os.path.join(OUTPUT_PATH, 'console_error.log')\\n\",\n    \"\\n\",\n    \"os.environ['CONSOLE_OUTPUT_PATH'] = CONSOLE_OUTPUT_PATH\\n\",\n    \"os.environ['CONSOLE_ERROR_PATH'] = CONSOLE_ERROR_PATH\\n\",\n    \"\\n\",\n    \"print(f'Console output will be stored at {CONSOLE_OUTPUT_PATH} and errors will be stored at {CONSOLE_ERROR_PATH}')\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"execution\": {\n     \"iopub.execute_input\": \"2024-10-24T21:27:00.924906Z\",\n     \"iopub.status.busy\": \"2024-10-24T21:27:00.924587Z\",\n     \"iopub.status.idle\": \"2024-10-24T21:27:00.928129Z\",\n     \"shell.execute_reply\": \"2024-10-24T21:27:00.927454Z\",\n     \"shell.execute_reply.started\": \"2024-10-24T21:27:00.924879Z\"\n    }\n   },\n   \"source\": [\n    \"## Run Qualification Tool\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"693b5ee0-7500-43f3-b3e2-717fd5468aa8\",\n     \"showTitle\": true,\n     \"title\": \"Run Qualification Tool\"\n    },\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"%%sh\\n\",\n    \"spark_rapids qualification --platform emr --eventlogs \\\"$EVENTLOG_PATH\\\" -o \\\"$OUTPUT_PATH\\\" --verbose > \\\"$CONSOLE_OUTPUT_PATH\\\" 2> \\\"$CONSOLE_ERROR_PATH\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"f83af6c8-5a79-4a46-965b-38a4cb621877\",\n     \"showTitle\": false,\n     \"title\": \"\"\n    }\n   },\n   \"source\": [\n    \"## Console Output\\n\",\n    \"Console output shows the top candidates and their estimated GPU speedup.\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"c61527b7-a21a-492c-bab8-77f83dc5cabf\",\n     \"showTitle\": true,\n     \"title\": \"Show Console Output\"\n    },\n    \"scrolled\": true,\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"%%sh\\n\",\n    \"cat $CONSOLE_OUTPUT_PATH\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"f3c68b28-fc62-40ae-8528-799f3fc7507e\",\n     \"showTitle\": true,\n     \"title\": \"Show Logs\"\n    },\n    \"scrolled\": true,\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"%%sh\\n\",\n    \"cat $CONSOLE_ERROR_PATH\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"05f96ca1-1b08-494c-a12b-7e6cc3dcc546\",\n     \"showTitle\": true,\n     \"title\": \"Parse Output\"\n    },\n    \"jupyter\": {\n     \"source_hidden\": true\n    },\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"import re\\n\",\n    \"import shutil\\n\",\n    \"import os\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def extract_file_info(console_output_path, output_base_path):\\n\",\n    \"    try:\\n\",\n    \"        with open(console_output_path, 'r') as file:\\n\",\n    \"            stdout_text = file.read()\\n\",\n    \"\\n\",\n    \"        # Extract log file location\\n\",\n    \"        location_match = re.search(r\\\"Location: (.+)\\\", stdout_text)\\n\",\n    \"        if not location_match:\\n\",\n    \"            raise ValueError(\\n\",\n    \"                \\\"Log file location not found in the provided text.\\\")\\n\",\n    \"\\n\",\n    \"        log_file_location = location_match.group(1)\\n\",\n    \"\\n\",\n    \"        # Extract qualification output folder\\n\",\n    \"        qual_match = re.search(r\\\"qual_[^/]+(?=\\\\.log)\\\", log_file_location)\\n\",\n    \"        if not qual_match:\\n\",\n    \"            raise ValueError(\\n\",\n    \"                \\\"Output folder not found in the log file location.\\\")\\n\",\n    \"\\n\",\n    \"        output_folder_name = qual_match.group(0)\\n\",\n    \"        output_folder = os.path.join(output_base_path, output_folder_name)\\n\",\n    \"        return output_folder, log_file_location\\n\",\n    \"\\n\",\n    \"    except Exception as e:\\n\",\n    \"        raise RuntimeError(f\\\"Cannot parse console output. Reason: {e}\\\")\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def copy_logs(destination_folder, *log_files):\\n\",\n    \"    try:\\n\",\n    \"        log_folder = os.path.join(destination_folder, \\\"logs\\\")\\n\",\n    \"        os.makedirs(log_folder, exist_ok=True)\\n\",\n    \"\\n\",\n    \"        for log_file in log_files:\\n\",\n    \"            if os.path.exists(log_file):\\n\",\n    \"                shutil.copy2(log_file, log_folder)\\n\",\n    \"            else:\\n\",\n    \"                print(f\\\"Log file not found: {log_file}\\\")\\n\",\n    \"    except Exception as e:\\n\",\n    \"        raise RuntimeError(f\\\"Cannot copy logs to output. Reason: {e}\\\")\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"try:\\n\",\n    \"    output_folder, log_file_location = extract_file_info(\\n\",\n    \"        CONSOLE_OUTPUT_PATH, OUTPUT_PATH)\\n\",\n    \"    jar_output_folder = os.path.join(output_folder,\\n\",\n    \"                                     \\\"rapids_4_spark_qualification_output\\\")\\n\",\n    \"    print(f\\\"Output folder detected {output_folder}\\\")\\n\",\n    \"    copy_logs(output_folder, log_file_location, CONSOLE_OUTPUT_PATH,\\n\",\n    \"              CONSOLE_ERROR_PATH)\\n\",\n    \"    print(f\\\"Logs successfully copied to {output_folder}\\\")\\n\",\n    \"except Exception as e:\\n\",\n    \"    print(e)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Download Output\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"8c65adcd-a933-482e-a50b-d40fa8f50e16\",\n     \"showTitle\": true,\n     \"title\": \"Download Output\"\n    },\n    \"jupyter\": {\n     \"source_hidden\": true\n    },\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"import shutil\\n\",\n    \"import os\\n\",\n    \"import subprocess\\n\",\n    \"from IPython.display import HTML, display\\n\",\n    \"from urllib.parse import urlparse\\n\",\n    \"\\n\",\n    \"def display_error_message(error_message, exception):\\n\",\n    \"    error_message_html = f\\\"\\\"\\\"\\n\",\n    \"    <div style=\\\"color: red; margin: 20px;\\\">\\n\",\n    \"        <strong>Error:</strong> {error_message}.\\n\",\n    \"        <br/>\\n\",\n    \"        <strong>Exception:</strong> {exception}\\n\",\n    \"    </div>\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"    display(HTML(error_message_html))\\n\",\n    \"\\n\",\n    \"def copy_file_to_s3(local_file: str, bucket: str, destination_folder_name: str):\\n\",\n    \"    try:\\n\",\n    \"        file_name = os.path.basename(local_file)\\n\",\n    \"        s3_path = f\\\"s3://{bucket}/{destination_folder_name}/{file_name}\\\"\\n\",\n    \"        subprocess.run([\\\"aws\\\", \\\"s3\\\", \\\"cp\\\", local_file, s3_path], check=True, capture_output=True, text=True)\\n\",\n    \"        return construct_download_url(file_name, bucket, destination_folder_name)\\n\",\n    \"    except subprocess.CalledProcessError as e:\\n\",\n    \"        raise Exception(f\\\"Error copying file to S3: {e.stderr}\\\") from e\\n\",\n    \"\\n\",\n    \"def get_default_aws_region():\\n\",\n    \"    try:\\n\",\n    \"        return subprocess.check_output(\\n\",\n    \"            \\\"aws configure list | grep region | awk '{print $2}'\\\",\\n\",\n    \"            shell=True,\\n\",\n    \"            text=True\\n\",\n    \"        ).strip()\\n\",\n    \"    except subprocess.CalledProcessError:\\n\",\n    \"        return \\\"Error: Unable to retrieve the region.\\\"\\n\",\n    \"\\n\",\n    \"def construct_download_url(file_name: str, bucket_name: str, destination_folder_name: str):\\n\",\n    \"    region = get_default_aws_region()\\n\",\n    \"    return f\\\"https://{region}.console.aws.amazon.com/s3/object/{bucket_name}?region={region}&prefix={destination_folder_name}/{file_name}\\\"\\n\",\n    \"\\n\",\n    \"def create_download_link(source_folder, bucket_name, destination_folder_name):\\n\",\n    \"    folder_to_compress = os.path.join(\\\"/tmp\\\", os.path.basename(source_folder))\\n\",\n    \"    local_zip_file_path = shutil.make_archive(folder_to_compress, 'zip', source_folder)\\n\",\n    \"    download_url = copy_file_to_s3(local_zip_file_path, bucket_name, destination_folder_name)\\n\",\n    \"\\n\",\n    \"    download_button_html = f\\\"\\\"\\\"\\n\",\n    \"    <style>\\n\",\n    \"        .download-btn {{\\n\",\n    \"            display: inline-block;\\n\",\n    \"            padding: 10px 20px;\\n\",\n    \"            font-size: 16px;\\n\",\n    \"            color: white;\\n\",\n    \"            background-color: #4CAF50;\\n\",\n    \"            text-align: center;\\n\",\n    \"            text-decoration: none;\\n\",\n    \"            border-radius: 5px;\\n\",\n    \"            border: none;\\n\",\n    \"            cursor: pointer;\\n\",\n    \"            margin: 15px auto;\\n\",\n    \"        }}\\n\",\n    \"        .download-btn:hover {{\\n\",\n    \"            background-color: #45a049;\\n\",\n    \"        }}\\n\",\n    \"        .button-container {{\\n\",\n    \"            display: flex;\\n\",\n    \"            justify-content: center;\\n\",\n    \"            align-items: center;\\n\",\n    \"        }}\\n\",\n    \"        .button-container a {{\\n\",\n    \"            color: white !important;\\n\",\n    \"        }}\\n\",\n    \"    </style>\\n\",\n    \"\\n\",\n    \"    <div style=\\\"color: #444; font-size: 14px; text-align: center; margin: 10px;\\\">\\n\",\n    \"        Zipped output file created at {download_url}\\n\",\n    \"    </div>\\n\",\n    \"    <div class='button-container'>\\n\",\n    \"        <a href='{download_url}' class='download-btn'>Download Output</a>\\n\",\n    \"    </div>\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"    display(HTML(download_button_html))\\n\",\n    \"\\n\",\n    \"try:\\n\",\n    \"    current_working_directory = os.getcwd()\\n\",\n    \"    parsed_s3_output_path = urlparse(S3_OUTPUT_PATH)\\n\",\n    \"    bucket_name = parsed_s3_output_path.netloc\\n\",\n    \"    destination_path = os.path.join(parsed_s3_output_path.path.strip(\\\"/\\\"), DEST_FOLDER_NAME.strip(\\\"/\\\"))\\n\",\n    \"    create_download_link(output_folder, bucket_name, destination_path)\\n\",\n    \"    \\n\",\n    \"except Exception as e:\\n\",\n    \"    error_msg = f\\\"Failed to create download link for {output_folder}\\\"\\n\",\n    \"    display_error_message(error_msg, e)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"bbe50fde-0bd6-4281-95fd-6a1ec6f17ab2\",\n     \"showTitle\": false,\n     \"title\": \"\"\n    }\n   },\n   \"source\": [\n    \"\\n\",\n    \"## Summary\\n\",\n    \"\\n\",\n    \"The report provides a comprehensive overview of the entire application execution, estimated speedup, including unsupported operators and non-SQL operations. By default, the applications and queries are sorted in descending order based on the following fields:\\n\",\n    \"\\n\",\n    \"- Estimated GPU Speedup Category\\n\",\n    \"- Estimated GPU Speedup\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"b8bca4a6-16d8-4b60-ba7b-9aff64bdcaa1\",\n     \"showTitle\": true,\n     \"title\": \"qualification_summary.csv\"\n    },\n    \"jupyter\": {\n     \"source_hidden\": true\n    },\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"def millis_to_human_readable(millis):\\n\",\n    \"    seconds = int(millis) / 1000\\n\",\n    \"    if seconds < 60:\\n\",\n    \"        return f\\\"{seconds:.2f} sec\\\"\\n\",\n    \"    else:\\n\",\n    \"        minutes = seconds / 60\\n\",\n    \"        if minutes < 60:\\n\",\n    \"            return f\\\"{minutes:.2f} min\\\"\\n\",\n    \"        else:\\n\",\n    \"            hours = minutes / 60\\n\",\n    \"            return f\\\"{hours:.2f} hr\\\"\\n\",\n    \"\\n\",\n    \"try: \\n\",\n    \"    # Read qualification summary \\n\",\n    \"    summary_output = pd.read_csv(os.path.join(output_folder, \\\"qualification_summary.csv\\\"))\\n\",\n    \"    summary_output = summary_output.drop(columns=[\\\"Unnamed: 0\\\"]).rename_axis('Index').reset_index()\\n\",\n    \"    summary_output['Estimated GPU Duration'] = summary_output['Estimated GPU Duration'].apply(millis_to_human_readable)\\n\",\n    \"    summary_output['App Duration'] = summary_output['App Duration'].apply(millis_to_human_readable)\\n\",\n    \"    \\n\",\n    \"    summary_output = summary_output[[\\n\",\n    \"        'App Name', 'App ID', 'Estimated GPU Speedup Category', 'Estimated GPU Speedup', \\n\",\n    \"        'Estimated GPU Duration', 'App Duration'\\n\",\n    \"    ]]\\n\",\n    \"    \\n\",\n    \"    # Read cluster information\\n\",\n    \"    cluster_df = pd.read_json(os.path.join(output_folder, \\\"app_metadata.json\\\"))\\n\",\n    \"    cluster_df['Recommended GPU Cluster'] = cluster_df['clusterInfo'].apply(\\n\",\n    \"        lambda x: f\\\"{x['recommendedCluster']['numWorkerNodes']} x {x['recommendedCluster']['workerNodeType']}\\\"\\n\",\n    \"    )\\n\",\n    \"    cluster_df['App ID'] = cluster_df['appId']\\n\",\n    \"    cluster_df = cluster_df[['App ID', 'Recommended GPU Cluster']]\\n\",\n    \"    \\n\",\n    \"    # Merge the results\\n\",\n    \"    results = pd.merge(summary_output, cluster_df, on='App ID', how='left')\\n\",\n    \"    display(results)\\n\",\n    \"except Exception as e:\\n\",\n    \"    error_msg = \\\"Unable to show summary\\\"\\n\",\n    \"    display_error_message(error_msg, e)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {},\n     \"inputWidgets\": {},\n     \"nuid\": \"73b5e0b0-3a96-4cc6-8e6c-840e4b0d9d43\",\n     \"showTitle\": false,\n     \"title\": \"\"\n    }\n   },\n   \"source\": [\n    \"\\n\",\n    \"## Application Status\\n\",\n    \"\\n\",\n    \"The report show the status of each eventlog file that was provided\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"c9ffbfdb-dbb6-4736-b9cb-2ac457cc6714\",\n     \"showTitle\": true,\n     \"title\": \"rapids_4_spark_qualification_output_status.csv\"\n    },\n    \"jupyter\": {\n     \"source_hidden\": true\n    },\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"try:\\n\",\n    \"    status_output = pd.read_csv(\\n\",\n    \"        os.path.join(jar_output_folder,\\n\",\n    \"                     \\\"rapids_4_spark_qualification_output_status.csv\\\"))\\n\",\n    \"\\n\",\n    \"    # Set options to display the full content of the DataFrame\\n\",\n    \"    pd.set_option('display.max_rows', None)  # Show all rows\\n\",\n    \"    pd.set_option('display.max_columns', None)  # Show all columns\\n\",\n    \"    pd.set_option('display.width', None)  # Adjust column width to fit the display\\n\",\n    \"    pd.set_option('display.max_colwidth', None)  # Display full content of each column\\n\",\n    \"\\n\",\n    \"    display(status_output)\\n\",\n    \"except Exception as e:\\n\",\n    \"    error_msg = \\\"Unable to show Application Status\\\"\\n\",\n    \"    display_error_message(error_msg, e)        \\n\",\n    \"        \\n\",\n    \"        \"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"6756159b-30ca-407a-ab6b-9c29ced01ea6\",\n     \"showTitle\": false,\n     \"title\": \"\"\n    }\n   },\n   \"source\": [\n    \"## Stages Output\\n\",\n    \"\\n\",\n    \"For each stage used in SQL operations, the Qualification tool generates the following information:\\n\",\n    \"\\n\",\n    \"1. App ID\\n\",\n    \"2. Stage ID\\n\",\n    \"3. Average Speedup Factor: The average estimated speed-up of all the operators in the given stage.\\n\",\n    \"4. Stage Task Duration: The amount of time spent in tasks of SQL DataFrame operations for the given stage.\\n\",\n    \"5. Unsupported Task Duration: The sum of task durations for the unsupported operators. For more details, see [Supported Operators](https://nvidia.github.io/spark-rapids/docs/supported_ops.html).\\n\",\n    \"6. Stage Estimated: Indicates if the stage duration had to be estimated (True or False).\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"cdde6177-db5f-434a-995b-776678a64a3a\",\n     \"showTitle\": true,\n     \"title\": \"rapids_4_spark_qualification_output_stages.csv\"\n    },\n    \"jupyter\": {\n     \"source_hidden\": true\n    },\n    \"scrolled\": true,\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"try:\\n\",\n    \"    stages_output = pd.read_csv(\\n\",\n    \"        os.path.join(jar_output_folder,\\n\",\n    \"                     \\\"rapids_4_spark_qualification_output_stages.csv\\\"))\\n\",\n    \"    display(stages_output)\\n\",\n    \"except Exception as e:\\n\",\n    \"    error_msg = \\\"Unable to show stage output\\\"\\n\",\n    \"    display_error_message(error_msg, e) \"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"4d7ce219-ae75-4a0c-a78c-4e7f25b8cd6f\",\n     \"showTitle\": false,\n     \"title\": \"\"\n    }\n   },\n   \"source\": [\n    \"## Execs Output\\n\",\n    \"\\n\",\n    \"The Qualification tool generates a report of the “Exec” in the “SparkPlan” or “Executor Nodes” along with the estimated acceleration on the GPU. Please refer to the [Supported Operators guide](https://nvidia.github.io/spark-rapids/docs/supported_ops.html) for more details on limitations on UDFs and unsupported operators.\\n\",\n    \"\\n\",\n    \"1. App ID\\n\",\n    \"2. SQL ID\\n\",\n    \"3. Exec Name: Example: Filter, HashAggregate\\n\",\n    \"4. Expression Name\\n\",\n    \"5. Task Speedup Factor: The average acceleration of the operators based on the original CPU duration of the operator divided by the GPU duration. The tool uses historical queries and benchmarks to estimate a speed-up at an individual operator level to calculate how much a specific operator would accelerate on GPU.\\n\",\n    \"6. Exec Duration: Wall-clock time measured from when the operator starts until it is completed.\\n\",\n    \"7. SQL Node ID\\n\",\n    \"8. Exec Is Supported: Indicates whether the Exec is supported by RAPIDS. Refer to the Supported Operators section for details.\\n\",\n    \"9. Exec Stages: An array of stage IDs.\\n\",\n    \"10. Exec Children\\n\",\n    \"11. Exec Children Node IDs\\n\",\n    \"12. Exec Should Remove: Indicates whether the Op is removed from the migrated plan.\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"application/vnd.databricks.v1+cell\": {\n     \"cellMetadata\": {\n      \"byteLimit\": 2048000,\n      \"rowLimit\": 10000\n     },\n     \"inputWidgets\": {},\n     \"nuid\": \"998b0c51-0cb6-408e-a01a-d1f5b1a61e1f\",\n     \"showTitle\": true,\n     \"title\": \"rapids_4_spark_qualification_output_execs.csv\"\n    },\n    \"jupyter\": {\n     \"source_hidden\": true\n    },\n    \"scrolled\": true,\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"try:\\n\",\n    \"    execs_output = pd.read_csv(\\n\",\n    \"        os.path.join(jar_output_folder,\\n\",\n    \"                     \\\"rapids_4_spark_qualification_output_execs.csv\\\"))\\n\",\n    \"    display(execs_output)\\n\",\n    \"except Exception as e:\\n\",\n    \"    error_msg = \\\"Unable to show Execs output\\\"\\n\",\n    \"    display_error_message(error_msg, e) \"\n   ]\n  }\n ],\n \"metadata\": {\n  \"application/vnd.databricks.v1+notebook\": {\n   \"dashboards\": [\n    {\n     \"elements\": [],\n     \"globalVars\": {},\n     \"guid\": \"\",\n     \"layoutOption\": {\n      \"grid\": true,\n      \"stack\": true\n     },\n     \"nuid\": \"91c1bfb2-695a-4e5c-8a25-848a433108dc\",\n     \"origId\": 2173122769183715,\n     \"title\": \"Executive View\",\n     \"version\": \"DashboardViewV1\",\n     \"width\": 1600\n    },\n    {\n     \"elements\": [],\n     \"globalVars\": {},\n     \"guid\": \"\",\n     \"layoutOption\": {\n      \"grid\": true,\n      \"stack\": true\n     },\n     \"nuid\": \"62243296-4562-4f06-90ac-d7a609f19c16\",\n     \"origId\": 2173122769183716,\n     \"title\": \"App View\",\n     \"version\": \"DashboardViewV1\",\n     \"width\": 1920\n    },\n    {\n     \"elements\": [],\n     \"globalVars\": {},\n     \"guid\": \"\",\n     \"layoutOption\": {\n      \"grid\": true,\n      \"stack\": true\n     },\n     \"nuid\": \"854f9c75-5977-42aa-b3dd-c680b8331f19\",\n     \"origId\": 2173122769183722,\n     \"title\": \"Untitled\",\n     \"version\": \"DashboardViewV1\",\n     \"width\": 1024\n    }\n   ],\n   \"environmentMetadata\": null,\n   \"language\": \"python\",\n   \"notebookMetadata\": {\n    \"mostRecentlyExecutedCommandWithImplicitDF\": {\n     \"commandId\": 2173122769183704,\n     \"dataframes\": [\n      \"_sqldf\"\n     ]\n    },\n    \"pythonIndentUnit\": 2,\n    \"widgetLayout\": [\n     {\n      \"breakBefore\": false,\n      \"name\": \"Eventlog Path\",\n      \"width\": 778\n     },\n     {\n      \"breakBefore\": false,\n      \"name\": \"Output Path\",\n      \"width\": 302\n     }\n    ]\n   },\n   \"notebookName\": \"[RAPIDS Accelerator for Apache Spark] Qualification Tool Notebook Template\"\n  },\n  \"kernelspec\": {\n   \"display_name\": \"Python 3\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.9.18\"\n  },\n  \"widgets\": {\n   \"application/vnd.jupyter.widget-state+json\": {\n    \"state\": {},\n    \"version_major\": 2,\n    \"version_minor\": 0\n   }\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 4\n}\n"
  }
]