[
  {
    "path": ".github/actions/build-whl/action.yml",
    "content": "name: 'Build Whl'\nauthor: 'EnricoMi'\ndescription: 'A GitHub Action that builds pyspark-extension package'\n\ninputs:\n  spark-version:\n    description: Spark version, e.g. 3.4.0, 3.4.0-SNAPSHOT, or 4.0.0-preview1\n    required: true\n  scala-version:\n    description: Scala version, e.g. 2.12.15\n    required: true\n  spark-compat-version:\n    description: Spark compatibility version, e.g. 3.4\n    required: true\n  scala-compat-version:\n    description: Scala compatibility version, e.g. 2.12\n    required: true\n  java-compat-version:\n    description: Java compatibility version, e.g. 8\n    required: true\n  python-version:\n    description: Python version, e.g. 3.8\n    required: true\n\nruns:\n  using: 'composite'\n  steps:\n  - name: Fetch Binaries Artifact\n    uses: actions/download-artifact@v4\n    with:\n      name: Binaries-${{ inputs.spark-compat-version }}-${{ inputs.scala-compat-version }}\n      path: .\n\n  - name: Set versions in pom.xml\n    run: |\n      ./set-version.sh ${{ inputs.spark-version }} ${{ inputs.scala-version }}\n      git diff\n    shell: bash\n\n  - name: Make this work with PySpark preview versions\n    if: contains(inputs.spark-version, 'preview')\n    run: |\n      sed -i -e 's/f\"\\(pyspark~=.*\\)\"/f\"\\1.dev1\"/' -e 's/f\"\\({spark_compat_version}.0\\)\"/\"${{ inputs.spark-version }}\"/g' python/setup.py\n      git diff python/setup.py\n    shell: bash\n\n  - name: Restore Maven packages cache\n    if: github.event_name != 'schedule'\n    uses: actions/cache/restore@v4\n    with:\n      path: ~/.m2/repository\n      key: ${{ runner.os }}-mvn-build-${{ inputs.spark-version }}-${{ inputs.scala-version }}-${{ hashFiles('pom.xml') }}\n      restore-keys: |\n        ${{ runner.os }}-mvn-build-${{ inputs.spark-version }}-${{ inputs.scala-version }}-${{ hashFiles('pom.xml') }}\n        ${{ runner.os }}-mvn-build-${{ inputs.spark-version }}-${{ inputs.scala-version }}-\n\n  - name: Setup JDK ${{ inputs.java-compat-version }}\n    uses: actions/setup-java@v4\n    with:\n      java-version: ${{ inputs.java-compat-version }}\n      distribution: 'zulu'\n\n  - name: Fetch Release Test Dependencies\n    run: |\n      # Fetch Release Test Dependencies\n      echo \"::group::mvn dependency:get\"\n      mvn dependency:get -Dtransitive=false -Dartifact=org.apache.parquet:parquet-hadoop:1.16.0:jar:tests\n      echo \"::endgroup::\"\n    shell: bash\n\n  - name: Setup Python\n    uses: actions/setup-python@v5\n    with:\n      python-version: ${{ inputs.python-version }}\n\n  - name: Install Python dependencies\n    run: |\n      # Install Python dependencies\n      echo \"::group::mvn compile\"\n      python -m pip install --upgrade pip build twine\n      echo \"::endgroup::\"\n    shell: bash\n\n  - name: Build whl\n    run: |\n      # Build whl\n      echo \"::group::build-whl.sh\"\n      ./build-whl.sh\n      echo \"::endgroup::\"\n    shell: bash\n\n  - name: Test whl\n    run: |\n      # Test whl\n      echo \"::group::test-release.py\"\n      twine check python/dist/*\n      # .dev1 allows this to work with preview versions\n      pip install python/dist/*.whl \"pyspark~=${{ inputs.spark-compat-version }}.0.dev1\"\n      python test-release.py\n      echo \"::endgroup::\"\n    shell: bash\n\n  - name: Upload whl\n    uses: actions/upload-artifact@v4\n    with:\n      name: Whl (Spark ${{ inputs.spark-compat-version }} Scala ${{ inputs.scala-compat-version }})\n      path: |\n        python/dist/*.whl\n\n  - name: Build whl with mvn\n    env:\n      JDK_JAVA_OPTIONS: --add-exports java.base/sun.nio.ch=ALL-UNNAMED --add-exports java.base/sun.util.calendar=ALL-UNNAMED\n    run: |\n      # Build whl with mvn\n      rm -rf target python/dist python/pyspark_extension.egg-info pyspark/jars/*.jar\n      echo \"::group::build-whl.sh\"\n      ./build-whl.sh\n      echo \"::endgroup::\"\n    shell: bash\n\nbranding:\n  icon: 'check-circle'\n  color: 'green'\n"
  },
  {
    "path": ".github/actions/check-compat/action.yml",
    "content": "name: 'Check'\nauthor: 'EnricoMi'\ndescription: 'A GitHub Action that checks compatibility of spark-extension'\n\ninputs:\n  spark-version:\n    description: Spark version, e.g. 3.4.0 or 3.4.0-SNAPSHOT\n    required: true\n  scala-version:\n    description: Scala version, e.g. 2.12.15\n    required: true\n  spark-compat-version:\n    description: Spark compatibility version, e.g. 3.4\n    required: true\n  scala-compat-version:\n    description: Scala compatibility version, e.g. 2.12\n    required: true\n  package-version:\n    description: Spark-Extension version to check against\n    required: true\n\nruns:\n  using: 'composite'\n  steps:\n  - name: Fetch Binaries Artifact\n    uses: actions/download-artifact@v4\n    with:\n      name: Binaries-${{ inputs.spark-compat-version }}-${{ inputs.scala-compat-version }}\n      path: .\n\n  - name: Set versions in pom.xml\n    run: |\n      ./set-version.sh ${{ inputs.spark-version }} ${{ inputs.scala-version }}\n      git diff\n    shell: bash\n\n  - name: Restore Maven packages cache\n    if: github.event_name != 'schedule'\n    uses: actions/cache/restore@v4\n    with:\n      path: ~/.m2/repository\n      key: ${{ runner.os }}-mvn-build-${{ inputs.spark-version }}-${{ inputs.scala-version }}-${{ hashFiles('pom.xml') }}\n      restore-keys: |\n        ${{ runner.os }}-mvn-build-${{ inputs.spark-version }}-${{ inputs.scala-version }}-${{ hashFiles('pom.xml') }}\n        ${{ runner.os }}-mvn-build-${{ inputs.spark-version }}-${{ inputs.scala-version }}-\n\n  - name: Setup JDK 1.8\n    uses: actions/setup-java@v4\n    with:\n      java-version: '8'\n      distribution: 'zulu'\n\n  - name: Install Checker\n    run: |\n      # Install Checker\n      echo \"::group::apt update install\"\n      sudo apt update\n      sudo apt install japi-compliance-checker\n      echo \"::endgroup::\"\n    shell: bash\n\n  - name: Release exists\n    id: exists\n    continue-on-error: true\n    run: |\n      # Release exists\n      curl --head --fail https://repo1.maven.org/maven2/uk/co/gresearch/spark/spark-extension_${{ inputs.scala-compat-version }}/${{ inputs.package-version }}-${{ inputs.spark-compat-version }}/spark-extension_${{ inputs.scala-compat-version }}-${{ inputs.package-version }}-${{ inputs.spark-compat-version }}.jar\n    shell: bash\n\n  - name: Fetch package\n    if: steps.exists.outcome == 'success'\n    run: |\n      # Fetch package\n      echo \"::group::mvn dependency:get\"\n      mvn dependency:get -Dtransitive=false -DremoteRepositories -Dartifact=uk.co.gresearch.spark:spark-extension_${{ inputs.scala-compat-version }}:${{ inputs.package-version }}-${{ inputs.spark-compat-version }}\n      echo \"::endgroup::\"\n    shell: bash\n\n  - name: Check\n    if: steps.exists.outcome == 'success'\n    continue-on-error: ${{ github.ref == 'refs/heads/master' }}\n    run: |\n      # Check\n      echo \"::group::japi-compliance-checker\"\n      ls -lah ~/.m2/repository/uk/co/gresearch/spark/spark-extension_${{ inputs.scala-compat-version }}/${{ inputs.package-version }}-${{ inputs.spark-compat-version }}/spark-extension_${{ inputs.scala-compat-version }}-${{ inputs.package-version }}-${{ inputs.spark-compat-version }}.jar target/spark-extension*.jar\n      japi-compliance-checker ~/.m2/repository/uk/co/gresearch/spark/spark-extension_${{ inputs.scala-compat-version }}/${{ inputs.package-version }}-${{ inputs.spark-compat-version }}/spark-extension_${{ inputs.scala-compat-version }}-${{ inputs.package-version }}-${{ inputs.spark-compat-version }}.jar target/spark-extension*.jar\n      echo \"::endgroup::\"\n    shell: bash\n\n  - name: Upload Report\n    uses: actions/upload-artifact@v4\n    if: always() && steps.exists.outcome == 'success'\n    with:\n      name: Compat-Report-${{ inputs.spark-compat-version }}\n      path: compat_reports/spark-extension/*\n\nbranding:\n  icon: 'check-circle'\n  color: 'green'\n"
  },
  {
    "path": ".github/actions/prime-caches/action.yml",
    "content": "name: 'Prime caches'\nauthor: 'EnricoMi'\ndescription: 'A GitHub Action that primes caches'\n\ninputs:\n  spark-version:\n    description: Spark version, e.g. 3.4.0 or 3.4.0-SNAPSHOT\n    required: true\n  scala-version:\n    description: Scala version, e.g. 2.12.15\n    required: true\n  spark-compat-version:\n    description: Spark compatibility version, e.g. 3.4\n    required: true\n  scala-compat-version:\n    description: Scala compatibility version, e.g. 2.12\n    required: true\n  java-compat-version:\n    description: Java compatibility version, e.g. 8\n    required: true\n  hadoop-version:\n    description: Hadoop version, e.g. 2.7 or 2\n    required: true\n\nruns:\n  using: 'composite'\n  steps:\n  - name: Set versions in pom.xml\n    run: |\n      ./set-version.sh ${{ inputs.spark-version }} ${{ inputs.scala-version }}\n      git diff\n    shell: bash\n\n  - name: Check Maven packages cache\n    id: mvn-build-cache\n    uses: actions/cache/restore@v4\n    with:\n      lookup-only: true\n      path: ~/.m2/repository\n      key: ${{ runner.os }}-mvn-build-${{ inputs.spark-version }}-${{ inputs.scala-version }}-${{ hashFiles('pom.xml') }}\n\n  - name: Check Spark Binaries cache\n    id: spark-binaries-cache\n    uses: actions/cache/restore@v4\n    with:\n      lookup-only: true\n      path: ~/spark\n      key: ${{ runner.os }}-spark-binaries-${{ inputs.spark-version }}-${{ inputs.scala-compat-version }}\n\n  - name: Prepare priming caches\n    id: setup\n    run: |\n      # Prepare priming caches\n      if [[ \"${{ inputs.spark-version }}\" == *\"-SNAPSHOT\" ]] || [[ -z \"${{ steps.mvn-build-cache.outputs.cache-hit }}\" ]]; then\n        echo \"prime-mvn-cache=true\" >> \"$GITHUB_ENV\"\n        echo \"prime-some-cache=true\" >> \"$GITHUB_ENV\"\n      fi;\n      if [[ \"${{ inputs.spark-version }}\" == *\"-SNAPSHOT\" ]] || [[ -z \"${{ steps.spark-binaries-cache.outputs.cache-hit }}\" ]]; then\n        echo \"prime-spark-cache=true\" >> \"$GITHUB_ENV\"\n        echo \"prime-some-cache=true\" >> \"$GITHUB_ENV\"\n      fi;\n    shell: bash\n\n  - name: Setup JDK ${{ inputs.java-compat-version }}\n    if: env.prime-some-cache\n    uses: actions/setup-java@v4\n    with:\n      java-version: ${{ inputs.java-compat-version }}\n      distribution: 'zulu'\n\n  - name: Build\n    if: env.prime-mvn-cache\n    env:\n      JDK_JAVA_OPTIONS: --add-exports java.base/sun.nio.ch=ALL-UNNAMED --add-exports java.base/sun.util.calendar=ALL-UNNAMED\n    run: |\n      # Build\n      echo \"::group::mvn dependency:go-offline\"\n      mvn --batch-mode dependency:go-offline\n      echo \"::endgroup::\"\n    shell: bash\n\n  - name: Save Maven packages cache\n    if: env.prime-mvn-cache\n    uses: actions/cache/save@v4\n    with:\n      path: ~/.m2/repository\n      key: ${{ runner.os }}-mvn-build-${{ inputs.spark-version }}-${{ inputs.scala-version }}-${{ hashFiles('pom.xml') }}-${{ github.run_id }}\n\n  - name: Setup Spark Binaries\n    if: env.prime-spark-cache && ! contains(inputs.spark-version, '-SNAPSHOT')\n    env:\n      SPARK_PACKAGE: spark-${{ inputs.spark-version }}/spark-${{ inputs.spark-version }}-bin-hadoop${{ inputs.hadoop-version }}${{ startsWith(inputs.spark-version, '3.') && inputs.scala-compat-version == '2.13' && '-scala2.13' || '' }}.tgz\n    run: |\n      wget --progress=dot:giga \"https://www.apache.org/dyn/closer.lua/spark/${SPARK_PACKAGE}?action=download\" -O - | tar -xzC \"${{ runner.temp }}\"\n      archive=$(basename \"${SPARK_PACKAGE}\") bash -c \"mv -v \"${{ runner.temp }}/\\${archive/%.tgz/}\" ~/spark\"\n    shell: bash\n\n  - name: Save Spark Binaries cache\n    if: env.prime-spark-cache && ! contains(inputs.spark-version, '-SNAPSHOT')\n    uses: actions/cache/save@v4\n    with:\n      path: ~/spark\n      key: ${{ runner.os }}-spark-binaries-${{ inputs.spark-version }}-${{ inputs.scala-compat-version }}-${{ github.run_id }}\n\nbranding:\n  icon: 'check-circle'\n  color: 'green'\n"
  },
  {
    "path": ".github/actions/test-jvm/action.yml",
    "content": "name: 'Test JVM'\nauthor: 'EnricoMi'\ndescription: 'A GitHub Action that tests JVM spark-extension'\n\ninputs:\n  spark-version:\n    description: Spark version, e.g. 3.4.0, 3.4.0-SNAPSHOT or 4.0.0-preview1\n    required: true\n  spark-compat-version:\n    description: Spark compatibility version, e.g. 3.4\n    required: true\n  spark-archive-url:\n    description: The URL to download the Spark binary distribution\n    required: false\n  scala-version:\n    description: Scala version, e.g. 2.12.15\n    required: true\n  scala-compat-version:\n    description: Scala compatibility version, e.g. 2.12\n    required: true\n  hadoop-version:\n    description: Hadoop version, e.g. 2.7 or 2\n    required: true\n  java-compat-version:\n    description: Java compatibility version, e.g. 8\n    required: true\n\nruns:\n  using: 'composite'\n  steps:\n  - name: Fetch Binaries Artifact\n    uses: actions/download-artifact@v4\n    with:\n      name: Binaries-${{ inputs.spark-compat-version }}-${{ inputs.scala-compat-version }}\n      path: .\n\n  - name: Set versions in pom.xml\n    run: |\n      ./set-version.sh ${{ inputs.spark-version }} ${{ inputs.scala-version }}\n      git diff\n    shell: bash\n\n  - name: Restore Spark Binaries cache\n    if: github.event_name != 'schedule' && ! contains(inputs.spark-version, '-SNAPSHOT')\n    uses: actions/cache/restore@v4\n    with:\n      path: ~/spark\n      key: ${{ runner.os }}-spark-binaries-${{ inputs.spark-version }}-${{ inputs.scala-compat-version }}\n      restore-keys: |\n        ${{ runner.os }}-spark-binaries-${{ inputs.spark-version }}-${{ inputs.scala-compat-version }}\n\n  - name: Setup Spark Binaries\n    if: ( ! contains(inputs.spark-version, '-SNAPSHOT') )\n    env:\n      SPARK_PACKAGE: spark-${{ inputs.spark-version }}/spark-${{ inputs.spark-version }}-bin-hadoop${{ inputs.hadoop-version }}${{ startsWith(inputs.spark-version, '3.') && inputs.scala-compat-version == '2.13' && '-scala2.13' || '' }}.tgz\n    run: |\n      # Setup Spark Binaries\n      if [[ ! -e ~/spark ]]\n      then\n        url=\"${{ inputs.spark-archive-url }}\"\n        wget --progress=dot:giga \"${url:-https://www.apache.org/dyn/closer.lua/spark/${SPARK_PACKAGE}?action=download}\" -O - | tar -xzC \"${{ runner.temp }}\"\n        archive=$(basename \"${SPARK_PACKAGE}\") bash -c \"mv -v \"${{ runner.temp }}/\\${archive/%.tgz/}\" ~/spark\"\n      fi\n      echo \"SPARK_HOME=$(cd ~/spark; pwd)\" >> $GITHUB_ENV\n    shell: bash\n\n  - name: Restore Maven packages cache\n    if: github.event_name != 'schedule'\n    uses: actions/cache/restore@v4\n    with:\n      path: ~/.m2/repository\n      key: ${{ runner.os }}-mvn-build-${{ inputs.spark-version }}-${{ inputs.scala-version }}-${{ hashFiles('pom.xml') }}\n      restore-keys: |\n        ${{ runner.os }}-mvn-build-${{ inputs.spark-version }}-${{ inputs.scala-version }}-${{ hashFiles('pom.xml') }}\n        ${{ runner.os }}-mvn-build-${{ inputs.spark-version }}-${{ inputs.scala-version }}-\n\n  - name: Setup JDK ${{ inputs.java-compat-version }}\n    uses: actions/setup-java@v4\n    with:\n      java-version: ${{ inputs.java-compat-version }}\n      distribution: 'zulu'\n\n  - name: Scala and Java Tests\n    env:\n      JDK_JAVA_OPTIONS: --add-exports java.base/sun.nio.ch=ALL-UNNAMED --add-exports java.base/sun.util.calendar=ALL-UNNAMED\n    run: |\n      # Scala and Java Tests\n      echo \"::group::mvn test\"\n      mvn --batch-mode --update-snapshots -Dspotless.check.skip test integration-test\n      echo \"::endgroup::\"\n    shell: bash\n\n  - name: Upload Test Results\n    if: always()\n    uses: actions/upload-artifact@v4\n    with:\n      name: JVM Test Results (Spark ${{ inputs.spark-version }} Scala ${{ inputs.scala-version }})\n      path: |\n        target/surefire-*reports/*.xml\n\nbranding:\n  icon: 'check-circle'\n  color: 'green'\n"
  },
  {
    "path": ".github/actions/test-python/action.yml",
    "content": "name: 'Test Python'\nauthor: 'EnricoMi'\ndescription: 'A GitHub Action that tests Python spark-extension'\n\n# pyspark is not available for snapshots or scala other than 2.12\n# we would have to compile spark from sources for this, not worth it\n# so this action only works with scala 2.12 and non-snapshot spark versions\ninputs:\n  spark-version:\n    description: Spark version, e.g. 3.4.0 or 4.0.0-preview1\n    required: true\n  scala-version:\n    description: Scala version, e.g. 2.12.15\n    required: true\n  spark-compat-version:\n    description: Spark compatibility version, e.g. 3.4\n    required: true\n  spark-archive-url:\n    description: The URL to download the Spark binary distribution\n    required: false\n  spark-package-repo:\n    description: The URL of an alternate maven repository to fetch Spark packages\n    required: false\n  scala-compat-version:\n    description: Scala compatibility version, e.g. 2.12\n    required: true\n  java-compat-version:\n    description: Java compatibility version, e.g. 8\n    required: true\n  hadoop-version:\n    description: Hadoop version, e.g. 2.7 or 2\n    required: true\n  python-version:\n    description: Python version, e.g. 3.8\n    required: true\n\nruns:\n  using: 'composite'\n  steps:\n  - name: Fetch Binaries Artifact\n    uses: actions/download-artifact@v4\n    with:\n      name: Binaries-${{ inputs.spark-compat-version }}-${{ inputs.scala-compat-version }}\n      path: .\n\n  - name: Set versions in pom.xml\n    run: |\n      ./set-version.sh ${{ inputs.spark-version }} ${{ inputs.scala-version }}\n      git diff\n\n      SPARK_EXTENSION_VERSION=$(grep --max-count=1 \"<version>.*</version>\" pom.xml | sed -E -e \"s/\\s*<[^>]+>//g\")\n      echo \"SPARK_EXTENSION_VERSION=$SPARK_EXTENSION_VERSION\" | tee -a \"$GITHUB_ENV\"\n    shell: bash\n\n  - name: Make this work with PySpark preview versions\n    if: contains(inputs.spark-version, 'preview')\n    run: |\n      sed -i -e 's/\\({spark_compat_version}.0\\)\"/\\1.dev1\"/' python/setup.py\n      git diff python/setup.py\n    shell: bash\n\n  - name: Restore Spark Binaries cache\n    if: github.event_name != 'schedule' && ( startsWith(inputs.spark-version, '3.') && inputs.scala-compat-version == '2.12' || startsWith(inputs.spark-version, '4.') ) && ! contains(inputs.spark-version, '-SNAPSHOT')\n    uses: actions/cache/restore@v4\n    with:\n      path: ~/spark\n      key: ${{ runner.os }}-spark-binaries-${{ inputs.spark-version }}-${{ inputs.scala-compat-version }}\n      restore-keys: |\n        ${{ runner.os }}-spark-binaries-${{ inputs.spark-version }}-${{ inputs.scala-compat-version }}\n\n  - name: Setup Spark Binaries\n    if: ( startsWith(inputs.spark-version, '3.') && inputs.scala-compat-version == '2.12' || startsWith(inputs.spark-version, '4.') ) && ! contains(inputs.spark-version, '-SNAPSHOT')\n    env:\n      SPARK_PACKAGE: spark-${{ inputs.spark-version }}/spark-${{ inputs.spark-version }}-bin-hadoop${{ inputs.hadoop-version }}${{ startsWith(inputs.spark-version, '3.') && inputs.scala-compat-version == '2.13' && '-scala2.13' || '' }}.tgz\n    run: |\n      # Setup Spark Binaries\n      if [[ ! -e ~/spark ]]\n      then\n        url=\"${{ inputs.spark-archive-url }}\"\n        wget --progress=dot:giga \"${url:-https://www.apache.org/dyn/closer.lua/spark/${SPARK_PACKAGE}?action=download}\" -O - | tar -xzC \"${{ runner.temp }}\"\n        archive=$(basename \"${SPARK_PACKAGE}\") bash -c \"mv -v \"${{ runner.temp }}/\\${archive/%.tgz/}\" ~/spark\"\n      fi\n      echo \"SPARK_BIN_HOME=$(cd ~/spark; pwd)\" >> $GITHUB_ENV\n    shell: bash\n\n  - name: Restore Maven packages cache\n    if: github.event_name != 'schedule'\n    uses: actions/cache/restore@v4\n    with:\n      path: ~/.m2/repository\n      key: ${{ runner.os }}-mvn-build-${{ inputs.spark-version }}-${{ inputs.scala-version }}-${{ hashFiles('pom.xml') }}\n      restore-keys: |\n        ${{ runner.os }}-mvn-build-${{ inputs.spark-version }}-${{ inputs.scala-version }}-${{ hashFiles('pom.xml') }}\n        ${{ runner.os }}-mvn-build-${{ inputs.spark-version }}-${{ inputs.scala-version }}-\n\n  - name: Setup JDK ${{ inputs.java-compat-version }}\n    uses: actions/setup-java@v4\n    with:\n      java-version: ${{ inputs.java-compat-version }}\n      distribution: 'zulu'\n\n  - name: Setup Python\n    uses: actions/setup-python@v5\n    with:\n      python-version: ${{ inputs.python-version }}\n\n  - name: Install Python dependencies\n    run: |\n      # Install Python dependencies\n      echo \"::group::pip install\"\n      python -m venv .pytest-venv\n      .pytest-venv/bin/python -m pip install --upgrade pip\n      .pytest-venv/bin/pip install pypandoc\n      .pytest-venv/bin/pip install -e python/[test]\n      echo \"::endgroup::\"\n\n      PYSPARK_HOME=$(.pytest-venv/bin/python -c \"import os; import pyspark; print(os.path.dirname(pyspark.__file__))\")\n      PYSPARK_BIN_HOME=\"$(cd \".pytest-venv/\"; pwd)\"\n      PYSPARK_PYTHON=\"$PYSPARK_BIN_HOME/bin/python\"\n      echo \"PYSPARK_HOME=$PYSPARK_HOME\" | tee -a \"$GITHUB_ENV\"\n      echo \"PYSPARK_BIN_HOME=$PYSPARK_BIN_HOME\" | tee -a \"$GITHUB_ENV\"\n      echo \"PYSPARK_PYTHON=$PYSPARK_PYTHON\" | tee -a \"$GITHUB_ENV\"\n    shell: bash\n\n  - name: Prepare Poetry tests\n    run: |\n      # Prepare Poetry tests\n      echo \"::group::Prepare poetry tests\"\n      # install poetry in venv\n      python -m venv .poetry-venv\n      .poetry-venv/bin/python -m pip install poetry\n      # env var needed by poetry tests\n      echo \"POETRY_PYTHON=$PWD/.poetry-venv/bin/python\" | tee -a \"$GITHUB_ENV\"\n\n      # clone example poetry project\n      git clone https://github.com/Textualize/rich.git .rich\n      cd .rich\n      git reset --hard 20024635c06c22879fd2fd1e380ec4cccd9935dd\n      # env var needed by poetry tests\n      echo \"RICH_SOURCES=$PWD\" | tee -a \"$GITHUB_ENV\"\n      echo \"::endgroup::\"\n    shell: bash\n\n  - name: Python Unit Tests\n    env:\n      SPARK_HOME: ${{ env.PYSPARK_HOME }}\n      PYTHONPATH: python/test\n    run: |\n      .pytest-venv/bin/python -m pytest python/test --junit-xml test-results/pytest-$(date +%s.%N)-$RANDOM.xml\n    shell: bash\n\n  - name: Install Spark Extension\n    run: |\n      # Install Spark Extension\n      echo \"::group::mvn install\"\n      mvn --batch-mode --update-snapshots install -Dspotless.check.skip -DskipTests -Dmaven.test.skip=true -Dgpg.skip\n      echo \"::endgroup::\"\n    shell: bash\n\n  - name: Start Spark Connect\n    id: spark-connect\n    if: ( contains('3.4,3.5', inputs.spark-compat-version) && inputs.scala-compat-version == '2.12' || startsWith(inputs.spark-version, '4.') ) && ! contains(inputs.spark-version, '-SNAPSHOT')\n    env:\n      SPARK_HOME: ${{ env.SPARK_BIN_HOME }}\n      CONNECT_GRPC_BINDING_ADDRESS: 127.0.0.1\n      CONNECT_GRPC_BINDING_PORT: 15002\n    run: |\n      # Start Spark Connect\n      for attempt in {1..10}; do\n        $SPARK_HOME/sbin/start-connect-server.sh --packages org.apache.spark:spark-connect_${{ inputs.scala-compat-version }}:${{ inputs.spark-version }} --repositories \"${{ inputs.spark-package-repo }}\"\n        sleep 10\n        for log in $SPARK_HOME/logs/spark-*-org.apache.spark.sql.connect.service.SparkConnectServer-*.out; do\n          echo \"::group::Spark Connect server log: $log\"\n          eoc=\"EOC-$RANDOM\"\n          echo \"::stop-commands::$eoc\"\n          cat \"$log\" || true\n          echo \"::$eoc::\"\n          echo \"::endgroup::\"\n        done\n\n        if netstat -an | grep 15002; then\n          break;\n        fi\n        echo \"::warning title=Starting Spark Connect server failed::Attempt #$attempt to start Spark Connect server failed\"\n        $SPARK_HOME/sbin/stop-connect-server.sh --packages org.apache.spark:spark-connect_${{ inputs.scala-compat-version }}:${{ inputs.spark-version }}\n        sleep 5\n      done\n\n      if ! netstat -an | grep 15002; then\n        echo \"::error title=Starting Spark Connect server failed::All attempts to start Spark Connect server failed\"\n        exit 1\n      fi\n    shell: bash\n\n  - name: Python Unit Tests (Spark Connect)\n    if: steps.spark-connect.outcome == 'success'\n    env:\n      SPARK_HOME: ${{ env.PYSPARK_HOME }}\n      PYTHONPATH: python/test\n      TEST_SPARK_CONNECT_SERVER: sc://127.0.0.1:15002\n    run: |\n      # Python Unit Tests (Spark Connect)\n\n      echo \"::group::pip install\"\n      # .dev1 allows this to work with preview versions\n      .pytest-venv/bin/pip install \"pyspark[connect]~=${{ inputs.spark-compat-version }}.0.dev1\"\n      echo \"::endgroup::\"\n\n      .pytest-venv/bin/python -m pytest python/test --junit-xml test-results-connect/pytest-$(date +%s.%N)-$RANDOM.xml\n    shell: bash\n\n  - name: Stop Spark Connect\n    if: always() && steps.spark-connect.outcome == 'success'\n    env:\n      SPARK_HOME: ${{ env.SPARK_BIN_HOME }}\n    run: |\n      # Stop Spark Connect\n      $SPARK_HOME/sbin/stop-connect-server.sh\n      for log in $SPARK_HOME/logs/spark-*-org.apache.spark.sql.connect.service.SparkConnectServer-*.out; do\n        echo \"::group::Spark Connect server log: $log\"\n        eoc=\"EOC-$RANDOM\"\n        echo \"::stop-commands::$eoc\"\n        cat \"$log\" || true\n        echo \"::$eoc::\"\n        echo \"::endgroup::\"\n      done\n    shell: bash\n\n  - name: Upload Test Results\n    if: always()\n    uses: actions/upload-artifact@v4\n    with:\n      name: Python Test Results (Spark ${{ inputs.spark-version }} Scala ${{ inputs.scala-version }} Python ${{ inputs.python-version }})\n      path: |\n        test-results/*.xml\n        test-results-connect/*.xml\n\nbranding:\n  icon: 'check-circle'\n  color: 'green'\n"
  },
  {
    "path": ".github/actions/test-release/action.yml",
    "content": "name: 'Test Release'\nauthor: 'EnricoMi'\ndescription: 'A GitHub Action that tests spark-extension release'\n\n# pyspark is not available for snapshots or scala other than 2.12\n# we would have to compile spark from sources for this, not worth it\n# so this action only works with scala 2.12 and non-snapshot spark versions\ninputs:\n  spark-version:\n    description: Spark version, e.g. 3.4.0 or 4.0.0-preview1\n    required: true\n  scala-version:\n    description: Scala version, e.g. 2.12.15\n    required: true\n  spark-compat-version:\n    description: Spark compatibility version, e.g. 3.4\n    required: true\n  spark-archive-url:\n    description: The URL to download the Spark binary distribution\n    required: false\n  scala-compat-version:\n    description: Scala compatibility version, e.g. 2.12\n    required: true\n  java-compat-version:\n    description: Java compatibility version, e.g. 8\n    required: true\n  hadoop-version:\n    description: Hadoop version, e.g. 2.7 or 2\n    required: true\n  python-version:\n    description: Python version, e.g. 3.8\n    default: ''\n    required: false\n\nruns:\n  using: 'composite'\n  steps:\n  - name: Fetch Binaries Artifact\n    uses: actions/download-artifact@v4\n    with:\n      name: Binaries-${{ inputs.spark-compat-version }}-${{ inputs.scala-compat-version }}\n      path: .\n\n  - name: Set versions in pom.xml\n    run: |\n      ./set-version.sh ${{ inputs.spark-version }} ${{ inputs.scala-version }}\n      git diff\n\n      SPARK_EXTENSION_VERSION=$(grep --max-count=1 \"<version>.*</version>\" pom.xml | sed -E -e \"s/\\s*<[^>]+>//g\")\n      echo \"SPARK_EXTENSION_VERSION=$SPARK_EXTENSION_VERSION\" | tee -a \"$GITHUB_ENV\"\n    shell: bash\n\n  - name: Restore Spark Binaries cache\n    if: github.event_name != 'schedule'\n    uses: actions/cache/restore@v4\n    with:\n      path: ~/spark\n      key: ${{ runner.os }}-spark-binaries-${{ inputs.spark-version }}-${{ inputs.scala-compat-version }}\n      restore-keys: |\n        ${{ runner.os }}-spark-binaries-${{ inputs.spark-version }}-${{ inputs.scala-compat-version }}\n\n  - name: Setup Spark Binaries\n    env:\n      SPARK_PACKAGE: spark-${{ inputs.spark-version }}/spark-${{ inputs.spark-version }}-bin-hadoop${{ inputs.hadoop-version }}${{ startsWith(inputs.spark-version, '3.') && inputs.scala-compat-version == '2.13' && '-scala2.13' || '' }}.tgz\n    run: |\n      # Setup Spark Binaries\n      if [[ ! -e ~/spark ]]\n      then\n        url=\"${{ inputs.spark-archive-url }}\"\n        wget --progress=dot:giga \"${url:-https://www.apache.org/dyn/closer.lua/spark/${SPARK_PACKAGE}?action=download}\" -O - | tar -xzC \"${{ runner.temp }}\"\n        archive=$(basename \"${SPARK_PACKAGE}\") bash -c \"mv -v \"${{ runner.temp }}/\\${archive/%.tgz/}\" ~/spark\"\n      fi\n      echo \"SPARK_BIN_HOME=$(cd ~/spark; pwd)\" >> $GITHUB_ENV\n    shell: bash\n\n  - name: Restore Maven packages cache\n    if: github.event_name != 'schedule'\n    uses: actions/cache/restore@v4\n    with:\n      path: ~/.m2/repository\n      key: ${{ runner.os }}-mvn-build-${{ inputs.spark-version }}-${{ inputs.scala-version }}-${{ hashFiles('pom.xml') }}\n      restore-keys: |\n        ${{ runner.os }}-mvn-build-${{ inputs.spark-version }}-${{ inputs.scala-version }}-${{ hashFiles('pom.xml') }}\n        ${{ runner.os }}-mvn-build-${{ inputs.spark-version }}-${{ inputs.scala-version }}-\n\n  - name: Setup JDK ${{ inputs.java-compat-version }}\n    uses: actions/setup-java@v4\n    with:\n      java-version: ${{ inputs.java-compat-version }}\n      distribution: 'zulu'\n\n  - name: Diff App test\n    env:\n      SPARK_HOME: ${{ env.SPARK_BIN_HOME }}\n    run: |\n      # Diff App test\n      echo \"::group::spark-submit\"\n      $SPARK_HOME/bin/spark-submit --packages com.github.scopt:scopt_${{ inputs.scala-compat-version }}:4.1.0 target/spark-extension_*.jar --format parquet --id id src/test/files/test.parquet/file1.parquet src/test/files/test.parquet/file2.parquet diff.parquet\n      echo\n      echo \"::endgroup::\"\n\n      echo \"::group::spark-shell\"\n      $SPARK_HOME/bin/spark-shell <<< 'val df = spark.read.parquet(\"diff.parquet\").orderBy($\"id\").groupBy($\"diff\").count; df.show; if (df.count != 2) sys.exit(1)'\n      echo\n      echo \"::endgroup::\"\n    shell: bash\n\n  - name: Install Spark Extension\n    run: |\n      # Install Spark Extension\n      echo \"::group::mvn install\"\n      mvn --batch-mode --update-snapshots install -Dspotless.check.skip -DskipTests -Dmaven.test.skip=true -Dgpg.skip\n      echo \"::endgroup::\"\n    shell: bash\n\n  - name: Fetch Release Test Dependencies\n    run: |\n      # Fetch Release Test Dependencies\n      echo \"::group::mvn dependency:get\"\n      mvn dependency:get -Dtransitive=false -Dartifact=org.apache.parquet:parquet-hadoop:1.16.0:jar:tests\n      echo \"::endgroup::\"\n    shell: bash\n\n  - name: Scala Release Test\n    env:\n      SPARK_HOME: ${{ env.SPARK_BIN_HOME }}\n    run: |\n      # Scala Release Test\n      echo \"::group::spark-shell\"\n      $SPARK_BIN_HOME/bin/spark-shell --packages uk.co.gresearch.spark:spark-extension_${{ inputs.scala-compat-version }}:$SPARK_EXTENSION_VERSION --jars ~/.m2/repository/org/apache/parquet/parquet-hadoop/1.16.0/parquet-hadoop-1.16.0-tests.jar < test-release.scala\n      echo\n      echo \"::endgroup::\"\n    shell: bash\n\n  - name: Setup Python\n    uses: actions/setup-python@v5\n    if: inputs.python-version != ''\n    with:\n      python-version: ${{ inputs.python-version }}\n\n  - name: Python Release Test\n    if: inputs.python-version != ''\n    env:\n      SPARK_HOME: ${{ env.SPARK_BIN_HOME }}\n    run: |\n      # Python Release Test\n      echo \"::group::spark-submit\"\n      $SPARK_BIN_HOME/bin/spark-submit --packages uk.co.gresearch.spark:spark-extension_${{ inputs.scala-compat-version }}:$SPARK_EXTENSION_VERSION test-release.py\n      echo\n      echo \"::endgroup::\"\n    shell: bash\n\n  - name: Fetch Whl Artifact\n    if: inputs.python-version != ''\n    uses: actions/download-artifact@v4\n    with:\n      name: Whl (Spark ${{ inputs.spark-compat-version }} Scala ${{ inputs.scala-compat-version }})\n      path: .\n\n  - name: Install Python dependencies\n    if: inputs.python-version != ''\n    run: |\n      # Install Python dependencies\n      echo \"::group::pip install\"\n      python -m venv .pytest-venv\n      .pytest-venv/bin/python -m pip install --upgrade pip\n      .pytest-venv/bin/pip install pypandoc\n      .pytest-venv/bin/pip install $(ls pyspark_extension-*.whl)[test]\n      echo \"::endgroup::\"\n\n      PYSPARK_HOME=$(.pytest-venv/bin/python -c \"import os; import pyspark; print(os.path.dirname(pyspark.__file__))\")\n      PYSPARK_BIN_HOME=\"$(cd \".pytest-venv/\"; pwd)\"\n      PYSPARK_PYTHON=\"$PYSPARK_BIN_HOME/bin/python\"\n      echo \"PYSPARK_HOME=$PYSPARK_HOME\" | tee -a \"$GITHUB_ENV\"\n      echo \"PYSPARK_BIN_HOME=$PYSPARK_BIN_HOME\" | tee -a \"$GITHUB_ENV\"\n      echo \"PYSPARK_PYTHON=$PYSPARK_PYTHON\" | tee -a \"$GITHUB_ENV\"\n    shell: bash\n\n  - name: PySpark Release Test\n    if: inputs.python-version != ''\n    run: |\n      .pytest-venv/bin/python3 test-release.py\n    shell: bash\n\n  - name: Python Integration Tests\n    if: inputs.python-version != ''\n    env:\n      SPARK_HOME: ${{ env.PYSPARK_HOME }}\n      PYTHONPATH: python:python/test\n    run: |\n      # Python Integration Tests\n      source .pytest-venv/bin/activate\n      find python/test -name 'test*.py' > tests\n      while read test\n      do\n        echo \"::group::spark-submit $test\"\n        if ! $PYSPARK_BIN_HOME/bin/spark-submit --master \"local[2]\" --packages uk.co.gresearch.spark:spark-extension_${{ inputs.scala-compat-version }}:$SPARK_EXTENSION_VERSION \"$test\" test-results-submit\n        then\n          state=\"fail\"\n        fi\n        echo\n        echo \"::endgroup::\"\n      done < tests\n      if [[ \"$state\" == \"fail\" ]]; then exit 1; fi\n    shell: bash\n\n  - name: Upload Test Results\n    if: always() && inputs.python-version != ''\n    uses: actions/upload-artifact@v4\n    with:\n      name: Python Release Test Results (Spark ${{ inputs.spark-version }} Scala ${{ inputs.scala-version }} Python ${{ inputs.python-version }})\n      path: |\n        test-results-submit/*.xml\n\nbranding:\n  icon: 'check-circle'\n  color: 'green'\n"
  },
  {
    "path": ".github/dependabot.yml",
    "content": "version: 2\nupdates:\n  - package-ecosystem: \"github-actions\"\n    directory: \"/\"\n    schedule:\n      interval: \"monthly\"\n\n  - package-ecosystem: \"maven\"\n    directory: \"/\"\n    schedule:\n      interval: \"daily\"\n"
  },
  {
    "path": ".github/show-spark-versions.sh",
    "content": "#!/bin/bash\n\nbase=$(cd \"$(dirname \"$0\")\"; pwd)\n\ngrep -- \"-version\" \"$base\"/workflows/prime-caches.yml | sed -e \"s/ -//g\" -e \"s/ //g\" -e \"s/'//g\" | grep -v -e \"matrix\" -e \"]\" | while read line\ndo\n  IFS=\":\" read var compat_version <<< \"$line\"\n  if [[ \"$var\" == \"spark-compat-version\" ]]\n  then\n    while read line\n    do\n      IFS=\":\" read var patch_version <<< \"$line\"\n      if [[ \"$var\" == \"spark-patch-version\" ]]\n      then\n        echo -n \"spark-version: $compat_version.$patch_version\"\n        read line\n        if [[ \"$line\" == \"spark-snapshot-version:true\" ]]\n        then\n          echo \"-SNAPSHOT\"\n        else\n          echo\n        fi\n        break\n      fi\n    done\n  fi\ndone > \"$base\"/workflows/prime-caches.yml.tmp\n\ngrep spark-version \"$base\"/workflows/*.yml \"$base\"/workflows/prime-caches.yml.tmp | cut -d : -f 2- | sed -e \"s/^[ -]*//\" -e \"s/'//g\" -e 's/{\"params\": {\"//g' -e 's/params: {//g' -e 's/\"//g' -e \"s/,.*//\" | grep \"^spark-version\" | grep -v \"matrix\" | sort | uniq\n\n"
  },
  {
    "path": ".github/workflows/build-jvm.yml",
    "content": "name: Build JVM\n\non:\n  workflow_call:\n\njobs:\n  build:\n    name: Build (Spark ${{ matrix.spark-version }} Scala ${{ matrix.scala-version }})\n    runs-on: ubuntu-latest\n\n    strategy:\n      fail-fast: false\n      matrix:\n        include:\n          - spark-version: '3.2.4'\n            spark-compat-version: '3.2'\n            scala-compat-version: '2.12'\n            scala-version: '2.12.15'\n            java-compat-version: '8'\n            hadoop-version: '2.7'\n          - spark-version: '3.3.4'\n            spark-compat-version: '3.3'\n            scala-compat-version: '2.12'\n            scala-version: '2.12.15'\n            java-compat-version: '8'\n            hadoop-version: '3'\n          - spark-version: '3.4.4'\n            spark-compat-version: '3.4'\n            scala-compat-version: '2.12'\n            scala-version: '2.12.17'\n            java-compat-version: '8'\n            hadoop-version: '3'\n          - spark-version: '3.5.8'\n            spark-compat-version: '3.5'\n            scala-compat-version: '2.12'\n            scala-version: '2.12.18'\n            java-compat-version: '8'\n            hadoop-version: '3'\n\n          - spark-version: '3.2.4'\n            spark-compat-version: '3.2'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.5'\n            java-compat-version: '8'\n            hadoop-version: '3.2'\n          - spark-version: '3.3.4'\n            spark-compat-version: '3.3'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.8'\n            java-compat-version: '8'\n            hadoop-version: '3'\n          - spark-version: '3.4.4'\n            spark-compat-version: '3.4'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.8'\n            java-compat-version: '8'\n            hadoop-version: '3'\n          - spark-version: '3.5.8'\n            spark-compat-version: '3.5'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.8'\n            java-compat-version: '8'\n            hadoop-version: '3'\n          - spark-version: '4.0.2'\n            spark-compat-version: '4.0'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.16'\n            java-compat-version: '17'\n            hadoop-version: '3'\n          - spark-version: '4.1.1'\n            spark-compat-version: '4.1'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.17'\n            java-compat-version: '17'\n            hadoop-version: '3'\n          - spark-version: '4.2.0-preview3'\n            spark-compat-version: '4.2'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.18'\n            java-compat-version: '17'\n            hadoop-version: '3'\n\n    steps:\n      - name: Checkout\n        uses: actions/checkout@v4\n\n      - name: Build\n        uses: ./.github/actions/build\n        with:\n          spark-version: ${{ matrix.spark-version }}\n          scala-version: ${{ matrix.scala-version }}\n          spark-compat-version: ${{ matrix.spark-compat-version }}\n          scala-compat-version: ${{ matrix.scala-compat-version }}\n          java-compat-version: ${{ matrix.java-compat-version }}\n          hadoop-version: ${{ matrix.hadoop-version }}\n"
  },
  {
    "path": ".github/workflows/build-python.yml",
    "content": "name: Build Python\n\non:\n  workflow_call:\n\njobs:\n  # pyspark<4 is not available for snapshots or scala other than 2.12\n  whl:\n    name: Build whl (Spark ${{ matrix.spark-version }} Scala ${{ matrix.scala-version }})\n    runs-on: ubuntu-latest\n\n    strategy:\n      fail-fast: false\n      matrix:\n        include:\n          - spark-compat-version: '3.2'\n            spark-version: '3.2.4'\n            scala-compat-version: '2.12'\n            scala-version: '2.12.15'\n            java-compat-version: '8'\n            python-version: '3.9'\n          - spark-compat-version: '3.3'\n            spark-version: '3.3.4'\n            scala-compat-version: '2.12'\n            scala-version: '2.12.15'\n            java-compat-version: '8'\n            python-version: '3.9'\n          - spark-compat-version: '3.4'\n            spark-version: '3.4.4'\n            scala-compat-version: '2.12'\n            scala-version: '2.12.17'\n            java-compat-version: '8'\n            python-version: '3.9'\n          - spark-compat-version: '3.5'\n            spark-version: '3.5.8'\n            scala-compat-version: '2.12'\n            scala-version: '2.12.18'\n            java-compat-version: '8'\n            python-version: '3.9'\n          - spark-compat-version: '4.0'\n            spark-version: '4.0.2'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.16'\n            java-compat-version: '17'\n            python-version: '3.9'\n          - spark-version: '4.1.1'\n            spark-compat-version: '4.1'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.17'\n            java-compat-version: '17'\n            hadoop-version: '3'\n            python-version: '3.10'\n          - spark-version: '4.2.0-preview3'\n            spark-compat-version: '4.2'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.18'\n            java-compat-version: '17'\n            hadoop-version: '3'\n            python-version: '3.10'\n\n    steps:\n      - name: Checkout\n        uses: actions/checkout@v4\n\n      - name: Build\n        uses: ./.github/actions/build-whl\n        with:\n          spark-version: ${{ matrix.spark-version }}\n          scala-version: ${{ matrix.scala-version }}\n          spark-compat-version: ${{ matrix.spark-compat-version }}\n          scala-compat-version: ${{ matrix.scala-compat-version }}\n          java-compat-version: ${{ matrix.java-compat-version }}\n          python-version: ${{ matrix.python-version }}\n"
  },
  {
    "path": ".github/workflows/build-snapshots.yml",
    "content": "name: Build Snapshots\n\non:\n  workflow_call:\n\njobs:\n  build:\n    name: Build (Spark ${{ matrix.spark-version }} Scala ${{ matrix.scala-version }})\n    runs-on: ubuntu-latest\n\n    strategy:\n      fail-fast: false\n      matrix:\n        include:\n          - spark-compat-version: '3.2'\n            spark-version: '3.2.5-SNAPSHOT'\n            scala-compat-version: '2.12'\n            scala-version: '2.12.15'\n            java-compat-version: '8'\n          - spark-compat-version: '3.3'\n            spark-version: '3.3.5-SNAPSHOT'\n            scala-compat-version: '2.12'\n            scala-version: '2.12.15'\n            java-compat-version: '8'\n          - spark-compat-version: '3.4'\n            spark-version: '3.4.5-SNAPSHOT'\n            scala-compat-version: '2.12'\n            scala-version: '2.12.17'\n            java-compat-version: '8'\n          - spark-compat-version: '3.5'\n            spark-version: '3.5.9-SNAPSHOT'\n            scala-compat-version: '2.12'\n            scala-version: '2.12.18'\n            java-compat-version: '8'\n\n          - spark-compat-version: '3.2'\n            spark-version: '3.2.5-SNAPSHOT'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.5'\n            java-compat-version: '8'\n          - spark-compat-version: '3.3'\n            spark-version: '3.3.5-SNAPSHOT'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.8'\n            java-compat-version: '8'\n          - spark-compat-version: '3.4'\n            spark-version: '3.4.5-SNAPSHOT'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.8'\n            java-compat-version: '8'\n          - spark-compat-version: '3.5'\n            spark-version: '3.5.9-SNAPSHOT'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.8'\n            java-compat-version: '8'\n          - spark-compat-version: '4.0'\n            spark-version: '4.0.3-SNAPSHOT'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.16'\n            java-compat-version: '17'\n          - spark-compat-version: '4.1'\n            spark-version: '4.1.2-SNAPSHOT'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.17'\n            java-compat-version: '17'\n          - spark-compat-version: '4.2'\n            spark-version: '4.2.0-SNAPSHOT'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.18'\n            java-compat-version: '17'\n\n    steps:\n      - name: Checkout\n        uses: actions/checkout@v4\n\n      - name: Build\n        uses: ./.github/actions/build\n        with:\n          spark-version: ${{ matrix.spark-version }}\n          scala-version: ${{ matrix.scala-version }}\n          spark-compat-version: ${{ matrix.spark-compat-version }}-SNAPSHOT\n          scala-compat-version: ${{ matrix.scala-compat-version }}\n          java-compat-version: ${{ matrix.java-compat-version }}\n"
  },
  {
    "path": ".github/workflows/check.yml",
    "content": "name: Check\n\non:\n  workflow_call:\n\njobs:\n  lint:\n    name: Scala lint\n    runs-on: ubuntu-latest\n\n    steps:\n      - name: Checkout\n        uses: actions/checkout@v4\n        with:\n          fetch-depth: 0\n\n      - name: Setup JDK ${{ inputs.java-compat-version }}\n        uses: actions/setup-java@v4\n        with:\n          java-version: '11'\n          distribution: 'zulu'\n\n      - name: Check\n        id: check\n        run: |\n          mvn --batch-mode --update-snapshots spotless:check\n        shell: bash\n\n      - name: Changes\n        if: failure() && steps.check.outcome == 'failure'\n        run: |\n          mvn --batch-mode --update-snapshots spotless:apply\n          git diff\n        shell: bash\n\n  config:\n    name: Configure compat\n    runs-on: ubuntu-latest\n    outputs:\n      major-version: ${{ steps.versions.outputs.major-version }}\n      release-version: ${{ steps.versions.outputs.release-version }}\n      release-major-version: ${{ steps.versions.outputs.release-major-version }}\n\n    steps:\n      - name: Checkout\n        uses: actions/checkout@v4\n        with:\n          fetch-depth: 0\n\n      - name: Get versions\n        id: versions\n        run: |\n          version=$(grep -m1 version pom.xml | sed -e \"s/<[^>]*>//g\" -e \"s/ //g\")\n          echo \"version: $version\"\n          echo \"major-version: ${version/.*/}\"\n          echo \"version=$version\" >> \"$GITHUB_OUTPUT\"\n          echo \"major-version=${version/.*/}\" >> \"$GITHUB_OUTPUT\"\n          release_version=$(git tag | grep \"^v\" | sort --version-sort | tail -n1 | sed \"s/^v//\")\n          echo \"release-version: $release_version\"\n          echo \"release-major-version: ${release_version/.*/}\"\n          echo \"release-version=$release_version\" >> \"$GITHUB_OUTPUT\"\n          echo \"release-major-version=${release_version/.*/}\" >> \"$GITHUB_OUTPUT\"\n        shell: bash\n\n  compat:\n    name: Compat (Spark ${{ matrix.spark-compat-version }} Scala ${{ matrix.scala-compat-version }})\n    needs: config\n    runs-on: ubuntu-latest\n    if: needs.config.outputs.major-version == needs.config.outputs.release-major-version\n\n    strategy:\n      fail-fast: false\n      matrix:\n        include:\n          - spark-compat-version: '3.2'\n            spark-version: '3.2.4'\n            scala-compat-version: '2.12'\n            scala-version: '2.12.15'\n          - spark-compat-version: '3.3'\n            spark-version: '3.3.4'\n            scala-compat-version: '2.12'\n            scala-version: '2.12.15'\n          - spark-compat-version: '3.4'\n            scala-compat-version: '2.12'\n            scala-version: '2.12.17'\n            spark-version: '3.4.4'\n          - spark-compat-version: '3.5'\n            scala-compat-version: '2.12'\n            scala-version: '2.12.18'\n            spark-version: '3.5.8'\n          - spark-compat-version: '4.0'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.16'\n            spark-version: '4.0.2'\n          - spark-compat-version: '4.1'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.17'\n            spark-version: '4.1.1'\n\n    steps:\n      - name: Checkout\n        uses: actions/checkout@v4\n\n      - name: Check\n        uses: ./.github/actions/check-compat\n        with:\n          spark-version: ${{ matrix.spark-version }}\n          scala-version: ${{ matrix.scala-version }}\n          spark-compat-version: ${{ matrix.spark-compat-version }}\n          scala-compat-version: ${{ matrix.scala-compat-version }}\n          package-version: ${{ needs.config.outputs.release-version }}\n"
  },
  {
    "path": ".github/workflows/ci.yml",
    "content": "name: CI\n\non:\n  schedule:\n    - cron: '0 8 */10 * *'\n  push:\n    branches:\n      - 'master'\n    tags:\n      - '*'\n  merge_group:\n  pull_request:\n  workflow_dispatch:\n\njobs:\n  event_file:\n    name: \"Event File\"\n    runs-on: ubuntu-latest\n    steps:\n      - name: Upload\n        uses: actions/upload-artifact@v4\n        with:\n          name: Event File\n          path: ${{ github.event_path }}\n\n  build-jvm:\n    name: \"Build JVM\"\n    uses: \"./.github/workflows/build-jvm.yml\"\n  build-snapshots:\n    name: \"Build Snapshots\"\n    uses: \"./.github/workflows/build-snapshots.yml\"\n  build-python:\n    name: \"Build Python\"\n    needs: build-jvm\n    uses: \"./.github/workflows/build-python.yml\"\n\n  test-jvm:\n    name: \"Test JVM\"\n    needs: build-jvm\n    uses: \"./.github/workflows/test-jvm.yml\"\n  test-python:\n    name: \"Test Python\"\n    needs: build-jvm\n    uses: \"./.github/workflows/test-python.yml\"\n  test-snapshots-jvm:\n    name: \"Test Snapshots\"\n    needs: build-snapshots\n    uses: \"./.github/workflows/test-snapshots.yml\"\n  test-release:\n    name: \"Test Release\"\n    needs: build-jvm\n    uses: \"./.github/workflows/test-release.yml\"\n\n  check:\n    name: \"Check\"\n    needs: build-jvm\n    uses: \"./.github/workflows/check.yml\"\n\n  # A single job that succeeds if all jobs listed under 'needs' succeed.\n  # This allows to configure a single job as a required check.\n  # The 'needed' jobs then can be changed through pull-requests.\n  test_success:\n    name: \"Test success\"\n    if: always()\n    runs-on: ubuntu-latest\n    # the if clauses below have to reflect the number of jobs listed here\n    needs: [build-jvm, build-python, test-jvm, test-python, test-release]\n    env:\n      RESULTS: ${{ join(needs.*.result, ',') }}\n\n    steps:\n      - name: \"Success\"\n        # we expect all required jobs to have success result\n        if: env.RESULTS == 'success,success,success,success,success'\n        run: true\n        shell: bash\n      - name: \"Failure\"\n        # we expect all required jobs to have success result, fail otherwise\n        if: env.RESULTS != 'success,success,success,success,success'\n        run: false\n        shell: bash\n"
  },
  {
    "path": ".github/workflows/clear-caches.yaml",
    "content": "name: Clear caches\n\non:\n  workflow_dispatch:\n\npermissions:\n  actions: write\n\njobs:\n  clear-cache:\n    runs-on: ubuntu-latest\n    steps:\n      - name: Clear caches\n        uses: actions/github-script@v7\n        with:\n          script: |\n            const caches = await github.paginate(\n              github.rest.actions.getActionsCacheList.endpoint.merge({\n                owner: context.repo.owner,\n                repo: context.repo.repo,\n              })\n            )\n            for (const cache of caches) {\n              console.log(cache)\n              github.rest.actions.deleteActionsCacheById({\n                owner: context.repo.owner,\n                repo: context.repo.repo,\n                cache_id: cache.id,\n              })\n            }\n\n"
  },
  {
    "path": ".github/workflows/prepare-release.yml",
    "content": "name: Prepare release\n\non:\n  workflow_dispatch:\n    inputs:\n      github_release_latest:\n        description: 'Make the created GitHub release the latest'\n        required: false\n        default: true\n        type: boolean\n\njobs:\n  get-version:\n    name: Get version\n    runs-on: ubuntu-latest\n    outputs:\n      release-tag: ${{ steps.versions.outputs.release-tag }}\n      is-snapshot: ${{ steps.versions.outputs.is-snapshot }}\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          fetch-depth: 0\n\n      - name: Get versions\n        id: versions\n        run: |\n          # get release version\n          version=$(grep --max-count=1 \"<version>.*</version>\" pom.xml | sed -E -e \"s/\\s*<[^>]+>//g\" -e \"s/-SNAPSHOT//\" -e \"s/-[0-9.]+//g\")\n          is_snapshot=$(if grep -q \"<version>.*-SNAPSHOT</version>\" pom.xml; then echo \"true\"; else echo \"false\"; fi)\n\n          # share versions\n          echo \"release-tag=v${version}\" >> \"$GITHUB_OUTPUT\"\n          echo \"is-snapshot=$is_snapshot\" >> \"$GITHUB_OUTPUT\"\n\n  prepare-release:\n    name: Prepare release\n    runs-on: ubuntu-latest\n    if: ( ! github.event.repository.fork )\n    needs: get-version\n    # secrets are provided by environment\n    environment:\n      name: tagged\n      url: 'https://github.com/G-Research/spark-extension?version=${{ needs.get-version.outputs.release-tag }}'\n\n    steps:\n      - name: Create GitHub App token\n        uses: actions/create-github-app-token@v2\n        id: app-token\n        with:\n          app-id: ${{ vars.APP_ID }}\n          private-key: ${{ secrets.PRIVATE_KEY }}\n          # required to push to a branch\n          permission-contents: write\n\n      - name: Get GitHub App User ID\n        id: get-user-id\n        run: echo \"user-id=$(gh api \"/users/${{ steps.app-token.outputs.app-slug }}[bot]\" --jq .id)\" >> \"$GITHUB_OUTPUT\"\n        env:\n          GH_TOKEN: ${{ steps.app-token.outputs.token }}\n\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          token: ${{ steps.app-token.outputs.token }}\n          fetch-depth: 0\n\n      - name: Check branch setup\n        run: |\n          # Check branch setup\n          if [[ \"$GITHUB_REF\" != \"refs/heads/master\" ]] && [[ \"$GITHUB_REF\" != \"refs/heads/master-\"* ]]\n          then\n            echo \"This workflow must be run on master or master-* branch, not $GITHUB_REF\"\n            exit 1\n          fi\n\n      - name: Tag and bump version\n        if: needs.get-version.outputs.is-snapshot\n        run: |\n          # check for unreleased entry in CHANGELOG.md\n          readarray -t changes < <(grep -A 100 \"^## \\[UNRELEASED\\] - YYYY-MM-DD\" CHANGELOG.md | grep -B 100 --max-count=1 -E \"^## \\[[0-9.]+\\]\" | grep \"^-\")\n          if [ ${#changes[@]} -eq 0 ]\n          then\n            echo \"Did not find any changes in CHANGELOG.md under '## [UNRELEASED] - YYYY-MM-DD'\"\n            exit 1\n          fi\n\n          # get latest and release version\n          latest=$(grep --max-count=1 \"<version>.*</version>\" README.md | sed -E -e \"s/\\s*<[^>]+>//g\" -e \"s/-[0-9.]+//g\")\n          version=$(grep --max-count=1 \"<version>.*</version>\" pom.xml | sed -E -e \"s/\\s*<[^>]+>//g\" -e \"s/-SNAPSHOT//\" -e \"s/-[0-9.]+//g\")\n\n          # update changlog\n          echo \"Releasing ${#changes[@]} changes as version $version:\"\n          for (( i=0; i<${#changes[@]}; i++ )); do echo \"${changes[$i]}\" ; done\n          sed -i \"s/## \\[UNRELEASED\\] - YYYY-MM-DD/## [$version] - $(date +%Y-%m-%d)/\" CHANGELOG.md\n          sed -i -e \"s/$latest-/$version-/g\" -e \"s/$latest\\./$version./g\" README.md PYSPARK-DEPS.md python/README.md\n          ./set-version.sh $version\n\n          # configure git so we can commit changes\n          git config --global user.name '${{ steps.app-token.outputs.app-slug }}[bot]'\n          git config --global user.email '${{ steps.get-user-id.outputs.user-id }}+${{ steps.app-token.outputs.app-slug }}[bot]@users.noreply.github.com'\n\n          # commit changes to local repo\n          echo \"Committing release to local git\"\n          git add pom.xml python/setup.py CHANGELOG.md README.md PYSPARK-DEPS.md python/README.md\n          git commit -m \"Releasing $version\"\n          git tag -a \"v${version}\" -m \"Release v${version}\"\n\n          # bump version\n          # define function to bump version\n          function next_version {\n            local version=$1\n            local branch=$2\n\n            patch=${version/*./}\n            majmin=${version%.${patch}}\n\n            if [[ $branch == \"master\" ]]\n            then\n              # minor version bump\n              if [[ $version != *\".0\" ]]\n              then\n                echo \"version is patch version, should be M.m.0: $version\" >&2\n                exit 1\n              fi\n              maj=${version/.*/}\n              min=${majmin#${maj}.}\n              next=${maj}.$((min+1)).0\n              echo \"$next\"\n            else\n              # patch version bump\n              next=${majmin}.$((patch+1))\n              echo \"$next\"\n            fi\n          }\n\n          # get next version\n          pkg_version=\"${version/-*/}\"\n          branch=$(git rev-parse --abbrev-ref HEAD)\n          next_pkg_version=\"$(next_version \"$pkg_version\" \"$branch\")\"\n\n          # bump the version\n          echo \"Bump version to $next_pkg_version\"\n          ./set-version.sh $next_pkg_version-SNAPSHOT\n\n          # commit changes to local repo\n          echo \"Committing release to local git\"\n          git commit -a -m \"Post-release version bump to $next_pkg_version\"\n\n          # push all commits and tag to origin\n          echo \"Pushing release commit and tag to origin\"\n          git push origin \"$GITHUB_REF_NAME\" \"v${version}\" --tags\n          # NOTE: This push will not trigger a CI as we are using GITHUB_TOKEN to push\n          # More info on: https://docs.github.com/en/actions/using-workflows/triggering-a-workflow#triggering-a-workflow-from-a-workflow\n\n  github-release:\n    name: Create GitHub release\n    runs-on: ubuntu-latest\n    needs:\n      - get-version\n      - prepare-release\n    permissions:\n      contents: write # required to create release\n\n    steps:\n      - name: Checkout release tag\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ needs.get-version.outputs.release-tag }}\n\n      - name: Extract release notes\n        id: release-notes\n        run: |\n          awk '/^## /{if(seen==1)exit; seen++} seen' CHANGELOG.md > ./release-notes.txt\n\n          # Grab release name\n          name=$(grep -m 1 \"^## \" CHANGELOG.md | sed \"s/^## //\")\n          echo \"release_name=$name\" >> $GITHUB_OUTPUT\n\n          # provide release notes file path as output\n          echo \"release_notes_path=release-notes.txt\" >> $GITHUB_OUTPUT\n\n      - name: Publish GitHub release\n        uses: ncipollo/release-action@2c591bcc8ecdcd2db72b97d6147f871fcd833ba5\n        id: github-release\n        with:\n          name: ${{ steps.release-notes.outputs.release_name }}\n          bodyFile: ${{ steps.release-notes.outputs.release_notes_path }}\n          makeLatest: ${{ inputs.github_release_latest }}\n          tag: ${{ needs.get-version.outputs.release-tag }}\n          token: ${{ github.token }}\n"
  },
  {
    "path": ".github/workflows/prime-caches.yml",
    "content": "name: Prime caches\n\non:\n  workflow_dispatch:\n\njobs:\n  prime:\n    name: Spark ${{ matrix.spark-compat-version }}.${{ matrix.spark-patch-version }}${{ matrix.spark-snapshot-version && '-SNAPSHOT' }} Scala ${{ matrix.scala-version }}\n    runs-on: ubuntu-latest\n\n    strategy:\n      fail-fast: false\n      # keep in-sync with .github/workflows/test-jvm.yml\n      matrix:\n        include:\n          - spark-compat-version: '3.2'\n            scala-compat-version: '2.12'\n            scala-version: '2.12.15'\n            spark-patch-version: '4'\n            hadoop-version: '2.7'\n          - spark-compat-version: '3.3'\n            scala-compat-version: '2.12'\n            scala-version: '2.12.15'\n            spark-patch-version: '4'\n            hadoop-version: '3'\n          - spark-compat-version: '3.4'\n            scala-compat-version: '2.12'\n            scala-version: '2.12.17'\n            spark-patch-version: '4'\n            hadoop-version: '3'\n          - spark-compat-version: '3.5'\n            scala-compat-version: '2.12'\n            scala-version: '2.12.18'\n            spark-patch-version: '8'\n            hadoop-version: '3'\n\n          - spark-compat-version: '3.2'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.5'\n            spark-patch-version: '4'\n            hadoop-version: '3.2'\n          - spark-compat-version: '3.3'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.8'\n            spark-patch-version: '4'\n            hadoop-version: '3'\n          - spark-compat-version: '3.4'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.8'\n            spark-patch-version: '4'\n            hadoop-version: '3'\n          - spark-compat-version: '3.5'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.8'\n            spark-patch-version: '8'\n            hadoop-version: '3'\n          - spark-compat-version: '4.0'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.16'\n            spark-patch-version: '2'\n            java-compat-version: '17'\n            hadoop-version: '3'\n          - spark-compat-version: '4.1'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.17'\n            spark-patch-version: '1'\n            java-compat-version: '17'\n            hadoop-version: '3'\n          - spark-compat-version: '4.2'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.18'\n            spark-patch-version: '0-preview3'\n            java-compat-version: '17'\n            hadoop-version: '3'\n\n          - spark-compat-version: '3.2'\n            scala-compat-version: '2.12'\n            scala-version: '2.12.15'\n            spark-patch-version: '5'\n            spark-snapshot-version: true\n            hadoop-version: '2.7'\n          - spark-compat-version: '3.3'\n            scala-compat-version: '2.12'\n            scala-version: '2.12.15'\n            spark-patch-version: '5'\n            spark-snapshot-version: true\n            hadoop-version: '3'\n          - spark-compat-version: '3.4'\n            scala-compat-version: '2.12'\n            scala-version: '2.12.17'\n            spark-patch-version: '5'\n            spark-snapshot-version: true\n            hadoop-version: '3'\n          - spark-compat-version: '3.5'\n            scala-compat-version: '2.12'\n            scala-version: '2.12.18'\n            spark-patch-version: '9'\n            spark-snapshot-version: true\n            hadoop-version: '3'\n\n          - spark-compat-version: '3.2'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.5'\n            spark-patch-version: '5'\n            spark-snapshot-version: true\n            hadoop-version: '3.2'\n          - spark-compat-version: '3.3'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.8'\n            spark-patch-version: '5'\n            spark-snapshot-version: true\n            hadoop-version: '3'\n          - spark-compat-version: '3.4'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.8'\n            spark-patch-version: '5'\n            spark-snapshot-version: true\n            hadoop-version: '3'\n          - spark-compat-version: '3.5'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.8'\n            spark-patch-version: '9'\n            spark-snapshot-version: true\n            hadoop-version: '3'\n          - spark-compat-version: '4.0'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.16'\n            spark-patch-version: '3'\n            spark-snapshot-version: true\n            hadoop-version: '3'\n          - spark-compat-version: '4.1'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.17'\n            spark-patch-version: '2'\n            spark-snapshot-version: true\n            hadoop-version: '3'\n          - spark-compat-version: '4.2'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.18'\n            spark-patch-version: '0'\n            spark-snapshot-version: true\n            hadoop-version: '3'\n\n    steps:\n      - name: Checkout\n        uses: actions/checkout@v4\n\n      - name: Prime caches\n        uses: ./.github/actions/prime-caches\n        with:\n          spark-version: ${{ matrix.spark-compat-version }}.${{ matrix.spark-patch-version }}${{ matrix.spark-snapshot-version && '-SNAPSHOT' }}\n          scala-version: ${{ matrix.scala-version }}\n          spark-compat-version: ${{ matrix.spark-compat-version }}\n          scala-compat-version: ${{ matrix.scala-compat-version }}\n          hadoop-version: ${{ matrix.hadoop-version }}\n          java-compat-version: '8'\n"
  },
  {
    "path": ".github/workflows/publish-release.yml",
    "content": "name: Publish release\n\non:\n  workflow_dispatch:\n    inputs:\n      versions:\n        required: true\n        type: string\n        description: 'Example: {\"include\": [{\"params\": {\"spark-version\": \"4.0.0\",\"scala-version\": \"2.13.16\"}}]}'\n        default: |\n          {\n            \"include\": [\n              {\"params\": {\"spark-version\": \"3.2.4\", \"scala-version\": \"2.12.15\", \"java-compat-version\": \"8\"}},\n              {\"params\": {\"spark-version\": \"3.3.4\", \"scala-version\": \"2.12.15\", \"java-compat-version\": \"8\"}},\n              {\"params\": {\"spark-version\": \"3.4.4\", \"scala-version\": \"2.12.17\", \"java-compat-version\": \"8\"}},\n              {\"params\": {\"spark-version\": \"3.5.8\", \"scala-version\": \"2.12.18\", \"java-compat-version\": \"8\"}},\n              {\"params\": {\"spark-version\": \"3.2.4\", \"scala-version\": \"2.13.5\", \"java-compat-version\": \"8\"}},\n              {\"params\": {\"spark-version\": \"3.3.4\", \"scala-version\": \"2.13.8\", \"java-compat-version\": \"8\"}},\n              {\"params\": {\"spark-version\": \"3.4.4\", \"scala-version\": \"2.13.8\", \"java-compat-version\": \"8\"}},\n              {\"params\": {\"spark-version\": \"3.5.8\", \"scala-version\": \"2.13.8\", \"java-compat-version\": \"8\"}},\n              {\"params\": {\"spark-version\": \"4.0.2\", \"scala-version\": \"2.13.16\", \"java-compat-version\": \"17\"}},\n              {\"params\": {\"spark-version\": \"4.1.1\", \"scala-version\": \"2.13.17\", \"java-compat-version\": \"17\"}}\n            ]\n          }\n\nenv:\n  # PySpark 3 versions only work with Python 3.9\n  PYTHON_VERSION: \"3.9\"\n\njobs:\n  get-version:\n    name: Get version\n    runs-on: ubuntu-latest\n    outputs:\n      release-tag: ${{ steps.versions.outputs.release-tag }}\n      is-snapshot: ${{ steps.versions.outputs.is-snapshot }}\n    steps:\n      - name: Checkout release tag\n        uses: actions/checkout@v4\n\n      - name: Get versions\n        id: versions\n        run: |\n          # get release version\n          version=$(grep --max-count=1 \"<version>.*</version>\" pom.xml | sed -E -e \"s/\\s*<[^>]+>//g\" -e \"s/-SNAPSHOT//\" -e \"s/-[0-9.]+//g\")\n          is_snapshot=$(if grep -q \"<version>.*-SNAPSHOT</version>\" pom.xml; then echo \"true\"; else echo \"false\"; fi)\n\n          # share versions\n          echo \"release-tag=v${version}\" >> \"$GITHUB_OUTPUT\"\n          echo \"is-snapshot=$is_snapshot\" >> \"$GITHUB_OUTPUT\"\n\n      - name: Check tag setup\n        run: |\n          # Check tag setup\n          if [[ \"$GITHUB_REF\" != \"refs/tags/v\"* ]]\n          then\n            echo \"This workflow must be run on a tag, not $GITHUB_REF\"\n            exit 1\n          fi\n\n          if [ \"${{ steps.versions.outputs.is-snapshot }}\" == \"true\" ]\n          then\n            echo \"This is a tagged SNAPSHOT version. This is not allowed for release!\"\n            exit 1\n          fi\n\n          if [ \"${{ github.ref_name }}\" != \"${{ steps.versions.outputs.release-tag }}\" ]\n          then\n            echo \"The version in the pom.xml is ${{ steps.versions.outputs.release-tag }}\"\n            echo \"This tag is ${{ github.ref_name }}, which is different!\"\n            exit 1\n          fi\n      - name: Show matrix\n        run: |\n          echo '${{ github.event.inputs.versions }}' | jq .\n\n  maven-release:\n    name: Publish maven release (Spark ${{ matrix.params.spark-version }}, Scala ${{ matrix.params.scala-version }})\n    runs-on: ubuntu-latest\n    needs: get-version\n    if: ( ! github.event.repository.fork )\n    # secrets are provided by environment\n    environment:\n      name: release\n      # a different URL for each point in the matrix, but the same URLs accross commits\n      url: 'https://github.com/G-Research/spark-extension?version=${{ needs.get-version.outputs.release-tag }}&spark=${{ matrix.params.spark-version }}&scala=${{ matrix.params.scala-version }}&package=maven'\n\n    permissions: {}\n    strategy:\n      fail-fast: false\n      matrix: ${{ fromJson(github.event.inputs.versions) }}\n\n    steps:\n      - name: Checkout release tag\n        uses: actions/checkout@v4\n\n      - name: Set up JDK and publish to Maven Central\n        uses: actions/setup-java@3a4f6e1af504cf6a31855fa899c6aa5355ba6c12  # v4.7.0\n        with:\n          java-version: ${{ matrix.params.java-compat-version }}\n          distribution: 'corretto'\n          server-id: central\n          server-username: MAVEN_USERNAME\n          server-password: MAVEN_PASSWORD\n          gpg-private-key: ${{ secrets.MAVEN_GPG_PRIVATE_KEY }}\n          gpg-passphrase: MAVEN_GPG_PASSPHRASE\n\n      - name: Inspect GPG\n        run: gpg -k\n\n      - name: Restore Maven packages cache\n        id: cache-maven\n        uses: actions/cache/restore@v4\n        with:\n          path: ~/.m2/repository\n          key: ${{ runner.os }}-mvn-build-${{ matrix.params.spark-version }}-${{ matrix.params.scala-version }}-${{ hashFiles('pom.xml') }}\n          restore-keys: |\n            ${{ runner.os }}-mvn-build-${{ matrix.params.spark-version }}-${{ matrix.params.scala-version }}-${{ hashFiles('pom.xml') }}\n            ${{ runner.os }}-mvn-build-${{ matrix.params.spark-version }}-${{ matrix.params.scala-version }}-\n\n      - name: Publish maven artifacts\n        id: publish-maven\n        run: |\n          ./set-version.sh ${{ matrix.params.spark-version }} ${{ matrix.params.scala-version }}\n          mvn clean deploy -Dsign -Dspotless.check.skip -DskipTests -Dmaven.test.skip=true\n        env:\n          MAVEN_USERNAME: ${{ secrets.MAVEN_USERNAME }}\n          MAVEN_PASSWORD: ${{ secrets.MAVEN_PASSWORD }}\n          MAVEN_GPG_PASSPHRASE: ${{ secrets.MAVEN_GPG_PASSPHRASE}}\n\n  pypi-release:\n    name: Publish PyPi release (Spark ${{ matrix.params.spark-version }}, Scala ${{ matrix.params.scala-version }})\n    runs-on: ubuntu-latest\n    needs: get-version\n    if: ( ! github.event.repository.fork )\n    # secrets are provided by environment\n    environment:\n      name: release\n      # a different URL for each point in the matrix, but the same URLs accross commits\n      url: 'https://github.com/G-Research/spark-extension?version=${{ needs.get-version.outputs.release-tag }}&spark=${{ matrix.params.spark-version }}&scala=${{ matrix.params.scala-version }}&package=pypi'\n\n    permissions:\n      id-token: write # required for PiPy publish\n    strategy:\n      fail-fast: false\n      matrix: ${{ fromJson(github.event.inputs.versions) }}\n\n    steps:\n      - name: Checkout release tag\n        uses: actions/checkout@v4\n\n      - name: Set up JDK\n        uses: actions/setup-java@3a4f6e1af504cf6a31855fa899c6aa5355ba6c12  # v4.7.0\n        with:\n          java-version: ${{ matrix.params.java-compat-version }}\n          distribution: 'corretto'\n\n      - uses: actions/setup-python@v5\n        with:\n          python-version: ${{ env.PYTHON_VERSION }}\n\n      - name: Restore Maven packages cache\n        id: cache-maven\n        uses: actions/cache/restore@v4\n        with:\n          path: ~/.m2/repository\n          key: ${{ runner.os }}-mvn-build-${{ matrix.params.spark-version }}-${{ matrix.params.scala-version }}-${{ hashFiles('pom.xml') }}\n          restore-keys: |\n            ${{ runner.os }}-mvn-build-${{ matrix.params.spark-version }}-${{ matrix.params.scala-version }}-${{ hashFiles('pom.xml') }}\n            ${{ runner.os }}-mvn-build-${{ matrix.params.spark-version }}-${{ matrix.params.scala-version }}-\n\n      - name: Build maven artifacts\n        id: maven\n        if: startsWith(matrix.params.spark-version, '3.') && startsWith(matrix.params.scala-version, '2.12.') || startsWith(matrix.params.spark-version, '4.') && startsWith(matrix.params.scala-version, '2.13.')\n        run: |\n          ./set-version.sh ${{ matrix.params.spark-version }} ${{ matrix.params.scala-version }}\n          mvn clean package -Dspotless.check.skip -DskipTests -Dmaven.test.skip=true\n\n      - name: Prepare PyPi package\n        id: prepare-pypi-package\n        if: steps.maven.outcome == 'success'\n        run: |\n          ./build-whl.sh\n\n      - name: Publish package distributions to PyPI\n        uses: pypa/gh-action-pypi-publish@release/v1\n        if: steps.prepare-pypi-package.outcome == 'success'\n        with:\n          packages-dir: python/dist\n          skip-existing: true\n          verbose: true\n"
  },
  {
    "path": ".github/workflows/publish-snapshot.yml",
    "content": "name: Publish snapshot\n\non:\n  workflow_dispatch:\n  push:\n    branches: [\"master\"]\n\nenv:\n  PYTHON_VERSION: \"3.10\"\n\njobs:\n  check-version:\n    name: Check SNAPSHOT version\n    if: ( ! github.event.repository.fork )\n    runs-on: ubuntu-latest\n    permissions: {}\n    outputs:\n      is-snapshot: ${{ steps.check.outputs.is-snapshot }}\n\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n\n      - name: Check if this is a SNAPSHOT version\n        id: check\n        run: |\n          # check is snapshot version\n          if grep -q \"<version>.*-SNAPSHOT</version>\" pom.xml\n          then\n            echo \"Version in pom IS a SNAPSHOT version\"\n            echo \"is-snapshot=true\" >> \"$GITHUB_OUTPUT\"\n          else\n            echo \"Version in pom is NOT a SNAPSHOT version\"\n            echo \"is-snapshot=false\" >> \"$GITHUB_OUTPUT\"\n          fi\n\n  snapshot:\n    name: Snapshot Spark ${{ matrix.params.spark-version }} Scala ${{ matrix.params.scala-version }}\n    needs: check-version\n    # when we release from master, this workflow will see a commit that does not have a SNAPSHOT version\n    # we want this workflow to skip over that commit\n    if: needs.check-version.outputs.is-snapshot == 'true'\n    runs-on: ubuntu-latest\n    # secrets are provided by environment\n    environment:\n      name: snapshot\n      # a different URL for each point in the matrix, but the same URLs accross commits\n      url: 'https://github.com/G-Research/spark-extension?spark=${{ matrix.params.spark-version }}&scala=${{ matrix.params.scala-version }}&snapshot'\n    permissions: {}\n    strategy:\n      fail-fast: false\n      matrix:\n        include:\n          - params: {\"spark-version\": \"3.2.4\", \"scala-version\": \"2.12.15\", \"scala-compat-version\": \"2.12\", \"java-compat-version\": \"8\"}\n          - params: {\"spark-version\": \"3.3.4\", \"scala-version\": \"2.12.15\", \"scala-compat-version\": \"2.12\", \"java-compat-version\": \"8\"}\n          - params: {\"spark-version\": \"3.4.4\", \"scala-version\": \"2.12.17\", \"scala-compat-version\": \"2.12\", \"java-compat-version\": \"8\"}\n          - params: {\"spark-version\": \"3.5.8\", \"scala-version\": \"2.12.18\", \"scala-compat-version\": \"2.12\", \"java-compat-version\": \"8\"}\n          - params: {\"spark-version\": \"3.2.4\", \"scala-version\": \"2.13.5\", \"scala-compat-version\": \"2.13\", \"java-compat-version\": \"8\"}\n          - params: {\"spark-version\": \"3.3.4\", \"scala-version\": \"2.13.8\", \"scala-compat-version\": \"2.13\", \"java-compat-version\": \"8\"}\n          - params: {\"spark-version\": \"3.4.4\", \"scala-version\": \"2.13.8\", \"scala-compat-version\": \"2.13\", \"java-compat-version\": \"8\"}\n          - params: {\"spark-version\": \"3.5.8\", \"scala-version\": \"2.13.8\", \"scala-compat-version\": \"2.13\", \"java-compat-version\": \"8\"}\n          - params: {\"spark-version\": \"4.0.2\", \"scala-version\": \"2.13.16\", \"scala-compat-version\": \"2.13\", \"java-compat-version\": \"17\"}\n          - params: {\"spark-version\": \"4.1.1\", \"scala-version\": \"2.13.17\", \"scala-compat-version\": \"2.13\", \"java-compat-version\": \"17\"}\n\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n\n      - name: Set up JDK and publish to Maven Central\n        uses: actions/setup-java@3a4f6e1af504cf6a31855fa899c6aa5355ba6c12  # v4.7.0\n        with:\n          java-version: ${{ matrix.params.java-compat-version }}\n          distribution: 'corretto'\n          server-id: central\n          server-username: MAVEN_USERNAME\n          server-password: MAVEN_PASSWORD\n          gpg-private-key: ${{ secrets.MAVEN_GPG_PRIVATE_KEY }}\n          gpg-passphrase: MAVEN_GPG_PASSPHRASE\n\n      - name: Inspect GPG\n        run: gpg -k\n\n      - uses: actions/setup-python@v5\n        with:\n          python-version: ${{ env.PYTHON_VERSION }}\n\n      - name: Restore Maven packages cache\n        id: cache-maven\n        uses: actions/cache/restore@v4\n        with:\n          path: ~/.m2/repository\n          key: ${{ runner.os }}-mvn-build-${{ matrix.params.spark-version }}-${{ matrix.params.scala-version }}-${{ hashFiles('pom.xml') }}\n          restore-keys: |\n            ${{ runner.os }}-mvn-build-${{ matrix.params.spark-version }}-${{ matrix.params.scala-version }}-${{ hashFiles('pom.xml') }}\n            ${{ runner.os }}-mvn-build-${{ matrix.params.spark-version }}-${{ matrix.params.scala-version }}-\n\n\n      - name: Publish snapshot\n        run: |\n          ./set-version.sh ${{ matrix.params.spark-version }} ${{ matrix.params.scala-version }}\n          mvn clean deploy -Dsign -Dspotless.check.skip -DskipTests -Dmaven.test.skip=true\n        env:\n          MAVEN_USERNAME: ${{ secrets.MAVEN_USERNAME }}\n          MAVEN_PASSWORD: ${{ secrets.MAVEN_PASSWORD }}\n          MAVEN_GPG_PASSPHRASE: ${{ secrets.MAVEN_GPG_PASSPHRASE}}\n\n      - name: Prepare PyPi package to test snapshot\n        if: ${{ matrix.params.scala-version }} == 2.12*\n        run: |\n          # Build whl\n          ./build-whl.sh\n\n      - name: Restore Spark Binaries cache\n        uses: actions/cache/restore@v4\n        with:\n          path: ~/spark\n          key: ${{ runner.os }}-spark-binaries-${{ matrix.params.spark-version }}-${{ matrix.params.scala-compat-version }}\n          restore-keys: |\n            ${{ runner.os }}-spark-binaries-${{ matrix.params.spark-version }}-${{ matrix.params.scala-compat-version }}\n\n      - name: Rename Spark Binaries cache\n        run: |\n          mv ~/spark ./spark-${{ matrix.params.spark-version }}-${{ matrix.params.scala-compat-version }}\n\n      - name: Test snapshot\n        id: test-package\n        run: |\n          # Test the snapshot (needs whl)\n          ./test-release.sh\n"
  },
  {
    "path": ".github/workflows/test-jvm.yml",
    "content": "name: Test JVM\n\non:\n  workflow_call:\n\njobs:\n  test:\n    name: Test (Spark ${{ matrix.spark-compat-version }}.${{ matrix.spark-patch-version }} Scala ${{ matrix.scala-version }})\n    runs-on: ubuntu-latest\n\n    strategy:\n      fail-fast: false\n      # keep in-sync with .github/workflows/prime-caches.yml\n      matrix:\n        include:\n          - spark-compat-version: '3.2'\n            scala-compat-version: '2.12'\n            scala-version: '2.12.15'\n            spark-patch-version: '4'\n            java-compat-version: '8'\n            hadoop-version: '2.7'\n          - spark-compat-version: '3.3'\n            scala-compat-version: '2.12'\n            scala-version: '2.12.15'\n            spark-patch-version: '4'\n            java-compat-version: '8'\n            hadoop-version: '3'\n          - spark-compat-version: '3.4'\n            scala-compat-version: '2.12'\n            scala-version: '2.12.17'\n            spark-patch-version: '4'\n            java-compat-version: '8'\n            hadoop-version: '3'\n          - spark-compat-version: '3.5'\n            scala-compat-version: '2.12'\n            scala-version: '2.12.18'\n            spark-patch-version: '7'\n            java-compat-version: '8'\n            hadoop-version: '3'\n\n          - spark-compat-version: '3.2'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.5'\n            spark-patch-version: '4'\n            java-compat-version: '8'\n            hadoop-version: '3.2'\n          - spark-compat-version: '3.3'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.8'\n            spark-patch-version: '4'\n            java-compat-version: '8'\n            hadoop-version: '3'\n          - spark-compat-version: '3.4'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.8'\n            spark-patch-version: '4'\n            java-compat-version: '8'\n            hadoop-version: '3'\n          - spark-compat-version: '3.5'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.8'\n            spark-patch-version: '7'\n            java-compat-version: '8'\n            hadoop-version: '3'\n          - spark-compat-version: '4.0'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.16'\n            spark-patch-version: '2'\n            java-compat-version: '17'\n            hadoop-version: '3'\n          - spark-compat-version: '4.1'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.17'\n            spark-patch-version: '1'\n            java-compat-version: '17'\n            hadoop-version: '3'\n          - spark-compat-version: '4.2'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.18'\n            spark-patch-version: '0-preview3'\n            java-compat-version: '17'\n            hadoop-version: '3'\n\n    steps:\n      - name: Checkout\n        uses: actions/checkout@v4\n\n      - name: Test\n        uses: ./.github/actions/test-jvm\n        env:\n          CI_SLOW_TESTS: 1\n        with:\n          spark-version: ${{ matrix.spark-compat-version }}.${{ matrix.spark-patch-version }}\n          scala-version: ${{ matrix.scala-version }}\n          spark-compat-version: ${{ matrix.spark-compat-version }}\n          spark-archive-url: ${{ matrix.spark-archive-url }}\n          scala-compat-version: ${{ matrix.scala-compat-version }}\n          java-compat-version: ${{ matrix.java-compat-version }}\n          hadoop-version: ${{ matrix.hadoop-version }}\n"
  },
  {
    "path": ".github/workflows/test-python.yml",
    "content": "name: Test Python\n\non:\n  workflow_call:\n\njobs:\n  # pyspark is not available for snapshots or scala other than 2.12\n  # we would have to compile spark from sources for this, not worth it\n  test:\n    name: Test (Spark ${{ matrix.spark-version }} Scala ${{ matrix.scala-version }} Python ${{ matrix.python-version }})\n    runs-on: ubuntu-latest\n\n    strategy:\n      fail-fast: false\n      matrix:\n        spark-compat-version: ['3.2', '3.3', '3.4', '3.5', '4.0']\n        python-version: ['3.9', '3.10', '3.11', '3.12', '3.13']\n\n        include:\n          - spark-compat-version: '3.2'\n            spark-version: '3.2.4'\n            scala-compat-version: '2.12'\n            scala-version: '2.12.15'\n            java-compat-version: '8'\n            hadoop-version: '2.7'\n          - spark-compat-version: '3.3'\n            spark-version: '3.3.4'\n            scala-compat-version: '2.12'\n            scala-version: '2.12.15'\n            java-compat-version: '8'\n            hadoop-version: '3'\n          - spark-compat-version: '3.4'\n            spark-version: '3.4.4'\n            scala-compat-version: '2.12'\n            scala-version: '2.12.17'\n            java-compat-version: '8'\n            hadoop-version: '3'\n          - spark-compat-version: '3.5'\n            spark-version: '3.5.8'\n            scala-compat-version: '2.12'\n            scala-version: '2.12.18'\n            java-compat-version: '8'\n            hadoop-version: '3'\n          - spark-compat-version: '4.0'\n            spark-version: '4.0.2'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.16'\n            java-compat-version: '17'\n            hadoop-version: '3'\n          - spark-compat-version: '4.1'\n            spark-version: '4.1.1'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.17'\n            java-compat-version: '17'\n            hadoop-version: '3'\n            python-version: '3.10'\n          - spark-compat-version: '4.2'\n            spark-version: '4.2.0-preview3'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.18'\n            java-compat-version: '17'\n            hadoop-version: '3'\n            python-version: '3.10'\n\n        exclude:\n          - spark-compat-version: '3.2'\n            python-version: '3.10'\n          - spark-compat-version: '3.2'\n            python-version: '3.11'\n          - spark-compat-version: '3.2'\n            python-version: '3.12'\n          - spark-compat-version: '3.2'\n            python-version: '3.13'\n\n          - spark-compat-version: '3.3'\n            python-version: '3.11'\n          - spark-compat-version: '3.3'\n            python-version: '3.12'\n          - spark-compat-version: '3.3'\n            python-version: '3.13'\n\n          - spark-compat-version: '3.4'\n            python-version: '3.12'\n          - spark-compat-version: '3.4'\n            python-version: '3.13'\n\n          - spark-compat-version: '3.5'\n            python-version: '3.12'\n          - spark-compat-version: '3.5'\n            python-version: '3.13'\n\n    steps:\n      - name: Checkout\n        uses: actions/checkout@v4\n\n      - name: Test\n        uses: ./.github/actions/test-python\n        with:\n          spark-version: ${{ matrix.spark-version }}\n          scala-version: ${{ matrix.scala-version }}\n          spark-compat-version: ${{ matrix.spark-compat-version }}\n          spark-archive-url: ${{ matrix.spark-archive-url }}\n          spark-package-repo: ${{ matrix.spark-package-repo }}\n          scala-compat-version: ${{ matrix.scala-compat-version }}\n          java-compat-version: ${{ matrix.java-compat-version }}\n          hadoop-version: ${{ matrix.hadoop-version }}\n          python-version: ${{ matrix.python-version }}\n"
  },
  {
    "path": ".github/workflows/test-release.yml",
    "content": "name: Test release\n\non:\n  workflow_call:\n\njobs:\n  test:\n    name: Test Release Spark ${{ matrix.spark-version }} Scala ${{ matrix.scala-version }}\n    runs-on: ubuntu-latest\n\n    strategy:\n      fail-fast: false\n      matrix:\n        include:\n          - spark-compat-version: '3.2'\n            spark-version: '3.2.4'\n            scala-compat-version: '2.12'\n            scala-version: '2.12.15'\n            java-compat-version: '8'\n            hadoop-version: '2.7'\n            python-version: '3.9'\n          - spark-compat-version: '3.3'\n            spark-version: '3.3.4'\n            scala-compat-version: '2.12'\n            scala-version: '2.12.15'\n            java-compat-version: '8'\n            hadoop-version: '3'\n            python-version: '3.10'\n          - spark-compat-version: '3.4'\n            spark-version: '3.4.4'\n            scala-compat-version: '2.12'\n            scala-version: '2.12.17'\n            java-compat-version: '8'\n            hadoop-version: '3'\n            python-version: '3.11'\n          - spark-compat-version: '3.5'\n            spark-version: '3.5.8'\n            scala-compat-version: '2.12'\n            scala-version: '2.12.18'\n            java-compat-version: '8'\n            hadoop-version: '3'\n            python-version: '3.11'\n\n          - spark-compat-version: '3.2'\n            spark-version: '3.2.4'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.5'\n            java-compat-version: '8'\n            hadoop-version: '3.2'\n          - spark-compat-version: '3.3'\n            spark-version: '3.3.4'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.8'\n            java-compat-version: '8'\n            hadoop-version: '3'\n          - spark-compat-version: '3.4'\n            spark-version: '3.4.4'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.8'\n            java-compat-version: '8'\n            hadoop-version: '3'\n          - spark-compat-version: '3.5'\n            spark-version: '3.5.8'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.8'\n            java-compat-version: '8'\n            hadoop-version: '3'\n          - spark-compat-version: '4.0'\n            spark-version: '4.0.2'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.16'\n            java-compat-version: '17'\n            hadoop-version: '3'\n            python-version: '3.13'\n          - spark-compat-version: '4.1'\n            spark-version: '4.1.1'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.17'\n            java-compat-version: '17'\n            hadoop-version: '3'\n            python-version: '3.13'\n          - spark-compat-version: '4.2'\n            spark-version: '4.2.0-preview3'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.18'\n            java-compat-version: '17'\n            hadoop-version: '3'\n            python-version: '3.13'\n\n    steps:\n      - name: Checkout\n        uses: actions/checkout@v4\n\n      - name: Test\n        uses: ./.github/actions/test-release\n        with:\n          spark-version: ${{ matrix.spark-version }}\n          scala-version: ${{ matrix.scala-version }}\n          spark-compat-version: ${{ matrix.spark-compat-version }}\n          spark-archive-url: ${{ matrix.spark-archive-url }}\n          scala-compat-version: ${{ matrix.scala-compat-version }}\n          java-compat-version: ${{ matrix.java-compat-version }}\n          hadoop-version: ${{ matrix.hadoop-version }}\n          python-version: ${{ matrix.python-version }}\n\n"
  },
  {
    "path": ".github/workflows/test-results.yml",
    "content": "name: Test Results\n\non:\n  workflow_run:\n    workflows: [\"CI\"]\n    types:\n      - completed\npermissions: {}\n\njobs:\n  publish-test-results:\n    name: Publish Test Results\n    runs-on: ubuntu-latest\n    if: github.event.workflow_run.conclusion != 'skipped'\n    permissions:\n      checks: write\n      pull-requests: write\n\n    steps:\n      - name: Download and Extract Artifacts\n        uses: dawidd6/action-download-artifact@09f2f74827fd3a8607589e5ad7f9398816f540fe\n        with:\n          run_id: ${{ github.event.workflow_run.id }}\n          name: \"^Event File$| Test Results \"\n          name_is_regexp: true\n          path: artifacts\n\n      - name: Publish Test Results\n        uses: EnricoMi/publish-unit-test-result-action@v2\n        with:\n          commit: ${{ github.event.workflow_run.head_sha }}\n          event_file: artifacts/Event File/event.json\n          event_name: ${{ github.event.workflow_run.event }}\n          files: \"artifacts/* Test Results*/**/*.xml\"\n"
  },
  {
    "path": ".github/workflows/test-snapshots.yml",
    "content": "name: Test Snapshots\n\non:\n  workflow_call:\n\njobs:\n  test:\n    name: Test (Spark ${{ matrix.spark-version }} Scala ${{ matrix.scala-version }})\n    runs-on: ubuntu-latest\n\n    strategy:\n      fail-fast: false\n      matrix:\n        include:\n          - spark-compat-version: '3.2'\n            spark-version: '3.2.5-SNAPSHOT'\n            scala-compat-version: '2.12'\n            scala-version: '2.12.15'\n            java-compat-version: '8'\n          - spark-compat-version: '3.3'\n            spark-version: '3.3.5-SNAPSHOT'\n            scala-compat-version: '2.12'\n            scala-version: '2.12.15'\n            java-compat-version: '8'\n          - spark-compat-version: '3.4'\n            spark-version: '3.4.5-SNAPSHOT'\n            scala-compat-version: '2.12'\n            scala-version: '2.12.17'\n            java-compat-version: '8'\n          - spark-compat-version: '3.5'\n            spark-version: '3.5.9-SNAPSHOT'\n            scala-compat-version: '2.12'\n            scala-version: '2.12.18'\n            java-compat-version: '8'\n\n          - spark-compat-version: '3.2'\n            spark-version: '3.2.5-SNAPSHOT'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.5'\n            java-compat-version: '8'\n          - spark-compat-version: '3.3'\n            spark-version: '3.3.5-SNAPSHOT'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.8'\n            java-compat-version: '8'\n          - spark-compat-version: '3.4'\n            spark-version: '3.4.5-SNAPSHOT'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.8'\n            java-compat-version: '8'\n          - spark-compat-version: '3.5'\n            spark-version: '3.5.9-SNAPSHOT'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.8'\n            java-compat-version: '8'\n          - spark-compat-version: '4.0'\n            spark-version: '4.0.3-SNAPSHOT'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.16'\n            java-compat-version: '17'\n          - spark-compat-version: '4.1'\n            spark-version: '4.1.2-SNAPSHOT'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.17'\n            java-compat-version: '17'\n          - spark-compat-version: '4.1'\n            spark-version: '4.2.0-SNAPSHOT'\n            scala-compat-version: '2.13'\n            scala-version: '2.13.18'\n            java-compat-version: '17'\n\n    steps:\n      - name: Checkout\n        uses: actions/checkout@v4\n\n      - name: Test\n        uses: ./.github/actions/test-jvm\n        env:\n          CI_SLOW_TESTS: 1\n        with:\n          spark-version: ${{ matrix.spark-version }}\n          scala-version: ${{ matrix.scala-version }}\n          spark-compat-version: ${{ matrix.spark-compat-version }}-SNAPSHOT\n          scala-compat-version: ${{ matrix.scala-compat-version }}\n          java-compat-version: ${{ matrix.java-compat-version }}\n"
  },
  {
    "path": ".gitignore",
    "content": "# use glob syntax.\nsyntax: glob\n*.ser\n*.class\n*~\n*.bak\n#*.off\n*.old\n\n# eclipse conf file\n.settings\n.classpath\n.project\n.manager\n.scala_dependencies\n\n# idea\n.idea\n*.iml\n\n# building\ntarget\nbuild\nnull\ntmp*\ntemp*\ndist\ntest-output\nbuild.log\n\n# other scm\n.svn\n.CVS\n.hg*\n\n# switch to regexp syntax.\n#  syntax: regexp\n#  ^\\.pc/\n\n#SHITTY output not in target directory\nbuild.log\n\n# project specific\npython/**/__pycache__\nspark-*\n.cache\n"
  },
  {
    "path": ".scalafmt.conf",
    "content": "version = 3.7.17\nrunner.dialect = scala213\nrewrite.trailingCommas.style = keep\ndocstrings.style = Asterisk\nmaxColumn = 120\n\n"
  },
  {
    "path": "CHANGELOG.md",
    "content": "# Changelog\nAll notable changes to this project will be documented in this file.\n\nThe format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).\n\n## [2.15.0] - 2025-12-13\n\n### Added\n- Support encrypted parquet files (#324)\n\n### Changed\n- Remove support for Spark 3.0 and Spark 3.1 (#332)\n- Make all undocumented unintended public API parts private (#331)\n- Reading Parquet metadata can use Parquet Hadoop version different to version coming with Spark (#330)\n\n## [2.14.2] - 2025-07-21\n\n### Changed\n- Fixed release process (#320)\n\n## [2.14.1] - 2025-07-17\n\n### Changed\n- Fixed release process (#319)\n\n## [2.14.0] - 2025-07-17\n\n### Added\n- Support for Spark 4.0 (#269, #272, #293)\n\n### Changed\n- Improve backticks (#265)\n\n  New: This escapes backticks that already exist in column names.\n\n  Change: This does not quote columns that only contain letters, numbers\n  and underscores, which were quoted before.\n- Move Python dependencies into `setup.py`, build jar from `setup.py` (#301)\n\n## [2.13.0] - 2024-11-04\n\n### Fixes\n- Support diff for Spark Connect implemened via PySpark Dataset API (#251)\n\n### Added\n- Add ignore columns to diff in Python API (#252)\n- Check that the Java / Scala package is installed when needed by Python (#250)\n\n## [2.12.0] - 2024-04-26\n\n### Fixes\n\n- Diff change column should respect comparators (#238)\n\n### Changed\n\n- Make create_temporary_dir work with pyspark-extension only (#222).\n  This allows [installing PIP packages and Poetry projects](PYSPARK-DEPS.md)\n  via pure Python spark-extension package (Maven package not required any more).\n- Add map diff comparator to Python API (#226)\n\n## [2.11.0] - 2024-01-04\n\n### Added\n\n- Add count_null aggregate function (#206)\n- Support reading parquet schema (#208)\n- Add more columns to reading parquet metadata (#209, #211)\n- Provide groupByKey shortcuts for groupBy.as (#213)\n- Allow to install PIP packages into PySpark job (#215)\n- Allow to install Poetry projects into PySpark job (#216)\n\n## [2.10.0] - 2023-09-27\n\n### Fixed\n\n- Update setup.py to include parquet methods in python package (#191)\n\n### Added\n\n- Add --statistics option to diff app (#189)\n- Add --filter option to diff app (#190)\n\n## [2.9.0] - 2023-08-23\n\n### Added\n\n- Add key order sensitive map comparator (#187)\n\n### Changed\n\n- Use dataset encoder rather than implicit value encoder for implicit dataset extension class (#183)\n\n### Fixed\n\n- Fix key-sensitivity in map comparator (#186)\n\n## [2.8.0] - 2023-05-24\n\n### Added\n\n- Add method to set and automatically unset Spark job description. (#172)\n- Add column function that converts between .Net (C#, F#, Visual Basic) `DateTime.Ticks` and Spark timestamp / Unix epoch timestamps. (#153)\n\n## [2.7.0] - 2023-05-05\n\n### Added\n\n- Spark app to diff files or tables and write result back to file or table. (#160)\n- Add null value count to `parquetBlockColumns` and `parquet_block_columns`. (#162)\n- Add `parallelism` argument to Parquet metadata methods. (#164)\n\n### Changed\n\n- Change data type of column name in `parquetBlockColumns` and `parquet_block_columns` to array of strings.\n  Cast to string to get earlier behaviour (string column name). (#162)\n\n## [2.6.0] - 2023-04-11\n\n### Added\n\n-  Add reader for parquet metadata. (#154)\n\n## [2.5.0] - 2023-03-23\n\n### Added\n\n- Add whitespace agnostic diff comparator. (#137)\n- Add Python whl package build. (#151)\n\n## [2.4.0] - 2022-12-08\n\n### Added\n\n- Allow for custom diff equality. (#127)\n\n### Fixed\n\n- Fix Python API calling into Scala code. (#132)\n\n## [2.3.0] - 2022-10-26\n\n### Added\n\n- Add diffWith to Scala, Java and Python Diff API. (#109)\n\n### Changed\n\n- Diff similar Datasets with ignoreColumns. Before, only similar DataFrame could be diffed with ignoreColumns. (#111)\n\n### Fixed\n\n- Cache before writing via partitionedBy to work around SPARK-40588. Unpersist via UnpersistHandle. (#124)\n\n## [2.2.0] - 2022-07-21\n\n### Added\n- Add (global) row numbers transformation to Scala, Java and Python API. (#97)\n\n### Removed\n- Removed support for Pyton 3.6\n\n## [2.1.0] - 2022-04-07\n\n### Added\n- Add sorted group methods to Dataset. (#76)\n\n## [2.0.0] - 2021-10-29\n\n### Added\n- Add support for Spark 3.2 and Scala 2.13.\n- Support to ignore columns in diff API. (#63)\n\n### Removed\n- Removed support for Spark 2.4.\n\n## [1.3.3] - 2020-12-17\n\n### Added\n- Add support for Spark 3.1.\n\n## [1.3.2] - 2020-12-17\n\n### Changed\n- Refine conditional transformation helper methods.\n\n## [1.3.1] - 2020-12-10\n\n### Changed\n- Refine conditional transformation helper methods.\n\n## [1.3.0] - 2020-12-07\n\n### Added\n- Add transformation to compute histogram. (#26)\n- Add conditional transformation helper methods. (#27)\n- Add partitioned writing helpers that simplifies writing optimally ordered partitioned data. (#29)\n\n## [1.2.0] - 2020-10-06\n\n### Added\n- Add diff modes (#22): column-by-column, side-by-side, left and right side diff modes.\n- Adds sparse mode (#23): diff DataFrame contains only changed values.\n\n## [1.1.0] - 2020-08-24\n\n### Added\n- Add Python API for Diff transformation.\n- Add change column to Diff transformation providing column names of all changed columns in a row.\n- Add fluent methods to change immutable diff options.\n- Add `backticks` method to handle column names that contain dots (`.`).\n\n## [1.0.0] - 2020-03-12\n\n### Added\n- Add Diff transformation for Datasets.\n"
  },
  {
    "path": "CONDITIONAL.md",
    "content": "# DataFrame Transformations\n\nThe Spark `Dataset` API allows for chaining transformations as in the following example:\n\n```scala\nds.where($\"id\" === 1)\n  .withColumn(\"state\", lit(\"new\"))\n  .orderBy($\"timestamp\")\n```\n\nWhen you define additional transformation functions, the `Dataset` API allows you to\nalso fluently call into those:\n\n```scala\ndef transformation(df: DataFrame): DataFrame = df.distinct\n\nds.transform(transformation)\n```\n\nHere are some methods that extend this principle to conditional calls.\n\n## Conditional Transformations\n\nYou can run a transformation after checking a condition with a chain of fluent transformation calls:\n\n```scala\nimport uk.co.gresearch._\n\nval condition = true\n\nval result =\n  ds.where($\"id\" === 1)\n    .withColumn(\"state\", lit(\"new\"))\n    .when(condition).call(transformation)\n    .orderBy($\"timestamp\")\n```\n\nrather than\n\n```scala\nval condition = true\n\nval filteredDf = ds.where($\"id\" === 1)\n                   .withColumn(\"state\", lit(\"new\"))\nval condDf = if (condition) ds.call(transformation) else ds\nval result = ds.orderBy($\"timestamp\")\n```\n\nIn case you need an else transformation as well, try:\n\n```scala\nimport uk.co.gresearch._\n\nval condition = true\n\nval result =\n  ds.where($\"id\" === 1)\n    .withColumn(\"state\", lit(\"new\"))\n    .on(condition).either(transformation).or(other)\n    .orderBy($\"timestamp\")\n```\n\n## Fluent and conditional functions elsewhere\n\nThe same fluent notation works for instances other than `Dataset` or `DataFrame`, e.g.\nfor the `DataFrameWriter`:\n\n```scala\ndef writeData[T](writer: DataFrameWriter[T]): Unit = { ... }\n\nds.write\n  .when(compress).call(_.option(\"compression\", \"gzip\"))\n  .call(writeData)\n```\n"
  },
  {
    "path": "DIFF.md",
    "content": "# Spark Diff\n\nAdd the following `import` to your Scala code:\n\n```scala\nimport uk.co.gresearch.spark.diff._\n```\n\nor this `import` to your Python code:\n\n```python\n# noinspection PyUnresolvedReferences\nfrom gresearch.spark.diff import *\n```\n\nThis adds a `diff` transformation to `Dataset` and `DataFrame` that computes the differences between two datasets / dataframes,\ni.e. which rows of one dataset / dataframes to _add_, _delete_ or _change_ to get to the other dataset / dataframes.\n\nFor example, in Scala\n\n```scala\nval left = Seq((1, \"one\"), (2, \"two\"), (3, \"three\")).toDF(\"id\", \"value\")\nval right = Seq((1, \"one\"), (2, \"Two\"), (4, \"four\")).toDF(\"id\", \"value\")\n```\n\nor in Python:\n\n```python\nleft = spark.createDataFrame([(1, \"one\"), (2, \"two\"), (3, \"three\")], [\"id\", \"value\"])\nright = spark.createDataFrame([(1, \"one\"), (2, \"Two\"), (4, \"four\")], [\"id\", \"value\"])\n```\n\ndiffing becomes as easy as:\n\n```scala\nleft.diff(right).show()\n```\n\n|diff |id   |value  |\n|:---:|:---:|:-----:|\n|    N|    1|    one|\n|    D|    2|    two|\n|    I|    2|    Two|\n|    D|    3|  three|\n|    I|    4|   four|\n\nWith columns that provide unique identifiers per row (here `id`), the diff looks like:\n\n```scala\nleft.diff(right, \"id\").show()\n```\n\n|diff |id   |left_value|right_value|\n|:---:|:---:|:--------:|:---------:|\n|    N|    1|       one|        one|\n|    C|    2|       two|        Two|\n|    D|    3|     three|     *null*|\n|    I|    4|    *null*|       four|\n\n\nEquivalent alternative is this hand-crafted transformation (Scala)\n\n```scala\nleft.withColumn(\"exists\", lit(1)).as(\"l\")\n  .join(right.withColumn(\"exists\", lit(1)).as(\"r\"),\n    $\"l.id\" <=> $\"r.id\",\n    \"fullouter\")\n  .withColumn(\"diff\",\n    when($\"l.exists\".isNull, \"I\").\n      when($\"r.exists\".isNull, \"D\").\n      when(!($\"l.value\" <=> $\"r.value\"), \"C\").\n      otherwise(\"N\"))\n  .show()\n```\n\nStatistics on the differences can be obtained by\n\n```scala\nleft.diff(right, \"id\").groupBy(\"diff\").count().show()\n```\n\n|diff  |count  |\n|:----:|:-----:|\n|     N|      1|\n|     I|      1|\n|     D|      1|\n|     C|      1|\n\nThe `diff` transformation can optionally provide a *change column* that lists all non-id column names that have changed.\nThis column is an array of strings and only set for `\"N\"` and `\"C\"`action rows; it is *null* for `\"I\"` and `\"D\"`action rows.\n\n|diff |changes|id   |left_value|right_value|\n|:---:|:-----:|:---:|:--------:|:---------:|\n|    N|     []|    1|       one|        one|\n|    C|[value]|    2|       two|        Two|\n|    D| *null*|    3|     three|     *null*|\n|    I| *null*|    4|    *null*|       four|\n\n## Features\n\nThis `diff` transformation provides the following features:\n* id columns are optional\n* provides typed `diffAs` and `diffWith` transformations\n* supports *null* values in id and non-id columns\n* detects *null* value insertion / deletion\n* [configurable](#configuring-diff) via `DiffOptions`:\n  * diff column name (default: `\"diff\"`), if default name exists in diff result schema\n  * diff action labels (defaults: `\"N\"`, `\"I\"`, `\"D\"`, `\"C\"`), allows custom diff notation,<br/> e.g. Unix diff left-right notation (<, >) or git before-after format (+, -, -+)\n  * [custom equality operators](#comparators-equality) (e.g. double comparison with epsilon threshold)\n  * [different diff result formats](#diffing-modes)\n  * [sparse diffing mode](#sparse-mode)\n* optionally provides a *change column* that lists all non-id column names that have changed (only for `\"C\"` action rows)\n* guarantees that no duplicate columns exist in the result, throws a readable exception otherwise\n\n## Configuring Diff\n\nDiffing can be configured via an optional `DiffOptions` instance (see [Methods](#methods) below).\n\n|option              |default  |description|\n|--------------------|:-------:|-----------|\n|`diffColumn`        |`\"diff\"` |The 'diff column' provides the action or diff value encoding if the respective row has been inserted, changed, deleted or has not been changed at all.|\n|`leftColumnPrefix`  |`\"left\"` |Non-id columns of the 'left' dataset are prefixed with this prefix.|\n|`rightColumnPrefix` |`\"right\"`|Non-id columns of the 'right' dataset are prefixed with this prefix.|\n|`insertDiffValue`   |`\"I\"`    |Inserted rows are marked with this string in the 'diff column'.|\n|`changeDiffValue`   |`\"C\"`    |Changed rows are marked with this string in the 'diff column'.|\n|`deleteDiffValue`   |`\"D\"`    |Deleted rows are marked with this string in the 'diff column'.|\n|`nochangeDiffValue` |`\"N\"`    |Unchanged rows are marked with this string in the 'diff column'.|\n|`changeColumn`      |*none*   |An array with the names of all columns that have changed values is provided in this column (only for unchanged and changed rows, *null* otherwise).|\n|`diffMode`          |`DiffModes.Default`|Configures the diff output format. For details see [Diff Modes](#diff-modes) section below.|\n|`sparseMode`        |`false`  |When `true`, only values that have changed are provided on left and right side, `null` is used for un-changed values.|\n|`defaultComparator` |`DiffComparators.default()`|The default equality for all value columns.|\n|`dataTypeComparators`|_empty_ |Map from data types to comparators.|\n|`columnNameComparators`|_empty_|Map from column names to comparators.|\n\nEither construct an instance via the constructor …\n\n```scala\n// Scala\nimport uk.co.gresearch.spark.diff.{DiffOptions, DiffMode}\nval options = DiffOptions(\"d\", \"l\", \"r\", \"i\", \"c\", \"d\", \"n\", Some(\"changes\"), DiffMode.Default, false)\n```\n\n```python\n# Python\nfrom gresearch.spark.diff import DiffOptions, DiffMode\noptions = DiffOptions(\"d\", \"l\", \"r\", \"i\", \"c\", \"d\", \"n\", \"changes\", DiffMode.Default, False)\n```\n\n… or via the `.with*` methods. The former requires most options to be specified, whereas the latter\nonly requires the ones that deviate from the default. And it is more readable.\n\nStart from the default options `DiffOptions.default` and customize as follows:\n\n```scala\n// Scala\nimport uk.co.gresearch.spark.diff.{DiffOptions, DiffMode, DiffComparators}\n\nval options = DiffOptions.default\n  .withDiffColumn(\"d\")\n  .withLeftColumnPrefix(\"l\")\n  .withRightColumnPrefix(\"r\")\n  .withInsertDiffValue(\"i\")\n  .withChangeDiffValue(\"c\")\n  .withDeleteDiffValue(\"d\")\n  .withNochangeDiffValue(\"n\")\n  .withChangeColumn(\"changes\")\n  .withDiffMode(DiffMode.Default)\n  .withSparseMode(true)\n  .withDefaultComparator(DiffComparators.epsilon(0.001))\n  .withComparator(DiffComparators.epsilon(0.001), DoubleType)\n  .withComparator(DiffComparators.epsilon(0.001), \"float_column\")\n```\n\n```python\n# Python\nfrom pyspark.sql.types import DoubleType\nfrom gresearch.spark.diff import DiffOptions, DiffMode, DiffComparators\n\noptions = DiffOptions() \\\n  .with_diff_column(\"d\") \\\n  .with_left_column_prefix(\"l\") \\\n  .with_right_column_prefix(\"r\") \\\n  .with_insert_diff_value(\"i\") \\\n  .with_change_diff_value(\"c\") \\\n  .with_delete_diff_value(\"d\") \\\n  .with_nochange_diff_value(\"n\") \\\n  .with_change_column(\"changes\") \\\n  .with_diff_mode(DiffMode.Default) \\\n  .with_sparse_mode(True) \\\n  .with_default_comparator(DiffComparators.epsilon(0.01)) \\\n  .with_data_type_comparator(DiffComparators.epsilon(0.001), DoubleType()) \\\n  .with_column_name_comparator(DiffComparators.epsilon(0.001), \"float_column\")\n```\n### Diffing Modes\n\nThe result of the diff transformation can have the following formats:\n\n- *column by column*: The non-id columns are arranged column by column, i.e. for each non-id column\n                      there are two columns next to each other in the diff result, one from the left\n                      and one from the right dataset. This is useful to easily compare the values\n                      for each column.\n- *side by side*: The non-id columns from the left and right dataset are are arranged side by side,\n                  i.e. first there are all columns from the left dataset, then from the right one.\n                  This is useful to visually compare the datasets as a whole, especially in conjunction\n                  with the sparse mode.\n- *left side*: Only the columns of the left dataset are present in the diff output. This mode\n               provides the left dataset as is, annotated with diff action and optional changed column names. \n- *right side*: Only the columns of the right dataset are present in the diff output. This mode\n                provides the right dataset as given, as well as the diff action that has been applied to it.\n                This serves as a patch that, applied to the left dataset, results in the right dataset.\n\nWith the following two datasets `left` and `right`:\n\n```scala\ncase class Value(id: Int, value: Option[String], label: Option[String])\n\nval left = Seq(\n  Value(1, Some(\"one\"), None),\n  Value(2, Some(\"two\"), Some(\"number two\")),\n  Value(3, Some(\"three\"), Some(\"number three\")),\n  Value(4, Some(\"four\"), Some(\"number four\")),\n  Value(5, Some(\"five\"), Some(\"number five\")),\n).toDS\n\nval right = Seq(\n  Value(1, Some(\"one\"), Some(\"one\")),\n  Value(2, Some(\"Two\"), Some(\"number two\")),\n  Value(3, Some(\"Three\"), Some(\"number Three\")),\n  Value(4, Some(\"four\"), Some(\"number four\")),\n  Value(6, Some(\"six\"), Some(\"number six\")),\n).toDS\n```\n\nthe diff modes produce the following outputs:\n\n#### Column by Column\n\n|diff |id   |left_value|right_value|left_label  |right_label |\n|:---:|:---:|:--------:|:---------:|:----------:|:----------:|\n|C    |1    |one       |one        |*null*      |one         |\n|C    |2    |two       |Two        |number two  |number two  |\n|C    |3    |three     |Three      |number three|number Three|\n|N    |4    |four      |four       |number four |number four |\n|D    |5    |five      |null       |number five |*null*      |\n|I    |6    |*null*    |six        |*null*      |number six  |\n\n#### Side by Side\n\n|diff |id   |left_value|left_label  |right_value|right_label |\n|:---:|:---:|:--------:|:----------:|:---------:|:----------:|\n|C    |1    |one       |*null*      |one        |one         |\n|C    |2    |two       |number two  |Two        |number two  |\n|C    |3    |three     |number three|Three      |number Three|\n|N    |4    |four      |number four |four       |number four |\n|D    |5    |five      |number five |null       |*null*      |\n|I    |6    |*null*    |*null*      |six        |number six  |\n\n#### Left Side\n\n|diff |id   |value|label       |\n|:---:|:---:|:---:|:----------:|\n|C    |1    |one  |null        |\n|C    |2    |two  |number two  |\n|C    |3    |three|number three|\n|N    |4    |four |number four |\n|D    |5    |five |number five |\n|I    |6    |null |null        |\n\n#### Right Side\n\n|diff |id   |value|label       |\n|:---:|:---:|:---:|:----------:|\n|C    |1    |one  |one         |\n|C    |2    |Two  |number two  |\n|C    |3    |Three|number Three|\n|N    |4    |four |number four |\n|D    |5    |null |null        |\n|I    |6    |six  |number six  |\n\n### Sparse Mode\n\nThe diff modes above can be combined with sparse mode. In sparse mode, only values that differ between\nthe two datasets are in the diff result, all other values are `null`.\n\nAbove [Column by Column](#column-by-column) example would look in sparse mode as follows:\n\n|diff |id   |left_value|right_value|left_label  |right_label |\n|:---:|:---:|:--------:|:---------:|:----------:|:----------:|\n|C    |1    |null      |null       |null        |one         |\n|C    |2    |two       |Two        |null        |null        |\n|C    |3    |three     |Three      |number three|number Three|\n|N    |4    |null      |null       |null        |null        |\n|D    |5    |five      |null       |number five |null        |\n|I    |6    |null      |six        |null        |number six  |\n\n\n### Comparators (Equality)\n\nValues are compared for equality with the default `<=>` operator, which considers values\nequal when both sides are `null`, or both sides are not `null` and equal.\n\nThe following alternative comparators are provided:\n\n|Comparator|Description|\n|:---------|:----------|\n|`DiffComparators.epsilon(epsilon)`|Two values are equal when they are at most `epsilon` apart.<br/><br/>The comparator can be configured to use `epsilon` as an absolute (`.asAbsolute()`) threshold, or as relative (`.asRelative()`) to the larger value. Further, the threshold itself can be considered equal (`.asInclusive()`) or not equal (`.asExclusive()`):<ul><li>`DiffComparators.epsilon(epsilon).asAbsolute().asInclusive()`:<br/>`x` and `y` are equal iff `abs(x - y) ≤ epsilon`</li><li>`DiffComparators.epsilon(epsilon).asAbsolute().asExclusive()`:<br/>`x` and `y` are equal iff `abs(x - y) < epsilon`</li><li>`DiffComparators.epsilon(epsilon).asRelative().asInclusive()`:<br/>`x` and `y` are equal iff `abs(x - y) ≤ epsilon * max(abs(x), abs(y))`</li><li>`DiffComparators.epsilon(epsilon).asRelative().asExclusive()`:<br/>`x` and `y` are equal iff `abs(x - y) < epsilon * max(abs(x), abs(y))`</li></ul>|\n|`DiffComparators.string()`|Two `StringType` values are compared while ignoring white space differences. For this comparison, sequences of whitespaces are collapesed into single whitespaces, leading and trailing whitespaces are removed. With `DiffComparators.string(false)`, string values are compared with the default comparator.|\n|`DiffComparators.duration(duration)`|Two `DateType` or `TimestampType` values are equal when they are at most `duration` apart. That duration is an instance of `java.time.Duration`.<br/><br/>The comparator can be configured to consider `duration` as equal (`.asInclusive()`) or not equal (`.asExclusive()`):<ul><li>`DiffComparators.duration(duration).asInclusive()`:<br/>`x` and `y` are equal iff `x - y ≤ duration`</li><li>`DiffComparators.duration(duration).asExclusive()`:<br/>`x` and `y` are equal iff `x - y < duration`</li></lu>|\n|`DiffComparators.map[K,V](keyOrderSensitive)` (Scala only)<br/>`DiffComparators.map(keyType, valueType, keyOrderSensitive)`|Two `Map[K,V]` values are equal when they match in all their keys and values. With `keyOrderSensitive=true`, the order of the keys matters, with `keyOrderSensitive=false` (default), the order of keys is ignored.|\n\nAn example:\n\n    val left = Seq((1, 1.0), (2, 2.0), (3, 3.0)).toDF(\"id\", \"value\")\n    val right = Seq((1, 1.0), (2, 2.02), (3, 3.05)).toDF(\"id\", \"value\")\n    left.diff(right, \"id\").show()\n\n|diff| id|left_value|right_value|\n|----|---|----------|-----------|\n|   N|  1|       1.0|        1.0|\n|   C|  2|       2.0|       2.02|\n|   C|  3|       3.0|       3.05|\n\nThe second and third rows are considered `\"C\"`hanged because `2.0 != 2.02` and `3.0 != 3.05`, respectively.\n\nWith an inclusive relative epsilon of 1%, `2.0 != 2.02` is considered equal, while `3.0 != 3.05` is still not equal:\n\n    val options = DiffOptions.default\n      .withComparator(DiffComparators.epsilon(0.01).asRelative().asInclusive(), DoubleType)\n    left.diff(right, options, \"id\").show()\n\n|diff| id|left_value|right_value|\n|----|---|----------|-----------|\n|   N|  1|       1.0|        1.0|\n|   N|  2|       2.0|       2.02|\n|   C|  3|       3.0|       3.05|\n\nThe user can provide custom comparator implementations by implementing `scala.math.Equiv[T]`\nor `uk.co.gresearch.spark.diff.DiffComparator`:\n\n    val intEquiv: Equiv[Int] = (x: Int, y: Int) => x == null && y == null || x != null && y != null && x.equals(y)\n    val anyEquiv: Equiv[Any] = (x: Any, y: Any) => x == null && y == null || x != null && y != null && x.equals(y)\n\n    val comparator: DiffComparator = (left: Column, right: Column) => left <=> right\n\n    import spark.implicits._\n\n    val options = DiffOptions.default\n      .withComparator(intEquiv)\n      .withComparator(anyEquiv, LongType, DoubleType)\n      .withComparator(anyEquiv, \"column1\", \"column2\")\n\n      .withComparator(comparator, StringType, FloatType)\n      .withComparator(comparator, \"column3\", \"column4\")\n\n\n## Methods (Scala)\n\nAll Scala methods come in two variants, one without (as shown below) and one with an `options: DiffOptions` argument.\n\n* `def diff(other: Dataset[T], idColumns: String*): DataFrame`\n* `def diff[U](other: Dataset[U], idColumns: Seq[String], ignoreColumns: Seq[String]): DataFrame`\n\n\n* `def diffAs[V](other: Dataset[T], idColumns: String*)(implicit diffEncoder: Encoder[V]): Dataset[V]`\n* `def diffAs[U, V](other: Dataset[U], idColumns: Seq[String], ignoreColumns: Seq[String])(implicit diffEncoder: Encoder[V]): Dataset[V]`\n* `def diffAs[V](other: Dataset[T], diffEncoder: Encoder[U], idColumns: String*): Dataset[V]`\n* `def diffAs[U, V](other: Dataset[U], diffEncoder: Encoder[U], idColumns: Seq[String], ignoreColumns: Seq[String]): Dataset[V]`\n\n\n* `def diffWith(other: Dataset[T], idColumns: String*): Dataset[(String, T, T)]`\n* `def diffWith[U](other: Dataset[U], idColumns: Seq[String], ignoreColumns: Seq[String]): Dataset[(String, T, U)]`\n\n## Methods (Java)\n\n* `Dataset<Row> Diff.of[T](Dataset<T> left, Dataset<T> right, String... idColumns)`\n* `Dataset<Row> Diff.of[T, U](Dataset<T> left, Dataset<U> right, List<String> idColumns, List<String> ignoreColumns)`\n\n\n* `Dataset<V> Diff.ofAs[T, V](Dataset<T> left, Dataset<T> right, Encoder<V> diffEncoder, String... idColumns)`\n* `Dataset<V> Diff.ofAs[T, U, V](Dataset<T> left, Dataset<U> right, Encoder<V> diffEncoder, List<String> idColumns, List<String> ignoreColumns)`\n\n\n* `Dataset<Tuple3<String, T, T>> Diff.ofWith[T](Dataset<T> left, Dataset<T> right, String... idColumns)`\n* `Dataset<Tuple3<String, T, U>> Diff.ofWith[T](Dataset<T> left, Dataset<U> right, List<String> idColumns, List<String> ignoreColumns)`\n\nGiven a `DiffOptions`, a customized `Differ` can be instantiated as `Differ differ = new Differ(options)`:\n\n* `Dataset<Row> Differ.diff[T](Dataset<T> left, Dataset<T> right, String... idColumns)`\n* `Dataset<Row> Differ.diff[T, U](Dataset<T> left, Dataset<U> right, List<String> idColumns, List<String> ignoreColumns)`\n\n\n* `Dataset<U> Differ.diffAs[T, V](Dataset<T> left, Dataset<T> right, Encoder<V> diffEncoder, String... idColumns)`\n* `Dataset<U> Differ.diffAs[T, U, V](Dataset<T> left, Dataset<U> right, Encoder<V> diffEncoder, List<String> idColumns, List<String> ignoreColumns)`\n\n\n* `Dataset<Row> Differ.diffWith[T](Dataset<T> left, Dataset<T> right, String... idColumns)`\n* `Dataset<Row> Differ.diffWith[T, U](Dataset<T> left, Dataset<U> right, List<String> idColumns, List<String> ignoreColumns)`\n\n## Methods (Python)\n\n* `def diff(self: DataFrame, other: DataFrame, *id_columns: str) -> DataFrame`\n* `def diff(self: DataFrame, other: DataFrame, id_columns: List[str], ignore_columns: List[str]) -> DataFrame`\n* `def diff(self: DataFrame, other: DataFrame, options: DiffOptions, *id_columns: str) -> DataFrame`\n* `def diff(self: DataFrame, other: DataFrame, options: DiffOptions, id_columns: List[str], ignore_columns: List[str]) -> DataFrame`\n* `def diffwith(self: DataFrame, other: DataFrame, *id_columns: str) -> DataFrame:`\n* `def diffwith(self: DataFrame, other: DataFrame, id_columns: List[str], ignore_columns: List[str]) -> DataFrame`\n* `def diffwith(self: DataFrame, other: DataFrame, options: DiffOptions, *id_columns: str) -> DataFrame:`\n* `def diffwith(self: DataFrame, other: DataFrame, options: DiffOptions, id_columns: List[str], ignore_columns: List[str]) -> DataFrame`\n\n## Diff Spark application\n\nThere is also a Spark application that can be used to create a diff DataFrame. The application reads two DataFrames\n`left` and `right` from files or tables, executes the diff transformation and writes the result DataFrame to a file or table.\nThe Diff app can be run via `spark-submit`:\n\n```shell\n# Scala 2.12\nspark-submit --packages com.github.scopt:scopt_2.12:4.1.0 spark-extension_2.12-2.7.0-3.4.jar --help\n\n# Scala 2.13\nspark-submit --packages com.github.scopt:scopt_2.13:4.1.0 spark-extension_2.13-2.7.0-3.4.jar --help\n```\n\n```\nSpark Diff app (2.10.0-3.4)\n\nUsage: spark-extension_2.13-2.10.0-3.4.jar [options] left right diff\n\n  left                     file path (requires format option) or table name to read left dataframe\n  right                    file path (requires format option) or table name to read right dataframe\n  diff                     file path (requires format option) or table name to write diff dataframe\n\nExamples:\n\n  - Diff CSV files 'left.csv' and 'right.csv' and write result into CSV file 'diff.csv':\n    spark-submit --packages com.github.scopt:scopt_2.13:4.1.0 spark-extension_2.13-2.10.0-3.4.jar --format csv left.csv right.csv diff.csv\n\n  - Diff CSV file 'left.csv' with Parquet file 'right.parquet' with id column 'id', and write result into Hive table 'diff':\n    spark-submit --packages com.github.scopt:scopt_2.13:4.1.0 spark-extension_2.13-2.10.0-3.4.jar --left-format csv --right-format parquet --hive --id id left.csv right.parquet diff\n\nSpark session\n  --master <master>        Spark master (local, yarn, ...), not needed with spark-submit\n  --app-name <app-name>    Spark application name\n  --hive                   enable Hive support to read from and write to Hive tables\n\nInput and output\n  -f, --format <format>    input and output file format (csv, json, parquet, ...)\n  --left-format <format>   left input file format (csv, json, parquet, ...)\n  --right-format <format>  right input file format (csv, json, parquet, ...)\n  --output-format <formt>  output file format (csv, json, parquet, ...)\n\n  -s, --schema <schema>    input schema\n  --left-schema <schema>   left input schema\n  --right-schema <schema>  right input schema\n\n  --left-option:key=val    left input option\n  --right-option:key=val   right input option\n  --output-option:key=val  output option\n\n  --id <name>              id column name\n  --ignore <name>          ignore column name\n  --save-mode <save-mode>  save mode for writing output (Append, Overwrite, ErrorIfExists, Ignore, default ErrorIfExists)\n  --filter <filter>        Filters for rows with these diff actions, with default diffing options use 'N', 'I', 'D', or 'C' (see 'Diffing options' section)\n  --statistics             Only output statistics on how many rows exist per diff action (see 'Diffing options' section)\n\nDiffing options\n  --diff-column <name>     column name for diff column (default 'diff')\n  --left-prefix <prefix>   prefix for left column names (default 'left')\n  --right-prefix <prefix>  prefix for right column names (default 'right')\n  --insert-value <value>   value for insertion (default 'I')\n  --change-value <value>   value for change (default 'C')\n  --delete-value <value>   value for deletion (default 'D')\n  --no-change-value <val>  value for no change (default 'N')\n  --change-column <name>   column name for change column (default is no such column)\n  --diff-mode <mode>       diff mode (ColumnByColumn, SideBySide, LeftSide, RightSide, default ColumnByColumn)\n  --sparse                 enable sparse diff\n\nGeneral\n  --help                   prints this usage text\n```\n\n### Examples\n\nDiff CSV files `left.csv` and `right.csv` and write result into CSV file `diff.csv`:\n```shell\nspark-submit --packages com.github.scopt:scopt_2.13:4.1.0 spark-extension_2.13-2.7.0-3.4.jar --format csv left.csv right.csv diff.csv\n```\n\nDiff CSV file `left.csv` with Parquet file `right.parquet` with id column `id`, and write result into Hive table `diff`:\n```shell\nspark-submit --packages com.github.scopt:scopt_2.13:4.1.0 spark-extension_2.13-2.7.0-3.4.jar --left-format csv --right-format parquet --hive --id id left.csv right.parquet diff\n```\n"
  },
  {
    "path": "GROUPS.md",
    "content": "# Sorted Groups\n\nSpark provides the ability to group rows by an arbitrary key,\nwhile then providing an iterator for each of these groups.\nThis allows to iterate over groups that are too large to fit into memory:\n\n```scala\nimport org.apache.spark.sql.Dataset\n\nimport spark.implicits._\n\ncase class Val(id: Int, seq: Int, value: Double)\n\nval ds: Dataset[Val] = Seq(\n  Val(1, 1, 1.1),\n  Val(1, 2, 1.2),\n  Val(1, 3, 1.3),\n\n  Val(2, 1, 2.1),\n  Val(2, 2, 2.2),\n  Val(2, 3, 2.3),\n\n  Val(3, 1, 3.1)\n).reverse.toDS().repartition(3).cache()\n\n// order of iterator IS NOT guaranteed\nds.groupByKey(v => v.id)\n  .flatMapGroups((key, it) => it.zipWithIndex.map(v => (key, v._2, v._1.seq, v._1.value)))\n  .toDF(\"key\", \"index\", \"seq\", \"value\")\n  .show(false)\n\n+---+-----+---+-----+\n|key|index|seq|value|\n+---+-----+---+-----+\n|1  |0    |3  |1.3  |\n|1  |1    |2  |1.2  |\n|1  |2    |1  |1.1  |\n|2  |0    |1  |2.1  |\n|2  |1    |3  |2.3  |\n|2  |2    |2  |2.2  |\n|3  |0    |1  |3.1  |\n+---+-----+---+-----+\n```\n\nHowever, we have no control over the order of the group iterators.\nIf we want the iterators to be ordered according to `seq`, we can do the following:\n\n```scala\nimport uk.co.gresearch.spark._\n\n// the group key $\"id\" needs an ordering\nimplicit val ordering: Ordering.Int.type = Ordering.Int\n\n// order of iterator IS guaranteed\nds.groupBySorted($\"id\")($\"seq\")\n  .flatMapSortedGroups((key, it) => it.zipWithIndex.map(v => (key, v._2, v._1.seq, v._1.value)))\n  .toDF(\"key\", \"index\", \"seq\", \"value\")\n  .show(false)\n\n+---+-----+---+-----+\n|key|index|seq|value|\n+---+-----+---+-----+\n|1  |0    |1  |1.1  |\n|1  |1    |2  |1.2  |\n|1  |2    |3  |1.3  |\n|2  |0    |1  |2.1  |\n|2  |1    |2  |2.2  |\n|2  |2    |3  |2.3  |\n|3  |0    |1  |3.1  |\n+---+-----+---+-----+\n```\n\nNow, iterators are ordered according to `seq`, which is proven by the value of `index`,\nthat has been generated by `it.zipWithIndex`.\n\nInstead of column expressions, we can also use lambdas to define group key and group order:\n```scala\nds.groupByKeySorted(v => v.id)(v => v.seq)\n  .flatMapSortedGroups((key, it) => it.zipWithIndex.map(v => (key, v._2, v._1.seq, v._1.value)))\n  .toDF(\"key\", \"index\", \"seq\", \"value\")\n  .show(false)\n```\n\n**Note:** Using lambdas here hides from Spark which columns we use for grouping and sorting.\nQuery optimization cannot improve partitioning and sorting in this case. Use column expressions when possible.\n"
  },
  {
    "path": "HISTOGRAM.md",
    "content": "# Histogram\n\nFor a table `df` like\n\n|user   |score|\n|:-----:|:---:|\n|Alice  |101  |\n|Alice  |221  |\n|Alice  |211  |\n|Alice  |176  |\n|Bob    |276  |\n|Bob    |232  |\n|Bon    |258  |\n|Charlie|221  |\n\nyou can compute the histogram for each user\n\n|user   |≤100 |≤200 |>200 |\n|:-----:|:---:|:---:|:---:|\n|Alice  |0    |2    |2    |\n|Bob    |0    |0    |3    |\n|Charlie|0    |0    |1    |\n\nas follows:\n\n    df.withColumn(\"≤100\", when($\"score\" <= 100, 1).otherwise(0))\n      .withColumn(\"≤200\", when($\"score\" > 100 && $\"score\" <= 200, 1).otherwise(0))\n      .withColumn(\">200\", when($\"score\" > 200, 1).otherwise(0))\n      .groupBy($\"user\")\n      .agg(\n        sum($\"≤100\").as(\"≤100\"),\n        sum($\"≤200\").as(\"≤200\"),\n        sum($\">200\").as(\">200\")\n      )\n      .orderBy($\"user\")\n\nEquivalent to that query is:\n\n    import uk.co.gresearch.spark._\n\n    df.histogram(Seq(100, 200), $\"score\", $\"user\").orderBy($\"user\")\n\nThe first argument is a sequence of thresholds, the second argument provides the value column.\nThe subsequent arguments refer to the aggregation columns (`groupBy`). Only aggregation columns\nwill be in the result DataFrame.\n\nIn Java, call:\n\n    import uk.co.gresearch.spark.Histogram;\n\n    Histogram.of(df, Arrays.asList(100, 200), new Column(\"score\")), new Column(\"user\")).orderBy($\"user\")\n\nIn Python, call:\n\n    import gresearch.spark\n\n    df.histogram([100, 200], 'user').orderBy('user')\n\nNote that this feature is not supported in Python when connected with a [Spark Connect server](README.md#spark-connect-server).\n"
  },
  {
    "path": "LICENSE",
    "content": "\n                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "MAINTAINERS.md",
    "content": "## Current maintainers of the project\n\n| Maintainer             | GitHub ID                                               |\n| ---------------------- | ------------------------------------------------------- |\n| Enrico Minack          | [EnricoMi](https://github.com/EnricoMi)                 |\n"
  },
  {
    "path": "PARQUET.md",
    "content": "# Parquet Metadata\n\nThe structure of Parquet files (the metadata, not the data stored in Parquet) can be inspected similar to [parquet-tools](https://pypi.org/project/parquet-tools/)\nor [parquet-cli](https://pypi.org/project/parquet-cli/)\nby reading from a simple Spark data source.\n\nParquet metadata can be read on [file level](#parquet-file-metadata),\n[schema level](#parquet-file-schema),\n[row group level](#parquet-block--rowgroup-metadata),\n[column chunk level](#parquet-block-column-metadata) and\n[Spark Parquet partition level](#parquet-partition-metadata).\nMultiple files can be inspected at once.\n\nAny location that can be read by Spark (`spark.read.parquet(…)`) can be inspected.\nThis means the path can point to a single Parquet file, a directory with Parquet files,\nor multiple paths separated by a comma (`,`). Paths can contain wildcards like `*`.\nMultiple files will be inspected in parallel and distributed by Spark.\nNo actual rows or values will be read from the Parquet files, only metadata, which is very fast.\nThis allows to inspect Parquet files that have different schemata with one `spark.read` operation.\n\nFirst, import the new Parquet metadata data sources:\n\n```scala\n// Scala\nimport uk.co.gresearch.spark.parquet._\n```\n\n```python\n# Python\nimport gresearch.spark.parquet\n```\n\nThen, the following metadata become available:\n\n## Parquet file metadata\n\nRead the metadata of Parquet files into a Dataframe:\n\n```scala\n// Scala\nspark.read.parquetMetadata(\"/path/to/parquet\").show()\n```\n```python\n# Python\nspark.read.parquet_metadata(\"/path/to/parquet\").show()\n```\n```\n+-------------+------+---------------+-----------------+----+-------+------+-----+--------------------+--------------------+-----------+--------------------+\n|     filename|blocks|compressedBytes|uncompressedBytes|rows|columns|values|nulls|           createdBy|              schema| encryption|           keyValues|\n+-------------+------+---------------+-----------------+----+-------+------+-----+--------------------+--------------------+-----------+--------------------+\n|file1.parquet|     1|           1268|             1652| 100|      2|   200|    0|parquet-mr versio...|message spark_sch...|UNENCRYPTED|{org.apache.spark...|\n|file2.parquet|     2|           2539|             3302| 200|      2|   400|    0|parquet-mr versio...|message spark_sch...|UNENCRYPTED|{org.apache.spark...|\n+-------------+------+---------------+-----------------+----+-------+------+-----+--------------------+--------------------+-----------+--------------------+\n```\n\nThe Dataframe provides the following per-file information:\n\n|column            |type  | description                                                                    |\n|:-----------------|:----:|:-------------------------------------------------------------------------------|\n|filename          |string| The Parquet file name                                                          |\n|blocks            |int   | Number of blocks / RowGroups in the Parquet file                               |\n|compressedBytes   |long  | Number of compressed bytes of all blocks                                       |\n|uncompressedBytes |long  | Number of uncompressed bytes of all blocks                                     |\n|rows              |long  | Number of rows in the file                                                     |\n|columns           |int   | Number of columns in the file                                                  |\n|values            |long  | Number of values in the file                                                   |\n|nulls             |long  | Number of null values in the file                                              |\n|createdBy         |string| The createdBy string of the Parquet file, e.g. library used to write the file  |\n|schema            |string| The schema                                                                     |\n|encryption        |string| The encryption (requires `org.apache.parquet:parquet-hadoop:1.12.4` and above) |\n|keyValues         |string-to-string map| Key-value data of the file                                       |\n\n## Parquet file schema\n\nRead the schema of Parquet files into a Dataframe:\n\n```scala\n// Scala\nspark.read.parquetSchema(\"/path/to/parquet\").show()\n```\n```python\n# Python\nspark.read.parquet_schema(\"/path/to/parquet\").show()\n```\n```\n+------------+----------+------------------+----------+------+------+----------------+--------------------+-----------+-------------+------------------+------------------+------------------+\n|    filename|columnName|        columnPath|repetition|  type|length|    originalType|         logicalType|isPrimitive|primitiveType|    primitiveOrder|maxDefinitionLevel|maxRepetitionLevel|\n+------------+----------+------------------+----------+------+------+----------------+--------------------+-----------+-------------+------------------+------------------+------------------+\n|file.parquet|         a|               [a]|  REQUIRED| INT64|     0|            NULL|                NULL|       true|        INT64|TYPE_DEFINED_ORDER|                 0|                 0|\n|file.parquet|         x|            [b, x]|  REQUIRED| INT32|     0|            NULL|                NULL|       true|        INT32|TYPE_DEFINED_ORDER|                 1|                 0|\n|file.parquet|         y|            [b, y]|  REQUIRED|DOUBLE|     0|            NULL|                NULL|       true|       DOUBLE|TYPE_DEFINED_ORDER|                 1|                 0|\n|file.parquet|         z|            [b, z]|  OPTIONAL| INT64|     0|TIMESTAMP_MICROS|TIMESTAMP(MICROS,...|       true|        INT64|TYPE_DEFINED_ORDER|                 2|                 0|\n|file.parquet|   element|[c, list, element]|  OPTIONAL|BINARY|     0|            UTF8|              STRING|       true|       BINARY|TYPE_DEFINED_ORDER|                 3|                 1|\n+------------+----------+------------------+----------+------+------+----------------+--------------------+-----------+-------------+------------------+------------------+------------------+\n```\n\nThe Dataframe provides the following per-file information:\n\n|column            |     type     | description                                                                       |\n|:-----------------|:------------:|:----------------------------------------------------------------------------------|\n|filename          |    string    | The Parquet file name                                                             |\n|columnName        |    string    | The column name                                                                   |\n|columnPath        | string array | The column path                                                                   |\n|repetition        |    string    | The repetition                                                                    |\n|type              |    string    | The data type                                                                     |\n|length            |     int      | The length of the type                                                            |\n|originalType      |   string     | The original type (requires `org.apache.parquet:parquet-hadoop:1.11.0` and above) |\n|isPrimitive       |   boolean    | True if type is primitive                                                         |\n|primitiveType     |    string    | The primitive type                                                                |\n|primitiveOrder    |    string    | The order of the primitive type                                                   |\n|maxDefinitionLevel|     int      | The max definition level                                                          |\n|maxRepetitionLevel|     int      | The max repetition level                                                          |\n\n## Parquet block / RowGroup metadata\n\nRead the metadata of Parquet blocks / RowGroups into a Dataframe:\n\n```scala\n// Scala\nspark.read.parquetBlocks(\"/path/to/parquet\").show()\n```\n```python\n# Python\nspark.read.parquet_blocks(\"/path/to/parquet\").show()\n```\n```\n+-------------+-----+----------+---------------+-----------------+----+-------+------+-----+\n|     filename|block|blockStart|compressedBytes|uncompressedBytes|rows|columns|values|nulls|\n+-------------+-----+----------+---------------+-----------------+----+-------+------+-----+\n|file1.parquet|    1|         4|           1269|             1651| 100|      2|   200|    0|\n|file2.parquet|    1|         4|           1268|             1652| 100|      2|   200|    0|\n|file2.parquet|    2|      1273|           1270|             1651| 100|      2|   200|    0|\n+-------------+-----+----------+---------------+-----------------+----+-------+------+-----+\n```\n\n|column            |type  |description                                    |\n|:-----------------|:----:|:----------------------------------------------|\n|filename          |string|The Parquet file name                          |\n|block             |int   |Block / RowGroup number starting at 1          |\n|blockStart        |long  |Start position of the block in the Parquet file|\n|compressedBytes   |long  |Number of compressed bytes in block            |\n|uncompressedBytes |long  |Number of uncompressed bytes in block          |\n|rows              |long  |Number of rows in block                        |\n|columns           |int   |Number of columns in block                     |\n|values            |long  |Number of values in block                      |\n|nulls             |long  |Number of null values in block                 |\n\n## Parquet block column metadata\n\nRead the metadata of Parquet block columns into a Dataframe:\n\n```scala\n// Scala\nspark.read.parquetBlockColumns(\"/path/to/parquet\").show()\n```\n```python\n# Python\nspark.read.parquet_block_columns(\"/path/to/parquet\").show()\n```\n```\n+-------------+-----+------+------+-------------------+-------------------+--------------------+------------------+-----------+---------------+-----------------+------+-----+\n|     filename|block|column| codec|               type|          encodings|            minValue|          maxValue|columnStart|compressedBytes|uncompressedBytes|values|nulls|\n+-------------+-----+------+------+-------------------+-------------------+--------------------+------------------+-----------+---------------+-----------------+------+-----+\n|file1.parquet|    1|  [id]|SNAPPY|  required int64 id|[BIT_PACKED, PLAIN]|                   0|                99|          4|            437|              826|   100|    0|\n|file1.parquet|    1| [val]|SNAPPY|required double val|[BIT_PACKED, PLAIN]|0.005067503372006343|0.9973357672164814|        441|            831|              826|   100|    0|\n|file2.parquet|    1|  [id]|SNAPPY|  required int64 id|[BIT_PACKED, PLAIN]|                 100|               199|          4|            438|              825|   100|    0|\n|file2.parquet|    1| [val]|SNAPPY|required double val|[BIT_PACKED, PLAIN]|0.010617521596503865| 0.999189783846449|        442|            831|              826|   100|    0|\n|file2.parquet|    2|  [id]|SNAPPY|  required int64 id|[BIT_PACKED, PLAIN]|                 200|               299|       1273|            440|              826|   100|    0|\n|file2.parquet|    2| [val]|SNAPPY|required double val|[BIT_PACKED, PLAIN]|0.011277044401634018| 0.970525681750662|       1713|            830|              825|   100|    0|\n+-------------+-----+------+------+-------------------+-------------------+--------------------+------------------+-----------+---------------+-----------------+------+-----+\n```\n\n| column            |     type      | description                                                                                       |\n|:------------------|:-------------:|:--------------------------------------------------------------------------------------------------|\n| filename          |    string     | The Parquet file name                                                                             |\n| block             |      int      | Block / RowGroup number starting at 1                                                             |\n| column            | array<string> | Block / RowGroup column name                                                                      |\n| codec             |    string     | The coded used to compress the block column values                                                |\n| type              |    string     | The data type of the block column                                                                 |\n| encodings         | array<string> | Encodings of the block column                                                                     |\n| isEncrypted       |    boolean    | Whether block column is encrypted (requires `org.apache.parquet:parquet-hadoop:1.12.3` and above) |\n| minValue          |    string     | Minimum value of this column in this block                                                        |\n| maxValue          |    string     | Maximum value of this column in this block                                                        |\n| columnStart       |     long      | Start position of the block column in the Parquet file                                            |\n| compressedBytes   |     long      | Number of compressed bytes of this block column                                                   |\n| uncompressedBytes |     long      | Number of uncompressed bytes of this block column                                                 |\n| values            |     long      | Number of values in this block column                                                             |\n| nulls             |     long      | Number of null values in this block column                                                        |\n\n## Parquet partition metadata\n\nRead the metadata of how Spark partitions Parquet files into a Dataframe:\n\n```scala\n// Scala\nspark.read.parquetPartitions(\"/path/to/parquet\").show()\n```\n```python\n# Python\nspark.read.parquet_partitions(\"/path/to/parquet\").show()\n```\n```\n+---------+-----+----+------+------+---------------+-----------------+----+-------+------+-----+-------------+----------+\n|partition|start| end|length|blocks|compressedBytes|uncompressedBytes|rows|columns|values|nulls|     filename|fileLength|\n+---------+-----+----+------+------+---------------+-----------------+----+-------+------+-----+-------------+----------+\n|        1|    0|1024|  1024|     1|           1268|             1652| 100|      2|   200|    0|file1.parquet|      1930|\n|        2| 1024|1930|   906|     0|              0|                0|   0|      0|     0|    0|file1.parquet|      1930|\n|        3|    0|1024|  1024|     1|           1269|             1651| 100|      2|   200|    0|file2.parquet|      3493|\n|        4| 1024|2048|  1024|     1|           1270|             1651| 100|      2|   200|    0|file2.parquet|      3493|\n|        5| 2048|3072|  1024|     0|              0|                0|   0|      0|     0|    0|file2.parquet|      3493|\n|        6| 3072|3493|   421|     0|              0|                0|   0|      0|     0|    0|file2.parquet|      3493|\n+---------+-----+----+------+------+---------------+-----------------+----+-------+------+-----+-------------+----------+\n```\n\n|column           |type  |description                                               |\n|:----------------|:----:|:---------------------------------------------------------|\n|partition        |int   |The Spark partition id                                    |\n|start            |long  |The start position of the partition                       |\n|end              |long  |The end position of the partition                         |\n|length           |long  |The length of the partition                               |\n|blocks           |int   |The number of Parquet blocks / RowGroups in this partition|\n|compressedBytes  |long  |The number of compressed bytes in this partition          |\n|uncompressedBytes|long  |The number of uncompressed bytes in this partition        |\n|rows             |long  |The number of rows in this partition                      |\n|columns          |int   |The number of columns in this partition                   |\n|values           |long  |The number of values in this partition                    |\n|nulls            |long  |The number of null values in this partition               |\n|filename         |string|The Parquet file name                                     |\n|fileLength       |long  |The length of the Parquet file                            |\n\n## Performance\n\nRetrieving Parquet metadata is parallelized and distributed by Spark. The result Dataframe\nhas as many partitions as there are Parquet files in the given `path`, but at most\n`spark.sparkContext.defaultParallelism` partitions.\n\nEach result partition reads Parquet metadata from its Parquet files sequentially,\nwhile partitions are executed in parallel (depending on the number of Spark cores of your Spark job).\n\nYou can control the number of partitions via the `parallelism` parameter:\n\n```scala\n// Scala\nspark.read.parquetMetadata(100, \"/path/to/parquet\")\nspark.read.parquetSchema(100, \"/path/to/parquet\")\nspark.read.parquetBlocks(100, \"/path/to/parquet\")\nspark.read.parquetBlockColumns(100, \"/path/to/parquet\")\nspark.read.parquetPartitions(100, \"/path/to/parquet\")\n```\n```python\n# Python\nspark.read.parquet_metadata(\"/path/to/parquet\", parallelism=100)\nspark.read.parquet_schema(\"/path/to/parquet\", parallelism=100)\nspark.read.parquet_blocks(\"/path/to/parquet\", parallelism=100)\nspark.read.parquet_block_columns(\"/path/to/parquet\", parallelism=100)\nspark.read.parquet_partitions(\"/path/to/parquet\", parallelism=100)\n```\n\n## Encryption\n\nReading [encrypted Parquet is supported](https://spark.apache.org/docs/latest/sql-data-sources-parquet.html#columnar-encryption).\nFiles encrypted with [plaintext footer](https://github.com/apache/parquet-format/blob/master/Encryption.md#55-plaintext-footer-mode)\ncan be read without any encryption keys, while encrypted Parquet metadata are then show as `NULL` values in the result Dataframe.\nEncrypted Parquet files with encrypted footer requires the footer encryption key only. No column encryption keys are needed.\n\n## Known Issues\n\nNote that this feature is not supported in Python when connected with a [Spark Connect server](README.md#spark-connect-server).\n"
  },
  {
    "path": "PARTITIONING.md",
    "content": "# Partitioned Writing\n\nIf you have ever used `Dataset[T].write.partitionBy`, here is how you can minimize the number of\nwritten files and obtain same-size files.\n\nSpark has two different concepts both referred to as partitioning. Central to Spark is the\nconcept of how a `Dataset[T]` is split into partitions where a Spark worker processes\na single partition at a time. This is the fundamental concept of how Spark scales with data.\n\nWhen writing a `Dataset` `ds` to a file-based storage, that output file is actually a directory:\n\n<!--\nimport java.sql.Timestamp\nimport java.sql.Timestamp\n\ncase class Value(id: Int, ts: Timestamp, property: String, value: String)\nval ds = Seq(\n  Value(1, Timestamp.valueOf(\"2020-07-01 12:00:00\"), \"label\", \"one\"),\n  Value(1, Timestamp.valueOf(\"2020-07-02 12:00:00\"), \"descr\", \"number one\"),\n  Value(1, Timestamp.valueOf(\"2020-07-03 12:00:00\"), \"label\", \"ONE\"),\n  Value(2, Timestamp.valueOf(\"2020-07-01 12:00:00\"), \"label\", \"two\"),\n  Value(2, Timestamp.valueOf(\"2020-07-03 12:00:00\"), \"label\", \"TWO\"),\n  Value(2, Timestamp.valueOf(\"2020-07-04 12:00:00\"), \"descr\", \"number two\"),\n  Value(3, Timestamp.valueOf(\"2020-07-03 12:00:00\"), \"label\", \"THREE\"),\n  Value(3, Timestamp.valueOf(\"2020-07-03 12:00:00\"), \"descr\", \"number three\"),\n  Value(4, Timestamp.valueOf(\"2020-07-01 12:00:00\"), \"label\", \"four\"),\n  Value(4, Timestamp.valueOf(\"2020-07-03 12:00:00\"), \"descr\", \"number four\"),\n  Value(5, Timestamp.valueOf(\"2020-07-01 12:00:00\"), \"label\", \"five\"),\n  Value(5, Timestamp.valueOf(\"2020-07-03 12:00:00\"), \"descr\", \"number five\"),\n  Value(6, Timestamp.valueOf(\"2020-07-01 12:00:00\"), \"label\", \"six\"),\n  Value(6, Timestamp.valueOf(\"2020-07-01 12:00:00\"), \"descr\", \"number six\"),\n).toDS()\n-->\n\n```scala\nds.write.csv(\"file.csv\")\n```\n\nThe directory structure looks like:\n\n    file.csv\n    file.csv/part-00000-7d34816f-bb53-4f44-ab9d-a62d570e5de0-c000.csv\n    file.csv/part-00001-7d34816f-bb53-4f44-ab9d-a62d570e5de0-c000.csv\n    file.csv/part-00002-7d34816f-bb53-4f44-ab9d-a62d570e5de0-c000.csv\n    file.csv/part-00003-7d34816f-bb53-4f44-ab9d-a62d570e5de0-c000.csv\n    file.csv/part-00004-7d34816f-bb53-4f44-ab9d-a62d570e5de0-c000.csv\n    file.csv/_SUCCESS\n\nWhen writing, the output can be `partitionBy` one or more columns of the `Dataset`.\nFor each distinct `value` in that column `col` an individual sub-directory is created in your output path.\nThe name is of the format `col=value`. Inside the sub-directory, multiple partitions exists,\nall containing only data where column `col` has value `value`. To remove redundancy, those\nfiles do not contain that column anymore.\n\n    file.csv/property=descr/part-00001-8eb44de1-2c33-4f95-a1be-8d1b4e35eb4a.c000.csv\n    file.csv/property=descr/part-00002-8eb44de1-2c33-4f95-a1be-8d1b4e35eb4a.c000.csv\n    file.csv/property=descr/part-00003-8eb44de1-2c33-4f95-a1be-8d1b4e35eb4a.c000.csv\n    file.csv/property=descr/part-00004-8eb44de1-2c33-4f95-a1be-8d1b4e35eb4a.c000.csv\n    file.csv/property=label/part-00001-8eb44de1-2c33-4f95-a1be-8d1b4e35eb4a.c000.csv\n    file.csv/property=label/part-00002-8eb44de1-2c33-4f95-a1be-8d1b4e35eb4a.c000.csv\n    file.csv/property=label/part-00003-8eb44de1-2c33-4f95-a1be-8d1b4e35eb4a.c000.csv\n    file.csv/property=label/part-00004-8eb44de1-2c33-4f95-a1be-8d1b4e35eb4a.c000.csv\n    file.csv/_SUCCESS\n\nData that is mis-organized when written end up with the same number of files\nin each of the sub-directories, even if some sub-directories contain only a fraction of\nthe number of rows than others. What you would like to have is have fewer files in smaller\nand more files in larger partition sub-directories. Further, all files should have\nroughly the same number of rows.\n\nFor this, you have to first range partition the `Dataset` according to your partition columns.\n\n    ds.repartitionByRange($\"property\", $\"id\")\n      .write\n      .partitionBy(\"property\")\n      .csv(\"file.csv\")\n\nThis organizes the data optimally for partition-writing them by column `property`.\n\n    file.csv/property=descr/part-00000-6317db5e-5161-41f1-8227-ffeaf06a3e41.c000.csv\n    file.csv/property=descr/part-00001-6317db5e-5161-41f1-8227-ffeaf06a3e41.c000.csv\n    file.csv/property=label/part-00002-6317db5e-5161-41f1-8227-ffeaf06a3e41.c000.csv\n    file.csv/property=label/part-00003-6317db5e-5161-41f1-8227-ffeaf06a3e41.c000.csv\n    file.csv/property=label/part-00004-6317db5e-5161-41f1-8227-ffeaf06a3e41.c000.csv\n    file.csv/_SUCCESS\n\nThis brings all rows with the same value in the `property` and `id` column into the same file.\n\nIf you need each file to further be sorted by additional columns, e.g. `ts`, then you can do this with `sortWithinPartitions`.\n\n    ds.repartitionByRange($\"property\", $\"id\")\n      .sortWithinPartitions($\"property\", $\"id\", $\"ts\")\n      .cache    // this is needed for Spark 3.0 to 3.3 with AQE enabled: SPARK-40588\n      .write\n      .partitionBy(\"property\")\n      .csv(\"file.csv\")\n\nSometimes you want to write-partition by some expression that is not a column of your data,\ne.g. the date-representation of the `ts` column.\n\n    ds.withColumn(\"date\", $\"ts\".cast(DateType))\n      .repartitionByRange($\"date\", $\"id\")\n      .sortWithinPartitions($\"date\", $\"id\", $\"ts\")\n      .cache    // this is needed for Spark 3.0 to 3.3 with AQE enabled: SPARK-40588\n      .write\n      .partitionBy(\"date\")\n      .csv(\"file.csv\")\n\nAll those above constructs can be replaced with a single meaningful operation:\n\n    ds.writePartitionedBy(Seq($\"ts\".cast(DateType).as(\"date\")), Seq($\"id\"), Seq($\"ts\"))\n      .csv(\"file.csv\")\n\nFor Spark 3.0 to 3.3 with AQE enabled (see [SPARK-40588](https://issues.apache.org/jira/browse/SPARK-40588)),\n`writePartitionedBy` has to cache an internally created DataFrame. This can be unpersisted after writing\nis finished. Provide an `UnpersistHandle` for this purpose:\n\n    val unpersist = UnpersistHandle()\n\n    ds.writePartitionedBy(…, unpersistHandle = Some(unpersist))\n      .csv(\"file.csv\")\n\n    unpersist()\n\nMore details about this issue can be found [here](https://www.gresearch.co.uk/blog/article/guaranteeing-in-partition-order-for-partitioned-writing-in-apache-spark/).\n\n<!--\n# Other Approaches\n\nproblems with `repartition()` instead of `repartitionByRange()`\nproblems with `repartitionByRange(cols).write.partitionBy(cols)`\n-->"
  },
  {
    "path": "PYSPARK-DEPS.md",
    "content": "# PySpark dependencies\n\nUsing PySpark on a cluster requires all cluster nodes to have those Python packages installed that are required by the PySpark job.\nSuch a deployment can be cumbersome, especially when running in an interactive notebook.\n\nThe `spark-extension` package allows installing Python packages programmatically by the PySpark application itself (PySpark ≥ 3.1.0).\nThese packages are only accessible by that PySpark application, and they are removed on calling `spark.stop()`.\n\nEither install the `spark-extension` Maven package, or the `pyspark-extension` PyPi package (on the driver only),\nas described [here](README.md#using-spark-extension).\n\n## Installing packages with `pip`\n\nPython packages can be installed with `pip` as follows:\n\n```python\n# noinspection PyUnresolvedReferences\nfrom gresearch.spark import *\n\nspark.install_pip_package(\"pandas\", \"pyarrow\")\n```\n\nAbove example installs PIP packages `pandas` and `pyarrow` via `pip`. Method `install_pip_package` takes any `pip` command line argument:\n\n```python\n# install packages with version specs\nspark.install_pip_package(\"pandas==1.4.3\", \"pyarrow~=8.0.0\")\n\n# install packages from package sources (e.g. git clone https://github.com/pandas-dev/pandas.git)\nspark.install_pip_package(\"./pandas/\")\n\n# install packages from git repo\nspark.install_pip_package(\"git+https://github.com/pandas-dev/pandas.git@main\")\n\n# use a pip cache directory to cache downloaded and built whl files\nspark.install_pip_package(\"pandas\", \"pyarrow\", \"--cache-dir\", \"/home/user/.cache/pip\")\n\n# use an alternative index url (other than https://pypi.org/simple)\nspark.install_pip_package(\"pandas\", \"pyarrow\", \"--index-url\", \"https://artifacts.company.com/pypi/simple\")\n\n# install pip packages quietly (only disables output of PIP)\nspark.install_pip_package(\"pandas\", \"pyarrow\", \"--quiet\")\n```\n\n## Installing Python projects with Poetry\n\nPython projects can be installed from sources, including their dependencies, using [Poetry](https://python-poetry.org/):\n\n```python\n# noinspection PyUnresolvedReferences\nfrom gresearch.spark import *\n\nspark.install_poetry_project(\"../my-poetry-project/\", poetry_python=\"../venv-poetry/bin/python\")\n```\n\n## Example\n\nThis example uses `install_pip_package` in a Spark standalone cluster.\n\nFirst checkout the example code:\n\n```shell\ngit clone https://github.com/G-Research/spark-extension.git\ncd spark-extension/examples/python-deps\n```\n\nBuild a Docker image based on the official Spark release:\n```shell\ndocker build -t spark-extension-example-docker .\n```\n\nStart the example Spark standalone cluster consisting of a Spark master and one worker:\n```shell\ndocker compose -f docker-compose.yml up -d\n```\n\nRun the `example.py` Spark application on the example cluster:\n```shell\ndocker exec spark-master spark-submit --master spark://master:7077 --packages uk.co.gresearch.spark:spark-extension_2.12:2.15.0-3.5 /example/example.py\n```\nThe `--packages uk.co.gresearch.spark:spark-extension_2.12:2.15.0-3.5` argument\ntells `spark-submit` to add the `spark-extension` Maven package to the Spark job.\n\nAlternatively, install the `pyspark-extension` PyPi package via `pip install` and remove the `--packages` argument from `spark-submit`:\n```shell\ndocker exec spark-master pip install --user pyspark_extension==2.11.1.3.5\ndocker exec spark-master spark-submit --master spark://master:7077 /example/example.py\n```\n\nThis output proves that PySpark could call into the function `func`, wich only works when Pandas and PyArrow are installed:\n```\n+---+\n| id|\n+---+\n|  0|\n|  1|\n|  2|\n+---+\n```\n\nTest that `spark.install_pip_package(\"pandas\", \"pyarrow\")` is really required by this example by removing this line from `example.py` …\n```diff\n from pyspark.sql import SparkSession\n\n def main():\n     spark = SparkSession.builder.appName(\"spark_app\").getOrCreate()\n\n     def func(df):\n         return df\n\n     from gresearch.spark import install_pip_package\n\n-    spark.install_pip_package(\"pandas\", \"pyarrow\")\n     spark.range(0, 3, 1, 5).mapInPandas(func, \"id long\").show()\n\n if __name__ == \"__main__\":\n     main()\n```\n\n… and running the `spark-submit` command again. The example does not work anymore,\nbecause the Pandas and PyArrow packages are missing from the driver:\n```\nTraceback (most recent call last):\n  File \"/opt/spark/python/lib/pyspark.zip/pyspark/sql/pandas/utils.py\", line 27, in require_minimum_pandas_version\nModuleNotFoundError: No module named 'pandas'\n```\n\nFinally, shutdown the example cluster:\n```shell\ndocker compose -f docker-compose.yml down\n```\n\n## Known Issues\n\nNote that this feature is not supported in Python when connected with a [Spark Connect server](README.md#spark-connect-server).\n"
  },
  {
    "path": "README.md",
    "content": "# Spark Extension\n\nThis project provides extensions to the [Apache Spark project](https://spark.apache.org/) in Scala and Python:\n\n**[Diff](DIFF.md):** A `diff` transformation and application for `Dataset`s that computes the differences between\ntwo datasets, i.e. which rows to _add_, _delete_ or _change_ to get from one dataset to the other.\n\n**[SortedGroups](GROUPS.md):** A `groupByKey` transformation that groups rows by a key while providing\na **sorted** iterator for each group. Similar to `Dataset.groupByKey.flatMapGroups`, but with order guarantees\nfor the iterator.\n\n**[Histogram](HISTOGRAM.md) [<sup>[*]</sup>](#spark-connect-server):** A `histogram` transformation that computes the histogram DataFrame for a value column.\n\n**[Global Row Number](ROW_NUMBER.md) [<sup>[*]</sup>](#spark-connect-server):** A `withRowNumbers` transformation that provides the global row number w.r.t.\nthe current order of the Dataset, or any given order. In contrast to the existing SQL function `row_number`, which\nrequires a window spec, this transformation provides the row number across the entire Dataset without scaling problems.\n\n**[Partitioned Writing](PARTITIONING.md):** The `writePartitionedBy` action writes your `Dataset` partitioned and\nefficiently laid out with a single operation.\n\n**[Inspect Parquet files](PARQUET.md) [<sup>[*]</sup>](#spark-connect-server):** The structure of Parquet files (the metadata, not the data stored in Parquet) can be inspected similar to [parquet-tools](https://pypi.org/project/parquet-tools/)\nor [parquet-cli](https://pypi.org/project/parquet-cli/) by reading from a simple Spark data source.\nThis simplifies identifying why some Parquet files cannot be split by Spark into scalable partitions.\n\n**[Install Python packages into PySpark job](PYSPARK-DEPS.md) [<sup>[*]</sup>](#spark-connect-server):** Install Python dependencies via PIP or Poetry programatically into your running PySpark job (PySpark ≥ 3.1.0):\n\n```python\n# noinspection PyUnresolvedReferences\nfrom gresearch.spark import *\n\n# using PIP\nspark.install_pip_package(\"pandas==1.4.3\", \"pyarrow\")\nspark.install_pip_package(\"-r\", \"requirements.txt\")\n\n# using Poetry\nspark.install_poetry_project(\"../my-poetry-project/\", poetry_python=\"../venv-poetry/bin/python\")\n```\n\n**[Fluent method call](CONDITIONAL.md):** `T.call(transformation: T => R): R`: Turns a transformation `T => R`,\nthat is not part of `T` into a fluent method call on `T`. This allows writing fluent code like:\n\n```scala\nimport uk.co.gresearch._\n\ni.doThis()\n .doThat()\n .call(transformation)\n .doMore()\n```\n\n**[Fluent conditional method call](CONDITIONAL.md):** `T.when(condition: Boolean).call(transformation: T => T): T`:\nPerform a transformation fluently only if the given condition is true.\nThis allows writing fluent code like:\n\n```scala\nimport uk.co.gresearch._\n\ni.doThis()\n .doThat()\n .when(condition).call(transformation)\n .doMore()\n```\n\n**[Shortcut for groupBy.as](https://github.com/G-Research/spark-extension/pull/213#issue-2032837105)**: Calling `Dataset.groupBy(Column*).as[K, T]`\nshould be preferred over calling `Dataset.groupByKey(V => K)` whenever possible. The former allows Catalyst to exploit\nexisting partitioning and ordering of the Dataset, while the latter hides from Catalyst which columns are used to create the keys.\nThis can have a significant performance penalty.\n\n<details>\n<summary>Details:</summary>\n\nThe new column-expression-based `groupByKey[K](Column*)` method makes it easier to group by a column expression key. Instead of\n\n    ds.groupBy($\"id\").as[Int, V]\n\nuse:\n\n    ds.groupByKey[Int]($\"id\")\n</details>\n\n**Backticks:** `backticks(string: String, strings: String*): String)`: Encloses the given column name with backticks (`` ` ``) when needed.\nThis is a handy way to ensure column names with special characters like dots (`.`) work with `col()` or `select()`.\n\n**Count null values:** `count_null(e: Column)`: an aggregation function like `count` that counts null values in column `e`.\nThis is equivalent to calling `count(when(e.isNull, lit(1)))`.\n\n**.Net DateTime.Ticks[<sup>[*]</sup>](#spark-connect-server):** Convert .Net (C#, F#, Visual Basic) `DateTime.Ticks` into Spark timestamps, seconds and nanoseconds.\n\n<details>\n<summary>Available methods:</summary>\n\n```scala\n// Scala\ndotNetTicksToTimestamp(Column): Column       // returns timestamp as TimestampType\ndotNetTicksToUnixEpoch(Column): Column       // returns Unix epoch seconds as DecimalType\ndotNetTicksToUnixEpochNanos(Column): Column  // returns Unix epoch nanoseconds as LongType\n```\n\nThe reverse is provided by (all return `LongType` .Net ticks):\n```scala\n// Scala\ntimestampToDotNetTicks(Column): Column\nunixEpochToDotNetTicks(Column): Column\nunixEpochNanosToDotNetTicks(Column): Column\n```\n\nThese methods are also available in Python:\n```python\n# Python\ndotnet_ticks_to_timestamp(column_or_name)         # returns timestamp as TimestampType\ndotnet_ticks_to_unix_epoch(column_or_name)        # returns Unix epoch seconds as DecimalType\ndotnet_ticks_to_unix_epoch_nanos(column_or_name)  # returns Unix epoch nanoseconds as LongType\n\ntimestamp_to_dotnet_ticks(column_or_name)\nunix_epoch_to_dotnet_ticks(column_or_name)\nunix_epoch_nanos_to_dotnet_ticks(column_or_name)\n```\n</details>\n\n**Spark temporary directory[<sup>[*]</sup>](#spark-connect-server)**: Create a temporary directory that will be removed on Spark application shutdown.\n\n<details>\n<summary>Examples:</summary>\n\nScala:\n```scala\nimport uk.co.gresearch.spark.createTemporaryDir\n\nval dir = createTemporaryDir(\"prefix\")\n```\n\nPython:\n```python\n# noinspection PyUnresolvedReferences\nfrom gresearch.spark import *\n\ndir = spark.create_temporary_dir(\"prefix\")\n```\n</details>\n\n**Spark job description[<sup>[*]</sup>](#spark-connect-server):** Set Spark job description for all Spark jobs within a context.\n\n<details>\n<summary>Examples:</summary>\n\n```scala\nimport uk.co.gresearch.spark._\n\nimplicit val session: SparkSession = spark\n\nwithJobDescription(\"parquet file\") {\n  val df = spark.read.parquet(\"data.parquet\")\n  val count = appendJobDescription(\"count\") {\n    df.count\n  }\n  appendJobDescription(\"write\") {\n    df.write.csv(\"data.csv\")\n  }\n}\n```\n\n| Without job description  | With job description |\n|:---:|:---:|\n| ![](without-job-description.png \"Spark job without description in UI\") | ![](with-job-description.png \"Spark job with description in UI\") |\n\nNote that setting a description in one thread while calling the action (e.g. `.count`) in a different thread\ndoes not work, unless the different thread is spawned from the current thread _after_ the description has been set.\n\nWorking example with parallel collections:\n\n```scala\nimport java.util.concurrent.ForkJoinPool\nimport scala.collection.parallel.CollectionConverters.seqIsParallelizable\nimport scala.collection.parallel.ForkJoinTaskSupport\n\nval files = Seq(\"data1.csv\", \"data2.csv\").par\n\nval counts = withJobDescription(\"Counting rows\") {\n  // new thread pool required to spawn new threads from this thread\n  // so that the job description is actually used\n  files.tasksupport = new ForkJoinTaskSupport(new ForkJoinPool())\n  files.map(filename => spark.read.csv(filename).count).sum\n}(spark)\n```\n</details>\n\n## Using Spark Extension\n\nThe `spark-extension` package is available for all Spark 3.2, 3.3, 3.4 and 3.5 versions.\nThe package version has the following semantics: `spark-extension_{SCALA_COMPAT_VERSION}-{VERSION}-{SPARK_COMPAT_VERSION}`:\n\n- `SCALA_COMPAT_VERSION`: Scala binary compatibility (minor) version. Available are `2.12` and `2.13`.\n- `SPARK_COMPAT_VERSION`: Apache Spark binary compatibility (minor) version. Available are `3.2`, `3.3`, `3.4`, `3.5` and `4.0`.\n- `VERSION`: The package version, e.g. `2.14.0`.\n\n### SBT\n\nAdd this line to your `build.sbt` file:\n\n```sbt\nlibraryDependencies += \"uk.co.gresearch.spark\" %% \"spark-extension\" % \"2.15.0-3.5\"\n```\n\n### Maven\n\nAdd this dependency to your `pom.xml` file:\n\n```xml\n<dependency>\n  <groupId>uk.co.gresearch.spark</groupId>\n  <artifactId>spark-extension_2.12</artifactId>\n  <version>2.15.0-3.5</version>\n</dependency>\n```\n\n### Gradle\n\nAdd this dependency to your `build.gradle` file:\n\n```groovy\ndependencies {\n    implementation \"uk.co.gresearch.spark:spark-extension_2.12:2.15.0-3.5\"\n}\n```\n\n### Spark Submit\n\nSubmit your Spark app with the Spark Extension dependency (version ≥1.1.0) as follows:\n\n```shell script\nspark-submit --packages uk.co.gresearch.spark:spark-extension_2.12:2.15.0-3.5 [jar]\n```\n\nNote: Pick the right Scala version (here 2.12) and Spark version (here 3.5) depending on your Spark version.\n\n### Spark Shell\n\nLaunch a Spark Shell with the Spark Extension dependency (version ≥1.1.0) as follows:\n\n```shell script\nspark-shell --packages uk.co.gresearch.spark:spark-extension_2.12:2.15.0-3.5\n```\n\nNote: Pick the right Scala version (here 2.12) and Spark version (here 3.5) depending on your Spark Shell version.\n\n### Python\n\n#### PySpark API\n\nStart a PySpark session with the Spark Extension dependency (version ≥1.1.0) as follows:\n\n```python\nfrom pyspark.sql import SparkSession\n\nspark = SparkSession \\\n    .builder \\\n    .config(\"spark.jars.packages\", \"uk.co.gresearch.spark:spark-extension_2.12:2.15.0-3.5\") \\\n    .getOrCreate()\n```\n\nNote: Pick the right Scala version (here 2.12) and Spark version (here 3.5) depending on your PySpark version.\n\n#### PySpark REPL\n\nLaunch the Python Spark REPL with the Spark Extension dependency (version ≥1.1.0) as follows:\n\n```shell script\npyspark --packages uk.co.gresearch.spark:spark-extension_2.12:2.15.0-3.5\n```\n\nNote: Pick the right Scala version (here 2.12) and Spark version (here 3.5) depending on your PySpark version.\n\n#### PySpark `spark-submit`\n\nRun your Python scripts that use PySpark via `spark-submit`:\n\n```shell script\nspark-submit --packages uk.co.gresearch.spark:spark-extension_2.12:2.15.0-3.5 [script.py]\n```\n\nNote: Pick the right Scala version (here 2.12) and Spark version (here 3.5) depending on your Spark version.\n\n#### PyPi package (local Spark cluster only)\n\nYou may want to install the `pyspark-extension` python package from PyPi into your development environment.\nThis provides you code completion, typing and test capabilities during your development phase.\n\nRunning your Python application on a Spark cluster will still require one of the above ways\nto add the Scala package to the Spark environment.\n\n```shell script\npip install pyspark-extension==2.15.0.3.5\n```\n\nNote: Pick the right Spark version (here 3.5) depending on your PySpark version.\n\n### Your favorite Data Science notebook\n\nThere are plenty of [Data Science notebooks](https://datasciencenotebook.org/) around. To use this library,\nadd **a jar dependency** to your notebook using these **Maven coordinates**:\n\n    uk.co.gresearch.spark:spark-extension_2.12:2.15.0-3.5\n\nOr [download the jar](https://mvnrepository.com/artifact/uk.co.gresearch.spark/spark-extension) and place it\non a filesystem where it is accessible by the notebook, and reference that jar file directly.\n\nCheck the documentation of your favorite notebook to learn how to add jars to your Spark environment.\n\n## Known issues\n### Spark Connect Server\n\nMost features are not supported **in Python** in conjunction with a [Spark Connect server](https://spark.apache.org/docs/latest/spark-connect-overview.html).\nThis also holds for Databricks Runtime environment 13.x and above. Details can be found [in this blog](https://semyonsinchenko.github.io/ssinchenko/post/how-databricks-14x-breaks-3dparty-compatibility/).\n\nCalling any of those features when connected to a Spark Connect server will raise this error:\n\n    This feature is not supported for Spark Connect.\n\nUse a classic connection to a Spark cluster instead.\n\n## Build\n\nYou can build this project against different versions of Spark and Scala.\n\n### Switch Spark and Scala version\n\nIf you want to build for a Spark or Scala version different to what is defined in the `pom.xml` file, then run\n\n```shell script\nsh set-version.sh [SPARK-VERSION] [SCALA-VERSION]\n```\n\nFor example, switch to Spark 3.5.0 and Scala 2.13.8 by running `sh set-version.sh 3.5.0 2.13.8`.\n\n### Build the Scala project\n\nThen execute `mvn package` to create a jar from the sources. It can be found in `target/`.\n\n## Testing\n\nRun the Scala tests via `mvn test`.\n\n### Setup Python environment\n\nIn order to run the Python tests, setup a Python environment as follows:\n\n```shell script\nvirtualenv -p python3 venv\nsource venv/bin/activate\npip install python/[test]\n```\n\n### Run Python tests\n\nRun the Python tests via `env PYTHONPATH=python/test python -m pytest python/test`.\n\n### Build Python package\n\nRun the following commands in the project root directory to create a whl from the sources:\n\n```shell script\npip install build\npython -m build python/\n```\n\nIt can be found in `python/dist/`.\n\n## Publications\n\n- ***Guaranteeing in-partition order for partitioned-writing in Apache Spark**, Enrico Minack, 20/01/2023*:<br/>https://www.gresearch.com/blog/article/guaranteeing-in-partition-order-for-partitioned-writing-in-apache-spark/\n- ***Un-pivot, sorted groups and many bug fixes: Celebrating the first Spark 3.4 release**, Enrico Minack, 21/03/2023*:<br/>https://www.gresearch.com/blog/article/un-pivot-sorted-groups-and-many-bug-fixes-celebrating-the-first-spark-3-4-release/\n- ***A PySpark bug makes co-grouping with window function partition-key-order-sensitive**, Enrico Minack, 29/03/2023*:<br/>https://www.gresearch.com/blog/article/a-pyspark-bug-makes-co-grouping-with-window-function-partition-key-order-sensitive/\n- ***Spark’s groupByKey should be avoided – and here’s why**, Enrico Minack, 13/06/2023*:<br/>https://www.gresearch.com/blog/article/sparks-groupbykey-should-be-avoided-and-heres-why/\n- ***Inspecting Parquet files with Spark**, Enrico Minack, 28/07/2023*:<br/>https://www.gresearch.com/blog/article/parquet-files-know-your-scaling-limits/\n- ***Enhancing Spark’s UI with Job Descriptions**, Enrico Minack, 12/12/2023*:<br/>https://www.gresearch.com/blog/article/enhancing-sparks-ui-with-job-descriptions/\n- ***PySpark apps with dependencies: Managing Python dependencies in code**, Enrico Minack, 24/01/2024*:<br/>https://www.gresearch.com/news/pyspark-apps-with-dependencies-managing-python-dependencies-in-code/\n- ***Observing Spark Aggregates: Cheap Metrics from Datasets**, Enrico Minack, 06/02/2024*:<br/>https://www.gresearch.com/news/observing-spark-aggregates-cheap-metrics-from-datasets-2/\n\n## Security\n\nPlease see our [security policy](https://github.com/G-Research/spark-extension/blob/master/SECURITY.md) for details on reporting security vulnerabilities.\n"
  },
  {
    "path": "RELEASE.md",
    "content": "# Releasing Spark Extension\n\nThis provides instructions on how to release a version of `spark-extension`. We release this library\nfor a number of Spark and Scala environments, but all from the same git tag. Release for the environment\nthat is set in the `pom.xml` and create a tag. On success, release from that tag for all other environments\nas described below.\n\nUse the `release.sh` script to test and release all versions. Or execute the following steps manually.\n\n## Testing master for all environments\n\nThe following steps release a snapshot and test it. Test all versions listed [further down](#releasing-master-for-other-environments).\n\n- Set the version with `./set-version.sh`, e.g. `./set-version.sh 3.4.0 2.12.17`\n- Release a snapshot (make sure the version in the `pom.xml` file ends with `SNAPSHOT`): `mvn clean deploy`\n- Test the released snapshot: `./test-release.sh`\n\n## Releasing from master\n\nFollow this procedure to release a new version:\n\n- Add a new entry to `CHANGELOG.md` listing all notable changes of this release.\n  Use the heading `## [VERSION] - YYYY-MM-dd`, e.g. `## [1.1.0] - 2020-03-12`.\n- Remove the `-SNAPSHOT` suffix from the version, e.g. `./set-version 1.1.0`.\n- Update the versions in the `README.md` and `python/README.md` file to the version of your `pom.xml` to reflect the latest version,\n  e.g. replace all `1.0.0-3.1` with `1.1.0-3.1` and `1.0.0.3.1` with `1.1.0.3.1`, respectively.\n- Commit the change to your local git repository, use a commit message like `Releasing 1.1.0`. Do not push to github yet.\n- Tag that commit with a version tag like `v1.1.0` and message like `Release v1.1.0`. Do not push to github yet.\n- Release the version with `mvn clean deploy`. This will be put into a staging repository and not automatically released (due to `<autoReleaseAfterClose>false</autoReleaseAfterClose>` in your [`pom.xml`](pom.xml) file).\n- Inspect and test the staged version. Use `./test-release.sh` or the `spark-examples` project for that. If you are happy with everything:\n  - Push the commit and tag to origin.\n  - Release the package with `mvn nexus-staging:release`.\n  - Bump the version to the next [minor version](https://semver.org/) and append the `-SNAPSHOT` suffix again: `./set-version 1.2.0-SNAPSHOT`.\n  - Commit this change to your local git repository, use a commit message like `Post-release version bump to 1.2.0`.\n  - Push all local commits to origin.\n- Otherwise drop it with `mvn nexus-staging:drop`. Remove the last two commits from your local history.\n\n## Releasing master for other environments\n\nOnce you have released the new version, release from the same tag for all other Spark and Scala environments as well:\n- Release for these environments, one of these has been released above, that should be the tagged version:\n\n|Spark|Scala|\n|:----|:----|\n|3.2  |2.12.15 and 2.13.5|\n|3.3  |2.12.15 and 2.13.8|\n|3.4  |2.12.17 and 2.13.8|\n|3.5  |2.12.17 and 2.13.8|\n- Always use the latest Spark version per Spark minor version\n- Release process:\n  - Checkout the release tag, e.g. `git checkout v1.0.0`\n  - Set the version in the `pom.xml` file via `set-version.sh`, e.g. `./set-version.sh 3.4.0 2.12.17`\n  - Review the `pom.xml` file changes: `git diff pom.xml`\n  - Release the version with `mvn clean deploy`\n  - Inspect and test the staged version. Use `./test-release.sh` or the `spark-examples` project for that.\n    - If you are happy with everything, release the package with `mvn nexus-staging:release`.\n    - Otherwise drop it with `mvn nexus-staging:drop`.\n- Revert the changes done to the `pom.xml` file: `git checkout pom.xml`\n\n## Releasing a bug-fix version\n\nA bug-fix version needs to be released from a [minor-version branch](https://semver.org/), e.g. `branch-1.1`.\n\n### Create a bug-fix branch\n\nIf there is no bug-fix branch yet, create it:\n\n- Create such a branch from the respective [minor-version tag](https://semver.org/), e.g. create minor version branch `branch-1.1` from tag `v1.1.0`.\n- Bump the version to the next [patch version](https://semver.org/) in `pom.xml` and append the `-SNAPSHOT` suffix again, e.g. `1.1.0` → `1.1.1-SNAPSHOT`.\n- Commit this change to your local git repository, use a commit message like `Post-release version bump to 1.1.1`.\n- Push this commit to origin.\n\nMerge your bug fixes into this branch as you would normally do for master, use PRs for that.\n\n### Release from a bug-fix branch\n\nThis is very similar to [releasing from master](#releasing-from-master),\nbut the version increment occurs on [patch level](https://semver.org/):\n\n- Add a new entry to `CHANGELOG.md` listing all notable changes of this release.\n  Use the heading `## [VERSION] - YYYY-MM-dd`, e.g. `## [1.1.1] - 2020-03-12`.\n- Remove the `-SNAPSHOT` suffix from the version, e.g. `./set-version 1.1.1`.\n- Update the versions in the `README.md` and `python/README.md` file to the version of your `pom.xml` to reflect the latest version,\n  e.g. replace all `1.1.0-3.1` with `1.1.1-3.1` and `1.1.0.3.1` with `1.1.1.3.1`, respectively.\n- Commit the change to your local git repository, use a commit message like `Releasing 1.1.1`. Do not push to github yet.\n- Tag that commit with a version tag like `v1.1.1` and message like `Release v1.1.1`. Do not push to github yet.\n- Release the version with `mvn clean deploy`. This will be put into a staging repository and not automatically released (due to `<autoReleaseAfterClose>false</autoReleaseAfterClose>` in your [`pom.xml`](pom.xml) file).\n- Inspect and test the staged version. Use `./test-release.sh` or the `spark-examples` project for that. If you are happy with everything:\n  - Push the commit and tag to origin.\n  - Release the package with `mvn nexus-staging:release`.\n  - Bump the version to the next [patch version](https://semver.org/) and append the `-SNAPSHOT` suffix again: `./set-version 1.1.2-SNAPSHOT`.\n  - Commit this change to your local git repository, use a commit message like `Post-release version bump to 1.1.2`.\n  - Push all local commits to origin.\n- Otherwise drop it with `mvn nexus-staging:drop`. Remove the last two commits from your local history.\n\nConsider releasing the bug-fix version for other environments as well. See [above](#releasing-master-for-other-environments) section for details.\n"
  },
  {
    "path": "ROW_NUMBER.md",
    "content": "# Global Row Number\n\nSpark provides the [SQL function `row_number`](https://spark.apache.org/docs/latest/api/sql/index.html#row_number),\nwhich assigns each row a consecutive number, starting from 1. This function works on a [Window](https://spark.apache.org/docs/latest/api/scala/org/apache/spark/sql/expressions/Window.html).\nAssigning a row number over the entire Dataset will load the entire dataset into a single partition / executor.\nThis does not scale.\n\nSpark extensions provide the `Dataset` transformation `withRowNumbers`, which assigns a global row number while scaling:\n\n```scala\nval df = Seq((1, \"one\"), (2, \"TWO\"), (2, \"two\"), (3, \"three\")).toDF(\"id\", \"value\")\ndf.show()\n// +---+-----+\n// | id|value|\n// +---+-----+\n// |  1|  one|\n// |  2|  TWO|\n// |  2|  two|\n// |  3|three|\n// +---+-----+\n\nimport uk.co.gresearch.spark._\n\ndf.withRowNumbers().show()\n// +---+-----+----------+\n// | id|value|row_number|\n// +---+-----+----------+\n// |  1|  one|         1|\n// |  2|  two|         2|\n// |  2|  TWO|         3|\n// |  3|three|         4|\n// +---+-----+----------+\n```\n\nIn Java:\n```java\nimport uk.co.gresearch.spark.RowNumbers;\n\nRowNumbers.of(df).show();\n// +---+-----+----------+\n// | id|value|row_number|\n// +---+-----+----------+\n// |  1|  one|         1|\n// |  2|  two|         2|\n// |  2|  TWO|         3|\n// |  3|three|         4|\n// +---+-----+----------+\n```\n\nIn Python:\n```python\nimport gresearch.spark\n\ndf.with_row_numbers().show()\n# +---+-----+----------+\n# | id|value|row_number|\n# +---+-----+----------+\n# |  1|  one|         1|\n# |  2|  two|         2|\n# |  2|  TWO|         3|\n# |  3|three|         4|\n# +---+-----+----------+\n```\n\n## Row number order\nRow numbers are assigned in the current order of the Dataset. If you want a specific order, provide columns as follows:\n\n```scala\ndf.withRowNumbers($\"id\".desc, $\"value\").show()\n// +---+-----+----------+\n// | id|value|row_number|\n// +---+-----+----------+\n// |  3|three|         1|\n// |  2|  TWO|         2|\n// |  2|  two|         3|\n// |  1|  one|         4|\n// +---+-----+----------+\n```\n\nIn Java:\n```java\nRowNumbers.withOrderColumns(df.col(\"id\").desc(), df.col(\"value\")).of(df).show();\n// +---+-----+----------+\n// | id|value|row_number|\n// +---+-----+----------+\n// |  3|three|         1|\n// |  2|  TWO|         2|\n// |  2|  two|         3|\n// |  1|  one|         4|\n// +---+-----+----------+\n```\n\nIn Python:\n```python\ndf.with_row_numbers(order=[df.id.desc(), df.value]).show()\n# +---+-----+----------+\n# | id|value|row_number|\n# +---+-----+----------+\n# |  3|three|         1|\n# |  2|  TWO|         2|\n# |  2|  two|         3|\n# |  1|  one|         4|\n# +---+-----+----------+\n```\n\n## Row number column name\n\nThe column name that contains the row number can be changed by providing the `rowNumberColumnName` argument:\n\n```scala\ndf.withRowNumbers(rowNumberColumnName=\"row\").show()\n// +---+-----+---+\n// | id|value|row|\n// +---+-----+---+\n// |  1|  one|  1|\n// |  2|  TWO|  2|\n// |  2|  two|  3|\n// |  3|three|  4|\n// +---+-----+---+\n```\n\nIn Java:\n```java\nRowNumbers.withRowNumberColumnName(\"row\").of(df).show();\n// +---+-----+---+\n// | id|value|row|\n// +---+-----+---+\n// |  1|  one|  1|\n// |  2|  TWO|  2|\n// |  2|  two|  3|\n// |  3|three|  4|\n// +---+-----+---+\n```\n\nIn Python:\n```python\ndf.with_row_numbers(row_number_column_name='row').show()\n# +---+-----+---+\n# | id|value|row|\n# +---+-----+---+\n# |  1|  one|  1|\n# |  2|  TWO|  2|\n# |  2|  two|  3|\n# |  3|three|  4|\n# +---+-----+---+\n```\n\n## Cached / persisted intermediate Dataset\n\nThe `withRowNumbers` transformation requires the input Dataset to be\n[cached](https://spark.apache.org/docs/latest/api/scala/org/apache/spark/sql/Dataset.html#cache():Dataset.this.type) /\n[persisted](https://spark.apache.org/docs/latest/api/scala/org/apache/spark/sql/Dataset.html#persist(newLevel:org.apache.spark.storage.StorageLevel):Dataset.this.type),\nafter adding an intermediate column. You can specify the level of persistence through the `storageLevel` parameter.\n\n```scala\nimport org.apache.spark.storage.StorageLevel\n\nval dfWithRowNumbers = df.withRowNumbers(storageLevel=StorageLevel.DISK_ONLY)\n```\n\nIn Java:\n```java\nimport org.apache.spark.storage.StorageLevel;\n\nDataset<Row> dfWithRowNumbers = RowNumbers.withStorageLevel(StorageLevel.DISK_ONLY()).of(df);\n```\n\nIn Python:\n```python\nfrom pyspark.storagelevel import StorageLevel\n\ndf_with_row_numbers = df.with_row_numbers(storage_level=StorageLevel.DISK_ONLY)\n```\n\n## Un-persist intermediate Dataset\n\nIf you want control over when to un-persist this intermediate Dataset, you can provide an `UnpersistHandle` and call it\nwhen you are done with the result Dataset:\n\n```scala\nimport uk.co.gresearch.spark.UnpersistHandle\n\nval unpersist = UnpersistHandle()\nval dfWithRowNumbers = df.withRowNumbers(unpersistHandle=unpersist);\n\n// after you are done with dfWithRowNumbers you may want to call unpersist()\nunpersist(blocking=false)\n```\n\nIn Java:\n```java\nimport uk.co.gresearch.spark.UnpersistHandle;\n\nUnpersistHandle unpersist = new UnpersistHandle();\nDataset<Row> dfWithRowNumbers = RowNumbers.withUnpersistHandle(unpersist).of(df);\n\n// after you are done with dfWithRowNumbers you may want to call unpersist()\nunpersist.apply(true);\n```\n\nIn Python:\n```python\nunpersist = spark.unpersist_handle()\ndf_with_row_numbers = df.with_row_numbers(unpersist_handle=unpersist)\n\n# after you are done with df_with_row_numbers you may want to call unpersist()\nunpersist(blocking=True)\n```\n\n## Spark warning\n\nYou will recognize that Spark logs the following warning:\n\n```\nWindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.\n```\nThis warning is unavoidable, because `withRowNumbers` has to pull information about the initial partitions into a single partition.\nFortunately, there are only 12 Bytes per input partition required, so this amount of data usually fits into a single partition and the warning can safely be ignored.\n\n## Known issues\n\nNote that this feature is not supported in Python when connected with a [Spark Connect server](README.md#spark-connect-server).\n"
  },
  {
    "path": "SECURITY.md",
    "content": "# Security and Coordinated Vulnerability Disclosure Policy\n\nThis project appreciates and encourages coordinated disclosure of security vulnerabilities. We prefer that you use the GitHub reporting mechanism to privately report vulnerabilities. Under the main repository's security tab, click \"Report a vulnerability\" to open the advisory form.\n\nIf you are unable to report it via GitHub, have received no response after repeated attempts, or have other security related questions, please contact security@gr-oss.io and mention this project in the subject line."
  },
  {
    "path": "build-whl.sh",
    "content": "#!/bin/bash\n\nset -eo pipefail\n\nbase=$(cd \"$(dirname \"$0\")\"; pwd)\n\nversion=$(grep --max-count=1 \"<version>.*</version>\" \"$base/pom.xml\" | sed -E -e \"s/\\s*<[^>]+>//g\")\nartifact_id=$(grep --max-count=1 \"<artifactId>.*</artifactId>\" \"$base/pom.xml\" | sed -E -e \"s/\\s*<[^>]+>//g\")\n\nrm -rf \"$base/python/pyspark/jars/$artifact_id-*.jar\"\n\npip install build\npython -m build \"$base/python/\"\n\n# check for missing modules in whl file\npyversion=${version/SNAPSHOT/dev0}\npyversion=${pyversion//-/.}\n\nmissing=\"$(diff <(cd \"$base/python\"; find gresearch -type f | grep -v \".pyc$\" | sort) <(unzip -l \"$base/python/dist/pyspark_extension-${pyversion}-*.whl\" | tail -n +4 | head -n -2 | sed -E -e \"s/^ +//\" -e \"s/ +/ /g\" | cut -d \" \" -f 4- | sort) | grep \"^<\" || true)\"\nif [ -n \"$missing\" ]\nthen\n  echo \"These files are missing from the whl file:\"\n  echo \"$missing\"\n  exit 1\nfi\n\njars=$(unzip -l \"$base/python/dist/pyspark_extension-${pyversion}-*.whl\" | grep \".jar\" | wc -l)\nif [ $jars -ne 1 ]\nthen\n  echo \"Expected exactly one jar in whl file, but $jars found!\"\n  exit 1\nfi\n"
  },
  {
    "path": "bump-version.sh",
    "content": "#!/bin/bash\n#\n# Copyright 2020 G-Research\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n\n# Script to prepare release, see RELEASE.md for details\n\nset -e -o pipefail\n\n# check for clean git status\nreadarray -t git_status < <(git status -s --untracked-files=no 2>/dev/null)\nif [ ${#git_status[@]} -gt 0 ]\nthen\n  echo \"There are pending git changes:\"\n  for (( i=0; i<${#git_status[@]}; i++ )); do echo \"${git_status[$i]}\" ; done\n  exit 1\nfi\n\nfunction next_version {\n  local version=$1\n  local branch=$2\n\n  patch=${version/*./}\n  majmin=${version%.${patch}}\n\n  if [[ $branch == \"master\" ]]\n  then\n    # minor version bump\n    if [[ $version != *\".0\" ]]\n    then\n      echo \"version is patch version, should be M.m.0: $version\" >&2\n      exit 1\n    fi\n    maj=${version/.*/}\n    min=${majmin#${maj}.}\n    next=${maj}.$((min+1)).0\n    echo \"$next\"\n  else\n    # patch version bump\n    next=${majmin}.$((patch+1))\n    echo \"$next\"\n  fi\n}\n\n# get release and next version\nversion=$(grep --max-count=1 \"<version>.*</version>\" pom.xml | sed -E -e \"s/\\s*<[^>]+>//g\")\npkg_version=\"${version/-*/}\"\nbranch=$(git rev-parse --abbrev-ref HEAD)\nnext_pkg_version=\"$(next_version \"$pkg_version\" \"$branch\")\"\n\n# bump the version\necho \"Bump version to $next_pkg_version\"\n./set-version.sh $next_pkg_version-SNAPSHOT\n\n# commit changes to local repo\necho\necho \"Committing release to local git\"\ngit commit -a -m \"Post-release version bump to $next_pkg_version\"\ngit show HEAD\necho\n\n# push version bump to origin\necho \"Press <ENTER> to push commit to origin\"\nread\n\necho \"Pushing release commit to origin\"\ngit push origin \"master\"\necho\n"
  },
  {
    "path": "examples/python-deps/Dockerfile",
    "content": "FROM apache/spark:3.5.0\n\nENV PATH=\"${PATH}:/opt/spark/bin\"\n\nUSER root\nRUN mkdir -p /home/spark; chown spark:spark /home/spark\nUSER spark\n"
  },
  {
    "path": "examples/python-deps/docker-compose.yml",
    "content": "version: \"3\"\nservices:\n  master:\n    container_name: spark-master\n    image: spark-extension-example-docker\n    command: /opt/spark/bin/spark-class org.apache.spark.deploy.master.Master -h master\n    environment:\n      MASTER: spark://master:7077\n      SPARK_PUBLIC_DNS: localhost\n      SPARK_MASTER_WEBUI_PORT: 8080\n      PYSPARK_PYTHON: python${PYTHON_VERSION:-3.8}\n      PYSPARK_DRIVER_PYTHON: python${PYTHON_VERSION:-3.8}\n    expose:\n      - 7077\n    ports:\n      - 4040:4040\n      - 8080:8080\n    volumes:\n      - ./:/example\n\n  worker:\n    container_name: spark-worker\n    image: spark-extension-example-docker\n    command: /opt/spark/bin/spark-class org.apache.spark.deploy.worker.Worker spark://master:7077\n    environment:\n      SPARK_WORKER_CORES: 1\n      SPARK_WORKER_MEMORY: 1g\n      SPARK_WORKER_PORT: 8881\n      SPARK_WORKER_WEBUI_PORT: 8081\n      SPARK_PUBLIC_DNS: localhost\n    links:\n      - master\n    ports:\n      - 8081:8081\n\n"
  },
  {
    "path": "examples/python-deps/example.py",
    "content": "from pyspark.sql import SparkSession\n\ndef main():\n    spark = SparkSession.builder.appName(\"spark_app\").getOrCreate()\n\n    def func(df):\n        return df\n\n    from gresearch.spark import install_pip_package\n\n    spark.install_pip_package(\"pandas\", \"pyarrow\")\n    spark.range(0, 3, 1, 5).mapInPandas(func, \"id long\").show()\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "pom.xml",
    "content": "<project xmlns=\"http://maven.apache.org/POM/4.0.0\" xmlns:xsi=\"http://www.w3.org/2001/XMLSchema-instance\" xsi:schemaLocation=\"http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd\">\n  <modelVersion>4.0.0</modelVersion>\n  <groupId>uk.co.gresearch.spark</groupId>\n  <artifactId>spark-extension_2.13</artifactId>\n  <version>2.16.0-3.5-SNAPSHOT</version>\n  <name>Spark Extension</name>\n  <description>A library that provides useful extensions to Apache Spark.</description>\n  <inceptionYear>2020</inceptionYear>\n  <url>https://github.com/G-Research</url>\n  <licenses>\n    <license>\n      <name>Apache 2.0 License</name>\n      <url>http://www.apache.org/licenses/LICENSE-2.0.html</url>\n      <distribution>repo</distribution>\n    </license>\n  </licenses>\n  <scm>\n    <connection>scm:git:git://github.com/g-research/spark-extension.git</connection>\n    <developerConnection>scm:git:ssh://github.com:g-research/spark-extension.git</developerConnection>\n    <url>https://github.com/g-research/spark-extension/tree/${project.scm.tag}</url>\n    <tag>master</tag>\n  </scm>\n  <developers>\n    <developer>\n      <id>EnricoMi</id>\n      <name>Enrico Minack</name>\n      <email>github@enrico.minack.dev</email>\n    </developer>\n  </developers>\n  <issueManagement>\n    <system>GitHub Issues</system>\n    <url>https://github.com/G-Research/spark-extension/issues</url>\n  </issueManagement>\n\n  <properties>\n    <java.version>1.8</java.version>\n    <maven.compiler.source>${java.version}</maven.compiler.source>\n    <maven.compiler.target>${java.version}</maven.compiler.target>\n    <encoding>UTF-8</encoding>\n    <project.version>${project.version}</project.version>\n    <scala.major.version>2</scala.major.version>\n    <scala.minor.version>13</scala.minor.version>\n    <scala.patch.version>8</scala.patch.version>\n    <scala.compat.version>${scala.major.version}.${scala.minor.version}</scala.compat.version>\n    <scala.version>${scala.compat.version}.${scala.patch.version}</scala.version>\n    <spark.major.version>3</spark.major.version>\n    <spark.minor.version>5</spark.minor.version>\n    <spark.patch.version>1</spark.patch.version>\n    <spark.compat.version>${spark.major.version}.${spark.minor.version}</spark.compat.version>\n    <spark.version>${spark.compat.version}.${spark.patch.version}</spark.version>\n  </properties>\n\n  <dependencies>\n    <dependency>\n      <groupId>org.scala-lang</groupId>\n      <artifactId>scala-library</artifactId>\n      <version>${scala.version}</version>\n    </dependency>\n\n    <!-- pulls in spark-core, which is also a dependency, but referencing spark-core might break transitive dependencies -->\n    <!-- see https://github.com/apache/spark/pull/40933#issuecomment-1536609310 -->\n    <dependency>\n      <groupId>org.apache.spark</groupId>\n      <artifactId>spark-sql_${scala.compat.version}</artifactId>\n      <version>${spark.version}</version>\n      <scope>provided</scope>\n      <exclusions>\n        <!-- we exclude parquet-hadoop and all its dependencies, as we depend on our own version -->\n        <exclusion>\n          <groupId>org.apache.parquet</groupId>\n          <artifactId>*</artifactId>\n        </exclusion>\n        <exclusion>\n          <groupId>io.airlift</groupId>\n          <artifactId>aircompressor</artifactId>\n        </exclusion>\n        <exclusion>\n          <groupId>org.xerial.snappy</groupId>\n          <artifactId>snappy-java</artifactId>\n        </exclusion>\n        <exclusion>\n          <groupId>org.slf4j</groupId>\n          <artifactId>slf4j-api</artifactId>\n        </exclusion>\n      </exclusions>\n    </dependency>\n\n    <dependency>\n      <groupId>org.apache.spark</groupId>\n      <artifactId>spark-catalyst_${scala.compat.version}</artifactId>\n      <version>${spark.version}</version>\n      <scope>provided</scope>\n    </dependency>\n\n    <dependency>\n      <groupId>org.apache.spark</groupId>\n      <artifactId>spark-hive_${scala.compat.version}</artifactId>\n      <version>${spark.version}</version>\n      <scope>provided</scope>\n    </dependency>\n\n    <!-- our own version of parquet-hadoop, which is more recent than what earlier Spark versions depends on -->\n    <dependency>\n      <groupId>org.apache.parquet</groupId>\n      <artifactId>parquet-hadoop</artifactId>\n      <version>1.16.0</version>\n      <exclusions>\n        <exclusion>\n          <groupId>commons-pool</groupId>\n          <artifactId>commons-pool</artifactId>\n        </exclusion>\n        <exclusion>\n          <groupId>javax.annotation</groupId>\n          <artifactId>javax.annotation-api</artifactId>\n        </exclusion>\n        <exclusion>\n          <groupId>com.github.luben</groupId>\n          <artifactId>zstd-jni</artifactId>\n        </exclusion>\n      </exclusions>\n      <scope>provided</scope>\n    </dependency>\n\n    <dependency>\n      <groupId>com.github.scopt</groupId>\n      <artifactId>scopt_${scala.compat.version}</artifactId>\n      <!-- keep DIFF.md section \"Diff Spark application\" synced with this value -->\n      <version>4.1.0</version>\n    </dependency>\n\n    <!-- Test -->\n    <dependency>\n      <groupId>org.apache.spark</groupId>\n      <artifactId>spark-catalyst_${scala.compat.version}</artifactId>\n      <version>${spark.version}</version>\n      <classifier>tests</classifier>\n      <scope>test</scope>\n    </dependency>\n\n    <dependency>\n      <groupId>junit</groupId>\n      <artifactId>junit</artifactId>\n      <version>4.13.2</version>\n      <scope>test</scope>\n    </dependency>\n\n    <dependency>\n      <groupId>org.scalatest</groupId>\n      <artifactId>scalatest_${scala.compat.version}</artifactId>\n      <version>3.3.0-SNAP4</version>\n      <scope>test</scope>\n    </dependency>\n\n    <dependency>\n      <groupId>org.scalatestplus</groupId>\n      <artifactId>scalatestplus-junit_${scala.compat.version}</artifactId>\n      <version>1.0.0-M2</version>\n      <scope>test</scope>\n    </dependency>\n\n    <dependency>\n      <groupId>org.apache.parquet</groupId>\n      <artifactId>parquet-hadoop</artifactId>\n      <version>1.16.0</version>\n      <classifier>tests</classifier>\n      <scope>test</scope>\n      <exclusions>\n        <exclusion>\n          <groupId>*</groupId>\n          <artifactId>*</artifactId>\n        </exclusion>\n      </exclusions>\n    </dependency>\n  </dependencies>\n\n  <repositories>\n    <!-- the Maven central repository for releases -->\n    <repository>\n      <id>central</id>\n      <name>Maven Central</name>\n      <layout>default</layout>\n      <url>https://repo1.maven.org/maven2</url>\n      <releases>\n        <enabled>true</enabled>\n        <updatePolicy>never</updatePolicy>\n      </releases>\n      <snapshots>\n        <enabled>false</enabled>\n      </snapshots>\n    </repository>\n    <!-- the official source for Spark releases, mirrored to Maven central -->\n    <repository>\n      <id>apache releases</id>\n      <name>Apache Releases</name>\n      <url>https://repository.apache.org/content/repositories/releases/</url>\n      <releases>\n        <enabled>true</enabled>\n        <updatePolicy>never</updatePolicy>\n      </releases>\n      <snapshots>\n        <enabled>false</enabled>\n      </snapshots>\n    </repository>\n    <!-- required to resolve Spark snapshot versions -->\n    <repository>\n      <id>apache snapshots</id>\n      <name>Apache Snapshots</name>\n      <url>https://repository.apache.org/snapshots/</url>\n      <releases>\n        <enabled>false</enabled>\n      </releases>\n      <snapshots>\n        <enabled>true</enabled>\n        <updatePolicy>daily</updatePolicy>\n      </snapshots>\n    </repository>\n    <!-- required to resolve Spark release candidates -->\n    <!-- update temporary url and enable releases when needed-->\n    <repository>\n      <id>apache release candidate</id>\n      <name>Apache staging</name>\n      <url>https://repository.apache.org/content/repositories/orgapachespark-1478/</url>\n      <releases>\n        <enabled>false</enabled>\n      </releases>\n      <snapshots>\n        <enabled>false</enabled>\n      </snapshots>\n    </repository>\n  </repositories>\n\n  <build>\n    <sourceDirectory>src/main/scala</sourceDirectory>\n    <testSourceDirectory>src/test/java</testSourceDirectory>\n\n    <resources>\n      <resource>\n        <directory>python</directory>\n        <includes>\n          <include>gresearch/**/*.py</include>\n        </includes>\n      </resource>\n    </resources>\n    <testResources>\n      <testResource>\n        <directory>src/test/resources</directory>\n      </testResource>\n    </testResources>\n\n    <plugins>\n      <plugin>\n        <groupId>org.codehaus.mojo</groupId>\n        <artifactId>build-helper-maven-plugin</artifactId>\n        <version>3.5.0</version>\n        <executions>\n          <execution>\n            <id>spark-version-sources</id>\n            <phase>generate-sources</phase>\n            <goals>\n              <goal>add-source</goal>\n            </goals>\n            <configuration>\n              <sources>\n                <source>src/main/scala-spark-${spark.compat.version}</source>\n              </sources>\n            </configuration>\n          </execution>\n          <execution>\n            <id>spark-version-test-sources</id>\n            <phase>generate-test-sources</phase>\n            <goals>\n              <goal>add-test-source</goal>\n            </goals>\n            <configuration>\n              <sources>\n                <source>src/test/scala-spark-${spark.major.version}</source>\n              </sources>\n            </configuration>\n          </execution>\n        </executions>\n      </plugin>\n      <plugin>\n        <groupId>org.codehaus.mojo</groupId>\n        <artifactId>properties-maven-plugin</artifactId>\n        <version>1.2.1</version>\n        <executions>\n          <execution>\n            <phase>generate-resources</phase>\n            <goals>\n              <goal>write-project-properties</goal>\n            </goals>\n            <configuration>\n              <outputFile>${project.build.outputDirectory}/spark-extension-build.properties</outputFile>\n            </configuration>\n          </execution>\n        </executions>\n      </plugin>\n      <plugin>\n        <groupId>org.scala-tools</groupId>\n        <artifactId>maven-scala-plugin</artifactId>\n        <version>2.15.2</version>\n        <executions>\n          <execution>\n            <goals>\n              <goal>compile</goal>\n              <goal>testCompile</goal>\n            </goals>\n            <configuration>\n              <args>\n                <arg>-dependencyfile</arg>\n                <arg>${project.build.directory}/.scala_dependencies</arg>\n              </args>\n            </configuration>\n          </execution>\n        </executions>\n      </plugin>\n      <plugin>\n        <groupId>org.apache.maven.plugins</groupId>\n        <artifactId>maven-jar-plugin</artifactId>\n        <version>3.3.0</version>\n        <configuration>\n          <archive>\n            <manifest>\n              <addClasspath>true</addClasspath>\n              <!-- add Diff app as main class -->\n              <mainClass>uk.co.gresearch.spark.diff.App</mainClass>\n            </manifest>\n          </archive>\n        </configuration>\n      </plugin>\n      <!-- scalafmt -->\n      <plugin>\n        <groupId>com.diffplug.spotless</groupId>\n        <artifactId>spotless-maven-plugin</artifactId>\n        <version>2.30.0</version>\n        <configuration>\n          <scala>\n            <scalafmt>\n              <version>3.7.17</version>\n              <file>${project.basedir}/.scalafmt.conf</file>\n            </scalafmt>\n          </scala>\n        </configuration>\n        <executions>\n          <execution>\n            <!-- Runs in compile phase to fail fast in case of formatting issues.-->\n            <id>spotless-check</id>\n            <phase>compile</phase>\n            <goals>\n              <goal>check</goal>\n            </goals>\n          </execution>\n        </executions>\n      </plugin>\n      <!-- run tests -->\n      <plugin>\n        <groupId>org.apache.maven.plugins</groupId>\n        <artifactId>maven-surefire-plugin</artifactId>\n        <version>3.1.2</version>\n        <configuration>\n          <skipTests>false</skipTests>\n          <includes>\n            <include>**/*Tests.class</include>\n            <include>**/*Suite.class</include>\n          </includes>\n        </configuration>\n      </plugin>\n      <!-- packaging -->\n      <plugin>\n        <groupId>org.apache.maven.plugins</groupId>\n        <artifactId>maven-source-plugin</artifactId>\n        <version>3.3.0</version>\n        <executions>\n          <execution>\n            <id>attach-sources</id>\n            <goals>\n              <goal>jar-no-fork</goal>\n            </goals>\n          </execution>\n        </executions>\n      </plugin>\n      <plugin>\n        <groupId>net.alchim31.maven</groupId>\n        <artifactId>scala-maven-plugin</artifactId>\n        <version>4.8.1</version>\n        <executions>\n          <execution>\n            <id>attach-javadocs</id>\n            <goals>\n              <goal>doc-jar</goal>\n            </goals>\n          </execution>\n        </executions>\n      </plugin>\n      <!-- run integration tests -->\n      <plugin>\n        <groupId>org.apache.maven.plugins</groupId>\n        <artifactId>maven-failsafe-plugin</artifactId>\n        <version>3.3.0</version>\n        <configuration>\n          <additionalClasspathElements>\n            <additionalClasspathElement>${project.build.directory}/${project.build.finalName}.jar</additionalClasspathElement>\n          </additionalClasspathElements>\n          <includes>\n            <include>**/*Tests.class</include>\n            <include>**/*Suite.class</include>\n          </includes>\n          <reportsDirectory>${project.build.directory}/surefire-integration-reports/</reportsDirectory>\n        </configuration>\n        <executions>\n          <execution>\n            <goals>\n              <goal>integration-test</goal>\n              <goal>verify</goal>\n            </goals>\n            <configuration>\n              <environmentVariables>\n                <CI_INTEGRATION_TEST>true</CI_INTEGRATION_TEST>\n              </environmentVariables>\n            </configuration>\n          </execution>\n        </executions>\n      </plugin>\n      <!-- publishing -->\n      <plugin>\n        <groupId>org.sonatype.central</groupId>\n        <artifactId>central-publishing-maven-plugin</artifactId>\n        <version>0.8.0</version>\n        <extensions>true</extensions>\n        <configuration>\n          <publishingServerId>central</publishingServerId>\n          <autoPublish>true</autoPublish>\n          <waitUntil>published</waitUntil>\n        </configuration>\n      </plugin>\n      <plugin>\n        <groupId>org.apache.maven.plugins</groupId>\n        <artifactId>maven-gpg-plugin</artifactId>\n        <version>3.1.0</version>\n        <executions>\n          <execution>\n            <id>sign-artifacts</id>\n            <phase>verify</phase>\n            <goals>\n              <goal>sign</goal>\n            </goals>\n          </execution>\n        </executions>\n      </plugin>\n    </plugins>\n  </build>\n\n  <reporting>\n    <plugins>\n      <plugin>\n        <groupId>org.apache.maven.plugins</groupId>\n        <artifactId>maven-surefire-report-plugin</artifactId>\n        <version>3.1.2</version>\n      </plugin>\n    </plugins>\n  </reporting>\n\n</project>\n"
  },
  {
    "path": "python/README.md",
    "content": "# Spark Extension\n\nThis project provides extensions to the [Apache Spark project](https://spark.apache.org/) in Scala and Python:\n\n**[Diff](https://github.com/G-Research/spark-extension/blob/v2.13.0/DIFF.md):** A `diff` transformation and application for `Dataset`s that computes the differences between\ntwo datasets, i.e. which rows to _add_, _delete_ or _change_ to get from one dataset to the other.\n\n**[Histogram](https://github.com/G-Research/spark-extension/blob/v2.13.0/HISTOGRAM.md):** A `histogram` transformation that computes the histogram DataFrame for a value column.\n\n**[Global Row Number](https://github.com/G-Research/spark-extension/blob/v2.13.0/ROW_NUMBER.md):** A `withRowNumbers` transformation that provides the global row number w.r.t.\nthe current order of the Dataset, or any given order. In contrast to the existing SQL function `row_number`, which\nrequires a window spec, this transformation provides the row number across the entire Dataset without scaling problems.\n\n**[Inspect Parquet files](https://github.com/G-Research/spark-extension/blob/v2.13.0/PARQUET.md):** The structure of Parquet files (the metadata, not the data stored in Parquet) can be inspected similar to [parquet-tools](https://pypi.org/project/parquet-tools/)\nor [parquet-cli](https://pypi.org/project/parquet-cli/) by reading from a simple Spark data source.\nThis simplifies identifying why some Parquet files cannot be split by Spark into scalable partitions.\n\n**[Install Python packages into PySpark job](https://github.com/G-Research/spark-extension/blob/v2.13.0/PYSPARK-DEPS.md):** Install Python dependencies via PIP or Poetry programatically into your running PySpark job (PySpark ≥ 3.1.0):\n\n```python\n# noinspection PyUnresolvedReferences\nfrom gresearch.spark import *\n\n# using PIP\nspark.install_pip_package(\"pandas==1.4.3\", \"pyarrow\")\nspark.install_pip_package(\"-r\", \"requirements.txt\")\n\n# using Poetry\nspark.install_poetry_project(\"../my-poetry-project/\", poetry_python=\"../venv-poetry/bin/python\")\n```\n\n**Count null values:** `count_null(e: Column)`: an aggregation function like `count` that counts null values in column `e`.\nThis is equivalent to calling `count(when(e.isNull, lit(1)))`.\n\n**.Net DateTime.Ticks:** Convert .Net (C#, F#, Visual Basic) `DateTime.Ticks` into Spark timestamps, seconds and nanoseconds.\n\n<details>\n<summary>Available methods:</summary>\n\n```python\ndotnet_ticks_to_timestamp(column_or_name)         # returns timestamp as TimestampType\ndotnet_ticks_to_unix_epoch(column_or_name)        # returns Unix epoch seconds as DecimalType\ndotnet_ticks_to_unix_epoch_nanos(column_or_name)  # returns Unix epoch nanoseconds as LongType\n```\n\nThe reverse is provided by (all return `LongType` .Net ticks):\n```python\ntimestamp_to_dotnet_ticks(column_or_name)\nunix_epoch_to_dotnet_ticks(column_or_name)\nunix_epoch_nanos_to_dotnet_ticks(column_or_name)\n```\n</details>\n\n**Spark temporary directory**: Create a temporary directory that will be removed on Spark application shutdown.\n\n<details>\n<summary>Example:</summary>\n\n```python\n# noinspection PyUnresolvedReferences\nfrom gresearch.spark import *\n\ndir = spark.create_temporary_dir(\"prefix\")\n```\n</details>\n\n**Spark job description:** Set Spark job description for all Spark jobs within a context.\n\n<details>\n<summary>Example:</summary>\n\n```python\nfrom gresearch.spark import job_description, append_job_description\n\nwith job_description(\"parquet file\"):\n    df = spark.read.parquet(\"data.parquet\")\n    with append_job_description(\"count\"):\n        count = df.count\n    with append_job_description(\"write\"):\n        df.write.csv(\"data.csv\")\n```\n</details>\n\nFor details, see the [README.md](https://github.com/G-Research/spark-extension#spark-extension) at the project homepage.\n\n## Using Spark Extension\n\n#### PyPi package (local Spark cluster only)\n\nYou may want to install the `pyspark-extension` python package from PyPi into your development environment.\nThis provides you code completion, typing and test capabilities during your development phase.\n\nRunning your Python application on a Spark cluster will still require one of the ways below\nto add the Scala package to the Spark environment.\n\n```shell script\npip install pyspark-extension==2.15.0.3.4\n```\n\nNote: Pick the right Spark version (here 3.4) depending on your PySpark version.\n\n#### PySpark API\n\nStart a PySpark session with the Spark Extension dependency (version ≥1.1.0) as follows:\n\n```python\nfrom pyspark.sql import SparkSession\n\nspark = SparkSession \\\n    .builder \\\n    .config(\"spark.jars.packages\", \"uk.co.gresearch.spark:spark-extension_2.12:2.15.0-3.4\") \\\n    .getOrCreate()\n```\n\nNote: Pick the right Scala version (here 2.12) and Spark version (here 3.4) depending on your PySpark version.\n\n#### PySpark REPL\n\nLaunch the Python Spark REPL with the Spark Extension dependency (version ≥1.1.0) as follows:\n\n```shell script\npyspark --packages uk.co.gresearch.spark:spark-extension_2.12:2.15.0-3.4\n```\n\nNote: Pick the right Scala version (here 2.12) and Spark version (here 3.4) depending on your PySpark version.\n\n#### PySpark `spark-submit`\n\nRun your Python scripts that use PySpark via `spark-submit`:\n\n```shell script\nspark-submit --packages uk.co.gresearch.spark:spark-extension_2.12:2.15.0-3.4 [script.py]\n```\n\nNote: Pick the right Scala version (here 2.12) and Spark version (here 3.4) depending on your Spark version.\n\n### Your favorite Data Science notebook\n\nThere are plenty of [Data Science notebooks](https://datasciencenotebook.org/) around. To use this library,\nadd **a jar dependency** to your notebook using these **Maven coordinates**:\n\n    uk.co.gresearch.spark:spark-extension_2.12:2.15.0-3.4\n\nOr [download the jar](https://mvnrepository.com/artifact/uk.co.gresearch.spark/spark-extension) and place it\non a filesystem where it is accessible by the notebook, and reference that jar file directly.\n\nCheck the documentation of your favorite notebook to learn how to add jars to your Spark environment.\n\n"
  },
  {
    "path": "python/gresearch/__init__.py",
    "content": "#  Copyright 2020 G-Research\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n\n"
  },
  {
    "path": "python/gresearch/spark/__init__.py",
    "content": "#  Copyright 2020 G-Research\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n\nimport os\nimport re\nimport shutil\nimport subprocess\nimport sys\nimport tempfile\nimport time\nfrom contextlib import contextmanager\nfrom pathlib import Path\nfrom typing import Any, Union, List, Optional, Mapping, Iterable, TYPE_CHECKING\n\nfrom py4j.java_gateway import JVMView, JavaObject\nfrom pyspark import __version__\nfrom pyspark.context import SparkContext\nfrom pyspark.files import SparkFiles\nfrom pyspark.sql import DataFrame, DataFrameReader, SQLContext\nfrom pyspark.sql.column import Column\nfrom pyspark.sql.context import SQLContext\nfrom pyspark import SparkConf\nfrom pyspark.sql.functions import col, count, lit, when\nfrom pyspark.sql.session import SparkSession\nfrom pyspark.storagelevel import StorageLevel\n\nif __version__.startswith('4.'):\n    from pyspark.sql.classic.column import _to_java_column\nelse:\n    from pyspark.sql.column import _to_java_column\n\ntry:\n    from pyspark.sql.connect.column import Column as ConnectColumn\n    from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame\n    from pyspark.sql.connect.readwriter import DataFrameReader as ConnectDataFrameReader\n    from pyspark.sql.connect.session import SparkSession as ConnectSparkSession\n    has_connect = True\nexcept ImportError:\n    has_connect = False\n\nif TYPE_CHECKING:\n    from pyspark.sql._typing import ColumnOrName\n\n_java_pkg_is_installed: Optional[bool] = None\n\n\n_column_types = (Column,)\n_dataframe_types = (DataFrame,)\nif has_connect:\n    _column_types += (ConnectColumn, )\n    _dataframe_types += (ConnectDataFrame, )\n_column_types_and_str = (str,) + _column_types\n\n\ndef _is_column(obj: Any) -> bool:\n    return isinstance(obj, _column_types)\n\n\ndef _is_column_or_str(obj: Any) -> bool:\n    return isinstance(obj, _column_types_and_str)\n\n\ndef _is_dataframe(obj: Any) -> bool:\n    return isinstance(obj, _dataframe_types)\n\n\ndef _check_java_pkg_is_installed(jvm: JVMView) -> bool:\n    \"\"\"Check that the Java / Scala package is installed.\"\"\"\n    try:\n        jvm.uk.co.gresearch.spark.__getattr__(\"package$\").__getattr__(\"MODULE$\").VersionString()\n        return True\n    except TypeError as e:\n        print(e.args)\n        return False\n    except:\n        # any other exception indicate some problem, be safe and do not fail fast here\n        return True\n\n\ndef _get_jvm(obj: Any) -> JVMView:\n    \"\"\"\n    Provides easy access to the JVMView provided by Spark, and raises meaningful error message if that is not available.\n    Also checks that the Java / Scala package is accessible via this JVMView.\n    \"\"\"\n    if obj is None:\n        if SparkContext._active_spark_context is None:\n            raise RuntimeError(\"This method must be called inside an active Spark session\")\n        else:\n            raise ValueError(\"Cannot provide access to JVM from None\")\n\n    if has_connect and isinstance(obj, (ConnectDataFrame, ConnectDataFrameReader, ConnectSparkSession)):\n        raise RuntimeError('This feature is not supported for Spark Connect. Please use a classic Spark client. '\n                           'https://github.com/G-Research/spark-extension#spark-connect-server')\n\n    if isinstance(obj, DataFrame):\n        jvm = _get_jvm(obj._sc)\n    elif isinstance(obj, DataFrameReader):\n        jvm = _get_jvm(obj._spark)\n    elif isinstance(obj, SparkSession):\n        jvm = _get_jvm(obj.sparkContext)\n    elif isinstance(obj, (SparkContext, SQLContext)):\n        jvm = obj._jvm\n    else:\n        raise RuntimeError(f'Unsupported class: {type(obj)}')\n\n    global _java_pkg_is_installed\n    if _java_pkg_is_installed is None:\n        _java_pkg_is_installed = _check_java_pkg_is_installed(jvm)\n    if not _java_pkg_is_installed:\n        raise RuntimeError(\"Java / Scala package not found! You need to add the Maven spark-extension package \"\n                           \"to your PySpark environment: https://github.com/G-Research/spark-extension#python\")\n\n    return jvm\n\n\ndef _to_seq(jvm: JVMView, list: List[Any]) -> JavaObject:\n    array = jvm.java.util.ArrayList(list)\n    return jvm.scala.collection.JavaConverters.asScalaIteratorConverter(array.iterator()).asScala().toSeq()\n\n\ndef _to_map(jvm: JVMView, map: Mapping[Any, Any]) -> JavaObject:\n    return jvm.scala.collection.JavaConverters.mapAsScalaMap(map)\n\n\ndef backticks(*name_parts: str) -> str:\n    for np in name_parts:\n        assert isinstance(np, str), np\n    return '.'.join([f'`{part}`'\n                     if '.' in part and not part.startswith('`') and not part.endswith('`')\n                     else part\n                     for part in name_parts])\n\n\ndef distinct_prefix_for(existing: List[str]) -> str:\n    assert isinstance(existing, Iterable)\n    for e in existing:\n        assert isinstance(e, str), e\n\n    # count number of suffix _ for each existing column name\n    length = 1\n    if existing:\n        length = max([len(name) - len(name.lstrip('_')) for name in existing]) + 1\n    # return string with one more _ than that\n    return '_' * length\n\n\ndef handle_configured_case_sensitivity(column_name: str, case_sensitive: bool) -> str:\n    \"\"\"\n    Produces a column name that considers configured case-sensitivity of column names. When case sensitivity is\n    deactivated, it lower-cases the given column name and no-ops otherwise.\n    \"\"\"\n    assert isinstance(column_name, str), column_name\n    assert isinstance(case_sensitive, bool), case_sensitive\n\n    if case_sensitive:\n        return column_name\n    return column_name.lower()\n\n\ndef list_contains_case_sensitivity(column_names: Iterable[str], columnName: str, case_sensitive: bool) -> bool:\n    assert isinstance(column_names, Iterable), column_names\n    for cn in column_names:\n        assert isinstance(cn, str), cn\n    assert isinstance(columnName, str), columnName\n    assert isinstance(case_sensitive, bool), case_sensitive\n\n    return handle_configured_case_sensitivity(columnName, case_sensitive) in [handle_configured_case_sensitivity(c, case_sensitive) for c in column_names]\n\n\ndef list_filter_case_sensitivity(column_names: Iterable[str], filter: Iterable[str], case_sensitive: bool) -> List[str]:\n    assert isinstance(column_names, Iterable), column_names\n    for cn in column_names:\n        assert isinstance(cn, str), cn\n    assert isinstance(filter, Iterable), filter\n    for f in filter:\n        assert isinstance(f, str), f\n    assert isinstance(case_sensitive, bool), case_sensitive\n\n    filter_set = {handle_configured_case_sensitivity(f, case_sensitive) for f in filter}\n    return [c for c in column_names if handle_configured_case_sensitivity(c, case_sensitive) in filter_set]\n\n\ndef list_diff_case_sensitivity(column_names: Iterable[str], other: Iterable[str], case_sensitive: bool) -> List[str]:\n    assert isinstance(column_names, Iterable), column_names\n    for cn in column_names:\n        assert isinstance(cn, str), cn\n    assert isinstance(other, Iterable), filter\n    for o in other:\n        assert isinstance(o, str), o\n    assert isinstance(case_sensitive, bool), case_sensitive\n\n    other_set = {handle_configured_case_sensitivity(f, case_sensitive) for f in other}\n    return [c for c in column_names if handle_configured_case_sensitivity(c, case_sensitive) not in other_set]\n\n\ndef dotnet_ticks_to_timestamp(tick_column: Union[str, Column]) -> Column:\n    \"\"\"\n    Convert a .Net `DateTime.Ticks` timestamp to a Spark timestamp. The input column must be\n    convertible to a number (e.g. string, int, long). The Spark timestamp type does not support\n    nanoseconds, so the last digit of the timestamp (1/10 of a microsecond) is lost.\n    {{{\n      df.select(col(\"ticks\"), dotNetTicksToTimestamp(\"ticks\").alias(\"timestamp\")).show(false)\n    }}}\n    +------------------+--------------------------+\n    |ticks             |timestamp                 |\n    +------------------+--------------------------+\n    |638155413748959318|2023-03-27 21:16:14.895931|\n    +------------------+--------------------------+\n\n    Note: the example timestamp lacks the 8/10 of a microsecond. Use `dotNetTicksToUnixEpoch` to\n    preserve the full precision of the tick timestamp.\n\n    https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks\n\n    :param tick_column: column with a tick value (str or Column)\n    :return: timestamp column\n    \"\"\"\n    if not _is_column_or_str(tick_column):\n        raise ValueError(f\"Given column must be a column name (str) or column instance (Column): {type(tick_column)}\")\n\n    jvm = _get_jvm(SparkContext._active_spark_context)\n    func = jvm.uk.co.gresearch.spark.__getattr__(\"package$\").__getattr__(\"MODULE$\").dotNetTicksToTimestamp\n    return Column(func(_to_java_column(tick_column)))\n\n\ndef dotnet_ticks_to_unix_epoch(tick_column: Union[str, Column]) -> Column:\n    \"\"\"\n    Convert a .Net `DateTime.Ticks` timestamp to a Unix epoch decimal. The input column must be\n    convertible to a number (e.g. string, int, long). The full precision of the tick timestamp\n    is preserved (1/10 of a microsecond).\n\n    Example:\n    {{{\n      df.select(col(\"ticks\"), dotNetTicksToUnixEpoch(\"ticks\").alias(\"timestamp\")).show(false)\n    }}}\n\n    +------------------+--------------------+\n    |ticks             |timestamp           |\n    +------------------+--------------------+\n    |638155413748959318|1679944574.895931800|\n    +------------------+--------------------+\n\n    https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks\n\n    :param tick_column: column with a tick value (str or Column)\n    :return: Unix epoch column\n    \"\"\"\n    if not _is_column_or_str(tick_column):\n        raise ValueError(f\"Given column must be a column name (str) or column instance (Column): {type(tick_column)}\")\n\n    jvm = _get_jvm(SparkContext._active_spark_context)\n    func = jvm.uk.co.gresearch.spark.__getattr__(\"package$\").__getattr__(\"MODULE$\").dotNetTicksToUnixEpoch\n    return Column(func(_to_java_column(tick_column)))\n\n\ndef dotnet_ticks_to_unix_epoch_nanos(tick_column: Union[str, Column]) -> Column:\n    \"\"\"\n    Convert a .Net `DateTime.Ticks` timestamp to a Unix epoch nanoseconds. The input column must be\n    convertible to a number (e.g. string, int, long). The full precision of the tick timestamp\n    is preserved (1/10 of a microsecond).\n\n    Example:\n    {{{\n      df.select(col(\"ticks\"), dotNetTicksToUnixEpoch(\"ticks\").alias(\"timestamp\")).show(false)\n    }}}\n\n    +------------------+-------------------+\n    |ticks             |timestamp          |\n    +------------------+-------------------+\n    |638155413748959318|1679944574895931800|\n    +------------------+-------------------+\n\n    https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks\n\n    :param tick_column: column with a tick value (str or Column)\n    :return: Unix epoch column\n    \"\"\"\n    if not _is_column_or_str(tick_column):\n        raise ValueError(f\"Given column must be a column name (str) or column instance (Column): {type(tick_column)}\")\n\n    jvm = _get_jvm(SparkContext._active_spark_context)\n    func = jvm.uk.co.gresearch.spark.__getattr__(\"package$\").__getattr__(\"MODULE$\").dotNetTicksToUnixEpochNanos\n    return Column(func(_to_java_column(tick_column)))\n\n\ndef timestamp_to_dotnet_ticks(timestamp_column: Union[str, Column]) -> Column:\n    \"\"\"\n    Convert a Spark timestamp to a .Net `DateTime.Ticks` timestamp.\n    The input column must be of TimestampType.\n\n    Example:\n    {{{\n      df.select(col(\"timestamp\"), timestampToDotNetTicks(\"timestamp\").alias(\"ticks\")).show(false)\n    }}}\n\n    +--------------------------+------------------+\n    |timestamp                 |ticks             |\n    +--------------------------+------------------+\n    |2023-03-27 21:16:14.895931|638155413748959310|\n    +--------------------------+------------------+\n\n    https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks\n\n    :param timestamp_column: column with a timestamp value\n    :return: tick value column\n    \"\"\"\n    if not _is_column_or_str(timestamp_column):\n        raise ValueError(f\"Given column must be a column name (str) or column instance (Column): {type(timestamp_column)}\")\n\n    jvm = _get_jvm(SparkContext._active_spark_context)\n    func = jvm.uk.co.gresearch.spark.__getattr__(\"package$\").__getattr__(\"MODULE$\").timestampToDotNetTicks\n    return Column(func(_to_java_column(timestamp_column)))\n\n\ndef unix_epoch_to_dotnet_ticks(unix_column: Union[str, Column]) -> Column:\n    \"\"\"\n    Convert a Unix epoch timestamp to a .Net `DateTime.Ticks` timestamp.\n    The input column must represent a numerical unix epoch timestamp, e.g. long, double, string or decimal.\n    The input must not be of TimestampType, as that may be interpreted incorrectly.\n    Use `timestampToDotNetTicks` for TimestampType columns instead.\n\n    Example:\n    {{{\n      df.select(col(\"unix\"), unixEpochToDotNetTicks(\"unix\").alias(\"ticks\")).show(false)\n    }}}\n\n    +-----------------------------+------------------+\n    |unix                         |ticks             |\n    +-----------------------------+------------------+\n    |1679944574.895931234000000000|638155413748959312|\n    +-----------------------------+------------------+\n\n    https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks\n\n    :param unix_column: column with a unix epoch value\n    :return: tick value column\n    \"\"\"\n    if not _is_column_or_str(unix_column):\n        raise ValueError(f\"Given column must be a column name (str) or column instance (Column): {type(unix_column)}\")\n\n    jvm = _get_jvm(SparkContext._active_spark_context)\n    func = jvm.uk.co.gresearch.spark.__getattr__(\"package$\").__getattr__(\"MODULE$\").unixEpochToDotNetTicks\n    return Column(func(_to_java_column(unix_column)))\n\n\ndef unix_epoch_nanos_to_dotnet_ticks(unix_column: Union[str, Column]) -> Column:\n    \"\"\"\n    Convert a Unix epoch nanosecond timestamp to a .Net `DateTime.Ticks` timestamp.\n    The .Net ticks timestamp does not support the two lowest nanosecond digits,\n    so only a 1/10 of a microsecond is the smallest resolution.\n    The input column must represent a numerical unix epoch nanoseconds timestamp,\n    e.g. long, double, string or decimal.\n\n    Example:\n    {{{\n      df.select(col(\"unix_nanos\"), unixEpochNanosToDotNetTicks(\"unix_nanos\").alias(\"ticks\")).show(false)\n    }}}\n\n    +-------------------+------------------+\n    |unix_nanos         |ticks             |\n    +-------------------+------------------+\n    |1679944574895931234|638155413748959312|\n    +-------------------+------------------+\n\n    https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks\n\n    :param unix_column: column with a unix epoch value\n    :return: tick value column\n    \"\"\"\n    if not _is_column_or_str(unix_column):\n        raise ValueError(f\"Given column must be a column name (str) or column instance (Column): {type(unix_column)}\")\n\n    jvm = _get_jvm(SparkContext._active_spark_context)\n    func = jvm.uk.co.gresearch.spark.__getattr__(\"package$\").__getattr__(\"MODULE$\").unixEpochNanosToDotNetTicks\n    return Column(func(_to_java_column(unix_column)))\n\n\ndef count_null(e: \"ColumnOrName\") -> Column:\n    \"\"\"\n    Aggregate function: returns the number of items in a group that are not null.\n\n    Parameters\n    ----------\n    col : :class:`~pyspark.sql.Column` or str target column to compute on.\n\n    Returns\n    -------\n    :class:`~pyspark.sql.Column`\n        column for computed results.\n    \"\"\"\n    if isinstance(e, str):\n        e = col(e)\n    if not _is_column(e):\n        raise ValueError(f\"Given column must be a column name (str) or column instance (Column): {type(e)}\")\n\n    return count(when(e.isNull(), lit(1)))\n\n\ndef histogram(self: DataFrame,\n              thresholds: List[Union[int, float]],\n              value_column: str,\n              *aggregate_columns: str) -> DataFrame:\n\n    if len(thresholds) == 0:\n        t = 'Int'\n    else:\n        t = type(thresholds[0])\n        if t == int:\n            t = 'Int'\n        elif t == float:\n            t = 'Double'\n        else:\n            raise ValueError('thresholds must be int or floats: {}'.format(t))\n\n    jvm = _get_jvm(self)\n    col = jvm.org.apache.spark.sql.functions.col\n    value_column = col(value_column)\n    aggregate_columns = [col(column) for column in aggregate_columns]\n\n    hist = jvm.uk.co.gresearch.spark.Histogram\n    jdf = hist.of(self._jdf, _to_seq(jvm, thresholds), value_column, _to_seq(jvm, aggregate_columns))\n    return DataFrame(jdf, self.session_or_ctx())\n\n\nDataFrame.histogram = histogram\nif has_connect:\n    ConnectDataFrame.histogram = histogram\n\n\nclass UnpersistHandle:\n    def __init__(self, handle):\n        self._handle = handle\n\n    def __call__(self, blocking: Optional[bool] = None):\n        if self._handle is not None:\n            if blocking is None:\n                self._handle.apply()\n            else:\n                self._handle.apply(blocking)\n\n\ndef unpersist_handle(self: SparkSession) -> UnpersistHandle:\n    jvm = _get_jvm(self)\n    handle = jvm.uk.co.gresearch.spark.UnpersistHandle()\n    return UnpersistHandle(handle)\n\n\nSparkSession.unpersist_handle = unpersist_handle\n\n\ndef _get_sort_cols(df: DataFrame, order: Union[str, Column, List[Union[str, Column]]], ascending: Union[bool, List[bool]]):\n    if __version__.startswith('3.'):\n        # pyspark<4\n        return df._sort_cols([order], {'ascending': ascending})\n\n    # pyspark>=4\n    _cols = df._preapare_cols_for_sort(col, [order], {\"ascending\": ascending})\n    return df._jseq(_cols, _to_java_column)\n\n\ndef with_row_numbers(self: DataFrame,\n                     storage_level: StorageLevel = StorageLevel.MEMORY_AND_DISK,\n                     unpersist_handle: Optional[UnpersistHandle] = None,\n                     row_number_column_name: str = \"row_number\",\n                     order: Union[str, Column, List[Union[str, Column]]] = [],\n                     ascending: Union[bool, List[bool]] = True) -> DataFrame:\n    jvm = _get_jvm(self)\n    jsl = self._sc._getJavaStorageLevel(storage_level)\n    juho = jvm.uk.co.gresearch.spark.UnpersistHandle\n    juh = unpersist_handle._handle if unpersist_handle else juho.Noop()\n    jcols = _get_sort_cols(self, order, ascending) if not isinstance(order, list) or order else jvm.PythonUtils.toSeq([])\n\n    row_numbers = jvm.uk.co.gresearch.spark.RowNumbers\n    jdf = row_numbers \\\n        .withRowNumberColumnName(row_number_column_name) \\\n        .withStorageLevel(jsl) \\\n        .withUnpersistHandle(juh) \\\n        .withOrderColumns(jcols) \\\n        .of(self._jdf)\n\n    return DataFrame(jdf, self.session_or_ctx())\n\n\nDataFrame.with_row_numbers = with_row_numbers\nif has_connect:\n    ConnectDataFrame.with_row_numbers = with_row_numbers\n\n\ndef session(self: DataFrame) -> SparkSession:\n    return self.sparkSession if hasattr(self, 'sparkSession') else self.sql_ctx.sparkSession\n\n\ndef session_or_ctx(self: DataFrame) -> Union[SparkSession, SQLContext]:\n    return self.sparkSession if hasattr(self, 'sparkSession') else self.sql_ctx\n\n\nDataFrame.session = session\nDataFrame.session_or_ctx = session_or_ctx\nif has_connect:\n    ConnectDataFrame.session = session\n    ConnectDataFrame.session_or_ctx = session_or_ctx\n\n\ndef set_description(description: Optional[str], if_not_set: bool = False):\n    if description is not None:\n        assert isinstance(description, str), description\n    assert isinstance(if_not_set, bool), if_not_set\n\n    context = SparkContext._active_spark_context\n    jvm = _get_jvm(context)\n    spark_package = jvm.uk.co.gresearch.spark.__getattr__(\"package$\").__getattr__(\"MODULE$\")\n    return spark_package.setJobDescription(description, if_not_set, context._jsc.sc())\n\n\n@contextmanager\ndef job_description(description: str, if_not_set: bool = False):\n    \"\"\"\n    Adds a job description to all Spark jobs started within this context.\n    The current Job description is restored after leaving the context.\n\n    Usage example:\n\n    >>> from gresearch.spark import job_description\n    >>>\n    >>> with job_description(\"parquet file\"):\n    ...     df = spark.read.parquet(\"data.parquet\")\n    ...     count = df.count\n\n    With ``if_not_set = True``, the description is only set if no job description is set yet.\n\n    Any modification to the job description within the context is reverted on exit,\n    even if `if_not_set = True`.\n\n    :param description: job description\n    :param if_not_set: job description is only set if no description is set yet\n    \"\"\"\n    earlier = set_description(description, if_not_set)\n    try:\n        yield\n    finally:\n        set_description(earlier)\n\n\ndef append_description(extra_description: str, separator: str = \" - \"):\n    assert isinstance(extra_description, str), extra_description\n    assert isinstance(separator, str), separator\n\n    context = SparkContext._active_spark_context\n    jvm = _get_jvm(context)\n    spark_package = jvm.uk.co.gresearch.spark.__getattr__(\"package$\").__getattr__(\"MODULE$\")\n    return spark_package.appendJobDescription(extra_description, separator, context._jsc.sc())\n\n\n@contextmanager\ndef append_job_description(extra_description: str, separator: str = \" - \"):\n    \"\"\"\n    Appends a job description to all Spark jobs started within this context.\n    The current Job description is extended by the separator and the extra description\n    on entering the context, and restored after leaving the context.\n\n    Usage example:\n\n    >>> from gresearch.spark import append_job_description\n    >>>\n    >>> with append_job_description(\"parquet file\"):\n    ...     df = spark.read.parquet(\"data.parquet\")\n    ...     with append_job_description(\"count\"):\n    ...         count = df.count\n\n    Any modification to the job description within the context is reverted on exit.\n\n    :param extra_description: job description to be appended\n    :param separator: separator used when appending description\n    \"\"\"\n    earlier = append_description(extra_description, separator)\n    try:\n        yield\n    finally:\n        set_description(earlier)\n\n\ndef create_temporary_dir(spark: Union[SparkSession, SparkContext], prefix: str) -> str:\n    \"\"\"\n    Create a temporary directory in a location (driver temp dir) that will be deleted on Spark application shutdown.\n    :param spark: spark session or context\n    :param prefix: prefix string of temporary directory name\n    :return: absolute path of temporary directory\n    \"\"\"\n    jvm = _get_jvm(spark)\n    root_dir = jvm.org.apache.spark.SparkFiles.getRootDirectory()\n    return tempfile.mkdtemp(prefix=prefix, dir=root_dir)\n\n\nSparkSession.create_temporary_dir = create_temporary_dir\nSparkContext.create_temporary_dir = create_temporary_dir\n\nif has_connect:\n    ConnectSparkSession.create_temporary_dir = create_temporary_dir\n\n\ndef install_pip_package(spark: Union[SparkSession, SparkContext], *package_or_pip_option: str) -> None:\n    if __version__.startswith('2.') or __version__.startswith('3.0.'):\n        raise NotImplementedError(f'Not supported for PySpark __version__')\n\n    for option in package_or_pip_option:\n        assert isinstance(option, str), option\n\n    # just here to assert JVM is accessible\n    _get_jvm(spark)\n\n    if isinstance(spark, SparkSession):\n        spark = spark.sparkContext\n\n    # create temporary directory for packages, inside a directory which will be deleted on spark application shutdown\n    id = f\"spark-extension-pip-pkgs-{time.time()}\"\n    dir = spark.create_temporary_dir(f\"{id}-\")\n\n    # install packages via pip install\n    # it is best to run pip as a separate process and not calling into module pip\n    # https://pip.pypa.io/en/stable/user_guide/#using-pip-from-your-program\n    subprocess.check_call([sys.executable, '-m', 'pip', \"install\"] + list(package_or_pip_option) + [\"--target\", dir])\n\n    # zip packages and remove directory\n    zip = shutil.make_archive(dir, \"zip\", dir)\n    shutil.rmtree(dir)\n\n    # register zip file as archive, and add as python source\n    # once support for Spark 3.0 is dropped, replace with spark.addArchive()\n    spark._jsc.sc().addArchive(zip + \"#\" + id)\n    spark._python_includes.append(id)\n    sys.path.insert(1, os.path.join(SparkFiles.getRootDirectory(), id))\n\n\nSparkSession.install_pip_package = install_pip_package\nSparkContext.install_pip_package = install_pip_package\n\nif has_connect:\n    ConnectSparkSession.install_pip_package = install_pip_package\n\n\ndef install_poetry_project(spark: Union[SparkSession, SparkContext],\n                           *project: str,\n                           poetry_python: Optional[str] = None,\n                           pip_args: Optional[List[str]] = None) -> None:\n    import logging\n    logger = logging.getLogger()\n\n    # spark.install_pip_dependency has this limitation, and it is used by this method\n    # and we want to fail quickly here\n    if __version__.startswith('2.') or __version__.startswith('3.0.'):\n        raise NotImplementedError(f'Not supported for PySpark __version__')\n\n    for p in project:\n        assert isinstance(p, str), p\n\n    if poetry_python is not None:\n        assert isinstance(poetry_python, str), poetry_python\n    if pip_args is not None:\n        for pa in pip_args:\n            assert isinstance(pa, str), pa\n\n    # just here to assert JVM is accessible\n    _get_jvm(spark)\n\n    if isinstance(spark, SparkSession):\n        spark = spark.sparkContext\n    if poetry_python is None:\n        poetry_python = sys.executable\n    if pip_args is None:\n        pip_args = []\n\n    def check_and_log_poetry(proc: subprocess.CompletedProcess) -> List[str]:\n        stdout = proc.stdout.decode('utf-8').splitlines(keepends=False)\n        for line in stdout:\n            logger.info(f\"poetry: {line}\")\n\n        stderr = proc.stderr.decode('utf-8').splitlines(keepends=False)\n        for line in stderr:\n            logger.error(f\"poetry: {line}\")\n\n        if proc.returncode != 0:\n            raise RuntimeError(f'Poetry process terminated with exit code {proc.returncode}')\n\n        return stdout\n\n    def build_wheel(project: Path) -> Path:\n        logger.info(f\"Running poetry using {poetry_python}\")\n\n        # make sure the virtual env for this project exists, otherwise we won't get to see the build whl file in stdout\n        proc = subprocess.run([\n            poetry_python, '-m', 'poetry',\n            'env', 'use',\n            '--directory', str(project.absolute()),\n            sys.executable\n        ], capture_output=True)\n        check_and_log_poetry(proc)\n\n        # build the whl file\n        proc = subprocess.run([\n            poetry_python, '-m', 'poetry',\n            'build',\n            '--verbose',\n            '--no-interaction',\n            '--format', 'wheel',\n            '--directory', str(project.absolute())\n        ], capture_output=True)\n        stdout = check_and_log_poetry(proc)\n\n        # first matching line is taken to extract whl file name\n        whl_pattern = \"^  - Built (.*.whl)$\"\n        for line in stdout:\n            if match := re.match(whl_pattern, line):\n                return project.joinpath('dist', match.group(1))\n\n        raise RuntimeError(f'Could not find wheel file name in poetry output, was looking for \"{whl_pattern}\"')\n\n    wheels = [build_wheel(Path(path)) for path in project]\n\n    # install wheels via pip\n    spark.install_pip_package(*[str(whl.absolute()) for whl in wheels] + pip_args)\n\n\nSparkSession.install_poetry_project = install_poetry_project\nSparkContext.install_poetry_project = install_poetry_project\n\nif has_connect:\n    ConnectSparkSession.install_poetry_project = install_poetry_project\n"
  },
  {
    "path": "python/gresearch/spark/diff/__init__.py",
    "content": "#  Copyright 2020 G-Research\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\nimport dataclasses\nfrom dataclasses import dataclass\nfrom enum import Enum\nfrom functools import reduce\nfrom typing import Optional, Dict, Mapping, Any, Callable, List, Tuple, Union, Iterable, overload\n\nfrom py4j.java_gateway import JavaObject, JVMView\nfrom pyspark.sql import DataFrame, Column\nfrom pyspark.sql.functions import col, lit, when, concat, coalesce, array, struct\nfrom pyspark.sql.types import DataType, StructField, ArrayType\n\nfrom gresearch.spark import _get_jvm, _to_seq, _to_map, backticks, distinct_prefix_for, \\\n    handle_configured_case_sensitivity, list_contains_case_sensitivity, list_filter_case_sensitivity, list_diff_case_sensitivity, \\\n    has_connect, _is_dataframe\nfrom gresearch.spark.diff.comparator import DiffComparator, DiffComparators, DefaultDiffComparator\n\n\ntry:\n    # There is a chance users use the Python code contained in the jvm package with Spark\n    # without ever pip installing the whl package and thus lacking dependencies like this\n    #from typing_extensions import deprecated\n    raise ImportError()\nexcept ImportError:\n    from typing import TypeVar\n\n    _T = TypeVar(\"_T\")\n\n    class deprecated:\n        def __init__(self, msg: str) -> None:\n            self.msg = msg\n\n        def __call__(self, func: _T) -> _T:\n            import warnings\n\n            def deprecated_func(*args, **kwargs):\n                warnings.warn(self.msg, DeprecationWarning, stacklevel=2)\n                return func(*args, **kwargs)\n\n            return deprecated_func\n\n\nclass DiffMode(Enum):\n    ColumnByColumn = \"ColumnByColumn\"\n    SideBySide = \"SideBySide\"\n    LeftSide = \"LeftSide\"\n    RightSide = \"RightSide\"\n\n    # should be in sync with default defined in Java\n    Default = ColumnByColumn\n\n    def _to_java(self, jvm: JVMView) -> JavaObject:\n        return jvm.uk.co.gresearch.spark.diff.DiffMode.withNameOption(self.name).get()\n\n\n@dataclass(frozen=True)\nclass DiffOptions:\n    \"\"\"\n    Configuration class for diffing Datasets.\n\n    :param diff_column: name of the diff column\n    :type diff_column: str\n    :param left_column_prefix: prefix of columns from the left Dataset\n    :type left_column_prefix: str\n    :param right_column_prefix: prefix of columns from the right Dataset\n    :type right_column_prefix: str\n    :param insert_diff_value: value in diff column for inserted rows\n    :type insert_diff_value: str\n    :param change_diff_value: value in diff column for changed rows\n    :type change_diff_value: str\n    :param delete_diff_value: value in diff column for deleted rows\n    :type delete_diff_value: str\n    :param nochange_diff_value: value in diff column for un-changed rows\n    :type nochange_diff_value: str\n    :param change_column: name of change column\n    :type change_column: str\n    :param diff_mode: diff mode\n    :type diff_mode: DiffMode\n    :param sparse_mode: sparse mode\n    :type sparse_mode: bool\n    \"\"\"\n    diff_column: str = 'diff'\n    left_column_prefix: str = 'left'\n    right_column_prefix: str = 'right'\n    insert_diff_value: str = 'I'\n    change_diff_value: str = 'C'\n    delete_diff_value: str = 'D'\n    nochange_diff_value: str = 'N'\n    change_column: Optional[str] = None\n    diff_mode: DiffMode = DiffMode.Default\n    sparse_mode: bool = False\n    default_comparator: DiffComparator = DefaultDiffComparator()\n    data_type_comparators: Dict[DataType, DiffComparator] = dataclasses.field(default_factory=lambda: dict())\n    column_name_comparators: Dict[str, DiffComparator] = dataclasses.field(default_factory=lambda: dict())\n\n    def with_diff_column(self, diff_column: str) -> 'DiffOptions':\n        \"\"\"\n        Fluent method to change the diff column name.\n        Returns a new immutable DiffOptions instance with the new diff column name.\n\n        :param diff_column: new diff column name\n        :type diff_column: str\n        :return: new immutable DiffOptions instance\n        :rtype: DiffOptions\n        \"\"\"\n        assert isinstance(diff_column, str), diff_column\n        return dataclasses.replace(self, diff_column=diff_column)\n\n    def with_left_column_prefix(self, left_column_prefix: str) -> 'DiffOptions':\n        \"\"\"\n        Fluent method to change the prefix of columns from the left Dataset.\n        Returns a new immutable DiffOptions instance with the new column prefix.\n\n        :param left_column_prefix: new column prefix\n        :type left_column_prefix: str\n        :return: new immutable DiffOptions instance\n        :rtype: DiffOptions\n        \"\"\"\n        assert isinstance(left_column_prefix, str), left_column_prefix\n        return dataclasses.replace(self, left_column_prefix=left_column_prefix)\n\n    def with_right_column_prefix(self, right_column_prefix: str) -> 'DiffOptions':\n        \"\"\"\n        Fluent method to change the prefix of columns from the right Dataset.\n        Returns a new immutable DiffOptions instance with the new column prefix.\n\n        :param right_column_prefix: new column prefix\n        :type right_column_prefix: str\n        :return: new immutable DiffOptions instance\n        :rtype: DiffOptions\n        \"\"\"\n        assert isinstance(right_column_prefix, str), right_column_prefix\n        return dataclasses.replace(self, right_column_prefix=right_column_prefix)\n\n    def with_insert_diff_value(self, insert_diff_value: str) -> 'DiffOptions':\n        \"\"\"\n        Fluent method to change the value of inserted rows in the diff column.\n        Returns a new immutable DiffOptions instance with the new diff value.\n\n        :param insert_diff_value: new diff value\n        :type insert_diff_value: str\n        :return: new immutable DiffOptions instance\n        :rtype: DiffOptions\n        \"\"\"\n        assert isinstance(insert_diff_value, str), insert_diff_value\n        return dataclasses.replace(self, insert_diff_value=insert_diff_value)\n\n    def with_change_diff_value(self, change_diff_value: str) -> 'DiffOptions':\n        \"\"\"\n        Fluent method to change the value of changed rows in the diff column.\n        Returns a new immutable DiffOptions instance with the new diff value.\n\n        :param change_diff_value: new diff column name\n        :type change_diff_value: str\n        :return: new immutable DiffOptions instance\n        :rtype: DiffOptions\n        \"\"\"\n        assert isinstance(change_diff_value, str), change_diff_value\n        return dataclasses.replace(self, change_diff_value=change_diff_value)\n\n    def with_delete_diff_value(self, delete_diff_value: str) -> 'DiffOptions':\n        \"\"\"\n        Fluent method to change the value of deleted rows in the diff column.\n        Returns a new immutable DiffOptions instance with the new diff value.\n\n        :param delete_diff_value: new diff column name\n        :type delete_diff_value: str\n        :return: new immutable DiffOptions instance\n        :rtype: DiffOptions\n        \"\"\"\n        assert isinstance(delete_diff_value, str), delete_diff_value\n        return dataclasses.replace(self, delete_diff_value=delete_diff_value)\n\n    def with_nochange_diff_value(self, nochange_diff_value: str) -> 'DiffOptions':\n        \"\"\"\n        Fluent method to change the value of un-changed rows in the diff column.\n        Returns a new immutable DiffOptions instance with the new diff value.\n\n        :param nochange_diff_value: new diff column name\n        :type nochange_diff_value: str\n        :return: new immutable DiffOptions instance\n        :rtype: DiffOptions\n        \"\"\"\n        assert isinstance(nochange_diff_value, str), nochange_diff_value\n        return dataclasses.replace(self, nochange_diff_value=nochange_diff_value)\n\n    def with_change_column(self, change_column: str) -> 'DiffOptions':\n        \"\"\"\n        Fluent method to change the change column name.\n        Returns a new immutable DiffOptions instance with the new change column name.\n\n        :param change_column: new change column name\n        :type change_column: str\n        :return: new immutable DiffOptions instance\n        :rtype: DiffOptions\n        \"\"\"\n        assert isinstance(change_column, str), change_column\n        return dataclasses.replace(self, change_column=change_column)\n\n    def without_change_column(self) -> 'DiffOptions':\n        \"\"\"\n        Fluent method to remove change column.\n        Returns a new immutable DiffOptions instance without a change column.\n\n        :return: new immutable DiffOptions instance\n        :rtype: DiffOptions\n        \"\"\"\n        return dataclasses.replace(self, change_column=None)\n\n    def with_diff_mode(self, diff_mode: DiffMode) -> 'DiffOptions':\n        \"\"\"\n        Fluent method to change the diff mode.\n        Returns a new immutable DiffOptions instance with the new diff mode.\n\n        :param diff_mode: new diff mode\n        :type diff_mode: DiffMode\n        :return: new immutable DiffOptions instance\n        :rtype: DiffOptions\n        \"\"\"\n        assert isinstance(diff_mode, DiffMode), diff_mode\n        return dataclasses.replace(self, diff_mode=diff_mode)\n\n    def with_sparse_mode(self, sparse_mode: bool) -> 'DiffOptions':\n        \"\"\"\n        Fluent method to change the sparse mode.\n        Returns a new immutable DiffOptions instance with the new sparse mode.\n\n        :param sparse: new sparse mode\n        :type sparse: bool\n        :return: new immutable DiffOptions instance\n        :rtype: DiffOptions\n        \"\"\"\n        assert isinstance(sparse_mode, bool), sparse_mode\n        return dataclasses.replace(self, sparse_mode=sparse_mode)\n\n    def with_default_comparator(self, comparator: DiffComparator) -> 'DiffOptions':\n        assert isinstance(comparator, DiffComparator), comparator\n        return dataclasses.replace(self, default_comparator=comparator)\n\n    def with_data_type_comparator(self, comparator: DiffComparator, *data_type: DataType) -> 'DiffOptions':\n        assert isinstance(comparator, DiffComparator), comparator\n        for dt in data_type:\n            assert isinstance(dt, DataType), dt\n\n        existing_data_types = {dt.simpleString() for dt in data_type if dt in self.data_type_comparators.keys()}\n        if existing_data_types:\n            existing_data_types = sorted(list(existing_data_types))\n            raise ValueError(f'A comparator for data type{\"s\" if len(existing_data_types) > 1 else \"\"} '\n                             f'{\", \".join(existing_data_types)} exists already.')\n\n        data_type_comparators = self.data_type_comparators.copy()\n        data_type_comparators.update({dt: comparator for dt in data_type})\n        return dataclasses.replace(self, data_type_comparators=data_type_comparators)\n\n    def with_column_name_comparator(self, comparator: DiffComparator, *column_name: str) -> 'DiffOptions':\n        assert isinstance(comparator, DiffComparator), comparator\n        for cn in column_name:\n            assert isinstance(cn, str), cn\n\n        existing_column_names = {cn for cn in column_name if cn in self.column_name_comparators.keys()}\n        if existing_column_names:\n            existing_column_names = sorted(list(existing_column_names))\n            raise ValueError(f'A comparator for column name{\"s\" if len(existing_column_names) > 1 else \"\"} '\n                             f'{\", \".join(existing_column_names)} exists already.')\n\n        column_name_comparators = self.column_name_comparators.copy()\n        column_name_comparators.update({dt: comparator for dt in column_name})\n        return dataclasses.replace(self, column_name_comparators=column_name_comparators)\n\n    def comparator_for(self, column: StructField) -> DiffComparator:\n        assert isinstance(column, StructField), column\n        cmp = self.column_name_comparators.get(column.name)\n        if cmp is None:\n            cmp = self.data_type_comparators.get(column.dataType)\n        if cmp is None:\n            cmp = self.default_comparator\n        return cmp\n\n\nclass Differ:\n    \"\"\"\n    Differ class to diff two Datasets. See Differ.of(…) for details.\n\n    :param options: options for the diffing process\n    :type options: DiffOptions\n    \"\"\"\n    def __init__(self, options: DiffOptions = None):\n        self._options = options or DiffOptions()\n\n    @overload\n    def diff(self, left: DataFrame, right: DataFrame, *id_columns: str) -> DataFrame: ...\n\n    @overload\n    def diff(self, left: DataFrame, right: DataFrame, id_columns: Iterable[str], ignore_columns: Iterable[str]) -> DataFrame: ...\n\n    def diff(self, left: DataFrame, right: DataFrame, *id_or_ignore_columns: Union[str, Iterable[str]]) -> DataFrame:\n        \"\"\"\n        Returns a new DataFrame that contains the differences between the two DataFrames.\n\n        Both DataFrames must contain the same set of column names and data types.\n        The order of columns in the two DataFrames is not important as columns are compared based on the\n        name, not the position.\n\n        Optional id columns are used to uniquely identify rows to compare. If values in any non-id\n        column are differing between the two DataFrames, then that row is marked as `\"C\"`hange\n        and `\"N\"`o-change otherwise. Rows of the right DataFrame, that do not exist in the left DataFrame\n        (w.r.t. the values in the id columns) are marked as `\"I\"`nsert. And rows of the left DataFrame,\n        that do not exist in the right DataFrame are marked as `\"D\"`elete.\n\n        If no id columns are given, all columns are considered id columns. Then, no `\"C\"`hange rows\n        will appear, as all changes will exist as respective `\"D\"`elete and `\"I\"`nsert.\n\n        Values in optional ignore columns are not compared but included in the output DataFrame.\n\n        The returned DataFrame has the `diff` column as the first column. This holds the `\"N\"`, `\"C\"`,\n        `\"I\"` or `\"D\"` strings. The id columns follow, then the non-id columns (all remaining columns).\n\n        .. code-block:: python\n\n          df1 = spark.createDataFrame([(1, \"one\"), (2, \"two\"), (3, \"three\")], [\"id\", \"value\"])\n          df2 = spark.createDataFrame([(1, \"one\"), (2, \"Two\"), (4, \"four\")], [\"id\", \"value\"])\n       \n          differ.diff(df1, df2).show()\n       \n          // output:\n          // +----+---+-----+\n          // |diff| id|value|\n          // +----+---+-----+\n          // |   N|  1|  one|\n          // |   D|  2|  two|\n          // |   I|  2|  Two|\n          // |   D|  3|three|\n          // |   I|  4| four|\n          // +----+---+-----+\n       \n          differ.diff(df1, df2, \"id\").show()\n       \n          // output:\n          // +----+---+----------+-----------+\n          // |diff| id|left_value|right_value|\n          // +----+---+----------+-----------+\n          // |   N|  1|       one|        one|\n          // |   C|  2|       two|        Two|\n          // |   D|  3|     three|       null|\n          // |   I|  4|      null|       four|\n          // +----+---+----------+-----------+\n\n        The id columns are in order as given to the method. If no id columns are given then all\n        columns of this DataFrame are id columns and appear in the same order. The remaining non-id\n        columns are in the order of this DataFrame.\n\n        :param left: left DataFrame\n        :type left: DataFrame\n        :param right: right DataFrame\n        :type right: DataFrame\n        :param id_or_ignore_columns: either id column names or two lists of column names,\n               first the id column names, second the ignore column names\n        :type *id_or_ignore_columns: str | Iterable[str]\n        :return: the diff DataFrame\n        :rtype DataFrame\n        \"\"\"\n        assert _is_dataframe(left), left\n        assert _is_dataframe(right), right\n        assert isinstance(id_or_ignore_columns, (str, Iterable)), id_or_ignore_columns\n\n        if len(id_or_ignore_columns) == 2 and all(isinstance(lst, Iterable) and not isinstance(lst, str) for lst in id_or_ignore_columns):\n            id_columns, ignore_columns = id_or_ignore_columns\n            if any(not isinstance(id, str) for id in id_columns):\n                raise ValueError(f\"The id_columns must all be strings: {', '.join(type(id).__name__ for id in id_columns)}\")\n            if any(not isinstance(ignore, str) for ignore in ignore_columns):\n                raise ValueError(f\"The ignore_columns must all be strings: {', '.join(type(ignore).__name__ for ignore in ignore_columns)}\")\n        elif all(isinstance(lst, str) for lst in id_or_ignore_columns):\n            id_columns, ignore_columns = (id_or_ignore_columns, [])\n        else:\n            raise ValueError(f\"The id_or_ignore_columns argument must either all be strings or exactly two iterables of strings: {', '.join(type(e).__name__ for e in id_or_ignore_columns)}\")\n\n        return self._do_diff(left, right, id_columns, ignore_columns)\n\n    @staticmethod\n    def _columns_of_side(df: DataFrame, id_columns: List[str], side_prefix: str) -> List[Column]:\n        prefix = side_prefix + '_'\n        return [col(c) if c in id_columns else col(c).alias(c.replace(prefix, \"\"))\n                for c in df.columns if c in id_columns or c.startswith(side_prefix)]\n\n    @overload\n    def diffwith(self, left: DataFrame, right: DataFrame, *id_columns: str) -> DataFrame: ...\n\n    @overload\n    def diffwith(self, left: DataFrame, right: DataFrame, id_columns: Iterable[str], ignore_columns: Iterable[str]) -> DataFrame: ...\n\n    def diffwith(self, left: DataFrame, right: DataFrame, *id_or_ignore_columns: Union[str, Iterable[str]]) -> DataFrame:\n        \"\"\"\n        Returns a new DataFrame that contains the differences between the two DataFrames\n        as tuples of type `(String, Row, Row)`.\n\n        See `diff(left: DataFrame, right: DataFrame, *id_columns: str)`.\n\n        :param left: left DataFrame\n        :type left: DataFrame\n        :param right: right DataFrame\n        :type right: DataFrame\n        :param id_or_ignore_columns: either id column names or two lists of column names,\n               first the id column names, second the ignore column names\n        :type id_or_ignore_columns: str\n        :return: the diff DataFrame\n        :rtype DataFrame\n        \"\"\"\n        assert _is_dataframe(left), left\n        assert _is_dataframe(right), right\n        assert isinstance(id_or_ignore_columns, (str, Iterable)), id_or_ignore_columns\n\n        if len(id_or_ignore_columns) == 2 and all([isinstance(lst, Iterable) for lst in id_or_ignore_columns]):\n            id_columns, ignore_columns = id_or_ignore_columns\n            if any(not isinstance(id, str) for id in id_columns):\n                raise ValueError(f\"The id_columns must all be strings: {', '.join(type(id).__name__ for id in id_columns)}\")\n            if any(not isinstance(ignore, str) for ignore in ignore_columns):\n                raise ValueError(f\"The ignore_columns must all be strings: {', '.join(type(ignore).__name__ for ignore in ignore_columns)}\")\n        elif all(isinstance(lst, str) for lst in id_or_ignore_columns):\n            id_columns, ignore_columns = (id_or_ignore_columns, [])\n        else:\n            raise ValueError(f\"The id_or_ignore_columns argument must either all be strings or exactly two iterables of strings: {', '.join(type(e).__name__ for e in id_or_ignore_columns)}\")\n\n        diff = self._do_diff(left, right, id_columns, ignore_columns)\n        left_columns = self._columns_of_side(diff, id_columns, self._options.left_column_prefix)\n        right_columns = self._columns_of_side(diff, id_columns, self._options.right_column_prefix)\n        diff_column = col(self._options.diff_column)\n\n        left_struct = when(diff_column == self._options.insert_diff_value, lit(None)) \\\n            .otherwise(struct(*left_columns)) \\\n            .alias(self._options.left_column_prefix)\n        right_struct = when(diff_column == self._options.delete_diff_value, lit(None)) \\\n            .otherwise(struct(*right_columns)) \\\n            .alias(self._options.right_column_prefix)\n        return diff.select(diff_column, left_struct, right_struct)\n\n    def _check_schema(self, left: DataFrame, right: DataFrame, id_columns: List[str], ignore_columns: List[str], case_sensitive: bool):\n        def require(result: bool, message: str) -> None:\n            if not result:\n                raise ValueError(message)\n\n        require(\n            len(left.columns) == len(set(left.columns)) and len(right.columns) == len(set(right.columns)),\n            f\"The datasets have duplicate columns.\\n\" +\n            f\"Left column names: {', '.join(left.columns)}\\n\" +\n            f\"Right column names: {', '.join(right.columns)}\")\n\n        left_non_ignored = list_diff_case_sensitivity(left.columns, ignore_columns, case_sensitive)\n        right_non_ignored = list_diff_case_sensitivity(right.columns, ignore_columns, case_sensitive)\n\n        except_ignored_columns_msg = ' except ignored columns' if ignore_columns else ''\n\n        require(\n            len(left_non_ignored) == len(right_non_ignored),\n            \"The number of columns doesn't match.\\n\" +\n            f\"Left column names{except_ignored_columns_msg} ({len(left_non_ignored)}): {', '.join(left_non_ignored)}\\n\" +\n            f\"Right column names{except_ignored_columns_msg} ({len(right_non_ignored)}): {', '.join(right_non_ignored)}\"\n        )\n\n        require(len(left_non_ignored) > 0, f\"The schema{except_ignored_columns_msg} must not be empty\")\n\n        # column types must match but we ignore the nullability of columns\n        left_fields = {handle_configured_case_sensitivity(field.name, case_sensitive): field.dataType\n                       for field in left.schema.fields\n                       if not list_contains_case_sensitivity(ignore_columns, field.name, case_sensitive)}\n        right_fields = {handle_configured_case_sensitivity(field.name, case_sensitive): field.dataType\n                        for field in right.schema.fields\n                        if not list_contains_case_sensitivity(ignore_columns, field.name, case_sensitive)}\n        left_extra_schema = set(left_fields.items()) - set(right_fields.items())\n        right_extra_schema = set(right_fields.items()) - set(left_fields.items())\n        require(\n            len(left_extra_schema) == 0 and len(right_extra_schema) == 0,\n            \"The datasets do not have the same schema.\\n\" +\n            f\"Left extra columns: {', '.join([f'{f} ({t.typeName()})' for f, t in sorted(list(left_extra_schema))])}\\n\" +\n            f\"Right extra columns: {', '.join([f'{f} ({t.typeName()})' for f, t in sorted(list(right_extra_schema))])}\")\n\n        columns = left_non_ignored\n        pk_columns = id_columns or columns\n        non_pk_columns = list_diff_case_sensitivity(columns, pk_columns, case_sensitive)\n        missing_id_columns = list_diff_case_sensitivity(pk_columns, columns, case_sensitive)\n        require(\n            len(missing_id_columns) == 0,\n            f\"Some id columns do not exist: {', '.join(missing_id_columns)} missing among {', '.join(columns)}\"\n        )\n\n        missing_ignore_columns = list_diff_case_sensitivity(ignore_columns, left.columns + right.columns, case_sensitive)\n        require(\n            len(missing_ignore_columns) == 0,\n            f\"Some ignore columns do not exist: {', '.join(missing_ignore_columns)} \" +\n            f\"missing among {', '.join(sorted(list(set(left_non_ignored + right_non_ignored))))}\"\n        )\n\n        require(\n            not list_contains_case_sensitivity(pk_columns, self._options.diff_column, case_sensitive),\n            f\"The id columns must not contain the diff column name '{self._options.diff_column}': {', '.join(pk_columns)}\"\n        )\n        require(\n            self._options.change_column is None or not list_contains_case_sensitivity(pk_columns, self._options.change_column, case_sensitive),\n            f\"The id columns must not contain the change column name '{self._options.change_column}': {', '.join(pk_columns)}\"\n        )\n        diff_value_columns = self._get_diff_value_columns(pk_columns, non_pk_columns, left, right, ignore_columns, case_sensitive)\n        diff_value_columns = {n for n, t in diff_value_columns}\n\n        if self._options.diff_mode in [DiffMode.LeftSide, DiffMode.RightSide]:\n            require(\n                not list_contains_case_sensitivity(diff_value_columns, self._options.diff_column, case_sensitive),\n                f\"The {'left' if self._options.diff_mode == DiffMode.LeftSide else 'right'} \" +\n                f\"non-id columns must not contain the diff column name '{self._options.diff_column}': \" +\n                f\"{', '.join(list_diff_case_sensitivity((left if self._options.diff_mode == DiffMode.LeftSide else right).columns, id_columns, case_sensitive))}\"\n            )\n\n            require(\n                self._options.change_column is None or not list_contains_case_sensitivity(diff_value_columns, self._options.change_column, case_sensitive),\n                f\"The {'left' if self._options.diff_mode == DiffMode.LeftSide else 'right'} \" +\n                f\"non-id columns must not contain the change column name '{self._options.change_column}': \" +\n                f\"{', '.join(list_diff_case_sensitivity((left if self._options.diff_mode == DiffMode.LeftSide else right).columns, id_columns, case_sensitive))}\"\n            )\n        else:\n            require(\n                not list_contains_case_sensitivity(diff_value_columns, self._options.diff_column, case_sensitive),\n                f\"The column prefixes '{self._options.left_column_prefix}' and '{self._options.right_column_prefix}', \" +\n                f\"together with these non-id columns must not produce the diff column name '{self._options.diff_column}': \" +\n                f\"{', '.join(non_pk_columns)}\"\n            )\n\n            require(\n                self._options.change_column is None or not list_contains_case_sensitivity(diff_value_columns, self._options.change_column, case_sensitive),\n                f\"The column prefixes '{self._options.left_column_prefix}' and '{self._options.right_column_prefix}', \" +\n                f\"together with these non-id columns must not produce the change column name '{self._options.change_column}': \" +\n                f\"{', '.join(non_pk_columns)}\"\n            )\n\n            require(\n                all(not list_contains_case_sensitivity(pk_columns, c, case_sensitive) for c in diff_value_columns),\n                f\"The column prefixes '{self._options.left_column_prefix}' and '{self._options.right_column_prefix}', \" +\n                f\"together with these non-id columns must not produce any id column name '{', '.join(pk_columns)}': \" +\n                f\"{', '.join(non_pk_columns)}\"\n            )\n\n    def _get_change_column(self,\n                           exists_column_name: str,\n                           value_columns_with_comparator: List[Tuple[str, DiffComparator]],\n                           left: DataFrame,\n                           right: DataFrame) -> Optional[Column]:\n        if self._options.change_column is None:\n            return None\n        if not self._options.change_column:\n            return array().cast(ArrayType(StringType, containsNull = false)).alias(self._options.change_column)\n        return when(left[exists_column_name].isNull() | right[exists_column_name].isNull(), lit(None)) \\\n            .otherwise(\n                concat(*[when(cmp.equiv(left[c], right[c]), array()).otherwise(array(lit(c)))\n                         for (c, cmp) in value_columns_with_comparator])) \\\n            .alias(self._options.change_column)\n\n    def _do_diff(self, left: DataFrame, right: DataFrame, id_columns: List[str], ignore_columns: List[str]) -> DataFrame:\n        case_sensitive = left.session().conf.get(\"spark.sql.caseSensitive\") == \"true\"\n        self._check_schema(left, right, id_columns, ignore_columns, case_sensitive)\n\n        columns = list_diff_case_sensitivity(left.columns, ignore_columns, case_sensitive)\n        pk_columns = id_columns or columns\n        value_columns = list_diff_case_sensitivity(columns, pk_columns, case_sensitive)\n        value_struct_fields = {f.name: f for f in left.schema.fields}\n        value_columns_with_comparator = [(c, self._options.comparator_for(value_struct_fields[c])) for c in value_columns]\n\n        exists_column_name = distinct_prefix_for(left.columns) + \"exists\"\n        left_with_exists = left.withColumn(exists_column_name, lit(1))\n        right_with_exists = right.withColumn(exists_column_name, lit(1))\n        join_condition = reduce(lambda l, r: l & r,\n                                [left_with_exists[c].eqNullSafe(right_with_exists[c])\n                                 for c in pk_columns])\n        un_changed = reduce(lambda l, r: l & r,\n                            [cmp.equiv(left_with_exists[c], right_with_exists[c])\n                             for (c, cmp) in value_columns_with_comparator],\n                            lit(True))\n        change_condition = ~un_changed\n\n        diff_action_column = \\\n            when(left_with_exists[exists_column_name].isNull(), lit(self._options.insert_diff_value)) \\\n            .when(right_with_exists[exists_column_name].isNull(), lit(self._options.delete_diff_value)) \\\n            .when(change_condition, lit(self._options.change_diff_value)) \\\n            .otherwise(lit(self._options.nochange_diff_value)) \\\n            .alias(self._options.diff_column)\n\n        diff_columns = [c[1] for c in self._get_diff_columns(pk_columns, value_columns, left, right, ignore_columns, case_sensitive)]\n        # turn this column into a list of one or none column so we can easily concat it below with diffActionColumn and diffColumns\n        change_column = self._get_change_column(exists_column_name, value_columns_with_comparator, left_with_exists, right_with_exists)\n        change_columns = [change_column] if change_column is not None else []\n\n        return left_with_exists \\\n            .join(right_with_exists, join_condition, \"fullouter\") \\\n            .select(*([diff_action_column] + change_columns + diff_columns))\n\n    def _get_diff_id_columns(self, pk_columns: List[str],\n                                left: DataFrame,\n                                right: DataFrame) -> List[Tuple[str, Column]]:\n        return [(c, coalesce(left[c], right[c]).alias(c)) for c in pk_columns]\n\n    def _get_diff_value_columns(self, pk_columns: List[str],\n                          value_columns: List[str],\n                          left: DataFrame,\n                          right: DataFrame,\n                          ignore_columns: List[str],\n                          case_sensitive: bool) -> List[Tuple[str, Column]]:\n        left_value_columns = list_filter_case_sensitivity(left.columns, value_columns, case_sensitive)\n        right_value_columns = list_filter_case_sensitivity(right.columns, value_columns, case_sensitive)\n\n        left_non_pk_columns = list_diff_case_sensitivity(left.columns, pk_columns, case_sensitive)\n        right_non_pk_columns = list_diff_case_sensitivity(right.columns, pk_columns, case_sensitive)\n\n        left_ignored_columns = list_filter_case_sensitivity(left.columns, ignore_columns, case_sensitive)\n        right_ignored_columns = list_filter_case_sensitivity(right.columns, ignore_columns, case_sensitive)\n        left_values = {handle_configured_case_sensitivity(c, case_sensitive): (c, when(~(left[c].eqNullSafe(right[c])), left[c]) if self._options.sparse_mode else left[c]) for c in left_non_pk_columns}\n        right_values = {handle_configured_case_sensitivity(c, case_sensitive): (c, when(~(left[c].eqNullSafe(right[c])), right[c]) if self._options.sparse_mode else right[c]) for c in right_non_pk_columns}\n\n        def alias(prefix: Optional[str], values: Dict[str, Tuple[str, Column]]) -> Callable[[str], Tuple[str, Column]]:\n            def func(name: str) -> (str, Column):\n                name, column = values[handle_configured_case_sensitivity(name, case_sensitive)]\n                alias = name if prefix is None else f'{prefix}_{name}'\n                return alias, column.alias(alias)\n\n            return func\n\n        def alias_left(name: str) -> (str, Column):\n            return alias(self._options.left_column_prefix, left_values)(name)\n\n        def alias_right(name: str) -> (str, Column):\n            return alias(self._options.right_column_prefix, right_values)(name)\n\n        prefixed_left_ignored_columns = [alias_left(c) for c in left_ignored_columns]\n        prefixed_right_ignored_columns = [alias_right(c) for c in right_ignored_columns]\n\n        if self._options.diff_mode == DiffMode.ColumnByColumn:\n            non_id_columns = \\\n                [c for vc in value_columns for c in [alias_left(vc), alias_right(vc)]] + \\\n                [c for ic in ignore_columns for c in (\n                   ([alias_left(ic)] if list_contains_case_sensitivity(left_ignored_columns, ic, case_sensitive) else []) +\n                   ([alias_right(ic)] if list_contains_case_sensitivity(right_ignored_columns, ic, case_sensitive) else [])\n                )]\n        elif self._options.diff_mode == DiffMode.SideBySide:\n            non_id_columns = \\\n                [alias_left(c) for c in left_value_columns] + prefixed_left_ignored_columns + \\\n                [alias_right(c) for c in right_value_columns] + prefixed_right_ignored_columns\n        elif self._options.diff_mode == DiffMode.LeftSide:\n            non_id_columns = \\\n                [alias(None, left_values)(c) for c in value_columns] +\\\n                [alias(None, left_values)(c) for c in left_ignored_columns]\n        elif self._options.diff_mode == DiffMode.RightSide:\n            non_id_columns = \\\n                [alias(None, right_values)(c) for c in value_columns] + \\\n                [alias(None, right_values)(c) for c in right_ignored_columns]\n        else:\n            raise RuntimeError(f'Unsupported diff mode: {self._options.diff_mode}')\n\n        return non_id_columns\n\n    def _get_diff_columns(self, pk_columns: List[str],\n                          value_columns: List[str],\n                          left: DataFrame,\n                          right: DataFrame,\n                          ignore_columns: List[str],\n                          case_sensitive: bool) -> List[Tuple[str, Column]]:\n        return self._get_diff_id_columns(pk_columns, left, right) + \\\n               self._get_diff_value_columns(pk_columns, value_columns, left, right, ignore_columns, case_sensitive)\n\n\n@overload\ndef diff(self: DataFrame, other: DataFrame, *id_columns: str) -> DataFrame: ...\n\n\n@overload\ndef diff(self: DataFrame, other: DataFrame, id_columns: Iterable[str], ignore_columns: Iterable[str]) -> DataFrame: ...\n\n\n@overload\ndef diff(self: DataFrame, other: DataFrame, *id_or_ignore_columns: Union[str, Iterable[str]]) -> DataFrame: ...\n\n\n@overload\ndef diff(self: DataFrame, other: DataFrame, options: DiffOptions, *id_columns: str) -> DataFrame: ...\n\n\n@overload\ndef diff(self: DataFrame, other: DataFrame, options: DiffOptions, id_columns: Iterable[str], ignore_columns: Iterable[str]) -> DataFrame: ...\n\n\n@overload\ndef diff(self: DataFrame, other: DataFrame, options: DiffOptions, *id_or_ignore_columns: Union[str, Iterable[str]]) -> DataFrame: ...\n\n\ndef diff(self: DataFrame, other: DataFrame, *options_or_id_or_ignore_columns: Union[DiffOptions, str, Iterable[str]]) -> DataFrame:\n    \"\"\"\n    Returns a new DataFrame that contains the differences between this and the other DataFrame.\n    Both DataFrames must contain the same set of column names and data types.\n    The order of columns in the two DataFrames is not important as one column is compared to the\n    column with the same name of the other DataFrame, not the column with the same position.\n\n    Optional options allow for customizing diffing behaviour and diff result schema.\n\n    Optional id columns are used to uniquely identify rows to compare. If values in any non-id\n    column are differing between this and the other DataFrame, then that row is marked as `\"C\"`hange\n    and `\"N\"`o-change otherwise. Rows of the other DataFrame, that do not exist in this DataFrame\n    (w.r.t. the values in the id columns) are marked as `\"I\"`nsert. And rows of this DataFrame, that\n    do not exist in the other DataFrame are marked as `\"D\"`elete.\n\n    If no id columns are given, all columns are considered id columns. Then, no `\"C\"`hange rows\n    will appear, as all changes will exist as respective `\"D\"`elete and `\"I\"`nsert.\n\n    Values in optional ignore columns are not compared but included in the output DataFrame.\n\n    The returned DataFrame has the `diff` column as the first column. This holds the `\"N\"`, `\"C\"`,\n    `\"I\"` or `\"D\"` strings. The id columns follow, then the non-id columns (all remaining columns).\n\n    .. code-block:: python\n\n      df1 = spark.createDataFrame([(1, \"one\"), (2, \"two\"), (3, \"three\")], [\"id\", \"value\"])\n      df2 = spark.createDataFrame([(1, \"one\"), (2, \"Two\"), (4, \"four\")], [\"id\", \"value\"])\n\n      df1.diff(df2).show()\n\n      // output:\n      // +----+---+-----+\n      // |diff| id|value|\n      // +----+---+-----+\n      // |   N|  1|  one|\n      // |   D|  2|  two|\n      // |   I|  2|  Two|\n      // |   D|  3|three|\n      // |   I|  4| four|\n      // +----+---+-----+\n\n      df1.diff(df2, \"id\").show()\n\n      // output:\n      // +----+---+----------+-----------+\n      // |diff| id|left_value|right_value|\n      // +----+---+----------+-----------+\n      // |   N|  1|       one|        one|\n      // |   C|  2|       two|        Two|\n      // |   D|  3|     three|       null|\n      // |   I|  4|      null|       four|\n      // +----+---+----------+-----------+\n\n    The id columns are in order as given to the method. If no id columns are given then all\n    columns of this DataFrame are id columns and appear in the same order. The remaining non-id\n    columns are in the order of this DataFrame.\n\n    :param other: right DataFrame\n    :type other: DataFrame\n    :param options: optional diff options\n    :type options: DiffOptions\n    :param id_columns: id columns\n    :type id_columns: str\n    :param ignore_columns: optional ignored columns\n    :type ignore_columns: str\n    :param id_or_ignore_columns: either id column names or two lists of column names,\n           first the id column names, second the ignore column names\n    :type id_or_ignore_columns: str\n    :return: the diff DataFrame\n    :rtype DataFrame\n    \"\"\"\n    if any(isinstance(i, DiffOptions) for i in options_or_id_or_ignore_columns):\n        options = options_or_id_or_ignore_columns[0]\n        if not isinstance(options, DiffOptions):\n            raise ValueError(\"Diff options must be given as second argument\")\n        id_or_ignore_columns = options_or_id_or_ignore_columns[1:]\n        return Differ(options).diff(self, other, *id_or_ignore_columns)\n\n    id_or_ignore_columns = options_or_id_or_ignore_columns\n    return Differ().diff(self, other, *id_or_ignore_columns)\n\n\n@overload\ndef diffwith(self: DataFrame, other: DataFrame, *id_columns: str) -> DataFrame: ...\n\n\n@overload\ndef diffwith(self: DataFrame, other: DataFrame, id_columns: Iterable[str], ignore_columns: Iterable[str]) -> DataFrame: ...\n\n\n@overload\ndef diffwith(self: DataFrame, other: DataFrame, *id_or_ignore_columns: Union[str, Iterable[str]]) -> DataFrame: ...\n\n\n@overload\ndef diffwith(self: DataFrame, other: DataFrame, options: DiffOptions, *id_columns: str) -> DataFrame: ...\n\n\n@overload\ndef diffwith(self: DataFrame, other: DataFrame, options: DiffOptions, id_columns: Iterable[str], ignore_columns: Iterable[str]) -> DataFrame: ...\n\n\n@overload\ndef diffwith(self: DataFrame, other: DataFrame, options: DiffOptions, *id_or_ignore_columns: Union[str, Iterable[str]]) -> DataFrame: ...\n\n\ndef diffwith(self: DataFrame, other: DataFrame, *options_or_id_or_ignore_columns: Union[DiffOptions, str, Iterable[str]]) -> DataFrame:\n    \"\"\"\n    Returns a new DataFrame that contains the differences between the two DataFrames\n    as tuples of type `(String, Row, Row)`.\n\n    See `diff(left: DataFrame, right: DataFrame, *options_or_id_or_ignore_columns: str)`.\n\n    :param left: left DataFrame\n    :type left: DataFrame\n    :param right: right DataFrame\n    :type right: DataFrame\n    :param options: diff options\n    :type options: DiffOptions\n    :param id_columns: id columns\n    :type id_columns: str\n    :param ignore_columns: optional ignored columns\n    :type ignore_columns: str\n    :param id_or_ignore_columns: either id column names or two lists of column names,\n           first the id column names, second the ignore column names\n    :type id_or_ignore_columns: str\n    :return: the diff DataFrame\n    :rtype DataFrame\n    \"\"\"\n    if any(isinstance(i, DiffOptions) for i in options_or_id_or_ignore_columns):\n        options = options_or_id_or_ignore_columns[0]\n        if not isinstance(options, DiffOptions):\n            raise ValueError(\"Diff options must be given as second argument\")\n        id_or_ignore_columns = options_or_id_or_ignore_columns[1:]\n        return Differ(options).diffwith(self, other, *id_or_ignore_columns)\n\n    id_or_ignore_columns = options_or_id_or_ignore_columns\n    return Differ().diffwith(self, other, *id_or_ignore_columns)\n\n\n@overload\ndef diff_with_options(self: DataFrame, other: DataFrame, options: DiffOptions, *id_columns: str) -> DataFrame: ...\n\n\n@overload\ndef diff_with_options(self: DataFrame, other: DataFrame, options: DiffOptions, id_columns: Iterable[str], ignore_columns: Iterable[str]) -> DataFrame: ...\n\n\n@deprecated(\"Use diff with identical arguments instead\")\ndef diff_with_options(self: DataFrame, other: DataFrame, options: DiffOptions, *id_or_ignore_columns: Union[str, Iterable[str]]) -> DataFrame:\n    \"\"\"\n    Returns a new DataFrame that contains the differences between this and the other DataFrame.\n\n    See `diff(other: DataFrame, *id_columns: str)`.\n\n    The schema of the returned DataFrame can be configured by the given `DiffOptions`.\n\n    :param other: right DataFrame\n    :type other: DataFrame\n    :param id_or_ignore_columns: either id column names or two lists of column names,\n           first the id column names, second the ignore column names\n    :type id_or_ignore_columns: str\n    :param options: diff options\n    :type options: DiffOptions\n    :return: the diff DataFrame\n    :rtype DataFrame\n    \"\"\"\n    return Differ(options).diff(self, other, *id_or_ignore_columns)\n\n\n@overload\ndef diffwith_with_options(self: DataFrame, other: DataFrame, options: DiffOptions, *id_columns: str) -> DataFrame: ...\n\n\n@overload\ndef diffwith_with_options(self: DataFrame, other: DataFrame, options: DiffOptions, id_columns: Iterable[str], ignore_columns: Iterable[str]) -> DataFrame: ...\n\n\n@deprecated(\"Use diffwith with identical arguments instead\")\ndef diffwith_with_options(self: DataFrame, other: DataFrame, options: DiffOptions, *id_or_ignore_columns: Union[str, Iterable[str]]) -> DataFrame:\n    \"\"\"\n    Returns a new DataFrame that contains the differences between the two DataFrames\n    as tuples of type `(String, Row, Row)`.\n\n    See `diff(left: DataFrame, right: DataFrame, *id_columns: str)`.\n\n    The schema of the returned DataFrame can be configured by the given `DiffOptions`.\n\n    :param other: right DataFrame\n    :type other: DataFrame\n    :param options: diff options\n    :type options: DiffOptions\n    :param id_or_ignore_columns: either id column names or two lists of column names,\n           first the id column names, second the ignore column names\n    :type id_or_ignore_columns: str\n    :return: the diff DataFrame\n    :rtype DataFrame\n    \"\"\"\n    return Differ(options).diffwith(self, other, *id_or_ignore_columns)\n\n\nDataFrame.diff = diff\nDataFrame.diffwith = diffwith\nDataFrame.diff_with_options = diff_with_options\nDataFrame.diffwith_with_options = diffwith_with_options\n\nif has_connect:\n    from gresearch.spark import ConnectDataFrame\n\n    ConnectDataFrame.diff = diff\n    ConnectDataFrame.diffwith = diffwith\n    ConnectDataFrame.diff_with_options = diff_with_options\n    ConnectDataFrame.diffwith_with_options = diffwith_with_options\n"
  },
  {
    "path": "python/gresearch/spark/diff/comparator/__init__.py",
    "content": "#  Copyright 2022 G-Research\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n\nimport abc\nimport dataclasses\nfrom dataclasses import dataclass\n\nfrom py4j.java_gateway import JVMView, JavaObject\nfrom pyspark.sql import Column\nfrom pyspark.sql.functions import abs, greatest, lit\nfrom pyspark.sql.types import DataType\n\nfrom gresearch.spark import _is_column\n\n\nclass DiffComparator(abc.ABC):\n    @abc.abstractmethod\n    def equiv(self, left: Column, right: Column) -> Column:\n        pass\n\n\nclass DiffComparators:\n    @staticmethod\n    def default() -> 'DefaultDiffComparator':\n        return DefaultDiffComparator()\n\n    @staticmethod\n    def nullSafeEqual() -> 'NullSafeEqualDiffComparator':\n        return NullSafeEqualDiffComparator()\n\n    @staticmethod\n    def epsilon(epsilon: float) -> 'EpsilonDiffComparator':\n        assert isinstance(epsilon, float), epsilon\n        return EpsilonDiffComparator(epsilon)\n\n    @staticmethod\n    def string(whitespace_agnostic: bool = True) -> 'StringDiffComparator':\n        assert isinstance(whitespace_agnostic, bool), whitespace_agnostic\n        return StringDiffComparator(whitespace_agnostic)\n\n    @staticmethod\n    def duration(duration: str) -> 'DurationDiffComparator':\n        assert isinstance(duration, str), duration\n        return DurationDiffComparator(duration)\n\n    @staticmethod\n    def map(key_type: DataType, value_type: DataType, key_order_sensitive: bool = False) -> 'MapDiffComparator':\n        assert isinstance(key_type, DataType), key_type\n        assert isinstance(value_type, DataType), value_type\n        assert isinstance(key_order_sensitive, bool), key_order_sensitive\n        return MapDiffComparator(key_type, value_type, key_order_sensitive)\n\n\nclass NullSafeEqualDiffComparator(DiffComparator):\n    def equiv(self, left: Column, right: Column) -> Column:\n        assert _is_column(left), left\n        assert _is_column(right), right\n        return left.eqNullSafe(right)\n\n\nclass DefaultDiffComparator(NullSafeEqualDiffComparator):\n    # for testing only\n    def _to_java(self, jvm: JVMView) -> JavaObject:\n        return jvm.uk.co.gresearch.spark.diff.DiffComparators.default()\n\n\n@dataclass(frozen=True)\nclass EpsilonDiffComparator(DiffComparator):\n    epsilon: float\n    relative: bool = True\n    inclusive: bool = True\n\n    def as_relative(self) -> 'EpsilonDiffComparator':\n        return dataclasses.replace(self, relative=True)\n\n    def as_absolute(self) -> 'EpsilonDiffComparator':\n        return dataclasses.replace(self, relative=False)\n\n    def as_inclusive(self) -> 'EpsilonDiffComparator':\n        return dataclasses.replace(self, inclusive=True)\n\n    def as_exclusive(self) -> 'EpsilonDiffComparator':\n        return dataclasses.replace(self, inclusive=False)\n\n    def equiv(self, left: Column, right: Column) -> Column:\n        assert _is_column(left), left\n        assert _is_column(right), right\n\n        threshold = greatest(abs(left), abs(right)) * self.epsilon if self.relative else lit(self.epsilon)\n\n        def inclusive_epsilon(diff: Column) -> Column:\n            return diff.__le__(threshold)\n\n        def exclusive_epsilon(diff: Column) -> Column:\n            return diff.__lt__(threshold)\n\n        in_epsilon = inclusive_epsilon if self.inclusive else exclusive_epsilon\n        return left.isNull() & right.isNull() | left.isNotNull() & right.isNotNull() & in_epsilon(abs(left - right))\n\n\n@dataclass(frozen=True)\nclass StringDiffComparator(DiffComparator):\n    whitespace_agnostic: bool\n\n    def equiv(self, left: Column, right: Column) -> Column:\n        assert _is_column(left), left\n        assert _is_column(right), right\n        return left.eqNullSafe(right)\n\n\n@dataclass(frozen=True)\nclass DurationDiffComparator(DiffComparator):\n    duration: str\n    inclusive: bool = True\n\n    def as_inclusive(self) -> 'DurationDiffComparator':\n        return dataclasses.replace(self, inclusive=True)\n\n    def as_exclusive(self) -> 'DurationDiffComparator':\n        return dataclasses.replace(self, inclusive=False)\n\n    def equiv(self, left: Column, right: Column) -> Column:\n        assert _is_column(left), left\n        assert _is_column(right), right\n        return left.eqNullSafe(right)\n\n\n@dataclass(frozen=True)\nclass MapDiffComparator(DiffComparator):\n    key_type: DataType\n    value_type: DataType\n    key_order_sensitive: bool\n\n    def equiv(self, left: Column, right: Column) -> Column:\n        assert _is_column(left), left\n        assert _is_column(right), right\n        return left.eqNullSafe(right)\n"
  },
  {
    "path": "python/gresearch/spark/parquet/__init__.py",
    "content": "#  Copyright 2023 G-Research\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n\nfrom typing import Optional\n\nfrom py4j.java_gateway import JavaObject\nfrom pyspark.sql import DataFrameReader, DataFrame\n\nfrom gresearch.spark import _get_jvm, _to_seq\n\ntry:\n    from pyspark.sql.connect.readwriter import DataFrameReader as ConnectDataFrameReader\n    has_connect = True\nexcept ImportError:\n    has_connect = False\n\n\ndef _jreader(reader: DataFrameReader) -> JavaObject:\n    jvm = _get_jvm(reader)\n    return jvm.uk.co.gresearch.spark.parquet.__getattr__(\"package$\").__getattr__(\"MODULE$\").ExtendedDataFrameReader(reader._jreader)\n\n\ndef parquet_metadata(self: DataFrameReader, *paths: str, parallelism: Optional[int] = None) -> DataFrame:\n    \"\"\"\n    Read the metadata of Parquet files into a Dataframe.\n\n    The returned DataFrame has as many partitions as specified via `parallelism`.\n    If not specified, there are as many partitions as there are Parquet files,\n    at most `spark.sparkContext.defaultParallelism` partitions.\n\n    This provides the following per-file information:\n    - filename (string): The file name\n    - blocks (int): Number of blocks / RowGroups in the Parquet file\n    - compressedBytes (long): Number of compressed bytes of all blocks\n    - uncompressedBytes (long): Number of uncompressed bytes of all blocks\n    - rows (long): Number of rows in the file\n    - columns (int): Number of rows in the file\n    - values (long): Number of values in the file\n    - nulls (long): Number of null values in the file\n    - createdBy (string): The createdBy string of the Parquet file, e.g. library used to write the file\n    - schema (string): The schema\n    - encryption (string): The encryption\n    - keyValues (string-to-string map): Key-value data of the file\n\n    :param self: a Spark DataFrameReader\n    :param paths: paths one or more paths to Parquet files or directories\n    :param parallelism: number of partitions of returned DataFrame\n    :return: dataframe with Parquet metadata\n    \"\"\"\n    jvm = _get_jvm(self)\n    if parallelism is None:\n        jdf = _jreader(self).parquetMetadata(_to_seq(jvm, list(paths)))\n    else:\n        jdf = _jreader(self).parquetMetadata(parallelism, _to_seq(jvm, list(paths)))\n    return DataFrame(jdf, self._spark)\n\n\ndef parquet_schema(self: DataFrameReader, *paths: str, parallelism: Optional[int] = None) -> DataFrame:\n    \"\"\"\n    Read the schema of Parquet files into a Dataframe.\n\n    The returned DataFrame has as many partitions as specified via `parallelism`.\n    If not specified, there are as many partitions as there are Parquet files,\n    at most `spark.sparkContext.defaultParallelism` partitions.\n\n    This provides the following per-file information:\n    - filename (string): The Parquet file name\n    - columnName (string): The column name\n    - columnPath (string array): The column path\n    - repetition (string): The repetition\n    - type (string): The data type\n    - length (int): The length of the type\n    - originalType (string): The original type\n    - isPrimitive (boolean: True if type is primitive\n    - primitiveType (string: The primitive type\n    - primitiveOrder (string: The order of the primitive type\n    - maxDefinitionLevel (int): The max definition level\n    - maxRepetitionLevel (int): The max repetition level\n\n    :param self: a Spark DataFrameReader\n    :param paths: paths one or more paths to Parquet files or directories\n    :param parallelism: number of partitions of returned DataFrame\n    :return: dataframe with Parquet metadata\n    \"\"\"\n    jvm = _get_jvm(self)\n    if parallelism is None:\n        jdf = _jreader(self).parquetSchema(_to_seq(jvm, list(paths)))\n    else:\n        jdf = _jreader(self).parquetSchema(parallelism, _to_seq(jvm, list(paths)))\n    return DataFrame(jdf, self._spark)\n\n\ndef parquet_blocks(self: DataFrameReader, *paths: str, parallelism: Optional[int] = None) -> DataFrame:\n    \"\"\"\n    Read the metadata of Parquet blocks into a Dataframe.\n\n    The returned DataFrame has as many partitions as specified via `parallelism`.\n    If not specified, there are as many partitions as there are Parquet files,\n    at most `spark.sparkContext.defaultParallelism` partitions.\n\n    This provides the following per-block information:\n    - filename (string): The file name\n    - block (int): Block / RowGroup number starting at 1\n    - blockStart (long): Start position of the block in the Parquet file\n    - compressedBytes (long): Number of compressed bytes in block\n    - uncompressedBytes (long): Number of uncompressed bytes in block\n    - rows (long): Number of rows in block\n    - columns (int): Number of columns in block\n    - values (long): Number of values in block\n    - nulls (long): Number of null values in block\n\n    :param self: a Spark DataFrameReader\n    :param paths: paths one or more paths to Parquet files or directories\n    :param parallelism: number of partitions of returned DataFrame\n    :return: dataframe with Parquet metadata\n    \"\"\"\n    jvm = _get_jvm(self)\n    if parallelism is None:\n        jdf = _jreader(self).parquetBlocks(_to_seq(jvm, list(paths)))\n    else:\n        jdf = _jreader(self).parquetBlocks(parallelism, _to_seq(jvm, list(paths)))\n    return DataFrame(jdf, self._spark)\n\n\ndef parquet_block_columns(self: DataFrameReader, *paths: str, parallelism: Optional[int] = None) -> DataFrame:\n    \"\"\"\n    Read the metadata of Parquet block columns into a Dataframe.\n\n    The returned DataFrame has as many partitions as specified via `parallelism`.\n    If not specified, there are as many partitions as there are Parquet files,\n    at most `spark.sparkContext.defaultParallelism` partitions.\n\n    This provides the following per-block-column information:\n    - filename (string): The file name\n    - block (int): Block / RowGroup number starting at 1\n    - column (array<string>): Block / RowGroup column name\n    - codec (string): The coded used to compress the block column values\n    - type (string): The data type of the block column\n    - encodings (array<string>): Encodings of the block column\n    - minValue (string): Minimum value of this column in this block\n    - maxValue (string): Maximum value of this column in this block\n    - columnStart (long): Start position of the block column in the Parquet file\n    - compressedBytes (long): Number of compressed bytes of this block column\n    - uncompressedBytes (long): Number of uncompressed bytes of this block column\n    - values (long): Number of values in this block column\n    - nulls (long): Number of null values in this block column\n\n    :param self: a Spark DataFrameReader\n    :param paths: paths one or more paths to Parquet files or directories\n    :param parallelism: number of partitions of returned DataFrame\n    :return: dataframe with Parquet metadata\n    \"\"\"\n    jvm = _get_jvm(self)\n    if parallelism is None:\n        jdf = _jreader(self).parquetBlockColumns(_to_seq(jvm, list(paths)))\n    else:\n        jdf = _jreader(self).parquetBlockColumns(parallelism, _to_seq(jvm, list(paths)))\n    return DataFrame(jdf, self._spark)\n\n\ndef parquet_partitions(self: DataFrameReader, *paths: str, parallelism: Optional[int] = None) -> DataFrame:\n    \"\"\"\n    Read the metadata of how Spark partitions Parquet files into a Dataframe.\n\n    The returned DataFrame has as many partitions as specified via `parallelism`.\n    If not specified, there are as many partitions as there are Parquet files,\n    at most `spark.sparkContext.defaultParallelism` partitions.\n\n    This provides the following per-partition information:\n    - partition (int): The Spark partition id\n    - partitionStart (long): The start position of the partition\n    - partitionEnd (long): The end position of the partition\n    - partitionLength (long): The length of the partition\n    - blocks (int): The number of Parquet blocks / RowGroups in this partition\n    - compressedBytes (long): The number of compressed bytes in this partition\n    - uncompressedBytes (long): The number of uncompressed bytes in this partition\n    - rows (long): The number of rows in this partition\n    - columns (int): The number of columns in this partition\n    - values (long): The number of values in this partition\n    - nulls (long): The number of null values in this partition\n    - filename (string): The Parquet file name\n    - fileLength (long): The length of the Parquet file\n\n    :param self: a Spark DataFrameReader\n    :param paths: paths one or more paths to Parquet files or directories\n    :param parallelism: number of partitions of returned DataFrame\n    :return: dataframe with Parquet metadata\n    \"\"\"\n    jvm = _get_jvm(self)\n    if parallelism is None:\n        jdf = _jreader(self).parquetPartitions(_to_seq(jvm, list(paths)))\n    else:\n        jdf = _jreader(self).parquetPartitions(parallelism, _to_seq(jvm, list(paths)))\n    return DataFrame(jdf, self._spark)\n\n\nDataFrameReader.parquet_metadata = parquet_metadata\nDataFrameReader.parquet_schema = parquet_schema\nDataFrameReader.parquet_blocks = parquet_blocks\nDataFrameReader.parquet_block_columns = parquet_block_columns\nDataFrameReader.parquet_partitions = parquet_partitions\n\nif has_connect:\n    ConnectDataFrameReader.parquet_metadata = parquet_metadata\n    ConnectDataFrameReader.parquet_schema = parquet_schema\n    ConnectDataFrameReader.parquet_blocks = parquet_blocks\n    ConnectDataFrameReader.parquet_block_columns = parquet_block_columns\n    ConnectDataFrameReader.parquet_partitions = parquet_partitions\n"
  },
  {
    "path": "python/pyproject.toml",
    "content": "[build-system]\nrequires = [\"setuptools\"]\nbuild-backend = \"setuptools.build_meta\"\n"
  },
  {
    "path": "python/pyspark/jars/.gitignore",
    "content": "# Ignore everything in this directory\n*\n# Except this file\n!.gitignore\n"
  },
  {
    "path": "python/setup.py",
    "content": "#!/usr/bin/env python3\n\n#  Copyright 2023 G-Research\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n\nimport shutil\nimport subprocess\nimport sys\nfrom pathlib import Path\nfrom setuptools import setup\nfrom setuptools.command.sdist import sdist\n\n\njar_version = '2.16.0-3.5-SNAPSHOT'\nscala_version = '2.13.8'\nscala_compat_version = '.'.join(scala_version.split('.')[:2])\nspark_compat_version = jar_version.split('-')[1]\njar_file = f\"spark-extension_{scala_compat_version}-{jar_version}.jar\"\nversion = jar_version.replace('SNAPSHOT', 'dev0').replace('-', '.')\n\n# read the contents of the README.md file\nlong_description = (Path(__file__).parent / \"README.md\").read_text()\n\n\nclass custom_sdist(sdist):\n    def make_distribution(self):\n        # build jar file via mvn if it does not exist\n        # then copy the jar file from target/ into python/pyspark/jars/\n        project_root = Path(__file__).parent.parent\n        jar_src_path = project_root / \"target\" / jar_file\n        jar_dst_path = project_root / \"python\" / \"pyspark\" / \"jars\" / jar_file\n\n        if not jar_src_path.exists():\n            # first set version for scala sources\n            set_version_command = [\"./set-version.sh\", f\"{spark_compat_version}.0\", scala_version]\n            # then package Scala sources\n            mvn_command = [\"mvn\", \"--batch-mode\", \"package\", \"-Dspotless.check.skip\", \"-DskipTests\", \"-Dmaven.test.skip=true\"]\n\n            print(' '.join(set_version_command))\n            try:\n                subprocess.check_call(set_version_command, cwd=str(project_root.absolute()))\n            except OSError as e:\n                raise RuntimeError(f'setting versions failed: {e}')\n\n            print(f\"building {jar_src_path}\")\n            print(' '.join(mvn_command))\n            try:\n                subprocess.check_call(mvn_command, cwd=str(project_root.absolute()))\n            except OSError as e:\n                raise RuntimeError(f'mvn command failed: {e}')\n\n            if not jar_src_path.exists():\n                print(f\"Building jar file succeeded but file does still not exist: {jar_src_path}\")\n                sys.exit(1)\n\n        print(f\"copying {jar_src_path} -> {jar_dst_path}\")\n        jar_dst_path.parent.mkdir(exist_ok=True)\n        shutil.copy2(jar_src_path, jar_dst_path)\n        self._add_data_files([(\"pyspark.jars\", \"pyspark/jars\", \".\", [jar_file])])\n\n        sdist.make_distribution(self)\n\n\nsetup(\n    name=\"pyspark-extension\",\n    version=version,\n    description=\"A library that provides useful extensions to Apache Spark.\",\n    long_description=long_description,\n    long_description_content_type=\"text/markdown\",\n    author=\"Enrico Minack\",\n    author_email=\"github@enrico.minack.dev\",\n    url=\"https://github.com/G-Research/spark-extension\",\n    cmdclass={'sdist': custom_sdist},\n    install_requires=[\"typing_extensions\"],\n    extras_require={\n        \"test\": [\n            \"pandas>=1.0.5\",\n            \"py4j\",\n            \"pyarrow>=4.0.0\",\n            f\"pyspark~={spark_compat_version}.0\",\n            \"pytest\",\n            \"unittest-xml-reporting\",\n        ],\n    },\n    packages=[\n        \"gresearch\",\n        \"gresearch.spark\",\n        \"gresearch.spark.diff\",\n        \"gresearch.spark.diff.comparator\",\n        \"gresearch.spark.parquet\",\n        \"pyspark.jars\",\n    ],\n    include_package_data=False,\n    package_data={\n        \"pyspark.jars\": [jar_file],\n    },\n    license=\"http://www.apache.org/licenses/LICENSE-2.0.html\",\n    python_requires=\">=3.7\",\n    classifiers=[\n        \"Development Status :: 5 - Production/Stable\",\n        \"License :: OSI Approved :: Apache Software License\",\n        \"Programming Language :: Python :: 3\",\n        \"Programming Language :: Python :: 3.7\",\n        \"Programming Language :: Python :: 3.8\",\n        \"Programming Language :: Python :: 3.9\",\n        \"Programming Language :: Python :: 3.10\",\n        \"Programming Language :: Python :: 3.11\",\n        \"Programming Language :: Python :: 3.12\",\n        \"Programming Language :: Python :: 3.13\",\n        \"Programming Language :: Python :: Implementation :: CPython\",\n        \"Programming Language :: Python :: Implementation :: PyPy\",\n        \"Typing :: Typed\",\n    ],\n)\n"
  },
  {
    "path": "python/test/__init__.py",
    "content": "#  Copyright 2020 G-Research\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n\n"
  },
  {
    "path": "python/test/spark_common.py",
    "content": "#  Copyright 2020 G-Research\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n\nimport logging\nimport os\nimport sys\nimport unittest\nfrom contextlib import contextmanager\nfrom pathlib import Path\n\nfrom pyspark import SparkConf\nfrom pyspark.sql import SparkSession\n\nlogger = logging.getLogger()\nlogger.level = logging.INFO\n\n\n@contextmanager\ndef spark_session():\n    session = SparkTest.get_spark_session()\n    try:\n        yield session\n    finally:\n        session.stop()\n\n\nclass SparkTest(unittest.TestCase):\n\n    @staticmethod\n    def main(file: str):\n        if len(sys.argv) == 2:\n            # location to store test results provided, this requires package unittest-xml-reporting\n            import xmlrunner\n\n            unittest.main(\n                module=f'test.{Path(file).name[:-3]}',\n                testRunner=xmlrunner.XMLTestRunner(output=sys.argv[1]),\n                argv=sys.argv[:1],\n                # these make sure that some options that are not applicable\n                # remain hidden from the help menu.\n                failfast=False, buffer=False, catchbreak=False\n            )\n        else:\n            unittest.main()\n\n    @staticmethod\n    def get_pom_path() -> str:\n        paths = ['.', '..', os.path.join('..', '..')]\n        for path in paths:\n            if os.path.exists(os.path.join(path, 'pom.xml')):\n                return path\n        raise RuntimeError('Could not find path to pom.xml, looked here: {}'.format(', '.join(paths)))\n\n    @staticmethod\n    def get_spark_config(path) -> SparkConf:\n        master = 'local[2]'\n        conf = SparkConf().setAppName('unit test').setMaster(master)\n        return conf.setAll([\n            ('spark.ui.showConsoleProgress', 'false'),\n            ('spark.test.home', os.environ.get('SPARK_HOME')),\n            ('spark.locality.wait', '0'),\n            ('spark.driver.extraClassPath', '{}'.format(':'.join([\n                os.path.join(os.getcwd(), path, 'target', 'classes'),\n                os.path.join(os.getcwd(), path, 'target', 'test-classes'),\n            ]))),\n        ])\n\n    @classmethod\n    def get_spark_session(cls) -> SparkSession:\n        builder = SparkSession.builder\n\n        if 'TEST_SPARK_CONNECT_SERVER' in os.environ:\n            builder.remote(os.environ['TEST_SPARK_CONNECT_SERVER'])\n        elif 'PYSPARK_GATEWAY_PORT' in os.environ:\n            logging.info('Running inside existing Spark environment')\n        else:\n            logging.info('Setting up Spark environment')\n            # setting conf spark.pyspark.python does not work\n            os.environ['PYSPARK_PYTHON'] = sys.executable\n            path = cls.get_pom_path()\n            conf = cls.get_spark_config(path)\n            builder.config(conf=conf)\n\n        return builder.getOrCreate()\n\n    spark: SparkSession = None\n    is_spark_connect: bool = 'TEST_SPARK_CONNECT_SERVER' in os.environ\n\n    @classmethod\n    def setUpClass(cls):\n        super(SparkTest, cls).setUpClass()\n        logging.info('launching Spark session')\n        cls.spark = cls.get_spark_session()\n\n    @classmethod\n    def tearDownClass(cls):\n        logging.info('stopping Spark session')\n        cls.spark.stop()\n        super(SparkTest, cls).tearDownClass()\n\n    @contextmanager\n    def sql_conf(self, pairs):\n        \"\"\"\n        Copied from pyspark/testing/sqlutils available from PySpark 3.5.0 and higher.\n        https://github.com/apache/spark/blob/v3.5.0/python/pyspark/testing/sqlutils.py#L171\n        http://www.apache.org/licenses/LICENSE-2.0\n\n        A convenient context manager to test some configuration specific logic. This sets\n        `value` to the configuration `key` and then restores it back when it exits.\n        \"\"\"\n        assert isinstance(pairs, dict), \"pairs should be a dictionary.\"\n        assert hasattr(self, \"spark\"), \"it should have 'spark' attribute, having a spark session.\"\n\n        keys = pairs.keys()\n        new_values = pairs.values()\n        old_values = [self.spark.conf.get(key, None) for key in keys]\n        for key, new_value in zip(keys, new_values):\n            self.spark.conf.set(key, new_value)\n        try:\n            yield\n        finally:\n            for key, old_value in zip(keys, old_values):\n                if old_value is None:\n                    self.spark.conf.unset(key)\n                else:\n                    self.spark.conf.set(key, old_value)\n"
  },
  {
    "path": "python/test/test_diff.py",
    "content": "#  Copyright 2020 G-Research\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\nimport contextlib\nimport re\n\nfrom py4j.java_gateway import JavaObject\nfrom pyspark.sql import Row\nfrom pyspark.sql.functions import col, when, abs\nfrom pyspark.sql.types import IntegerType, LongType, StringType, DateType, StructField, StructType, FloatType, DoubleType\nfrom unittest import skipIf\n\nfrom gresearch.spark.diff import Differ, DiffOptions, DiffMode, DiffComparators, diffwith\nfrom spark_common import SparkTest\n\n\nclass DiffTest(SparkTest):\n\n    expected_diff = None\n\n    @contextlib.contextmanager\n    def assert_requirement(self, error_message: str):\n        with self.assertRaises(ValueError) as e:\n            yield\n        self.assertEqual((error_message,), e.exception.args)\n\n    @classmethod\n    def setUpClass(cls):\n        super(DiffTest, cls).setUpClass()\n\n        value_row = Row('id', 'val', 'label')\n        cls.left_df = cls.spark.createDataFrame([\n            value_row(1, 1.0, 'one'),\n            value_row(2, 2.0, 'two'),\n            value_row(3, 3.0, 'three'),\n            value_row(4, None, None),\n            value_row(5, 5.0, 'five'),\n            value_row(7, 7.0, 'seven'),\n        ])\n\n        cls.right_df = cls.spark.createDataFrame([\n            value_row(1, 1.1, 'one'),\n            value_row(2, 2.0, 'Two'),\n            value_row(3, 3.0, 'three'),\n            value_row(4, 4.0, 'four'),\n            value_row(5, None, None),\n            value_row(6, 6.0, 'six'),\n        ])\n\n        diff_row = Row('diff', 'id', 'left_val', 'right_val', 'left_label', 'right_label')\n        cls.expected_diff = [\n            diff_row('C', 1, 1.0, 1.1, 'one', 'one'),\n            diff_row('C', 2, 2.0, 2.0, 'two', 'Two'),\n            diff_row('N', 3, 3.0, 3.0, 'three', 'three'),\n            diff_row('C', 4, None, 4.0, None, 'four'),\n            diff_row('C', 5, 5.0, None, 'five', None),\n            diff_row('I', 6, None, 6.0, None, 'six'),\n            diff_row('D', 7, 7.0, None, 'seven', None),\n        ]\n        diff_change_row = Row('diff', 'change', 'id', 'left_val', 'right_val', 'left_label', 'right_label')\n        cls.expected_diff_change = [\n            diff_change_row('C', ['val'], 1, 1.0, 1.1, 'one', 'one'),\n            diff_change_row('C', ['label'], 2, 2.0, 2.0, 'two', 'Two'),\n            diff_change_row('N', [], 3, 3.0, 3.0, 'three', 'three'),\n            diff_change_row('C', ['val', 'label'], 4, None, 4.0, None, 'four'),\n            diff_change_row('C', ['val', 'label'], 5, 5.0, None, 'five', None),\n            diff_change_row('I', None, 6, None, 6.0, None, 'six'),\n            diff_change_row('D', None, 7, 7.0, None, 'seven', None),\n        ]\n        cls.expected_diff_reversed = [\n            diff_row('C', 1, 1.1, 1.0, 'one', 'one'),\n            diff_row('C', 2, 2.0, 2.0, 'Two', 'two'),\n            diff_row('N', 3, 3.0, 3.0, 'three', 'three'),\n            diff_row('C', 4, 4.0, None, 'four', None),\n            diff_row('C', 5, None, 5.0, None, 'five'),\n            diff_row('D', 6, 6.0, None, 'six', None),\n            diff_row('I', 7, None, 7.0, None, 'seven'),\n        ]\n        cls.expected_diff_ignored = [\n            diff_row('C', 1, 1.0, 1.1, 'one', 'one'),\n            diff_row('N', 2, 2.0, 2.0, 'two', 'Two'),\n            diff_row('N', 3, 3.0, 3.0, 'three', 'three'),\n            diff_row('C', 4, None, 4.0, None, 'four'),\n            diff_row('C', 5, 5.0, None, 'five', None),\n            diff_row('I', 6, None, 6.0, None, 'six'),\n            diff_row('D', 7, 7.0, None, 'seven', None),\n        ]\n\n        diffwith_row = Row('diff', 'left', 'right')\n        cls.expected_diffwith = [\n            diffwith_row('C', value_row(1, 1.0, 'one'), value_row(1, 1.1, 'one')),\n            diffwith_row('C', value_row(2, 2.0, 'two'), value_row(2, 2.0, 'Two')),\n            diffwith_row('N', value_row(3, 3.0, 'three'), value_row(3, 3.0, 'three')),\n            diffwith_row('C', value_row(4, None, None), value_row(4, 4.0, 'four')),\n            diffwith_row('C', value_row(5, 5.0, 'five'), value_row(5, None, None)),\n            diffwith_row('I', None, value_row(6, 6.0, 'six')),\n            diffwith_row('D', value_row(7, 7.0, 'seven'), None),\n        ]\n        diffwith_with_options_row = Row('d', 'l_val', 'r_val')\n        cls.expected_diffwith_with_options = [\n            diffwith_with_options_row('c', value_row(1, 1.0, 'one'), value_row(1, 1.1, 'one')),\n            diffwith_with_options_row('c', value_row(2, 2.0, 'two'), value_row(2, 2.0, 'Two')),\n            diffwith_with_options_row('n', value_row(3, 3.0, 'three'), value_row(3, 3.0, 'three')),\n            diffwith_with_options_row('c', value_row(4, None, None), value_row(4, 4.0, 'four')),\n            diffwith_with_options_row('c', value_row(5, 5.0, 'five'), value_row(5, None, None)),\n            diffwith_with_options_row('i', None, value_row(6, 6.0, 'six')),\n            diffwith_with_options_row('r', value_row(7, 7.0, 'seven'), None),\n        ]\n        cls.expected_diffwith_ignored = [\n            diffwith_row('C', value_row(1, 1.0, 'one'), value_row(1, 1.1, 'one')),\n            diffwith_row('N', value_row(2, 2.0, 'two'), value_row(2, 2.0, 'Two')),\n            diffwith_row('N', value_row(3, 3.0, 'three'), value_row(3, 3.0, 'three')),\n            diffwith_row('C', value_row(4, None, None), value_row(4, 4.0, 'four')),\n            diffwith_row('C', value_row(5, 5.0, 'five'), value_row(5, None, None)),\n            diffwith_row('I', None, value_row(6, 6.0, 'six')),\n            diffwith_row('D', value_row(7, 7.0, 'seven'), None),\n        ]\n\n        diff_with_options_row = Row('d', 'id', 'l_val', 'r_val', 'l_label', 'r_label')\n        cls.expected_diff_with_options = [\n            diff_with_options_row('c', 1, 1.0, 1.1, 'one', 'one'),\n            diff_with_options_row('c', 2, 2.0, 2.0, 'two', 'Two'),\n            diff_with_options_row('n', 3, 3.0, 3.0, 'three', 'three'),\n            diff_with_options_row('c', 4, None, 4.0, None, 'four'),\n            diff_with_options_row('c', 5, 5.0, None, 'five', None),\n            diff_with_options_row('i', 6, None, 6.0, None, 'six'),\n            diff_with_options_row('r', 7, 7.0, None, 'seven', None),\n        ]\n        cls.expected_diff_with_options_ignored = [\n            diff_with_options_row('c', 1, 1.0, 1.1, 'one', 'one'),\n            diff_with_options_row('n', 2, 2.0, 2.0, 'two', 'Two'),\n            diff_with_options_row('n', 3, 3.0, 3.0, 'three', 'three'),\n            diff_with_options_row('c', 4, None, 4.0, None, 'four'),\n            diff_with_options_row('c', 5, 5.0, None, 'five', None),\n            diff_with_options_row('i', 6, None, 6.0, None, 'six'),\n            diff_with_options_row('r', 7, 7.0, None, 'seven', None),\n        ]\n\n        diff_with_changes_row = Row('diff', 'changes', 'id', 'left_val', 'right_val', 'left_label', 'right_label')\n        cls.expected_diff_with_changes = [\n            diff_with_changes_row('C', ['val'], 1, 1.0, 1.1, 'one', 'one'),\n            diff_with_changes_row('C', ['label'], 2, 2.0, 2.0, 'two', 'Two'),\n            diff_with_changes_row('N', [], 3, 3.0, 3.0, 'three', 'three'),\n            diff_with_changes_row('C', ['val', 'label'], 4, None, 4.0, None, 'four'),\n            diff_with_changes_row('C', ['val', 'label'], 5, 5.0, None, 'five', None),\n            diff_with_changes_row('I', None, 6, None, 6.0, None, 'six'),\n            diff_with_changes_row('D', None, 7, 7.0, None, 'seven', None),\n        ]\n\n        cls.expected_diff_in_column_by_column_mode = cls.expected_diff\n\n        diff_in_side_by_side_mode_row = Row('diff', 'id', 'left_val', 'left_label', 'right_val', 'right_label')\n        cls.expected_diff_in_side_by_side_mode = [\n            diff_in_side_by_side_mode_row('C', 1, 1.0, 'one', 1.1, 'one'),\n            diff_in_side_by_side_mode_row('C', 2, 2.0, 'two', 2.0, 'Two'),\n            diff_in_side_by_side_mode_row('N', 3, 3.0, 'three', 3.0, 'three'),\n            diff_in_side_by_side_mode_row('C', 4, None, None, 4.0, 'four'),\n            diff_in_side_by_side_mode_row('C', 5, 5.0, 'five', None, None),\n            diff_in_side_by_side_mode_row('I', 6, None, None, 6.0, 'six'),\n            diff_in_side_by_side_mode_row('D', 7, 7.0, 'seven', None, None),\n        ]\n\n        diff_in_left_side_mode_row = Row('diff', 'id', 'left_val', 'left_label')\n        cls.expected_diff_in_left_side_mode = [\n            diff_in_left_side_mode_row('C', 1, 1.0, 'one'),\n            diff_in_left_side_mode_row('C', 2, 2.0, 'two'),\n            diff_in_left_side_mode_row('N', 3, 3.0, 'three'),\n            diff_in_left_side_mode_row('C', 4, None, None),\n            diff_in_left_side_mode_row('C', 5, 5.0, 'five'),\n            diff_in_left_side_mode_row('I', 6, None, None),\n            diff_in_left_side_mode_row('D', 7, 7.0, 'seven'),\n        ]\n\n        diff_in_right_side_mode_row = Row('diff', 'id', 'right_val', 'right_label')\n        cls.expected_diff_in_right_side_mode = [\n            diff_in_right_side_mode_row('C', 1, 1.1, 'one'),\n            diff_in_right_side_mode_row('C', 2, 2.0, 'Two'),\n            diff_in_right_side_mode_row('N', 3, 3.0, 'three'),\n            diff_in_right_side_mode_row('C', 4, 4.0, 'four'),\n            diff_in_right_side_mode_row('C', 5, None, None),\n            diff_in_right_side_mode_row('I', 6, 6.0, 'six'),\n            diff_in_right_side_mode_row('D', 7, None, None),\n        ]\n\n        diff_in_sparse_mode_row = Row('diff', 'id', 'left_val', 'right_val', 'left_label', 'right_label')\n        cls.expected_diff_in_sparse_mode = [\n            diff_in_sparse_mode_row('C', 1, 1.0, 1.1, None, None),\n            diff_in_sparse_mode_row('C', 2, None, None, 'two', 'Two'),\n            diff_in_sparse_mode_row('N', 3, None, None, None, None),\n            diff_in_sparse_mode_row('C', 4, None, 4.0, None, 'four'),\n            diff_in_sparse_mode_row('C', 5, 5.0, None, 'five', None),\n            diff_in_sparse_mode_row('I', 6, None, 6.0, None, 'six'),\n            diff_in_sparse_mode_row('D', 7, 7.0, None, 'seven', None),\n        ]\n\n    def test_check_schema(self):\n        with self.subTest(\"duplicate columns\"):\n            with self.assert_requirement(\"The datasets have duplicate columns.\\n\"\n                                  \"Left column names: id, id\\nRight column names: id, id\"):\n                self.left_df.select(\"id\", \"id\").diff(self.right_df.select(\"id\", \"id\"), \"id\")\n\n        with self.subTest(\"case-sensitive id column\"):\n            with self.assert_requirement(\"Some id columns do not exist: ID missing among id, val, label\"):\n                with self.sql_conf({\"spark.sql.caseSensitive\": \"true\"}):\n                    self.left_df.diff(self.right_df, \"ID\")\n\n        left = self.left_df.withColumnRenamed(\"val\", \"diff\")\n        right = self.right_df.withColumnRenamed(\"val\", \"diff\")\n\n        with self.subTest(\"id column 'diff'\"):\n            with self.assert_requirement(\"The id columns must not contain the diff column name 'diff': id, diff, label\"):\n                left.diff(right)\n            with self.assert_requirement(\"The id columns must not contain the diff column name 'diff': diff\"):\n                left.diff(right, \"diff\")\n            with self.assert_requirement(\"The id columns must not contain the diff column name 'diff': diff, id\"):\n                left.diff(right, \"diff\", \"id\")\n\n            with self.sql_conf({\"spark.sql.caseSensitive\": \"false\"}):\n                with self.assert_requirement(\"The id columns must not contain the diff column name 'diff': Diff, id\"):\n                    left.withColumnRenamed(\"diff\", \"Diff\") \\\n                        .diff(right.withColumnRenamed(\"diff\", \"Diff\"), \"Diff\", \"id\")\n\n            with self.sql_conf({\"spark.sql.caseSensitive\": \"true\"}):\n                left.withColumnRenamed(\"diff\", \"Diff\") \\\n                    .diff(right.withColumnRenamed(\"diff\", \"Diff\"), \"Diff\", \"id\")\n\n        with self.subTest(\"non-id column 'diff\"):\n            actual = left.diff(right, \"id\").orderBy(\"id\")\n            expected_columns = [\"diff\", \"id\", \"left_diff\", \"right_diff\", \"left_label\", \"right_label\"]\n            self.assertEqual(actual.columns, expected_columns)\n            self.assertEqual(actual.collect(), self.expected_diff)\n\n        with self.subTest(\"non-id column produces diff column name\"):\n            options = DiffOptions() \\\n                .with_diff_column(\"a_val\") \\\n                .with_left_column_prefix(\"a\") \\\n                .with_right_column_prefix(\"b\")\n\n            with self.assert_requirement(\"The column prefixes 'a' and 'b', together with these non-id columns \" +\n                                  \"must not produce the diff column name 'a_val': val, label\"):\n                self.left_df.diff(self.right_df, options, \"id\")\n            with self.assert_requirement(\"The column prefixes 'a' and 'b', together with these non-id columns \" +\n                                  \"must not produce the diff column name 'b_val': val, label\"):\n                self.left_df.diff(self.right_df, options.with_diff_column(\"b_val\"), \"id\")\n\n        with self.subTest(\"non-id column would produce diff column name unless in left-side mode\"):\n            options = DiffOptions() \\\n                .with_diff_column(\"a_val\") \\\n                .with_left_column_prefix(\"a\") \\\n                .with_right_column_prefix(\"b\") \\\n                .with_diff_mode(DiffMode.LeftSide)\n            self.left_df.diff(self.right_df, options, \"id\")\n\n        with self.subTest(\"non-id column would produce diff column name unless in right-side mode\"):\n            options = DiffOptions() \\\n                .with_diff_column(\"b_val\") \\\n                .with_left_column_prefix(\"a\") \\\n                .with_right_column_prefix(\"b\") \\\n                .with_diff_mode(DiffMode.RightSide)\n            self.left_df.diff(self.right_df, options, \"id\")\n\n        with self.sql_conf({\"spark.sql.caseSensitive\": \"false\"}):\n            with self.subTest(\"case-insensitive non-id column produces diff column name\"):\n                options = DiffOptions() \\\n                    .with_diff_column(\"a_val\") \\\n                    .with_left_column_prefix(\"A\") \\\n                    .with_right_column_prefix(\"b\")\n                with self.assert_requirement(\"The column prefixes 'A' and 'b', together with these non-id columns \" +\n                                      \"must not produce the diff column name 'a_val': val, label\"):\n                    self.left_df.diff(self.right_df, options, \"id\")\n\n            with self.subTest(\"case-insensitive non-id column would produce diff column name unless in left-side mode\"):\n                options = DiffOptions() \\\n                    .with_diff_column(\"a_val\") \\\n                    .with_left_column_prefix(\"A\") \\\n                    .with_right_column_prefix(\"B\") \\\n                    .with_diff_mode(DiffMode.LeftSide)\n                self.left_df.diff(self.right_df, options, \"id\")\n\n            with self.subTest(\"case-insensitive non-id column would produce diff column name unless in right-side mode\"):\n                options = DiffOptions() \\\n                    .with_diff_column(\"b_val\") \\\n                    .with_left_column_prefix(\"A\") \\\n                    .with_right_column_prefix(\"B\") \\\n                    .with_diff_mode(DiffMode.RightSide)\n                self.left_df.diff(self.right_df, options, \"id\")\n\n        with self.sql_conf({\"spark.sql.caseSensitive\": \"true\"}):\n            with self.subTest(\"case-sensitive non-id column produces non-conflicting diff column name\"):\n                options = DiffOptions() \\\n                    .with_diff_column(\"a_val\") \\\n                    .with_left_column_prefix(\"A\") \\\n                    .with_right_column_prefix(\"B\") \\\n\n                actual = self.left_df.diff(self.right_df, options, \"id\").orderBy(\"id\")\n                expected_columns = [\"a_val\", \"id\", \"A_val\", \"B_val\", \"A_label\", \"B_label\"]\n                self.assertEqual(actual.columns, expected_columns)\n                self.assertEqual(actual.collect(), self.expected_diff)\n\n        left = self.left_df.withColumnRenamed(\"val\", \"change\")\n        right = self.right_df.withColumnRenamed(\"val\", \"change\")\n\n        with self.subTest(\"id column 'change'\"):\n            options = DiffOptions() \\\n                .with_change_column(\"change\")\n            with self.assert_requirement(\"The id columns must not contain the change column name 'change': id, change, label\"):\n                left.diff(right, options)\n            with self.assert_requirement(\"The id columns must not contain the change column name 'change': change\"):\n                left.diff(right, options, \"change\")\n            with self.assert_requirement(\"The id columns must not contain the change column name 'change': change, id\"):\n                left.diff(right, options, \"change\", \"id\")\n\n            with self.sql_conf({\"spark.sql.caseSensitive\": \"false\"}):\n                with self.assert_requirement(\"The id columns must not contain the change column name 'change': Change, id\"):\n                    left.withColumnRenamed(\"change\", \"Change\") \\\n                        .diff(right.withColumnRenamed(\"change\", \"Change\"), options, \"Change\", \"id\")\n\n            with self.sql_conf({\"spark.sql.caseSensitive\": \"true\"}):\n                left.withColumnRenamed(\"change\", \"Change\") \\\n                    .diff(right.withColumnRenamed(\"change\", \"Change\"), options, \"Change\", \"id\")\n\n        with self.subTest(\"non-id column 'change\"):\n            actual = left.diff(right, options, \"id\").orderBy(\"id\")\n            expected_columns = [\"diff\", \"change\", \"id\", \"left_change\", \"right_change\", \"left_label\", \"right_label\"]\n            diff_change_row = Row(*expected_columns)\n            expected_diff = [\n                diff_change_row('C', ['change'], 1, 1.0, 1.1, 'one', 'one'),\n                diff_change_row('C', ['label'], 2, 2.0, 2.0, 'two', 'Two'),\n                diff_change_row('N', [], 3, 3.0, 3.0, 'three', 'three'),\n                diff_change_row('C', ['change', 'label'], 4, None, 4.0, None, 'four'),\n                diff_change_row('C', ['change', 'label'], 5, 5.0, None, 'five', None),\n                diff_change_row('I', None, 6, None, 6.0, None, 'six'),\n                diff_change_row('D', None, 7, 7.0, None, 'seven', None),\n            ]\n            self.assertEqual(actual.columns, expected_columns)\n            self.assertEqual(actual.collect(), expected_diff)\n\n        with self.subTest(\"non-id column produces change column name\"):\n            options = DiffOptions() \\\n                .with_change_column(\"a_val\") \\\n                .with_left_column_prefix(\"a\") \\\n                .with_right_column_prefix(\"b\")\n            with self.assert_requirement(\"The column prefixes 'a' and 'b', together with these non-id columns \" +\n                                  \"must not produce the change column name 'a_val': val, label\"):\n                self.left_df.diff(self.right_df, options, \"id\")\n\n        with self.sql_conf({\"spark.sql.caseSensitive\": \"false\"}):\n            with self.subTest(\"case-insensitive non-id column produces change column name\"):\n                options = DiffOptions() \\\n                    .with_change_column(\"a_val\") \\\n                    .with_left_column_prefix(\"A\") \\\n                    .with_right_column_prefix(\"B\")\n                with self.assert_requirement(\"The column prefixes 'A' and 'B', together with these non-id columns \" +\n                                      \"must not produce the change column name 'a_val': val, label\"):\n                    self.left_df.diff(self.right_df, options, \"id\")\n\n        with self.sql_conf({\"spark.sql.caseSensitive\": \"true\"}):\n            with self.subTest(\"case-sensitive non-id column produces non-conflicting change column name\"):\n                options = DiffOptions() \\\n                    .with_change_column(\"a_val\") \\\n                    .with_left_column_prefix(\"A\") \\\n                    .with_right_column_prefix(\"B\")\n                actual = self.left_df.diff(self.right_df, options, \"id\").orderBy(\"id\")\n                expected_columns = [\"diff\", \"a_val\", \"id\", \"A_val\", \"B_val\", \"A_label\", \"B_label\"]\n                self.assertEqual(actual.columns, expected_columns)\n                self.assertEqual(actual.collect(), self.expected_diff_change)\n\n        left = self.left_df.select(col(\"id\").alias(\"first_id\"), col(\"val\").alias(\"id\"), \"label\")\n        right = self.right_df.select(col(\"id\").alias(\"first_id\"), col(\"val\").alias(\"id\"), \"label\")\n        with self.subTest(\"non-id column produces id column name\"):\n            options = DiffOptions() \\\n                .with_left_column_prefix(\"first\") \\\n                .with_right_column_prefix(\"second\")\n            with self.assert_requirement(\"The column prefixes 'first' and 'second', together with these non-id columns \" +\n                                  \"must not produce any id column name 'first_id': id, label\"):\n                left.diff(right, options, \"first_id\")\n\n        with self.sql_conf({\"spark.sql.caseSensitive\": \"false\"}):\n            with self.subTest(\"case-insensitive non-id column produces id column name\"):\n                options = DiffOptions() \\\n                    .with_left_column_prefix(\"FIRST\") \\\n                    .with_right_column_prefix(\"SECOND\")\n                with self.assert_requirement(\"The column prefixes 'FIRST' and 'SECOND', together with these non-id columns \" +\n                                      \"must not produce any id column name 'first_id': id, label\"):\n                    left.diff(right, options, \"first_id\")\n\n        with self.sql_conf({\"spark.sql.caseSensitive\": \"true\"}):\n            with self.subTest(\"case-sensitive non-id column produces non-conflicting id column name\"):\n                options = DiffOptions() \\\n                    .with_left_column_prefix(\"FIRST\") \\\n                    .with_right_column_prefix(\"SECOND\")\n                actual = left.diff(right, options, \"first_id\").orderBy(\"first_id\")\n                expected_columns = [\"diff\", \"first_id\", \"FIRST_id\", \"SECOND_id\", \"FIRST_label\", \"SECOND_label\"]\n                self.assertEqual(actual.columns, expected_columns)\n                self.assertEqual(actual.collect(), self.expected_diff)\n\n        with self.subTest(\"empty schema\"):\n            with self.assert_requirement(\"The schema must not be empty\"):\n                self.left_df.select().diff(self.right_df.select())\n\n        with self.subTest(\"empty schema after ignored columns\"):\n            with self.assert_requirement(\"The schema except ignored columns must not be empty\"):\n                self.left_df.select(\"id\", \"val\").diff(self.right_df.select(\"id\", \"label\"), [], [\"id\", \"val\", \"label\"])\n\n        with self.subTest(\"different types\"):\n            with self.assert_requirement(\"The datasets do not have the same schema.\\n\" +\n                                  \"Left extra columns: val (double)\\n\" +\n                                  \"Right extra columns: val (string)\"):\n                self.left_df.select(\"id\", \"val\").diff(self.right_df.select(\"id\", col(\"label\").alias(\"val\")))\n\n        with self.subTest(\"ignore columns with different types\"):\n            actual = self.left_df.select(\"id\", \"val\").diff(self.right_df.select(\"id\", col(\"label\").alias(\"val\")), [], [\"val\"])\n            expected_schema = [\n                (\"diff\", StringType()),\n                (\"id\", LongType()),\n                (\"left_val\", DoubleType()),\n                (\"right_val\", StringType()),\n            ]\n            self.assertEqual([(f.name, f.dataType) for f in actual.schema], expected_schema)\n\n        with self.subTest(\"diff with different column names\"):\n            with self.assert_requirement(\"The datasets do not have the same schema.\\n\" +\n                                  \"Left extra columns: val (double)\\n\" +\n                                  \"Right extra columns: label (string)\"):\n                self.left_df.select(\"id\", \"val\").diff(self.right_df.select(\"id\", \"label\"))\n\n        left = self.left_df.select(\"id\", \"val\", \"label\")\n        right = self.right_df.select(col(\"id\").alias(\"ID\"), col(\"val\").alias(\"VaL\"), \"label\")\n        with self.sql_conf({\"spark.sql.caseSensitive\": \"false\"}):\n            with self.subTest(\"case-insensitive column names\"):\n                actual = left.diff(right, \"id\").orderBy(\"id\")\n                reverse = right.diff(left, \"id\").orderBy(\"id\")\n                self.assertEqual(actual.columns, [\"diff\", \"id\", \"left_val\", \"right_VaL\", \"left_label\", \"right_label\"])\n                self.assertEqual(actual.collect(), self.expected_diff)\n                self.assertEqual(reverse.columns, [\"diff\", \"id\", \"left_VaL\", \"right_val\", \"left_label\", \"right_label\"])\n                self.assertEqual(reverse.collect(), self.expected_diff_reversed)\n\n        with self.sql_conf({\"spark.sql.caseSensitive\": \"true\"}):\n            with self.subTest(\"case-sensitive column names\"):\n                with self.assert_requirement(\"The datasets do not have the same schema.\\n\" +\n                                      \"Left extra columns: id (long), val (double)\\n\" +\n                                      \"Right extra columns: ID (long), VaL (double)\"):\n                    left.diff(right, \"id\")\n\n        with self.subTest(\"non-existing id column\"):\n            with self.assert_requirement(\"Some id columns do not exist: does not exists missing among id, val, label\"):\n                self.left_df.diff(self.right_df, \"does not exists\")\n\n        with self.subTest(\"different number of columns\"):\n            with self.assert_requirement(\"The number of columns doesn't match.\\n\" +\n                                  \"Left column names (2): id, val\\n\" +\n                                  \"Right column names (3): id, val, label\"):\n                self.left_df.select(\"id\", \"val\").diff(self.right_df, \"id\")\n\n        with self.subTest(\"different number of columns after ignoring columns\"):\n            left = self.left_df.select(\"id\", \"val\", col(\"label\").alias(\"meta\"))\n            right = self.right_df.select(\"id\", col(\"label\").alias(\"seq\"), \"val\")\n            with self.assert_requirement(\"The number of columns doesn't match.\\n\" +\n                                  \"Left column names except ignored columns (2): id, val\\n\" +\n                                  \"Right column names except ignored columns (3): id, seq, val\"):\n                left.diff(right, [\"id\"], [\"meta\"])\n\n        with self.subTest(\"diff column name in value columns in left-side diff mode\"):\n            options = DiffOptions().with_diff_column(\"val\").with_diff_mode(DiffMode.LeftSide)\n            with self.assert_requirement(\"The left non-id columns must not contain the diff column name 'val': val, label\"):\n                self.left_df.diff(self.right_df, options, \"id\")\n\n        with self.subTest(\"diff column name in value columns in right-side diff mode\"):\n            options = DiffOptions().with_diff_column(\"val\").with_diff_mode(DiffMode.RightSide)\n            with self.assert_requirement(\"The right non-id columns must not contain the diff column name 'val': val, label\"):\n                self.left_df.diff(self.right_df, options, \"id\")\n\n        with self.subTest(\"change column name in value columns in left-side diff mode\"):\n            options = DiffOptions().with_change_column(\"val\").with_diff_mode(DiffMode.LeftSide)\n            with self.assert_requirement(\"The left non-id columns must not contain the change column name 'val': val, label\"):\n                self.left_df.diff(self.right_df, options, \"id\")\n\n        with self.subTest(\"change column name in value columns in right-side diff mode\"):\n            options = DiffOptions().with_change_column(\"val\").with_diff_mode(DiffMode.RightSide)\n            with self.assert_requirement(\"The right non-id columns must not contain the change column name 'val': val, label\"):\n                self.left_df.diff(self.right_df, options, \"id\")\n\n    def test_dataframe_diff(self):\n        diff = self.left_df.diff(self.right_df, 'id').orderBy('id').collect()\n        self.assertEqual(self.expected_diff, diff)\n\n    def test_dataframe_diff_with_ids_ignored(self):\n        diff = self.left_df.diff(self.right_df, ['id'], ['label']).orderBy('id').collect()\n        self.assertEqual(self.expected_diff_ignored, diff)\n\n    def test_dataframe_diff_with_wrong_argument_types(self):\n        with self.subTest(\"id columns is not string\"):\n            with self.assert_requirement(\"The id_or_ignore_columns argument must either all be strings \"\n                                       \"or exactly two iterables of strings: int\"):\n                self.left_df.diff(self.right_df, 1)\n        with self.subTest(\"one of two id columns is not string\"):\n            with self.assert_requirement(\"The id_or_ignore_columns argument must either all be strings \"\n                                       \"or exactly two iterables of strings: str, int\"):\n                self.left_df.diff(self.right_df, \"id\", 1)\n        with self.subTest(\"one of three id columns is not string\"):\n            with self.assert_requirement(\"The id_or_ignore_columns argument must either all be strings \"\n                                       \"or exactly two iterables of strings: str, int, str\"):\n                self.left_df.diff(self.right_df, \"id\", 1, \"val\")\n\n        with self.subTest(\"id columns is not list\"):\n            with self.assert_requirement(\"The id_or_ignore_columns argument must either all be strings \"\n                                       \"or exactly two iterables of strings: int, list\"):\n                self.left_df.diff(self.right_df, 1, ['val'])\n        with self.subTest(\"one of id columns is not string\"):\n            with self.assert_requirement(\"The id_columns must all be strings: str, int\"):\n                self.left_df.diff(self.right_df, ['id', 1], ['val'])\n\n        with self.subTest(\"ignore columns is not list\"):\n            with self.assert_requirement(\"The id_or_ignore_columns argument must either all be strings \"\n                                       \"or exactly two iterables of strings: list, int\"):\n                self.left_df.diff(self.right_df, ['id'], 1)\n        with self.subTest(\"one of ignore columns is not string\"):\n            with self.assert_requirement(\"The ignore_columns must all be strings: str, int\"):\n                self.left_df.diff(self.right_df, ['id'], ['val', 1])\n\n        with self.subTest(\"one list of string id columns\"):\n            with self.assert_requirement(\"The id_or_ignore_columns argument must either all be strings \"\n                                       \"or exactly two iterables of strings: list\"):\n                self.left_df.diff(self.right_df, ['id'])\n\n        with self.subTest(\"three lists of string id and ignore columns\"):\n            with self.assert_requirement(\"The id_or_ignore_columns argument must either all be strings \"\n                                       \"or exactly two iterables of strings: list, list, list\"):\n                self.left_df.diff(self.right_df, ['id'], ['val'], ['three'])\n\n        with self.subTest(\"options not second argument\"):\n            with self.assert_requirement(\"Diff options must be given as second argument\"):\n                self.left_df.diff(self.right_df, 'id', DiffOptions())\n\n    def test_dataframe_diffwith(self):\n        diff = self.left_df.diffwith(self.right_df, 'id').orderBy('id').collect()\n        self.assertSetEqual(set(self.expected_diffwith), set(diff))\n        self.assertEqual(len(self.expected_diffwith), len(diff))\n\n    def test_dataframe_diffwith_with_default_options(self):\n        diff = self.left_df.diffwith(self.right_df, DiffOptions(), 'id').orderBy('id').collect()\n        self.assertSetEqual(set(self.expected_diffwith), set(diff))\n        self.assertEqual(len(self.expected_diffwith), len(diff))\n\n    def test_dataframe_diffwith_with_options(self):\n        options = DiffOptions('d', 'l', 'r', 'i', 'c', 'r', 'n', None)\n        diff = self.left_df.diffwith(self.right_df, options, 'id').orderBy('id').collect()\n        self.assertSetEqual(set(self.expected_diffwith_with_options), set(diff))\n        self.assertEqual(len(self.expected_diffwith_with_options), len(diff))\n\n    def test_dataframe_diffwith_with_ignored(self):\n        diff = self.left_df.diffwith(self.right_df, ['id'], ['label']).orderBy('id').collect()\n        self.assertSetEqual(set(self.expected_diffwith_ignored), set(diff))\n        self.assertEqual(len(self.expected_diffwith_ignored), len(diff))\n\n    def test_dataframe_diffwith_with_wrong_argument_types(self):\n        with self.subTest(\"id columns is not string\"):\n            with self.assert_requirement(\"The id_or_ignore_columns argument must either all be strings \"\n                                       \"or exactly two iterables of strings: int\"):\n                self.left_df.diffwith(self.right_df, 1)\n        with self.subTest(\"one of two id columns is not string\"):\n            with self.assert_requirement(\"The id_or_ignore_columns argument must either all be strings \"\n                                       \"or exactly two iterables of strings: str, int\"):\n                self.left_df.diffwith(self.right_df, \"id\", 1)\n        with self.subTest(\"one of three id columns is not string\"):\n            with self.assert_requirement(\"The id_or_ignore_columns argument must either all be strings \"\n                                       \"or exactly two iterables of strings: str, int, str\"):\n                self.left_df.diffwith(self.right_df, \"id\", 1, \"val\")\n\n        with self.subTest(\"id columns is not list\"):\n            with self.assert_requirement(\"The id_or_ignore_columns argument must either all be strings \"\n                                       \"or exactly two iterables of strings: int, list\"):\n                self.left_df.diffwith(self.right_df, 1, ['val'])\n        with self.subTest(\"one of id columns is not string\"):\n            with self.assert_requirement(\"The id_columns must all be strings: str, int\"):\n                self.left_df.diffwith(self.right_df, ['id', 1], ['val'])\n\n        with self.subTest(\"ignore columns is not list\"):\n            with self.assert_requirement(\"The id_or_ignore_columns argument must either all be strings \"\n                                       \"or exactly two iterables of strings: list, int\"):\n                self.left_df.diffwith(self.right_df, ['id'], 1)\n        with self.subTest(\"one of ignore columns is not string\"):\n            with self.assert_requirement(\"The ignore_columns must all be strings: str, int\"):\n                self.left_df.diffwith(self.right_df, ['id'], ['val', 1])\n\n        with self.subTest(\"one list of string id columns\"):\n            with self.assert_requirement(\"The id_or_ignore_columns argument must either all be strings \"\n                                       \"or exactly two iterables of strings: list\"):\n                self.left_df.diffwith(self.right_df, ['id'])\n\n        with self.subTest(\"three lists of string id and ignore columns\"):\n            with self.assert_requirement(\"The id_or_ignore_columns argument must either all be strings \"\n                                       \"or exactly two iterables of strings: list, list, list\"):\n                self.left_df.diffwith(self.right_df, ['id'], ['val'], ['three'])\n\n        with self.subTest(\"options not second argument\"):\n            with self.assert_requirement(\"Diff options must be given as second argument\"):\n                self.left_df.diffwith(self.right_df, 'id', DiffOptions())\n\n    def test_dataframe_diff_with_default_options(self):\n        diff = self.left_df.diff(self.right_df, DiffOptions(), 'id').orderBy('id').collect()\n        self.assertEqual(self.expected_diff, diff)\n        diff = self.left_df.diff_with_options(self.right_df, DiffOptions(), 'id').orderBy('id').collect()\n        self.assertEqual(self.expected_diff, diff)\n\n    def test_dataframe_diff_with_options(self):\n        options = DiffOptions('d', 'l', 'r', 'i', 'c', 'r', 'n', None)\n        diff = self.left_df.diff(self.right_df, options, 'id').orderBy('id').collect()\n        self.assertEqual(self.expected_diff_with_options, diff)\n        diff = self.left_df.diff_with_options(self.right_df, options, 'id').orderBy('id').collect()\n        self.assertEqual(self.expected_diff_with_options, diff)\n\n    def test_dataframe_diff_with_options_and_ignored(self):\n        options = DiffOptions('d', 'l', 'r', 'i', 'c', 'r', 'n', None)\n        diff = self.left_df.diff(self.right_df, options, ['id'], ['label']).orderBy('id').collect()\n        self.assertEqual(self.expected_diff_with_options_ignored, diff)\n        diff = self.left_df.diff_with_options(self.right_df, options, ['id'], ['label']).orderBy('id').collect()\n        self.assertEqual(self.expected_diff_with_options_ignored, diff)\n\n    def test_dataframe_diff_with_changes(self):\n        options = DiffOptions().with_change_column('changes')\n        diff = self.left_df.diff(self.right_df, options, 'id').orderBy('id').collect()\n        self.assertEqual(self.expected_diff_with_changes, diff)\n        diff = self.left_df.diff_with_options(self.right_df, options, 'id').orderBy('id').collect()\n        self.assertEqual(self.expected_diff_with_changes, diff)\n\n    def test_dataframe_diff_with_diff_mode_column_by_column(self):\n        options = DiffOptions().with_diff_mode(DiffMode.ColumnByColumn)\n        diff = self.left_df.diff(self.right_df, options, 'id').orderBy('id').collect()\n        self.assertEqual(self.expected_diff_in_column_by_column_mode, diff)\n        diff = self.left_df.diff_with_options(self.right_df, options, 'id').orderBy('id').collect()\n        self.assertEqual(self.expected_diff_in_column_by_column_mode, diff)\n\n    def test_dataframe_diff_with_diff_mode_side_by_side(self):\n        options = DiffOptions().with_diff_mode(DiffMode.SideBySide)\n        diff = self.left_df.diff(self.right_df, options, 'id').orderBy('id').collect()\n        self.assertEqual(self.expected_diff_in_side_by_side_mode, diff)\n        diff = self.left_df.diff_with_options(self.right_df, options, 'id').orderBy('id').collect()\n        self.assertEqual(self.expected_diff_in_side_by_side_mode, diff)\n\n    def test_dataframe_diff_with_diff_mode_left_side(self):\n        options = DiffOptions().with_diff_mode(DiffMode.LeftSide)\n        diff = self.left_df.diff(self.right_df, options, 'id').orderBy('id').collect()\n        self.assertEqual(self.expected_diff_in_left_side_mode, diff)\n        diff = self.left_df.diff_with_options(self.right_df, options, 'id').orderBy('id').collect()\n        self.assertEqual(self.expected_diff_in_left_side_mode, diff)\n\n    def test_dataframe_diff_with_diff_mode_right_side(self):\n        options = DiffOptions().with_diff_mode(DiffMode.RightSide)\n        diff = self.left_df.diff(self.right_df, options, 'id').orderBy('id').collect()\n        self.assertEqual(self.expected_diff_in_right_side_mode, diff)\n        diff = self.left_df.diff_with_options(self.right_df, options, 'id').orderBy('id').collect()\n        self.assertEqual(self.expected_diff_in_right_side_mode, diff)\n\n    def test_dataframe_diff_with_sparse_mode(self):\n        options = DiffOptions().with_sparse_mode(True)\n        diff = self.left_df.diff(self.right_df, options, 'id').orderBy('id').collect()\n        self.assertEqual(self.expected_diff_in_sparse_mode, diff)\n        diff = self.left_df.diff_with_options(self.right_df, options, 'id').orderBy('id').collect()\n        self.assertEqual(self.expected_diff_in_sparse_mode, diff)\n\n    def test_differ_diff(self):\n        diff = Differ().diff(self.left_df, self.right_df, 'id').orderBy('id').collect()\n        self.assertEqual(self.expected_diff, diff)\n\n    def test_differ_diffwith(self):\n        diff = Differ().diffwith(self.left_df, self.right_df, 'id').orderBy('id').collect()\n        self.assertSetEqual(set(self.expected_diffwith), set(diff))\n        self.assertEqual(len(self.expected_diffwith), len(diff))\n\n    def test_differ_diff_with_default_options(self):\n        options = DiffOptions()\n        diff = Differ(options).diff(self.left_df, self.right_df, 'id').orderBy('id').collect()\n        self.assertEqual(self.expected_diff, diff)\n\n    def test_differ_diff_with_options(self):\n        options = DiffOptions('d', 'l', 'r', 'i', 'c', 'r', 'n', None)\n        diff = Differ(options).diff(self.left_df, self.right_df, 'id').orderBy('id').collect()\n        self.assertEqual(self.expected_diff_with_options, diff)\n\n    def test_differ_diff_with_changes(self):\n        options = DiffOptions().with_change_column('changes')\n        diff = Differ(options).diff(self.left_df, self.right_df, 'id').orderBy('id').collect()\n        self.assertEqual(self.expected_diff_with_changes, diff)\n\n    def test_differ_diff_in_diff_mode_column_by_column(self):\n        options = DiffOptions().with_diff_mode(DiffMode.ColumnByColumn)\n        diff = Differ(options).diff(self.left_df, self.right_df, 'id').orderBy('id').collect()\n        self.assertEqual(self.expected_diff_in_column_by_column_mode, diff)\n\n    def test_differ_diff_in_diff_mode_side_by_side(self):\n        options = DiffOptions().with_diff_mode(DiffMode.SideBySide)\n        diff = Differ(options).diff(self.left_df, self.right_df, 'id').orderBy('id').collect()\n        self.assertEqual(self.expected_diff_in_side_by_side_mode, diff)\n\n    def test_differ_diff_in_diff_mode_left_side(self):\n        options = DiffOptions().with_diff_mode(DiffMode.LeftSide)\n        diff = Differ(options).diff(self.left_df, self.right_df, 'id').orderBy('id').collect()\n        self.assertEqual(self.expected_diff_in_left_side_mode, diff)\n\n    def test_differ_diff_in_diff_mode_right_side(self):\n        options = DiffOptions().with_diff_mode(DiffMode.RightSide)\n        diff = Differ(options).diff(self.left_df, self.right_df, 'id').orderBy('id').collect()\n        self.assertEqual(self.expected_diff_in_right_side_mode, diff)\n\n    def test_differ_diff_with_sparse_mode(self):\n        options = DiffOptions().with_sparse_mode(True)\n        diff = Differ(options).diff(self.left_df, self.right_df, 'id').orderBy('id').collect()\n        self.assertEqual(self.expected_diff_in_sparse_mode, diff)\n\n    @skipIf(SparkTest.is_spark_connect, \"Spark Connect does not provide access to the JVM\")\n    def test_diff_options_default(self):\n        jvm = self.spark._jvm\n        joptions = jvm.uk.co.gresearch.spark.diff.DiffOptions.default()\n        options = DiffOptions()\n        for attr in options.__dict__.keys():\n            const = re.sub(r'_(.)', lambda match: match.group(1).upper(), attr)\n            expected = getattr(joptions, const)()\n            actual = getattr(options, attr)\n\n            if type(expected) == JavaObject:\n                class_name = re.sub(r'\\$.*$', '', expected.getClass().getName())\n                if class_name in ['scala.None']:  # how does the Some(?) look like?\n                    actual = 'Some({})'.format(actual) if actual is not None else 'None'\n                if class_name in ['scala.collection.immutable.Map', 'scala.collection.mutable.Map']:\n                    actual = f'Map({\", \".join(f\"{key} -> {value._to_java(jvm).toString()}\" for key, value in actual.items())})'\n                expected = expected.toString()\n\n            if attr in ['diff_mode', 'default_comparator']:\n                # does the Python default diff mode resolve to the same Java diff mode enum value?\n                # does the Python diff comparator resolve to the same Java diff comparator?\n                self.assertEqual(expected, actual._to_java(jvm).toString(), '{} == {} ?'.format(attr, const))\n            else:\n                self.assertEqual(expected, actual, '{} == {} ?'.format(attr, const))\n\n    @skipIf(SparkTest.is_spark_connect, \"Spark Connect does not provide access to the JVM\")\n    def test_diff_mode_consts(self):\n        jvm = self.spark._jvm\n        jmodes = jvm.uk.co.gresearch.spark.diff.DiffMode\n        modes = DiffMode\n        for attr in modes.__dict__.keys():\n            if attr[0] != '_':\n                actual = getattr(modes, attr)\n                if isinstance(actual, DiffMode) and actual != DiffMode.Default:\n                    expected = getattr(jmodes, attr)()\n                    self.assertEqual(expected.toString(), actual.name, actual.name)\n        self.assertIsNotNone(DiffMode.Default.name, jmodes.Default().toString())\n\n    def test_diff_options_comparator_for(self):\n        cmp1 = DiffComparators.default()\n        cmp2 = DiffComparators.epsilon(0.01)\n        cmp3 = DiffComparators.string()\n\n        opts = DiffOptions() \\\n            .with_column_name_comparator(cmp1, \"abc\", \"def\") \\\n            .with_data_type_comparator(cmp2, LongType()) \\\n            .with_default_comparator(cmp3)\n\n        self.assertEqual(opts.comparator_for(StructField(\"abc\", IntegerType())), cmp1)\n        self.assertEqual(opts.comparator_for(StructField(\"def\", LongType())), cmp1)\n        self.assertEqual(opts.comparator_for(StructField(\"ghi\", LongType())), cmp2)\n        self.assertEqual(opts.comparator_for(StructField(\"jkl\", IntegerType())), cmp3)\n\n    def test_diff_fluent_setters(self):\n        cmp1 = DiffComparators.default()\n        cmp2 = DiffComparators.epsilon(0.01)\n        cmp3 = DiffComparators.string()\n        cmp4 = DiffComparators.duration('PT24H')\n\n        default = DiffOptions()\n        options = default \\\n            .with_diff_column('d') \\\n            .with_left_column_prefix('l') \\\n            .with_right_column_prefix('r') \\\n            .with_insert_diff_value('i') \\\n            .with_change_diff_value('c') \\\n            .with_delete_diff_value('r') \\\n            .with_nochange_diff_value('n') \\\n            .with_change_column('c') \\\n            .with_diff_mode(DiffMode.SideBySide) \\\n            .with_sparse_mode(True) \\\n            .with_default_comparator(cmp1) \\\n            .with_data_type_comparator(cmp2, IntegerType()) \\\n            .with_data_type_comparator(cmp3, StringType()) \\\n            .with_column_name_comparator(cmp4, 'value')\n\n        self.assertEqual(options.diff_column, 'd')\n        self.assertEqual(options.left_column_prefix, 'l')\n        self.assertEqual(options.right_column_prefix, 'r')\n        self.assertEqual(options.insert_diff_value, 'i')\n        self.assertEqual(options.change_diff_value, 'c')\n        self.assertEqual(options.delete_diff_value, 'r')\n        self.assertEqual(options.nochange_diff_value, 'n')\n        self.assertEqual(options.change_column, 'c')\n        self.assertEqual(options.diff_mode, DiffMode.SideBySide)\n        self.assertEqual(options.sparse_mode, True)\n        self.assertEqual(options.default_comparator, cmp1)\n        self.assertEqual(options.data_type_comparators, {IntegerType(): cmp2, StringType(): cmp3})\n        self.assertEqual(options.column_name_comparators, {'value': cmp4})\n\n        self.assertNotEqual(options.diff_column, default.diff_column)\n        self.assertNotEqual(options.left_column_prefix, default.left_column_prefix)\n        self.assertNotEqual(options.right_column_prefix, default.right_column_prefix)\n        self.assertNotEqual(options.insert_diff_value, default.insert_diff_value)\n        self.assertNotEqual(options.change_diff_value, default.change_diff_value)\n        self.assertNotEqual(options.delete_diff_value, default.delete_diff_value)\n        self.assertNotEqual(options.nochange_diff_value, default.nochange_diff_value)\n        self.assertNotEqual(options.change_column, default.change_column)\n        self.assertNotEqual(options.diff_mode, default.diff_mode)\n        self.assertNotEqual(options.sparse_mode, default.sparse_mode)\n\n        without_change = options.without_change_column()\n        self.assertEqual(without_change.diff_column, 'd')\n        self.assertEqual(without_change.left_column_prefix, 'l')\n        self.assertEqual(without_change.right_column_prefix, 'r')\n        self.assertEqual(without_change.insert_diff_value, 'i')\n        self.assertEqual(without_change.change_diff_value, 'c')\n        self.assertEqual(without_change.delete_diff_value, 'r')\n        self.assertEqual(without_change.nochange_diff_value, 'n')\n        self.assertIsNone(without_change.change_column)\n        self.assertEqual(without_change.diff_mode, DiffMode.SideBySide)\n        self.assertEqual(without_change.sparse_mode, True)\n\n    def test_diff_with_epsilon_comparator(self):\n        # relative inclusive epsilon\n        options = DiffOptions() \\\n            .with_column_name_comparator(DiffComparators.epsilon(0.1).as_relative().as_inclusive(), 'val')\n        diff = self.left_df.diff(self.right_df, options, 'id').orderBy('id').collect()\n        expected = self.spark.createDataFrame(self.expected_diff) \\\n            .withColumn(\"diff\", when(col(\"id\") == 1, \"N\").otherwise(col(\"diff\"))) \\\n            .collect()\n        self.assertEqual(expected, diff)\n\n        # relative exclusive epsilon\n        options = DiffOptions() \\\n            .with_column_name_comparator(DiffComparators.epsilon(0.0909).as_relative().as_exclusive(), 'val')\n        diff = self.left_df.diff(self.right_df, options, 'id').orderBy('id').collect()\n        self.assertEqual(self.expected_diff, diff)\n\n        # absolute inclusive epsilon\n        options = DiffOptions() \\\n            .with_column_name_comparator(DiffComparators.epsilon(0.10000000000000009).as_absolute().as_inclusive(), 'val')\n        diff = self.left_df.diff(self.right_df, options, 'id').orderBy('id').collect()\n        self.assertEqual(expected, diff)\n\n        # absolute exclusive epsilon\n        options = DiffOptions() \\\n            .with_column_name_comparator(DiffComparators.epsilon(0.10000000000000009).as_absolute().as_exclusive(), 'val')\n        diff = self.left_df.diff(self.right_df, options, 'id').orderBy('id').collect()\n        self.assertEqual(self.expected_diff, diff)\n\n    def test_diff_options_with_duplicate_comparators(self):\n        options = DiffOptions() \\\n            .with_data_type_comparator(DiffComparators.default(), DateType(), IntegerType()) \\\n            .with_column_name_comparator(DiffComparators.default(), 'col1', 'col2')\n\n        with self.assertRaisesRegex(ValueError, \"A comparator for data type date exists already.\"):\n            options.with_data_type_comparator(DiffComparators.default(), DateType())\n\n        with self.assertRaisesRegex(ValueError, \"A comparator for data type int exists already.\"):\n            options.with_data_type_comparator(DiffComparators.default(), IntegerType())\n\n        with self.assertRaisesRegex(ValueError, \"A comparator for data types date, int exists already.\"):\n            options.with_data_type_comparator(DiffComparators.default(), DateType(), IntegerType())\n\n        with self.assertRaisesRegex(ValueError, \"A comparator for column name col1 exists already.\"):\n            options.with_column_name_comparator(DiffComparators.default(), 'col1')\n\n        with self.assertRaisesRegex(ValueError, \"A comparator for column name col2 exists already.\"):\n            options.with_column_name_comparator(DiffComparators.default(), 'col2')\n\n        with self.assertRaisesRegex(ValueError, \"A comparator for column names col1, col2 exists already.\"):\n            options.with_column_name_comparator(DiffComparators.default(), 'col1', 'col2')\n\n\nif __name__ == '__main__':\n    SparkTest.main(__file__)\n"
  },
  {
    "path": "python/test/test_histogram.py",
    "content": "#  Copyright 2020 G-Research\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n\nfrom unittest import skipIf\n\nfrom spark_common import SparkTest\nimport gresearch.spark\n\n\n@skipIf(SparkTest.is_spark_connect, \"Spark Connect does not provide access to the JVM, required by Historgam\")\nclass HistogramTest(SparkTest):\n\n    @classmethod\n    def setUpClass(cls):\n        super(HistogramTest, cls).setUpClass()\n\n        cls.df = cls.spark.createDataFrame([\n            (1, 1),\n            (1, 2),\n            (1, 10),\n            (2, -3),\n            (2, 5),\n            (3, 8),\n        ], ['id', 'value'])\n\n    def test_histogram_with_ints(self):\n        hist = self.df.histogram([-5, 0, 5], 'value', 'id').orderBy('id').collect()\n        self.assertEqual([\n            {'id': 1, '≤-5': 0, '≤0': 0, '≤5': 2, '>5': 1},\n            {'id': 2, '≤-5': 0, '≤0': 1, '≤5': 1, '>5': 0},\n            {'id': 3, '≤-5': 0, '≤0': 0, '≤5': 0, '>5': 1},\n        ], [row.asDict() for row in hist])\n\n    def test_histogram_with_floats(self):\n        hist = self.df.histogram([-5.0, 0.0, 5.0], 'value', 'id').orderBy('id').collect()\n        self.assertEqual([\n            {'id': 1, '≤-5.0': 0, '≤0.0': 0, '≤5.0': 2, '>5.0': 1},\n            {'id': 2, '≤-5.0': 0, '≤0.0': 1, '≤5.0': 1, '>5.0': 0},\n            {'id': 3, '≤-5.0': 0, '≤0.0': 0, '≤5.0': 0, '>5.0': 1},\n        ], [row.asDict() for row in hist])\n\n\nif __name__ == '__main__':\n    SparkTest.main(__file__)\n"
  },
  {
    "path": "python/test/test_job_description.py",
    "content": "#  Copyright 2023 G-Research\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n\nfrom unittest import skipIf\n\nfrom pyspark import TaskContext, SparkContext\nfrom typing import Optional\n\nfrom spark_common import SparkTest\nfrom gresearch.spark import job_description, append_job_description\n\n\n@skipIf(SparkTest.is_spark_connect, \"Spark Connect does not provide access to the JVM, required by JobDescription\")\nclass JobDescriptionTest(SparkTest):\n\n    def _assert_job_description(self, expected: Optional[str]):\n        def get_job_description_func(part):\n            def func(row):\n                return row.id, part, TaskContext.get().getLocalProperty(\"spark.job.description\")\n            return func\n\n        descriptions = self.spark.range(3, numPartitions=3).rdd \\\n            .mapPartitionsWithIndex(lambda part, it: map(get_job_description_func(part), it)) \\\n            .collect()\n        self.assertEqual(\n            [(0, 0, expected), (1, 1, expected), (2, 2, expected)],\n            descriptions\n        )\n\n    def setUp(self) -> None:\n        SparkContext._active_spark_context.setJobDescription(None)\n\n    def test_with_job_description(self):\n        self._assert_job_description(None)\n        with job_description(\"job description\"):\n            self._assert_job_description(\"job description\")\n            with job_description(\"inner job description\"):\n                self._assert_job_description(\"inner job description\")\n            self._assert_job_description(\"job description\")\n            with job_description(\"inner job description\", True):\n                self._assert_job_description(\"job description\")\n            self._assert_job_description(\"job description\")\n        self._assert_job_description(None)\n        with job_description(\"other job description\", True):\n            self._assert_job_description(\"other job description\")\n        self._assert_job_description(None)\n\n    def test_append_job_description(self):\n        self._assert_job_description(None)\n        with append_job_description(\"job\"):\n            self._assert_job_description(\"job\")\n            with append_job_description(\"description\"):\n                self._assert_job_description(\"job - description\")\n            self._assert_job_description(\"job\")\n            with append_job_description(\"description 2\", \" \"):\n                self._assert_job_description(\"job description 2\")\n            self._assert_job_description(\"job\")\n        self._assert_job_description(None)\n\n\nif __name__ == '__main__':\n    SparkTest.main(__file__)\n"
  },
  {
    "path": "python/test/test_jvm.py",
    "content": "#  Copyright 2024 G-Research\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n\nfrom unittest import skipIf, skipUnless\n\nfrom pyspark.sql.functions import sum\n\nfrom gresearch.spark import _get_jvm, \\\n    dotnet_ticks_to_timestamp, dotnet_ticks_to_unix_epoch, dotnet_ticks_to_unix_epoch_nanos, \\\n    timestamp_to_dotnet_ticks, unix_epoch_to_dotnet_ticks, unix_epoch_nanos_to_dotnet_ticks, \\\n    histogram, job_description, append_description\nfrom gresearch.spark.diff import *\nfrom gresearch.spark.parquet import *\nfrom spark_common import SparkTest\n\nEXPECTED_UNSUPPORTED_MESSAGE = \"This feature is not supported for Spark Connect. Please use a classic Spark client. \" \\\n                               \"https://github.com/G-Research/spark-extension#spark-connect-server\"\n\n\nclass PackageTest(SparkTest):\n    df = None\n\n    @classmethod\n    def setUpClass(cls):\n        super(PackageTest, cls).setUpClass()\n        cls.df = cls.spark.createDataFrame([(1, \"one\"), (2, \"two\"), (3, \"three\")], [\"id\", \"value\"])\n\n    @skipIf(SparkTest.is_spark_connect, \"Spark classic client tests\")\n    def test_get_jvm_classic(self):\n        for obj in [self.spark, self.spark.sparkContext, self.df, self.spark.read]:\n            with self.subTest(type(obj).__name__):\n                self.assertIsNotNone(_get_jvm(obj))\n\n        with self.subTest(\"Unsupported\"):\n            with self.assertRaises(RuntimeError) as e:\n                _get_jvm(object())\n            self.assertEqual((\"Unsupported class: <class 'object'>\", ), e.exception.args)\n\n    @skipUnless(SparkTest.is_spark_connect, \"Spark connect client tests\")\n    def test_get_jvm_connect(self):\n        for obj in [self.spark, self.df, self.spark.read]:\n            with self.subTest(type(obj).__name__):\n                with self.assertRaises(RuntimeError) as e:\n                    _get_jvm(obj)\n                self.assertEqual((EXPECTED_UNSUPPORTED_MESSAGE, ), e.exception.args)\n\n        with self.subTest(\"Unsupported\"):\n            with self.assertRaises(RuntimeError) as e:\n                _get_jvm(object())\n            self.assertEqual((\"Unsupported class: <class 'object'>\", ), e.exception.args)\n\n    @skipIf(SparkTest.is_spark_connect, \"Spark classic client tests\")\n    def test_get_jvm_check_java_pkg_is_installed(self):\n        from gresearch import spark\n\n        is_installed = spark._java_pkg_is_installed\n\n        try:\n            spark._java_pkg_is_installed = False\n            with self.assertRaises(RuntimeError) as e:\n                _get_jvm(self.spark)\n            self.assertEqual((\"Java / Scala package not found! You need to add the Maven spark-extension package \"\n                              \"to your PySpark environment: https://github.com/G-Research/spark-extension#python\", ), e.exception.args)\n        finally:\n            spark._java_pkg_is_installed = is_installed\n\n    @skipUnless(SparkTest.is_spark_connect, \"Spark connect client tests\")\n    def test_dotnet_ticks(self):\n        for label, func in {\n            'dotnet_ticks_to_timestamp': dotnet_ticks_to_timestamp,\n            'dotnet_ticks_to_unix_epoch': dotnet_ticks_to_unix_epoch,\n            'dotnet_ticks_to_unix_epoch_nanos': dotnet_ticks_to_unix_epoch_nanos,\n            'timestamp_to_dotnet_ticks': timestamp_to_dotnet_ticks,\n            'unix_epoch_to_dotnet_ticks': unix_epoch_to_dotnet_ticks,\n            'unix_epoch_nanos_to_dotnet_ticks': unix_epoch_nanos_to_dotnet_ticks,\n        }.items():\n            with self.subTest(label):\n                with self.assertRaises(RuntimeError) as e:\n                    func(\"id\")\n                self.assertEqual((\"This method must be called inside an active Spark session\", ), e.exception.args)\n\n    @skipUnless(SparkTest.is_spark_connect, \"Spark connect client tests\")\n    def test_histogram(self):\n        with self.assertRaises(RuntimeError) as e:\n            self.df.histogram([1, 10, 100], \"bin\", sum)\n        self.assertEqual((EXPECTED_UNSUPPORTED_MESSAGE, ), e.exception.args)\n\n    @skipUnless(SparkTest.is_spark_connect, \"Spark connect client tests\")\n    def test_with_row_numbers(self):\n        with self.assertRaises(RuntimeError) as e:\n            self.df.with_row_numbers()\n        self.assertEqual((EXPECTED_UNSUPPORTED_MESSAGE, ), e.exception.args)\n\n    @skipUnless(SparkTest.is_spark_connect, \"Spark connect client tests\")\n    def test_job_description(self):\n        with self.assertRaises(RuntimeError) as e:\n            with job_description(\"job description\"):\n                pass\n        self.assertEqual((\"This method must be called inside an active Spark session\", ), e.exception.args)\n\n        with self.assertRaises(RuntimeError) as e:\n            with append_description(\"job description\"):\n                pass\n        self.assertEqual((\"This method must be called inside an active Spark session\", ), e.exception.args)\n\n    @skipUnless(SparkTest.is_spark_connect, \"Spark connect client tests\")\n    def test_create_temp_dir(self):\n        with self.assertRaises(RuntimeError) as e:\n            self.spark.create_temporary_dir(\"prefix\")\n        self.assertEqual((EXPECTED_UNSUPPORTED_MESSAGE, ), e.exception.args)\n\n    @skipUnless(SparkTest.is_spark_connect, \"Spark connect client tests\")\n    def test_install_pip_package(self):\n        with self.assertRaises(RuntimeError) as e:\n            self.spark.install_pip_package(\"pytest\")\n        self.assertEqual((EXPECTED_UNSUPPORTED_MESSAGE, ), e.exception.args)\n\n    @skipUnless(SparkTest.is_spark_connect, \"Spark connect client tests\")\n    def test_install_poetry_project(self):\n        with self.assertRaises(RuntimeError) as e:\n            self.spark.install_poetry_project(\"./poetry-project\")\n        self.assertEqual((EXPECTED_UNSUPPORTED_MESSAGE, ), e.exception.args)\n\n    @skipUnless(SparkTest.is_spark_connect, \"Spark connect client tests\")\n    def test_parquet(self):\n        for label, func in {\n            'parquet_metadata': lambda dr: dr.parquet_metadata(\"file.parquet\"),\n            'parquet_schema': lambda dr: dr.parquet_schema(\"file.parquet\"),\n            'parquet_blocks': lambda dr: dr.parquet_blocks(\"file.parquet\"),\n            'parquet_block_columns': lambda dr: dr.parquet_block_columns(\"file.parquet\"),\n            'parquet_partitions': lambda dr: dr.parquet_partitions(\"file.parquet\"),\n        }.items():\n            with self.subTest(label):\n                with self.assertRaises(RuntimeError) as e:\n                    func(self.spark.read)\n                self.assertEqual((EXPECTED_UNSUPPORTED_MESSAGE, ), e.exception.args)\n\n\nif __name__ == '__main__':\n    SparkTest.main(__file__)\n"
  },
  {
    "path": "python/test/test_package.py",
    "content": "#  Copyright 2023 G-Research\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\nimport datetime\nimport os\nfrom decimal import Decimal\nfrom subprocess import CalledProcessError\nfrom unittest import skipUnless, skipIf\n\nfrom pyspark import __version__, SparkContext\nfrom pyspark.sql import Row, SparkSession, SQLContext\nfrom pyspark.sql.functions import col, count\n\nfrom gresearch.spark import backticks, distinct_prefix_for, handle_configured_case_sensitivity, \\\n    list_contains_case_sensitivity, list_filter_case_sensitivity, list_diff_case_sensitivity, \\\n    dotnet_ticks_to_timestamp, dotnet_ticks_to_unix_epoch, dotnet_ticks_to_unix_epoch_nanos, \\\n    timestamp_to_dotnet_ticks, unix_epoch_to_dotnet_ticks, unix_epoch_nanos_to_dotnet_ticks, count_null\nfrom spark_common import SparkTest\n\ntry:\n    from pyspark.sql.connect.session import SparkSession as ConnectSparkSession\n    has_connect = True\nexcept ImportError:\n    has_connect = False\n\nPOETRY_PYTHON_ENV = \"POETRY_PYTHON\"\nRICH_SOURCES_ENV = \"RICH_SOURCES\"\n\n\nclass PackageTest(SparkTest):\n\n    @classmethod\n    def setUpClass(cls):\n        super(PackageTest, cls).setUpClass()\n\n        cls.ticks = cls.spark.createDataFrame([\n            (1, 599266080000000000),\n            (2, 621355968000000000),\n            (3, 638155413748959308),\n            (4, 638155413748959309),\n            (5, 638155413748959310),\n            (6, 713589688368547758),\n            (7, 946723967999999999)\n        ], ['id', 'tick'])\n\n        cls.timestamps = cls.spark.createDataFrame([\n            (1, datetime.datetime(1900, 1, 1, tzinfo=datetime.timezone.utc).astimezone().replace(tzinfo=None)),\n            (2, datetime.datetime(1970, 1, 1, tzinfo=datetime.timezone.utc).astimezone().replace(tzinfo=None)),\n            (3, datetime.datetime(2023, 3, 27, 19, 16, 14, 895930, datetime.timezone.utc).astimezone().replace(tzinfo=None)),\n            (4, datetime.datetime(2023, 3, 27, 19, 16, 14, 895930, datetime.timezone.utc).astimezone().replace(tzinfo=None)),\n            (5, datetime.datetime(2023, 3, 27, 19, 16, 14, 895931, datetime.timezone.utc).astimezone().replace(tzinfo=None)),\n            (6, datetime.datetime(2262, 4, 11, 23, 47, 16, 854775, datetime.timezone.utc).astimezone().replace(tzinfo=None)),\n            (7, datetime.datetime(3001, 1, 19, 7, 59, 59, 999999, datetime.timezone.utc).astimezone().replace(tzinfo=None))\n        ], ['id', 'timestamp'])\n\n        cls.unix = cls.spark.createDataFrame([\n            (1, Decimal('-2208988800.000000000')),\n            (2, Decimal('0E-9')),\n            (3, Decimal('1679944574.895930800')),\n            (4, Decimal('1679944574.895930900')),\n            (5, Decimal('1679944574.895931000')),\n            (6, Decimal('9223372036.854775800')),\n            (7, Decimal('32536799999.999999900'))\n        ], ['id', 'unix'])\n\n        cls.unix_nanos = cls.spark.createDataFrame([\n            (1, -2208988800000000000),\n            (2, 0),\n            (3, 1679944574895930800),\n            (4, 1679944574895930900),\n            (5, 1679944574895931000),\n            (6, 9223372036854775800),\n            (7, None)\n        ], ['id', 'unix_nanos'])\n\n        cls.ticks_from_timestamp = cls.spark.createDataFrame([\n            (1, 599266080000000000),\n            (2, 621355968000000000),\n            (3, 638155413748959300),\n            (4, 638155413748959300),\n            (5, 638155413748959310),\n            (6, 713589688368547750),\n            (7, 946723967999999990)\n        ], ['id', 'tick'])\n\n        cls.ticks_from_unix_nanos = cls.spark.createDataFrame([\n            (1, 599266080000000000),\n            (2, 621355968000000000),\n            (3, 638155413748959308),\n            (4, 638155413748959309),\n            (5, 638155413748959310),\n            (6, 713589688368547758),\n            (7, None)\n        ], ['id', 'tick'])\n\n    def compare_dfs(self, expected, actual):\n        print('expected')\n        expected.show(truncate=False)\n        print('actual')\n        actual.show(truncate=False)\n        self.assertEqual(\n            [row.asDict() for row in actual.collect()],\n            [row.asDict() for row in expected.collect()]\n        )\n\n    def test_backticks(self):\n        self.assertEqual(backticks(\"column\"), \"column\")\n        self.assertEqual(backticks(\"a.column\"), \"`a.column`\")\n        self.assertEqual(backticks(\"`a.column`\"), \"`a.column`\")\n        self.assertEqual(backticks(\"column\", \"a.field\"), \"column.`a.field`\")\n        self.assertEqual(backticks(\"a.column\", \"a.field\"), \"`a.column`.`a.field`\")\n        self.assertEqual(backticks(\"the.alias\", \"a.column\", \"a.field\"), \"`the.alias`.`a.column`.`a.field`\")\n\n    def test_distinct_prefix_for(self):\n        self.assertEqual(distinct_prefix_for([]), \"_\")\n        self.assertEqual(distinct_prefix_for([\"a\"]), \"_\")\n        self.assertEqual(distinct_prefix_for([\"abc\"]), \"_\")\n        self.assertEqual(distinct_prefix_for([\"a\", \"bc\", \"def\"]), \"_\")\n        self.assertEqual(distinct_prefix_for([\"_a\"]), \"__\")\n        self.assertEqual(distinct_prefix_for([\"_abc\"]), \"__\")\n        self.assertEqual(distinct_prefix_for([\"a\", \"_bc\", \"__def\"]), \"___\")\n\n    def test_handle_configured_case_sensitivity(self):\n        case_sensitive = False\n        with self.subTest(case_sensitive=case_sensitive):\n            self.assertEqual(handle_configured_case_sensitivity('abc', case_sensitive), 'abc')\n            self.assertEqual(handle_configured_case_sensitivity('AbC', case_sensitive), 'abc')\n            self.assertEqual(handle_configured_case_sensitivity('ABC', case_sensitive), 'abc')\n\n        case_sensitive = True\n        with self.subTest(case_sensitive=case_sensitive):\n            self.assertEqual(handle_configured_case_sensitivity('abc', case_sensitive), 'abc')\n            self.assertEqual(handle_configured_case_sensitivity('AbC', case_sensitive), 'AbC')\n            self.assertEqual(handle_configured_case_sensitivity('ABC', case_sensitive), 'ABC')\n\n    def test_list_contains_case_sensitivity(self):\n        the_list = ['abc', 'Def', 'GhI', 'JKL']\n        self.assertEqual(list_contains_case_sensitivity(the_list, 'a', case_sensitive=False), False)\n        self.assertEqual(list_contains_case_sensitivity(the_list, 'abc', case_sensitive=False), True)\n        self.assertEqual(list_contains_case_sensitivity(the_list, 'deF', case_sensitive=False), True)\n        self.assertEqual(list_contains_case_sensitivity(the_list, 'JKL', case_sensitive=False), True)\n\n        self.assertEqual(list_contains_case_sensitivity(the_list, 'a', case_sensitive=True), False)\n        self.assertEqual(list_contains_case_sensitivity(the_list, 'abc', case_sensitive=True), True)\n        self.assertEqual(list_contains_case_sensitivity(the_list, 'deF', case_sensitive=True), False)\n        self.assertEqual(list_contains_case_sensitivity(the_list, 'JKL', case_sensitive=True), True)\n\n    def test_list_filter_case_sensitivity(self):\n        the_list = ['abc', 'Def', 'GhI', 'JKL']\n        self.assertEqual(list_filter_case_sensitivity(the_list, ['a'], case_sensitive=False), [])\n        self.assertEqual(list_filter_case_sensitivity(the_list, ['abc'], case_sensitive=False), ['abc'])\n        self.assertEqual(list_filter_case_sensitivity(the_list, ['deF'], case_sensitive=False), ['Def'])\n        self.assertEqual(list_filter_case_sensitivity(the_list, ['JKL'], case_sensitive=False), ['JKL'])\n\n        self.assertEqual(list_filter_case_sensitivity(the_list, ['a'], case_sensitive=True), [])\n        self.assertEqual(list_filter_case_sensitivity(the_list, ['abc'], case_sensitive=True), ['abc'])\n        self.assertEqual(list_filter_case_sensitivity(the_list, ['deF'], case_sensitive=True), [])\n        self.assertEqual(list_filter_case_sensitivity(the_list, ['JKL'], case_sensitive=True), ['JKL'])\n\n    def test_list_diff_case_sensitivity(self):\n        the_list = ['abc', 'Def', 'GhI', 'JKL']\n        self.assertEqual(list_diff_case_sensitivity(the_list, ['a'], case_sensitive=False), the_list)\n        self.assertEqual(list_diff_case_sensitivity(the_list, ['abc'], case_sensitive=False), ['Def', 'GhI', 'JKL'])\n        self.assertEqual(list_diff_case_sensitivity(the_list, ['deF'], case_sensitive=False), ['abc', 'GhI', 'JKL'])\n        self.assertEqual(list_diff_case_sensitivity(the_list, ['JKL'], case_sensitive=False), ['abc', 'Def', 'GhI'])\n\n        self.assertEqual(list_diff_case_sensitivity(the_list, ['a'], case_sensitive=True), the_list)\n        self.assertEqual(list_diff_case_sensitivity(the_list, ['abc'], case_sensitive=True), ['Def', 'GhI', 'JKL'])\n        self.assertEqual(list_diff_case_sensitivity(the_list, ['deF'], case_sensitive=True), the_list)\n        self.assertEqual(list_diff_case_sensitivity(the_list, ['JKL'], case_sensitive=True), ['abc', 'Def', 'GhI'])\n\n    @skipIf(SparkTest.is_spark_connect, \"Spark Connect does not provide access to the JVM, required by dotnet ticks\")\n    def test_dotnet_ticks_to_timestamp(self):\n        for column in [\"tick\", self.ticks.tick]:\n            with self.subTest(column=column):\n                timestamps = self.ticks.withColumn(\"timestamp\", dotnet_ticks_to_timestamp(column)).orderBy('id')\n                expected = self.ticks.join(self.timestamps, \"id\").orderBy('id')\n                self.compare_dfs(expected, timestamps)\n\n    @skipIf(SparkTest.is_spark_connect, \"Spark Connect does not provide access to the JVM, required by dotnet ticks\")\n    def test_dotnet_ticks_to_unix_epoch(self):\n        for column in [\"tick\", self.ticks.tick]:\n            with self.subTest(column=column):\n                timestamps = self.ticks.withColumn(\"unix\", dotnet_ticks_to_unix_epoch(column)).orderBy('id')\n                expected = self.ticks.join(self.unix, \"id\").orderBy('id')\n                self.compare_dfs(expected, timestamps)\n\n    @skipIf(SparkTest.is_spark_connect, \"Spark Connect does not provide access to the JVM, required by dotnet ticks\")\n    def test_dotnet_ticks_to_unix_epoch_nanos(self):\n        self.maxDiff = None\n        for column in [\"tick\", self.ticks.tick]:\n            with self.subTest(column=column):\n                timestamps = self.ticks.withColumn(\"unix_nanos\", dotnet_ticks_to_unix_epoch_nanos(column)).orderBy('id')\n                expected = self.ticks.join(self.unix_nanos, \"id\").orderBy('id')\n                self.compare_dfs(expected, timestamps)\n\n    @skipIf(SparkTest.is_spark_connect, \"Spark Connect does not provide access to the JVM, required by dotnet ticks\")\n    def test_timestamp_to_dotnet_ticks(self):\n        if self.spark.version.startswith('3.0.'):\n            self.skipTest('timestamp_to_dotnet_ticks not supported by Spark 3.0')\n        for column in [\"timestamp\", self.timestamps.timestamp]:\n            with self.subTest(column=column):\n                timestamps = self.timestamps.withColumn(\"tick\", timestamp_to_dotnet_ticks(column)).orderBy('id')\n                expected = self.timestamps.join(self.ticks_from_timestamp, \"id\").orderBy('id')\n                self.compare_dfs(expected, timestamps)\n\n    @skipIf(SparkTest.is_spark_connect, \"Spark Connect does not provide access to the JVM, required by dotnet ticks\")\n    def test_unix_epoch_dotnet_ticks(self):\n        for column in [\"unix\", self.unix.unix]:\n            with self.subTest(column=column):\n                timestamps = self.unix.withColumn(\"tick\", unix_epoch_to_dotnet_ticks(column)).orderBy('id')\n                expected = self.unix.join(self.ticks, \"id\").orderBy('id')\n                self.compare_dfs(expected, timestamps)\n\n    @skipIf(SparkTest.is_spark_connect, \"Spark Connect does not provide access to the JVM, required by dotnet ticks\")\n    def test_unix_epoch_nanos_to_dotnet_ticks(self):\n        for column in [\"unix_nanos\", self.unix_nanos.unix_nanos]:\n            with self.subTest(column=column):\n                timestamps = self.unix_nanos.withColumn(\"tick\", unix_epoch_nanos_to_dotnet_ticks(column)).orderBy('id')\n                expected = self.unix_nanos.join(self.ticks_from_unix_nanos, \"id\").orderBy('id')\n                self.compare_dfs(expected, timestamps)\n\n    def test_count_null(self):\n        actual = self.unix_nanos.select(\n            count(\"id\").alias(\"ids\"),\n            count(col(\"unix_nanos\")).alias(\"nanos\"),\n            count_null(\"id\").alias(\"null_ids\"),\n            count_null(col(\"unix_nanos\")).alias(\"null_nanos\"),\n        ).collect()\n        self.assertEqual([Row(ids=7, nanos=6, null_ids=0, null_nanos=1)], actual)\n\n    def test_session(self):\n        self.assertIsNotNone(self.ticks.session())\n        self.assertIsInstance(self.ticks.session(), tuple(([SparkSession] + ([ConnectSparkSession] if has_connect else []))))\n\n    def test_session_or_ctx(self):\n        self.assertIsNotNone(self.ticks.session_or_ctx())\n        self.assertIsInstance(self.ticks.session_or_ctx(), tuple(([SparkSession, SQLContext] + ([ConnectSparkSession] if has_connect else []))))\n\n    @skipIf(SparkTest.is_spark_connect, \"Spark Connect does not provide access to the JVM, required by create_temp_dir\")\n    def test_create_temp_dir(self):\n        from pyspark import SparkFiles\n\n        dir = self.spark.create_temporary_dir(\"prefix\")\n        self.assertTrue(dir.startswith(SparkFiles.getRootDirectory()))\n\n    @skipIf(__version__.startswith('3.0.'), 'install_pip_package not supported for Spark 3.0')\n    @skipIf(SparkTest.is_spark_connect, \"Spark Connect does not provide access to the JVM, required by install_pip_package\")\n    def test_install_pip_package(self):\n        self.spark.sparkContext.setLogLevel(\"INFO\")\n        with self.assertRaises(ImportError):\n            # noinspection PyPackageRequirements\n            import emoji\n            emoji.emojize(\"this test is :thumbs_up:\")\n\n        self.spark.install_pip_package(\"emoji\", '--cache', '.cache/pypi')\n\n        # noinspection PyPackageRequirements\n        import emoji\n        actual = emoji.emojize(\"this test is :thumbs_up:\")\n        expected = \"this test is 👍\"\n        self.assertEqual(expected, actual)\n\n        import pandas as pd\n        actual = self.spark.range(0, 10, 1, 10) \\\n            .mapInPandas(lambda it: [pd.DataFrame.from_dict({\"val\": [emoji.emojize(\":thumbs_up:\")]})], \"val string\") \\\n            .collect()\n        expected = [Row(\"👍\")] * 10\n        self.assertEqual(expected, actual)\n\n    @skipIf(__version__.startswith('3.0.'), 'install_pip_package not supported for Spark 3.0')\n    @skipIf(SparkTest.is_spark_connect, \"Spark Connect does not provide access to the JVM, required by install_pip_package\")\n    def test_install_pip_package_unknown_argument(self):\n        with self.assertRaises(CalledProcessError):\n            self.spark.install_pip_package(\"--unknown\", \"argument\")\n\n    @skipIf(__version__.startswith('3.0.'), 'install_pip_package not supported for Spark 3.0')\n    @skipIf(SparkTest.is_spark_connect, \"Spark Connect does not provide access to the JVM, required by install_pip_package\")\n    def test_install_pip_package_package_not_found(self):\n        with self.assertRaises(CalledProcessError):\n            self.spark.install_pip_package(\"pyspark-extension==abc\")\n\n    @skipUnless(__version__.startswith('3.0.'), 'install_pip_package not supported for Spark 3.0')\n    @skipIf(SparkTest.is_spark_connect, \"Spark Connect does not provide access to the JVM, required by install_pip_package\")\n    def test_install_pip_package_not_supported(self):\n        with self.assertRaises(NotImplementedError):\n            self.spark.install_pip_package(\"emoji\")\n\n    @skipIf(__version__.startswith('3.0.'), 'install_poetry_project not supported for Spark 3.0')\n    # provide an environment variable with path to the python binary of a virtual env that has poetry installed\n    @skipIf(POETRY_PYTHON_ENV not in os.environ, f'Environment variable {POETRY_PYTHON_ENV} pointing to '\n                                                 f'virtual env python with poetry required')\n    @skipIf(RICH_SOURCES_ENV not in os.environ, f'Environment variable {RICH_SOURCES_ENV} pointing to '\n                                                f'rich project sources required')\n    @skipIf(SparkTest.is_spark_connect, \"Spark Connect does not provide access to the JVM, required by install_poetry_project\")\n    def test_install_poetry_project(self):\n        self.spark.sparkContext.setLogLevel(\"INFO\")\n        with self.assertRaises(ImportError):\n            # noinspection PyPackageRequirements\n            from rich.emoji import Emoji\n            thumbs_up = Emoji(\"thumbs_up\")\n\n        rich_path = os.environ[RICH_SOURCES_ENV]\n        poetry_python = os.environ[POETRY_PYTHON_ENV]\n        self.spark.install_poetry_project(\n            rich_path,\n            poetry_python=poetry_python,\n            pip_args=['--cache', '.cache/pypi']\n        )\n\n        # noinspection PyPackageRequirements\n        from rich.emoji import Emoji\n        thumbs_up = Emoji(\"thumbs_up\")\n        actual = thumbs_up.replace(\"this test is :thumbs_up:\")\n        expected = \"this test is 👍\"\n        self.assertEqual(expected, actual)\n\n        import pandas as pd\n        actual = self.spark.range(0, 10, 1, 10) \\\n            .mapInPandas(lambda it: [pd.DataFrame.from_dict({\"val\": [thumbs_up.replace(\":thumbs_up:\")]})], \"val string\") \\\n            .collect()\n        expected = [Row(\"👍\")] * 10\n        self.assertEqual(expected, actual)\n\n    @skipIf(__version__.startswith('3.0.'), 'install_poetry_project not supported for Spark 3.0')\n    # provide an environment variable with path to the python binary of a virtual env that has poetry installed\n    @skipIf(POETRY_PYTHON_ENV not in os.environ, f'Environment variable {POETRY_PYTHON_ENV} pointing to '\n                                                 f'virtual env python with poetry required')\n    @skipIf(RICH_SOURCES_ENV not in os.environ, f'Environment variable {RICH_SOURCES_ENV} pointing to '\n                                                f'rich project sources required')\n    @skipIf(SparkTest.is_spark_connect, \"Spark Connect does not provide access to the JVM, required by install_poetry_project\")\n    def test_install_poetry_project_wrong_arguments(self):\n        rich_path = os.environ[RICH_SOURCES_ENV]\n        poetry_python = os.environ[POETRY_PYTHON_ENV]\n\n        with self.assertRaises(RuntimeError):\n            self.spark.install_poetry_project(\"non-existing-project\", poetry_python=poetry_python)\n        with self.assertRaises(FileNotFoundError):\n            self.spark.install_poetry_project(rich_path, poetry_python=\"non-existing-python\")\n\n    @skipUnless(__version__.startswith('3.0.'), 'install_poetry_project not supported for Spark 3.0')\n    @skipIf(SparkTest.is_spark_connect, \"Spark Connect does not provide access to the JVM, required by install_poetry_project\")\n    def test_install_poetry_project_not_supported(self):\n        with self.assertRaises(NotImplementedError):\n            self.spark.install_poetry_project(\"./rich\")\n\n\nif __name__ == '__main__':\n    SparkTest.main(__file__)\n"
  },
  {
    "path": "python/test/test_parquet.py",
    "content": "#  Copyright 2023 G-Research\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n\nfrom pathlib import Path\nfrom unittest import skipIf\n\nfrom spark_common import SparkTest\nimport gresearch.spark.parquet\n\n\n@skipIf(SparkTest.is_spark_connect, \"Spark Connect does not provide access to the JVM, required by Parquet\")\nclass ParquetTest(SparkTest):\n\n    test_file = str((Path(__file__).parent.parent.parent / \"src\" / \"test\" / \"files\" / \"test.parquet\").resolve())\n\n    def test_parquet_metadata(self):\n        self.assertEqual(self.spark.read.parquet_metadata(self.test_file).count(), 2)\n        self.assertEqual(self.spark.read.parquet_metadata(self.test_file, self.test_file).count(), 2)\n        self.assertEqual(self.spark.read.parquet_metadata(self.test_file, parallelism=100).count(), 2)\n        self.assertEqual(self.spark.read.parquet_metadata(self.test_file, self.test_file, parallelism=100).count(), 2)\n\n    def test_parquet_schema(self):\n        self.assertEqual(self.spark.read.parquet_schema(self.test_file).count(), 4)\n        self.assertEqual(self.spark.read.parquet_schema(self.test_file, self.test_file).count(), 4)\n        self.assertEqual(self.spark.read.parquet_schema(self.test_file, parallelism=100).count(), 4)\n        self.assertEqual(self.spark.read.parquet_schema(self.test_file, self.test_file, parallelism=100).count(), 4)\n\n    def test_parquet_blocks(self):\n        self.assertEqual(self.spark.read.parquet_blocks(self.test_file).count(), 3)\n        self.assertEqual(self.spark.read.parquet_blocks(self.test_file, self.test_file).count(), 3)\n        self.assertEqual(self.spark.read.parquet_blocks(self.test_file, parallelism=100).count(), 3)\n        self.assertEqual(self.spark.read.parquet_blocks(self.test_file, self.test_file, parallelism=100).count(), 3)\n\n    def test_parquet_block_columns(self):\n        self.assertEqual(self.spark.read.parquet_block_columns(self.test_file).count(), 6)\n        self.assertEqual(self.spark.read.parquet_block_columns(self.test_file, self.test_file).count(), 6)\n        self.assertEqual(self.spark.read.parquet_block_columns(self.test_file, parallelism=100).count(), 6)\n        self.assertEqual(self.spark.read.parquet_block_columns(self.test_file, self.test_file, parallelism=100).count(), 6)\n\n    def test_parquet_partitions(self):\n        self.assertEqual(self.spark.read.parquet_partitions(self.test_file).count(), 2)\n        self.assertEqual(self.spark.read.parquet_partitions(self.test_file, self.test_file).count(), 2)\n        self.assertEqual(self.spark.read.parquet_partitions(self.test_file, parallelism=100).count(), 2)\n        self.assertEqual(self.spark.read.parquet_partitions(self.test_file, self.test_file, parallelism=100).count(), 2)\n\n\nif __name__ == '__main__':\n    SparkTest.main(__file__)\n"
  },
  {
    "path": "python/test/test_row_number.py",
    "content": "#  Copyright 2022 G-Research\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n\nfrom unittest import skipIf\n\nfrom pyspark.storagelevel import StorageLevel\n\nfrom spark_common import SparkTest\nimport gresearch.spark\n\n\n@skipIf(SparkTest.is_spark_connect, \"Spark Connect does not provide access to the JVM, required by RowNumber\")\nclass RowNumberTest(SparkTest):\n\n    @classmethod\n    def setUpClass(cls):\n        super(RowNumberTest, cls).setUpClass()\n\n        cls.df1 = cls.spark.createDataFrame([\n            (1, 'one'),\n            (2, 'two'),\n            (3, 'three'),\n            (4, 'four'),\n        ], ['id', 'value'])\n        cls.expected1 = [\n            {'id': 1, 'value': 'one', 'row_number': 1},\n            {'id': 2, 'value': 'two', 'row_number': 2},\n            {'id': 3, 'value': 'three', 'row_number': 3},\n            {'id': 4, 'value': 'four', 'row_number': 4},\n        ]\n        cls.expected1Desc = [\n            {'id': 1, 'value': 'one', 'row_number': 4},\n            {'id': 2, 'value': 'two', 'row_number': 3},\n            {'id': 3, 'value': 'three', 'row_number': 2},\n            {'id': 4, 'value': 'four', 'row_number': 1},\n        ]\n\n        cls.df2 = cls.spark.createDataFrame([\n            (1, 'one'),\n            (2, 'TWO'),\n            (2, 'two'),\n            (3, 'three'),\n        ], ['id', 'value'])\n        cls.expected2 = [\n            {'id': 1, 'value': 'one', 'row_number': 1},\n            {'id': 2, 'value': 'TWO', 'row_number': 2},\n            {'id': 2, 'value': 'two', 'row_number': 3},\n            {'id': 3, 'value': 'three', 'row_number': 4},\n        ]\n        cls.expected2Desc = [\n            {'id': 1, 'value': 'one', 'row_number': 4},\n            {'id': 2, 'value': 'TWO', 'row_number': 3},\n            {'id': 2, 'value': 'two', 'row_number': 2},\n            {'id': 3, 'value': 'three', 'row_number': 1},\n        ]\n\n    def test_row_numbers(self):\n        rows = self.df1.with_row_numbers().orderBy('id', 'value').collect()\n        self.assertEqual(self.expected1, [row.asDict() for row in rows])\n\n    def test_row_numbers_order_one_column(self):\n        for order in ['id', ['id'], self.df1.id, [self.df1.id]]:\n            with self.subTest(order=order):\n                rows = self.df1.with_row_numbers(order=order).orderBy('id', 'value').collect()\n                self.assertEqual(self.expected1, [row.asDict() for row in rows])\n\n    def test_row_numbers_order_two_columns(self):\n        for order in [['id', 'value'], [self.df2.id, self.df2.value]]:\n            with self.subTest(order=order):\n                rows = self.df2.with_row_numbers(order=order).orderBy('id', 'value').collect()\n                self.assertEqual(self.expected2, [row.asDict() for row in rows])\n\n    def test_row_numbers_order_not_asc_one_column(self):\n        for order in ['id', ['id'], self.df1.id, [self.df1.id]]:\n            with self.subTest(order=order):\n                rows = self.df1.with_row_numbers(order=order, ascending=False).orderBy('id', 'value').collect()\n                self.assertEqual(self.expected1Desc, [row.asDict() for row in rows])\n\n    def test_row_numbers_order_not_asc_two_columns(self):\n        for order in [['id', 'value'], [self.df2.id, self.df2.value]]:\n            with self.subTest(order=order):\n                rows = self.df2.with_row_numbers(order=order, ascending=False).orderBy('id', 'value').collect()\n                self.assertEqual(self.expected2Desc, [row.asDict() for row in rows])\n\n    def test_row_numbers_order_desc_one_column(self):\n        for order in [self.df1.id.desc(), [self.df1.id.desc()]]:\n            with self.subTest(order=order):\n                rows = self.df1.with_row_numbers(order=order).orderBy('id', 'value').collect()\n                self.assertEqual(self.expected1Desc, [row.asDict() for row in rows])\n\n    def test_row_numbers_order_desc_two_columns(self):\n        for order in [[self.df2.id.desc(), self.df2.value.desc()]]:\n            with self.subTest(order=order):\n                rows = self.df2.with_row_numbers(order=order).orderBy('id', 'value').collect()\n                self.assertEqual(self.expected2Desc, [row.asDict() for row in rows])\n\n    def test_row_numbers_unpersist(self):\n        for storage_level in [StorageLevel.MEMORY_AND_DISK, StorageLevel.MEMORY_ONLY, StorageLevel.DISK_ONLY]:\n            with self.subTest(storage_level=storage_level):\n                # make sure the cache is clear\n                jcm = self.spark._jsparkSession.sharedState().cacheManager()\n                jcm.clearCache()\n                self.assertTrue(jcm.isEmpty())\n\n                unpersist = self.spark.unpersist_handle()\n                self.df1.with_row_numbers(storage_level=storage_level, unpersist_handle=unpersist) \\\n                    .orderBy('id', 'value').collect()\n\n                # the cache should not be empty now\n                self.assertFalse(jcm.isEmpty())\n                unpersist(blocking=True)\n\n                # this should have removed the only DataFrame from the cache\n                self.assertTrue(jcm.isEmpty())\n\n                # calling unpersist again does not hurt, this time without blocking\n                unpersist()\n\n    def test_row_numbers_row_number_col_name(self):\n        rows = self.df1.with_row_numbers(row_number_column_name='row').orderBy('id', 'value').collect()\n        self.assertEqual([{'row' if k == 'row_number' else k: v for k, v in row.items()}\n                          for row in self.expected1],\n                         [row.asDict() for row in rows])\n\n\nif __name__ == '__main__':\n    SparkTest.main(__file__)\n"
  },
  {
    "path": "release.sh",
    "content": "#!/bin/bash\n#\n# Copyright 2020 G-Research\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n\n# Script to prepare release, see RELEASE.md for details\n\nset -euo pipefail\n\n# check for clean git status (except for CHANGELOG.md and release.sh)\nreadarray -t git_status < <(git status -s --untracked-files=no 2>/dev/null | grep -v -e \" CHANGELOG.md$\" -e \" release.sh$\")\nif [ ${#git_status[@]} -gt 0 ]\nthen\n  echo \"There are pending git changes:\"\n  for (( i=0; i<${#git_status[@]}; i++ )); do echo \"${git_status[$i]}\" ; done\n  exit 1\nfi\n\n# check for unreleased entry in CHANGELOG.md\nreadarray -t changes < <(grep -A 100 \"^## \\[UNRELEASED\\] - YYYY-MM-DD\" CHANGELOG.md | grep -B 100 --max-count=1 -E \"^## \\[[0-9.]+\\]\" | grep \"^-\")\nif [ ${#changes[@]} -eq 0 ]\nthen\n  echo \"Did not find any changes in CHANGELOG.md under '## [UNRELEASED] - YYYY-MM-DD'\"\n  exit 1\nfi\n\n# check this is a SNAPSHOT versions\nif ! grep -q \"<version>.*-SNAPSHOT</version>\" pom.xml\nthen\n  echo \"Version in pom is not a SNAPSHOT version, cannot test all versions\"\n  exit 1\nfi\n\n# check for existing cached SNAPSHOT jars\nversion=$(grep --max-count=1 \"<version>.*</version>\" pom.xml | sed -E -e \"s/\\s*<[^>]+>//g\" -e \"s/-SNAPSHOT//\" -e \"s/-[0-9.]+//g\")\njars=$(find $HOME/.m2 $HOME/.ivy2 -name \"*spark-extension_*-$version-*-SNAPSHOT.jar\")\nif [[ -n \"$jars\" ]]\nthen\n  echo \"There are installed SNAPSHOT jars, these may interfere with release tests. These must be deleted first:\"\n  echo \"$jars\" | tr '\\n' ' '\n  echo\n  exit 1\nfi\n\n# testing all versions\nrm -rf metastore_db/ spark-warehouse/\n./set-version.sh 3.2.4 2.12.15; mvn clean deploy -Dsign; ./build-whl.sh; ./test-release.sh\n./set-version.sh 3.3.4 2.12.15; mvn clean deploy -Dsign; ./build-whl.sh; ./test-release.sh\n./set-version.sh 3.4.4 2.12.17; mvn clean deploy -Dsign; ./build-whl.sh; ./test-release.sh\n./set-version.sh 3.5.3 2.12.18; mvn clean deploy -Dsign; ./build-whl.sh; ./test-release.sh\nrm -rf python/dist\n\n./set-version.sh 3.2.4 2.13.5; mvn clean deploy -Dsign; ./test-release.sh\n./set-version.sh 3.3.4 2.13.8; mvn clean deploy -Dsign; ./test-release.sh\n./set-version.sh 3.4.4 2.13.8; mvn clean deploy -Dsign; ./test-release.sh\n./set-version.sh 3.5.3 2.13.8; mvn clean deploy -Dsign; ./test-release.sh\nrm -rf metastore_db/ spark-warehouse/\n\n# all SNAPSHOT versions build, test and complete the example, releasing\n\n# revert pom.xml and python/setup.py changes\ngit checkout pom.xml python/setup.py\n\n# get latest and release version\nlatest=$(grep --max-count=1 \"<version>.*</version>\" README.md | sed -E -e \"s/\\s*<[^>]+>//g\" -e \"s/-[0-9.]+//g\")\nversion=$(grep --max-count=1 \"<version>.*</version>\" pom.xml | sed -E -e \"s/\\s*<[^>]+>//g\" -e \"s/-SNAPSHOT//\" -e \"s/-[0-9.]+//g\")\n\necho \"Releasing ${#changes[@]} changes as version $version:\"\nfor (( i=0; i<${#changes[@]}; i++ )); do echo \"${changes[$i]}\" ; done\n\nsed -i \"s/## \\[UNRELEASED\\] - YYYY-MM-DD/## [$version] - $(date +%Y-%m-%d)/\" CHANGELOG.md\nsed -i -e \"s/$latest-/$version-/g\" -e \"s/$latest\\./$version./g\" -e \"s/v$latest/v$version/g\" README.md PYSPARK-DEPS.md python/README.md\n./set-version.sh $version\n\n# commit changes to local repo\necho\necho \"Committing release to local git\"\ngit add pom.xml python/setup.py CHANGELOG.md README.md PYSPARK-DEPS.md python/README.md\ngit commit -m \"Releasing $version\"\ngit tag -a \"v${version}\" -m \"Release v${version}\"\n\necho \"Please inspect git changes:\"\ngit show HEAD\necho \"Press <ENTER> to push to origin\"\nread\n\necho \"Pushing release commit and tag to origin\"\ngit push origin master \"v${version}\"\necho\n\n# create release\necho \"Creating release packages\"\nmkdir -p python/pyspark/jars/\n./set-version.sh 3.2.4 2.12.15; mvn clean deploy -Dsign -Dspotless.check.skip -DskipTests -Dmaven.test.skip=true; ./build-whl.sh\n./set-version.sh 3.3.4 2.12.15; mvn clean deploy -Dsign -Dspotless.check.skip -DskipTests -Dmaven.test.skip=true; ./build-whl.sh\n./set-version.sh 3.4.4 2.12.17; mvn clean deploy -Dsign -Dspotless.check.skip -DskipTests -Dmaven.test.skip=true; ./build-whl.sh\n./set-version.sh 3.5.3 2.12.18; mvn clean deploy -Dsign -Dspotless.check.skip -DskipTests -Dmaven.test.skip=true; ./build-whl.sh\n\n./set-version.sh 3.2.4 2.13.5; mvn clean deploy -Dsign -Dspotless.check.skip -DskipTests -Dmaven.test.skip=true\n./set-version.sh 3.3.4 2.13.8; mvn clean deploy -Dsign -Dspotless.check.skip -DskipTests -Dmaven.test.skip=true\n./set-version.sh 3.4.4 2.13.8; mvn clean deploy -Dsign -Dspotless.check.skip -DskipTests -Dmaven.test.skip=true\n./set-version.sh 3.5.3 2.13.8; mvn clean deploy -Dsign -Dspotless.check.skip -DskipTests -Dmaven.test.skip=true\n\n# upload to test PyPi\npip install twine\ntwine check python/dist/*\npython3 -m twine upload --repository testpypi python/dist/*\n\necho \"Press <ENTER> to upload to PyPi\"\nread\n\n# upload to PyPi\npython3 -m twine upload python/dist/*\n\necho\n\ngit checkout pom.xml python/setup.py\n./bump-version.sh\n"
  },
  {
    "path": "set-version.sh",
    "content": "#!/bin/bash\n\nif [ $# -eq 1 ]\nthen\n    IFS=-\n    read version flavour <<< \"$1\"\n\n    echo \"setting version=$version${flavour:+ with }$flavour\"\n\n    sed -i -E \\\n        -e \"s%^(  <version>)[^-]+-([^-]+).*(</version>)$%\\1$version-\\2${flavour:+-}$flavour\\3%\" \\\n        pom.xml\n\n    version=$(grep -m 1 version pom.xml | sed \"s/\\s*<[^>]*>\\s*//g\")\n\n    sed -i -E \\\n        -e \"s/(jar_version *= *).*/\\1'$version'/\" \\\n        python/setup.py\nelif [ $# -eq 2 ]\nthen\n    spark=$1\n    scala=$2\n\n    spark_compat=${spark%.*}\n    scala_compat=${scala%.*}\n\n    spark_major=${spark_compat%.*}\n    scala_major=${scala_compat%.*}\n\n    spark_minor=${spark_compat/*./}\n    scala_minor=${scala_compat/*./}\n\n    spark_patch=${spark/*./}\n    scala_patch=${scala/*./}\n\n    echo \"setting spark=$spark and scala=$scala\"\n    sed -i -E \\\n        -e \"s%^(  <artifactId>)([^_]+)[_0-9.]+(</artifactId>)$%\\1\\2_${scala_compat}\\3%\" \\\n        -e \"s%^(  <version>)([^-]+)-[^-]+(.*</version>)$%\\1\\2-$spark_compat\\3%\" \\\n        -e \"s%^(    <scala.major.version>).+(</scala.major.version>)$%\\1${scala_major}\\2%\" \\\n        -e \"s%^(    <scala.minor.version>).+(</scala.minor.version>)$%\\1${scala_minor}\\2%\" \\\n        -e \"s%^(    <scala.patch.version>).+(</scala.patch.version>)$%\\1${scala_patch}\\2%\" \\\n        -e \"s%^(    <spark.major.version>).+(</spark.major.version>)$%\\1${spark_major}\\2%\" \\\n        -e \"s%^(    <spark.minor.version>).+(</spark.minor.version>)$%\\1${spark_minor}\\2%\" \\\n        -e \"s%^(    <spark.patch.version>).+(</spark.patch.version>)$%\\1${spark_patch}\\2%\" \\\n        pom.xml\n\n    version=$(grep -m 1 version pom.xml | sed \"s/\\s*<[^>]*>\\s*//g\")\n\n    sed -i -E \\\n        -e \"s/(jar_version *= *).*/\\1'$version'/\" \\\n        -e \"s/(scala_version *= *).*/\\1'$scala'/\" \\\n        python/setup.py\nelse\n    echo \"Provide the Spark-Extension version (e.g. 2.5.0 or 2.5.0-SNAPSHOT), or the Spark and Scala version\"\n    exit 1\nfi\n\n"
  },
  {
    "path": "src/main/scala/uk/co/gresearch/package.scala",
    "content": "/*\n * Copyright 2020 G-Research\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage uk.co\n\npackage object gresearch {\n\n  trait ConditionalCall[T] {\n    def call(f: T => T): T\n    def either[R](f: T => R): ConditionalCallOr[T, R]\n  }\n\n  trait ConditionalCallOr[T, R] {\n    def or(f: T => R): R\n  }\n\n  case class TrueCall[T](t: T) extends ConditionalCall[T] {\n    override def call(f: T => T): T = f(t)\n    override def either[R](f: T => R): ConditionalCallOr[T, R] = TrueCallOr[T, R](f(t))\n  }\n\n  case class FalseCall[T](t: T) extends ConditionalCall[T] {\n    override def call(f: T => T): T = t\n    override def either[R](f: T => R): ConditionalCallOr[T, R] = FalseCallOr[T, R](t)\n  }\n\n  case class TrueCallOr[T, R](r: R) extends ConditionalCallOr[T, R] {\n    override def or(f: T => R): R = r\n  }\n\n  case class FalseCallOr[T, R](t: T) extends ConditionalCallOr[T, R] {\n    override def or(f: T => R): R = f(t)\n  }\n\n  implicit class ExtendedAny[T](t: T) {\n\n    /**\n     * Allows to call a function on the decorated instance conditionally.\n     *\n     * This allows fluent code like\n     *\n     * {{{\n     * i.doThis()\n     *  .doThat()\n     *  .on(condition).call(function)\n     *  .on(condition).either(function1).or(function2)\n     *  .doMore()\n     * }}}\n     *\n     * rather than\n     *\n     * {{{\n     * val temp = i.doThis()\n     *             .doThat()\n     * val temp2 = if (condition) function(temp) else temp\n     * temp2.doMore()\n     * }}}\n     *\n     * which either needs many temporary variables or duplicate code.\n     *\n     * @param condition\n     *   condition\n     * @return\n     *   the function result\n     */\n    def on(condition: Boolean): ConditionalCall[T] = {\n      if (condition) TrueCall[T](t) else FalseCall[T](t)\n    }\n\n    /**\n     * Allows to call a function on the decorated instance conditionally. This is an alias for the `on` function.\n     *\n     * This allows fluent code like\n     *\n     * {{{\n     * i.doThis()\n     *  .doThat()\n     *  .when(condition).call(function)\n     *  .when(condition).either(function1).or(function2)\n     *  .doMore()\n     *\n     *\n     * rather than\n     *\n     * {{{\n     * val temp = i.doThis()\n     *             .doThat()\n     * val temp2 = if (condition) function(temp) else temp\n     * temp2.doMore()\n     * }}}\n     *\n     * which either needs many temporary variables or duplicate code.\n     *\n     * @param condition\n     *   condition\n     * @return\n     *   the function result\n     */\n    def when(condition: Boolean): ConditionalCall[T] = on(condition)\n\n    /**\n     * Executes the given function on the decorated instance.\n     *\n     * This allows writing fluent code like\n     *\n     * {{{\n     * i.doThis()\n     *  .doThat()\n     *  .call(function)\n     *  .doMore()\n     * }}}\n     *\n     * rather than\n     *\n     * {{{\n     * function(\n     *   i.doThis()\n     *    .doThat()\n     * ).doMore()\n     * }}}\n     *\n     * where the effective sequence of operations is not clear.\n     *\n     * @param f\n     *   function\n     * @return\n     *   the function result\n     */\n    def call[R](f: T => R): R = f(t)\n  }\n\n}\n"
  },
  {
    "path": "src/main/scala/uk/co/gresearch/spark/BuildVersion.scala",
    "content": "/*\n * Copyright 2022 G-Research\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage uk.co.gresearch.spark\n\nimport java.util.Properties\n\n/**\n * Provides versions from build environment.\n */\ntrait BuildVersion {\n  val propertyFileName = \"spark-extension-build.properties\"\n\n  lazy val props: Properties = {\n    val properties = new Properties\n\n    val in = Option(Thread.currentThread().getContextClassLoader.getResourceAsStream(propertyFileName))\n    if (in.isEmpty) {\n      throw new RuntimeException(s\"Property file $propertyFileName not found in class path\")\n    }\n\n    in.foreach(properties.load)\n    properties\n  }\n\n  lazy val VersionString: String = props.getProperty(\"project.version\")\n\n  lazy val BuildSparkMajorVersion: Int = props.getProperty(\"spark.major.version\").toInt\n  lazy val BuildSparkMinorVersion: Int = props.getProperty(\"spark.minor.version\").toInt\n  lazy val BuildSparkPatchVersion: Int = props.getProperty(\"spark.patch.version\").split(\"-\").head.toInt\n  lazy val BuildSparkCompatVersionString: String = props.getProperty(\"spark.compat.version\")\n\n  lazy val BuildScalaMajorVersion: Int = props.getProperty(\"scala.major.version\").toInt\n  lazy val BuildScalaMinorVersion: Int = props.getProperty(\"scala.minor.version\").toInt\n  lazy val BuildScalaPatchVersion: Int = props.getProperty(\"scala.patch.version\").toInt\n  lazy val BuildScalaCompatVersionString: String = props.getProperty(\"scala.compat.version\")\n\n  val BuildSparkVersion: (Int, Int, Int) = (BuildSparkMajorVersion, BuildSparkMinorVersion, BuildSparkPatchVersion)\n  val BuildSparkCompatVersion: (Int, Int) = (BuildSparkMajorVersion, BuildSparkMinorVersion)\n\n  val BuildScalaVersion: (Int, Int, Int) = (BuildScalaMajorVersion, BuildScalaMinorVersion, BuildScalaPatchVersion)\n  val BuildScalaCompatVersion: (Int, Int) = (BuildScalaMajorVersion, BuildScalaMinorVersion)\n}\n"
  },
  {
    "path": "src/main/scala/uk/co/gresearch/spark/Histogram.scala",
    "content": "/*\n * Copyright 2020 G-Research\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage uk.co.gresearch.spark\n\nimport org.apache.spark.sql.functions.{sum, when}\nimport org.apache.spark.sql.{Column, DataFrame, Dataset}\nimport uk.co.gresearch.ExtendedAny\n\nimport scala.collection.JavaConverters\n\nobject Histogram {\n\n  /**\n   * Compute the histogram of a column when aggregated by aggregate columns. Thresholds are expected to be provided in\n   * ascending order. The result dataframe contains the aggregate and histogram columns only. For each threshold value\n   * in thresholds, there will be a column named s\"≤threshold\". There will also be a final column called\n   * s\">last_threshold\", that counts the remaining values that exceed the last threshold.\n   *\n   * @param df\n   *   dataset to compute histogram from\n   * @param thresholds\n   *   sequence of thresholds in ascending order, must implement <= and > operators w.r.t. valueColumn\n   * @param valueColumn\n   *   histogram is computed for values of this column\n   * @param aggregateColumns\n   *   histogram is computed against these columns\n   * @tparam T\n   *   type of histogram thresholds\n   * @return\n   *   dataframe with aggregate and histogram columns\n   */\n  def of[D, T](df: Dataset[D], thresholds: Seq[T], valueColumn: Column, aggregateColumns: Column*): DataFrame = {\n    if (thresholds.isEmpty)\n      throw new IllegalArgumentException(\"Thresholds must not be empty\")\n\n    val bins = if (thresholds.length == 1) Seq.empty else thresholds.sliding(2).toSeq\n\n    if (bins.exists(s => s.head == s.last))\n      throw new IllegalArgumentException(s\"Thresholds must not contain duplicates: ${thresholds.mkString(\",\")}\")\n\n    df.toDF()\n      .withColumn(s\"≤${thresholds.head}\", when(valueColumn <= thresholds.head, 1).otherwise(0))\n      .call(bins.foldLeft(_) { case (df, bin) =>\n        df.withColumn(s\"≤${bin.last}\", when(valueColumn > bin.head && valueColumn <= bin.last, 1).otherwise(0))\n      })\n      .withColumn(s\">${thresholds.last}\", when(valueColumn > thresholds.last, 1).otherwise(0))\n      .groupBy(aggregateColumns: _*)\n      .agg(\n        Some(thresholds.head).map(t => sum(backticks(s\"≤$t\")).as(s\"≤$t\")).get,\n        thresholds.tail.map(t => sum(backticks(s\"≤$t\")).as(s\"≤$t\")) :+\n          sum(backticks(s\">${thresholds.last}\")).as(s\">${thresholds.last}\"): _*\n      )\n  }\n\n  /**\n   * Compute the histogram of a column when aggregated by aggregate columns. Thresholds are expected to be provided in\n   * ascending order. The result dataframe contains the aggregate and histogram columns only. For each threshold value\n   * in thresholds, there will be a column named s\"≤threshold\". There will also be a final column called\n   * s\">last_threshold\", that counts the remaining values that exceed the last threshold.\n   *\n   * @param df\n   *   dataset to compute histogram from\n   * @param thresholds\n   *   sequence of thresholds in ascending order, must implement <= and > operators w.r.t. valueColumn\n   * @param valueColumn\n   *   histogram is computed for values of this column\n   * @param aggregateColumns\n   *   histogram is computed against these columns\n   * @tparam T\n   *   type of histogram thresholds\n   * @return\n   *   dataframe with aggregate and histogram columns\n   */\n  @scala.annotation.varargs\n  def of[D, T](\n      df: Dataset[D],\n      thresholds: java.util.List[T],\n      valueColumn: Column,\n      aggregateColumns: Column*\n  ): DataFrame =\n    of(df, JavaConverters.iterableAsScalaIterable(thresholds).toSeq, valueColumn, aggregateColumns: _*)\n\n}\n"
  },
  {
    "path": "src/main/scala/uk/co/gresearch/spark/RowNumbers.scala",
    "content": "/*\n * Copyright 2023 G-Research\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage uk.co.gresearch.spark\n\nimport org.apache.spark.sql.expressions.Window\nimport org.apache.spark.sql.{Column, DataFrame, Dataset, functions}\nimport org.apache.spark.sql.functions.{coalesce, col, lit, max, monotonically_increasing_id, spark_partition_id, sum}\nimport org.apache.spark.storage.StorageLevel\n\ncase class RowNumbersFunc(\n    rowNumberColumnName: String = \"row_number\",\n    storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK,\n    unpersistHandle: UnpersistHandle = UnpersistHandle.Noop,\n    orderColumns: Seq[Column] = Seq.empty\n) {\n\n  def withRowNumberColumnName(rowNumberColumnName: String): RowNumbersFunc =\n    this.copy(rowNumberColumnName = rowNumberColumnName)\n\n  def withStorageLevel(storageLevel: StorageLevel): RowNumbersFunc =\n    this.copy(storageLevel = storageLevel)\n\n  def withUnpersistHandle(unpersistHandle: UnpersistHandle): RowNumbersFunc =\n    this.copy(unpersistHandle = unpersistHandle)\n\n  def withOrderColumns(orderColumns: Seq[Column]): RowNumbersFunc =\n    this.copy(orderColumns = orderColumns)\n\n  def of[D](df: Dataset[D]): DataFrame = {\n    if (\n      storageLevel.equals(\n        StorageLevel.NONE\n      ) && (SparkMajorVersion > 3 || SparkMajorVersion == 3 && SparkMinorVersion >= 5)\n    ) {\n      throw new IllegalArgumentException(s\"Storage level $storageLevel not supported with Spark 3.5.0 and above.\")\n    }\n\n    // define some column names that do not exist in ds\n    val prefix = distinctPrefixFor(df.columns)\n    val monoIdColumnName = prefix + \"mono_id\"\n    val partitionIdColumnName = prefix + \"partition_id\"\n    val localRowNumberColumnName = prefix + \"local_row_number\"\n    val maxLocalRowNumberColumnName = prefix + \"max_local_row_number\"\n    val cumRowNumbersColumnName = prefix + \"cum_row_numbers\"\n    val partitionOffsetColumnName = prefix + \"partition_offset\"\n\n    // if no order is given, we preserve existing order\n    val dfOrdered =\n      if (orderColumns.isEmpty) df.withColumn(monoIdColumnName, monotonically_increasing_id())\n      else df.orderBy(orderColumns: _*)\n    val order = if (orderColumns.isEmpty) Seq(col(monoIdColumnName)) else orderColumns\n\n    // add partition ids and local row numbers\n    val localRowNumberWindow = Window.partitionBy(partitionIdColumnName).orderBy(order: _*)\n    val dfWithPartitionId = dfOrdered\n      .withColumn(partitionIdColumnName, spark_partition_id())\n      .persist(storageLevel)\n    unpersistHandle.setDataFrame(dfWithPartitionId)\n    val dfWithLocalRowNumbers = dfWithPartitionId\n      .withColumn(localRowNumberColumnName, functions.row_number().over(localRowNumberWindow))\n\n    // compute row offset for the partitions\n    val cumRowNumbersWindow = Window\n      .orderBy(partitionIdColumnName)\n      .rowsBetween(Window.unboundedPreceding, Window.currentRow)\n    val partitionOffsets = dfWithLocalRowNumbers\n      .groupBy(partitionIdColumnName)\n      .agg(max(localRowNumberColumnName).alias(maxLocalRowNumberColumnName))\n      .withColumn(cumRowNumbersColumnName, sum(maxLocalRowNumberColumnName).over(cumRowNumbersWindow))\n      .select(\n        col(partitionIdColumnName) + 1 as partitionIdColumnName,\n        col(cumRowNumbersColumnName).as(partitionOffsetColumnName)\n      )\n\n    // compute global row number by adding local row number with partition offset\n    val partitionOffsetColumn = coalesce(col(partitionOffsetColumnName), lit(0))\n    dfWithLocalRowNumbers\n      .join(partitionOffsets, Seq(partitionIdColumnName), \"left\")\n      .withColumn(rowNumberColumnName, col(localRowNumberColumnName) + partitionOffsetColumn)\n      .drop(monoIdColumnName, partitionIdColumnName, localRowNumberColumnName, partitionOffsetColumnName)\n  }\n\n}\n\nobject RowNumbers {\n  def default(): RowNumbersFunc = RowNumbersFunc()\n\n  def withRowNumberColumnName(rowNumberColumnName: String): RowNumbersFunc =\n    default().withRowNumberColumnName(rowNumberColumnName)\n\n  def withStorageLevel(storageLevel: StorageLevel): RowNumbersFunc =\n    default().withStorageLevel(storageLevel)\n\n  def withUnpersistHandle(unpersistHandle: UnpersistHandle): RowNumbersFunc =\n    default().withUnpersistHandle(unpersistHandle)\n\n  @scala.annotation.varargs\n  def withOrderColumns(orderColumns: Column*): RowNumbersFunc =\n    default().withOrderColumns(orderColumns)\n\n  def of[D](ds: Dataset[D]): DataFrame = default().of(ds)\n}\n"
  },
  {
    "path": "src/main/scala/uk/co/gresearch/spark/SparkVersion.scala",
    "content": "/*\n * Copyright 2023 G-Research\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage uk.co.gresearch.spark\n\nimport org.apache.spark.SPARK_VERSION_SHORT\n\n/**\n * Provides versions form runtime environment.\n */\ntrait SparkVersion {\n  private def SparkVersionSeq: Seq[Int] = SPARK_VERSION_SHORT.split('.').toSeq.map(_.toInt)\n\n  def SparkMajorVersion: Int = SparkVersionSeq.head\n  def SparkMinorVersion: Int = SparkVersionSeq(1)\n  def SparkPatchVersion: Int = SparkVersionSeq(2)\n\n  def SparkVersion: (Int, Int, Int) = (SparkMajorVersion, SparkMinorVersion, SparkPatchVersion)\n  def SparkCompatVersion: (Int, Int) = (SparkMajorVersion, SparkMinorVersion)\n  def SparkCompatVersionString: String = SparkVersionSeq.slice(0, 2).mkString(\".\")\n}\n"
  },
  {
    "path": "src/main/scala/uk/co/gresearch/spark/UnpersistHandle.scala",
    "content": "/*\n * Copyright 2022 G-Research\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage uk.co.gresearch.spark\n\nimport org.apache.spark.sql.DataFrame\n\n/**\n * Handle to call `DataFrame.unpersist` on a `DataFrame` that is not known to the caller. The [[RowNumbers.of]]\n * constructs a `DataFrame` that is based ony an intermediate cached `DataFrame`, for witch `unpersist` must be called.\n * A provided [[UnpersistHandle]] allows to do that in user code.\n */\nclass UnpersistHandle {\n  var df: Option[DataFrame] = None\n\n  private[spark] def setDataFrame(dataframe: DataFrame): DataFrame = {\n    if (df.isDefined) throw new IllegalStateException(\"DataFrame has been set already, it cannot be reused.\")\n    this.df = Some(dataframe)\n    dataframe\n  }\n\n  def apply(): Unit = {\n    this.df.getOrElse(throw new IllegalStateException(\"DataFrame has to be set first\")).unpersist()\n  }\n\n  def apply(blocking: Boolean): Unit = {\n    this.df.getOrElse(throw new IllegalStateException(\"DataFrame has to be set first\")).unpersist(blocking)\n  }\n}\n\ncase class SilentUnpersistHandle() extends UnpersistHandle {\n  override def apply(): Unit = {\n    this.df.foreach(_.unpersist())\n  }\n\n  override def apply(blocking: Boolean): Unit = {\n    this.df.foreach(_.unpersist(blocking))\n  }\n}\n\ncase class NoopUnpersistHandle() extends UnpersistHandle {\n  override def setDataFrame(dataframe: DataFrame): DataFrame = dataframe\n  override def apply(): Unit = {}\n  override def apply(blocking: Boolean): Unit = {}\n}\n\nobject UnpersistHandle {\n  val Noop: NoopUnpersistHandle = NoopUnpersistHandle()\n  def apply(): UnpersistHandle = new UnpersistHandle()\n\n  def withUnpersist[T](blocking: Boolean = false)(func: UnpersistHandle => T): T = {\n    val handle = SilentUnpersistHandle()\n    try {\n      func(handle)\n    } finally {\n      handle(blocking)\n    }\n  }\n}\n"
  },
  {
    "path": "src/main/scala/uk/co/gresearch/spark/diff/App.scala",
    "content": "/*\n * Copyright 2023 G-Research\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage uk.co.gresearch.spark.diff\n\nimport org.apache.spark.sql.functions.col\nimport org.apache.spark.sql.{DataFrame, SaveMode, SparkSession}\nimport scopt.OptionParser\nimport uk.co.gresearch._\n\nobject App {\n  // define available options\n  case class Options(\n      master: Option[String] = None,\n      appName: Option[String] = None,\n      hive: Boolean = false,\n      leftPath: Option[String] = None,\n      rightPath: Option[String] = None,\n      outputPath: Option[String] = None,\n      leftFormat: Option[String] = None,\n      rightFormat: Option[String] = None,\n      outputFormat: Option[String] = None,\n      leftSchema: Option[String] = None,\n      rightSchema: Option[String] = None,\n      leftOptions: Map[String, String] = Map.empty,\n      rightOptions: Map[String, String] = Map.empty,\n      outputOptions: Map[String, String] = Map.empty,\n      ids: Seq[String] = Seq.empty,\n      ignore: Seq[String] = Seq.empty,\n      saveMode: SaveMode = SaveMode.ErrorIfExists,\n      filter: Set[String] = Set.empty,\n      statistics: Boolean = false,\n      diffOptions: DiffOptions = DiffOptions.default\n  )\n\n  // read options from args\n  val programName = s\"spark-extension_${spark.BuildScalaCompatVersionString}-${spark.VersionString}.jar\"\n  val scop = s\"com.github.scopt:scopt_${spark.BuildScalaCompatVersionString}:4.1.0\"\n  val sparkSubmit = s\"spark-submit --packages $scop $programName\"\n  val parser: OptionParser[Options] = new scopt.OptionParser[Options](programName) {\n    head(s\"Spark Diff app (${spark.VersionString})\")\n    head()\n\n    arg[String](\"left\")\n      .required()\n      .valueName(\"<left-path>\")\n      .action((x, c) => c.copy(leftPath = Some(x)))\n      .text(\"file path (requires format option) or table name to read left dataframe\")\n\n    arg[String](\"right\")\n      .required()\n      .valueName(\"<right-path>\")\n      .action((x, c) => c.copy(rightPath = Some(x)))\n      .text(\"file path (requires format option) or table name to read right dataframe\")\n\n    arg[String](\"diff\")\n      .required()\n      .valueName(\"<diff-path>\")\n      .action((x, c) => c.copy(outputPath = Some(x)))\n      .text(\"file path (requires format option) or table name to write diff dataframe\")\n\n    note(\"\")\n    note(\"Examples:\")\n    note(\"\")\n    note(\"  - Diff CSV files 'left.csv' and 'right.csv' and write result into CSV file 'diff.csv':\")\n    note(s\"    $sparkSubmit --format csv left.csv right.csv diff.csv\")\n    note(\"\")\n    note(\"  - Diff CSV file 'left.csv' and Parquet file 'right.parquet' with id column 'id',\")\n    note(\"    and write result into Hive table 'diff':\")\n    note(s\"    $sparkSubmit --left-format csv --right-format parquet --hive --id id left.csv right.parquet diff\")\n\n    note(\"\")\n    note(\"Spark session\")\n    opt[String](\"master\")\n      .valueName(\"<master>\")\n      .action((x, c) => c.copy(master = Some(x)))\n      .text(\"Spark master (local, yarn, ...), not needed with spark-submit\")\n    opt[String](\"app-name\")\n      .valueName(\"<app-name>\")\n      .action((x, c) => c.copy(appName = Some(x)))\n      .text(\"Spark application name\")\n      .withFallback(() => \"Diff App\")\n    opt[Unit](\"hive\")\n      .optional()\n      .action((_, c) => c.copy(hive = true))\n      .text(s\"enable Hive support to read from and write to Hive tables\")\n\n    note(\"\")\n    note(\"Input and output\")\n    opt[String]('f', \"format\")\n      .valueName(\"<format>\")\n      .action((x, c) =>\n        c.copy(\n          leftFormat = c.leftFormat.orElse(Some(x)),\n          rightFormat = c.rightFormat.orElse(Some(x)),\n          outputFormat = c.outputFormat.orElse(Some(x))\n        )\n      )\n      .text(\"input and output file format (csv, json, parquet, ...)\")\n    opt[String](\"left-format\")\n      .valueName(\"<format>\")\n      .action((x, c) => c.copy(leftFormat = Some(x)))\n      .text(\"left input file format (csv, json, parquet, ...)\")\n    opt[String](\"right-format\")\n      .valueName(\"<format>\")\n      .action((x, c) => c.copy(rightFormat = Some(x)))\n      .text(\"right input file format (csv, json, parquet, ...)\")\n    opt[String](\"output-format\")\n      .valueName(\"<formt>\")\n      .action((x, c) => c.copy(outputFormat = Some(x)))\n      .text(\"output file format (csv, json, parquet, ...)\")\n\n    note(\"\")\n    opt[String]('s', \"schema\")\n      .valueName(\"<schema>\")\n      .action((x, c) =>\n        c.copy(\n          leftSchema = c.leftSchema.orElse(Some(x)),\n          rightSchema = c.rightSchema.orElse(Some(x))\n        )\n      )\n      .text(\"input schema\")\n    opt[String](\"left-schema\")\n      .valueName(\"<schema>\")\n      .action((x, c) => c.copy(leftSchema = Some(x)))\n      .text(\"left input schema\")\n    opt[String](\"right-schema\")\n      .valueName(\"<schema>\")\n      .action((x, c) => c.copy(rightSchema = Some(x)))\n      .text(\"right input schema\")\n\n    note(\"\")\n    opt[(String, String)](\"left-option\")\n      .unbounded()\n      .optional()\n      .keyValueName(\"key\", \"val\")\n      .action((x, c) => c.copy(leftOptions = c.leftOptions + (x._1 -> x._2)))\n      .text(\"left input option\")\n    opt[(String, String)](\"right-option\")\n      .unbounded()\n      .optional()\n      .keyValueName(\"key\", \"val\")\n      .action((x, c) => c.copy(rightOptions = c.rightOptions + (x._1 -> x._2)))\n      .text(\"right input option\")\n    opt[(String, String)](\"output-option\")\n      .unbounded()\n      .optional()\n      .keyValueName(\"key\", \"val\")\n      .action((x, c) => c.copy(outputOptions = c.outputOptions + (x._1 -> x._2)))\n      .text(\"output option\")\n\n    note(\"\")\n    opt[String](\"id\")\n      .unbounded()\n      .valueName(\"<name>\")\n      .action((x, c) => c.copy(ids = c.ids :+ x))\n      .text(s\"id column name\")\n    opt[String](\"ignore\")\n      .unbounded()\n      .valueName(\"<name>\")\n      .action((x, c) => c.copy(ignore = c.ignore :+ x))\n      .text(s\"ignore column name\")\n    opt[String](\"save-mode\")\n      .optional()\n      .valueName(\"<save-mode>\")\n      .action((x, c) => c.copy(saveMode = SaveMode.valueOf(x)))\n      .text(s\"save mode for writing output (${SaveMode.values().mkString(\", \")}, default ${Options().saveMode})\")\n    opt[String](\"filter\")\n      .unbounded()\n      .optional()\n      .valueName(\"<filter>\")\n      .action((x, c) => c.copy(filter = c.filter + x))\n      .text(\n        s\"Filters for rows with these diff actions, with default diffing options use 'N', 'I', 'D', or 'C' (see 'Diffing options' section)\"\n      )\n    opt[Unit](\"statistics\")\n      .optional()\n      .action((_, c) => c.copy(statistics = true))\n      .text(s\"Only output statistics on how many rows exist per diff action (see 'Diffing options' section)\")\n\n    note(\"\")\n    note(\"Diffing options\")\n    opt[String](\"diff-column\")\n      .optional()\n      .valueName(\"<name>\")\n      .action((x, c) => c.copy(diffOptions = c.diffOptions.copy(diffColumn = x)))\n      .text(s\"column name for diff column (default '${DiffOptions.default.diffColumn}')\")\n    opt[String](\"left-prefix\")\n      .optional()\n      .valueName(\"<prefix>\")\n      .action((x, c) => c.copy(diffOptions = c.diffOptions.copy(leftColumnPrefix = x)))\n      .text(s\"prefix for left column names (default '${DiffOptions.default.leftColumnPrefix}')\")\n    opt[String](\"right-prefix\")\n      .optional()\n      .valueName(\"<prefix>\")\n      .action((x, c) => c.copy(diffOptions = c.diffOptions.copy(rightColumnPrefix = x)))\n      .text(s\"prefix for right column names (default '${DiffOptions.default.rightColumnPrefix}')\")\n    opt[String](\"insert-value\")\n      .optional()\n      .valueName(\"<value>\")\n      .action((x, c) => c.copy(diffOptions = c.diffOptions.copy(insertDiffValue = x)))\n      .text(s\"value for insertion (default '${DiffOptions.default.insertDiffValue}')\")\n    opt[String](\"change-value\")\n      .optional()\n      .valueName(\"<value>\")\n      .action((x, c) => c.copy(diffOptions = c.diffOptions.copy(changeDiffValue = x)))\n      .text(s\"value for change (default '${DiffOptions.default.changeDiffValue}')\")\n    opt[String](\"delete-value\")\n      .optional()\n      .valueName(\"<value>\")\n      .action((x, c) => c.copy(diffOptions = c.diffOptions.copy(deleteDiffValue = x)))\n      .text(s\"value for deletion (default '${DiffOptions.default.deleteDiffValue}')\")\n    opt[String](\"no-change-value\")\n      .optional()\n      .valueName(\"<val>\")\n      .action((x, c) => c.copy(diffOptions = c.diffOptions.copy(nochangeDiffValue = x)))\n      .text(s\"value for no change (default '${DiffOptions.default.nochangeDiffValue}')\")\n    opt[String](\"change-column\")\n      .optional()\n      .valueName(\"<name>\")\n      .action((x, c) => c.copy(diffOptions = c.diffOptions.copy(changeColumn = Some(x))))\n      .text(s\"column name for change column (default is no such column)\")\n    opt[String](\"diff-mode\")\n      .optional()\n      .valueName(\"<mode>\")\n      .action((x, c) => c.copy(diffOptions = c.diffOptions.copy(diffMode = DiffMode.withName(x))))\n      .text(s\"diff mode (${DiffMode.values.mkString(\", \")}, default ${Options().diffOptions.diffMode})\")\n    opt[Unit](\"sparse\")\n      .optional()\n      .action((_, c) => c.copy(diffOptions = c.diffOptions.copy(sparseMode = true)))\n      .text(s\"enable sparse diff\")\n\n    note(\"\")\n    note(\"General\")\n    help(\"help\").text(\"prints this usage text\")\n  }\n\n  def read(\n      spark: SparkSession,\n      format: Option[String],\n      path: String,\n      schema: Option[String],\n      options: Map[String, String]\n  ): DataFrame =\n    spark.read\n      .when(format.isDefined)\n      .call(_.format(format.get))\n      .options(options)\n      .when(schema.isDefined)\n      .call(_.schema(schema.get))\n      .when(format.isDefined)\n      .either(_.load(path))\n      .or(_.table(path))\n\n  def write(\n      df: DataFrame,\n      format: Option[String],\n      path: String,\n      options: Map[String, String],\n      saveMode: SaveMode,\n      filter: Set[String],\n      saveStats: Boolean,\n      diffOptions: DiffOptions\n  ): Unit =\n    df.when(filter.nonEmpty)\n      .call(_.where(col(diffOptions.diffColumn).isInCollection(filter)))\n      .when(saveStats)\n      .call(_.groupBy(diffOptions.diffColumn).count.orderBy(diffOptions.diffColumn))\n      .write\n      .when(format.isDefined)\n      .call(_.format(format.get))\n      .options(options)\n      .mode(saveMode)\n      .when(format.isDefined)\n      .either(_.save(path))\n      .or(_.saveAsTable(path))\n\n  def main(args: Array[String]): Unit = {\n    // parse options\n    val options = parser.parse(args, Options()) match {\n      case Some(options) => options\n      case None          => sys.exit(1)\n    }\n    val unknownFilters = options.filter.filter(filter => !options.diffOptions.diffValues.contains(filter))\n    if (unknownFilters.nonEmpty) {\n      throw new RuntimeException(\n        s\"Filter ${unknownFilters.mkString(\"'\", \"', '\", \"'\")} not allowed, \" +\n          s\"these are the configured diff values: ${options.diffOptions.diffValues.mkString(\"'\", \"', '\", \"'\")}\"\n      )\n    }\n\n    // create spark session\n    val spark = SparkSession\n      .builder()\n      .appName(options.appName.get)\n      .when(options.hive)\n      .call(_.enableHiveSupport())\n      .when(options.master.isDefined)\n      .call(_.master(options.master.get))\n      .getOrCreate()\n\n    // read and write\n    val left = read(spark, options.leftFormat, options.leftPath.get, options.leftSchema, options.leftOptions)\n    val right = read(spark, options.rightFormat, options.rightPath.get, options.rightSchema, options.rightOptions)\n    val diff = left.diff(right, options.diffOptions, options.ids, options.ignore)\n    write(\n      diff,\n      options.outputFormat,\n      options.outputPath.get,\n      options.outputOptions,\n      options.saveMode,\n      options.filter,\n      options.statistics,\n      options.diffOptions\n    )\n  }\n}\n"
  },
  {
    "path": "src/main/scala/uk/co/gresearch/spark/diff/Diff.scala",
    "content": "/*\n * Copyright 2020 G-Research\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage uk.co.gresearch.spark.diff\n\nimport org.apache.spark.sql._\nimport org.apache.spark.sql.functions._\nimport org.apache.spark.sql.types.{ArrayType, StringType}\nimport uk.co.gresearch.spark.diff.comparator.DiffComparator\nimport uk.co.gresearch.spark.{backticks, distinctPrefixFor}\n\nimport scala.collection.JavaConverters\n\n/**\n * Differ class to diff two Datasets. See Differ.of(…) for details.\n * @param options\n *   options for the diffing process\n */\nclass Differ(options: DiffOptions) {\n\n  private[diff] def checkSchema[T, U](\n      left: Dataset[T],\n      right: Dataset[U],\n      idColumns: Seq[String],\n      ignoreColumns: Seq[String]\n  ): Unit = {\n    require(\n      left.columns.length == left.columns.toSet.size &&\n        right.columns.length == right.columns.toSet.size,\n      \"The datasets have duplicate columns.\\n\" +\n        s\"Left column names: ${left.columns.mkString(\", \")}\\n\" +\n        s\"Right column names: ${right.columns.mkString(\", \")}\"\n    )\n\n    val leftNonIgnored = left.columns.diffCaseSensitivity(ignoreColumns)\n    val rightNonIgnored = right.columns.diffCaseSensitivity(ignoreColumns)\n\n    val exceptIgnoredColumnsMsg = if (ignoreColumns.nonEmpty) \" except ignored columns\" else \"\"\n\n    require(\n      leftNonIgnored.length == rightNonIgnored.length,\n      \"The number of columns doesn't match.\\n\" +\n        s\"Left column names$exceptIgnoredColumnsMsg (${leftNonIgnored.length}): ${leftNonIgnored.mkString(\", \")}\\n\" +\n        s\"Right column names$exceptIgnoredColumnsMsg (${rightNonIgnored.length}): ${rightNonIgnored.mkString(\", \")}\"\n    )\n\n    require(leftNonIgnored.length > 0, s\"The schema$exceptIgnoredColumnsMsg must not be empty\")\n\n    // column types must match but we ignore the nullability of columns\n    val leftFields = left.schema.fields\n      .filter(f => !ignoreColumns.containsCaseSensitivity(f.name))\n      .map(f => handleConfiguredCaseSensitivity(f.name) -> f.dataType)\n    val rightFields = right.schema.fields\n      .filter(f => !ignoreColumns.containsCaseSensitivity(f.name))\n      .map(f => handleConfiguredCaseSensitivity(f.name) -> f.dataType)\n    val leftExtraSchema = leftFields.diff(rightFields)\n    val rightExtraSchema = rightFields.diff(leftFields)\n    require(\n      leftExtraSchema.isEmpty && rightExtraSchema.isEmpty,\n      \"The datasets do not have the same schema.\\n\" +\n        s\"Left extra columns: ${leftExtraSchema.map(t => s\"${t._1} (${t._2})\").mkString(\", \")}\\n\" +\n        s\"Right extra columns: ${rightExtraSchema.map(t => s\"${t._1} (${t._2})\").mkString(\", \")}\"\n    )\n\n    val columns = leftNonIgnored\n    val pkColumns = if (idColumns.isEmpty) columns.toList else idColumns\n    val nonPkColumns = columns.diffCaseSensitivity(pkColumns)\n    val missingIdColumns = pkColumns.diffCaseSensitivity(columns)\n    require(\n      missingIdColumns.isEmpty,\n      s\"Some id columns do not exist: ${missingIdColumns.mkString(\", \")} missing among ${columns.mkString(\", \")}\"\n    )\n\n    val missingIgnoreColumns = ignoreColumns.diffCaseSensitivity(left.columns).diffCaseSensitivity(right.columns)\n    require(\n      missingIgnoreColumns.isEmpty,\n      s\"Some ignore columns do not exist: ${missingIgnoreColumns.mkString(\", \")} \" +\n        s\"missing among ${(leftNonIgnored ++ rightNonIgnored).distinct.sorted.mkString(\", \")}\"\n    )\n\n    require(\n      !pkColumns.containsCaseSensitivity(options.diffColumn),\n      s\"The id columns must not contain the diff column name '${options.diffColumn}': ${pkColumns.mkString(\", \")}\"\n    )\n    require(\n      options.changeColumn.forall(!pkColumns.containsCaseSensitivity(_)),\n      s\"The id columns must not contain the change column name '${options.changeColumn.get}': ${pkColumns.mkString(\", \")}\"\n    )\n\n    val diffValueColumns = getDiffValueColumns(pkColumns, nonPkColumns, left, right, ignoreColumns).map(_._1)\n\n    if (Seq(DiffMode.LeftSide, DiffMode.RightSide).contains(options.diffMode)) {\n      require(\n        !diffValueColumns.containsCaseSensitivity(options.diffColumn),\n        s\"The ${if (options.diffMode == DiffMode.LeftSide) \"left\" else \"right\"} \" +\n          s\"non-id columns must not contain the diff column name '${options.diffColumn}': \" +\n          s\"${(if (options.diffMode == DiffMode.LeftSide) left else right).columns.diffCaseSensitivity(idColumns).mkString(\", \")}\"\n      )\n\n      require(\n        options.changeColumn.forall(!diffValueColumns.containsCaseSensitivity(_)),\n        s\"The ${if (options.diffMode == DiffMode.LeftSide) \"left\" else \"right\"} \" +\n          s\"non-id columns must not contain the change column name '${options.changeColumn.get}': \" +\n          s\"${(if (options.diffMode == DiffMode.LeftSide) left else right).columns.diffCaseSensitivity(idColumns).mkString(\", \")}\"\n      )\n    } else {\n      require(\n        !diffValueColumns.containsCaseSensitivity(options.diffColumn),\n        s\"The column prefixes '${options.leftColumnPrefix}' and '${options.rightColumnPrefix}', \" +\n          s\"together with these non-id columns \" +\n          s\"must not produce the diff column name '${options.diffColumn}': \" +\n          s\"${nonPkColumns.mkString(\", \")}\"\n      )\n\n      require(\n        options.changeColumn.forall(!diffValueColumns.containsCaseSensitivity(_)),\n        s\"The column prefixes '${options.leftColumnPrefix}' and '${options.rightColumnPrefix}', \" +\n          s\"together with these non-id columns \" +\n          s\"must not produce the change column name '${options.changeColumn.orNull}': \" +\n          s\"${nonPkColumns.mkString(\", \")}\"\n      )\n\n      require(\n        diffValueColumns.forall(!pkColumns.containsCaseSensitivity(_)),\n        s\"The column prefixes '${options.leftColumnPrefix}' and '${options.rightColumnPrefix}', \" +\n          s\"together with these non-id columns \" +\n          s\"must not produce any id column name '${pkColumns.mkString(\"', '\")}': \" +\n          s\"${nonPkColumns.mkString(\", \")}\"\n      )\n    }\n  }\n\n  private def getChangeColumn(\n      existsColumnName: String,\n      valueColumnsWithComparator: Seq[(String, DiffComparator)],\n      left: Dataset[_],\n      right: Dataset[_]\n  ): Option[Column] = {\n    options.changeColumn\n      .map(changeColumn =>\n        when(left(existsColumnName).isNull || right(existsColumnName).isNull, lit(null))\n          .otherwise(\n            Some(valueColumnsWithComparator)\n              .filter(_.nonEmpty)\n              .map(columns =>\n                concat(\n                  columns\n                    .map { case (c, cmp) =>\n                      when(cmp.equiv(left(backticks(c)), right(backticks(c))), array()).otherwise(array(lit(c)))\n                    }: _*\n                )\n              )\n              .getOrElse(\n                array().cast(ArrayType(StringType, containsNull = false))\n              )\n          )\n          .as(changeColumn)\n      )\n  }\n\n  private[diff] def getDiffIdColumns[T, U](\n      pkColumns: Seq[String],\n      left: Dataset[T],\n      right: Dataset[U],\n  ): Seq[(String, Column)] = {\n    pkColumns.map(c => c -> coalesce(left(backticks(c)), right(backticks(c))).as(c))\n  }\n\n  private[diff] def getDiffValueColumns[T, U](\n      pkColumns: Seq[String],\n      valueColumns: Seq[String],\n      left: Dataset[T],\n      right: Dataset[U],\n      ignoreColumns: Seq[String]\n  ): Seq[(String, Column)] = {\n    val leftValueColumns = left.columns.filterIsInCaseSensitivity(valueColumns)\n    val rightValueColumns = right.columns.filterIsInCaseSensitivity(valueColumns)\n\n    val leftNonPkColumns = left.columns.diffCaseSensitivity(pkColumns)\n    val rightNonPkColumns = right.columns.diffCaseSensitivity(pkColumns)\n\n    val leftIgnoredColumns = left.columns.filterIsInCaseSensitivity(ignoreColumns)\n    val rightIgnoredColumns = right.columns.filterIsInCaseSensitivity(ignoreColumns)\n\n    val (leftValues, rightValues) = if (options.sparseMode) {\n      (\n        leftNonPkColumns\n          .map(c =>\n            (\n              handleConfiguredCaseSensitivity(c),\n              c -> when(not(left(backticks(c)) <=> right(backticks(c))), left(backticks(c)))\n            )\n          )\n          .toMap,\n        rightNonPkColumns\n          .map(c =>\n            (\n              handleConfiguredCaseSensitivity(c),\n              c -> when(not(left(backticks(c)) <=> right(backticks(c))), right(backticks(c)))\n            )\n          )\n          .toMap\n      )\n    } else {\n      (\n        leftNonPkColumns.map(c => (handleConfiguredCaseSensitivity(c), c -> left(backticks(c)))).toMap,\n        rightNonPkColumns.map(c => (handleConfiguredCaseSensitivity(c), c -> right(backticks(c)))).toMap,\n      )\n    }\n\n    def alias(prefix: Option[String], values: Map[String, (String, Column)])(name: String): (String, Column) = {\n      values(handleConfiguredCaseSensitivity(name)) match {\n        case (name, column) =>\n          val alias = prefix.map(p => s\"${p}_$name\").getOrElse(name)\n          alias -> column.as(alias)\n      }\n    }\n\n    def aliasLeft(name: String): (String, Column) = alias(Some(options.leftColumnPrefix), leftValues)(name)\n\n    def aliasRight(name: String): (String, Column) = alias(Some(options.rightColumnPrefix), rightValues)(name)\n\n    val prefixedLeftIgnoredColumns = leftIgnoredColumns.map(c => aliasLeft(c))\n    val prefixedRightIgnoredColumns = rightIgnoredColumns.map(c => aliasRight(c))\n\n    options.diffMode match {\n      case DiffMode.ColumnByColumn =>\n        valueColumns.flatMap(c =>\n          Seq(\n            aliasLeft(c),\n            aliasRight(c)\n          )\n        ) ++ ignoreColumns.flatMap(c =>\n          (if (leftIgnoredColumns.containsCaseSensitivity(c)) Seq(aliasLeft(c)) else Seq.empty) ++\n            (if (rightIgnoredColumns.containsCaseSensitivity(c)) Seq(aliasRight(c)) else Seq.empty)\n        )\n\n      case DiffMode.SideBySide =>\n        leftValueColumns.toSeq.map(c => aliasLeft(c)) ++ prefixedLeftIgnoredColumns ++\n          rightValueColumns.toSeq.map(c => aliasRight(c)) ++ prefixedRightIgnoredColumns\n\n      case DiffMode.LeftSide | DiffMode.RightSide =>\n        // in left-side / right-side mode, we do not prefix columns\n        (\n          if (options.diffMode == DiffMode.LeftSide) valueColumns.map(alias(None, leftValues))\n          else valueColumns.map(alias(None, rightValues))\n        ) ++ (\n          if (options.diffMode == DiffMode.LeftSide) leftIgnoredColumns.map(alias(None, leftValues))\n          else rightIgnoredColumns.map(alias(None, rightValues))\n        )\n    }\n  }\n\n  private[diff] def getDiffColumns[T, U](\n      pkColumns: Seq[String],\n      valueColumns: Seq[String],\n      left: Dataset[T],\n      right: Dataset[U],\n      ignoreColumns: Seq[String]\n  ): Seq[(String, Column)] = {\n    getDiffIdColumns(pkColumns, left, right) ++ getDiffValueColumns(pkColumns, valueColumns, left, right, ignoreColumns)\n  }\n\n  private def doDiff[T, U](\n      left: Dataset[T],\n      right: Dataset[U],\n      idColumns: Seq[String],\n      ignoreColumns: Seq[String] = Seq.empty\n  ): DataFrame = {\n    checkSchema(left, right, idColumns, ignoreColumns)\n\n    val columns = left.columns.diffCaseSensitivity(ignoreColumns).toList\n    val pkColumns = if (idColumns.isEmpty) columns else idColumns\n    val valueColumns = columns.diffCaseSensitivity(pkColumns)\n    val valueStructFields = left.schema.fields.map(f => f.name -> f).toMap\n    val valueColumnsWithComparator = valueColumns.map(c => c -> options.comparatorFor(valueStructFields(c)))\n\n    val existsColumnName = distinctPrefixFor(left.columns) + \"exists\"\n    val leftWithExists = left.withColumn(existsColumnName, lit(1))\n    val rightWithExists = right.withColumn(existsColumnName, lit(1))\n    val joinCondition =\n      pkColumns.map(c => leftWithExists(backticks(c)) <=> rightWithExists(backticks(c))).reduce(_ && _)\n    val unChanged = valueColumnsWithComparator\n      .map { case (c, cmp) =>\n        cmp.equiv(leftWithExists(backticks(c)), rightWithExists(backticks(c)))\n      }\n      .reduceOption(_ && _)\n\n    val changeCondition = not(unChanged.getOrElse(lit(true)))\n\n    val diffActionColumn =\n      when(leftWithExists(existsColumnName).isNull, lit(options.insertDiffValue))\n        .when(rightWithExists(existsColumnName).isNull, lit(options.deleteDiffValue))\n        .when(changeCondition, lit(options.changeDiffValue))\n        .otherwise(lit(options.nochangeDiffValue))\n        .as(options.diffColumn)\n\n    val diffColumns = getDiffColumns(pkColumns, valueColumns, left, right, ignoreColumns).map(_._2)\n    val changeColumn = getChangeColumn(existsColumnName, valueColumnsWithComparator, leftWithExists, rightWithExists)\n      // turn this column into a sequence of one or none column so we can easily concat it below with diffActionColumn and diffColumns\n      .map(Seq(_))\n      .getOrElse(Seq.empty[Column])\n\n    leftWithExists\n      .join(rightWithExists, joinCondition, \"fullouter\")\n      .select((diffActionColumn +: changeColumn) ++ diffColumns: _*)\n  }\n\n  /**\n   * Returns a new DataFrame that contains the differences between two Datasets of the same type `T`. Both Datasets must\n   * contain the same set of column names and data types. The order of columns in the two Datasets is not relevant as\n   * columns are compared based on the name, not the the position.\n   *\n   * Optional `id` columns are used to uniquely identify rows to compare. If values in any non-id column are differing\n   * between two Datasets, then that row is marked as `\"C\"`hange and `\"N\"`o-change otherwise. Rows of the right Dataset,\n   * that do not exist in the left Dataset (w.r.t. the values in the id columns) are marked as `\"I\"`nsert. And rows of\n   * the left Dataset, that do not exist in the right Dataset are marked as `\"D\"`elete.\n   *\n   * If no id columns are given, all columns are considered id columns. Then, no `\"C\"`hange rows will appear, as all\n   * changes will exists as respective `\"D\"`elete and `\"I\"`nsert.\n   *\n   * The returned DataFrame has the `diff` column as the first column. This holds the `\"N\"`, `\"C\"`, `\"I\"` or `\"D\"`\n   * strings. The id columns follow, then the non-id columns (all remaining columns).\n   *\n   * {{{\n   *   val df1 = Seq((1, \"one\"), (2, \"two\"), (3, \"three\")).toDF(\"id\", \"value\")\n   *   val df2 = Seq((1, \"one\"), (2, \"Two\"), (4, \"four\")).toDF(\"id\", \"value\")\n   *\n   *   differ.diff(df1, df2).show()\n   *\n   *   // output:\n   *   // +----+---+-----+\n   *   // |diff| id|value|\n   *   // +----+---+-----+\n   *   // |   N|  1|  one|\n   *   // |   D|  2|  two|\n   *   // |   I|  2|  Two|\n   *   // |   D|  3|three|\n   *   // |   I|  4| four|\n   *   // +----+---+-----+\n   *\n   *   differ.diff(df1, df2, \"id\").show()\n   *\n   *   // output:\n   *   // +----+---+----------+-----------+\n   *   // |diff| id|left_value|right_value|\n   *   // +----+---+----------+-----------+\n   *   // |   N|  1|       one|        one|\n   *   // |   C|  2|       two|        Two|\n   *   // |   D|  3|     three|       null|\n   *   // |   I|  4|      null|       four|\n   *   // +----+---+----------+-----------+\n   *\n   * }}}\n   *\n   * The id columns are in order as given to the method. If no id columns are given then all columns of this Dataset are\n   * id columns and appear in the same order. The remaining non-id columns are in the order of this Dataset.\n   */\n  @scala.annotation.varargs\n  def diff[T](left: Dataset[T], right: Dataset[T], idColumns: String*): DataFrame =\n    doDiff(left, right, idColumns)\n\n  /**\n   * Returns a new DataFrame that contains the differences between two Datasets of similar types `T` and `U`. Both\n   * Datasets must contain the same set of column names and data types, except for the columns in `ignoreColumns`. The\n   * order of columns in the two Datasets is not relevant as columns are compared based on the name, not the the\n   * position.\n   *\n   * Optional id columns are used to uniquely identify rows to compare. If values in any non-id column are differing\n   * between two Datasets, then that row is marked as `\"C\"`hange and `\"N\"`o-change otherwise. Rows of the right Dataset,\n   * that do not exist in the left Dataset (w.r.t. the values in the id columns) are marked as `\"I\"`nsert. And rows of\n   * the left Dataset, that do not exist in the right Dataset are marked as `\"D\"`elete.\n   *\n   * If no id columns are given, all columns are considered id columns. Then, no `\"C\"`hange rows will appear, as all\n   * changes will exists as respective `\"D\"`elete and `\"I\"`nsert.\n   *\n   * Values in optional ignore columns are not compared but included in the output DataFrame.\n   *\n   * The returned DataFrame has the `diff` column as the first column. This holds the `\"N\"`, `\"C\"`, `\"I\"` or `\"D\"`\n   * strings. The id columns follow, then the non-id columns (all remaining columns).\n   *\n   * {{{\n   *   val df1 = Seq((1, \"one\"), (2, \"two\"), (3, \"three\")).toDF(\"id\", \"value\")\n   *   val df2 = Seq((1, \"one\"), (2, \"Two\"), (4, \"four\")).toDF(\"id\", \"value\")\n   *\n   *   differ.diff(df1, df2).show()\n   *\n   *   // output:\n   *   // +----+---+-----+\n   *   // |diff| id|value|\n   *   // +----+---+-----+\n   *   // |   N|  1|  one|\n   *   // |   D|  2|  two|\n   *   // |   I|  2|  Two|\n   *   // |   D|  3|three|\n   *   // |   I|  4| four|\n   *   // +----+---+-----+\n   *\n   *   differ.diff(df1, df2, Seq(\"id\")).show()\n   *\n   *   // output:\n   *   // +----+---+----------+-----------+\n   *   // |diff| id|left_value|right_value|\n   *   // +----+---+----------+-----------+\n   *   // |   N|  1|       one|        one|\n   *   // |   C|  2|       two|        Two|\n   *   // |   D|  3|     three|       null|\n   *   // |   I|  4|      null|       four|\n   *   // +----+---+----------+-----------+\n   *\n   * }}}\n   *\n   * The id columns are in order as given to the method. If no id columns are given then all columns of this Dataset are\n   * id columns and appear in the same order. The remaining non-id columns are in the order of this Dataset.\n   */\n  def diff[T, U](left: Dataset[T], right: Dataset[U], idColumns: Seq[String], ignoreColumns: Seq[String]): DataFrame =\n    doDiff(left, right, idColumns, ignoreColumns)\n\n  /**\n   * Returns a new DataFrame that contains the differences between two Datasets of similar types `T` and `U`. Both\n   * Datasets must contain the same set of column names and data types, except for the columns in `ignoreColumns`. The\n   * order of columns in the two Datasets is not relevant as columns are compared based on the name, not the the\n   * position.\n   *\n   * Optional id columns are used to uniquely identify rows to compare. If values in any non-id column are differing\n   * between two Datasets, then that row is marked as `\"C\"`hange and `\"N\"`o-change otherwise. Rows of the right Dataset,\n   * that do not exist in the left Dataset (w.r.t. the values in the id columns) are marked as `\"I\"`nsert. And rows of\n   * the left Dataset, that do not exist in the right Dataset are marked as `\"D\"`elete.\n   *\n   * If no id columns are given, all columns are considered id columns. Then, no `\"C\"`hange rows will appear, as all\n   * changes will exists as respective `\"D\"`elete and `\"I\"`nsert.\n   *\n   * Values in optional ignore columns are not compared but included in the output DataFrame.\n   *\n   * The returned DataFrame has the `diff` column as the first column. This holds the `\"N\"`, `\"C\"`, `\"I\"` or `\"D\"`\n   * strings. The id columns follow, then the non-id columns (all remaining columns).\n   *\n   * {{{\n   *   val df1 = Seq((1, \"one\"), (2, \"two\"), (3, \"three\")).toDF(\"id\", \"value\")\n   *   val df2 = Seq((1, \"one\"), (2, \"Two\"), (4, \"four\")).toDF(\"id\", \"value\")\n   *\n   *   differ.diff(df1, df2).show()\n   *\n   *   // output:\n   *   // +----+---+-----+\n   *   // |diff| id|value|\n   *   // +----+---+-----+\n   *   // |   N|  1|  one|\n   *   // |   D|  2|  two|\n   *   // |   I|  2|  Two|\n   *   // |   D|  3|three|\n   *   // |   I|  4| four|\n   *   // +----+---+-----+\n   *\n   *   differ.diff(df1, df2, Seq(\"id\")).show()\n   *\n   *   // output:\n   *   // +----+---+----------+-----------+\n   *   // |diff| id|left_value|right_value|\n   *   // +----+---+----------+-----------+\n   *   // |   N|  1|       one|        one|\n   *   // |   C|  2|       two|        Two|\n   *   // |   D|  3|     three|       null|\n   *   // |   I|  4|      null|       four|\n   *   // +----+---+----------+-----------+\n   *\n   * }}}\n   *\n   * The id columns are in order as given to the method. If no id columns are given then all columns of this Dataset are\n   * id columns and appear in the same order. The remaining non-id columns are in the order of this Dataset.\n   */\n  def diff[T, U](\n      left: Dataset[T],\n      right: Dataset[U],\n      idColumns: java.util.List[String],\n      ignoreColumns: java.util.List[String]\n  ): DataFrame = {\n    diff(\n      left,\n      right,\n      JavaConverters.iterableAsScalaIterable(idColumns).toSeq,\n      JavaConverters.iterableAsScalaIterable(ignoreColumns).toSeq\n    )\n  }\n\n  /**\n   * Returns a new Dataset that contains the differences between two Datasets of the same type `T`.\n   *\n   * See `diff(Dataset[T], Dataset[U], String*)`.\n   *\n   * This requires an additional implicit `Encoder[V]` for the return type `Dataset[V]`.\n   */\n  // no @scala.annotation.varargs here as implicit arguments are explicit in Java\n  // this signature is redundant to the other diffAs method in Java\n  def diffAs[T, U, V](left: Dataset[T], right: Dataset[T], idColumns: String*)(implicit\n      diffEncoder: Encoder[V]\n  ): Dataset[V] = {\n    diffAs(left, right, diffEncoder, idColumns: _*)\n  }\n\n  /**\n   * Returns a new Dataset that contains the differences between two Datasets of similar types `T` and `U`.\n   *\n   * See `diff(Dataset[T], Dataset[U], Seq[String], Seq[String])`.\n   *\n   * This requires an additional implicit `Encoder[V]` for the return type `Dataset[V]`.\n   */\n  def diffAs[T, U, V](left: Dataset[T], right: Dataset[U], idColumns: Seq[String], ignoreColumns: Seq[String])(implicit\n      diffEncoder: Encoder[V]\n  ): Dataset[V] = {\n    diffAs(left, right, diffEncoder, idColumns, ignoreColumns)\n  }\n\n  /**\n   * Returns a new Dataset that contains the differences between two Datasets of the same type `T`.\n   *\n   * See `diff(Dataset[T], Dataset[T], String*)`.\n   *\n   * This requires an additional explicit `Encoder[V]` for the return type `Dataset[V]`.\n   */\n  @scala.annotation.varargs\n  def diffAs[T, V](left: Dataset[T], right: Dataset[T], diffEncoder: Encoder[V], idColumns: String*): Dataset[V] = {\n    diffAs(left, right, diffEncoder, idColumns, Seq.empty)\n  }\n\n  /**\n   * Returns a new Dataset that contains the differences between two Datasets of similar types `T` and `U`.\n   *\n   * See `of(Dataset[T], Dataset[U], Seq[String], Seq[String])`.\n   *\n   * This requires an additional explicit `Encoder[V]` for the return type `Dataset[V]`.\n   */\n  def diffAs[T, U, V](\n      left: Dataset[T],\n      right: Dataset[U],\n      diffEncoder: Encoder[V],\n      idColumns: Seq[String],\n      ignoreColumns: Seq[String]\n  ): Dataset[V] = {\n    val nonIdColumns =\n      if (idColumns.isEmpty) Seq.empty\n      else left.columns.diffCaseSensitivity(idColumns).diffCaseSensitivity(ignoreColumns).toSeq\n    val encColumns = diffEncoder.schema.fields.map(_.name)\n    val diffColumns =\n      Seq(options.diffColumn) ++ getDiffColumns(idColumns, nonIdColumns, left, right, ignoreColumns).map(_._1)\n    val extraColumns = encColumns.diffCaseSensitivity(diffColumns)\n\n    require(\n      extraColumns.isEmpty,\n      s\"Diff encoder's columns must be part of the diff result schema, \" +\n        s\"these columns are unexpected: ${extraColumns.mkString(\", \")}\"\n    )\n\n    diff(left, right, idColumns, ignoreColumns).as[V](diffEncoder)\n  }\n\n  /**\n   * Returns a new Dataset that contains the differences between two Datasets of similar types `T` and `U`.\n   *\n   * See `of(Dataset[T], Dataset[U], Seq[String], Seq[String])`.\n   *\n   * This requires an additional explicit `Encoder[V]` for the return type `Dataset[V]`.\n   */\n  def diffAs[T, U, V](\n      left: Dataset[T],\n      right: Dataset[U],\n      diffEncoder: Encoder[V],\n      idColumns: java.util.List[String],\n      ignoreColumns: java.util.List[String]\n  ): Dataset[V] = {\n    diffAs(\n      left,\n      right,\n      diffEncoder,\n      JavaConverters.iterableAsScalaIterable(idColumns).toSeq,\n      JavaConverters.iterableAsScalaIterable(ignoreColumns).toSeq\n    )\n  }\n\n  /**\n   * Returns a new Dataset that contains the differences between two Dataset of the same type `T` as tuples of type\n   * `(String, T, T)`.\n   *\n   * See `diff(Dataset[T], Dataset[T], String*)`.\n   */\n  @scala.annotation.varargs\n  def diffWith[T](left: Dataset[T], right: Dataset[T], idColumns: String*): Dataset[(String, T, T)] = {\n    val df = diff(left, right, idColumns: _*)\n    diffWith(df, idColumns: _*)(left.encoder, right.encoder)\n  }\n\n  /**\n   * Returns a new Dataset that contains the differences between two Dataset of similar types `T` and `U` as tuples of\n   * type `(String, T, U)`.\n   *\n   * See `diff(Dataset[T], Dataset[U], Seq[String], Seq[String])`.\n   */\n  def diffWith[T, U](\n      left: Dataset[T],\n      right: Dataset[U],\n      idColumns: Seq[String],\n      ignoreColumns: Seq[String]\n  ): Dataset[(String, T, U)] = {\n    val df = diff(left, right, idColumns, ignoreColumns)\n    diffWith(df, idColumns: _*)(left.encoder, right.encoder)\n  }\n\n  /**\n   * Returns a new Dataset that contains the differences between two Dataset of similar types `T` and `U` as tuples of\n   * type `(String, T, U)`.\n   *\n   * See `diff(Dataset[T], Dataset[U], Seq[String], Seq[String])`.\n   */\n  def diffWith[T, U](\n      left: Dataset[T],\n      right: Dataset[U],\n      idColumns: java.util.List[String],\n      ignoreColumns: java.util.List[String]\n  ): Dataset[(String, T, U)] = {\n    diffWith(\n      left,\n      right,\n      JavaConverters.iterableAsScalaIterable(idColumns).toSeq,\n      JavaConverters.iterableAsScalaIterable(ignoreColumns).toSeq\n    )\n  }\n\n  private def columnsOfSide(df: DataFrame, idColumns: Seq[String], sidePrefix: String): Seq[Column] = {\n    val prefix = sidePrefix + \"_\"\n    df.columns\n      .filter(c => idColumns.contains(c) || c.startsWith(sidePrefix))\n      .map(c => if (idColumns.contains(c)) col(c) else col(c).as(c.replace(prefix, \"\")))\n  }\n\n  private def diffWith[T: Encoder, U: Encoder](diff: DataFrame, idColumns: String*): Dataset[(String, T, U)] = {\n    val leftColumns = columnsOfSide(diff, idColumns, options.leftColumnPrefix)\n    val rightColumns = columnsOfSide(diff, idColumns, options.rightColumnPrefix)\n\n    val diffColumn = col(options.diffColumn).as(\"_1\")\n    val leftStruct = when(col(options.diffColumn) === options.insertDiffValue, lit(null))\n      .otherwise(struct(leftColumns: _*))\n      .as(\"_2\")\n    val rightStruct = when(col(options.diffColumn) === options.deleteDiffValue, lit(null))\n      .otherwise(struct(rightColumns: _*))\n      .as(\"_3\")\n\n    val encoder: Encoder[(String, T, U)] = Encoders.tuple(\n      Encoders.STRING,\n      implicitly[Encoder[T]],\n      implicitly[Encoder[U]]\n    )\n\n    diff.select(diffColumn, leftStruct, rightStruct).as(encoder)\n  }\n\n}\n\n/**\n * Diffing singleton with default diffing options.\n */\nobject Diff {\n  val default = new Differ(DiffOptions.default)\n\n  /**\n   * Returns a new DataFrame that contains the differences between two Datasets of the same type `T`. Both Datasets must\n   * contain the same set of column names and data types. The order of columns in the two Datasets is not relevant as\n   * columns are compared based on the name, not the the position.\n   *\n   * Optional id columns are used to uniquely identify rows to compare. If values in any non-id column are differing\n   * between two Datasets, then that row is marked as `\"C\"`hange and `\"N\"`o-change otherwise. Rows of the right Dataset,\n   * that do not exist in the left Dataset (w.r.t. the values in the id columns) are marked as `\"I\"`nsert. And rows of\n   * the left Dataset, that do not exist in the right Dataset are marked as `\"D\"`elete.\n   *\n   * If no id columns are given, all columns are considered id columns. Then, no `\"C\"`hange rows will appear, as all\n   * changes will exists as respective `\"D\"`elete and `\"I\"`nsert.\n   *\n   * The returned DataFrame has the `diff` column as the first column. This holds the `\"N\"`, `\"C\"`, `\"I\"` or `\"D\"`\n   * strings. The id columns follow, then the non-id columns (all remaining columns).\n   *\n   * {{{\n   *   val df1 = Seq((1, \"one\"), (2, \"two\"), (3, \"three\")).toDF(\"id\", \"value\")\n   *   val df2 = Seq((1, \"one\"), (2, \"Two\"), (4, \"four\")).toDF(\"id\", \"value\")\n   *\n   *   Diff.of(df1, df2).show()\n   *\n   *   // output:\n   *   // +----+---+-----+\n   *   // |diff| id|value|\n   *   // +----+---+-----+\n   *   // |   N|  1|  one|\n   *   // |   D|  2|  two|\n   *   // |   I|  2|  Two|\n   *   // |   D|  3|three|\n   *   // |   I|  4| four|\n   *   // +----+---+-----+\n   *\n   *   Diff.of(df1, df2, \"id\").show()\n   *\n   *   // output:\n   *   // +----+---+----------+-----------+\n   *   // |diff| id|left_value|right_value|\n   *   // +----+---+----------+-----------+\n   *   // |   N|  1|       one|        one|\n   *   // |   C|  2|       two|        Two|\n   *   // |   D|  3|     three|       null|\n   *   // |   I|  4|      null|       four|\n   *   // +----+---+----------+-----------+\n   *\n   * }}}\n   *\n   * The id columns are in order as given to the method. If no id columns are given then all columns of this Dataset are\n   * id columns and appear in the same order. The remaining non-id columns are in the order of this Dataset.\n   */\n  @scala.annotation.varargs\n  def of[T](left: Dataset[T], right: Dataset[T], idColumns: String*): DataFrame =\n    default.diff(left, right, idColumns: _*)\n\n  /**\n   * Returns a new DataFrame that contains the differences between two Datasets of similar types `T` and `U`. Both\n   * Datasets must contain the same set of column names and data types, except for the columns in `ignoreColumns`. The\n   * order of columns in the two Datasets is not relevant as columns are compared based on the name, not the the\n   * position.\n   *\n   * Optional id columns are used to uniquely identify rows to compare. If values in any non-id column are differing\n   * between two Datasets, then that row is marked as `\"C\"`hange and `\"N\"`o-change otherwise. Rows of the right Dataset,\n   * that do not exist in the left Dataset (w.r.t. the values in the id columns) are marked as `\"I\"`nsert. And rows of\n   * the left Dataset, that do not exist in the right Dataset are marked as `\"D\"`elete.\n   *\n   * If no id columns are given, all columns are considered id columns. Then, no `\"C\"`hange rows will appear, as all\n   * changes will exists as respective `\"D\"`elete and `\"I\"`nsert.\n   *\n   * Values in optional ignore columns are not compared but included in the output DataFrame.\n   *\n   * The returned DataFrame has the `diff` column as the first column. This holds the `\"N\"`, `\"C\"`, `\"I\"` or `\"D\"`\n   * strings. The id columns follow, then the non-id columns (all remaining columns).\n   *\n   * {{{\n   *   val df1 = Seq((1, \"one\"), (2, \"two\"), (3, \"three\")).toDF(\"id\", \"value\")\n   *   val df2 = Seq((1, \"one\"), (2, \"Two\"), (4, \"four\")).toDF(\"id\", \"value\")\n   *\n   *   Diff.of(df2).show()\n   *\n   *   // output:\n   *   // +----+---+-----+\n   *   // |diff| id|value|\n   *   // +----+---+-----+\n   *   // |   N|  1|  one|\n   *   // |   D|  2|  two|\n   *   // |   I|  2|  Two|\n   *   // |   D|  3|three|\n   *   // |   I|  4| four|\n   *   // +----+---+-----+\n   *\n   *   Diff.of(df2, \"id\").show()\n   *\n   *   // output:\n   *   // +----+---+----------+-----------+\n   *   // |diff| id|left_value|right_value|\n   *   // +----+---+----------+-----------+\n   *   // |   N|  1|       one|        one|\n   *   // |   C|  2|       two|        Two|\n   *   // |   D|  3|     three|       null|\n   *   // |   I|  4|      null|       four|\n   *   // +----+---+----------+-----------+\n   *\n   * }}}\n   *\n   * The id columns are in order as given to the method. If no id columns are given then all columns of this Dataset are\n   * id columns and appear in the same order. The remaining non-id columns are in the order of this Dataset.\n   */\n  def of[T, U](\n      left: Dataset[T],\n      right: Dataset[U],\n      idColumns: Seq[String],\n      ignoreColumns: Seq[String] = Seq.empty\n  ): DataFrame =\n    default.diff(left, right, idColumns, ignoreColumns)\n\n  /**\n   * Returns a new DataFrame that contains the differences between two Datasets of similar types `T` and `U`. Both\n   * Datasets must contain the same set of column names and data types, except for the columns in `ignoreColumns`. The\n   * order of columns in the two Datasets is not relevant as columns are compared based on the name, not the the\n   * position.\n   *\n   * Optional id columns are used to uniquely identify rows to compare. If values in any non-id column are differing\n   * between two Datasets, then that row is marked as `\"C\"`hange and `\"N\"`o-change otherwise. Rows of the right Dataset,\n   * that do not exist in the left Dataset (w.r.t. the values in the id columns) are marked as `\"I\"`nsert. And rows of\n   * the left Dataset, that do not exist in the right Dataset are marked as `\"D\"`elete.\n   *\n   * If no id columns are given, all columns are considered id columns. Then, no `\"C\"`hange rows will appear, as all\n   * changes will exists as respective `\"D\"`elete and `\"I\"`nsert.\n   *\n   * Values in optional ignore columns are not compared but included in the output DataFrame.\n   *\n   * The returned DataFrame has the `diff` column as the first column. This holds the `\"N\"`, `\"C\"`, `\"I\"` or `\"D\"`\n   * strings. The id columns follow, then the non-id columns (all remaining columns).\n   *\n   * {{{\n   *   val df1 = Seq((1, \"one\"), (2, \"two\"), (3, \"three\")).toDF(\"id\", \"value\")\n   *   val df2 = Seq((1, \"one\"), (2, \"Two\"), (4, \"four\")).toDF(\"id\", \"value\")\n   *\n   *   Diff.of(df2).show()\n   *\n   *   // output:\n   *   // +----+---+-----+\n   *   // |diff| id|value|\n   *   // +----+---+-----+\n   *   // |   N|  1|  one|\n   *   // |   D|  2|  two|\n   *   // |   I|  2|  Two|\n   *   // |   D|  3|three|\n   *   // |   I|  4| four|\n   *   // +----+---+-----+\n   *\n   *   Diff.of(df2, \"id\").show()\n   *\n   *   // output:\n   *   // +----+---+----------+-----------+\n   *   // |diff| id|left_value|right_value|\n   *   // +----+---+----------+-----------+\n   *   // |   N|  1|       one|        one|\n   *   // |   C|  2|       two|        Two|\n   *   // |   D|  3|     three|       null|\n   *   // |   I|  4|      null|       four|\n   *   // +----+---+----------+-----------+\n   *\n   * }}}\n   *\n   * The id columns are in order as given to the method. If no id columns are given then all columns of this Dataset are\n   * id columns and appear in the same order. The remaining non-id columns are in the order of this Dataset.\n   */\n  def of[T, U](\n      left: Dataset[T],\n      right: Dataset[U],\n      idColumns: java.util.List[String],\n      ignoreColumns: java.util.List[String]\n  ): DataFrame =\n    default.diff(left, right, idColumns, ignoreColumns)\n\n  /**\n   * Returns a new Dataset that contains the differences between two Datasets of the same type `T`.\n   *\n   * See `of(Dataset[T], Dataset[T], String*)`.\n   *\n   * This requires an additional implicit `Encoder[V]` for the return type `Dataset[V]`.\n   */\n  // no @scala.annotation.varargs here as implicit arguments are explicit in Java\n  // this signature is redundant to the other ofAs method in Java\n  def ofAs[T, V](left: Dataset[T], right: Dataset[T], idColumns: String*)(implicit\n      diffEncoder: Encoder[V]\n  ): Dataset[V] =\n    default.diffAs(left, right, idColumns: _*)\n\n  /**\n   * Returns a new DataFrame that contains the differences between two Datasets of similar types `T` and `U`.\n   *\n   * See `of(Dataset[T], Dataset[U], Seq[String], Seq[String])`.\n   *\n   * This requires an additional implicit `Encoder[V]` for the return type `Dataset[V]`.\n   */\n  def ofAs[T, U, V](\n      left: Dataset[T],\n      right: Dataset[U],\n      idColumns: Seq[String],\n      ignoreColumns: Seq[String] = Seq.empty\n  )(implicit diffEncoder: Encoder[V]): Dataset[V] =\n    default.diffAs(left, right, idColumns, ignoreColumns)\n\n  /**\n   * Returns a new Dataset that contains the differences between two Datasets of the same type `T`.\n   *\n   * See `of(Dataset[T], Dataset[T], String*)`.\n   *\n   * This requires an additional explicit `Encoder[V]` for the return type `Dataset[V]`.\n   */\n  @scala.annotation.varargs\n  def ofAs[T, V](left: Dataset[T], right: Dataset[T], diffEncoder: Encoder[V], idColumns: String*): Dataset[V] =\n    default.diffAs(left, right, diffEncoder, idColumns: _*)\n\n  /**\n   * Returns a new DataFrame that contains the differences between two Datasets of similar types `T` and `U`.\n   *\n   * See `of(Dataset[T], Dataset[U], Seq[String], Seq[String])`.\n   *\n   * This requires an additional explicit `Encoder[V]` for the return type `Dataset[V]`.\n   */\n  def ofAs[T, U, V](\n      left: Dataset[T],\n      right: Dataset[U],\n      diffEncoder: Encoder[V],\n      idColumns: Seq[String],\n      ignoreColumns: Seq[String]\n  ): Dataset[V] =\n    default.diffAs(left, right, diffEncoder, idColumns, ignoreColumns)\n\n  /**\n   * Returns a new DataFrame that contains the differences between two Datasets of similar types `T` and `U`.\n   *\n   * See `of(Dataset[T], Dataset[U], Seq[String], Seq[String])`.\n   *\n   * This requires an additional explicit `Encoder[V]` for the return type `Dataset[V]`.\n   */\n  def ofAs[T, U, V](\n      left: Dataset[T],\n      right: Dataset[U],\n      diffEncoder: Encoder[V],\n      idColumns: java.util.List[String],\n      ignoreColumns: java.util.List[String]\n  ): Dataset[V] =\n    default.diffAs(left, right, diffEncoder, idColumns, ignoreColumns)\n\n  /**\n   * Returns a new Dataset that contains the differences between two Dataset of the same type `T` as tuples of type\n   * `(String, T, T)`.\n   *\n   * See `of(Dataset[T], Dataset[T], String*)`.\n   */\n  @scala.annotation.varargs\n  def ofWith[T](left: Dataset[T], right: Dataset[T], idColumns: String*): Dataset[(String, T, T)] =\n    default.diffWith(left, right, idColumns: _*)\n\n  /**\n   * Returns a new DataFrame that contains the differences between two Datasets of similar types `T` and `U` as tuples\n   * of type `(String, T, U)`.\n   *\n   * See `of(Dataset[T], Dataset[U], Seq[String], Seq[String])`.\n   */\n  def ofWith[T, U](\n      left: Dataset[T],\n      right: Dataset[U],\n      idColumns: Seq[String],\n      ignoreColumns: Seq[String]\n  ): Dataset[(String, T, U)] =\n    default.diffWith(left, right, idColumns, ignoreColumns)\n\n  /**\n   * Returns a new DataFrame that contains the differences between two Datasets of similar types `T` and `U` as tuples\n   * of type `(String, T, U)`.\n   *\n   * See `of(Dataset[T], Dataset[U], Seq[String], Seq[String])`.\n   */\n  def ofWith[T, U](\n      left: Dataset[T],\n      right: Dataset[U],\n      idColumns: java.util.List[String],\n      ignoreColumns: java.util.List[String]\n  ): Dataset[(String, T, U)] =\n    default.diffWith(left, right, idColumns, ignoreColumns)\n}\n"
  },
  {
    "path": "src/main/scala/uk/co/gresearch/spark/diff/DiffComparators.scala",
    "content": "/*\n * Copyright 2022 G-Research\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage uk.co.gresearch.spark.diff\n\nimport org.apache.spark.sql.Encoder\nimport org.apache.spark.sql.types.DataType\nimport uk.co.gresearch.spark.diff.comparator._\n\nimport java.time.Duration\n\nobject DiffComparators {\n\n  /**\n   * The default comparator used in [[DiffOptions.default.defaultComparator]].\n   */\n  def default(): DiffComparator = DefaultDiffComparator\n\n  /**\n   * A comparator equivalent to `Column <=> Column`. Null values are considered equal.\n   */\n  def nullSafeEqual(): DiffComparator = NullSafeEqualDiffComparator\n\n  /**\n   * Return a comparator that uses the given [[math.Equiv]] to compare values of type [[T]]. The implicit [[Encoder]] of\n   * type [[T]] determines the input data type of the comparator. Only columns of that type can be compared.\n   */\n  def equiv[T: Encoder](equiv: math.Equiv[T]): EquivDiffComparator[T] = EquivDiffComparator(equiv)\n\n  /**\n   * Return a comparator that uses the given [[math.Equiv]] to compare values of type [[T]]. Only columns of the given\n   * data type `inputType` can be compared.\n   */\n  def equiv[T](equiv: math.Equiv[T], inputType: DataType): EquivDiffComparator[T] =\n    EquivDiffComparator(equiv, inputType)\n\n  /**\n   * Return a comparator that uses the given [[math.Equiv]] to compare values of any type.\n   */\n  def equiv(equiv: math.Equiv[Any]): EquivDiffComparator[Any] = EquivDiffComparator(equiv)\n\n  /**\n   * This comparator considers values equal when they are less than `epsilon` apart. It can be configured to use\n   * `epsilon` as an absolute (`.asAbsolute()`) threshold, or as relative (`.asRelative()`) to the larger value.\n   * Further, the threshold itself can be considered equal (`.asInclusive()`) or not equal (`.asExclusive()`):\n   *\n   * <ul> <li>`DiffComparator.epsilon(epsilon).asAbsolute().asInclusive()`: `abs(left - right) ≤ epsilon`</li>\n   * <li>`DiffComparator.epsilon(epsilon).asAbsolute().asExclusive()`: `abs(left - right) < epsilon`</li>\n   * <li>`DiffComparator.epsilon(epsilon).asRelative().asInclusive()`: `abs(left - right) ≤ epsilon * max(abs(left),\n   * abs(right))`</li> <li>`DiffComparator.epsilon(epsilon).asRelative().asExclusive()`: `abs(left - right) < epsilon *\n   * max(abs(left), abs(right))`</li> </ul>\n   *\n   * Requires compared column types to implement `-`, `*`, `<`, `==`, and `abs`.\n   */\n  def epsilon(epsilon: Double): EpsilonDiffComparator = EpsilonDiffComparator(epsilon)\n\n  /**\n   * A comparator for string values.\n   *\n   * With `whitespaceAgnostic` set `true`, differences in white spaces are ignored. This ignores leading and trailing\n   * whitespaces as well. With `whitespaceAgnostic` set `false`, this is equal to the default string comparison (see\n   * [[default()]]).\n   */\n  def string(whitespaceAgnostic: Boolean = true): StringDiffComparator =\n    if (whitespaceAgnostic) {\n      WhitespaceDiffComparator\n    } else {\n      StringDiffComparator\n    }\n\n  /**\n   * This comparator considers two `DateType` or `TimestampType` values equal when they are at most `duration` apart.\n   * Duration is an instance of `java.time.Duration`.\n   *\n   * The comparator can be configured to consider `duration` as equal (`.asInclusive()`) or not equal\n   * (`.asExclusive()`): <ul> <li>`DiffComparator.duration(duration).asInclusive()`: `left - right ≤ duration`</li>\n   * <li>`DiffComparator.duration(duration).asExclusive()`: `left - right < duration`</li> </lu>\n   */\n  def duration(duration: Duration): DurationDiffComparator = DurationDiffComparator(duration)\n\n  /**\n   * This comparator compares two `Map[K,V]` values. They are equal when they match in all their keys and values.\n   */\n  def map[K: Encoder, V: Encoder](): DiffComparator = MapDiffComparator[K, V](keyOrderSensitive = false)\n\n  /**\n   * This comparator compares two `Map[keyType,valueType]` values. They are equal when they match in all their keys and\n   * values.\n   */\n  def map(keyType: DataType, valueType: DataType, keyOrderSensitive: Boolean = false): DiffComparator =\n    MapDiffComparator(keyType, valueType, keyOrderSensitive)\n\n  // for backward compatibility to v2.4.0 up to v2.8.0\n  // replace with default value in above map when moving to v3\n  /**\n   * This comparator compares two `Map[K,V]` values. They are equal when they match in all their keys and values.\n   *\n   * @param keyOrderSensitive\n   *   comparator compares key order if true\n   */\n  def map[K: Encoder, V: Encoder](keyOrderSensitive: Boolean): DiffComparator =\n    MapDiffComparator[K, V](keyOrderSensitive)\n}\n"
  },
  {
    "path": "src/main/scala/uk/co/gresearch/spark/diff/DiffOptions.scala",
    "content": "/*\n * Copyright 2020 G-Research\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage uk.co.gresearch.spark.diff\n\nimport org.apache.spark.sql.Encoder\nimport org.apache.spark.sql.types.{DataType, StructField}\nimport uk.co.gresearch.spark.diff\nimport uk.co.gresearch.spark.diff.DiffMode.{Default, DiffMode}\nimport uk.co.gresearch.spark.diff.comparator.{\n  DefaultDiffComparator,\n  DiffComparator,\n  EquivDiffComparator,\n  TypedDiffComparator\n}\n\nimport scala.annotation.varargs\nimport scala.collection.Map\n\n/**\n * The diff mode determines the output columns of the diffing transformation.\n */\nobject DiffMode extends Enumeration {\n  type DiffMode = Value\n\n  /**\n   * The diff mode determines the output columns of the diffing transformation.\n   *\n   *   - ColumnByColumn: The diff contains value columns from the left and right dataset, arranged column by column:\n   *     diff,( changes,) id-1, id-2, …, left-value-1, right-value-1, left-value-2, right-value-2, …\n   *\n   *   - SideBySide: The diff contains value columns from the left and right dataset, arranged side by side: diff,(\n   *     changes,) id-1, id-2, …, left-value-1, left-value-2, …, right-value-1, right-value-2, …\n   *   - LeftSide / RightSide: The diff contains value columns from the left / right dataset only.\n   */\n  val ColumnByColumn, SideBySide, LeftSide, RightSide = Value\n\n  /**\n   * The diff mode determines the output columns of the diffing transformation. The default diff mode is ColumnByColumn.\n   *\n   * Default is not a enum value here (hence the def) so that we do not have to include it in every match clause. We\n   * will see the respective enum value that Default points to instead.\n   */\n  def Default: diff.DiffMode.Value = ColumnByColumn\n\n  // we want to return Default's enum value for 'Default' here but cannot override super.withName.\n  def withNameOption(name: String): Option[Value] = {\n    if (\"Default\".equals(name)) {\n      Some(DiffMode.Default)\n    } else {\n      try {\n        Some(super.withName(name))\n      } catch {\n        case _: NoSuchElementException => None\n      }\n    }\n  }\n\n}\n\n/**\n * Configuration class for diffing Datasets.\n *\n * @param diffColumn\n *   name of the diff column\n * @param leftColumnPrefix\n *   prefix of columns from the left Dataset\n * @param rightColumnPrefix\n *   prefix of columns from the right Dataset\n * @param insertDiffValue\n *   value in diff column for inserted rows\n * @param changeDiffValue\n *   value in diff column for changed rows\n * @param deleteDiffValue\n *   value in diff column for deleted rows\n * @param nochangeDiffValue\n *   value in diff column for un-changed rows\n * @param changeColumn\n *   name of change column\n * @param diffMode\n *   diff output format\n * @param sparseMode\n *   un-changed values are null on both sides\n * @param defaultComparator\n *   default custom comparator\n * @param dataTypeComparators\n *   custom comparator for some data type\n * @param columnNameComparators\n *   custom comparator for some column name\n */\ncase class DiffOptions(\n    diffColumn: String,\n    leftColumnPrefix: String,\n    rightColumnPrefix: String,\n    insertDiffValue: String,\n    changeDiffValue: String,\n    deleteDiffValue: String,\n    nochangeDiffValue: String,\n    changeColumn: Option[String] = None,\n    diffMode: DiffMode = Default,\n    sparseMode: Boolean = false,\n    defaultComparator: DiffComparator = DefaultDiffComparator,\n    dataTypeComparators: Map[DataType, DiffComparator] = Map.empty,\n    columnNameComparators: Map[String, DiffComparator] = Map.empty\n) {\n  // Constructor for Java to construct default options\n  def this() = this(\"diff\", \"left\", \"right\", \"I\", \"C\", \"D\", \"N\")\n  def this(\n      diffColumn: String,\n      leftColumnPrefix: String,\n      rightColumnPrefix: String,\n      insertDiffValue: String,\n      changeDiffValue: String,\n      deleteDiffValue: String,\n      nochangeDiffValue: String,\n      changeColumn: Option[String],\n      diffMode: DiffMode,\n      sparseMode: Boolean\n  ) = {\n    this(\n      diffColumn,\n      leftColumnPrefix,\n      rightColumnPrefix,\n      insertDiffValue,\n      changeDiffValue,\n      deleteDiffValue,\n      nochangeDiffValue,\n      changeColumn,\n      diffMode,\n      sparseMode,\n      DefaultDiffComparator,\n      Map.empty,\n      Map.empty\n    )\n  }\n\n  require(leftColumnPrefix.nonEmpty, \"Left column prefix must not be empty\")\n  require(rightColumnPrefix.nonEmpty, \"Right column prefix must not be empty\")\n  require(\n    handleConfiguredCaseSensitivity(leftColumnPrefix) != handleConfiguredCaseSensitivity(rightColumnPrefix),\n    s\"Left and right column prefix must be distinct: $leftColumnPrefix\"\n  )\n\n  val diffValues = Seq(insertDiffValue, changeDiffValue, deleteDiffValue, nochangeDiffValue)\n  require(diffValues.distinct.length == diffValues.length, s\"Diff values must be distinct: $diffValues\")\n\n  require(\n    !changeColumn.map(handleConfiguredCaseSensitivity).contains(handleConfiguredCaseSensitivity(diffColumn)),\n    s\"Change column name must be different to diff column: $diffColumn\"\n  )\n\n  /**\n   * Fluent method to change the diff column name. Returns a new immutable DiffOptions instance with the new diff column\n   * name.\n   * @param diffColumn\n   *   new diff column name\n   * @return\n   *   new immutable DiffOptions instance\n   */\n  def withDiffColumn(diffColumn: String): DiffOptions = {\n    this.copy(diffColumn = diffColumn)\n  }\n\n  /**\n   * Fluent method to change the prefix of columns from the left Dataset. Returns a new immutable DiffOptions instance\n   * with the new column prefix.\n   * @param leftColumnPrefix\n   *   new column prefix\n   * @return\n   *   new immutable DiffOptions instance\n   */\n  def withLeftColumnPrefix(leftColumnPrefix: String): DiffOptions = {\n    this.copy(leftColumnPrefix = leftColumnPrefix)\n  }\n\n  /**\n   * Fluent method to change the prefix of columns from the right Dataset. Returns a new immutable DiffOptions instance\n   * with the new column prefix.\n   * @param rightColumnPrefix\n   *   new column prefix\n   * @return\n   *   new immutable DiffOptions instance\n   */\n  def withRightColumnPrefix(rightColumnPrefix: String): DiffOptions = {\n    this.copy(rightColumnPrefix = rightColumnPrefix)\n  }\n\n  /**\n   * Fluent method to change the value of inserted rows in the diff column. Returns a new immutable DiffOptions instance\n   * with the new diff value.\n   * @param insertDiffValue\n   *   new diff value\n   * @return\n   *   new immutable DiffOptions instance\n   */\n  def withInsertDiffValue(insertDiffValue: String): DiffOptions = {\n    this.copy(insertDiffValue = insertDiffValue)\n  }\n\n  /**\n   * Fluent method to change the value of changed rows in the diff column. Returns a new immutable DiffOptions instance\n   * with the new diff value.\n   * @param changeDiffValue\n   *   new diff value\n   * @return\n   *   new immutable DiffOptions instance\n   */\n  def withChangeDiffValue(changeDiffValue: String): DiffOptions = {\n    this.copy(changeDiffValue = changeDiffValue)\n  }\n\n  /**\n   * Fluent method to change the value of deleted rows in the diff column. Returns a new immutable DiffOptions instance\n   * with the new diff value.\n   * @param deleteDiffValue\n   *   new diff value\n   * @return\n   *   new immutable DiffOptions instance\n   */\n  def withDeleteDiffValue(deleteDiffValue: String): DiffOptions = {\n    this.copy(deleteDiffValue = deleteDiffValue)\n  }\n\n  /**\n   * Fluent method to change the value of un-changed rows in the diff column. Returns a new immutable DiffOptions\n   * instance with the new diff value.\n   * @param nochangeDiffValue\n   *   new diff value\n   * @return\n   *   new immutable DiffOptions instance\n   */\n  def withNochangeDiffValue(nochangeDiffValue: String): DiffOptions = {\n    this.copy(nochangeDiffValue = nochangeDiffValue)\n  }\n\n  /**\n   * Fluent method to change the change column name. Returns a new immutable DiffOptions instance with the new change\n   * column name.\n   * @param changeColumn\n   *   new change column name\n   * @return\n   *   new immutable DiffOptions instance\n   */\n  def withChangeColumn(changeColumn: String): DiffOptions = {\n    this.copy(changeColumn = Some(changeColumn))\n  }\n\n  /**\n   * Fluent method to remove change column. Returns a new immutable DiffOptions instance without a change column.\n   * @return\n   *   new immutable DiffOptions instance\n   */\n  def withoutChangeColumn(): DiffOptions = {\n    this.copy(changeColumn = None)\n  }\n\n  /**\n   * Fluent method to change the diff mode. Returns a new immutable DiffOptions instance with the new diff mode.\n   * @return\n   *   new immutable DiffOptions instance\n   */\n  def withDiffMode(diffMode: DiffMode): DiffOptions = {\n    this.copy(diffMode = diffMode)\n  }\n\n  /**\n   * Fluent method to change the sparse mode. Returns a new immutable DiffOptions instance with the new sparse mode.\n   * @return\n   *   new immutable DiffOptions instance\n   */\n  def withSparseMode(sparseMode: Boolean): DiffOptions = {\n    this.copy(sparseMode = sparseMode)\n  }\n\n  /**\n   * Fluent method to add a default comparator. Returns a new immutable DiffOptions instance with the new default\n   * comparator.\n   * @return\n   *   new immutable DiffOptions instance\n   */\n  def withDefaultComparator(diffComparator: DiffComparator): DiffOptions = {\n    this.copy(defaultComparator = diffComparator)\n  }\n\n  /**\n   * Fluent method to add a typed equivalent operator as a default comparator. The encoder defines the input type of the\n   * comparator. Returns a new immutable DiffOptions instance with the new default comparator.\n   * @note\n   *   The `math.Equiv` will not be given any null values. Null-aware comparators can only be implemented via the\n   *   `DiffComparator` interface.\n   * @return\n   *   new immutable DiffOptions instance\n   */\n  def withDefaultComparator[T: Encoder](equiv: math.Equiv[T]): DiffOptions = {\n    withDefaultComparator(EquivDiffComparator(equiv))\n  }\n\n  /**\n   * Fluent method to add a typed equivalent operator as a default comparator. Returns a new immutable DiffOptions\n   * instance with the new default comparator.\n   * @note\n   *   The `math.Equiv` will not be given any null values. Null-aware comparators can only be implemented via the\n   *   `DiffComparator` interface.\n   * @return\n   *   new immutable DiffOptions instance\n   */\n  def withDefaultComparator[T](equiv: math.Equiv[T], inputDataType: DataType): DiffOptions = {\n    withDefaultComparator(EquivDiffComparator(equiv, inputDataType))\n  }\n\n  /**\n   * Fluent method to add an equivalent operator as a default comparator. Returns a new immutable DiffOptions instance\n   * with the new default comparator.\n   * @note\n   *   The `math.Equiv` will not be given any null values. Null-aware comparators can only be implemented via the\n   *   `DiffComparator` interface.\n   * @return\n   *   new immutable DiffOptions instance\n   */\n  def withDefaultComparator(equiv: math.Equiv[Any]): DiffOptions = {\n    withDefaultComparator(EquivDiffComparator(equiv))\n  }\n\n  /**\n   * Fluent method to add a comparator for its input data type. Returns a new immutable DiffOptions instance with the\n   * new comparator.\n   * @return\n   *   new immutable DiffOptions instance\n   */\n  def withComparator(diffComparator: TypedDiffComparator): DiffOptions = {\n    if (dataTypeComparators.contains(diffComparator.inputType)) {\n      throw new IllegalArgumentException(s\"A comparator for data type ${diffComparator.inputType} exists already.\")\n    }\n    this.copy(dataTypeComparators = dataTypeComparators ++ Map(diffComparator.inputType -> diffComparator))\n  }\n\n  /**\n   * Fluent method to add a comparator for one or more data types. Returns a new immutable DiffOptions instance with the\n   * new comparator.\n   * @return\n   *   new immutable DiffOptions instance\n   */\n  @varargs\n  def withComparator(diffComparator: DiffComparator, dataType: DataType, dataTypes: DataType*): DiffOptions = {\n    val allDataTypes = dataType +: dataTypes\n\n    diffComparator match {\n      case typed: TypedDiffComparator if allDataTypes.exists(_ != typed.inputType) =>\n        throw new IllegalArgumentException(\n          s\"Comparator with input type ${typed.inputType.simpleString} \" +\n            s\"cannot be used for data type ${allDataTypes.filter(_ != typed.inputType).map(_.simpleString).sorted.mkString(\", \")}\"\n        )\n      case _ =>\n    }\n\n    val existingDataTypes = allDataTypes.filter(dataTypeComparators.contains)\n    if (existingDataTypes.nonEmpty) {\n      throw new IllegalArgumentException(\n        s\"A comparator for data type${if (existingDataTypes.length > 1) \"s\" else \"\"} \" +\n          s\"${existingDataTypes.map(_.simpleString).sorted.mkString(\", \")} exists already.\"\n      )\n    }\n    this.copy(dataTypeComparators = dataTypeComparators ++ allDataTypes.map(dt => dt -> diffComparator))\n  }\n\n  /**\n   * Fluent method to add a comparator for one or more column names. Returns a new immutable DiffOptions instance with\n   * the new comparator.\n   * @return\n   *   new immutable DiffOptions instance\n   */\n  @varargs\n  def withComparator(diffComparator: DiffComparator, columnName: String, columnNames: String*): DiffOptions = {\n    val allColumnNames = columnName +: columnNames\n    val existingColumnNames = allColumnNames.filter(columnNameComparators.contains)\n    if (existingColumnNames.nonEmpty) {\n      throw new IllegalArgumentException(\n        s\"A comparator for column name${if (existingColumnNames.length > 1) \"s\" else \"\"} \" +\n          s\"${existingColumnNames.sorted.mkString(\", \")} exists already.\"\n      )\n    }\n    this.copy(columnNameComparators = columnNameComparators ++ allColumnNames.map(name => name -> diffComparator))\n  }\n\n  /**\n   * Fluent method to add a typed equivalent operator as a comparator for its input data type. The encoder defines the\n   * input type of the comparator. Returns a new immutable DiffOptions instance with the new comparator.\n   * @note\n   *   The `math.Equiv` will not be given any null values. Null-aware comparators can only be implemented via the\n   *   `DiffComparator` interface.\n   * @return\n   *   new immutable DiffOptions instance\n   */\n  def withComparator[T: Encoder](equiv: math.Equiv[T]): DiffOptions =\n    withComparator(EquivDiffComparator(equiv))\n\n  /**\n   * Fluent method to add a typed equivalent operator as a comparator for one or more column names. The encoder defines\n   * the input type of the comparator. Returns a new immutable DiffOptions instance with the new comparator.\n   * @note\n   *   The `math.Equiv` will not be given any null values. Null-aware comparators can only be implemented via the\n   *   `DiffComparator` interface.\n   * @return\n   *   new immutable DiffOptions instance\n   */\n  def withComparator[T: Encoder](equiv: math.Equiv[T], columnName: String, columnNames: String*): DiffOptions =\n    withComparator(EquivDiffComparator(equiv), columnName, columnNames: _*)\n\n  /**\n   * Fluent method to add an equivalent operator as a comparator for one or more column names. Returns a new immutable\n   * DiffOptions instance with the new comparator.\n   * @note\n   *   The `math.Equiv` will not be given any null values. Null-aware comparators can only be implemented via the\n   *   `DiffComparator` interface.\n   * @note\n   *   Java-specific method\n   * @return\n   *   new immutable DiffOptions instance\n   */\n  @varargs\n  def withComparator[T](\n      equiv: math.Equiv[T],\n      encoder: Encoder[T],\n      columnName: String,\n      columnNames: String*\n  ): DiffOptions =\n    withComparator(EquivDiffComparator(equiv)(encoder), columnName, columnNames: _*)\n\n  /**\n   * Fluent method to add an equivalent operator as a comparator for one or more data types. Returns a new immutable\n   * DiffOptions instance with the new comparator.\n   * @note\n   *   The `math.Equiv` will not be given any null values. Null-aware comparators can only be implemented via the\n   *   `DiffComparator` interface.\n   * @return\n   *   new immutable DiffOptions instance\n   */\n  // There is probably no use case of calling this with multiple datatype while T not being Any\n  // But this is the only way to define withComparator[T](equiv: math.Equiv[T], dataType: DataType)\n  // without being ambiguous with withComparator(equiv: math.Equiv[Any], dataType: DataType, dataTypes: DataType*)\n  @varargs\n  def withComparator[T](equiv: math.Equiv[T], dataType: DataType, dataTypes: DataType*): DiffOptions =\n    (dataType +: dataTypes).foldLeft(this)((options, dataType) =>\n      options.withComparator(EquivDiffComparator(equiv, dataType))\n    )\n\n  /**\n   * Fluent method to add an equivalent operator as a comparator for one or more column names. Returns a new immutable\n   * DiffOptions instance with the new comparator.\n   * @note\n   *   The `math.Equiv` will not be given any null values. Null-aware comparators can only be implemented via the\n   *   `DiffComparator` interface.\n   * @return\n   *   new immutable DiffOptions instance\n   */\n  @varargs\n  def withComparator(equiv: math.Equiv[Any], columnName: String, columnNames: String*): DiffOptions =\n    withComparator(EquivDiffComparator(equiv), columnName, columnNames: _*)\n\n  private[diff] def comparatorFor(column: StructField): DiffComparator =\n    columnNameComparators\n      .get(column.name)\n      .orElse(dataTypeComparators.get(column.dataType))\n      .getOrElse(defaultComparator)\n}\n\nobject DiffOptions {\n\n  /**\n   * Default diffing options.\n   */\n  val default: DiffOptions = new DiffOptions()\n}\n"
  },
  {
    "path": "src/main/scala/uk/co/gresearch/spark/diff/comparator/DefaultDiffComparator.scala",
    "content": "/*\n * Copyright 2022 G-Research\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage uk.co.gresearch.spark.diff.comparator\n\nimport org.apache.spark.sql.Column\n\ncase object DefaultDiffComparator extends DiffComparator {\n  override def equiv(left: Column, right: Column): Column = NullSafeEqualDiffComparator.equiv(left, right)\n}\n"
  },
  {
    "path": "src/main/scala/uk/co/gresearch/spark/diff/comparator/DiffComparator.scala",
    "content": "/*\n * Copyright 2022 G-Research\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage uk.co.gresearch.spark.diff.comparator\n\nimport org.apache.spark.sql.Column\n\ntrait DiffComparator {\n  def equiv(left: Column, right: Column): Column\n}\n"
  },
  {
    "path": "src/main/scala/uk/co/gresearch/spark/diff/comparator/DurationDiffComparator.scala",
    "content": "/*\n * Copyright 2022 G-Research\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage uk.co.gresearch.spark.diff.comparator\n\nimport org.apache.spark.sql.Column\nimport org.apache.spark.sql.functions.abs\nimport uk.co.gresearch.spark\nimport uk.co.gresearch.spark.SparkVersion\nimport uk.co.gresearch.spark.diff.comparator.DurationDiffComparator.isNotSupportedBySpark\n\nimport java.time.Duration\n\n/**\n * Compares two timestamps and considers them equal when they are less than (or equal to when inclusive = true) a given\n * duration apart.\n *\n * @param duration\n *   equality threshold\n * @param inclusive\n *   duration is considered equal when true\n */\ncase class DurationDiffComparator(duration: Duration, inclusive: Boolean = true) extends DiffComparator {\n  if (isNotSupportedBySpark) {\n    throw new UnsupportedOperationException(\n      s\"java.time.Duration is not supported by Spark ${spark.SparkCompatVersionString}\"\n    )\n  }\n\n  override def equiv(left: Column, right: Column): Column = {\n    val inDuration =\n      if (inclusive)\n        (diff: Column) => diff <= duration\n      else\n        (diff: Column) => diff < duration\n\n    left.isNull && right.isNull ||\n    left.isNotNull && right.isNotNull && inDuration(abs(left - right))\n  }\n\n  def asInclusive(): DurationDiffComparator = if (inclusive) this else copy(inclusive = true)\n  def asExclusive(): DurationDiffComparator = if (inclusive) copy(inclusive = false) else this\n}\n\nobject DurationDiffComparator extends SparkVersion {\n  val isSupportedBySpark: Boolean = SparkMajorVersion == 3 && SparkMinorVersion >= 3 || SparkMajorVersion > 3\n  val isNotSupportedBySpark: Boolean = !isSupportedBySpark\n}\n"
  },
  {
    "path": "src/main/scala/uk/co/gresearch/spark/diff/comparator/EpsilonDiffComparator.scala",
    "content": "/*\n * Copyright 2022 G-Research\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage uk.co.gresearch.spark.diff.comparator\n\nimport org.apache.spark.sql.Column\nimport org.apache.spark.sql.functions.{abs, greatest}\n\ncase class EpsilonDiffComparator(epsilon: Double, relative: Boolean = true, inclusive: Boolean = true)\n    extends DiffComparator {\n  override def equiv(left: Column, right: Column): Column = {\n    val threshold =\n      if (relative)\n        greatest(abs(left), abs(right)) * epsilon\n      else\n        epsilon\n\n    val inEpsilon =\n      if (inclusive)\n        (diff: Column) => diff <= threshold\n      else\n        (diff: Column) => diff < threshold\n\n    left.isNull && right.isNull || left.isNotNull && right.isNotNull && inEpsilon(abs(left - right))\n  }\n\n  def asAbsolute(): EpsilonDiffComparator = if (relative) copy(relative = false) else this\n  def asRelative(): EpsilonDiffComparator = if (relative) this else copy(relative = true)\n\n  def asInclusive(): EpsilonDiffComparator = if (inclusive) this else copy(inclusive = true)\n  def asExclusive(): EpsilonDiffComparator = if (inclusive) copy(inclusive = false) else this\n}\n"
  },
  {
    "path": "src/main/scala/uk/co/gresearch/spark/diff/comparator/EquivDiffComparator.scala",
    "content": "/*\n * Copyright 2022 G-Research\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage uk.co.gresearch.spark.diff.comparator\n\nimport org.apache.spark.sql.catalyst.InternalRow\nimport org.apache.spark.sql.catalyst.encoders.encoderFor\nimport org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper\nimport org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, FalseLiteral}\nimport org.apache.spark.sql.catalyst.expressions.{BinaryExpression, BinaryOperator, Expression}\nimport org.apache.spark.sql.extension.{ColumnExtension, ExpressionExtension}\nimport org.apache.spark.sql.types.{BooleanType, DataType}\nimport org.apache.spark.sql.{Column, Encoder}\n\ntrait EquivDiffComparator[T] extends DiffComparator {\n  val equiv: math.Equiv[T]\n}\n\nprivate trait ExpressionEquivDiffComparator[T] extends EquivDiffComparator[T] {\n  def equiv(left: Expression, right: Expression): EquivExpression[T]\n  def equiv(left: Column, right: Column): Column = equiv(left.expr, right.expr).column\n}\n\ntrait TypedEquivDiffComparator[T] extends EquivDiffComparator[T] with TypedDiffComparator\n\nprivate[comparator] trait TypedEquivDiffComparatorWithInput[T]\n    extends ExpressionEquivDiffComparator[T]\n    with TypedEquivDiffComparator[T] {\n  def equiv(left: Expression, right: Expression): Equiv[T] = Equiv(left, right, equiv, inputType)\n}\n\nprivate[comparator] case class InputTypedEquivDiffComparator[T](equiv: math.Equiv[T], inputType: DataType)\n    extends TypedEquivDiffComparatorWithInput[T]\n\nobject EquivDiffComparator {\n  def apply[T: Encoder](equiv: math.Equiv[T]): TypedEquivDiffComparator[T] = EncoderEquivDiffComparator(equiv)\n  def apply[T](equiv: math.Equiv[T], inputType: DataType): TypedEquivDiffComparator[T] =\n    InputTypedEquivDiffComparator(equiv, inputType)\n  def apply(equiv: math.Equiv[Any]): EquivDiffComparator[Any] = EquivAnyDiffComparator(equiv)\n\n  private case class EncoderEquivDiffComparator[T: Encoder](equiv: math.Equiv[T])\n      extends ExpressionEquivDiffComparator[T]\n      with TypedEquivDiffComparator[T] {\n    override def inputType: DataType = encoderFor[T].schema.fields(0).dataType\n    def equiv(left: Expression, right: Expression): Equiv[T] = Equiv(left, right, equiv, inputType)\n  }\n\n  private case class EquivAnyDiffComparator(equiv: math.Equiv[Any]) extends ExpressionEquivDiffComparator[Any] {\n    def equiv(left: Expression, right: Expression): EquivExpression[Any] = EquivAny(left, right, equiv)\n  }\n}\n\nprivate trait EquivExpression[T] extends BinaryExpression {\n  val equiv: math.Equiv[T]\n\n  override def nullable: Boolean = false\n\n  override def dataType: DataType = BooleanType\n\n  override def eval(input: InternalRow): Any = {\n    val input1 = left.eval(input).asInstanceOf[T]\n    val input2 = right.eval(input).asInstanceOf[T]\n    if (input1 == null && input2 == null) {\n      true\n    } else if (input1 == null || input2 == null) {\n      false\n    } else {\n      equiv.equiv(input1, input2)\n    }\n  }\n\n  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {\n    val eval1 = left.genCode(ctx)\n    val eval2 = right.genCode(ctx)\n    val equivRef = ctx.addReferenceObj(\"equiv\", equiv, math.Equiv.getClass.getName.stripSuffix(\"$\"))\n    ev.copy(\n      code = eval1.code + eval2.code + code\"\"\"\n        boolean ${ev.value} = (${eval1.isNull} && ${eval2.isNull}) ||\n           (!${eval1.isNull} && !${eval2.isNull} && $equivRef.equiv(${eval1.value}, ${eval2.value}));\"\"\",\n      isNull = FalseLiteral\n    )\n  }\n}\n\nprivate trait EquivOperator[T] extends BinaryOperator with EquivExpression[T] {\n  val equivInputType: DataType\n\n  override def inputType: DataType = equivInputType\n\n  override def symbol: String = \"≡\"\n}\n\nprivate case class Equiv[T](left: Expression, right: Expression, equiv: math.Equiv[T], equivInputType: DataType)\n    extends EquivOperator[T] {\n  override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Equiv[T] =\n    copy(left = newLeft, right = newRight)\n}\n\nprivate case class EquivAny(left: Expression, right: Expression, equiv: math.Equiv[Any]) extends EquivExpression[Any] {\n\n  override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): EquivAny =\n    copy(left = newLeft, right = newRight)\n}\n"
  },
  {
    "path": "src/main/scala/uk/co/gresearch/spark/diff/comparator/MapDiffComparator.scala",
    "content": "/*\n * Copyright 2022 G-Research\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage uk.co.gresearch.spark.diff.comparator\n\nimport org.apache.spark.sql.catalyst.encoders.encoderFor\nimport org.apache.spark.sql.catalyst.expressions.UnsafeMapData\nimport org.apache.spark.sql.types.{DataType, MapType}\nimport org.apache.spark.sql.{Column, Encoder}\n\nimport scala.reflect.ClassTag\n\ncase class MapDiffComparator[K, V](private val comparator: EquivDiffComparator[UnsafeMapData]) extends DiffComparator {\n  override def equiv(left: Column, right: Column): Column = comparator.equiv(left, right)\n}\n\nprivate case class MapDiffEquiv[K: ClassTag, V](keyType: DataType, valueType: DataType, keyOrderSensitive: Boolean)\n    extends math.Equiv[UnsafeMapData] {\n  override def equiv(left: UnsafeMapData, right: UnsafeMapData): Boolean = {\n\n    val leftKeys: Array[K] = left.keyArray().toArray(keyType)\n    val rightKeys: Array[K] = right.keyArray().toArray(keyType)\n\n    val leftKeysIndices: Map[K, Int] = leftKeys.zipWithIndex.toMap\n    val rightKeysIndices: Map[K, Int] = rightKeys.zipWithIndex.toMap\n\n    val leftValues = left.valueArray()\n    val rightValues = right.valueArray()\n\n    // can only be evaluated when right has same keys as left\n    lazy val valuesAreEqual = leftKeysIndices\n      .map { case (key, index) => index -> rightKeysIndices(key) }\n      .map { case (leftIndex, rightIndex) =>\n        (leftIndex, rightIndex, leftValues.isNullAt(leftIndex), rightValues.isNullAt(rightIndex))\n      }\n      .map { case (leftIndex, rightIndex, leftIsNull, rightIsNull) =>\n        leftIsNull && rightIsNull ||\n        !leftIsNull && !rightIsNull && leftValues\n          .get(leftIndex, valueType)\n          .equals(rightValues.get(rightIndex, valueType))\n      }\n\n    left.numElements() == right.numElements() &&\n    (keyOrderSensitive && leftKeys\n      .sameElements(rightKeys) || !keyOrderSensitive && leftKeys.toSet.diff(rightKeys.toSet).isEmpty) &&\n    valuesAreEqual.forall(identity)\n  }\n}\n\ncase object MapDiffComparator {\n  def apply[K: Encoder, V: Encoder](keyOrderSensitive: Boolean): MapDiffComparator[K, V] = {\n    val keyType = encoderFor[K].schema.fields(0).dataType\n    val valueType = encoderFor[V].schema.fields(0).dataType\n    val equiv = MapDiffEquiv(keyType, valueType, keyOrderSensitive)\n    val dataType = MapType(keyType, valueType)\n    val comparator = InputTypedEquivDiffComparator[UnsafeMapData](equiv, dataType)\n    MapDiffComparator[K, V](comparator)\n  }\n\n  def apply(keyType: DataType, valueType: DataType, keyOrderSensitive: Boolean): MapDiffComparator[Any, Any] = {\n    val equiv = MapDiffEquiv(keyType, valueType, keyOrderSensitive)\n    val dataType = MapType(keyType, valueType)\n    val comparator = InputTypedEquivDiffComparator[UnsafeMapData](equiv, dataType)\n    MapDiffComparator[Any, Any](comparator)\n  }\n\n  // for backward compatibility to v2.4.0 up to v2.8.0\n  // replace with default value in above apply when moving to v3\n  def apply[K: Encoder, V: Encoder](): MapDiffComparator[K, V] = apply(keyOrderSensitive = false)\n}\n"
  },
  {
    "path": "src/main/scala/uk/co/gresearch/spark/diff/comparator/NullSafeEqualDiffComparator.scala",
    "content": "/*\n * Copyright 2022 G-Research\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage uk.co.gresearch.spark.diff.comparator\n\nimport org.apache.spark.sql.Column\n\ncase object NullSafeEqualDiffComparator extends DiffComparator {\n  override def equiv(left: Column, right: Column): Column = left <=> right\n}\n"
  },
  {
    "path": "src/main/scala/uk/co/gresearch/spark/diff/comparator/TypedDiffComparator.scala",
    "content": "/*\n * Copyright 2022 G-Research\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage uk.co.gresearch.spark.diff.comparator\n\nimport org.apache.spark.sql.Column\nimport org.apache.spark.sql.types.{DataType, StringType}\n\ntrait TypedDiffComparator extends DiffComparator {\n  def inputType: DataType\n}\n\ntrait StringDiffComparator extends TypedDiffComparator {\n  override def inputType: DataType = StringType\n}\n\ncase object StringDiffComparator extends StringDiffComparator {\n  override def equiv(left: Column, right: Column): Column = DefaultDiffComparator.equiv(left, right)\n}\n"
  },
  {
    "path": "src/main/scala/uk/co/gresearch/spark/diff/comparator/WhitespaceDiffComparator.scala",
    "content": "/*\n * Copyright 2023 G-Research\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage uk.co.gresearch.spark.diff.comparator\n\nimport org.apache.spark.unsafe.types.UTF8String\n\ncase object WhitespaceDiffComparator extends TypedEquivDiffComparatorWithInput[UTF8String] with StringDiffComparator {\n  override val equiv: scala.Equiv[UTF8String] = (x: UTF8String, y: UTF8String) =>\n    x.trimAll()\n      .toString\n      .replaceAll(\"\\\\s+\", \" \")\n      .equals(\n        y.trimAll().toString.replaceAll(\"\\\\s+\", \" \")\n      )\n}\n"
  },
  {
    "path": "src/main/scala/uk/co/gresearch/spark/diff/package.scala",
    "content": "/*\n * Copyright 2020 G-Research\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage uk.co.gresearch.spark\n\nimport org.apache.spark.sql.internal.SQLConf\nimport org.apache.spark.sql.{DataFrame, Dataset, Encoder}\n\nimport java.util.Locale\n\npackage object diff {\n\n  implicit class DatasetDiff[T](ds: Dataset[T]) {\n\n    /**\n     * Returns a new DataFrame that contains the differences between this and the other Dataset of the same type `T`.\n     * Both Datasets must contain the same set of column names and data types. The order of columns in the two Datasets\n     * is not important as one column is compared to the column with the same name of the other Dataset, not the column\n     * with the same position.\n     *\n     * Optional id columns are used to uniquely identify rows to compare. If values in any non-id column are differing\n     * between this and the other Dataset, then that row is marked as `\"C\"`hange and `\"N\"`o-change otherwise. Rows of\n     * the other Dataset, that do not exist in this Dataset (w.r.t. the values in the id columns) are marked as\n     * `\"I\"`nsert. And rows of this Dataset, that do not exist in the other Dataset are marked as `\"D\"`elete.\n     *\n     * If no id columns are given (empty sequence), all columns are considered id columns. Then, no `\"C\"`hange rows will\n     * appear, as all changes will exists as respective `\"D\"`elete and `\"I\"`nsert.\n     *\n     * The returned DataFrame has the `diff` column as the first column. This holds the `\"N\"`, `\"C\"`, `\"I\"` or `\"D\"`\n     * strings. The id columns follow, then the non-id columns (all remaining columns).\n     *\n     * {{{\n     *   val df1 = Seq((1, \"one\"), (2, \"two\"), (3, \"three\")).toDF(\"id\", \"value\")\n     *   val df2 = Seq((1, \"one\"), (2, \"Two\"), (4, \"four\")).toDF(\"id\", \"value\")\n     *\n     *   df1.diff(df2).show()\n     *\n     *   // output:\n     *   // +----+---+-----+\n     *   // |diff| id|value|\n     *   // +----+---+-----+\n     *   // |   N|  1|  one|\n     *   // |   D|  2|  two|\n     *   // |   I|  2|  Two|\n     *   // |   D|  3|three|\n     *   // |   I|  4| four|\n     *   // +----+---+-----+\n     *\n     *   df1.diff(df2, \"id\").show()\n     *\n     *   // output:\n     *   // +----+---+----------+-----------+\n     *   // |diff| id|left_value|right_value|\n     *   // +----+---+----------+-----------+\n     *   // |   N|  1|       one|        one|\n     *   // |   C|  2|       two|        Two|\n     *   // |   D|  3|     three|       null|\n     *   // |   I|  4|      null|       four|\n     *   // +----+---+----------+-----------+\n     *\n     * }}}\n     *\n     * The id columns are in order as given to the method. If no id columns are given then all columns of this Dataset\n     * are id columns and appear in the same order. The remaining non-id columns are in the order of this Dataset.\n     *\n     * The id column names are take literally, i.e. \"a.field\" is interpreted as \"`a.field`, which is a column name\n     * containing a dot. This is not interpreted as a column \"a\" with a field \"field\" (struct).\n     */\n    // no @scala.annotation.varargs here as this implicit class is not nicely accessible from Java\n    def diff(other: Dataset[T], idColumns: String*): DataFrame = {\n      Diff.of(this.ds, other, idColumns: _*)\n    }\n\n    /**\n     * Returns a new DataFrame that contains the differences between two Datasets of similar types `T` and `U`. Both\n     * Datasets must contain the same set of column names and data types, except for the columns in `ignoreColumns`. The\n     * order of columns in the two Datasets is not relevant as columns are compared based on the name, not the the\n     * position.\n     *\n     * Optional id columns are used to uniquely identify rows to compare. If values in any non-id column are differing\n     * between this and the other Dataset, then that row is marked as `\"C\"`hange and `\"N\"`o-change otherwise. Rows of\n     * the other Dataset, that do not exist in this Dataset (w.r.t. the values in the id columns) are marked as\n     * `\"I\"`nsert. And rows of this Dataset, that do not exist in the other Dataset are marked as `\"D\"`elete.\n     *\n     * If no id columns are given (empty sequence), all columns are considered id columns. Then, no `\"C\"`hange rows will\n     * appear, as all changes will exists as respective `\"D\"`elete and `\"I\"`nsert.\n     *\n     * Values in optional ignore columns are not compared but included in the output DataFrame.\n     *\n     * The returned DataFrame has the `diff` column as the first column. This holds the `\"N\"`, `\"C\"`, `\"I\"` or `\"D\"`\n     * strings. The id columns follow, then the non-id columns (all remaining columns).\n     *\n     * {{{\n     *   val df1 = Seq((1, \"one\"), (2, \"two\"), (3, \"three\")).toDF(\"id\", \"value\")\n     *   val df2 = Seq((1, \"one\"), (2, \"Two\"), (4, \"four\")).toDF(\"id\", \"value\")\n     *\n     *   df1.diff(df2).show()\n     *\n     *   // output:\n     *   // +----+---+-----+\n     *   // |diff| id|value|\n     *   // +----+---+-----+\n     *   // |   N|  1|  one|\n     *   // |   D|  2|  two|\n     *   // |   I|  2|  Two|\n     *   // |   D|  3|three|\n     *   // |   I|  4| four|\n     *   // +----+---+-----+\n     *\n     *   df1.diff(df2, \"id\").show()\n     *\n     *   // output:\n     *   // +----+---+----------+-----------+\n     *   // |diff| id|left_value|right_value|\n     *   // +----+---+----------+-----------+\n     *   // |   N|  1|       one|        one|\n     *   // |   C|  2|       two|        Two|\n     *   // |   D|  3|     three|       null|\n     *   // |   I|  4|      null|       four|\n     *   // +----+---+----------+-----------+\n     *\n     * }}}\n     *\n     * The id columns are in order as given to the method. If no id columns are given then all columns of this Dataset\n     * are id columns and appear in the same order. The remaining non-id columns are in the order of this Dataset.\n     *\n     * The id column names are take literally, i.e. \"a.field\" is interpreted as \"`a.field`, which is a column name\n     * containing a dot. This is not interpreted as a column \"a\" with a field \"field\" (struct).\n     */\n    def diff[U](other: Dataset[U], idColumns: Seq[String], ignoreColumns: Seq[String]): DataFrame = {\n      Diff.of(this.ds, other, idColumns, ignoreColumns)\n    }\n\n    /**\n     * Returns a new DataFrame that contains the differences between this and the other Dataset of the same type `T`.\n     *\n     * See `diff(Dataset[T], String*)`.\n     *\n     * The schema of the returned DataFrame can be configured by the given `DiffOptions`.\n     */\n    // no @scala.annotation.varargs here as this implicit class is not nicely accessible from Java\n    def diff(other: Dataset[T], options: DiffOptions, idColumns: String*): DataFrame = {\n      new Differ(options).diff(this.ds, other, idColumns: _*)\n    }\n\n    /**\n     * Returns a new DataFrame that contains the differences between this and the other Dataset of similar types `T` and\n     * `U`.\n     *\n     * See `diff(Dataset[U], Seq[String], Seq[String])`.\n     *\n     * The schema of the returned DataFrame can be configured by the given `DiffOptions`.\n     */\n    def diff[U](\n        other: Dataset[U],\n        options: DiffOptions,\n        idColumns: Seq[String],\n        ignoreColumns: Seq[String]\n    ): DataFrame = {\n      new Differ(options).diff(this.ds, other, idColumns, ignoreColumns)\n    }\n\n    /**\n     * Returns a new Dataset that contains the differences between this and the other Dataset of the same type `T`.\n     *\n     * See `diff(Dataset[T], String*)`.\n     *\n     * This requires an additional implicit `Encoder[U]` for the return type `Dataset[U]`.\n     */\n    // no @scala.annotation.varargs here as this implicit class is not nicely accessible from Java\n    def diffAs[V](other: Dataset[T], idColumns: String*)(implicit diffEncoder: Encoder[V]): Dataset[V] = {\n      Diff.ofAs(this.ds, other, idColumns: _*)\n    }\n\n    /**\n     * Returns a new Dataset that contains the differences between this and the other Dataset of similar types `T` and\n     * `U`.\n     *\n     * See `diff(Dataset[U], Seq[String], Seq[String])`.\n     *\n     * This requires an additional implicit `Encoder[V]` for the return type `Dataset[V]`.\n     */\n    def diffAs[U, V](other: Dataset[U], idColumns: Seq[String], ignoreColumns: Seq[String])(implicit\n        diffEncoder: Encoder[V]\n    ): Dataset[V] = {\n      Diff.ofAs(this.ds, other, idColumns, ignoreColumns)\n    }\n\n    /**\n     * Returns a new Dataset that contains the differences between this and the other Dataset of the same type `T`.\n     *\n     * See `diff(Dataset[T], String*)`.\n     *\n     * This requires an additional implicit `Encoder[V]` for the return type `Dataset[V]`. The schema of the returned\n     * Dataset can be configured by the given `DiffOptions`.\n     */\n    // no @scala.annotation.varargs here as this implicit class is not nicely accessible from Java\n    def diffAs[V](other: Dataset[T], options: DiffOptions, idColumns: String*)(implicit\n        diffEncoder: Encoder[V]\n    ): Dataset[V] = {\n      new Differ(options).diffAs(this.ds, other, idColumns: _*)\n    }\n\n    /**\n     * Returns a new Dataset that contains the differences between this and the other Dataset of similar types `T` and\n     * `U`.\n     *\n     * See `diff(Dataset[U], Seq[String], Seq[String])`.\n     *\n     * This requires an additional implicit `Encoder[V]` for the return type `Dataset[V]`. The schema of the returned\n     * Dataset can be configured by the given `DiffOptions`.\n     */\n    def diffAs[U, V](other: Dataset[T], options: DiffOptions, idColumns: Seq[String], ignoreColumns: Seq[String])(\n        implicit diffEncoder: Encoder[V]\n    ): Dataset[V] = {\n      new Differ(options).diffAs(this.ds, other, idColumns, ignoreColumns)\n    }\n\n    /**\n     * Returns a new Dataset that contains the differences between this and the other Dataset of the same type `T`.\n     *\n     * See `diff(Dataset[T], String*)`.\n     *\n     * This requires an additional explicit `Encoder[V]` for the return type `Dataset[V]`.\n     */\n    // no @scala.annotation.varargs here as this implicit class is not nicely accessible from Java\n    def diffAs[V](other: Dataset[T], diffEncoder: Encoder[V], idColumns: String*): Dataset[V] = {\n      Diff.ofAs(this.ds, other, diffEncoder, idColumns: _*)\n    }\n\n    /**\n     * Returns a new Dataset that contains the differences between this and the other Dataset of similar types `T` and\n     * `U`.\n     *\n     * See `diff(Dataset[U], Seq[String], Seq[String])`.\n     *\n     * This requires an additional explicit `Encoder[V]` for the return type `Dataset[V]`.\n     */\n    def diffAs[U, V](\n        other: Dataset[U],\n        diffEncoder: Encoder[V],\n        idColumns: Seq[String],\n        ignoreColumns: Seq[String]\n    ): Dataset[V] = {\n      Diff.ofAs(this.ds, other, diffEncoder, idColumns, ignoreColumns)\n    }\n\n    /**\n     * Returns a new Dataset that contains the differences between this and the other Dataset of the same type `T`.\n     *\n     * See `diff(Dataset[T], String*)`.\n     *\n     * This requires an additional explicit `Encoder[V]` for the return type `Dataset[V]`. The schema of the returned\n     * Dataset can be configured by the given `DiffOptions`.\n     */\n    // no @scala.annotation.varargs here as this implicit class is not nicely accessible from Java\n    def diffAs[V](other: Dataset[T], options: DiffOptions, diffEncoder: Encoder[V], idColumns: String*): Dataset[V] = {\n      new Differ(options).diffAs(this.ds, other, diffEncoder, idColumns: _*)\n    }\n\n    /**\n     * Returns a new Dataset that contains the differences between this and the other Dataset of similar types `T` and\n     * `U`.\n     *\n     * See `diff(Dataset[U], Seq[String], Seq[String])`.\n     *\n     * This requires an additional explicit `Encoder[V]` for the return type `Dataset[V]`. The schema of the returned\n     * Dataset can be configured by the given `DiffOptions`.\n     */\n    def diffAs[U, V](\n        other: Dataset[U],\n        options: DiffOptions,\n        diffEncoder: Encoder[V],\n        idColumns: Seq[String],\n        ignoreColumns: Seq[String]\n    ): Dataset[V] = {\n      new Differ(options).diffAs(this.ds, other, diffEncoder, idColumns, ignoreColumns)\n    }\n\n    /**\n     * Returns a new Dataset that contains the differences between this and the other Dataset of the same type `T` as\n     * tuples of type `(String, T, T)`.\n     *\n     * See `diff(Dataset[T], Seq[String])`.\n     */\n    def diffWith(other: Dataset[T], idColumns: String*): Dataset[(String, T, T)] =\n      Diff.default.diffWith(this.ds, other, idColumns: _*)\n\n    /**\n     * Returns a new Dataset that contains the differences between this and the other Dataset of similar types `T` and\n     * `U` as tuples of type `(String, T, U)`.\n     *\n     * See `diff(Dataset[U], Seq[String], Seq[String])`.\n     */\n    def diffWith[U](other: Dataset[U], idColumns: Seq[String], ignoreColumns: Seq[String]): Dataset[(String, T, U)] =\n      Diff.default.diffWith(this.ds, other, idColumns, ignoreColumns)\n\n    /**\n     * Returns a new Dataset that contains the differences between this and the other Dataset of the same type `T` as\n     * tuples of type `(String, T, T)`.\n     *\n     * See `diff(Dataset[T], String*)`.\n     *\n     * The schema of the returned Dataset can be configured by the given `DiffOptions`.\n     */\n    def diffWith(other: Dataset[T], options: DiffOptions, idColumns: String*): Dataset[(String, T, T)] = {\n      new Differ(options).diffWith(this.ds, other, idColumns: _*)\n    }\n\n    /**\n     * Returns a new Dataset that contains the differences between this and the other Dataset of similar types `T` and\n     * `U`. as tuples of type `(String, T, T)`.\n     *\n     * See `diff(Dataset[U], Seq[String], Seq[String])`.\n     *\n     * The schema of the returned Dataset can be configured by the given `DiffOptions`.\n     */\n    def diffWith[U](\n        other: Dataset[U],\n        options: DiffOptions,\n        idColumns: Seq[String],\n        ignoreColumns: Seq[String]\n    ): Dataset[(String, T, U)] = {\n      new Differ(options).diffWith(this.ds, other, idColumns, ignoreColumns)\n    }\n  }\n\n  /**\n   * Produces a column name that considers configured case-sensitivity of column names. When case sensitivity is\n   * deactivated, it lower-cases the given column name and no-ops otherwise.\n   *\n   * @param columnName\n   *   column name\n   * @return\n   *   case sensitive or insensitive column name\n   */\n  private[diff] def handleConfiguredCaseSensitivity(columnName: String): String =\n    if (SQLConf.get.caseSensitiveAnalysis) columnName else columnName.toLowerCase(Locale.ROOT)\n\n  implicit class CaseInsensitiveSeq(seq: Seq[String]) {\n    def containsCaseSensitivity(string: String): Boolean =\n      seq.map(handleConfiguredCaseSensitivity).contains(handleConfiguredCaseSensitivity(string))\n\n    def filterIsInCaseSensitivity(other: Iterable[String]): Seq[String] = {\n      val otherSet = other.map(handleConfiguredCaseSensitivity).toSet\n      seq.filter(v => otherSet.contains(handleConfiguredCaseSensitivity(v)))\n    }\n\n    def diffCaseSensitivity(other: Iterable[String]): Seq[String] = {\n      val otherSet = other.map(handleConfiguredCaseSensitivity).toSet\n      seq.filter(v => !otherSet.contains(handleConfiguredCaseSensitivity(v)))\n    }\n  }\n\n  implicit class CaseInsensitiveArray(array: Array[String]) {\n    def containsCaseSensitivity(string: String): Boolean =\n      array.map(handleConfiguredCaseSensitivity).contains(handleConfiguredCaseSensitivity(string))\n    def filterIsInCaseSensitivity(other: Iterable[String]): Array[String] =\n      array.toSeq.filterIsInCaseSensitivity(other).toArray\n    def diffCaseSensitivity(other: Iterable[String]): Array[String] = array.toSeq.diffCaseSensitivity(other).toArray\n  }\n\n}\n"
  },
  {
    "path": "src/main/scala/uk/co/gresearch/spark/group/package.scala",
    "content": "/*\n * Copyright 2022 G-Research\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage uk.co.gresearch.spark\n\nimport org.apache.spark.sql.functions.{col, struct}\nimport org.apache.spark.sql.{Column, Dataset, Encoder, Encoders}\nimport uk.co.gresearch.ExtendedAny\n\npackage object group {\n\n  /**\n   * This is a Dataset of key-value tuples, that provide a flatMap function over the individual groups, while providing\n   * a sorted iterator over group values.\n   *\n   * The key-value Dataset given the constructor has to be partitioned by the key and sorted within partitions by the\n   * key and value.\n   *\n   * @param ds\n   *   the properly partitioned and sorted dataset\n   * @tparam K\n   *   type of the keys with ordering and encoder\n   * @tparam V\n   *   type of the values with encoder\n   */\n  case class SortedGroupByDataset[K: Ordering: Encoder, V: Encoder] private (ds: Dataset[(K, V)]) {\n\n    /**\n     * (Scala-specific) Applies the given function to each group of data. For each unique group, the function will be\n     * passed the group key and a sorted iterator that contains all of the elements in the group. The function can\n     * return an iterator containing elements of an arbitrary type which will be returned as a new [[Dataset]].\n     *\n     * This function does not support partial aggregation, and as a result requires shuffling all the data in the\n     * [[Dataset]]. If an application intends to perform an aggregation over each key, it is best to use the reduce\n     * function or an `org.apache.spark.sql.expressions#Aggregator`.\n     *\n     * Internally, the implementation will spill to disk if any given group is too large to fit into memory. However,\n     * users must take care to avoid materializing the whole iterator for a group (for example, by calling `toList`)\n     * unless they are sure that this is possible given the memory constraints of their cluster.\n     */\n    def flatMapSortedGroups[W: Encoder](func: (K, Iterator[V]) => TraversableOnce[W]): Dataset[W] =\n      ds.mapPartitions(new GroupedIterator(_).flatMap(v => func(v._1, v._2)))\n\n    /**\n     * (Scala-specific) Applies the given function to each group of data. For each unique group, the function s will be\n     * passed the group key to create a state instance, while the function func will be passed that state instance and\n     * group values in sequence according to the sort order in the groups. The function func can return an iterator\n     * containing elements of an arbitrary type which will be returned as a new [[Dataset]].\n     *\n     * This function does not support partial aggregation, and as a result requires shuffling all the data in the\n     * [[Dataset]]. If an application intends to perform an aggregation over each key, it is best to use the reduce\n     * function or an `org.apache.spark.sql.expressions#Aggregator`.\n     *\n     * Internally, the implementation will spill to disk if any given group is too large to fit into memory. However,\n     * users must take care to avoid materializing the whole iterator for a group (for example, by calling `toList`)\n     * unless they are sure that this is possible given the memory constraints of their cluster.\n     */\n    def flatMapSortedGroups[S, W: Encoder](s: K => S)(func: (S, V) => TraversableOnce[W]): Dataset[W] = {\n      ds.mapPartitions(new GroupedIterator(_).flatMap { case (k, it) =>\n        val state = s(k)\n        it.flatMap(v => func(state, v))\n      })\n    }\n  }\n\n  private[spark] object SortedGroupByDataset {\n    def apply[K: Ordering: Encoder, V](\n        ds: Dataset[V],\n        groupColumns: Seq[Column],\n        orderColumns: Seq[Column],\n        partitions: Option[Int]\n    ): SortedGroupByDataset[K, V] = {\n      // make ds encoder implicitly available\n      implicit val valueEncoder: Encoder[V] = ds.encoder\n\n      // multiple group columns are turned into a tuple,\n      // while a single group column is taken as is\n      val keyColumn =\n        if (groupColumns.length == 1)\n          groupColumns.head\n        else\n          struct(groupColumns: _*)\n\n      // all columns are turned into a single column as a struct\n      val valColumn = struct(col(\"*\"))\n\n      // repartition by group columns with given number of partitions (if given)\n      // sort within partitions by group and order columns\n      // finally, turn key and value into typed classes\n      val grouped = ds\n        .on(partitions.isDefined)\n        .either(_.repartition(partitions.get, groupColumns: _*))\n        .or(_.repartition(groupColumns: _*))\n        .sortWithinPartitions(groupColumns ++ orderColumns: _*)\n        .select(\n          keyColumn.as(\"key\").as[K],\n          valColumn.as(\"value\").as[V]\n        )\n\n      SortedGroupByDataset(grouped)\n    }\n\n    def apply[K: Ordering: Encoder, V, O: Encoder](\n        ds: Dataset[V],\n        key: V => K,\n        order: V => O,\n        partitions: Option[Int],\n        reverse: Boolean\n    ): SortedGroupByDataset[K, V] = {\n      // prepare encoder needed for this exercise\n      val keyEncoder: Encoder[K] = implicitly[Encoder[K]]\n      implicit val valueEncoder: Encoder[V] = ds.encoder\n      val orderEncoder: Encoder[O] = implicitly[Encoder[O]]\n      implicit val kvEncoder: Encoder[(K, V)] = Encoders.tuple(keyEncoder, valueEncoder)\n      implicit val kvoEncoder: Encoder[(K, V, O)] = Encoders.tuple(keyEncoder, valueEncoder, orderEncoder)\n\n      // materialise the key and order class for each value\n      val kvo = ds.map(v => (key(v), v, order(v)))\n\n      // sort by key and order column\n      def keyColumn = col(kvo.columns.head)\n\n      def orderColumn = if (reverse) col(kvo.columns.last).desc else col(kvo.columns.last)\n\n      // repartition by group columns with given number of partitions (if given)\n      // sort within partitions by group and order columns\n      // finally, turn key and value into typed classes\n      val grouped = kvo\n        .on(partitions.isDefined)\n        .either(_.repartition(partitions.get, keyColumn))\n        .or(_.repartition(keyColumn))\n        .sortWithinPartitions(keyColumn, orderColumn)\n        .map(v => (v._1, v._2))\n\n      SortedGroupByDataset(grouped)\n    }\n  }\n\n  private[group] class GroupedIterator[K: Ordering, V](iter: Iterator[(K, V)]) extends Iterator[(K, Iterator[V])] {\n    private val values = iter.buffered\n    private var currentKey: Option[K] = None\n    private var currentGroup: Option[Iterator[V]] = None\n\n    override def hasNext: Boolean = {\n      if (currentKey.isEmpty) {\n        if (currentGroup.isDefined) {\n          // consume current group\n          val it = currentGroup.get\n          while (it.hasNext) it.next\n          currentGroup = None\n        }\n\n        if (values.hasNext) {\n          currentKey = Some(values.head._1)\n          currentGroup = Some(new GroupIterator(values))\n        }\n      }\n      currentKey.isDefined\n    }\n\n    override def next(): (K, Iterator[V]) = {\n      try {\n        (currentKey.get, currentGroup.get)\n      } finally {\n        currentKey = None\n      }\n    }\n  }\n\n  private[group] class GroupIterator[K: Ordering, V](iter: BufferedIterator[(K, V)]) extends Iterator[V] {\n    private val ordering = implicitly[Ordering[K]]\n    private val key = iter.head._1\n\n    private def identicalKeys(one: K, two: K): Boolean =\n      one == null && two == null || one != null && two != null && ordering.equiv(one, two)\n\n    override def hasNext: Boolean = iter.hasNext && identicalKeys(iter.head._1, key)\n\n    override def next(): V = iter.next._2\n  }\n\n}\n"
  },
  {
    "path": "src/main/scala/uk/co/gresearch/spark/package.scala",
    "content": "/*\n * Copyright 2020 G-Research\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage uk.co.gresearch\n\nimport org.apache.spark.internal.Logging\nimport org.apache.spark.sql._\nimport org.apache.spark.sql.ColumnName\nimport org.apache.spark.sql.catalyst.expressions.{NamedExpression, UnixMicros}\nimport org.apache.spark.sql.extension.{ColumnExtension, ExpressionExtension}\nimport org.apache.spark.sql.functions.{col, count, lit, when}\nimport org.apache.spark.sql.internal.SQLConf\nimport org.apache.spark.sql.types.{DecimalType, LongType, TimestampType}\nimport org.apache.spark.storage.StorageLevel\nimport org.apache.spark.{SparkContext, SparkFiles}\nimport uk.co.gresearch.spark.group.SortedGroupByDataset\n\nimport java.nio.file.{Files, Paths}\n\npackage object spark extends Logging with SparkVersion with BuildVersion {\n\n  /**\n   * Provides a prefix that makes any string distinct w.r.t. the given strings.\n   * @param existing\n   *   strings\n   * @return\n   *   distinct prefix\n   */\n  private[spark] def distinctPrefixFor(existing: Seq[String]): String = {\n    // count number of suffix _ for each existing column name\n    // return string with one more _ than that\n    \"_\" * (existing.map(_.takeWhile(_ == '_').length).reduceOption(_ max _).getOrElse(0) + 1)\n  }\n\n  /**\n   * Create a temporary directory in a location (driver temp dir) that will be deleted on Spark application shutdown.\n   * @param prefix\n   *   prefix string of temporary directory name\n   * @return\n   *   absolute path of temporary directory\n   */\n  def createTemporaryDir(prefix: String): String = {\n    // SparkFiles.getRootDirectory() will be deleted on spark application shutdown\n    Files.createTempDirectory(Paths.get(SparkFiles.getRootDirectory()), prefix).toAbsolutePath.toString\n  }\n\n  // https://issues.apache.org/jira/browse/SPARK-40588\n  private[spark] def writePartitionedByRequiresCaching[T](ds: Dataset[T]): Boolean = {\n    ds.sparkSession.conf\n      .get(\n        SQLConf.ADAPTIVE_EXECUTION_ENABLED.key,\n        SQLConf.ADAPTIVE_EXECUTION_ENABLED.defaultValue.getOrElse(true).toString\n      )\n      .equalsIgnoreCase(\"true\") && Some(ds.sparkSession.version).exists(ver =>\n      Set(\"3.0.\", \"3.1.\", \"3.2.0\", \"3.2.1\", \"3.2.2\", \"3.3.0\", \"3.3.1\").exists(pat =>\n        if (pat.endsWith(\".\")) { ver.startsWith(pat) }\n        else { ver.equals(pat) || ver.startsWith(pat + \"-\") }\n      )\n    )\n  }\n\n  private[spark] def info(msg: String): Unit = logInfo(msg)\n  private[spark] def warning(msg: String): Unit = logWarning(msg)\n\n  /**\n   * Encloses the given strings with backticks (backquotes) if needed.\n   *\n   * Backticks are not needed for strings that start with a letter (`a`-`z` and `A`-`Z`) or an underscore, and contain\n   * only letters, numbers and underscores.\n   *\n   * Multiple strings will be enclosed individually and concatenated with dots (`.`).\n   *\n   * This is useful when referencing column names that contain special characters like dots (`.`) or backquotes.\n   *\n   * Examples:\n   * {{{\n   *   col(\"a.column\")                        // this references the field \"column\" of column \"a\"\n   *   col(\"`a.column`\")                      // this reference the column with the name \"a.column\"\n   *   col(backticks(\"column\"))               // produces \"column\"\n   *   col(backticks(\"a.column\"))             // produces \"`a.column`\"\n   *   col(backticks(\"a column\"))             // produces \"`a column`\"\n   *   col(backticks(\"`a.column`\"))           // produces \"`a.column`\"\n   *   col(backticks(\"a.column\", \"a.field\"))  // produces \"`a.column`.`a.field`\"\n   * }}}\n   *\n   * @param string\n   *   a string\n   * @param strings\n   *   more strings\n   */\n  @scala.annotation.varargs\n  def backticks(string: String, strings: String*): String =\n    Backticks.column_name(string, strings: _*)\n\n  /**\n   * Aggregate function: returns the number of items in a group that are not null.\n   */\n  def count_null(e: Column): Column = count(when(e.isNull, lit(1)))\n\n  private val nanoSecondsPerDotNetTick: Long = 100\n  private val dotNetTicksPerSecond: Long = 10000000\n  private val unixEpochDotNetTicks: Long = 621355968000000000L\n\n  /**\n   * Convert a .Net `DateTime.Ticks` timestamp to a Spark timestamp. The input column must be convertible to a number\n   * (e.g. string, int, long). The Spark timestamp type does not support nanoseconds, so the the last digit of the\n   * timestamp (1/10 of a microsecond) is lost.\n   *\n   * Example:\n   * {{{\n   *   df.select($\"ticks\", dotNetTicksToTimestamp($\"ticks\").as(\"timestamp\")).show(false)\n   * }}}\n   *\n   * | ticks              | timestamp                  |\n   * |:-------------------|:---------------------------|\n   * | 638155413748959318 | 2023-03-27 21:16:14.895931 |\n   *\n   * Note: the example timestamp lacks the 8/10 of a microsecond. Use `dotNetTicksToUnixEpoch` to preserve the full\n   * precision of the tick timestamp.\n   *\n   * https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks\n   *\n   * @param tickColumn\n   *   column with a tick value\n   * @return\n   *   result timestamp column\n   */\n  def dotNetTicksToTimestamp(tickColumn: Column): Column =\n    dotNetTicksToUnixEpoch(tickColumn).cast(TimestampType)\n\n  /**\n   * Convert a .Net `DateTime.Ticks` timestamp to a Spark timestamp. The input column must be convertible to a number\n   * (e.g. string, int, long). The Spark timestamp type does not support nanoseconds, so the the last digit of the\n   * timestamp (1/10 of a microsecond) is lost.\n   *\n   * {{{\n   *   df.select($\"ticks\", dotNetTicksToTimestamp(\"ticks\").as(\"timestamp\")).show(false)\n   * }}}\n   *\n   * | ticks              | timestamp                  |\n   * |:-------------------|:---------------------------|\n   * | 638155413748959318 | 2023-03-27 21:16:14.895931 |\n   *\n   * Note: the example timestamp lacks the 8/10 of a microsecond. Use `dotNetTicksToUnixEpoch` to preserve the full\n   * precision of the tick timestamp.\n   *\n   * https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks\n   *\n   * @param tickColumnName\n   *   name of a column with a tick value\n   * @return\n   *   result timestamp column\n   */\n  def dotNetTicksToTimestamp(tickColumnName: String): Column = dotNetTicksToTimestamp(col(tickColumnName))\n\n  /**\n   * Convert a .Net `DateTime.Ticks` timestamp to a Unix epoch decimal. The input column must be convertible to a number\n   * (e.g. string, int, long). The full precision of the tick timestamp is preserved (1/10 of a microsecond).\n   *\n   * Example:\n   * {{{\n   *   df.select($\"ticks\", dotNetTicksToUnixEpoch($\"ticks\").as(\"timestamp\")).show(false)\n   * }}}\n   *\n   * | ticks              | timestamp            |\n   * |:-------------------|:---------------------|\n   * | 638155413748959318 | 1679944574.895931800 |\n   *\n   * https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks\n   *\n   * @param tickColumn\n   *   column with a tick value\n   * @return\n   *   result unix epoch seconds column as decimal\n   */\n  def dotNetTicksToUnixEpoch(tickColumn: Column): Column =\n    (tickColumn.cast(DecimalType(19, 0)) - unixEpochDotNetTicks) / dotNetTicksPerSecond\n\n  /**\n   * Convert a .Net `DateTime.Ticks` timestamp to a Unix epoch seconds. The input column must be convertible to a number\n   * (e.g. string, int, long). The full precision of the tick timestamp is preserved (1/10 of a microsecond).\n   *\n   * Example:\n   * {{{\n   *   df.select($\"ticks\", dotNetTicksToUnixEpoch(\"ticks\").as(\"timestamp\")).show(false)\n   * }}}\n   *\n   * | ticks              | timestamp            |\n   * |:-------------------|:---------------------|\n   * | 638155413748959318 | 1679944574.895931800 |\n   *\n   * https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks\n   *\n   * @param tickColumnName\n   *   name of column with a tick value\n   * @return\n   *   result unix epoch seconds column as decimal\n   */\n  def dotNetTicksToUnixEpoch(tickColumnName: String): Column = dotNetTicksToUnixEpoch(col(tickColumnName))\n\n  /**\n   * Convert a .Net `DateTime.Ticks` timestamp to a Unix epoch seconds. The input column must be convertible to a number\n   * (e.g. string, int, long). The full precision of the tick timestamp is preserved (1/10 of a microsecond).\n   *\n   * Example:\n   * {{{\n   *   df.select($\"ticks\", dotNetTicksToUnixEpochNanos($\"ticks\").as(\"timestamp\")).show(false)\n   * }}}\n   *\n   * | ticks              | timestamp           |\n   * |:-------------------|:--------------------|\n   * | 638155413748959318 | 1679944574895931800 |\n   *\n   * https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks\n   *\n   * @param tickColumn\n   *   column with a tick value\n   * @return\n   *   result unix epoch nanoseconds column as long\n   */\n  def dotNetTicksToUnixEpochNanos(tickColumn: Column): Column = {\n    when(\n      tickColumn <= 713589688368547758L,\n      (tickColumn.cast(LongType) - unixEpochDotNetTicks) * nanoSecondsPerDotNetTick\n    )\n  }\n\n  /**\n   * Convert a .Net `DateTime.Ticks` timestamp to a Unix epoch nanoseconds. The input column must be convertible to a\n   * number (e.g. string, int, long). The full precision of the tick timestamp is preserved (1/10 of a microsecond).\n   *\n   * Example:\n   * {{{\n   *   df.select($\"ticks\", dotNetTicksToUnixEpochNanos(\"ticks\").as(\"timestamp\")).show(false)\n   * }}}\n   *\n   * | ticks              | timestamp           |\n   * |:-------------------|:--------------------|\n   * | 638155413748959318 | 1679944574895931800 |\n   *\n   * https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks\n   *\n   * @param tickColumnName\n   *   name of column with a tick value\n   * @return\n   *   result unix epoch nanoseconds column as long\n   */\n  def dotNetTicksToUnixEpochNanos(tickColumnName: String): Column = dotNetTicksToUnixEpochNanos(col(tickColumnName))\n\n  /**\n   * Convert a Spark timestamp to a .Net `DateTime.Ticks` timestamp. The input column must be of TimestampType.\n   *\n   * Example:\n   * {{{\n   *   df.select($\"timestamp\", timestampToDotNetTicks($\"timestamp\").as(\"ticks\")).show(false)\n   * }}}\n   *\n   * | timestamp                  | ticks              |\n   * |:---------------------------|:-------------------|\n   * | 2023-03-27 21:16:14.895931 | 638155413748959310 |\n   *\n   * https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks\n   *\n   * @param timestampColumn\n   *   column with a timestamp value\n   * @return\n   *   result tick value column\n   */\n  def timestampToDotNetTicks(timestampColumn: Column): Column =\n    unixEpochTenthMicrosToDotNetTicks(UnixMicros(timestampColumn.expr).column * 10)\n\n  /**\n   * Convert a Spark timestamp to a .Net `DateTime.Ticks` timestamp. The input column must be of TimestampType.\n   *\n   * Example:\n   * {{{\n   *   df.select($\"timestamp\", timestampToDotNetTicks(\"timestamp\").as(\"ticks\")).show(false)\n   * }}}\n   *\n   * | timestamp                  | ticks              |\n   * |:---------------------------|:-------------------|\n   * | 2023-03-27 21:16:14.895931 | 638155413748959310 |\n   *\n   * https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks\n   *\n   * @param timestampColumnName\n   *   name of column with a timestamp value\n   * @return\n   *   result tick value column\n   */\n  def timestampToDotNetTicks(timestampColumnName: String): Column = timestampToDotNetTicks(col(timestampColumnName))\n\n  /**\n   * Convert a Unix epoch timestamp to a .Net `DateTime.Ticks` timestamp. The input column must represent a numerical\n   * unix epoch timestamp, e.g. long, double, string or decimal. The input must not be of TimestampType, as that may be\n   * interpreted incorrectly. Use `timestampToDotNetTicks` for TimestampType columns instead.\n   *\n   * Example:\n   * {{{\n   *   df.select($\"unix\", unixEpochToDotNetTicks($\"unix\").as(\"ticks\")).show(false)\n   * }}}\n   *\n   * | unix                          | ticks              |\n   * |:------------------------------|:-------------------|\n   * | 1679944574.895931234000000000 | 638155413748959312 |\n   *\n   * https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks\n   *\n   * @param unixTimeColumn\n   *   column with a unix epoch timestamp value\n   * @return\n   *   result tick value column\n   */\n  def unixEpochToDotNetTicks(unixTimeColumn: Column): Column = unixEpochTenthMicrosToDotNetTicks(\n    unixTimeColumn.cast(DecimalType(19, 7)) * 10000000\n  )\n\n  /**\n   * Convert a Unix epoch timestamp to a .Net `DateTime.Ticks` timestamp. The input column must represent a numerical\n   * unix epoch timestamp, e.g. long, double, string or decimal. The input must not be of TimestampType, as that may be\n   * interpreted incorrectly. Use `timestampToDotNetTicks` for TimestampType columns instead.\n   *\n   * Example:\n   * {{{\n   *   df.select($\"unix\", unixEpochToDotNetTicks(\"unix\").as(\"ticks\")).show(false)\n   * }}}\n   *\n   * | unix                          | ticks              |\n   * |:------------------------------|:-------------------|\n   * | 1679944574.895931234000000000 | 638155413748959312 |\n   *\n   * https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks\n   *\n   * @param unixTimeColumnName\n   *   name of column with a unix epoch timestamp value\n   * @return\n   *   result tick value column\n   */\n  def unixEpochToDotNetTicks(unixTimeColumnName: String): Column = unixEpochToDotNetTicks(col(unixTimeColumnName))\n\n  /**\n   * Convert a Unix epoch nanosecond timestamp to a .Net `DateTime.Ticks` timestamp. The .Net ticks timestamp does not\n   * support the two lowest nanosecond digits, so only a 1/10 of a microsecond is the smallest resolution. The input\n   * column must represent a numerical unix epoch nanoseconds timestamp, e.g. long, double, string or decimal.\n   *\n   * Example:\n   * {{{\n   *   df.select($\"unix_nanos\", unixEpochNanosToDotNetTicks($\"unix_nanos\").as(\"ticks\")).show(false)\n   * }}}\n   *\n   * | unix_nanos          | ticks              |\n   * |:--------------------|:-------------------|\n   * | 1679944574895931234 | 638155413748959312 |\n   *\n   * Note: the example timestamp lacks the two lower nanosecond digits as this precision is not supported by .Net ticks.\n   *\n   * https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks\n   *\n   * @param unixNanosColumn\n   *   column with a unix epoch timestamp value\n   * @return\n   *   result tick value column\n   */\n  def unixEpochNanosToDotNetTicks(unixNanosColumn: Column): Column = unixEpochTenthMicrosToDotNetTicks(\n    unixNanosColumn.cast(DecimalType(21, 0)) / nanoSecondsPerDotNetTick\n  )\n\n  /**\n   * Convert a Unix epoch nanosecond timestamp to a .Net `DateTime.Ticks` timestamp. The .Net ticks timestamp does not\n   * support the two lowest nanosecond digits, so only a 1/10 of a microsecond is the smallest resolution. The input\n   * column must represent a numerical unix epoch nanoseconds timestamp, e.g. long, double, string or decimal.\n   *\n   * Example:\n   * {{{\n   *   df.select($\"unix_nanos\", unixEpochNanosToDotNetTicks($\"unix_nanos\").as(\"ticks\")).show(false)\n   * }}}\n   *\n   * | unix_nanos          | ticks              |\n   * |:--------------------|:-------------------|\n   * | 1679944574895931234 | 638155413748959312 |\n   *\n   * Note: the example timestamp lacks the two lower nanosecond digits as this precision is not supported by .Net ticks.\n   *\n   * https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks\n   *\n   * @param unixNanosColumnName\n   *   name of column with a unix epoch timestamp value\n   * @return\n   *   result tick value column\n   */\n  def unixEpochNanosToDotNetTicks(unixNanosColumnName: String): Column = unixEpochNanosToDotNetTicks(\n    col(unixNanosColumnName)\n  )\n\n  private def unixEpochTenthMicrosToDotNetTicks(unixNanosColumn: Column): Column =\n    unixNanosColumn.cast(LongType) + unixEpochDotNetTicks\n\n  /**\n   * Set the job description and return the earlier description. Only set the description if it is not set.\n   *\n   * @param description\n   *   job description\n   * @param ifNotSet\n   *   job description is only set if no description is set yet\n   * @param context\n   *   spark context\n   * @return\n   */\n  def setJobDescription(description: String, ifNotSet: Boolean = false)(implicit context: SparkContext): String = {\n    val earlierDescriptionOption = Option(context.getLocalProperty(\"spark.job.description\"))\n    if (earlierDescriptionOption.isEmpty || !ifNotSet) {\n      context.setJobDescription(description)\n    }\n    earlierDescriptionOption.orNull\n  }\n\n  /**\n   * Adds a job description to all Spark jobs started within the given function. The current Job description is restored\n   * after exit of the function.\n   *\n   * Usage example:\n   *\n   * {{{\n   *   import uk.co.gresearch.spark._\n   *\n   *   implicit val session: SparkSession = spark\n   *\n   *   val count = withJobDescription(\"parquet file\") {\n   *     val df = spark.read.parquet(\"data.parquet\")\n   *     df.count\n   *   }\n   * }}}\n   *\n   * With `ifNotSet == true`, the description is only set if no job description is set yet.\n   *\n   * Any modification to the job description during execution of the function is reverted, even if `ifNotSet == true`.\n   *\n   * @param description\n   *   job description\n   * @param ifNotSet\n   *   job description is only set if no description is set yet\n   * @param func\n   *   code to execute while job description is set\n   * @param session\n   *   spark session\n   * @tparam T\n   *   return type of func\n   */\n  def withJobDescription[T](description: String, ifNotSet: Boolean = false)(\n      func: => T\n  )(implicit session: SparkSession): T = {\n    val earlierDescription = setJobDescription(description, ifNotSet)(session.sparkContext)\n    try {\n      func\n    } finally {\n      setJobDescription(earlierDescription)(session.sparkContext)\n    }\n  }\n\n  /**\n   * Append the job description and return the earlier description.\n   *\n   * @param extraDescription\n   *   job description\n   * @param separator\n   *   separator to join exiting and extra description with\n   * @param context\n   *   spark context\n   * @return\n   */\n  def appendJobDescription(extraDescription: String, separator: String, context: SparkContext): String = {\n    val earlierDescriptionOption = Option(context.getLocalProperty(\"spark.job.description\"))\n    val description = earlierDescriptionOption.map(_ + separator + extraDescription).getOrElse(extraDescription)\n    context.setJobDescription(description)\n    earlierDescriptionOption.orNull\n  }\n\n  /**\n   * Appends a job description to all Spark jobs started within the given function. The current Job description is\n   * extended by the separator and the extra description on entering the function, and restored after exit of the\n   * function.\n   *\n   * Usage example:\n   *\n   * {{{\n   *   import uk.co.gresearch.spark._\n   *\n   *   implicit val session: SparkSession = spark\n   *\n   *   val count = appendJobDescription(\"parquet file\") {\n   *     val df = spark.read.parquet(\"data.parquet\")\n   *     appendJobDescription(\"count\") {\n   *       df.count\n   *     }\n   *   }\n   * }}}\n   *\n   * Any modification to the job description during execution of the function is reverted.\n   *\n   * @param extraDescription\n   *   job description to be appended\n   * @param separator\n   *   separator used when appending description\n   * @param func\n   *   code to execute while job description is set\n   * @param session\n   *   spark session\n   * @tparam T\n   *   return type of func\n   */\n  def appendJobDescription[T](extraDescription: String, separator: String = \" - \")(\n      func: => T\n  )(implicit session: SparkSession): T = {\n    val earlierDescription = appendJobDescription(extraDescription, separator, session.sparkContext)\n    try {\n      func\n    } finally {\n      setJobDescription(earlierDescription)(session.sparkContext)\n    }\n  }\n\n  /**\n   * Class to extend a Spark Dataset.\n   *\n   * @param ds\n   *   dataset\n   * @tparam V\n   *   inner type of dataset\n   */\n  @deprecated(\n    \"Constructor with encoder is deprecated, the encoder argument is ignored, ds.encoder is used instead.\",\n    since = \"2.9.0\"\n  )\n  class ExtendedDataset[V](ds: Dataset[V], encoder: Encoder[V]) {\n    private val eds = ExtendedDatasetV2[V](ds)\n\n    def histogram[T: Ordering](thresholds: Seq[T], valueColumn: Column, aggregateColumns: Column*): DataFrame =\n      eds.histogram(thresholds, valueColumn, aggregateColumns: _*)\n\n    def writePartitionedBy(\n        partitionColumns: Seq[Column],\n        moreFileColumns: Seq[Column] = Seq.empty,\n        moreFileOrder: Seq[Column] = Seq.empty,\n        partitions: Option[Int] = None,\n        writtenProjection: Option[Seq[Column]] = None,\n        unpersistHandle: Option[UnpersistHandle] = None\n    ): DataFrameWriter[Row] =\n      eds.writePartitionedBy(\n        partitionColumns,\n        moreFileColumns,\n        moreFileOrder,\n        partitions,\n        writtenProjection,\n        unpersistHandle\n      )\n\n    def groupBySorted[K: Ordering: Encoder](cols: Column*)(order: Column*): SortedGroupByDataset[K, V] =\n      eds.groupBySorted(cols: _*)(order: _*)\n\n    def groupBySorted[K: Ordering: Encoder](partitions: Int)(cols: Column*)(\n        order: Column*\n    ): SortedGroupByDataset[K, V] =\n      eds.groupBySorted(partitions)(cols: _*)(order: _*)\n\n    def groupByKeySorted[K: Ordering: Encoder, O: Encoder](key: V => K, partitions: Int)(\n        order: V => O\n    ): SortedGroupByDataset[K, V] =\n      eds.groupByKeySorted(key, Some(partitions))(order)\n\n    def groupByKeySorted[K: Ordering: Encoder, O: Encoder](key: V => K, partitions: Int)(\n        order: V => O,\n        reverse: Boolean\n    ): SortedGroupByDataset[K, V] =\n      eds.groupByKeySorted(key, Some(partitions))(order, reverse)\n\n    def groupByKeySorted[K: Ordering: Encoder, O: Encoder](key: V => K, partitions: Option[Int] = None)(\n        order: V => O,\n        reverse: Boolean = false\n    ): SortedGroupByDataset[K, V] =\n      eds.groupByKeySorted(key, partitions)(order, reverse)\n\n    def withRowNumbers(order: Column*): DataFrame =\n      eds.withRowNumbers(order: _*)\n\n    def withRowNumbers(rowNumberColumnName: String, order: Column*): DataFrame =\n      eds.withRowNumbers(rowNumberColumnName, order: _*)\n\n    def withRowNumbers(storageLevel: StorageLevel, order: Column*): DataFrame =\n      eds.withRowNumbers(storageLevel, order: _*)\n\n    def withRowNumbers(unpersistHandle: UnpersistHandle, order: Column*): DataFrame =\n      eds.withRowNumbers(unpersistHandle, order: _*)\n\n    def withRowNumbers(rowNumberColumnName: String, storageLevel: StorageLevel, order: Column*): DataFrame =\n      eds.withRowNumbers(rowNumberColumnName, storageLevel, order: _*)\n\n    def withRowNumbers(rowNumberColumnName: String, unpersistHandle: UnpersistHandle, order: Column*): DataFrame =\n      eds.withRowNumbers(rowNumberColumnName, unpersistHandle, order: _*)\n\n    def withRowNumbers(storageLevel: StorageLevel, unpersistHandle: UnpersistHandle, order: Column*): DataFrame =\n      eds.withRowNumbers(storageLevel, unpersistHandle, order: _*)\n\n    def withRowNumbers(\n        rowNumberColumnName: String,\n        storageLevel: StorageLevel,\n        unpersistHandle: UnpersistHandle,\n        order: Column*\n    ): DataFrame =\n      eds.withRowNumbers(rowNumberColumnName, storageLevel, unpersistHandle, order: _*)\n  }\n\n  /**\n   * Class to extend a Spark Dataset.\n   *\n   * @param ds\n   *   dataset\n   * @tparam V\n   *   inner type of dataset\n   */\n  def ExtendedDataset[V](ds: Dataset[V], encoder: Encoder[V]): ExtendedDataset[V] = new ExtendedDataset(ds, encoder)\n\n  /**\n   * Implicit class to extend a Spark Dataset.\n   *\n   * @param ds\n   *   dataset\n   * @tparam V\n   *   inner type of dataset\n   */\n  implicit class ExtendedDatasetV2[V](ds: Dataset[V]) {\n    private implicit val encoder: Encoder[V] = ds.encoder\n\n    /**\n     * Compute the histogram of a column when aggregated by aggregate columns. Thresholds are expected to be provided in\n     * ascending order. The result dataframe contains the aggregate and histogram columns only. For each threshold value\n     * in thresholds, there will be a column named s\"≤threshold\". There will also be a final column called\n     * s\">last_threshold\", that counts the remaining values that exceed the last threshold.\n     *\n     * @param thresholds\n     *   sequence of thresholds, must implement <= and > operators w.r.t. valueColumn\n     * @param valueColumn\n     *   histogram is computed for values of this column\n     * @param aggregateColumns\n     *   histogram is computed against these columns\n     * @tparam T\n     *   type of histogram thresholds\n     * @return\n     *   dataframe with aggregate and histogram columns\n     */\n    def histogram[T: Ordering](thresholds: Seq[T], valueColumn: Column, aggregateColumns: Column*): DataFrame =\n      Histogram.of(ds, thresholds, valueColumn, aggregateColumns: _*)\n\n    /**\n     * Writes the Dataset / DataFrame via DataFrameWriter.partitionBy. In addition to partitionBy, this method sorts the\n     * data to improve partition file size. Small partitions will contain few files, large partitions contain more\n     * files. Partition ids are contained in a single partition file per `partitionBy` partition only. Rows within the\n     * partition files are also sorted, if partitionOrder is defined.\n     *\n     * Note: With Spark 3.0, 3.1, 3.2 before 3.2.3, 3.3 before 3.3.2, and AQE enabled, an intermediate DataFrame is\n     * being cached in order to guarantee sorted output files. See https://issues.apache.org/jira/browse/SPARK-40588.\n     * That cached DataFrame can be unpersisted via an optional [[UnpersistHandle]] provided to this method.\n     *\n     * Calling:\n     * {{{\n     *   val unpersist = UnpersistHandle()\n     *   val writer = df.writePartitionedBy(Seq(\"a\"), Seq(\"b\"), Seq(\"c\"), Some(10), Seq($\"a\", concat($\"b\", $\"c\")), unpersist)\n     *   writer.parquet(\"data.parquet\")\n     *   unpersist()\n     * }}}\n     *\n     * is equivalent to:\n     * {{{\n     *   val cached =\n     *     df.repartitionByRange(10, $\"a\", $\"b\")\n     *       .sortWithinPartitions($\"a\", $\"b\", $\"c\")\n     *       .cache\n     *\n     *   val writer =\n     *     cached\n     *       .select($\"a\", concat($\"b\", $\"c\"))\n     *       .write\n     *       .partitionBy(\"a\")\n     *\n     *   writer.parquet(\"data.parquet\")\n     *\n     *   cached.unpersist\n     * }}}\n     *\n     * @param partitionColumns\n     *   columns used for partitioning\n     * @param moreFileColumns\n     *   columns where individual values are written to a single file\n     * @param moreFileOrder\n     *   additional columns to sort partition files\n     * @param partitions\n     *   optional number of partition files\n     * @param writtenProjection\n     *   additional transformation to be applied before calling write\n     * @param unpersistHandle\n     *   handle to unpersist internally created DataFrame after writing\n     * @return\n     *   configured DataFrameWriter\n     */\n    def writePartitionedBy(\n        partitionColumns: Seq[Column],\n        moreFileColumns: Seq[Column] = Seq.empty,\n        moreFileOrder: Seq[Column] = Seq.empty,\n        partitions: Option[Int] = None,\n        writtenProjection: Option[Seq[Column]] = None,\n        unpersistHandle: Option[UnpersistHandle] = None\n    ): DataFrameWriter[Row] = {\n      if (partitionColumns.isEmpty)\n        throw new IllegalArgumentException(s\"partition columns must not be empty\")\n\n      if (partitionColumns.exists(col => !col.isInstanceOf[ColumnName] && !col.expr.isInstanceOf[NamedExpression]))\n        throw new IllegalArgumentException(s\"partition columns must be named: ${partitionColumns.mkString(\",\")}\")\n\n      val requiresCaching = writePartitionedByRequiresCaching(ds)\n      (requiresCaching, unpersistHandle.isDefined) match {\n        case (true, false) =>\n          warning(\n            \"Partitioned-writing with AQE enabled and Spark 3.0, 3.1, 3.2 below 3.2.3, \" +\n              \"and 3.3 below 3.3.2 requires caching an intermediate DataFrame, \" +\n              \"which calling code has to unpersist once writing is done. \" +\n              \"Please provide an UnpersistHandle to DataFrame.writePartitionedBy, or UnpersistHandle.Noop. \" +\n              \"See https://issues.apache.org/jira/browse/SPARK-40588\"\n          )\n        case (false, true) if !unpersistHandle.get.isInstanceOf[NoopUnpersistHandle] =>\n          info(\n            \"UnpersistHandle provided to DataFrame.writePartitionedBy is not needed as \" +\n              \"partitioned-writing with AQE disabled or Spark 3.2.3, 3.3.2 or 3.4 and above \" +\n              \"does not require caching intermediate DataFrame.\"\n          )\n          unpersistHandle.get.setDataFrame(ds.sparkSession.emptyDataFrame)\n        case _ =>\n      }\n      // resolve partition column names\n      val partitionColumnNames = ds.select(partitionColumns: _*).queryExecution.analyzed.output.map(_.name)\n      val partitionColumnsMap = partitionColumnNames.zip(partitionColumns).toMap\n      val rangeColumns = partitionColumnNames.map(col) ++ moreFileColumns\n      val sortColumns = partitionColumnNames.map(col) ++ moreFileColumns ++ moreFileOrder\n      ds.toDF\n        .call(ds => partitionColumnsMap.foldLeft(ds) { case (ds, (name, col)) => ds.withColumn(name, col) })\n        .when(partitions.isEmpty)\n        .call(_.repartitionByRange(rangeColumns: _*))\n        .when(partitions.isDefined)\n        .call(_.repartitionByRange(partitions.get, rangeColumns: _*))\n        .sortWithinPartitions(sortColumns: _*)\n        .when(writtenProjection.isDefined)\n        .call(_.select(writtenProjection.get: _*))\n        .when(requiresCaching && unpersistHandle.isDefined)\n        .call(unpersistHandle.get.setDataFrame(_))\n        .write\n        .partitionBy(partitionColumnsMap.keys.toSeq: _*)\n    }\n\n    /**\n     * (Scala-specific) Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given key columns.\n     *\n     * @see\n     *   `org.apache.spark.sql.Dataset.groupByKey(T => K)`\n     *\n     * @note\n     *   Calling this method should be preferred to `groupByKey(T => K)` because the Catalyst query planner cannot\n     *   exploit existing partitioning and ordering of this Dataset with that function.\n     *\n     * {{{\n     *   ds.groupByKey[Int]($\"age\").flatMapGroups(...)\n     *   ds.groupByKey[(String, String)]($\"department\", $\"gender\").flatMapGroups(...)\n     * }}}\n     */\n    def groupByKey[K: Encoder](column: Column, columns: Column*): KeyValueGroupedDataset[K, V] =\n      ds.groupBy(column +: columns: _*).as[K, V]\n\n    /**\n     * (Scala-specific) Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given key columns.\n     *\n     * @see\n     *   `org.apache.spark.sql.Dataset.groupByKey(T => K)`\n     *\n     * @note\n     *   Calling this method should be preferred to `groupByKey(T => K)` because the Catalyst query planner cannot\n     *   exploit existing partitioning and ordering of this Dataset with that function.\n     *\n     * {{{\n     *   ds.groupByKey[Int]($\"age\").flatMapGroups(...)\n     *   ds.groupByKey[(String, String)]($\"department\", $\"gender\").flatMapGroups(...)\n     * }}}\n     */\n    def groupByKey[K: Encoder](column: String, columns: String*): KeyValueGroupedDataset[K, V] =\n      ds.groupBy(column, columns: _*).as[K, V]\n\n    /**\n     * Groups the Dataset and sorts the groups using the specified columns, so we can run further process the sorted\n     * groups. See [[uk.co.gresearch.spark.group.SortedGroupByDataset]] for all the available functions.\n     *\n     * {{{\n     *   // Enumerate elements in the sorted group\n     *   ds.groupBySorted($\"department\")($\"salery\")\n     *     .flatMapSortedGroups((key, it) => it.zipWithIndex)\n     * }}}\n     *\n     * @param cols\n     *   grouping columns\n     * @param order\n     *   sort columns\n     */\n    def groupBySorted[K: Ordering: Encoder](cols: Column*)(order: Column*): SortedGroupByDataset[K, V] = {\n      SortedGroupByDataset(ds, cols, order, None)\n    }\n\n    /**\n     * Groups the Dataset and sorts the groups using the specified columns, so we can run further process the sorted\n     * groups. See [[uk.co.gresearch.spark.group.SortedGroupByDataset]] for all the available functions.\n     *\n     * {{{\n     *   // Enumerate elements in the sorted group\n     *   ds.groupBySorted(10)($\"department\")($\"salery\")\n     *     .flatMapSortedGroups((key, it) => it.zipWithIndex)\n     * }}}\n     *\n     * @param partitions\n     *   number of partitions\n     * @param cols\n     *   grouping columns\n     * @param order\n     *   sort columns\n     */\n    def groupBySorted[K: Ordering: Encoder](\n        partitions: Int\n    )(cols: Column*)(order: Column*): SortedGroupByDataset[K, V] = {\n      SortedGroupByDataset(ds, cols, order, Some(partitions))\n    }\n\n    /**\n     * Groups the Dataset and sorts the groups using the specified columns, so we can run further process the sorted\n     * groups. See [[uk.co.gresearch.spark.group.SortedGroupByDataset]] for all the available functions.\n     *\n     * {{{\n     *   // Enumerate elements in the sorted group\n     *   ds.groupByKeySorted(row => row.getInt(0), 10)(row => row.getInt(1))\n     *     .flatMapSortedGroups((key, it) => it.zipWithIndex)\n     * }}}\n     *\n     * @param partitions\n     *   number of partitions\n     * @param key\n     *   grouping key\n     * @param order\n     *   sort key\n     */\n    def groupByKeySorted[K: Ordering: Encoder, O: Encoder](key: V => K, partitions: Int)(\n        order: V => O\n    ): SortedGroupByDataset[K, V] =\n      groupByKeySorted(key, Some(partitions))(order)\n\n    /**\n     * Groups the Dataset and sorts the groups using the specified columns, so we can run further process the sorted\n     * groups. See [[uk.co.gresearch.spark.group.SortedGroupByDataset]] for all the available functions.\n     *\n     * {{{\n     *   // Enumerate elements in the sorted group\n     *   ds.groupByKeySorted(row => row.getInt(0), 10)(row => row.getInt(1), true)\n     *     .flatMapSortedGroups((key, it) => it.zipWithIndex)\n     * }}}\n     *\n     * @param partitions\n     *   number of partitions\n     * @param key\n     *   grouping key\n     * @param order\n     *   sort key\n     * @param reverse\n     *   sort reverse order\n     */\n    def groupByKeySorted[K: Ordering: Encoder, O: Encoder](key: V => K, partitions: Int)(\n        order: V => O,\n        reverse: Boolean\n    ): SortedGroupByDataset[K, V] =\n      groupByKeySorted(key, Some(partitions))(order, reverse)\n\n    /**\n     * Groups the Dataset and sorts the groups using the specified columns, so we can run further process the sorted\n     * groups. See [[uk.co.gresearch.spark.group.SortedGroupByDataset]] for all the available functions.\n     *\n     * {{{\n     *   // Enumerate elements in the sorted group\n     *   ds.groupByKeySorted(row => row.getInt(0))(row => row.getInt(1), true)\n     *     .flatMapSortedGroups((key, it) => it.zipWithIndex)\n     * }}}\n     *\n     * @param partitions\n     *   optional number of partitions\n     * @param key\n     *   grouping key\n     * @param order\n     *   sort key\n     * @param reverse\n     *   sort reverse order\n     */\n    def groupByKeySorted[K: Ordering: Encoder, O: Encoder](\n        key: V => K,\n        partitions: Option[Int] = None\n    )(order: V => O, reverse: Boolean = false): SortedGroupByDataset[K, V] = {\n      SortedGroupByDataset(ds, key, order, partitions, reverse)\n    }\n\n    /**\n     * Adds a global continuous row number starting at 1.\n     *\n     * See [[withRowNumbers(String,StorageLevel,UnpersistHandle,Column...)]] for details.\n     */\n    def withRowNumbers(order: Column*): DataFrame =\n      RowNumbers.withOrderColumns(order: _*).of(ds)\n\n    /**\n     * Adds a global continuous row number starting at 1.\n     *\n     * See [[withRowNumbers(String,StorageLevel,UnpersistHandle,Column...)]] for details.\n     */\n    def withRowNumbers(rowNumberColumnName: String, order: Column*): DataFrame =\n      RowNumbers.withRowNumberColumnName(rowNumberColumnName).withOrderColumns(order).of(ds)\n\n    /**\n     * Adds a global continuous row number starting at 1.\n     *\n     * See [[withRowNumbers(String,StorageLevel,UnpersistHandle,Column...)]] for details.\n     */\n    def withRowNumbers(storageLevel: StorageLevel, order: Column*): DataFrame =\n      RowNumbers.withStorageLevel(storageLevel).withOrderColumns(order).of(ds)\n\n    /**\n     * Adds a global continuous row number starting at 1.\n     *\n     * See [[withRowNumbers(String,StorageLevel,UnpersistHandle,Column...)]] for details.\n     */\n    def withRowNumbers(unpersistHandle: UnpersistHandle, order: Column*): DataFrame =\n      RowNumbers.withUnpersistHandle(unpersistHandle).withOrderColumns(order).of(ds)\n\n    /**\n     * Adds a global continuous row number starting at 1.\n     *\n     * See [[withRowNumbers(String,StorageLevel,UnpersistHandle,Column...)]] for details.\n     */\n    def withRowNumbers(rowNumberColumnName: String, storageLevel: StorageLevel, order: Column*): DataFrame =\n      RowNumbers\n        .withRowNumberColumnName(rowNumberColumnName)\n        .withStorageLevel(storageLevel)\n        .withOrderColumns(order)\n        .of(ds)\n\n    /**\n     * Adds a global continuous row number starting at 1.\n     *\n     * See [[withRowNumbers(String,StorageLevel,UnpersistHandle,Column...)]] for details.\n     */\n    def withRowNumbers(rowNumberColumnName: String, unpersistHandle: UnpersistHandle, order: Column*): DataFrame =\n      RowNumbers\n        .withRowNumberColumnName(rowNumberColumnName)\n        .withUnpersistHandle(unpersistHandle)\n        .withOrderColumns(order)\n        .of(ds)\n\n    /**\n     * Adds a global continuous row number starting at 1.\n     *\n     * See [[withRowNumbers(String,StorageLevel,UnpersistHandle,Column...)]] for details.\n     */\n    def withRowNumbers(storageLevel: StorageLevel, unpersistHandle: UnpersistHandle, order: Column*): DataFrame =\n      RowNumbers.withStorageLevel(storageLevel).withUnpersistHandle(unpersistHandle).withOrderColumns(order).of(ds)\n\n    /**\n     * Adds a global continuous row number starting at 1, after sorting rows by the given columns. When no columns are\n     * given, the existing order is used.\n     *\n     * Hence, the following examples are equivalent:\n     * {{{\n     *   ds.withRowNumbers($\"a\".desc, $\"b\")\n     *   ds.orderBy($\"a\".desc, $\"b\").withRowNumbers()\n     * }}}\n     *\n     * The column name of the column with the row numbers can be set via the `rowNumberColumnName` argument.\n     *\n     * To avoid some known issues optimizing the query plan, this function has to internally call\n     * `Dataset.persist(StorageLevel)` on an intermediate DataFrame. The storage level of that cached DataFrame can be\n     * set via `storageLevel`, where the default is `StorageLevel.MEMORY_AND_DISK`.\n     *\n     * That cached intermediate DataFrame can be un-persisted / un-cached as follows:\n     * {{{\n     *   import uk.co.gresearch.spark.UnpersistHandle\n     *\n     *   val unpersist = UnpersistHandle()\n     *   ds.withRowNumbers(unpersist).show()\n     *   unpersist()\n     * }}}\n     *\n     * @param rowNumberColumnName\n     *   name of the row number column\n     * @param storageLevel\n     *   storage level of the cached intermediate DataFrame\n     * @param unpersistHandle\n     *   handle to un-persist intermediate DataFrame\n     * @param order\n     *   columns to order dataframe before assigning row numbers\n     * @return\n     *   dataframe with row numbers\n     */\n    def withRowNumbers(\n        rowNumberColumnName: String,\n        storageLevel: StorageLevel,\n        unpersistHandle: UnpersistHandle,\n        order: Column*\n    ): DataFrame =\n      RowNumbers\n        .withRowNumberColumnName(rowNumberColumnName)\n        .withStorageLevel(storageLevel)\n        .withUnpersistHandle(unpersistHandle)\n        .withOrderColumns(order)\n        .of(ds)\n  }\n\n  /**\n   * Class to extend a Spark Dataframe.\n   *\n   * @param df\n   *   dataframe\n   */\n  @deprecated(\"Implicit class ExtendedDataframe is deprecated, please recompile your source code.\", since = \"2.9.0\")\n  class ExtendedDataframe(df: DataFrame) extends ExtendedDataset[Row](df, df.encoder)\n\n  /**\n   * Class to extend a Spark Dataframe.\n   *\n   * @param df\n   *   dataframe\n   */\n  def ExtendedDataframe(df: DataFrame): ExtendedDataframe = new ExtendedDataframe(df)\n\n  /**\n   * Implicit class to extend a Spark Dataframe, which is a Dataset[Row].\n   *\n   * @param df\n   *   dataframe\n   */\n  implicit class ExtendedDataframeV2(df: DataFrame) extends ExtendedDatasetV2[Row](df)\n\n}\n"
  },
  {
    "path": "src/main/scala/uk/co/gresearch/spark/parquet/ParquetMetaDataUtil.scala",
    "content": "/*\n * Copyright 2023 G-Research\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage uk.co.gresearch.spark.parquet\n\nimport org.apache.parquet.crypto.ParquetCryptoRuntimeException\nimport org.apache.parquet.hadoop.Footer\nimport org.apache.parquet.hadoop.metadata.{BlockMetaData, ColumnChunkMetaData, FileMetaData}\nimport org.apache.parquet.schema.PrimitiveType\n\nimport scala.reflect.{ClassTag, classTag}\nimport scala.util.Try\nimport scala.collection.convert.ImplicitConversions.`iterable AsScalaIterable`\n\nprivate trait MethodGuard {\n  def isSupported[T: ClassTag](methodName: String): Boolean = {\n    Try(classTag[T].runtimeClass.getMethod(methodName)).isSuccess\n  }\n\n  def guard[T, R](supported: Boolean)(f: T => R): T => Option[R] =\n    guardOption(supported)(t => Some(f(t)))\n\n  def guardOption[T, R](supported: Boolean)(f: T => Option[R]): T => Option[R] =\n    if (supported) { (v: T) =>\n      f(v)\n    } else { (_: T) =>\n      None\n    }\n}\n\n/**\n * Guard access to possibly encrypted and inaccessible metadata of a footer.\n *   - If footer is encrypted while we have no decryption keys, metadata values are None.\n *   - If footer is known not to be encrypted, metadata values are Some.\n *   - If we don't know whether the footer is encrypted, we access some metadata that we could not read if encrypted to\n *     determine the encryption state of the footer.\n */\nprivate case class FooterGuard(footer: Footer) {\n  lazy val isSafe: Boolean = {\n    // having a decryptor tells us this file is expected to be decryptable\n    Option(footer.getParquetMetadata.getFileMetaData.getFileDecryptor)\n      // otherwise, when we have an unencrypted file, we are also safe to access f\n      .orElse(\n        ParquetMetaDataUtil\n          .getEncryptionType(footer.getParquetMetadata.getFileMetaData)\n          .filter(_ == \"UNENCRYPTED\")\n      )\n      // turn to Some(true) if safe, None if unknown\n      .map(_ => true)\n      // otherwise, we access some metadata that if the footer is encrypted would fail\n      .orElse(\n        Some(\n          Try(footer.getParquetMetadata.getBlocks.headOption.map(_.getTotalByteSize))\n            // get hold of the possible exception\n            .toEither.swap.toOption\n            // no exception means safe, ignore exceptions other than ParquetCryptoRuntimeException\n            .exists(!_.isInstanceOf[ParquetCryptoRuntimeException])\n        )\n      )\n      // now is Some(true) or Some(false)\n      .get\n  }\n\n  private[parquet] def apply[T](f: => T): Option[T] = {\n    if (isSafe) { Some(f) }\n    else { None }\n  }\n}\n\nprivate[parquet] object ParquetMetaDataUtil extends MethodGuard {\n  lazy val getEncryptionTypeIsSupported: Boolean =\n    isSupported[FileMetaData](\"getEncryptionType\")\n  lazy val getEncryptionType: FileMetaData => Option[String] =\n    guard(getEncryptionTypeIsSupported) { fileMetaData: FileMetaData =>\n      fileMetaData.getEncryptionType.name()\n    }\n\n  lazy val getLogicalTypeAnnotationIsSupported: Boolean =\n    isSupported[PrimitiveType](\"getLogicalTypeAnnotation\")\n  lazy val getLogicalTypeAnnotation: PrimitiveType => Option[String] =\n    guardOption(getLogicalTypeAnnotationIsSupported) { (primitive: PrimitiveType) =>\n      Option(primitive.getLogicalTypeAnnotation).map(_.toString)\n    }\n\n  lazy val getOrdinalIsSupported: Boolean =\n    isSupported[BlockMetaData](\"getOrdinal\")\n  lazy val getOrdinal: BlockMetaData => Option[Int] =\n    guard(getOrdinalIsSupported) { (block: BlockMetaData) =>\n      block.getOrdinal\n    }\n\n  lazy val isEncryptedIsSupported: Boolean =\n    isSupported[ColumnChunkMetaData](\"isEncrypted\")\n  lazy val isEncrypted: ColumnChunkMetaData => Option[Boolean] =\n    guard(isEncryptedIsSupported) { (column: ColumnChunkMetaData) =>\n      column.isEncrypted\n    }\n}\n"
  },
  {
    "path": "src/main/scala/uk/co/gresearch/spark/parquet/package.scala",
    "content": "/*\n * Copyright 2023 G-Research\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage uk.co.gresearch.spark\n\n// hadoop and parquet dependencies provided by Spark\nimport org.apache.hadoop.conf.Configuration\nimport org.apache.hadoop.fs.Path\nimport org.apache.parquet.hadoop.metadata.BlockMetaData\nimport org.apache.parquet.hadoop.{Footer, ParquetFileReader}\nimport org.apache.spark.SparkContext\nimport org.apache.spark.sql._\nimport org.apache.spark.sql.execution.datasources.FilePartition\nimport uk.co.gresearch._\n\nimport scala.collection.JavaConverters.{collectionAsScalaIterableConverter, mapAsScalaMapConverter}\nimport scala.collection.convert.ImplicitConversions.`iterable AsScalaIterable`\n\npackage object parquet {\n\n  private def conf: Configuration = SparkContext.getOrCreate().hadoopConfiguration\n\n  /**\n   * Implicit class to extend a Spark DataFrameReader.\n   *\n   * @param reader\n   *   data frame reader\n   */\n  implicit class ExtendedDataFrameReader(reader: DataFrameReader) {\n\n    /**\n     * Read the metadata of Parquet files into a Dataframe.\n     *\n     * The returned DataFrame has as many partitions as there are Parquet files, at most\n     * `spark.sparkContext.defaultParallelism` partitions.\n     *\n     * This provides the following per-file information:\n     *   - filename (string): The file name\n     *   - blocks (int): Number of blocks / RowGroups in the Parquet file\n     *   - compressedBytes (long): Number of compressed bytes of all blocks\n     *   - uncompressedBytes (long): Number of uncompressed bytes of all blocks\n     *   - rows (long): Number of rows in the file\n     *   - columns (int): Number of columns in the file\n     *   - values (long): Number of values in the file\n     *   - nulls (long): Number of null values in the file\n     *   - createdBy (string): The createdBy string of the Parquet file, e.g. library used to write the file\n     *   - schema (string): The schema\n     *   - encryption (string): The encryption\n     *   - keyValues (string-to-string map): Key-value data of the file\n     *\n     * @param paths\n     *   one or more paths to Parquet files or directories\n     * @return\n     *   dataframe with Parquet metadata\n     */\n    @scala.annotation.varargs\n    def parquetMetadata(paths: String*): DataFrame = parquetMetadata(None, paths)\n\n    /**\n     * Read the metadata of Parquet files into a Dataframe.\n     *\n     * The returned DataFrame has as many partitions as specified via `parallelism`.\n     *\n     * This provides the following per-file information:\n     *   - filename (string): The file name\n     *   - blocks (int): Number of blocks / RowGroups in the Parquet file\n     *   - compressedBytes (long): Number of compressed bytes of all blocks\n     *   - uncompressedBytes (long): Number of uncompressed bytes of all blocks\n     *   - rows (long): Number of rows in the file\n     *   - columns (int): Number of columns in the file\n     *   - values (long): Number of values in the file\n     *   - nulls (long): Number of null values in the file\n     *   - createdBy (string): The createdBy string of the Parquet file, e.g. library used to write the file\n     *   - schema (string): The schema\n     *   - encryption (string): The encryption\n     *   - keyValues (string-to-string map): Key-value data of the file\n     *\n     * @param parallelism\n     *   number of partitions of returned DataFrame\n     * @param paths\n     *   one or more paths to Parquet files or directories\n     * @return\n     *   dataframe with Parquet metadata\n     */\n    @scala.annotation.varargs\n    def parquetMetadata(parallelism: Int, paths: String*): DataFrame = parquetMetadata(Some(parallelism), paths)\n\n    private def parquetMetadata(parallelism: Option[Int], paths: Seq[String]): DataFrame = {\n      val files = getFiles(parallelism, paths)\n\n      import files.sparkSession.implicits._\n\n      files\n        .flatMap { case (_, file) =>\n          readFooters(file).map { footer =>\n            val guard = FooterGuard(footer)\n            (\n              footer.getFile.toString,\n              footer.getParquetMetadata.getBlocks.size(),\n              guard { footer.getParquetMetadata.getBlocks.asScala.map(_.getCompressedSize).sum },\n              guard { footer.getParquetMetadata.getBlocks.asScala.map(_.getTotalByteSize).sum },\n              footer.getParquetMetadata.getBlocks.asScala.map(_.getRowCount).sum,\n              footer.getParquetMetadata.getFileMetaData.getSchema.getColumns.size(),\n              guard {\n                footer.getParquetMetadata.getBlocks.asScala.map(_.getColumns.map(_.getValueCount).sum).sum\n              },\n              // when all columns have statistics, count the null values\n              guard {\n                Option(\n                  footer.getParquetMetadata.getBlocks.asScala.flatMap(_.getColumns.map(c => Option(c.getStatistics)))\n                )\n                  .filter(_.forall(_.isDefined))\n                  .map(_.map(_.get.getNumNulls).sum)\n              },\n              footer.getParquetMetadata.getFileMetaData.getCreatedBy,\n              footer.getParquetMetadata.getFileMetaData.getSchema.toString,\n              ParquetMetaDataUtil.getEncryptionType(footer.getParquetMetadata.getFileMetaData),\n              footer.getParquetMetadata.getFileMetaData.getKeyValueMetaData.asScala,\n            )\n          }\n        }\n        .toDF(\n          \"filename\",\n          \"blocks\",\n          \"compressedBytes\",\n          \"uncompressedBytes\",\n          \"rows\",\n          \"columns\",\n          \"values\",\n          \"nulls\",\n          \"createdBy\",\n          \"schema\",\n          \"encryption\",\n          \"keyValues\"\n        )\n    }\n\n    /**\n     * Read the schema of Parquet files into a Dataframe.\n     *\n     * The returned DataFrame has as many partitions as there are Parquet files, at most\n     * `spark.sparkContext.defaultParallelism` partitions.\n     *\n     * This provides the following per-file information:\n     *   - filename (string): The Parquet file name\n     *   - columnName (string): The column name\n     *   - columnPath (string array): The column path\n     *   - repetition (string): The repetition\n     *   - type (string): The data type\n     *   - length (int): The length of the type\n     *   - originalType (string): The original type\n     *   - isPrimitive (boolean: True if type is primitive\n     *   - primitiveType (string: The primitive type\n     *   - primitiveOrder (string: The order of the primitive type\n     *   - maxDefinitionLevel (int): The max definition level\n     *   - maxRepetitionLevel (int): The max repetition level\n     *\n     * @param paths\n     *   one or more paths to Parquet files or directories\n     * @return\n     *   dataframe with Parquet metadata\n     */\n    @scala.annotation.varargs\n    def parquetSchema(paths: String*): DataFrame = parquetSchema(None, paths)\n\n    /**\n     * Read the schema of Parquet files into a Dataframe.\n     *\n     * The returned DataFrame has as many partitions as specified via `parallelism`.\n     *\n     * This provides the following per-file information:\n     *   - filename (string): The Parquet file name\n     *   - columnName (string): The column name\n     *   - columnPath (string array): The column path\n     *   - repetition (string): The repetition\n     *   - type (string): The data type\n     *   - length (int): The length of the type\n     *   - originalType (string): The original type\n     *   - isPrimitive (boolean: True if type is primitive\n     *   - primitiveType (string: The primitive type\n     *   - primitiveOrder (string: The order of the primitive type\n     *   - maxDefinitionLevel (int): The max definition level\n     *   - maxRepetitionLevel (int): The max repetition level\n     *\n     * @param parallelism\n     *   number of partitions of returned DataFrame\n     * @param paths\n     *   one or more paths to Parquet files or directories\n     * @return\n     *   dataframe with Parquet metadata\n     */\n    @scala.annotation.varargs\n    def parquetSchema(parallelism: Int, paths: String*): DataFrame = parquetSchema(Some(parallelism), paths)\n\n    private def parquetSchema(parallelism: Option[Int], paths: Seq[String]): DataFrame = {\n      val files = getFiles(parallelism, paths)\n\n      import files.sparkSession.implicits._\n\n      files\n        .flatMap { case (_, file) =>\n          readFooters(file).flatMap { footer =>\n            footer.getParquetMetadata.getFileMetaData.getSchema.getColumns.map { column =>\n              (\n                footer.getFile.toString,\n                Option(column.getPrimitiveType).map(_.getName),\n                column.getPath,\n                Option(column.getPrimitiveType).flatMap(v => Option(v.getRepetition)).map(_.name),\n                Option(column.getPrimitiveType).flatMap(v => Option(v.getPrimitiveTypeName)).map(_.name),\n                Option(column.getPrimitiveType).map(_.getTypeLength),\n                Option(column.getPrimitiveType).flatMap(v => Option(v.getOriginalType)).map(_.name),\n                Option(column.getPrimitiveType).flatMap(ParquetMetaDataUtil.getLogicalTypeAnnotation),\n                column.getPrimitiveType.isPrimitive,\n                Option(column.getPrimitiveType).map(_.getPrimitiveTypeName.name),\n                Option(column.getPrimitiveType).flatMap(v => Option(v.columnOrder)).map(_.getColumnOrderName.name),\n                column.getMaxDefinitionLevel,\n                column.getMaxRepetitionLevel,\n              )\n            }\n          }\n        }\n        .toDF(\n          \"filename\",\n          \"columnName\",\n          \"columnPath\",\n          \"repetition\",\n          \"type\",\n          \"length\",\n          \"originalType\",\n          \"logicalType\",\n          \"isPrimitive\",\n          \"primitiveType\",\n          \"primitiveOrder\",\n          \"maxDefinitionLevel\",\n          \"maxRepetitionLevel\",\n        )\n    }\n\n    /**\n     * Read the metadata of Parquet blocks into a Dataframe.\n     *\n     * The returned DataFrame has as many partitions as there are Parquet files, at most\n     * `spark.sparkContext.defaultParallelism` partitions.\n     *\n     * This provides the following per-block information:\n     *   - filename (string): The file name\n     *   - block (int): Block / RowGroup number starting at 1\n     *   - blockStart (long): Start position of the block in the Parquet file\n     *   - compressedBytes (long): Number of compressed bytes in block\n     *   - uncompressedBytes (long): Number of uncompressed bytes in block\n     *   - rows (long): Number of rows in block\n     *   - columns (int): Number of columns in block\n     *   - values (long): Number of values in block\n     *   - nulls (long): Number of null values in block\n     *\n     * @param paths\n     *   one or more paths to Parquet files or directories\n     * @return\n     *   dataframe with Parquet block metadata\n     */\n    @scala.annotation.varargs\n    def parquetBlocks(paths: String*): DataFrame = parquetBlocks(None, paths)\n\n    /**\n     * Read the metadata of Parquet blocks into a Dataframe.\n     *\n     * The returned DataFrame has as many partitions as specified via `parallelism`.\n     *\n     * This provides the following per-block information:\n     *   - filename (string): The file name\n     *   - block (int): Block / RowGroup number starting at 1 (block ordinal + 1)\n     *   - blockStart (long): Start position of the block in the Parquet file\n     *   - compressedBytes (long): Number of compressed bytes in block\n     *   - uncompressedBytes (long): Number of uncompressed bytes in block\n     *   - rows (long): Number of rows in block\n     *   - columns (int): Number of columns in block\n     *   - values (long): Number of values in block\n     *   - nulls (long): Number of null values in block\n     *\n     * @param parallelism\n     *   number of partitions of returned DataFrame\n     * @param paths\n     *   one or more paths to Parquet files or directories\n     * @return\n     *   dataframe with Parquet block metadata\n     */\n    @scala.annotation.varargs\n    def parquetBlocks(parallelism: Int, paths: String*): DataFrame = parquetBlocks(Some(parallelism), paths)\n\n    private def parquetBlocks(parallelism: Option[Int], paths: Seq[String]): DataFrame = {\n      val files = getFiles(parallelism, paths)\n\n      import files.sparkSession.implicits._\n\n      files\n        .flatMap { case (_, file) =>\n          readFooters(file).flatMap { footer =>\n            val guard = FooterGuard(footer)\n            footer.getParquetMetadata.getBlocks.asScala.zipWithIndex.map { case (block, idx) =>\n              (\n                footer.getFile.toString,\n                ParquetMetaDataUtil.getOrdinal(block).getOrElse(idx) + 1,\n                block.getStartingPos,\n                guard { block.getCompressedSize },\n                block.getTotalByteSize,\n                block.getRowCount,\n                block.getColumns.asScala.size,\n                guard { block.getColumns.asScala.map(_.getValueCount).sum },\n                // when all columns have statistics, count the null values\n                guard {\n                  Option(block.getColumns.asScala.map(c => Option(c.getStatistics)))\n                    .filter(_.forall(_.isDefined))\n                    .map(_.map(_.get.getNumNulls).sum)\n                },\n              )\n            }\n          }\n        }\n        .toDF(\n          \"filename\",\n          \"block\",\n          \"blockStart\",\n          \"compressedBytes\",\n          \"uncompressedBytes\",\n          \"rows\",\n          \"columns\",\n          \"values\",\n          \"nulls\"\n        )\n    }\n\n    /**\n     * Read the metadata of Parquet block columns into a Dataframe.\n     *\n     * The returned DataFrame has as many partitions as there are Parquet files, at most\n     * `spark.sparkContext.defaultParallelism` partitions.\n     *\n     * This provides the following per-block-column information:\n     *   - filename (string): The file name\n     *   - block (int): Block / RowGroup number starting at 1\n     *   - column (string): Block / RowGroup column name\n     *   - codec (string): The coded used to compress the block column values\n     *   - type (string): The data type of the block column\n     *   - encodings (string): Encodings of the block column\n     *   - minValue (string): Minimum value of this column in this block\n     *   - maxValue (string): Maximum value of this column in this block\n     *   - columnStart (long): Start position of the block column in the Parquet file\n     *   - compressedBytes (long): Number of compressed bytes of this block column\n     *   - uncompressedBytes (long): Number of uncompressed bytes of this block column\n     *   - values (long): Number of values in this block column\n     *   - nulls (long): Number of null values in block\n     *\n     * @param paths\n     *   one or more paths to Parquet files or directories\n     * @return\n     *   dataframe with Parquet block metadata\n     */\n    @scala.annotation.varargs\n    def parquetBlockColumns(paths: String*): DataFrame = parquetBlockColumns(None, paths)\n\n    /**\n     * Read the metadata of Parquet block columns into a Dataframe.\n     *\n     * The returned DataFrame has as many partitions as specified via `parallelism`.\n     *\n     * This provides the following per-block-column information:\n     *   - filename (string): The file name\n     *   - block (int): Block / RowGroup number starting at 1 (block ordinal + 1)\n     *   - column (string): Block / RowGroup column name\n     *   - codec (string): The coded used to compress the block column values\n     *   - type (string): The data type of the block column\n     *   - encodings (string): Encodings of the block column\n     *   - minValue (string): Minimum value of this column in this block\n     *   - maxValue (string): Maximum value of this column in this block\n     *   - columnStart (long): Start position of the block column in the Parquet file\n     *   - compressedBytes (long): Number of compressed bytes of this block column\n     *   - uncompressedBytes (long): Number of uncompressed bytes of this block column\n     *   - values (long): Number of values in this block column\n     *   - nulls (long): Number of null values in block\n     *\n     * @param parallelism\n     *   number of partitions of returned DataFrame\n     * @param paths\n     *   one or more paths to Parquet files or directories\n     * @return\n     *   dataframe with Parquet block metadata\n     */\n    @scala.annotation.varargs\n    def parquetBlockColumns(parallelism: Int, paths: String*): DataFrame = parquetBlockColumns(Some(parallelism), paths)\n\n    private def parquetBlockColumns(parallelism: Option[Int], paths: Seq[String]): DataFrame = {\n      val files = getFiles(parallelism, paths)\n\n      import files.sparkSession.implicits._\n\n      files\n        .flatMap { case (_, file) =>\n          readFooters(file).flatMap { footer =>\n            val guard = FooterGuard(footer)\n            footer.getParquetMetadata.getBlocks.asScala.zipWithIndex.flatMap { case (block, idx) =>\n              block.getColumns.asScala.map { column =>\n                (\n                  footer.getFile.toString,\n                  ParquetMetaDataUtil.getOrdinal(block).getOrElse(idx) + 1,\n                  column.getPath.toSeq,\n                  guard { column.getCodec.toString },\n                  guard { column.getPrimitiveType.toString },\n                  guard { column.getEncodings.asScala.toSeq.map(_.toString).sorted },\n                  ParquetMetaDataUtil.isEncrypted(column),\n                  guard { Option(column.getStatistics).map(_.minAsString) },\n                  guard { Option(column.getStatistics).map(_.maxAsString) },\n                  guard { column.getStartingPos },\n                  guard { column.getTotalSize },\n                  guard { column.getTotalUncompressedSize },\n                  guard { column.getValueCount },\n                  guard { Option(column.getStatistics).map(_.getNumNulls) },\n                )\n              }\n            }\n          }\n        }\n        .toDF(\n          \"filename\",\n          \"block\",\n          \"column\",\n          \"codec\",\n          \"type\",\n          \"encodings\",\n          \"encrypted\",\n          \"minValue\",\n          \"maxValue\",\n          \"columnStart\",\n          \"compressedBytes\",\n          \"uncompressedBytes\",\n          \"values\",\n          \"nulls\"\n        )\n    }\n\n    /**\n     * Read the metadata of how Spark partitions Parquet files into a Dataframe.\n     *\n     * The returned DataFrame has as many partitions as there are Parquet files, at most\n     * `spark.sparkContext.defaultParallelism` partitions.\n     *\n     * This provides the following per-partition information:\n     *   - partition (int): The Spark partition id\n     *   - start (long): The start position of the partition\n     *   - end (long): The end position of the partition\n     *   - length (long): The length of the partition\n     *   - blocks (int): The number of Parquet blocks / RowGroups in this partition\n     *   - compressedBytes (long): The number of compressed bytes in this partition\n     *   - uncompressedBytes (long): The number of uncompressed bytes in this partition\n     *   - rows (long): The number of rows in this partition\n     *   - columns (int): Number of columns in the file\n     *   - values (long): The number of values in this partition\n     *   - filename (string): The Parquet file name\n     *   - fileLength (long): The length of the Parquet file\n     *\n     * @param paths\n     *   one or more paths to Parquet files or directories\n     * @return\n     *   dataframe with Spark Parquet partition metadata\n     */\n    @scala.annotation.varargs\n    def parquetPartitions(paths: String*): DataFrame = parquetPartitions(None, paths)\n\n    /**\n     * Read the metadata of how Spark partitions Parquet files into a Dataframe.\n     *\n     * The returned DataFrame has as many partitions as specified via `parallelism`.\n     *\n     * This provides the following per-partition information:\n     *   - partition (int): The Spark partition id\n     *   - start (long): The start position of the partition\n     *   - end (long): The end position of the partition\n     *   - length (long): The length of the partition\n     *   - blocks (int): The number of Parquet blocks / RowGroups in this partition\n     *   - compressedBytes (long): The number of compressed bytes in this partition\n     *   - uncompressedBytes (long): The number of uncompressed bytes in this partition\n     *   - rows (long): The number of rows in this partition\n     *   - columns (int): Number of columns in the file\n     *   - values (long): The number of values in this partition\n     *   - filename (string): The Parquet file name\n     *   - fileLength (long): The length of the Parquet file\n     *\n     * @param parallelism\n     *   number of partitions of returned DataFrame\n     * @param paths\n     *   one or more paths to Parquet files or directories\n     * @return\n     *   dataframe with Spark Parquet partition metadata\n     */\n    @scala.annotation.varargs\n    def parquetPartitions(parallelism: Int, paths: String*): DataFrame = parquetPartitions(Some(parallelism), paths)\n\n    private def parquetPartitions(parallelism: Option[Int], paths: Seq[String]): DataFrame = {\n      val files = getFiles(parallelism, paths)\n\n      import files.sparkSession.implicits._\n\n      files\n        .flatMap { case (part, file) =>\n          readFooters(file)\n            .map(footer => (footer, getBlocks(footer, file.start, file.length)))\n            .map { case (footer, blocks) =>\n              (\n                part,\n                file.start,\n                file.start + file.length,\n                file.length,\n                blocks.size,\n                blocks.map(_.getCompressedSize).sum,\n                blocks.map(_.getTotalByteSize).sum,\n                blocks.map(_.getRowCount).sum,\n                blocks\n                  .map(_.getColumns.map(_.getPath.mkString(\".\")).toSet)\n                  .foldLeft(Set.empty[String])((left, right) => left.union(right))\n                  .size,\n                blocks.map(_.getColumns.asScala.map(_.getValueCount).sum).sum,\n                // when all columns have statistics, count the null values\n                Option(blocks.flatMap(_.getColumns.asScala.map(c => Option(c.getStatistics))))\n                  .filter(_.forall(_.isDefined))\n                  .map(_.map(_.get.getNumNulls).sum),\n                footer.getFile.toString,\n                file.fileSize,\n              )\n            }\n        }\n        .toDF(\n          \"partition\",\n          \"start\",\n          \"end\",\n          \"length\",\n          \"blocks\",\n          \"compressedBytes\",\n          \"uncompressedBytes\",\n          \"rows\",\n          \"columns\",\n          \"values\",\n          \"nulls\",\n          \"filename\",\n          \"fileLength\"\n        )\n    }\n\n    private def getFiles(parallelism: Option[Int], paths: Seq[String]): Dataset[(Int, SplitFile)] = {\n      val df = reader.parquet(paths: _*)\n      val parts = df.rdd.partitions\n        .flatMap(part =>\n          part\n            .asInstanceOf[FilePartition]\n            .files\n            .map(file => (part.index, SplitFile(file)))\n        )\n        .toSeq\n        .distinct\n\n      import df.sparkSession.implicits._\n\n      parts\n        .toDS()\n        .when(parallelism.isDefined)\n        .call(_.repartition(parallelism.get))\n    }\n  }\n\n  private def readFooters(file: SplitFile): Iterable[Footer] = {\n    val path = new Path(file.filePath)\n    val status = path.getFileSystem(conf).getFileStatus(path)\n    ParquetFileReader.readFooters(conf, status, false).asScala\n  }\n\n  private def getBlocks(footer: Footer, start: Long, length: Long): Seq[BlockMetaData] = {\n    footer.getParquetMetadata.getBlocks.asScala\n      .map(block => (block, block.getStartingPos + block.getCompressedSize / 2))\n      .filter { case (_, midBlock) => start <= midBlock && midBlock < start + length }\n      .map(_._1)\n      .toSeq\n  }\n\n}\n"
  },
  {
    "path": "src/main/scala-spark-3.2/uk/co/gresearch/spark/parquet/SplitFile.scala",
    "content": "/*\n * Copyright 2023 G-Research\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage uk.co.gresearch.spark.parquet\n\nimport org.apache.spark.sql.execution.datasources.PartitionedFile\n\nprivate[spark] case class SplitFile(filePath: String, start: Long, length: Long, fileSize: Option[Long])\n\nprivate[spark] object SplitFile {\n  def apply(file: PartitionedFile): SplitFile = SplitFile(file.filePath, file.start, file.length, None)\n}\n"
  },
  {
    "path": "src/main/scala-spark-3.3/uk/co/gresearch/spark/parquet/SplitFile.scala",
    "content": "/*\n * Copyright 2023 G-Research\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage uk.co.gresearch.spark.parquet\n\nimport org.apache.spark.sql.execution.datasources.PartitionedFile\n\nprivate[spark] case class SplitFile(filePath: String, start: Long, length: Long, fileSize: Option[Long])\n\nprivate[spark] object SplitFile {\n  def apply(file: PartitionedFile): SplitFile = SplitFile(file.filePath, file.start, file.length, Some(file.fileSize))\n}\n"
  },
  {
    "path": "src/main/scala-spark-3.5/org/apache/spark/sql/extension/package.scala",
    "content": "/*\n * Copyright 2024 G-Research\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage org.apache.spark.sql\n\nimport org.apache.spark.sql.catalyst.expressions.Expression\n\npackage object extension {\n  implicit class ColumnExtension(col: Column) {\n    // Column.expr exists in this Spark version and earlier\n    def sql: String = col.expr.sql\n  }\n\n  implicit class ExpressionExtension(expr: Expression) {\n    def column: Column = new Column(expr)\n  }\n}\n"
  },
  {
    "path": "src/main/scala-spark-3.5/uk/co/gresearch/spark/Backticks.scala",
    "content": "/*\n * Copyright 2021 G-Research\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage uk.co.gresearch.spark\n\nimport java.util.regex.Pattern\n\nobject Backticks {\n\n  // https://github.com/apache/spark/blob/523ff15/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/QuotingUtils.scala#L46\n  private val validIdentPattern = Pattern.compile(\"^[a-zA-Z_][a-zA-Z0-9_]*\")\n\n  /**\n   * Detects if column name part requires quoting.\n   * https://github.com/apache/spark/blob/523ff15/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/QuotingUtils.scala#L48\n   */\n  private def needQuote(part: String): Boolean = {\n    !validIdentPattern.matcher(part).matches()\n  }\n\n  /**\n   * Encloses the given strings with backticks (backquotes) if needed.\n   *\n   * Backticks are not needed for strings that start with a letter (`a`-`z` and `A`-`Z`) or an underscore,\n   * and contain only letters, numbers and underscores.\n   *\n   * Multiple strings will be enclosed individually and concatenated with dots (`.`).\n   *\n   * This is useful when referencing column names that contain special characters like dots (`.`) or backquotes.\n   *\n   * Examples:\n   * {{{\n   *   col(\"a.column\")                                    // this references the field \"column\" of column \"a\"\n   *   col(\"`a.column`\")                                  // this reference the column with the name \"a.column\"\n   *   col(Backticks.column_name(\"column\"))               // produces \"column\"\n   *   col(Backticks.column_name(\"a.column\"))             // produces \"`a.column`\"\n   *   col(Backticks.column_name(\"a column\"))             // produces \"`a column`\"\n   *   col(Backticks.column_name(\"`a.column`\"))           // produces \"`a.column`\"\n   *   col(Backticks.column_name(\"a.column\", \"a.field\"))  // produces \"`a.column`.`a.field`\"\n   * }}}\n   *\n   * @param string\n   *   a string\n   * @param strings\n   *   more strings\n   * @return\n   */\n  @scala.annotation.varargs\n  def column_name(string: String, strings: String*): String =\n    (string +: strings)\n      .map(s => if (needQuote(s)) s\"`${s.replace(\"`\", \"``\")}`\" else s)\n      .mkString(\".\")\n\n}\n"
  },
  {
    "path": "src/main/scala-spark-4.0/org/apache/spark/sql/extension/package.scala",
    "content": "/*\n * Copyright 2024 G-Research\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage org.apache.spark.sql\n\nimport org.apache.spark.sql.catalyst.expressions.Expression\nimport org.apache.spark.sql.classic.ExpressionUtils.{column => toColumn, expression}\n\npackage object extension {\n  implicit class ColumnExtension(col: Column) {\n    def expr: Expression = expression(col)\n    def sql: String = col.node.sql\n  }\n\n  implicit class ExpressionExtension(expr: Expression) {\n    def column: Column = toColumn(expr)\n  }\n}\n"
  },
  {
    "path": "src/main/scala-spark-4.0/uk/co/gresearch/spark/Backticks.scala",
    "content": "/*\n * Copyright 2021 G-Research\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage uk.co.gresearch.spark\n\nimport org.apache.spark.sql.catalyst.util.QuotingUtils\n\nobject Backticks {\n\n  /**\n   * Encloses the given strings with backticks (backquotes) if needed.\n   *\n   * Backticks are not needed for strings that start with a letter (`a`-`z` and `A`-`Z`) or an underscore,\n   * and contain only letters, numbers and underscores.\n   *\n   * Multiple strings will be enclosed individually and concatenated with dots (`.`).\n   *\n   * This is useful when referencing column names that contain special characters like dots (`.`) or backquotes.\n   *\n   * Examples:\n   * {{{\n   *   col(\"a.column\")                                    // this references the field \"column\" of column \"a\"\n   *   col(\"`a.column`\")                                  // this reference the column with the name \"a.column\"\n   *   col(Backticks.column_name(\"column\"))               // produces \"column\"\n   *   col(Backticks.column_name(\"a.column\"))             // produces \"`a.column`\"\n   *   col(Backticks.column_name(\"a column\"))             // produces \"`a column`\"\n   *   col(Backticks.column_name(\"`a.column`\"))           // produces \"`a.column`\"\n   *   col(Backticks.column_name(\"a.column\", \"a.field\"))  // produces \"`a.column`.`a.field`\"\n   * }}}\n   *\n   * @param string\n   *   a string\n   * @param strings\n   *   more strings\n   * @return\n   */\n  @scala.annotation.varargs\n  def column_name(string: String, strings: String*): String =\n    QuotingUtils.quoted(Array.from(string +: strings))\n\n}\n"
  },
  {
    "path": "src/main/scala-spark-4.0/uk/co/gresearch/spark/parquet/SplitFile.scala",
    "content": "/*\n * Copyright 2023 G-Research\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage uk.co.gresearch.spark.parquet\n\nimport org.apache.spark.sql.execution.datasources.PartitionedFile\n\nprivate[spark] case class SplitFile(filePath: String, start: Long, length: Long, fileSize: Option[Long])\n\nprivate[spark] object SplitFile {\n  def apply(file: PartitionedFile): SplitFile = SplitFile(file.filePath.toString, file.start, file.length, Some(file.fileSize))\n}\n"
  },
  {
    "path": "src/test/java/uk/co/gresearch/test/SparkJavaTests.java",
    "content": "/*\n * Copyright 2021 G-Research\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n// these tests are deliberately located outside uk.co.gresearch.spark to show how imports look for Java\npackage uk.co.gresearch.test;\n\nimport org.apache.spark.SparkConf;\nimport org.apache.spark.sql.*;\nimport org.apache.spark.sql.execution.CacheManager;\nimport org.apache.spark.storage.StorageLevel;\nimport org.junit.AfterClass;\nimport org.junit.Assert;\nimport org.junit.BeforeClass;\nimport org.junit.Test;\nimport uk.co.gresearch.spark.Backticks;\nimport uk.co.gresearch.spark.Histogram;\nimport uk.co.gresearch.spark.RowNumbers;\nimport uk.co.gresearch.spark.UnpersistHandle;\nimport uk.co.gresearch.spark.diff.JavaValue;\n\nimport java.util.Arrays;\nimport java.util.List;\n\npublic class SparkJavaTests {\n    private static SparkSession spark;\n    private static Dataset<JavaValue> dataset;\n\n    @BeforeClass\n    public static void beforeClass() {\n        spark = SparkSession\n                .builder()\n                .master(\"local[*]\")\n                .config(new SparkConf().set(\"fs.defaultFS\", \"file:///\"))\n                .appName(\"Diff Java Suite\")\n                .getOrCreate();\n\n        JavaValue valueOne = new JavaValue(1, \"one\", 1.0);\n        JavaValue valueTwo = new JavaValue(2, \"two\", 2.0);\n        JavaValue valueThree = new JavaValue(3, \"three\", 3.0);\n        Encoder<JavaValue> encoder = Encoders.bean(JavaValue.class);\n        dataset = spark.createDataset(Arrays.asList(valueOne, valueTwo, valueThree), encoder);\n    }\n\n    @Test\n    public void testBackticks() {\n        Assert.assertEquals(\"col\", Backticks.column_name(\"col\"));\n        Assert.assertEquals(\"`a.col`\", Backticks.column_name(\"a.col\"));\n        Assert.assertEquals(\"a.col\", Backticks.column_name(\"a\", \"col\"));\n        Assert.assertEquals(\"some.more.columns\", Backticks.column_name(\"some\", \"more\", \"columns\"));\n        Assert.assertEquals(\"some.`more.columns`\", Backticks.column_name(\"some\", \"more.columns\"));\n        Assert.assertEquals(\"some.more.dotted.columns\", Backticks.column_name(\"some\", \"more\", \"dotted\", \"columns\"));\n        Assert.assertEquals(\"some.more.`dotted.columns`\", Backticks.column_name(\"some\", \"more\", \"dotted.columns\"));\n    }\n\n    @Test\n    public void testHistogram() {\n        Dataset<Row> histogram = Histogram.of(dataset, Arrays.asList(0, 1, 2), new Column(\"id\"));\n        List<Row> expected = Arrays.asList(RowFactory.create(0, 1, 1, 1));\n        Assert.assertEquals(expected, histogram.collectAsList());\n    }\n\n    @Test\n    public void testHistogramWithAggColumn() {\n        Dataset<Row> histogram = Histogram.of(dataset, Arrays.asList(0, 1, 2), new Column(\"id\"), new Column(\"label\"));\n        List<Row> expected = Arrays.asList(\n                RowFactory.create(\"one\", 0, 1, 0, 0),\n                RowFactory.create(\"three\", 0, 0, 0, 1),\n                RowFactory.create(\"two\", 0, 0, 1, 0)\n        );\n        Assert.assertEquals(expected, histogram.sort(\"label\").collectAsList());\n    }\n\n    @Test\n    public void testRowNumbers() {\n        Dataset<Row> withRowNumbers = RowNumbers.of(dataset);\n        List<Row> expected = Arrays.asList(\n                RowFactory.create(1, \"one\", 1.0, 1),\n                RowFactory.create(2, \"two\", 2.0, 2),\n                RowFactory.create(3, \"three\", 3.0, 3)\n        );\n        Assert.assertEquals(expected, withRowNumbers.orderBy(\"id\").collectAsList());\n    }\n\n    @Test\n    public void testRowNumbersOrderOneColumn() {\n        Dataset<Row> withRowNumbers = RowNumbers.withOrderColumns(dataset.col(\"id\").desc()).of(dataset);\n        List<Row> expected = Arrays.asList(\n                RowFactory.create(1, \"one\", 1.0, 3),\n                RowFactory.create(2, \"two\", 2.0, 2),\n                RowFactory.create(3, \"three\", 3.0, 1)\n        );\n        Assert.assertEquals(expected, withRowNumbers.orderBy(\"id\").collectAsList());\n    }\n\n    @Test\n    public void testRowNumbersOrderTwoColumns() {\n        Dataset<Row> withRowNumbers = RowNumbers.withOrderColumns(dataset.col(\"id\"), dataset.col(\"label\")).of(dataset);\n        List<Row> expected = Arrays.asList(\n                RowFactory.create(1, \"one\", 1.0, 1),\n                RowFactory.create(2, \"two\", 2.0, 2),\n                RowFactory.create(3, \"three\", 3.0, 3)\n        );\n        Assert.assertEquals(expected, withRowNumbers.orderBy(\"id\").collectAsList());\n    }\n\n    @Test\n    public void testRowNumbersOrderDesc() {\n        Dataset<Row> withRowNumbers = RowNumbers.withOrderColumns(dataset.col(\"id\").desc()).of(dataset);\n        List<Row> expected = Arrays.asList(\n                RowFactory.create(1, \"one\", 1.0, 3),\n                RowFactory.create(2, \"two\", 2.0, 2),\n                RowFactory.create(3, \"three\", 3.0, 1)\n        );\n        Assert.assertEquals(expected, withRowNumbers.orderBy(\"id\").collectAsList());\n    }\n\n    @Test\n    public void testRowNumbersUnpersist() {\n        CacheManager cacheManager = SparkJavaTests.spark.sharedState().cacheManager();\n        cacheManager.clearCache();\n        Assert.assertTrue(cacheManager.isEmpty());\n\n        UnpersistHandle unpersist = new UnpersistHandle();\n        Dataset<Row> withRowNumbers = RowNumbers.withUnpersistHandle(unpersist).of(dataset);\n        List<Row> expected = Arrays.asList(\n                RowFactory.create(1, \"one\", 1.0, 1),\n                RowFactory.create(2, \"two\", 2.0, 2),\n                RowFactory.create(3, \"three\", 3.0, 3)\n        );\n        Assert.assertEquals(expected, withRowNumbers.orderBy(\"id\").collectAsList());\n\n        Assert.assertFalse(cacheManager.isEmpty());\n        unpersist.apply(true);\n        Assert.assertTrue(cacheManager.isEmpty());\n    }\n\n    @Test\n    public void testRowNumbersStorageLevelAndUnpersist() {\n        CacheManager cacheManager = SparkJavaTests.spark.sharedState().cacheManager();\n        cacheManager.clearCache();\n        Assert.assertTrue(cacheManager.isEmpty());\n\n        UnpersistHandle unpersist = new UnpersistHandle();\n        RowNumbers.withStorageLevel(StorageLevel.MEMORY_ONLY()).withUnpersistHandle(unpersist).of(dataset);\n\n        Assert.assertFalse(cacheManager.isEmpty());\n        unpersist.apply(true);\n        Assert.assertTrue(cacheManager.isEmpty());\n    }\n\n    @Test\n    public void testRowNumbersColumnName() {\n        Dataset<Row> withRowNumbers = RowNumbers.withRowNumberColumnName(\"row\").of(dataset);\n        Assert.assertEquals(Arrays.asList(\"id\", \"label\", \"score\", \"row\"), Arrays.asList(withRowNumbers.columns()));\n\n        List<Row> expected = Arrays.asList(\n                RowFactory.create(1, \"one\", 1.0, 1),\n                RowFactory.create(2, \"two\", 2.0, 2),\n                RowFactory.create(3, \"three\", 3.0, 3)\n        );\n        Assert.assertEquals(expected, withRowNumbers.orderBy(\"id\").collectAsList());\n    }\n\n    @AfterClass\n    public static void afterClass() {\n        if (spark != null) {\n            spark.stop();\n        }\n    }\n\n}\n"
  },
  {
    "path": "src/test/java/uk/co/gresearch/test/diff/DiffJavaTests.java",
    "content": "/*\n * Copyright 2021 G-Research\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage uk.co.gresearch.spark.diff;\n\nimport org.apache.spark.SparkConf;\nimport org.apache.spark.sql.*;\nimport org.apache.spark.sql.types.DataTypes;\nimport org.junit.AfterClass;\nimport org.junit.Assert;\nimport org.junit.BeforeClass;\nimport org.junit.Test;\nimport scala.Tuple3;\nimport scala.math.Equiv;\nimport uk.co.gresearch.spark.diff.comparator.DiffComparator;\n\nimport java.util.Arrays;\nimport java.util.Collections;\nimport java.util.List;\n\nimport static java.lang.Math.abs;\n\npublic class DiffJavaTests {\n    private static SparkSession spark;\n    private static Dataset<JavaValue> left;\n    private static Dataset<JavaValue> right;\n\n    @BeforeClass\n    public static void beforeClass() {\n        spark = SparkSession\n                    .builder()\n                    .master(\"local[*]\")\n                    .config(new SparkConf().set(\"fs.defaultFS\", \"file:///\"))\n                    .appName(\"Diff Java Suite\")\n                    .getOrCreate();\n\n        JavaValue valueOne = new JavaValue(1, \"one\", 1.0);\n        JavaValue valueTwo = new JavaValue(2, \"two\", 2.0);\n        JavaValue valueThree = new JavaValue(3, \"three\", 3.0);\n        JavaValue valueThreeScored = new JavaValue(3, \"three\", 3.1);\n        JavaValue valueFour = new JavaValue(4, \"four\", 4.0);\n        Encoder<JavaValue> encoder = Encoders.bean(JavaValue.class);\n\n        left = spark.createDataset(Arrays.asList(valueOne, valueTwo, valueThree), encoder);\n        right = spark.createDataset(Arrays.asList(valueTwo, valueThreeScored, valueFour), encoder);\n    }\n\n    @Test\n    public void testDiff() {\n        Dataset<Row> diff = Diff.of(left.toDF(), right.toDF(), \"id\");\n        List<Row> expected = Arrays.asList(\n                RowFactory.create(\"D\", 1, \"one\", null, 1.0, null),\n                RowFactory.create(\"N\", 2, \"two\", \"two\", 2.0, 2.0),\n                RowFactory.create(\"C\", 3, \"three\", \"three\", 3.0, 3.1),\n                RowFactory.create(\"I\", 4, null, \"four\", null, 4.0)\n        );\n        Assert.assertEquals(expected, diff.sort(\"id\").collectAsList());\n    }\n\n    @Test\n    public void testDiffNoKey() {\n        Dataset<Row> diff = Diff.of(left, right);\n        List<Row> expected = Arrays.asList(\n                RowFactory.create(\"D\", 1, \"one\", 1.0),\n                RowFactory.create(\"N\", 2, \"two\", 2.0),\n                RowFactory.create(\"D\", 3, \"three\", 3.0),\n                RowFactory.create(\"I\", 3, \"three\", 3.1),\n                RowFactory.create(\"I\", 4, \"four\", 4.0)\n        );\n        Assert.assertEquals(expected, diff.sort(\"id\", \"diff\").collectAsList());\n    }\n\n    @Test\n    public void testDiffSingleKey() {\n        Dataset<Row> diff = Diff.of(left, right, \"id\");\n        List<Row> expected = Arrays.asList(\n                RowFactory.create(\"D\", 1, \"one\", null, 1.0, null),\n                RowFactory.create(\"N\", 2, \"two\", \"two\", 2.0, 2.0),\n                RowFactory.create(\"C\", 3, \"three\", \"three\", 3.0, 3.1),\n                RowFactory.create(\"I\", 4, null, \"four\", null, 4.0)\n        );\n        Assert.assertEquals(expected, diff.sort(\"id\").collectAsList());\n    }\n\n    @Test\n    public void testDiffMultipleKeys() {\n        Dataset<Row> diff = Diff.of(left, right, \"id\", \"label\");\n        List<Row> expected = Arrays.asList(\n                RowFactory.create(\"D\", 1, \"one\", 1.0, null),\n                RowFactory.create(\"N\", 2, \"two\", 2.0, 2.0),\n                RowFactory.create(\"C\", 3, \"three\", 3.0, 3.1),\n                RowFactory.create(\"I\", 4, \"four\", null, 4.0)\n        );\n        Assert.assertEquals(expected, diff.sort(\"id\").collectAsList());\n    }\n\n    @Test\n    public void testDiffIgnoredColumn() {\n        Dataset<Row> diff = Diff.of(left, right, Collections.singletonList(\"id\"), Collections.singletonList(\"score\"));\n        List<Row> expected = Arrays.asList(\n                RowFactory.create(\"D\", 1, \"one\", null, 1.0, null),\n                RowFactory.create(\"N\", 2, \"two\", \"two\", 2.0, 2.0),\n                RowFactory.create(\"N\", 3, \"three\", \"three\", 3.0, 3.1),\n                RowFactory.create(\"I\", 4, null, \"four\", null, 4.0)\n        );\n        Assert.assertEquals(expected, diff.sort(\"id\").collectAsList());\n    }\n\n    @Test\n    public void testDiffAs() {\n        Encoder<JavaValueAs> encoder = Encoders.bean(JavaValueAs.class);\n        Dataset<JavaValueAs> diff = Diff.ofAs(left.toDF(), right.toDF(), encoder, \"id\");\n        List<JavaValueAs> expected = Arrays.asList(\n                new JavaValueAs(\"D\", 1, \"one\", null, 1.0, null),\n                new JavaValueAs(\"N\", 2, \"two\", \"two\", 2.0, 2.0),\n                new JavaValueAs(\"C\", 3, \"three\", \"three\", 3.0, 3.1),\n                new JavaValueAs(\"I\", 4, null, \"four\", null, 4.0)\n        );\n        Assert.assertEquals(expected, diff.sort(\"id\").collectAsList());\n    }\n\n    @Test\n    public void testDiffOfWith() {\n        Dataset<Tuple3<String, JavaValue, JavaValue>> diff = Diff.ofWith(left, right, \"id\");\n        List<Tuple3<String, JavaValue, JavaValue>> expected = Arrays.asList(\n                new Tuple3<>(\"D\", new JavaValue(1, \"one\", 1.0), null),\n                new Tuple3<>(\"N\", new JavaValue(2, \"two\", 2.0), new JavaValue(2, \"two\", 2.0)),\n                new Tuple3<>(\"C\", new JavaValue(3, \"three\", 3.0), new JavaValue(3, \"three\", 3.1)),\n                new Tuple3<>(\"I\", null, new JavaValue(4, \"four\", 4.0))\n        );\n        Assert.assertEquals(expected, diff.sort(\"id\").collectAsList());\n    }\n\n    @Test\n    public void testDiffer() {\n        DiffOptions options = new DiffOptions();\n\n        Differ differ = new Differ(options);\n        Dataset<Row> diff = differ.diff(left, right, \"id\");\n        List<Row> expected = Arrays.asList(\n                RowFactory.create(\"D\", 1, \"one\", null, 1.0, null),\n                RowFactory.create(\"N\", 2, \"two\", \"two\", 2.0, 2.0),\n                RowFactory.create(\"C\", 3, \"three\", \"three\", 3.0, 3.1),\n                RowFactory.create(\"I\", 4, null, \"four\", null, 4.0)\n        );\n        Assert.assertEquals(expected, diff.sort(\"id\").collectAsList());\n    }\n\n    @Test\n    public void testDifferWithIgnored() {\n        DiffOptions options = new DiffOptions();\n\n        Differ differ = new Differ(options);\n        Dataset<Row> diff = differ.diff(left, right, Collections.singletonList(\"id\"), Collections.singletonList(\"score\"));\n        List<Row> expected = Arrays.asList(\n                RowFactory.create(\"D\", 1, \"one\", null, 1.0, null),\n                RowFactory.create(\"N\", 2, \"two\", \"two\", 2.0, 2.0),\n                RowFactory.create(\"N\", 3, \"three\", \"three\", 3.0, 3.1),\n                RowFactory.create(\"I\", 4, null, \"four\", null, 4.0)\n        );\n        Assert.assertEquals(expected, diff.sort(\"id\").collectAsList());\n        List<String> columns = Arrays.asList(diff.schema().fieldNames());\n        Assert.assertEquals(Arrays.asList(\"diff\", \"id\", \"left_label\", \"right_label\", \"left_score\", \"right_score\"), columns);\n    }\n\n    @Test\n    public void testDiffWithOptions() {\n        DiffOptions options = new DiffOptions(\n                \"action\",\n                \"before\", \"after\",\n                \"+\", \"~\", \"-\", \"=\",\n                scala.Option.apply(null),\n                DiffMode.ColumnByColumn(),\n                false\n        );\n\n        Differ differ = new Differ(options);\n        Dataset<Row> diff = differ.diff(left, right, \"id\");\n        List<Row> expected = Arrays.asList(\n                RowFactory.create(\"-\", 1, \"one\", null, 1.0, null),\n                RowFactory.create(\"=\", 2, \"two\", \"two\", 2.0, 2.0),\n                RowFactory.create(\"~\", 3, \"three\", \"three\", 3.0, 3.1),\n                RowFactory.create(\"+\", 4, null, \"four\", null, 4.0)\n        );\n        Assert.assertEquals(expected, diff.sort(\"id\").collectAsList());\n        List<String> names = Arrays.asList(diff.schema().fieldNames());\n        Assert.assertEquals(Arrays.asList(\"action\", \"id\", \"before_label\", \"after_label\", \"before_score\", \"after_score\"), names);\n    }\n\n    @Test\n    public void testDiffWithComparators() {\n        DiffComparator comparator = DiffComparators.epsilon(0.100000001).asInclusive().asAbsolute();\n        testDiffWithComparator(new DiffOptions().withComparator(comparator, DataTypes.DoubleType));\n        testDiffWithComparator(new DiffOptions().withComparator(comparator, \"score\"));\n\n        Equiv<Double> equivDouble = (Double x, Double y) -> x == null && y == null || x != null && y != null &&\n                abs(x - y) <= 0.1000000001;\n        testDiffWithComparator(new DiffOptions().withComparator(equivDouble, Encoders.DOUBLE()));\n        testDiffWithComparator(new DiffOptions().withComparator(equivDouble, Encoders.DOUBLE(), \"score\"));\n        testDiffWithComparator(new DiffOptions().withComparator(equivDouble, DataTypes.DoubleType));\n\n        Equiv<Object> equivAny = (x, y) -> x == null && y == null || x instanceof Double && y instanceof Double &&\n                abs((Double) x - (Double) y) <= 0.1000000001;\n        testDiffWithComparator(new DiffOptions().withComparator(equivAny, DataTypes.DoubleType));\n        testDiffWithComparator(new DiffOptions().withComparator(equivAny, \"score\"));\n    }\n\n    private void testDiffWithComparator(DiffOptions options) {\n        Differ differ = new Differ(options);\n\n        Dataset<Row> diff = differ.diff(left, right, \"id\");\n        List<Row> expected = Arrays.asList(\n                RowFactory.create(\"D\", 1, \"one\", null, 1.0, null),\n                RowFactory.create(\"N\", 2, \"two\", \"two\", 2.0, 2.0),\n                // this is only considered un-changed because of the epsilon diff comparator for column 'score'\n                RowFactory.create(\"N\", 3, \"three\", \"three\", 3.0, 3.1),\n                RowFactory.create(\"I\", 4, null, \"four\", null, 4.0)\n        );\n        Assert.assertEquals(expected, diff.sort(\"id\").collectAsList());\n    }\n\n    @AfterClass\n    public static void afterClass() {\n        if (spark != null) {\n            spark.stop();\n        }\n    }\n}\n"
  },
  {
    "path": "src/test/java/uk/co/gresearch/test/diff/JavaValue.java",
    "content": "/*\n * Copyright 2021 G-Research\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage uk.co.gresearch.spark.diff;\n\nimport java.io.Serializable;\nimport java.util.Objects;\n\npublic class JavaValue implements Serializable {\n    private Integer id;\n    private String label;\n    private Double score;\n\n    public JavaValue() { }\n\n    public JavaValue(Integer id, String label, Double score) {\n        this.id = id;\n        this.label = label;\n        this.score = score;\n    }\n\n    public Integer getId() {\n        return id;\n    }\n\n    public void setId(Integer id) {\n        this.id = id;\n    }\n\n    public String getLabel() {\n        return label;\n    }\n\n    public void setLabel(String label) {\n        this.label = label;\n    }\n\n    public Double getScore() {\n        return score;\n    }\n\n    public void setScore(Double score) {\n        this.score = score;\n    }\n\n    @Override\n    public boolean equals(Object o) {\n        if (this == o) return true;\n        if (o == null || getClass() != o.getClass()) return false;\n\n        JavaValue javaValue = (JavaValue) o;\n        return Objects.equals(id, javaValue.id) && Objects.equals(label, javaValue.label) && Objects.equals(score, javaValue.score);\n    }\n\n    @Override\n    public int hashCode() {\n        return Objects.hash(id, label, score);\n    }\n\n    @Override\n    public String toString() {\n        return \"JavaValue{id=\" + id + \", label='\" + label + \"', score=\" + score + '}';\n    }\n}\n"
  },
  {
    "path": "src/test/java/uk/co/gresearch/test/diff/JavaValueAs.java",
    "content": "/*\n * Copyright 2022 G-Research\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage uk.co.gresearch.spark.diff;\n\nimport java.io.Serializable;\nimport java.util.Objects;\n\npublic class JavaValueAs implements Serializable {\n    private String diff;\n    private Integer id;\n    private String left_label;\n    private String right_label;\n    private Double left_score;\n    private Double right_score;\n\n    public JavaValueAs() { }\n\n    public JavaValueAs(String diff, Integer id, String left_label, String right_label, Double left_score, Double right_score) {\n        this.diff = diff;\n        this.id = id;\n        this.left_label = left_label;\n        this.right_label = right_label;\n        this.left_score = left_score;\n        this.right_score = right_score;\n    }\n\n    public String getDiff() {\n        return diff;\n    }\n\n    public void setDiff(String diff) {\n        this.diff = diff;\n    }\n\n    public Integer getId() {\n        return id;\n    }\n\n    public void setId(Integer id) {\n        this.id = id;\n    }\n\n    public String getLeft_label() {\n        return left_label;\n    }\n\n    public void setLeft_label(String left_label) {\n        this.left_label = left_label;\n    }\n\n    public String getRight_label() {\n        return right_label;\n    }\n\n    public void setRight_label(String right_label) {\n        this.right_label = right_label;\n    }\n\n    public Double getLeft_score() {\n        return left_score;\n    }\n\n    public void setLeft_score(Double left_score) {\n        this.left_score = left_score;\n    }\n\n    public Double getRight_score() {\n        return right_score;\n    }\n\n    public void setRight_score(Double right_score) {\n        this.right_score = right_score;\n    }\n\n    @Override\n    public boolean equals(Object o) {\n        if (this == o) return true;\n        if (o == null || getClass() != o.getClass()) return false;\n\n        JavaValueAs that = (JavaValueAs) o;\n        return Objects.equals(diff, that.diff) && Objects.equals(id, that.id) && Objects.equals(left_label, that.left_label) && Objects.equals(right_label, that.right_label) && Objects.equals(left_score, that.left_score) && Objects.equals(right_score, that.right_score);\n    }\n\n    @Override\n    public int hashCode() {\n        return Objects.hash(diff, id, left_label, right_label, left_score, right_score);\n    }\n\n    @Override\n    public String toString() {\n        return \"JavaValueAs{\" +\n                \"diff='\" + diff + \"', \" +\n                \"id=\" + id + \", \" +\n                \"left_label='\" + left_label + \"', \" +\n                \"right_label='\" + right_label + \"', \" +\n                \"left_score=\" + left_score + \", \" +\n                \"right_score=\" + right_score +\n                '}';\n    }\n}\n"
  },
  {
    "path": "src/test/resources/log4j.properties",
    "content": "#\n# Copyright 2020 G-Research\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n\n# Set everything to be logged to the console\nlog4j.rootCategory=WARN, console\nlog4j.appender.console=org.apache.log4j.ConsoleAppender\nlog4j.appender.console.target=System.err\nlog4j.appender.console.layout=org.apache.log4j.PatternLayout\nlog4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n\n\n# Set the default spark-shell log level to WARN. When running the spark-shell, the\n# log level for this class is used to overwrite the root logger's log level, so that\n# the user can have different defaults for the shell and regular Spark apps.\nlog4j.logger.org.apache.spark.repl.Main=WARN\n\n# Settings to quiet third party logs that are too verbose\nlog4j.logger.org.sparkproject.jetty=WARN\nlog4j.logger.org.sparkproject.jetty.util.component.AbstractLifeCycle=ERROR\nlog4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO\nlog4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO\nlog4j.logger.org.apache.parquet=ERROR\nlog4j.logger.parquet=ERROR\n\n# SPARK-9183: Settings to avoid annoying messages when looking up nonexistent UDFs in SparkSQL with Hive support\nlog4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=FATAL\nlog4j.logger.org.apache.hadoop.hive.ql.exec.FunctionRegistry=ERROR\n\n# Set G-Research Spark logging to DEBUG\nlog4j.logger.uk.co.gresearch.spark=DEBUG\n"
  },
  {
    "path": "src/test/resources/log4j2.properties",
    "content": "#\n# Copyright 2020 G-Research\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n\n# Set everything to be logged to the console\nrootLogger.level = warn\nrootLogger.appenderRef.stdout.ref = console\n\n# In the pattern layout configuration below, we specify an explicit `%ex` conversion\n# pattern for logging Throwables. If this was omitted, then (by default) Log4J would\n# implicitly add an `%xEx` conversion pattern which logs stacktraces with additional\n# class packaging information. That extra information can sometimes add a substantial\n# performance overhead, so we disable it in our default logging config.\n# For more information, see SPARK-39361.\nappender.console.type = Console\nappender.console.name = console\nappender.console.target = SYSTEM_ERR\nappender.console.layout.type = PatternLayout\nappender.console.layout.pattern = %d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n%ex\n\n# Set the default spark-shell/spark-sql log level to WARN. When running the\n# spark-shell/spark-sql, the log level for these classes is used to overwrite\n# the root logger's log level, so that the user can have different defaults\n# for the shell and regular Spark apps.\nlogger.repl.name = org.apache.spark.repl.Main\nlogger.repl.level = warn\n\nlogger.thriftserver.name = org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver\nlogger.thriftserver.level = warn\n\n# Settings to quiet third party logs that are too verbose\nlogger.jetty1.name = org.sparkproject.jetty\nlogger.jetty1.level = warn\nlogger.jetty2.name = org.sparkproject.jetty.util.component.AbstractLifeCycle\nlogger.jetty2.level = error\nlogger.replexprTyper.name = org.apache.spark.repl.SparkIMain$exprTyper\nlogger.replexprTyper.level = info\nlogger.replSparkILoopInterpreter.name = org.apache.spark.repl.SparkILoop$SparkILoopInterpreter\nlogger.replSparkILoopInterpreter.level = info\nlogger.parquet1.name = org.apache.parquet\nlogger.parquet1.level = error\nlogger.parquet2.name = parquet\nlogger.parquet2.level = error\n\n# SPARK-9183: Settings to avoid annoying messages when looking up nonexistent UDFs in SparkSQL with Hive support\nlogger.RetryingHMSHandler.name = org.apache.hadoop.hive.metastore.RetryingHMSHandler\nlogger.RetryingHMSHandler.level = fatal\nlogger.FunctionRegistry.name = org.apache.hadoop.hive.ql.exec.FunctionRegistry\nlogger.FunctionRegistry.level = error\n\n# For deploying Spark ThriftServer\n# SPARK-34128: Suppress undesirable TTransportException warnings involved in THRIFT-4805\nappender.console.filter.1.type = RegexFilter\nappender.console.filter.1.regex = .*Thrift error occurred during processing of message.*\nappender.console.filter.1.onMatch = deny\nappender.console.filter.1.onMismatch = neutral\n\n# Set G-Research Spark logging to DEBUG\nlogger.GRSpark.name = uk.co.gresearch.spark\nlogger.GRSpark.level = debug\n"
  },
  {
    "path": "src/test/scala/uk/co/gresearch/spark/GroupBySuite.scala",
    "content": "/*\n * Copyright 2022 G-Research\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage uk.co.gresearch.spark\n\nimport org.apache.spark.sql.{DataFrame, Dataset, KeyValueGroupedDataset, Row}\nimport uk.co.gresearch.spark.GroupBySortedSuite.{valueRowToTuple, valueToTuple}\nimport uk.co.gresearch.spark.group.SortedGroupByDataset\nimport uk.co.gresearch.test.Spec\n\nimport scala.language.implicitConversions\n\ncase class Val(id: Int, seq: Int, value: Double)\ncase class State(init: Int) {\n  var sum: Int = init\n\n  def add(i: Int): Int = {\n    sum = sum + i\n    sum\n  }\n}\n\nclass GroupBySuite extends Spec with SparkTestSession {\n\n  import spark.implicits._\n\n  // format: off\n  val ds: Dataset[Val] = Seq(\n    Val(1, 1, 1.1),\n    Val(1, 2, 1.2),\n    Val(1, 3, 1.3),\n    Val(1, 3, 1.31),\n\n    Val(2, 1, 2.1),\n    Val(2, 2, 2.2),\n    Val(2, 3, 2.3),\n\n    Val(3, 1, 3.1),\n  ).reverse.toDS().repartition(3).cache()\n  // format: on\n\n  val df: DataFrame = ds.toDF()\n\n  it(\"should ds.groupByKey\") {\n    testGroupBy(ds.groupByKey($\"id\"))\n    testGroupBy(ds.groupByKey(\"id\"))\n  }\n\n  it(\"should df.groupByKey\") {\n    testGroupBy(df.groupByKey($\"id\"))\n    testGroupBy(df.groupByKey(\"id\"))\n  }\n\n  def testGroupBy[T](ds: KeyValueGroupedDataset[Int, T]): Unit = {\n    val actual = ds\n      .mapGroups { (key, it) => (key, it.length) }\n      .collect()\n      .sortBy(v => v._1)\n\n    val expected = Seq(\n      // (key, group length)\n      (1, 4),\n      (2, 3),\n      (3, 1),\n    )\n\n    assert(actual === expected)\n  }\n\n  describe(\"ds.groupBySorted\") {\n    testGroupByIdSortBySeq(ds.groupBySorted($\"id\")($\"seq\", $\"value\"))\n    testGroupByIdSortBySeqDesc(ds.groupBySorted($\"id\")($\"seq\".desc, $\"value\".desc))\n    testGroupByIdSortBySeqWithPartitionNum(ds.groupBySorted(10)($\"id\")($\"seq\", $\"value\"))\n    testGroupByIdSortBySeqDescWithPartitionNum(ds.groupBySorted(10)($\"id\")($\"seq\".desc, $\"value\".desc))\n    testGroupByIdSeqSortByValue(ds.groupBySorted($\"id\", $\"seq\")($\"value\"))\n  }\n\n  describe(\"ds.groupByKeySorted\") {\n    testGroupByIdSortBySeq(ds.groupByKeySorted(v => v.id)(v => (v.seq, v.value)))\n    testGroupByIdSortBySeqDesc(ds.groupByKeySorted(v => v.id)(v => (v.seq, v.value), reverse = true))\n    testGroupByIdSortBySeqWithPartitionNum(ds.groupByKeySorted(v => v.id, partitions = Some(10))(v => (v.seq, v.value)))\n    testGroupByIdSortBySeqDescWithPartitionNum(\n      ds.groupByKeySorted(v => v.id, partitions = Some(10))(v => (v.seq, v.value), reverse = true)\n    )\n    testGroupByIdSeqSortByValue(ds.groupByKeySorted(v => (v.id, v.seq))(v => v.value))\n  }\n\n  describe(\"df.groupBySorted\") {\n    testGroupByIdSortBySeq(df.groupBySorted($\"id\")($\"seq\", $\"value\"))\n    testGroupByIdSortBySeqDesc(df.groupBySorted($\"id\")($\"seq\".desc, $\"value\".desc))\n    testGroupByIdSortBySeqWithPartitionNum(df.groupBySorted(10)($\"id\")($\"seq\", $\"value\"))\n    testGroupByIdSortBySeqDescWithPartitionNum(df.groupBySorted(10)($\"id\")($\"seq\".desc, $\"value\".desc))\n    testGroupByIdSeqSortByValue(df.groupBySorted($\"id\", $\"seq\")($\"value\"))\n  }\n\n  describe(\"df.groupByKeySorted\") {\n    testGroupByIdSortBySeq(df.groupByKeySorted(v => v.getInt(0))(v => (v.getInt(1), v.getDouble(2))))\n    testGroupByIdSortBySeqDesc(\n      df.groupByKeySorted(v => v.getInt(0))(v => (v.getInt(1), v.getDouble(2)), reverse = true)\n    )\n    testGroupByIdSortBySeqWithPartitionNum(\n      df.groupByKeySorted(v => v.getInt(0), partitions = Some(10))(v => (v.getInt(1), v.getDouble(2)))\n    )\n    testGroupByIdSortBySeqDescWithPartitionNum(\n      df.groupByKeySorted(v => v.getInt(0), partitions = Some(10))(v => (v.getInt(1), v.getDouble(2)), reverse = true)\n    )\n    testGroupByIdSeqSortByValue(df.groupByKeySorted(v => (v.getInt(0), v.getInt(1)))(v => v.getDouble(2)))\n  }\n\n  def testGroupByIdSortBySeq[T](ds: SortedGroupByDataset[Int, T])(implicit asTuple: T => (Int, Int, Double)): Unit = {\n\n    it(\"should flatMapSortedGroups\") {\n      val actual = ds\n        .flatMapSortedGroups((key, it) => it.zipWithIndex.map(v => (key, v._2, asTuple(v._1))))\n        .collect()\n        .sortBy(v => (v._1, v._2))\n\n      // format: off\n      val expected = Seq(\n        // (key, group index, value)\n        (1, 0, (1, 1, 1.1)),\n        (1, 1, (1, 2, 1.2)),\n        (1, 2, (1, 3, 1.3)),\n        (1, 3, (1, 3, 1.31)),\n\n        (2, 0, (2, 1, 2.1)),\n        (2, 1, (2, 2, 2.2)),\n        (2, 2, (2, 3, 2.3)),\n\n        (3, 0, (3, 1, 3.1)),\n      )\n      // format: on\n\n      assert(actual === expected)\n    }\n\n    it(\"should flatMapSortedGroups with state\") {\n      val actual = ds\n        .flatMapSortedGroups(key => State(key))((state, v) => Iterator((asTuple(v), state.add(asTuple(v)._2))))\n        .collect()\n        .sortBy(v => (v._1._1, v._1._2))\n\n      // format: off\n      val expected = Seq(\n        // (value, state)\n        ((1, 1, 1.1), 1 + 1),\n        ((1, 2, 1.2), 1 + 1 + 2),\n        ((1, 3, 1.3), 1 + 1 + 2 + 3),\n        ((1, 3, 1.31), 1 + 1 + 2 + 3 + 3),\n\n        ((2, 1, 2.1), 2 + 1),\n        ((2, 2, 2.2), 2 + 1 + 2),\n        ((2, 3, 2.3), 2 + 1 + 2 + 3),\n\n        ((3, 1, 3.1), 3 + 1),\n      )\n      // format: on\n\n      assert(actual === expected)\n    }\n\n  }\n\n  def testGroupByIdSortBySeqDesc[T](\n      ds: SortedGroupByDataset[Int, T]\n  )(implicit asTuple: T => (Int, Int, Double)): Unit = {\n    it(\"should flatMapSortedGroups reverse\") {\n      val actual = ds\n        .flatMapSortedGroups((key, it) => it.zipWithIndex.map(v => (key, v._2, asTuple(v._1))))\n        .collect()\n        .sortBy(v => (v._1, v._2))\n\n      // format: off\n      val expected = Seq(\n        // (key, group index, value)\n        (1, 0, (1, 3, 1.31)),\n        (1, 1, (1, 3, 1.3)),\n        (1, 2, (1, 2, 1.2)),\n        (1, 3, (1, 1, 1.1)),\n\n        (2, 0, (2, 3, 2.3)),\n        (2, 1, (2, 2, 2.2)),\n        (2, 2, (2, 1, 2.1)),\n\n        (3, 0, (3, 1, 3.1)),\n      )\n      // format: on\n\n      assert(actual === expected)\n    }\n\n  }\n\n  def testGroupByIdSortBySeqWithPartitionNum[T](ds: SortedGroupByDataset[Int, T], partitions: Int = 10)(implicit\n      asTuple: T => (Int, Int, Double)\n  ): Unit = {\n\n    it(\"should flatMapSortedGroups with partition num\") {\n      val grouped = ds\n        .flatMapSortedGroups((key, it) => it.zipWithIndex.map(v => (key, v._2, asTuple(v._1))))\n      val actual = grouped\n        .collect()\n        .sortBy(v => (v._1, v._2))\n\n      // format: off\n      val expected = Seq(\n        // (key, group index, value)\n        (1, 0, (1, 1, 1.1)),\n        (1, 1, (1, 2, 1.2)),\n        (1, 2, (1, 3, 1.3)),\n        (1, 3, (1, 3, 1.31)),\n\n        (2, 0, (2, 1, 2.1)),\n        (2, 1, (2, 2, 2.2)),\n        (2, 2, (2, 3, 2.3)),\n\n        (3, 0, (3, 1, 3.1)),\n      )\n      // format: on\n\n      val partitionSizes = grouped.mapPartitions(it => Iterator.single(it.length)).collect()\n      assert(partitionSizes.length === partitions)\n      assert(partitionSizes.sum === actual.length)\n      assert(actual === expected)\n    }\n\n  }\n\n  def testGroupByIdSortBySeqDescWithPartitionNum[T](ds: SortedGroupByDataset[Int, T], partitions: Int = 10)(implicit\n      asTuple: T => (Int, Int, Double)\n  ): Unit = {\n    it(\"should flatMapSortedGroups with partition num and reverse\") {\n      val grouped = ds\n        .flatMapSortedGroups((key, it) => it.zipWithIndex.map(v => (key, v._2, asTuple(v._1))))\n      val actual = grouped\n        .collect()\n        .sortBy(v => (v._1, v._2))\n\n      // format: off\n      val expected = Seq(\n        // (key, group index, value)\n        (1, 0, (1, 3, 1.31)),\n        (1, 1, (1, 3, 1.3)),\n        (1, 2, (1, 2, 1.2)),\n        (1, 3, (1, 1, 1.1)),\n\n        (2, 0, (2, 3, 2.3)),\n        (2, 1, (2, 2, 2.2)),\n        (2, 2, (2, 1, 2.1)),\n\n        (3, 0, (3, 1, 3.1)),\n      )\n      // format: on\n\n      val partitionSizes = grouped.mapPartitions(it => Iterator.single(it.length)).collect()\n      assert(partitionSizes.length === partitions)\n      assert(partitionSizes.sum === actual.length)\n      assert(actual === expected)\n    }\n  }\n\n  def testGroupByIdSeqSortByValue[T](\n      ds: SortedGroupByDataset[(Int, Int), T]\n  )(implicit asTuple: T => (Int, Int, Double)): Unit = {\n\n    it(\"should flatMapSortedGroups with tuple key\") {\n      val actual = ds\n        .flatMapSortedGroups((key, it) => it.zipWithIndex.map(v => (key, v._2, asTuple(v._1))))\n        .collect()\n        .sortBy(v => (v._1, v._2))\n\n      // format: off\n      val expected = Seq(\n        // (key, group index, value)\n        ((1, 1), 0, (1, 1, 1.1)),\n\n        ((1, 2), 0, (1, 2, 1.2)),\n\n        ((1, 3), 0, (1, 3, 1.3)),\n        ((1, 3), 1, (1, 3, 1.31)),\n\n        ((2, 1), 0, (2, 1, 2.1)),\n\n        ((2, 2), 0, (2, 2, 2.2)),\n\n        ((2, 3), 0, (2, 3, 2.3)),\n\n        ((3, 1), 0, (3, 1, 3.1)),\n      )\n      // format: on\n\n      assert(actual === expected)\n    }\n\n    it(\"should flatMapSortedGroups with tuple key and state\") {\n      val actual = ds\n        .flatMapSortedGroups(key => State(key._1))((state, v) => Iterator((asTuple(v), state.add(asTuple(v)._2))))\n        .collect()\n        .sortBy(v => (v._1._1, v._1._2))\n\n      // format: off\n      val expected = Seq(\n        // (value, state)\n        ((1, 1, 1.1), 1 + 1),\n        ((1, 2, 1.2), 1 + 2),\n        ((1, 3, 1.3), 1 + 3),\n        ((1, 3, 1.31), 1 + 3 + 3),\n\n        ((2, 1, 2.1), 2 + 1),\n        ((2, 2, 2.2), 2 + 2),\n        ((2, 3, 2.3), 2 + 3),\n\n        ((3, 1, 3.1), 3 + 1),\n      )\n      // format: on\n\n      assert(actual === expected)\n    }\n\n  }\n\n}\n\nobject GroupBySortedSuite {\n  implicit def valueToTuple(value: Val): (Int, Int, Double) = (value.id, value.seq, value.value)\n  implicit def valueRowToTuple(value: Row): (Int, Int, Double) = (value.getInt(0), value.getInt(1), value.getDouble(2))\n}\n"
  },
  {
    "path": "src/test/scala/uk/co/gresearch/spark/HistogramSuite.scala",
    "content": "/*\n * Copyright 2020 G-Research\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage uk.co.gresearch.spark\n\nimport org.apache.spark.sql.{AnalysisException, DataFrame, Dataset}\nimport org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructField, StructType}\nimport uk.co.gresearch.test.Suite\n\ncase class IntValue(id: Int, title: String, value: Int)\ncase class DoubleValue(id: Int, title: String, value: Double)\n\nclass HistogramSuite extends Suite with SparkTestSession {\n\n  import spark.implicits._\n\n  val ints: Dataset[IntValue] = Seq(\n    IntValue(1, \"one\", 1),\n    IntValue(1, \"one\", 2),\n    IntValue(1, \"one\", 3),\n    IntValue(2, \"two\", 10),\n    IntValue(2, \"two\", 11),\n    IntValue(2, \"two\", 11),\n    IntValue(2, \"two\", 12),\n    IntValue(3, \"three\", -123),\n    IntValue(3, \"three\", 100),\n    IntValue(3, \"three\", 1024),\n    IntValue(4, \"four\", 0)\n  ).toDS()\n  val intThresholds = Seq(-200, -100, 0, 100, 200)\n\n  val doubles: Dataset[DoubleValue] = Seq(\n    DoubleValue(1, \"one\", 1.0),\n    DoubleValue(1, \"one\", 2.0),\n    DoubleValue(1, \"one\", 3.0),\n    DoubleValue(2, \"two\", 10.0),\n    DoubleValue(2, \"two\", 11.0),\n    DoubleValue(2, \"two\", 11.0),\n    DoubleValue(2, \"two\", 12.0),\n    DoubleValue(3, \"three\", -123.0),\n    DoubleValue(3, \"three\", 100.0),\n    DoubleValue(3, \"three\", 1024.0),\n    DoubleValue(4, \"four\", 0.0)\n  ).toDS()\n  val doubleThresholds = Seq(-200.0, -100.0, 0.0, 100.0, 200.0)\n\n  val expectedHistogram = Seq(\n    Seq(1, 0, 0, 0, 3, 0, 0),\n    Seq(2, 0, 0, 0, 4, 0, 0),\n    Seq(3, 0, 1, 0, 1, 0, 1),\n    Seq(4, 0, 0, 1, 0, 0, 0)\n  )\n  val expectedSchema: StructType = StructType(\n    Seq(\n      StructField(\"id\", IntegerType, nullable = false),\n      StructField(\"≤-200\", LongType, nullable = true),\n      StructField(\"≤-100\", LongType, nullable = true),\n      StructField(\"≤0\", LongType, nullable = true),\n      StructField(\"≤100\", LongType, nullable = true),\n      StructField(\"≤200\", LongType, nullable = true),\n      StructField(\">200\", LongType, nullable = true)\n    )\n  )\n  val expectedSchema2: StructType = StructType(\n    Seq(\n      StructField(\"id\", IntegerType, nullable = false),\n      StructField(\"title\", StringType, nullable = true),\n      StructField(\"≤-200\", LongType, nullable = true),\n      StructField(\"≤-100\", LongType, nullable = true),\n      StructField(\"≤0\", LongType, nullable = true),\n      StructField(\"≤100\", LongType, nullable = true),\n      StructField(\"≤200\", LongType, nullable = true),\n      StructField(\">200\", LongType, nullable = true)\n    )\n  )\n  val expectedDoubleSchema: StructType = StructType(\n    Seq(\n      StructField(\"id\", IntegerType, nullable = false),\n      StructField(\"≤-200.0\", LongType, nullable = true),\n      StructField(\"≤-100.0\", LongType, nullable = true),\n      StructField(\"≤0.0\", LongType, nullable = true),\n      StructField(\"≤100.0\", LongType, nullable = true),\n      StructField(\"≤200.0\", LongType, nullable = true),\n      StructField(\">200.0\", LongType, nullable = true)\n    )\n  )\n\n  test(\"histogram with no aggregate columns\") {\n    val histogram = ints.histogram(intThresholds, $\"value\")\n    val actual = histogram.collect().toSeq.map(_.toSeq)\n    assert(histogram.schema === StructType(expectedSchema.fields.toSeq.filterNot(_.name.equals(\"id\"))))\n    assert(actual === Seq(Seq(0, 1, 1, 8, 0, 1)))\n  }\n\n  test(\"histogram with one aggregate column\") {\n    val histogram = ints.histogram(intThresholds, $\"value\", $\"id\")\n    val actual = histogram.orderBy($\"id\").collect().toSeq.map(_.toSeq)\n    assert(histogram.schema === expectedSchema)\n    assert(actual === expectedHistogram)\n  }\n\n  test(\"histogram with two aggregate columns\") {\n    val histogram = ints.histogram(intThresholds, $\"value\", $\"id\", $\"title\")\n    val actual = histogram.orderBy($\"id\").collect().toSeq.map(_.toSeq)\n    assert(histogram.schema === expectedSchema2)\n    assert(\n      actual === Seq(\n        Seq(1, \"one\", 0, 0, 0, 3, 0, 0),\n        Seq(2, \"two\", 0, 0, 0, 4, 0, 0),\n        Seq(3, \"three\", 0, 1, 0, 1, 0, 1),\n        Seq(4, \"four\", 0, 0, 1, 0, 0, 0)\n      )\n    )\n  }\n\n  test(\"histogram with int values\") {\n    val histogram = ints.histogram(intThresholds, $\"value\", $\"id\")\n    val actual = histogram.orderBy($\"id\").collect().toSeq.map(_.toSeq)\n    assert(histogram.schema === expectedSchema)\n    assert(actual === expectedHistogram)\n  }\n\n  test(\"histogram with double values\") {\n    val histogram = doubles.histogram(doubleThresholds, $\"value\", $\"id\")\n    val actual = histogram.orderBy($\"id\").collect().toSeq.map(_.toSeq)\n    assert(histogram.schema === expectedDoubleSchema)\n    assert(actual === expectedHistogram)\n  }\n\n  test(\"histogram with int values and double thresholds\") {\n    val histogram = ints.histogram(doubleThresholds, $\"value\", $\"id\")\n    val actual = histogram.orderBy($\"id\").collect().toSeq.map(_.toSeq)\n    assert(histogram.schema === expectedDoubleSchema)\n    assert(actual === expectedHistogram)\n  }\n\n  test(\"histogram with double values and int thresholds\") {\n    val histogram = doubles.histogram(intThresholds, $\"value\", $\"id\")\n    val actual = histogram.orderBy($\"id\").collect().toSeq.map(_.toSeq)\n    assert(histogram.schema === expectedSchema)\n    assert(actual === expectedHistogram)\n  }\n\n  test(\"histogram with no thresholds\") {\n    val exception = intercept[IllegalArgumentException] {\n      ints.histogram(Seq.empty[Int], $\"value\", $\"id\")\n    }\n    assert(exception.getMessage === \"Thresholds must not be empty\")\n  }\n\n  test(\"histogram with one threshold\") {\n    val histogram = ints.histogram(Seq(0), $\"value\", $\"id\")\n    val actual = histogram.orderBy($\"id\").collect().toSeq.map(_.toSeq)\n    assert(\n      histogram.schema === StructType(\n        Seq(\n          StructField(\"id\", IntegerType, nullable = false),\n          StructField(\"≤0\", LongType, nullable = true),\n          StructField(\">0\", LongType, nullable = true)\n        )\n      )\n    )\n    assert(\n      actual === Seq(\n        Seq(1, 0, 3),\n        Seq(2, 0, 4),\n        Seq(3, 1, 2),\n        Seq(4, 1, 0)\n      )\n    )\n  }\n\n  test(\"histogram with duplicate thresholds\") {\n    val exception = intercept[IllegalArgumentException] {\n      ints.histogram(Seq(-200, -100, 0, 0, 100, 200), $\"value\", $\"id\")\n    }\n    assert(exception.getMessage === \"Thresholds must not contain duplicates: -200,-100,0,0,100,200\")\n  }\n\n  test(\"histogram with non-existing value columns\") {\n    val exception = intercept[AnalysisException] {\n      ints.histogram(Seq(0, -200, 100, -100, 200), $\"does-not-exist\", $\"id\")\n    }\n    assert(\n      // format: off\n      exception.getMessage.startsWith(\"cannot resolve '`does-not-exist`' given input columns: [id, title, value]\") ||\n        exception.getMessage.startsWith(\"Column '`does-not-exist`' does not exist. Did you mean one of the following? [title, id, value]\") ||\n        exception.getMessage.startsWith(\"[UNRESOLVED_COLUMN.WITH_SUGGESTION] A column or function parameter with name `does-not-exist` cannot be resolved. Did you mean one of the following? [`title`, `id`, `value`]\") ||\n        exception.getMessage.startsWith(\"[UNRESOLVED_COLUMN.WITH_SUGGESTION] A column, variable, or function parameter with name `does-not-exist` cannot be resolved. Did you mean one of the following? [`title`, `id`, `value`]\")\n      // format: on\n    )\n  }\n\n  test(\"histogram with non-existing aggregate column\") {\n    val exception = intercept[AnalysisException] {\n      ints.histogram(intThresholds, $\"value\", $\"does-not-exist\")\n    }\n    assert(\n      // format: off\n      exception.getMessage.startsWith(\"cannot resolve '`does-not-exist`' given input columns: [\") ||\n        exception.getMessage.startsWith(\"Column '`does-not-exist`' does not exist. Did you mean one of the following? [\") ||\n        exception.getMessage.startsWith(\"[UNRESOLVED_COLUMN.WITH_SUGGESTION] A column or function parameter with name `does-not-exist` cannot be resolved. Did you mean one of the following? [\") ||\n        exception.getMessage.startsWith(\"[UNRESOLVED_COLUMN.WITH_SUGGESTION] A column, variable, or function parameter with name `does-not-exist` cannot be resolved. Did you mean one of the following? [\")\n      // format: on\n    )\n  }\n\n  test(\"histogram with int values on DataFrame\") {\n    val histogram = ints.toDF().histogram(intThresholds, $\"value\", $\"id\")\n    val actual = histogram.orderBy($\"id\").collect().toSeq.map(_.toSeq)\n    assert(histogram.schema === expectedSchema)\n    assert(actual === expectedHistogram)\n  }\n\n  test(\"histogram with double values on DataFrame\") {\n    val histogram = doubles.toDF().histogram(doubleThresholds, $\"value\", $\"id\")\n    val actual = histogram.orderBy($\"id\").collect().toSeq.map(_.toSeq)\n    assert(histogram.schema === expectedDoubleSchema)\n    assert(actual === expectedHistogram)\n  }\n\n}\n"
  },
  {
    "path": "src/test/scala/uk/co/gresearch/spark/SparkSuite.scala",
    "content": "/*\n * Copyright 2020 G-Research\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage uk.co.gresearch.spark\n\nimport org.apache.spark.{SparkFiles, TaskContext}\nimport org.apache.spark.sql._\nimport org.apache.spark.sql.extension.ColumnExtension\nimport org.apache.spark.sql.functions._\nimport org.apache.spark.sql.types._\nimport org.apache.spark.storage.StorageLevel\nimport org.apache.spark.storage.StorageLevel.{DISK_ONLY, MEMORY_AND_DISK, MEMORY_ONLY, NONE, OFF_HEAP}\nimport uk.co.gresearch.ExtendedAny\nimport uk.co.gresearch.spark.SparkSuite.{Value, collectJobDescription}\nimport uk.co.gresearch.test.Suite\n\nimport java.nio.file.Paths\nimport java.sql.Timestamp\nimport java.time.Instant\n\nclass SparkSuite extends Suite with SparkTestSession with SparkSuiteHelper {\n\n  import spark.implicits._\n\n  val emptyDataset: Dataset[Value] = spark.emptyDataset[Value]\n  val emptyDataFrame: DataFrame = spark.createDataFrame(Seq.empty[Value])\n\n  test(\"Get Spark version\") {\n    assert(\n      VersionString.contains(s\"-$BuildSparkCompatVersionString-\") || VersionString.endsWith(\n        s\"-$BuildSparkCompatVersionString\"\n      )\n    )\n\n    assert(spark.version.startsWith(s\"$BuildSparkCompatVersionString.\"))\n    assert(SparkVersion === BuildSparkVersion)\n    assert(SparkCompatVersion === BuildSparkCompatVersion)\n    assert(SparkCompatVersionString === BuildSparkCompatVersionString)\n    assert(SparkMajorVersion === BuildSparkMajorVersion)\n    assert(SparkMinorVersion === BuildSparkMinorVersion)\n    assert(SparkPatchVersion === BuildSparkPatchVersion)\n  }\n\n  Seq(MEMORY_AND_DISK, MEMORY_ONLY, DISK_ONLY, NONE).foreach { level =>\n    Seq(\n      (\"UnpersistHandle\", UnpersistHandle()),\n      (\"SilentUnpersistHandle\", SilentUnpersistHandle())\n    ).foreach { case (handleClass, unpersist) =>\n      test(s\"$handleClass does unpersist set DataFrame with $level\") {\n        val cacheManager = spark.sharedState.cacheManager\n        cacheManager.clearCache()\n        assert(cacheManager.isEmpty === true)\n\n        val ds = createEmptyDataset[String]()\n        assert(cacheManager.lookupCachedData(ds).isDefined === false)\n\n        unpersist.setDataFrame(ds.toDF())\n        assert(cacheManager.lookupCachedData(ds).isDefined === false)\n\n        ds.cache()\n        assert(cacheManager.lookupCachedData(ds).isDefined === true)\n\n        unpersist(blocking = true)\n        assert(cacheManager.lookupCachedData(ds).isDefined === false)\n\n        // calling this twice does not throw any errors\n        unpersist()\n      }\n    }\n  }\n\n  Seq(MEMORY_AND_DISK, MEMORY_ONLY, DISK_ONLY, NONE).foreach { level =>\n    test(s\"NoopUnpersistHandle does not unpersist set DataFrame with $level\") {\n      val cacheManager = spark.sharedState.cacheManager\n      cacheManager.clearCache()\n      assert(cacheManager.isEmpty === true)\n\n      val ds = createEmptyDataset[String]()\n      assert(cacheManager.lookupCachedData(ds).isDefined === false)\n\n      val unpersist = UnpersistHandle.Noop\n      unpersist.setDataFrame(ds.toDF())\n      assert(cacheManager.lookupCachedData(ds).isDefined === false)\n\n      ds.cache()\n      assert(cacheManager.lookupCachedData(ds).isDefined === true)\n\n      unpersist(blocking = true)\n      assert(cacheManager.lookupCachedData(ds).isDefined === true)\n\n      // calling this twice does not throw any errors\n      unpersist()\n    }\n  }\n\n  Seq(\n    (\"UnpersistHandle\", UnpersistHandle()),\n    (\"SilentUnpersistHandle\", SilentUnpersistHandle())\n  ).foreach { case (handleClass, unpersist) =>\n    test(s\"$handleClass throws on setting DataFrame twice\") {\n      unpersist.setDataFrame(spark.emptyDataFrame)\n      assert(intercept[IllegalStateException] {\n        unpersist.setDataFrame(spark.emptyDataFrame)\n      }.getMessage === s\"DataFrame has been set already, it cannot be reused.\")\n    }\n  }\n\n  test(\"UnpersistHandle throws on unpersist if no DataFrame is set\") {\n    val unpersist = UnpersistHandle()\n    assert(intercept[IllegalStateException] { unpersist() }.getMessage === s\"DataFrame has to be set first\")\n  }\n\n  test(\"UnpersistHandle throws on unpersist with blocking if no DataFrame is set\") {\n    val unpersist = UnpersistHandle()\n    assert(intercept[IllegalStateException] {\n      unpersist(blocking = true)\n    }.getMessage === s\"DataFrame has to be set first\")\n  }\n\n  test(\"SilentUnpersistHandle does not throw on unpersist if no DataFrame is set\") {\n    val unpersist = SilentUnpersistHandle()\n    unpersist()\n  }\n\n  test(\"SilentUnpersistHandle does not throw on unpersist with blocking if no DataFrame is set\") {\n    val unpersist = SilentUnpersistHandle()\n    unpersist(blocking = true)\n  }\n\n  test(\"backticks\") {\n    assert(backticks(\"column\") === \"column\")\n    assert(backticks(\"a.column\") === \"`a.column`\")\n    assert(backticks(\"a column\") === \"`a column`\")\n    assert(backticks(\"a `column`\") === \"`a ``column```\")\n    assert(backticks(\"column\", \"a.field\") === \"column.`a.field`\")\n    assert(backticks(\"a.column\", \"a.field\") === \"`a.column`.`a.field`\")\n    assert(backticks(\"the.alias\", \"a.column\", \"a.field\") === \"`the.alias`.`a.column`.`a.field`\")\n  }\n\n  test(\"count_null\") {\n    val df = Seq(\n      (1, \"some\"),\n      (2, \"text\"),\n      (3, \"and\"),\n      (4, \"some\"),\n      (5, \"null\"),\n      (6, \"values\"),\n      (7, null),\n      (8, null)\n    ).toDF(\"id\", \"str\")\n    val actual =\n      df.select(\n        count($\"id\").as(\"ids\"),\n        count($\"str\").as(\"strs\"),\n        count_null($\"id\").as(\"null ids\"),\n        count_null($\"str\").as(\"null strs\")\n      ).collect()\n        .head\n    assert(actual === Row(8, 6, 0, 2))\n  }\n\n  def assertJobDescription(expected: Option[String]): Unit = {\n    val descriptions = collectJobDescription(spark)\n    assert(descriptions === 0.to(2).map(id => (id, id, expected.orNull)))\n  }\n\n  test(\"without job description\") {\n    assertJobDescription(None)\n  }\n\n  test(\"with job description\") {\n    implicit val session: SparkSession = spark\n\n    assertJobDescription(None)\n    withJobDescription(\"test job description\") {\n      assertJobDescription(Some(\"test job description\"))\n      spark.sparkContext.setJobDescription(\"modified\")\n      assertJobDescription(Some(\"modified\"))\n    }\n    assertJobDescription(None)\n  }\n\n  test(\"with existing job description\") {\n    implicit val session: SparkSession = spark\n\n    assertJobDescription(None)\n    withJobDescription(\"outer job description\") {\n      assertJobDescription(Some(\"outer job description\"))\n      withJobDescription(\"inner job description\") {\n        assertJobDescription(Some(\"inner job description\"))\n        spark.sparkContext.setJobDescription(\"modified\")\n        assertJobDescription(Some(\"modified\"))\n      }\n      assertJobDescription(Some(\"outer job description\"))\n    }\n    assertJobDescription(None)\n  }\n\n  test(\"with existing job description if not set\") {\n    implicit val session: SparkSession = spark\n\n    assertJobDescription(None)\n    withJobDescription(\"outer job description\") {\n      assertJobDescription(Some(\"outer job description\"))\n      withJobDescription(\"inner job description\", true) {\n        assertJobDescription(Some(\"outer job description\"))\n        spark.sparkContext.setJobDescription(\"modified\")\n        assertJobDescription(Some(\"modified\"))\n      }\n      assertJobDescription(Some(\"outer job description\"))\n    }\n    assertJobDescription(None)\n  }\n\n  test(\"append job description\") {\n    implicit val session: SparkSession = spark\n\n    assertJobDescription(None)\n    appendJobDescription(\"test\") {\n      assertJobDescription(Some(\"test\"))\n      appendJobDescription(\"job\") {\n        assertJobDescription(Some(\"test - job\"))\n        appendJobDescription(\"description\", \" \") {\n          assertJobDescription(Some(\"test - job description\"))\n          spark.sparkContext.setJobDescription(\"modified\")\n          assertJobDescription(Some(\"modified\"))\n        }\n        assertJobDescription(Some(\"test - job\"))\n      }\n      assertJobDescription(Some(\"test\"))\n    }\n    assertJobDescription(None)\n  }\n\n  def assertIsDataset[T](actual: Dataset[T]): Unit = {\n    // if calling class compiles, we assert success\n    // further we evaluate the dataset to see this works as well\n    actual.collect()\n  }\n\n  def assertIsGenericType[T](actual: T): Unit = {\n    // if calling class compiles, we assert success\n  }\n\n  test(\"call dataset-to-dataset transformation\") {\n    assertIsDataset[Value](spark.emptyDataset[Value].transform(_.sort()))\n    assertIsDataset[Value](spark.emptyDataset[Value].call(_.sort()))\n  }\n\n  test(\"call dataset-to-dataframe transformation\") {\n    assertIsDataset[Row](spark.emptyDataset[Value].transform(_.drop(\"string\")))\n    assertIsDataset[Row](spark.emptyDataset[Value].call(_.drop(\"string\")))\n  }\n\n  test(\"call dataframe-to-dataset transformation\") {\n    assertIsDataset[Value](spark.createDataFrame(Seq.empty[Value]).transform(_.as[Value]))\n    assertIsDataset[Value](spark.createDataFrame(Seq.empty[Value]).call(_.as[Value]))\n  }\n\n  test(\"call dataframe-to-dataframe transformation\") {\n    assertIsDataset[Row](spark.createDataFrame(Seq.empty[Value]).transform(_.drop(\"string\")))\n    assertIsDataset[Value](spark.createDataFrame(Seq.empty[Value]).call(_.as[Value]))\n  }\n\n  Seq(true, false).foreach { condition =>\n    test(s\"call on $condition condition dataset-to-dataset transformation\") {\n      assertIsGenericType[Dataset[Value]](\n        emptyDataset.transform(_.on(condition).call(_.sort()))\n      )\n      assertIsGenericType[Dataset[Value]](\n        emptyDataset.on(condition).call(_.sort())\n      )\n    }\n\n    test(s\"call on $condition condition dataframe-to-dataframe transformation\") {\n      assertIsGenericType[DataFrame](\n        emptyDataFrame.transform(_.on(condition).call(_.drop(\"string\")))\n      )\n      assertIsGenericType[DataFrame](\n        emptyDataFrame.on(condition).call(_.drop(\"string\"))\n      )\n    }\n\n    test(s\"when $condition call dataset-to-dataset transformation\") {\n      assertIsDataset[Value](\n        emptyDataset.transform(_.when(condition).call(_.sort()))\n      )\n      assertIsDataset[Value](\n        emptyDataset.when(condition).call(_.sort())\n      )\n    }\n\n    test(s\"when $condition call dataframe-to-dataframe transformation\") {\n      assertIsDataset[Row](\n        emptyDataFrame.transform(_.when(condition).call(_.drop(\"string\")))\n      )\n      assertIsDataset[Row](\n        emptyDataFrame.when(condition).call(_.drop(\"string\"))\n      )\n    }\n\n    test(s\"call on $condition condition either dataset-to-dataset transformation\") {\n      assertIsGenericType[Dataset[Value]](\n        spark\n          .emptyDataset[Value]\n          .transform(\n            _.on(condition)\n              .either(_.sort())\n              .or(_.orderBy())\n          )\n      )\n    }\n\n    test(s\"call on $condition condition either dataset-to-dataframe transformation\") {\n      assertIsGenericType[DataFrame](\n        spark\n          .emptyDataset[Value]\n          .transform(\n            _.on(condition)\n              .either(_.drop(\"string\"))\n              .or(_.withColumnRenamed(\"string\", \"value\"))\n          )\n      )\n    }\n\n    test(s\"call on $condition condition either dataframe-to-dataset transformation\") {\n      assertIsGenericType[Dataset[Value]](\n        spark\n          .createDataFrame(Seq.empty[Value])\n          .transform(\n            _.on(condition)\n              .either(_.as[Value])\n              .or(_.as[Value])\n          )\n      )\n    }\n\n    test(s\"call on $condition condition either dataframe-to-dataframe transformation\") {\n      assertIsGenericType[DataFrame](\n        spark\n          .createDataFrame(Seq.empty[Value])\n          .transform(\n            _.on(condition)\n              .either(_.drop(\"string\"))\n              .or(_.withColumnRenamed(\"string\", \"value\"))\n          )\n      )\n    }\n  }\n\n  test(\"on true condition call either writer-to-writer methods\") {\n    assertIsGenericType[DataFrameWriter[Value]](\n      spark\n        .emptyDataset[Value]\n        .write\n        .on(true)\n        .either(_.partitionBy(\"id\"))\n        .or(_.bucketBy(10, \"id\"))\n        .mode(SaveMode.Overwrite)\n    )\n  }\n\n  test(\"on false condition call either writer-to-writer methods\") {\n    assertIsGenericType[DataFrameWriter[Value]](\n      spark\n        .emptyDataset[Value]\n        .write\n        .on(false)\n        .either(_.partitionBy(\"id\"))\n        .or(_.bucketBy(10, \"id\"))\n        .mode(SaveMode.Overwrite)\n    )\n  }\n\n  test(\"on true condition call either writer-to-unit methods\") {\n    withTempPath { dir =>\n      assertIsGenericType[Unit](\n        spark\n          .emptyDataset[Value]\n          .write\n          .on(true)\n          .either(_.csv(dir.getAbsolutePath))\n          .or(_.csv(dir.getAbsolutePath))\n      )\n    }\n  }\n\n  test(\"on false condition call either writer-to-unit methods\") {\n    withTempPath { dir =>\n      assertIsGenericType[Unit](\n        spark\n          .emptyDataset[Value]\n          .write\n          .on(false)\n          .either(_.csv(dir.getAbsolutePath))\n          .or(_.csv(dir.getAbsolutePath))\n      )\n    }\n  }\n\n  test(\"global row number preserves order\") {\n    doTestWithRowNumbers()() { df =>\n      assert(df.columns === Seq(\"id\", \"rand\", \"row_number\"))\n    }\n  }\n\n  test(\"global row number respects order\") {\n    doTestWithRowNumbers { df => df.repartition(100) }($\"id\")()\n  }\n\n  test(\"global row number supports multiple order columns\") {\n    doTestWithRowNumbers { df => df.repartition(100) }($\"id\", $\"rand\", rand())()\n  }\n\n  test(\"global row number allows desc order\") {\n    doTestWithRowNumbers { df => df.repartition(100) }($\"id\".desc)()\n  }\n\n  Seq(MEMORY_AND_DISK, MEMORY_ONLY, DISK_ONLY, OFF_HEAP, NONE).foreach { level =>\n    test(s\"global row number with $level\") {\n      if (\n        level.equals(StorageLevel.NONE) && (SparkMajorVersion > 3 || SparkMajorVersion == 3 && SparkMinorVersion >= 5)\n      ) {\n        assertThrows[IllegalArgumentException] {\n          doTestWithRowNumbers(storageLevel = level)($\"id\")()\n        }\n      } else {\n        doTestWithRowNumbers(storageLevel = level)($\"id\")()\n      }\n    }\n  }\n\n  Seq(MEMORY_AND_DISK, MEMORY_ONLY, DISK_ONLY, OFF_HEAP, NONE).foreach { level =>\n    test(s\"global row number allows to unpersist with $level\") {\n      if (\n        level.equals(StorageLevel.NONE) && (SparkMajorVersion > 3 || SparkMajorVersion == 3 && SparkMinorVersion >= 5)\n      ) {\n        assertThrows[IllegalArgumentException] {\n          doTestWithRowNumbers(storageLevel = level)($\"id\")()\n        }\n      } else {\n        val cacheManager = spark.sharedState.cacheManager\n        cacheManager.clearCache()\n        assert(cacheManager.isEmpty === true)\n\n        val unpersist = UnpersistHandle()\n        doTestWithRowNumbers(storageLevel = level, unpersistHandle = unpersist)($\"id\")()\n        assert(cacheManager.isEmpty === false)\n        unpersist(true)\n        assert(cacheManager.isEmpty === true)\n      }\n    }\n  }\n\n  test(\"global row number with existing row_number column\") {\n    // this overwrites the existing column 'row_number' (formerly 'rand') with the row numbers\n    doTestWithRowNumbers { df => df.withColumnRenamed(\"rand\", \"row_number\") }() { df =>\n      assert(df.columns === Seq(\"id\", \"row_number\"))\n    }\n  }\n\n  test(\"global row number with custom row_number column\") {\n    // this puts the row numbers in the column \"row\", which is not the default column name\n    doTestWithRowNumbers(df => df.withColumnRenamed(\"rand\", \"row_number\"), rowNumberColumnName = \"row\")() { df =>\n      assert(df.columns === Seq(\"id\", \"row_number\", \"row\"))\n    }\n  }\n\n  test(\"global row number with internal column names\") {\n    val cols =\n      Seq(\"mono_id\", \"partition_id\", \"local_row_number\", \"max_local_row_number\", \"cum_row_numbers\", \"partition_offset\")\n    var prefix: String = null\n\n    doTestWithRowNumbers { df =>\n      prefix = distinctPrefixFor(df.columns)\n      cols.foldLeft(df) { (df, name) => df.withColumn(prefix + name, rand()) }\n    }() { df =>\n      assert(df.columns === Seq(\"id\", \"rand\") ++ cols.map(prefix + _) :+ \"row_number\")\n    }\n  }\n\n  def doTestWithRowNumbers(\n      transform: DataFrame => DataFrame = identity,\n      rowNumberColumnName: String = \"row_number\",\n      storageLevel: StorageLevel = MEMORY_AND_DISK,\n      unpersistHandle: UnpersistHandle = UnpersistHandle.Noop\n  )(columns: Column*)(handle: DataFrame => Unit = identity[DataFrame]): Unit = {\n    val partitions = 10\n    val rowsPerPartition = 1000\n    val rows = partitions * rowsPerPartition\n    assert(partitions > 1)\n    assert(rowsPerPartition > 1)\n\n    val df = spark\n      .range(1, rows + 1, 1, partitions)\n      .withColumn(\"rand\", rand())\n      .transform(transform)\n      .withRowNumbers(\n        rowNumberColumnName = rowNumberColumnName,\n        storageLevel = storageLevel,\n        unpersistHandle = unpersistHandle,\n        columns: _*\n      )\n      .cache()\n\n    try {\n      // testing with descending order is only supported for a single column\n      val desc = columns.map(_.sql) match {\n        case Seq(so) if so.contains(\"DESC\") => true\n        case _                              => false\n      }\n\n      // assert row numbers are correct\n      assertRowNumbers(df, rows, desc, rowNumberColumnName)\n      handle(df)\n    } finally {\n      // always unpersist\n      df.unpersist(true)\n    }\n  }\n\n  def assertRowNumbers(df: DataFrame, rows: Int, desc: Boolean, rowNumberColumnName: String): Unit = {\n    val expect = if (desc) {\n      $\"id\" === (lit(rows) - col(rowNumberColumnName) + 1)\n    } else {\n      $\"id\" === col(rowNumberColumnName)\n    }\n\n    val correctRowNumbers = df.where(expect).count()\n    val incorrectRowNumbers = df.where(!expect).count()\n    assert(correctRowNumbers === rows)\n    assert(incorrectRowNumbers === 0)\n  }\n\n  test(\".Net ticks to Spark timestamp / unix epoch seconds / nanoseconds\") {\n    val df = Seq(\n      (1, 599266080000000000L),\n      (2, 621355968000000000L),\n      (3, 638155413748959308L),\n      (4, 638155413748959309L),\n      (5, 638155413748959310L),\n      // results in largest possible unix epoch nanos\n      (6, 713589688368547758L),\n      (7, 3155378975999999999L)\n    ).toDF(\"id\", \"ts\")\n\n    val plan = df\n      .select(\n        $\"id\",\n        dotNetTicksToTimestamp($\"ts\"),\n        dotNetTicksToTimestamp(\"ts\"),\n        dotNetTicksToUnixEpoch($\"ts\"),\n        dotNetTicksToUnixEpoch(\"ts\"),\n        dotNetTicksToUnixEpochNanos($\"ts\"),\n        dotNetTicksToUnixEpochNanos(\"ts\")\n      )\n      .orderBy($\"id\")\n    assert(\n      plan.schema.fields.map(_.dataType) === Seq(\n        IntegerType,\n        TimestampType,\n        TimestampType,\n        DecimalType(29, 9),\n        DecimalType(29, 9),\n        LongType,\n        LongType\n      )\n    )\n\n    val actual = plan.collect()\n\n    assert(\n      actual.map(_.getTimestamp(1)) === Seq(\n        Timestamp.from(Instant.parse(\"1900-01-01T00:00:00Z\")),\n        Timestamp.from(Instant.parse(\"1970-01-01T00:00:00Z\")),\n        Timestamp.from(Instant.parse(\"2023-03-27T19:16:14.89593Z\")),\n        Timestamp.from(Instant.parse(\"2023-03-27T19:16:14.89593Z\")),\n        Timestamp.from(Instant.parse(\"2023-03-27T19:16:14.895931Z\")),\n        // largest possible unix epoch nanos\n        Timestamp.from(Instant.parse(\"2262-04-11T23:47:16.854775Z\")),\n        Timestamp.from(Instant.parse(\"9999-12-31T23:59:59.999999Z\")),\n      )\n    )\n    assert(actual.map(_.getTimestamp(2)) === actual.map(_.getTimestamp(1)))\n\n    assert(\n      actual.map(_.getDecimal(3)).map(BigDecimal(_)) === Array(\n        BigDecimal(-2208988800000000000L, 9),\n        BigDecimal(0, 9),\n        BigDecimal(1679944574895930800L, 9),\n        BigDecimal(1679944574895930900L, 9),\n        BigDecimal(1679944574895931000L, 9),\n        // largest possible unix epoch nanos\n        BigDecimal(9223372036854775800L, 9),\n        BigDecimal(2534023007999999999L, 7).setScale(9),\n      )\n    )\n    assert(actual.map(_.getDecimal(4)) === actual.map(_.getDecimal(3)))\n\n    assert(\n      actual.map(row =>\n        if (BigDecimal(row.getDecimal(3)) <= BigDecimal(9223372036854775800L, 9)) row.getLong(5) else null\n      ) === actual.map(row =>\n        if (BigDecimal(row.getDecimal(3)) <= BigDecimal(9223372036854775800L, 9))\n          row.getDecimal(3).multiply(new java.math.BigDecimal(1000000000)).longValue()\n        else null\n      )\n    )\n    assert(actual.map(_.get(6)) === actual.map(_.get(5)))\n  }\n\n  test(\"Spark timestamp to .Net ticks\") {\n    val df = Seq(\n      (1, Timestamp.from(Instant.parse(\"1900-01-01T00:00:00Z\"))),\n      (2, Timestamp.from(Instant.parse(\"1970-01-01T00:00:00Z\"))),\n      (3, Timestamp.from(Instant.parse(\"2023-03-27T19:16:14.895931Z\"))),\n      (4, Timestamp.from(Instant.parse(\"9999-12-31T23:59:59.999999Z\"))),\n    ).toDF(\"id\", \"ts\")\n\n    if (Some(spark.sparkContext.version).exists(_.startsWith(\"3.0.\"))) {\n      assertThrows[NotImplementedError] {\n        df.select(timestampToDotNetTicks($\"ts\"))\n      }\n    } else {\n      val plan = df\n        .select(\n          $\"id\",\n          timestampToDotNetTicks($\"ts\"),\n          timestampToDotNetTicks(\"ts\"),\n        )\n        .orderBy($\"id\")\n\n      assert(\n        plan.schema.fields.map(_.dataType) === Seq(\n          IntegerType,\n          LongType,\n          LongType\n        )\n      )\n\n      val actual = plan.collect()\n\n      assert(\n        actual.map(_.getLong(1)) === Seq(\n          599266080000000000L,\n          621355968000000000L,\n          638155413748959310L,\n          3155378975999999990L\n        )\n      )\n      assert(actual.map(_.getLong(2)) === actual.map(_.getLong(1)))\n\n      val message = intercept[AnalysisException] {\n        Seq(1L).toDF(\"ts\").select(timestampToDotNetTicks($\"ts\")).collect()\n      }.getMessage\n\n      // SparkMajorVersion == 2 no supported\n      if (SparkMajorVersion == 3 && SparkMinorVersion < 4) {\n        assert(\n          message.startsWith(\n            \"cannot resolve 'unix_micros(ts)' due to data type mismatch: argument 1 requires timestamp type, however, 'ts' is of bigint type.;\"\n          )\n        )\n      } else if (SparkMajorVersion == 3 && SparkMinorVersion >= 4) {\n        assert(\n          message.startsWith(\n            \"[DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve \\\"unix_micros(ts)\\\" due to data type mismatch: Parameter 1 requires the \\\"TIMESTAMP\\\" type, however \\\"ts\\\" has the type \\\"BIGINT\\\".\"\n          )\n        )\n      } else { // SparkMajorVersion > 3\n        assert(\n          message.startsWith(\n            \"[DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve \\\"unix_micros(ts)\\\" due to data type mismatch: The first parameter requires the \\\"TIMESTAMP\\\" type, however \\\"ts\\\" has the type \\\"BIGINT\\\".\"\n          )\n        )\n      }\n    }\n  }\n\n  test(\"Unix epoch to .Net ticks\") {\n    def df[T: Encoder](v: T): DataFrame =\n      spark.createDataset(Seq(v)).withColumnRenamed(\"value\", \"ts\")\n\n    Seq(\n      df(BigDecimal(1679944574895931234L, 9)),\n      df(\"1679944574.895931234\"),\n      df(1679944574.895931234),\n      df(1679944574L),\n      df(1679944574),\n    ).foreach { df =>\n      this.withClue(df.schema.fields.head.dataType) {\n        val plan = df.select(\n          unixEpochToDotNetTicks($\"ts\"),\n          unixEpochToDotNetTicks(\"ts\")\n        )\n        assert(plan.schema.fields.map(_.dataType) === Seq(LongType, LongType))\n\n        val actual = plan.collect()\n\n        assert(actual.length === 1)\n        assert(actual.head.isNullAt(0) === false)\n        assert(actual.head.isNullAt(1) === false)\n\n        if (Set(IntegerType, LongType).map(_.asInstanceOf[DataType]).contains(df.schema.head.dataType)) {\n          // long and integer also works, but without sub-second precision\n          assert(actual.head.getLong(0) === 638155413740000000L)\n        } else {\n          // lowest two nanosecond digits get lost\n          assert(actual.head.getLong(0) === 638155413748959312L)\n        }\n        assert(actual.head.getLong(1) === actual.head.getLong(0))\n      }\n    }\n  }\n\n  test(\"Unix epoch nanos to .Net ticks\") {\n    def df[T: Encoder](v: T): DataFrame =\n      spark.createDataset(Seq(v)).withColumnRenamed(\"value\", \"ts\")\n\n    Seq(\n      df(BigDecimal(1679944574895931234L)),\n      df(\"1679944574895931234\"),\n      df(1679944574895931234L),\n      df(1679944574895931234.5),\n    ).foreach { df =>\n      this.withClue(df.schema.fields.head.dataType) {\n        val plan = df.select(\n          unixEpochNanosToDotNetTicks($\"ts\"),\n          unixEpochNanosToDotNetTicks(\"ts\")\n        )\n        assert(plan.schema.fields.map(_.dataType) === Seq(LongType, LongType))\n\n        val actual = plan.collect()\n\n        assert(actual.length === 1)\n        assert(actual.head.isNullAt(0) === false)\n        assert(actual.head.isNullAt(1) === false)\n        if (df.schema.fields.head.dataType == DoubleType) {\n          // The initial double value can represent the epoch nanos only as 1.67994457489593114E18\n          assert(actual.head.getLong(0) === 638155413748959311L)\n        } else {\n          assert(actual.head.getLong(0) === 638155413748959312L)\n        }\n        assert(actual.head.getLong(1) === actual.head.getLong(0))\n      }\n    }\n  }\n\n  test(\"distinct prefix for\") {\n    assert(distinctPrefixFor(Seq.empty[String]) === \"_\")\n    assert(distinctPrefixFor(Seq(\"a\")) === \"_\")\n    assert(distinctPrefixFor(Seq(\"abc\")) === \"_\")\n    assert(distinctPrefixFor(Seq(\"a\", \"bc\", \"def\")) === \"_\")\n    assert(distinctPrefixFor(Seq(\"_a\")) === \"__\")\n    assert(distinctPrefixFor(Seq(\"_abc\")) === \"__\")\n    assert(distinctPrefixFor(Seq(\"a\", \"_bc\", \"__def\")) === \"___\")\n  }\n\n  test(\"Spark temp dir\") {\n    val dir = createTemporaryDir(\"test\")\n    assert(Paths.get(dir).toAbsolutePath.toString.startsWith(SparkFiles.getRootDirectory()))\n  }\n}\n\nobject SparkSuite {\n  case class Value(id: Int, string: String)\n\n  def collectJobDescription(spark: SparkSession): Array[(Long, Long, String)] = {\n    import spark.implicits._\n    spark\n      .range(0, 3, 1, 3)\n      .mapPartitions(it =>\n        it.map(id => (id, TaskContext.get().partitionId(), TaskContext.get().getLocalProperty(\"spark.job.description\")))\n      )\n      .as[(Long, Long, String)]\n      .sort()\n      .collect()\n  }\n\n}\n"
  },
  {
    "path": "src/test/scala/uk/co/gresearch/spark/SparkTestSession.scala",
    "content": "/*\n * Copyright 2020 G-Research\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage uk.co.gresearch.spark\n\nimport org.apache.spark.SparkContext\nimport org.apache.spark.sql.catalyst.plans.SQLHelper\nimport org.apache.spark.sql.{SQLContext, SparkSession}\n\ntrait SparkTestSession extends SQLHelper {\n\n  lazy val spark: SparkSession = {\n    SparkSession\n      .builder()\n      .master(\"local[1]\")\n      .appName(\"spark test example\")\n      .config(\"spark.sql.shuffle.partitions\", 2)\n      .config(\"spark.local.dir\", \".\")\n      .enableHiveSupport()\n      .getOrCreate()\n  }\n\n  lazy val sc: SparkContext = spark.sparkContext\n\n  lazy val sql: SQLContext = spark.sqlContext\n\n}\n"
  },
  {
    "path": "src/test/scala/uk/co/gresearch/spark/WritePartitionedSuite.scala",
    "content": "/*\n * Copyright 2020 G-Research\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage uk.co.gresearch.spark\n\nimport org.apache.spark.sql.Dataset\nimport org.apache.spark.sql.functions.{col, reverse}\nimport org.apache.spark.sql.internal.SQLConf\nimport uk.co.gresearch.spark.WritePartitionedSuite.Value\nimport uk.co.gresearch.spark.UnpersistHandle.withUnpersist\nimport uk.co.gresearch.test.Suite\n\nimport java.io.File\nimport java.sql.Date\nimport scala.io.Source\n\nclass WritePartitionedSuite extends Suite with SparkTestSession {\n\n  import spark.implicits._\n\n  val values: Dataset[Value] = Seq(\n    Value(1, Date.valueOf(\"2020-07-01\"), \"one\"),\n    Value(1, Date.valueOf(\"2020-07-02\"), \"One\"),\n    Value(1, Date.valueOf(\"2020-07-03\"), \"ONE\"),\n    Value(1, Date.valueOf(\"2020-07-04\"), \"one\"),\n    Value(2, Date.valueOf(\"2020-07-01\"), \"two\"),\n    Value(2, Date.valueOf(\"2020-07-02\"), \"Two\"),\n    Value(2, Date.valueOf(\"2020-07-03\"), \"TWO\"),\n    Value(2, Date.valueOf(\"2020-07-04\"), \"two\"),\n    Value(3, Date.valueOf(\"2020-07-01\"), \"three\"),\n    Value(4, Date.valueOf(\"2020-07-01\"), \"four\")\n  ).toDS()\n\n  test(\"write partitionedBy requires caching with AQE enabled\") {\n    withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> \"true\") {\n      Some(spark.version)\n        .map(version => Set(\"3.0.\", \"3.1.\", \"3.2.0\", \"3.2.1\", \"3.2.2\", \"3.3.0\", \"3.3.1\").exists(version.startsWith))\n        .foreach(expected => assert(writePartitionedByRequiresCaching(values) === expected))\n    }\n  }\n\n  test(\"write partitionedBy requires no caching with AQE disabled\") {\n    withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> \"false\") {\n      assert(writePartitionedByRequiresCaching(values) === false)\n    }\n  }\n\n  test(\"write with one partition column\") {\n    withTempPath { dir =>\n      withUnpersist() { handle =>\n        values.writePartitionedBy(Seq($\"id\"), unpersistHandle = Some(handle)).csv(dir.getAbsolutePath)\n      }\n\n      val partitions = dir.list().filter(_.startsWith(\"id=\")).sorted\n      assert(partitions === Seq(\"id=1\", \"id=2\", \"id=3\", \"id=4\"))\n      partitions.foreach { partition =>\n        val files = new File(dir, partition).list().filter(file => file.startsWith(\"part-\") && file.endsWith(\".csv\"))\n        assert(files.length === 1)\n      }\n    }\n  }\n\n  test(\"write with two partition column\") {\n    withTempPath { dir =>\n      withUnpersist() { handle =>\n        values.writePartitionedBy(Seq($\"id\", $\"date\"), unpersistHandle = Some(handle)).csv(dir.getAbsolutePath)\n      }\n\n      val ids = dir.list().filter(_.startsWith(\"id=\")).sorted\n      assert(ids === Seq(\"id=1\", \"id=2\", \"id=3\", \"id=4\"))\n      val dates = ids.flatMap { id =>\n        new File(dir, id).list().filter(file => file.startsWith(\"date=\"))\n      }.toSet\n      assert(dates === Set(\"date=2020-07-01\", \"date=2020-07-04\", \"date=2020-07-02\", \"date=2020-07-03\"))\n    }\n  }\n\n  test(\"write with more partition columns\") {\n    withTempPath { dir =>\n      withUnpersist() { handle =>\n        values.writePartitionedBy(Seq($\"id\"), Seq($\"date\"), unpersistHandle = Some(handle)).csv(dir.getAbsolutePath)\n      }\n\n      val partitions = dir.list().filter(_.startsWith(\"id=\")).sorted\n      assert(partitions === Seq(\"id=1\", \"id=2\", \"id=3\", \"id=4\"))\n      partitions.foreach { partition =>\n        val files = new File(dir, partition).list().filter(file => file.startsWith(\"part-\") && file.endsWith(\".csv\"))\n        files.foreach(println)\n        assert(files.length >= 1 && files.length <= 2)\n      }\n    }\n  }\n\n  test(\"write with one partition\") {\n    withTempPath { dir =>\n      withUnpersist() { handle =>\n        values\n          .writePartitionedBy(Seq($\"id\"), Seq($\"date\"), partitions = Some(1), unpersistHandle = Some(handle))\n          .csv(dir.getAbsolutePath)\n      }\n\n      val partitions = dir.list().filter(_.startsWith(\"id=\")).sorted\n      assert(partitions === Seq(\"id=1\", \"id=2\", \"id=3\", \"id=4\"))\n      val files = partitions.flatMap { partition =>\n        new File(dir, partition).list().filter(file => file.startsWith(\"part-\") && file.endsWith(\".csv\"))\n      }\n      assert(files.toSet.size === 1)\n    }\n  }\n\n  test(\"write with partition order\") {\n    withTempPath { dir =>\n      withUnpersist() { handle =>\n        values\n          .writePartitionedBy(Seq($\"id\"), Seq.empty, Seq($\"date\"), unpersistHandle = Some(handle))\n          .csv(dir.getAbsolutePath)\n      }\n\n      val partitions = dir.list().filter(_.startsWith(\"id=\")).sorted\n      assert(partitions === Seq(\"id=1\", \"id=2\", \"id=3\", \"id=4\"))\n      partitions.foreach { partition =>\n        val file = new File(dir, partition)\n        val files = file.list().filter(file => file.startsWith(\"part-\") && file.endsWith(\".csv\"))\n        assert(files.length === 1)\n\n        val source = Source.fromFile(new File(file, files(0)))\n        val lines =\n          try source.getLines().toList\n          finally source.close()\n        partition match {\n          case \"id=1\" =>\n            assert(\n              lines === Seq(\n                \"2020-07-01,one\",\n                \"2020-07-02,One\",\n                \"2020-07-03,ONE\",\n                \"2020-07-04,one\"\n              )\n            )\n          case \"id=2\" =>\n            assert(\n              lines === Seq(\n                \"2020-07-01,two\",\n                \"2020-07-02,Two\",\n                \"2020-07-03,TWO\",\n                \"2020-07-04,two\"\n              )\n            )\n          case \"id=3\" =>\n            assert(\n              lines === Seq(\n                \"2020-07-01,three\"\n              )\n            )\n          case \"id=4\" =>\n            assert(\n              lines === Seq(\n                \"2020-07-01,four\"\n              )\n            )\n        }\n      }\n    }\n  }\n\n  test(\"write with desc partition order\") {\n    withTempPath { dir =>\n      withUnpersist() { handle =>\n        values\n          .writePartitionedBy(Seq($\"id\"), Seq.empty, Seq($\"date\".desc), unpersistHandle = Some(handle))\n          .csv(dir.getAbsolutePath)\n      }\n\n      val partitions = dir.list().filter(_.startsWith(\"id=\")).sorted\n      assert(partitions === Seq(\"id=1\", \"id=2\", \"id=3\", \"id=4\"))\n      partitions.foreach { partition =>\n        val file = new File(dir, partition)\n        val files = file.list().filter(file => file.startsWith(\"part-\") && file.endsWith(\".csv\"))\n        assert(files.length === 1)\n\n        val source = Source.fromFile(new File(file, files(0)))\n        val lines =\n          try source.getLines().toList\n          finally source.close()\n        partition match {\n          case \"id=1\" =>\n            assert(\n              lines === Seq(\n                \"2020-07-04,one\",\n                \"2020-07-03,ONE\",\n                \"2020-07-02,One\",\n                \"2020-07-01,one\"\n              )\n            )\n          case \"id=2\" =>\n            assert(\n              lines === Seq(\n                \"2020-07-04,two\",\n                \"2020-07-03,TWO\",\n                \"2020-07-02,Two\",\n                \"2020-07-01,two\"\n              )\n            )\n          case \"id=3\" =>\n            assert(\n              lines === Seq(\n                \"2020-07-01,three\"\n              )\n            )\n          case \"id=4\" =>\n            assert(\n              lines === Seq(\n                \"2020-07-01,four\"\n              )\n            )\n        }\n      }\n    }\n  }\n\n  test(\"write with write projection\") {\n    val projection = Some(Seq(col(\"id\"), reverse(col(\"value\"))))\n    withTempPath { path =>\n      withUnpersist() { handle =>\n        values\n          .writePartitionedBy(\n            Seq($\"id\"),\n            Seq.empty,\n            Seq($\"date\"),\n            writtenProjection = projection,\n            unpersistHandle = Some(handle)\n          )\n          .csv(path.getAbsolutePath)\n      }\n\n      val partitions = path.list().filter(_.startsWith(\"id=\")).sorted\n      assert(partitions === Seq(\"id=1\", \"id=2\", \"id=3\", \"id=4\"))\n      partitions.foreach { partition =>\n        val dir = new File(path, partition)\n        val files = dir.list().filter(file => file.startsWith(\"part-\") && file.endsWith(\".csv\"))\n        assert(files.length === 1)\n\n        val lines = files.flatMap { file =>\n          val source = Source.fromFile(new File(dir, file))\n          try source.getLines().toList\n          finally source.close()\n        }\n\n        partition match {\n          case \"id=1\" => assert(lines === Seq(\"eno\", \"enO\", \"ENO\", \"eno\"))\n          case \"id=2\" => assert(lines === Seq(\"owt\", \"owT\", \"OWT\", \"owt\"))\n          case \"id=3\" => assert(lines === Seq(\"eerht\"))\n          case \"id=4\" => assert(lines === Seq(\"ruof\"))\n        }\n      }\n\n    }\n  }\n\n  test(\"write with un-named partition columns\") {\n    assertThrows[IllegalArgumentException] {\n      values.writePartitionedBy(Seq($\"id\" + 1))\n    }\n  }\n\n  test(\"write dataframe\") {\n    withTempPath { dir =>\n      withUnpersist() { handle =>\n        values.toDF().writePartitionedBy(Seq($\"id\", $\"date\"), unpersistHandle = Some(handle)).csv(dir.getAbsolutePath)\n      }\n\n      val ids = dir.list().filter(_.startsWith(\"id=\")).sorted\n      assert(ids === Seq(\"id=1\", \"id=2\", \"id=3\", \"id=4\"))\n      val dates = ids.flatMap { id =>\n        new File(dir, id).list().filter(file => file.startsWith(\"date=\"))\n      }.toSet\n      assert(dates === Set(\"date=2020-07-01\", \"date=2020-07-04\", \"date=2020-07-02\", \"date=2020-07-03\"))\n    }\n  }\n\n}\n\nobject WritePartitionedSuite {\n  case class Value(id: Int, date: Date, value: String)\n}\n"
  },
  {
    "path": "src/test/scala/uk/co/gresearch/spark/diff/AppSuite.scala",
    "content": "/*\n * Copyright 2023 G-Research\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage uk.co.gresearch.spark.diff\n\nimport org.apache.spark.sql.SaveMode\nimport uk.co.gresearch.spark.SparkTestSession\nimport uk.co.gresearch.test.Suite\n\nimport java.io.File\n\nclass AppSuite extends Suite with SparkTestSession {\n\n  import spark.implicits._\n\n  test(\"run app with file and hive table\") {\n    withTempPath { path =>\n      // write left dataframe as csv\n      val leftPath = new File(path, \"left.csv\").getAbsolutePath\n      DiffSuite.left(spark).write.csv(leftPath)\n\n      // write right dataframe as parquet table\n      DiffSuite.right(spark).write.format(\"parquet\").mode(SaveMode.Overwrite).saveAsTable(\"right_parquet\")\n\n      // launch app\n      val jsonPath = new File(path, \"diff.json\").getAbsolutePath\n      App.main(\n        Array(\n          \"--left-format\",\n          \"csv\",\n          \"--left-schema\",\n          \"id int, value string\",\n          \"--output-format\",\n          \"json\",\n          \"--id\",\n          \"id\",\n          leftPath,\n          \"right_parquet\",\n          jsonPath\n        )\n      )\n\n      // assert written diff\n      val actual = spark.read.json(jsonPath)\n      assert(actual.orderBy($\"id\").collect() === DiffSuite.expectedDiff)\n    }\n  }\n\n  Seq(Set(\"I\"), Set(\"C\"), Set(\"D\"), Set(\"N\"), Set(\"I\", \"C\", \"D\")).foreach { filter =>\n    test(s\"run app with filter ${filter.mkString(\"[\", \",\", \"]\")}\") {\n      withTempPath { path =>\n        // write left dataframe as parquet\n        val leftPath = new File(path, \"left.parquet\").getAbsolutePath\n        DiffSuite.left(spark).write.parquet(leftPath)\n\n        // write right dataframe as csv\n        val rightPath = new File(path, \"right.parquet\").getAbsolutePath\n        DiffSuite.right(spark).write.parquet(rightPath)\n\n        // launch app\n        val outputPath = new File(path, \"diff.parquet\").getAbsolutePath\n        App.main(\n          Array(\n            \"--format\",\n            \"parquet\",\n            \"--id\",\n            \"id\",\n          ) ++ filter.toSeq.flatMap(f => Array(\"--filter\", f)) ++ Array(\n            leftPath,\n            rightPath,\n            outputPath\n          )\n        )\n\n        // assert written diff\n        val actual = spark.read.parquet(outputPath).orderBy($\"id\").collect()\n        val expected = DiffSuite.expectedDiff.filter(row => filter.contains(row.getString(0)))\n        assert(actual === expected)\n        assert(expected.nonEmpty)\n      }\n    }\n  }\n\n  test(s\"run app with unknown filter\") {\n    withTempPath { path =>\n      // write left dataframe as parquet\n      val leftPath = new File(path, \"left.parquet\").getAbsolutePath\n      DiffSuite.left(spark).write.parquet(leftPath)\n\n      // write right dataframe as csv\n      val rightPath = new File(path, \"right.parquet\").getAbsolutePath\n      DiffSuite.right(spark).write.parquet(rightPath)\n\n      // launch app\n      val outputPath = new File(path, \"diff.parquet\").getAbsolutePath\n      assertThrows[RuntimeException](\n        App.main(\n          Array(\n            \"--format\",\n            \"parquet\",\n            \"--id\",\n            \"id\",\n            \"--filter\",\n            \"A\",\n            leftPath,\n            rightPath,\n            outputPath\n          )\n        )\n      )\n    }\n  }\n\n  test(\"run app writing stats\") {\n    withTempPath { path =>\n      // write left dataframe as parquet\n      val leftPath = new File(path, \"left.parquet\").getAbsolutePath\n      DiffSuite.left(spark).write.parquet(leftPath)\n\n      // write right dataframe as csv\n      val rightPath = new File(path, \"right.parquet\").getAbsolutePath\n      DiffSuite.right(spark).write.parquet(rightPath)\n\n      // launch app\n      val outputPath = new File(path, \"diff.parquet\").getAbsolutePath\n      App.main(\n        Array(\n          \"--format\",\n          \"parquet\",\n          \"--statistics\",\n          \"--id\",\n          \"id\",\n          leftPath,\n          rightPath,\n          outputPath\n        )\n      )\n\n      // assert written diff\n      val actual = spark.read.parquet(outputPath).as[(String, Long)].collect().toMap\n      val expected = DiffSuite.expectedDiff.groupBy(row => row.getString(0)).mapValues(_.length).toMap\n      assert(actual === expected)\n    }\n  }\n}\n"
  },
  {
    "path": "src/test/scala/uk/co/gresearch/spark/diff/DiffComparatorSuite.scala",
    "content": "/*\n * Copyright 2020 G-Research\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage uk.co.gresearch.spark.diff\n\nimport org.apache.spark.sql._\nimport org.apache.spark.sql.catalyst.encoders.ExpressionEncoder\nimport org.apache.spark.sql.functions.{abs, lit, when}\nimport org.apache.spark.sql.internal.SQLConf\nimport org.apache.spark.sql.types._\nimport org.apache.spark.unsafe.types.UTF8String\nimport uk.co.gresearch.spark.SparkTestSession\nimport uk.co.gresearch.spark.diff.DiffComparatorSuite.{\n  decimalEnc,\n  optionsWithRelaxedComparators,\n  optionsWithTightComparators\n}\nimport uk.co.gresearch.spark.diff.comparator._\nimport uk.co.gresearch.test.Suite\n\nimport java.sql.{Date, Timestamp}\nimport java.time.Duration\nimport java.util\n\ncase class Numbers(\n    id: Int,\n    longValue: Long,\n    floatValue: Float,\n    doubleValue: Double,\n    decimalValue: Decimal,\n    someInt: Option[Int],\n    someLong: Option[Long]\n)\ncase class Strings(id: Int, string: String)\ncase class Dates(id: Int, date: Date)\ncase class Times(id: Int, time: Timestamp)\ncase class Maps(id: Int, map: Map[Int, Long])\n\nclass DiffComparatorSuite extends Suite with SparkTestSession {\n\n  import spark.implicits._\n\n  lazy val left: Dataset[Numbers] = Seq(\n    Numbers(1, 1L, 1.0f, 1.0, Decimal(10, 8, 3), None, None),\n    Numbers(2, 2L, 2.0f, 2.0, Decimal(20, 8, 3), Some(2), Some(2L)),\n    Numbers(3, 3L, 3.0f, 3.0, Decimal(30, 8, 3), Some(3), Some(3L)),\n    Numbers(4, 4L, 4.0f, 4.0, Decimal(40, 8, 3), Some(4), None),\n    Numbers(5, 5L, 5.0f, 5.0, Decimal(50, 8, 3), None, Some(5L)),\n  ).toDS()\n\n  lazy val right: Dataset[Numbers] = Seq(\n    Numbers(1, 1L, 1.0f, 1.0, Decimal(10, 8, 3), None, None),\n    Numbers(2, 3L, 2.001f, 2.001, Decimal(21, 8, 3), Some(3), Some(3L)),\n    Numbers(3, 5L, 3.01f, 3.01, Decimal(32, 8, 3), Some(5), Some(5L)),\n    Numbers(4, 4L, 4.0f, 4.0, Decimal(40, 8, 3), None, Some(4L)),\n    Numbers(6, 6L, 6.0f, 6.0, Decimal(60, 8, 3), Some(6), Some(6L)),\n  ).toDS()\n\n  lazy val rightSign: Dataset[Numbers] = Seq(\n    Numbers(1, 1L, 1.0f, 1.0, Decimal(10, 8, 3), None, None),\n    Numbers(2, -2L, -2.0f, -2.0, Decimal(-20, 8, 3), Some(-2), Some(-2L)),\n    Numbers(3, -4L, -4.0f, -4.0, Decimal(-40, 8, 3), Some(-4), Some(-4L)),\n    Numbers(4, 4L, 4.0f, 4.0, Decimal(40, 8, 3), None, Some(4L)),\n    Numbers(6, 6L, 6.0f, 6.0, Decimal(60, 8, 3), Some(6), Some(6L)),\n  ).toDS()\n\n  lazy val leftStrings: DataFrame = Seq(\n    (1, Some(\"1\")),\n    (2, None),\n    (3, Some(\"3\")),\n    (4, None)\n  ).toDF(\"id\", \"string\")\n  lazy val rightStrings: DataFrame = Seq(\n    (1, Some(\"1\")),\n    (2, Some(\"2\")),\n    (3, None),\n    (4, None)\n  ).toDF(\"id\", \"string\")\n\n  lazy val leftDates: Dataset[Dates] = Seq(\n    Dates(1, Date.valueOf(\"2000-01-01\")),\n    Dates(2, Date.valueOf(\"2000-01-02\")),\n    Dates(3, Date.valueOf(\"2000-01-03\")),\n    Dates(4, Date.valueOf(\"2000-01-04\")),\n  ).toDS()\n\n  lazy val rightDates: Dataset[Dates] = Seq(\n    Dates(1, Date.valueOf(\"2000-01-01\")),\n    Dates(2, Date.valueOf(\"2000-01-03\")),\n    Dates(3, Date.valueOf(\"2000-01-03\")),\n    Dates(5, Date.valueOf(\"2000-01-05\")),\n  ).toDS()\n\n  lazy val leftTimes: Dataset[Times] = Seq(\n    Times(1, Timestamp.valueOf(\"2000-01-01 12:01:00\")),\n    Times(2, Timestamp.valueOf(\"2000-01-02 12:02:00\")),\n    Times(3, Timestamp.valueOf(\"2000-01-03 12:03:00\")),\n    Times(4, Timestamp.valueOf(\"2000-01-04 12:04:00\")),\n  ).toDS()\n\n  lazy val rightTimes: Dataset[Times] = Seq(\n    Times(1, Timestamp.valueOf(\"2000-01-01 12:01:00\")),\n    Times(2, Timestamp.valueOf(\"2000-01-02 12:03:00\")),\n    Times(3, Timestamp.valueOf(\"2000-01-03 12:03:00\")),\n    Times(5, Timestamp.valueOf(\"2000-01-04 12:05:00\")),\n  ).toDS()\n\n  lazy val leftMaps: Dataset[Maps] = Seq(\n    Maps(1, Map(1 -> 1L, 2 -> 2L, 3 -> 3L)),\n    Maps(2, Map(1 -> 2L, 2 -> 2L, 3 -> 3L)),\n    Maps(3, Map(1 -> 3L, 2 -> 2L, 3 -> 3L)),\n    Maps(4, Map(1 -> 4L, 2 -> 2L, 3 -> 3L)),\n    Maps(6, Map(1 -> 1L, 2 -> 2L, 3 -> 3L)),\n    Maps(7, Map(1 -> 4L, 2 -> 2L, 3 -> 3L)),\n  ).toDS()\n\n  lazy val rightMaps: Dataset[Maps] = Seq(\n    Maps(1, Map(1 -> 1L, 2 -> 2L, 3 -> 3L)),\n    Maps(2, Map(1 -> 2L, 2 -> 3L, 3 -> 3L)),\n    Maps(3, Map(1 -> 3L, 2 -> 2L, 4 -> 4L)),\n    Maps(5, Map(1 -> 4L, 2 -> 2L, 3 -> 3L)),\n    Maps(6, Map(3 -> 3L, 2 -> 2L, 1 -> 1L)),\n    Maps(7, Map(3 -> 4L, 2 -> 2L, 1 -> 1L)),\n  ).toDS()\n\n  def doTest(\n      optionsWithTightComparators: DiffOptions,\n      optionsWithRelaxedComparators: DiffOptions,\n      left: DataFrame = this.left.toDF(),\n      right: DataFrame = this.right.toDF()\n  ): Unit = {\n    // left and right numbers have some differences\n    val actualWithoutComparators = left.diff(right, \"id\").orderBy($\"id\")\n\n    // our tight comparators are just too strict to still see differences\n    val actualWithTightComparators = left.diff(right, optionsWithTightComparators, \"id\").orderBy($\"id\")\n    val expectedWithTightComparators = actualWithoutComparators\n    assert(actualWithTightComparators.collect() === expectedWithTightComparators.collect())\n\n    // the relaxed comparators are just relaxed enough to not see any differences\n    // they still see changes to / from null values\n    val actualWithRelaxedComparators = left.diff(right, optionsWithRelaxedComparators, \"id\").orderBy($\"id\")\n    val expectedWithRelaxedComparators = actualWithoutComparators\n      // the comparators are relaxed so that all changes disappear\n      .withColumn(\"diff\", when($\"id\" === 2, lit(\"N\")).otherwise($\"diff\"))\n    assert(actualWithRelaxedComparators.collect() === expectedWithRelaxedComparators.collect())\n  }\n\n  Seq(\"true\", \"false\").foreach { codegen =>\n    test(s\"diff with custom comparator - codegen enabled=$codegen\") {\n      withSQLConf(\n        SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegen,\n        SQLConf.CODEGEN_FALLBACK.key -> \"false\"\n      ) {\n        doTest(optionsWithTightComparators, optionsWithRelaxedComparators)\n      }\n    }\n  }\n\n  def alwaysTrueEquiv: math.Equiv[Any] = (_: Any, _: Any) => true\n\n  Seq(\n    \"default diff comparator\" -> DiffOptions.default\n      .withDefaultComparator((left: Column, right: Column) => abs(left) <=> abs(right)),\n    \"default encoder equiv\" -> DiffOptions.default\n      .withDefaultComparator((left: Int, right: Int) => left.abs == right.abs)\n      // the non-default comparator here are required because the default only supports int\n      // see \"encoder equiv\" test below\n      .withComparator((left: Long, right: Long) => left.abs == right.abs)\n      .withComparator((left: Float, right: Float) => left.abs == right.abs)\n      .withComparator((left: Double, right: Double) => left.abs == right.abs)\n      .withComparator((left: Decimal, right: Decimal) => left.abs == right.abs),\n    \"default typed equiv\" -> DiffOptions.default\n      .withDefaultComparator((left: Int, right: Int) => left.abs == right.abs, IntegerType)\n      // the non-default comparator here are required because the default only supports int\n      // see \"encoder equiv\" test below\n      .withComparator((left: Long, right: Long) => left.abs == right.abs)\n      .withComparator((left: Float, right: Float) => left.abs == right.abs)\n      .withComparator((left: Double, right: Double) => left.abs == right.abs)\n      .withComparator((left: Decimal, right: Decimal) => left.abs == right.abs),\n    \"default any equiv\" -> DiffOptions.default\n      .withDefaultComparator((_: Any, _: Any) => true),\n    \"typed diff comparator\" -> DiffOptions.default\n      .withComparator(EquivDiffComparator((left: Int, right: Int) => left.abs == right.abs))\n      .withComparator(\n        (left: Column, right: Column) => abs(left) <=> abs(right),\n        LongType,\n        FloatType,\n        DoubleType,\n        DecimalType(38, 18)\n      ),\n    \"typed diff comparator for type\" -> DiffOptions.default\n      // only works if data type is equal to input type of typed diff comparator\n      .withComparator(EquivDiffComparator((left: Int, right: Int) => left.abs == right.abs), IntegerType)\n      .withComparator(\n        (left: Column, right: Column) => abs(left) <=> abs(right),\n        LongType,\n        FloatType,\n        DoubleType,\n        DecimalType(38, 18)\n      ),\n    \"diff comparator for type\" -> DiffOptions.default\n      .withComparator((left: Column, right: Column) => abs(left) <=> abs(right), IntegerType)\n      .withComparator(\n        (left: Column, right: Column) => abs(left) <=> abs(right),\n        LongType,\n        FloatType,\n        DoubleType,\n        DecimalType(38, 18)\n      ),\n    \"diff comparator for name\" -> DiffOptions.default\n      .withComparator((left: Column, right: Column) => abs(left) <=> abs(right), \"someInt\")\n      .withComparator(\n        (left: Column, right: Column) => abs(left) <=> abs(right),\n        \"longValue\",\n        \"floatValue\",\n        \"doubleValue\",\n        \"someLong\",\n        \"decimalValue\"\n      ),\n    \"encoder equiv\" -> DiffOptions.default\n      .withComparator((left: Int, right: Int) => left.abs == right.abs)\n      .withComparator((left: Long, right: Long) => left.abs == right.abs)\n      .withComparator((left: Float, right: Float) => left.abs == right.abs)\n      .withComparator((left: Double, right: Double) => left.abs == right.abs)\n      .withComparator((left: Decimal, right: Decimal) => left.abs == right.abs),\n    \"encoder equiv for column name\" -> DiffOptions.default\n      .withComparator((left: Int, right: Int) => left.abs == right.abs, \"someInt\")\n      .withComparator((left: Long, right: Long) => left.abs == right.abs, \"longValue\", \"someLong\")\n      .withComparator((left: Float, right: Float) => left.abs == right.abs, \"floatValue\")\n      .withComparator((left: Double, right: Double) => left.abs == right.abs, \"doubleValue\")\n      .withComparator((left: Decimal, right: Decimal) => left.abs == right.abs, \"decimalValue\"),\n    \"equiv encoder for column name\" -> DiffOptions.default\n      .withComparator((left: Int, right: Int) => left.abs == right.abs, Encoders.scalaInt, \"someInt\")\n      .withComparator((left: Long, right: Long) => left.abs == right.abs, Encoders.scalaLong, \"longValue\", \"someLong\")\n      .withComparator((left: Float, right: Float) => left.abs == right.abs, Encoders.scalaFloat, \"floatValue\")\n      .withComparator((left: Double, right: Double) => left.abs == right.abs, Encoders.scalaDouble, \"doubleValue\")\n      .withComparator(\n        (left: Decimal, right: Decimal) => left.abs == right.abs,\n        ExpressionEncoder[Decimal](),\n        \"decimalValue\"\n      ),\n    \"typed equiv for type\" -> DiffOptions.default\n      .withComparator((left: Int, right: Int) => left.abs == right.abs, IntegerType)\n      .withComparator(alwaysTrueEquiv, LongType, FloatType, DoubleType, DecimalType(38, 18)),\n    \"any equiv for column name\" -> DiffOptions.default\n      .withComparator(alwaysTrueEquiv, \"someInt\")\n      .withComparator(alwaysTrueEquiv, \"longValue\", \"floatValue\", \"doubleValue\", \"someLong\", \"decimalValue\")\n  ).foreach { case (label, options) =>\n    test(s\"with comparator - $label\") {\n      val diffWithoutComparators = left.diff(rightSign, \"id\")\n\n      assert(diffWithoutComparators.where($\"diff\" === \"C\").count() === 3)\n\n      val allValuesEqual = Set(\"default any equiv\", \"any equiv for type\", \"any equiv for column name\").contains(label)\n      val unchangedIds = if (allValuesEqual) Seq(2, 3) else Seq(2)\n      val expected =\n        diffWithoutComparators.withColumn(\"diff\", when($\"id\".isin(unchangedIds: _*), lit(\"N\")).otherwise($\"diff\"))\n      assert(expected.where($\"diff\" === \"C\").count() === 3 - unchangedIds.size)\n\n      val actual = left.diff(rightSign, options, \"id\").orderBy($\"id\").collect()\n      assert(actual !== diffWithoutComparators.orderBy($\"id\").collect())\n      assert(actual === expected.orderBy($\"id\").collect())\n    }\n  }\n\n  test(\"null-aware comparator\") {\n    val options = DiffOptions.default.withComparator(\n      // only if this method is called with nulls, the expected result can occur\n      (x: Column, y: Column) => x.isNull || y.isNull || x === y,\n      StringType\n    )\n\n    val diff = leftStrings.diff(rightStrings, options, \"id\").orderBy($\"id\").collect()\n    assert(\n      diff === Seq(\n        Row(\"N\", 1, \"1\", \"1\"),\n        Row(\"N\", 2, null, \"2\"),\n        Row(\"N\", 3, \"3\", null),\n        Row(\"N\", 4, null, null),\n      )\n    )\n  }\n\n  Seq(\n    \"diff comparator\" -> (DiffOptions.default\n      .withDefaultComparator((_: Column, _: Column) => lit(1)),\n    Seq(\n      \"'(1 AND 1)' requires boolean type, not int\", // until Spark 3.3\n      \"\\\"(1 AND 1)\\\" due to data type mismatch: \" + // Spark 3.4 and beyond\n        \"the binary operator requires the input type \\\"BOOLEAN\\\", not \\\"INT\\\".\"\n    )),\n    \"encoder equiv\" -> (DiffOptions.default\n      .withDefaultComparator((_: Int, _: Int) => true),\n    Seq(\n      \"'(longValue ≡ longValue)' requires int type, not bigint\", // Spark 3.2 and 3.3\n      \"\\\"(longValue ≡ longValue)\\\" due to data type mismatch: \" + // Spark 3.4 and beyond\n        \"the binary operator requires the input type \\\"INT\\\", not \\\"BIGINT\\\".\"\n    )),\n    \"typed equiv\" -> (DiffOptions.default\n      .withDefaultComparator(EquivDiffComparator((left: Int, right: Int) => left.abs == right.abs, IntegerType)),\n    Seq(\n      \"'(longValue ≡ longValue)' requires int type, not bigint\", // Spark 3.2 and 3.3\n      \"\\\"(longValue ≡ longValue)\\\" due to data type mismatch: \" + // Spark 3.4 and beyond\n        \"the binary operator requires the input type \\\"INT\\\", not \\\"BIGINT\\\".\"\n    ))\n  ).foreach { case (label, (options, expecteds)) =>\n    test(s\"with comparator of incompatible type - $label\") {\n      val exception = intercept[AnalysisException] {\n        left.diff(right, options, \"id\")\n      }\n      assert(expecteds.nonEmpty)\n      assert(expecteds.exists(expected => exception.message.contains(expected)), exception.message)\n    }\n  }\n\n  test(\"absolute epsilon comparator (inclusive)\") {\n    val optionsWithTightComparator =\n      DiffOptions.default.withDefaultComparator(DiffComparators.epsilon(0.5).asAbsolute().asInclusive())\n    val optionsWithRelaxedComparator =\n      DiffOptions.default.withDefaultComparator(DiffComparators.epsilon(1.0).asAbsolute().asInclusive())\n    doTest(optionsWithTightComparator, optionsWithRelaxedComparator)\n  }\n\n  test(\"absolute epsilon comparator (exclusive)\") {\n    val optionsWithTightComparator =\n      DiffOptions.default.withDefaultComparator(DiffComparators.epsilon(1.0).asAbsolute().asExclusive())\n    val optionsWithRelaxedComparator =\n      DiffOptions.default.withDefaultComparator(DiffComparators.epsilon(1.001).asAbsolute().asExclusive())\n    doTest(optionsWithTightComparator, optionsWithRelaxedComparator)\n  }\n\n  test(\"relative epsilon comparator (inclusive)\") {\n    val optionsWithTightComparator =\n      DiffOptions.default.withDefaultComparator(DiffComparators.epsilon(0.1).asRelative().asInclusive())\n    val optionsWithRelaxedComparator =\n      DiffOptions.default.withDefaultComparator(DiffComparators.epsilon(1 / 3.0).asRelative().asInclusive())\n    doTest(optionsWithTightComparator, optionsWithRelaxedComparator)\n  }\n\n  test(\"relative epsilon comparator (exclusive)\") {\n    val optionsWithTightComparator =\n      DiffOptions.default.withDefaultComparator(DiffComparators.epsilon(1 / 3.0).asRelative().asExclusive())\n    val optionsWithRelaxedComparator =\n      DiffOptions.default.withDefaultComparator(DiffComparators.epsilon(1 / 3.0 + .001).asRelative().asExclusive())\n    doTest(optionsWithTightComparator, optionsWithRelaxedComparator)\n  }\n\n  test(\"whitespace agnostic string comparator\") {\n    val left = Seq(Strings(1, \"one\"), Strings(2, \"two spaces \"), Strings(3, \"three\"), Strings(4, \"four\")).toDF()\n    val right =\n      Seq(Strings(1, \"one\"), Strings(2, \" two \\t\\nspaces\"), Strings(3, \"three\\nspaces\"), Strings(5, \"five\")).toDF()\n    val optionsWithTightComparator =\n      DiffOptions.default.withComparator(DiffComparators.string(whitespaceAgnostic = false))\n    val optionsWithRelaxedComparator =\n      DiffOptions.default.withComparator(DiffComparators.string(whitespaceAgnostic = true))\n    doTest(optionsWithTightComparator, optionsWithRelaxedComparator, left, right)\n  }\n\n  if (DurationDiffComparator.isNotSupportedBySpark) {\n    test(\"duration comparator not supported\") {\n      assertThrows[UnsupportedOperationException] {\n        DiffComparators.duration(Duration.ofHours(1))\n      }\n      assertThrows[UnsupportedOperationException] {\n        DurationDiffComparator(Duration.ofHours(1))\n      }\n    }\n  } else {\n    test(\"duration comparator with date (inclusive)\") {\n      val optionsWithTightComparator =\n        DiffOptions.default.withComparator(DiffComparators.duration(Duration.ofHours(23)).asInclusive(), \"date\")\n      val optionsWithRelaxedComparator =\n        DiffOptions.default.withComparator(DiffComparators.duration(Duration.ofHours(24)).asInclusive(), \"date\")\n      doTest(optionsWithTightComparator, optionsWithRelaxedComparator, leftDates.toDF, rightDates.toDF)\n    }\n\n    test(\"duration comparator with date (exclusive)\") {\n      val optionsWithTightComparator =\n        DiffOptions.default.withComparator(DiffComparators.duration(Duration.ofHours(24)).asExclusive(), \"date\")\n      val optionsWithRelaxedComparator =\n        DiffOptions.default.withComparator(DiffComparators.duration(Duration.ofHours(25)).asExclusive(), \"date\")\n      doTest(optionsWithTightComparator, optionsWithRelaxedComparator, leftDates.toDF, rightDates.toDF)\n    }\n\n    test(\"duration comparator with time (inclusive)\") {\n      val optionsWithTightComparator =\n        DiffOptions.default.withComparator(DiffComparators.duration(Duration.ofSeconds(59)).asInclusive(), \"time\")\n      val optionsWithRelaxedComparator =\n        DiffOptions.default.withComparator(DiffComparators.duration(Duration.ofSeconds(60)).asInclusive(), \"time\")\n      doTest(optionsWithTightComparator, optionsWithRelaxedComparator, leftTimes.toDF, rightTimes.toDF)\n    }\n\n    test(\"duration comparator with time (exclusive)\") {\n      val optionsWithTightComparator =\n        DiffOptions.default.withComparator(DiffComparators.duration(Duration.ofSeconds(60)).asExclusive(), \"time\")\n      val optionsWithRelaxedComparator =\n        DiffOptions.default.withComparator(DiffComparators.duration(Duration.ofSeconds(61)).asExclusive(), \"time\")\n      doTest(optionsWithTightComparator, optionsWithRelaxedComparator, leftTimes.toDF, rightTimes.toDF)\n    }\n\n    test(\"changeset accounts for comparators\") {\n      val changesetOptions = DiffOptions.default\n        .withComparator(DiffComparators.epsilon(10).asAbsolute().asInclusive(), \"longValue\")\n        .withChangeColumn(\"changeset\")\n\n      lazy val left: Dataset[Numbers] = Seq(\n        Numbers(1, 1L, 1.0f, 1.0, Decimal(10, 8, 3), None, None),\n        Numbers(2, 2L, 2.0f, 2.0, Decimal(20, 8, 3), Some(2), Some(2L)),\n        Numbers(3, 3L, 3.0f, 3.0, Decimal(30, 8, 3), Some(3), Some(3L)),\n        Numbers(4, 4L, 4.0f, 4.0, Decimal(40, 8, 3), Some(4), None),\n        Numbers(5, 5L, 5.0f, 5.0, Decimal(50, 8, 3), None, Some(5L)),\n      ).toDS()\n\n      lazy val right: Dataset[Numbers] = Seq(\n        Numbers(1, 1L, 1.0f, 1.0, Decimal(10, 8, 3), None, None),\n        Numbers(2, 8L, 2.0f, 2.0, Decimal(20, 8, 3), Some(2), Some(2L)),\n        Numbers(3, 9L, 6.0f, 3.0, Decimal(30, 8, 3), Some(3), Some(3L)),\n        Numbers(4, 10L, 4.0f, 4.0, Decimal(40, 8, 3), Some(4), None),\n        Numbers(5, 11L, 5.0f, 5.0, Decimal(50, 8, 3), None, Some(5L)),\n      ).toDS()\n\n      val rs = left.diff(right, changesetOptions, \"id\").where($\"diff\" === \"C\")\n      assert(rs.count() == 1, \"Only one row should differ with the numeric comparator applied\")\n      val changesInDifferingRow: util.List[String] = rs.head.getList[String](1)\n      assert(\n        changesInDifferingRow.get(0) == \"floatValue\",\n        \"Only floatVal differs after considering the comparators so the changeset should be size 1\"\n      )\n    }\n  }\n\n  Seq(true, false).foreach { sensitive =>\n    Seq(true, false).foreach { codegen =>\n      Seq(true, false).foreach { typed =>\n        val typedLabel = if (typed) \"typed\" else \"untyped\"\n\n        test(s\"map comparator $typedLabel - keyOrderSensitive=$sensitive - codegen enabled=$codegen\") {\n          withSQLConf(\n            SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegen.toString,\n            SQLConf.CODEGEN_FALLBACK.key -> \"false\"\n          ) {\n            val options =\n              if (typed) {\n                DiffOptions.default.withComparator(DiffComparators.map[Int, Long](sensitive), \"map\")\n              } else {\n                DiffOptions.default.withComparator(DiffComparators.map(IntegerType, LongType, sensitive), \"map\")\n              }\n\n            val actual = leftMaps.diff(rightMaps, options, \"id\").orderBy($\"id\").collect()\n            val diffs =\n              Seq((1, \"N\"), (2, \"C\"), (3, \"C\"), (4, \"D\"), (5, \"I\"), (6, if (sensitive) \"C\" else \"N\"), (7, \"C\"))\n                .toDF(\"id\", \"diff\")\n            val expected = leftMaps\n              .withColumnRenamed(\"map\", \"left_map\")\n              .join(rightMaps.withColumnRenamed(\"map\", \"right_map\"), Seq(\"id\"), \"fullouter\")\n              .join(diffs, \"id\")\n              .select($\"diff\", $\"id\", $\"left_map\", $\"right_map\")\n              .orderBy($\"id\")\n              .collect()\n            assert(actual === expected)\n          }\n        }\n      }\n    }\n  }\n\n  case object IntEquiv extends math.Equiv[Int] {\n    override def equiv(x: Int, y: Int): Boolean = true\n  }\n\n  case object AnyEquiv extends math.Equiv[Any] {\n    override def equiv(x: Any, y: Any): Boolean = true\n  }\n\n  val diffComparatorMethodTests: Seq[(String, (() => DiffComparator, DiffComparator))] =\n    if (DurationDiffComparator.isSupportedBySpark) {\n      Seq(\n        \"duration\" -> (() => DiffComparators.duration(Duration.ofSeconds(1)).asExclusive(), DurationDiffComparator(\n          Duration.ofSeconds(1),\n          inclusive = false\n        ))\n      )\n    } else\n      { Seq.empty } ++ Seq(\n        \"default\" -> (() => DiffComparators.default(), DefaultDiffComparator),\n        \"nullSafeEqual\" -> (() => DiffComparators.nullSafeEqual(), NullSafeEqualDiffComparator),\n        \"equiv with encoder\" -> (() => DiffComparators.equiv(IntEquiv), EquivDiffComparator(IntEquiv)),\n        \"equiv with type\" -> (() =>\n          DiffComparators.equiv(IntEquiv, IntegerType), EquivDiffComparator(IntEquiv, IntegerType)),\n        \"equiv with any\" -> (() => DiffComparators.equiv(AnyEquiv), EquivDiffComparator(AnyEquiv)),\n        \"epsilon\" -> (() => DiffComparators.epsilon(1.0).asAbsolute().asExclusive(), EpsilonDiffComparator(\n          1.0,\n          relative = false,\n          inclusive = false\n        ))\n      )\n\n  diffComparatorMethodTests.foreach { case (label, (method, expected)) =>\n    test(s\"DiffComparator.$label\") {\n      val actual = method()\n      assert(actual === expected)\n    }\n  }\n}\n\nobject DiffComparatorSuite {\n  implicit val intEnc: Encoder[Int] = Encoders.scalaInt\n  implicit val longEnc: Encoder[Long] = Encoders.scalaLong\n  implicit val floatEnc: Encoder[Float] = Encoders.scalaFloat\n  implicit val doubleEnc: Encoder[Double] = Encoders.scalaDouble\n  implicit val decimalEnc: Encoder[Decimal] = ExpressionEncoder()\n\n  val tightIntComparator: EquivDiffComparator[Int] = EquivDiffComparator((x: Int, y: Int) => math.abs(x - y) < 1)\n  val tightLongComparator: EquivDiffComparator[Long] = EquivDiffComparator((x: Long, y: Long) => math.abs(x - y) < 1)\n  val tightFloatComparator: EquivDiffComparator[Float] =\n    EquivDiffComparator((x: Float, y: Float) => math.abs(x - y) < 0.001)\n  val tightDoubleComparator: EquivDiffComparator[Double] =\n    EquivDiffComparator((x: Double, y: Double) => math.abs(x - y) < 0.001)\n  val tightDecimalComparator: EquivDiffComparator[Decimal] =\n    EquivDiffComparator[Decimal]((x: Decimal, y: Decimal) => (x - y).abs < Decimal(0.001))\n\n  val optionsWithTightComparators: DiffOptions = DiffOptions.default\n    .withComparator(tightIntComparator, IntegerType)\n    .withComparator(tightLongComparator, LongType)\n    .withComparator(tightFloatComparator, \"floatValue\")\n    .withComparator(tightDoubleComparator, \"doubleValue\")\n    .withComparator(tightDecimalComparator, DecimalType(38, 18))\n\n  val relaxedIntComparator: EquivDiffComparator[Int] = EquivDiffComparator((x: Int, y: Int) => math.abs(x - y) <= 1)\n  val relaxedLongComparator: EquivDiffComparator[Long] = EquivDiffComparator((x: Long, y: Long) => math.abs(x - y) <= 1)\n  val relaxedFloatComparator: EquivDiffComparator[Float] =\n    EquivDiffComparator((x: Float, y: Float) => math.abs(x - y) <= 0.001)\n  val relaxedDoubleComparator: EquivDiffComparator[Double] =\n    EquivDiffComparator((x: Double, y: Double) => math.abs(x - y) <= 0.001)\n  val relaxedDecimalComparator: EquivDiffComparator[Decimal] =\n    EquivDiffComparator[Decimal]((x: Decimal, y: Decimal) => (x - y).abs <= Decimal(0.001))\n\n  val optionsWithRelaxedComparators: DiffOptions = DiffOptions.default\n    .withComparator(relaxedIntComparator, IntegerType)\n    .withComparator(relaxedLongComparator, LongType)\n    .withComparator(relaxedFloatComparator, \"floatValue\")\n    .withComparator(relaxedDoubleComparator, \"doubleValue\")\n    .withComparator(relaxedDecimalComparator, DecimalType(38, 18))\n}\n"
  },
  {
    "path": "src/test/scala/uk/co/gresearch/spark/diff/DiffOptionsSuite.scala",
    "content": "/*\n * Copyright 2020 G-Research\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage uk.co.gresearch.spark.diff\n\nimport org.apache.spark.sql.Column\nimport org.apache.spark.sql.functions.lit\nimport org.apache.spark.sql.internal.SQLConf\nimport org.apache.spark.sql.types._\nimport uk.co.gresearch.spark.SparkTestSession\nimport uk.co.gresearch.spark.diff.comparator.{DefaultDiffComparator, DiffComparator, EquivDiffComparator}\nimport uk.co.gresearch.test.Suite\n\nclass DiffOptionsSuite extends Suite with SparkTestSession {\n\n  import spark.implicits._\n\n  test(\"diff options with empty diff column name\") {\n    // test the copy method (constructor), not the fluent methods\n    val default = DiffOptions.default\n    val options = default.copy(diffColumn = \"\")\n    assert(options.diffColumn.isEmpty)\n  }\n\n  test(\"diff options left and right prefixes\") {\n    // test the copy method (constructor), not the fluent methods\n    val default = DiffOptions.default\n    doTestRequirement(default.copy(leftColumnPrefix = \"\"), \"Left column prefix must not be empty\")\n    doTestRequirement(default.copy(rightColumnPrefix = \"\"), \"Right column prefix must not be empty\")\n\n    val prefix = \"prefix\"\n    doTestRequirement(\n      default.copy(leftColumnPrefix = prefix, rightColumnPrefix = prefix),\n      s\"Left and right column prefix must be distinct: $prefix\"\n    )\n    withSQLConf(SQLConf.CASE_SENSITIVE.key -> \"false\") {\n      doTestRequirement(\n        default.copy(leftColumnPrefix = prefix.toLowerCase, rightColumnPrefix = prefix.toUpperCase),\n        s\"Left and right column prefix must be distinct: $prefix\"\n      )\n    }\n    withSQLConf(SQLConf.CASE_SENSITIVE.key -> \"true\") {\n      default.copy(leftColumnPrefix = prefix.toLowerCase, rightColumnPrefix = prefix.toUpperCase)\n    }\n  }\n\n  test(\"diff options diff value\") {\n    // test the copy method (constructor), not the fluent methods\n    val default = DiffOptions.default\n\n    val emptyInsertDiffValueOpts = default.copy(insertDiffValue = \"\")\n    assert(emptyInsertDiffValueOpts.insertDiffValue.isEmpty)\n    val emptyChangeDiffValueOpts = default.copy(changeDiffValue = \"\")\n    assert(emptyChangeDiffValueOpts.changeDiffValue.isEmpty)\n    val emptyDeleteDiffValueOpts = default.copy(deleteDiffValue = \"\")\n    assert(emptyDeleteDiffValueOpts.deleteDiffValue.isEmpty)\n    val emptyNochangeDiffValueOpts = default.copy(nochangeDiffValue = \"\")\n    assert(emptyNochangeDiffValueOpts.nochangeDiffValue.isEmpty)\n\n    Seq(\"value\", \"\").foreach { value =>\n      doTestRequirement(\n        default.copy(insertDiffValue = value, changeDiffValue = value),\n        s\"Diff values must be distinct: List($value, $value, D, N)\"\n      )\n      doTestRequirement(\n        default.copy(insertDiffValue = value, deleteDiffValue = value),\n        s\"Diff values must be distinct: List($value, C, $value, N)\"\n      )\n      doTestRequirement(\n        default.copy(insertDiffValue = value, nochangeDiffValue = value),\n        s\"Diff values must be distinct: List($value, C, D, $value)\"\n      )\n      doTestRequirement(\n        default.copy(changeDiffValue = value, deleteDiffValue = value),\n        s\"Diff values must be distinct: List(I, $value, $value, N)\"\n      )\n      doTestRequirement(\n        default.copy(changeDiffValue = value, nochangeDiffValue = value),\n        s\"Diff values must be distinct: List(I, $value, D, $value)\"\n      )\n      doTestRequirement(\n        default.copy(deleteDiffValue = value, nochangeDiffValue = value),\n        s\"Diff values must be distinct: List(I, C, $value, $value)\"\n      )\n    }\n  }\n\n  test(\"diff options with change column name same as diff column\") {\n    withSQLConf(SQLConf.CASE_SENSITIVE.key -> \"false\") {\n      doTestRequirement(\n        DiffOptions.default.withDiffColumn(\"same\").withChangeColumn(\"same\"),\n        \"Change column name must be different to diff column: same\"\n      )\n      doTestRequirement(\n        DiffOptions.default.withChangeColumn(\"same\").withDiffColumn(\"same\"),\n        \"Change column name must be different to diff column: same\"\n      )\n\n      doTestRequirement(\n        DiffOptions.default.withDiffColumn(\"SAME\").withChangeColumn(\"same\"),\n        \"Change column name must be different to diff column: SAME\"\n      )\n      doTestRequirement(\n        DiffOptions.default.withChangeColumn(\"SAME\").withDiffColumn(\"same\"),\n        \"Change column name must be different to diff column: same\"\n      )\n    }\n\n    withSQLConf(SQLConf.CASE_SENSITIVE.key -> \"true\") {\n      DiffOptions.default.withDiffColumn(\"SAME\").withChangeColumn(\"same\")\n      DiffOptions.default.withChangeColumn(\"SAME\").withDiffColumn(\"same\")\n    }\n  }\n\n  test(\"diff options with comparators\") {\n    case class Comparator(name: String) extends DiffComparator {\n      override def equiv(left: Column, right: Column): Column = DefaultDiffComparator.equiv(left, right)\n    }\n    val cmp1 = Comparator(\"cmp1\")\n    val cmp2 = Comparator(\"cmp2\")\n    val cmp3 = Comparator(\"cmp3\")\n    val cmp4 = Comparator(\"cmp4\")\n\n    val options = DiffOptions.default\n      .withDefaultComparator(cmp1)\n      .withComparator(cmp2, IntegerType, LongType)\n      .withComparator(cmp3, DoubleType)\n      .withComparator(cmp4, \"col1\", \"col2\")\n\n    assert(options.comparatorFor(StructField(\"col1\", IntegerType)) === cmp4)\n    assert(options.comparatorFor(StructField(\"col1\", LongType)) === cmp4)\n    assert(options.comparatorFor(StructField(\"col2\", StringType)) === cmp4)\n    assert(options.comparatorFor(StructField(\"col3\", IntegerType)) === cmp2)\n    assert(options.comparatorFor(StructField(\"col3\", LongType)) === cmp2)\n    assert(options.comparatorFor(StructField(\"col4\", DoubleType)) === cmp3)\n    assert(options.comparatorFor(StructField(\"col5\", FloatType)) === cmp1)\n  }\n\n  Seq(\n    (\n      \"single type\",\n      (options: DiffOptions) => options.withComparator(DiffComparators.default(), IntegerType),\n      \"A comparator for data type int exists already.\"\n    ),\n    (\n      \"multiple types\",\n      (options: DiffOptions) => options.withComparator(DiffComparators.default(), IntegerType, FloatType),\n      \"A comparator for data types float, int exists already.\"\n    ),\n    (\n      \"single column\",\n      (options: DiffOptions) => options.withComparator(DiffComparators.default(), \"col1\"),\n      \"A comparator for column name col1 exists already.\"\n    ),\n    (\n      \"multiple columns\",\n      (options: DiffOptions) => options.withComparator(DiffComparators.default(), \"col2\", \"col1\"),\n      \"A comparator for column names col1, col2 exists already.\"\n    ),\n  ).foreach { case (label, call, expected) =>\n    test(s\"diff options with duplicate comparator - $label\") {\n      val options = DiffOptions.default\n        .withComparator(DiffComparators.default(), IntegerType, FloatType)\n        .withComparator(DiffComparators.default(), \"col1\", \"col2\")\n      val exception = intercept[IllegalArgumentException] { call(options) }\n      assert(exception.getMessage === expected)\n    }\n  }\n\n  test(\"diff options with typed diff comparator for other data type\") {\n    val exceptionSingle = intercept[IllegalArgumentException] {\n      DiffOptions.default\n        .withComparator(EquivDiffComparator((left: Int, right: Int) => left.abs == right.abs), LongType)\n    }\n    assert(exceptionSingle.getMessage.contains(\"Comparator with input type int cannot be used for data type bigint\"))\n\n    val exceptionMulti = intercept[IllegalArgumentException] {\n      DiffOptions.default\n        .withComparator(EquivDiffComparator((left: Int, right: Int) => left.abs == right.abs), LongType, FloatType)\n    }\n    assert(\n      exceptionMulti.getMessage.contains(\"Comparator with input type int cannot be used for data type bigint, float\")\n    )\n  }\n\n  test(\"fluent methods of diff options\") {\n    assert(\n      DiffMode.Default != DiffMode.LeftSide,\n      \"test assumption on default diff mode must hold, otherwise test is trivial\"\n    )\n\n    val cmp1 = new DiffComparator {\n      override def equiv(left: Column, right: Column): Column = lit(true)\n    }\n    val cmp2 = new DiffComparator {\n      override def equiv(left: Column, right: Column): Column = lit(true)\n    }\n    val cmp3 = new DiffComparator {\n      override def equiv(left: Column, right: Column): Column = lit(true)\n    }\n\n    val options = DiffOptions.default\n      .withDiffColumn(\"d\")\n      .withLeftColumnPrefix(\"l\")\n      .withRightColumnPrefix(\"r\")\n      .withInsertDiffValue(\"i\")\n      .withChangeDiffValue(\"c\")\n      .withDeleteDiffValue(\"d\")\n      .withNochangeDiffValue(\"n\")\n      .withChangeColumn(\"change\")\n      .withDiffMode(DiffMode.LeftSide)\n      .withSparseMode(true)\n      .withDefaultComparator(cmp1)\n      .withComparator(cmp2, IntegerType)\n      .withComparator(cmp3, \"col1\")\n\n    val dexpectedDefCmp = cmp1\n    val expectedDtCmps = Map(IntegerType.asInstanceOf[DataType] -> cmp2)\n    val expectedColCmps = Map(\"col1\" -> cmp3)\n    val expected = DiffOptions(\n      \"d\",\n      \"l\",\n      \"r\",\n      \"i\",\n      \"c\",\n      \"d\",\n      \"n\",\n      Some(\"change\"),\n      DiffMode.LeftSide,\n      sparseMode = true,\n      dexpectedDefCmp,\n      expectedDtCmps,\n      expectedColCmps\n    )\n    assert(options === expected)\n  }\n\n  def doTestRequirement(f: => Any, expected: String): Unit = {\n    assert(intercept[IllegalArgumentException](f).getMessage === s\"requirement failed: $expected\")\n  }\n\n}\n"
  },
  {
    "path": "src/test/scala/uk/co/gresearch/spark/diff/DiffSuite.scala",
    "content": "/*\n * Copyright 2020 G-Research\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage uk.co.gresearch.spark.diff\n\nimport org.apache.spark.sql.functions.regexp_replace\nimport org.apache.spark.sql.internal.SQLConf\nimport org.apache.spark.sql.types._\nimport org.apache.spark.sql.{Dataset, Encoders, Row, SparkSession}\nimport uk.co.gresearch.spark.{SparkTestSession, distinctPrefixFor}\nimport uk.co.gresearch.test.Suite\n\ncase class Empty()\ncase class Value(id: Int, value: Option[String])\ncase class Value2(id: Int, seq: Option[Int], value: Option[String])\ncase class Value3(id: Int, left_value: String, right_value: String, value: String)\ncase class Value4(id: Int, diff: String)\ncase class Value4b(id: Int, change: String)\ncase class Value5(first_id: Int, id: String)\ncase class Value6(id: Int, label: String)\ncase class Value7(id: Int, value: Option[String], label: Option[String])\ncase class Value8(id: Int, seq: Option[Int], value: Option[String], meta: Option[String])\ncase class Value9(id: Int, seq: Option[Int], value: Option[String], info: Option[String])\ncase class Value9up(ID: Int, SEQ: Option[Int], VALUE: Option[String], INFO: Option[String])\n\ncase class ValueLeft(left_id: Int, value: Option[String])\ncase class ValueRight(right_id: Int, value: Option[String])\n\ncase class DiffAs(diff: String, id: Int, left_value: Option[String], right_value: Option[String])\ncase class DiffAs8(\n    diff: String,\n    id: Int,\n    seq: Option[Int],\n    left_value: Option[String],\n    right_value: Option[String],\n    left_meta: Option[String],\n    right_meta: Option[String]\n)\ncase class DiffAs8SideBySide(\n    diff: String,\n    id: Int,\n    seq: Option[Int],\n    left_value: Option[String],\n    left_meta: Option[String],\n    right_value: Option[String],\n    right_meta: Option[String]\n)\ncase class DiffAs8OneSide(diff: String, id: Int, seq: Option[Int], value: Option[String], meta: Option[String])\ncase class DiffAs8changes(\n    diff: String,\n    changed: Array[String],\n    id: Int,\n    seq: Option[Int],\n    left_value: Option[String],\n    right_value: Option[String],\n    left_meta: Option[String],\n    right_meta: Option[String]\n)\ncase class DiffAs8and9(\n    diff: String,\n    id: Int,\n    seq: Option[Int],\n    left_value: Option[String],\n    right_value: Option[String],\n    left_meta: Option[String],\n    right_info: Option[String]\n)\n\ncase class DiffAsCustom(action: String, id: Int, before_value: Option[String], after_value: Option[String])\ncase class DiffAsSubset(diff: String, id: Int, left_value: Option[String])\ncase class DiffAsExtra(diff: String, id: Int, left_value: Option[String], right_value: Option[String], extra: String)\ncase class DiffAsOneSide(diff: String, id: Int, value: Option[String])\n\nobject DiffSuite {\n  def left(spark: SparkSession): Dataset[Value] = {\n    import spark.implicits._\n    Seq(\n      Value(1, Some(\"one\")),\n      Value(2, Some(\"two\")),\n      Value(3, Some(\"three\"))\n    ).toDS()\n  }\n\n  def right(spark: SparkSession): Dataset[Value] = {\n    import spark.implicits._\n    Seq(\n      Value(1, Some(\"one\")),\n      Value(2, Some(\"Two\")),\n      Value(4, Some(\"four\"))\n    ).toDS()\n  }\n\n  val expectedDiff: Seq[Row] = Seq(\n    Row(\"N\", 1, \"one\", \"one\"),\n    Row(\"C\", 2, \"two\", \"Two\"),\n    Row(\"D\", 3, \"three\", null),\n    Row(\"I\", 4, null, \"four\")\n  )\n\n}\n\nclass DiffSuite extends Suite with SparkTestSession {\n\n  import spark.implicits._\n\n  lazy val left: Dataset[Value] = DiffSuite.left(spark)\n  lazy val right: Dataset[Value] = DiffSuite.right(spark)\n\n  lazy val left7: Dataset[Value7] = Seq(\n    Value7(1, Some(\"one\"), Some(\"one label\")),\n    Value7(2, Some(\"two\"), Some(\"two labels\")),\n    Value7(3, Some(\"three\"), Some(\"three labels\")),\n    Value7(4, Some(\"four\"), Some(\"four labels\")),\n    Value7(5, None, None),\n    Value7(6, Some(\"six\"), Some(\"six labels\")),\n    Value7(7, Some(\"seven\"), Some(\"seven labels\")),\n    Value7(9, None, None)\n  ).toDS()\n\n  lazy val right7: Dataset[Value7] = Seq(\n    Value7(1, Some(\"One\"), Some(\"one label\")),\n    Value7(2, Some(\"two\"), Some(\"two Labels\")),\n    Value7(3, Some(\"Three\"), Some(\"Three Labels\")),\n    Value7(4, None, None),\n    Value7(5, Some(\"five\"), Some(\"five labels\")),\n    Value7(6, Some(\"six\"), Some(\"six labels\")),\n    Value7(8, Some(\"eight\"), Some(\"eight labels\")),\n    Value7(10, None, None)\n  ).toDS()\n\n  lazy val left8: Dataset[Value8] = Seq(\n    Value8(1, Some(1), Some(\"one\"), Some(\"user1\")),\n    Value8(1, Some(2), Some(\"one\"), None),\n    Value8(1, Some(3), Some(\"one\"), Some(\"user1\")),\n    Value8(2, None, Some(\"two\"), Some(\"user2\")),\n    Value8(2, Some(1), Some(\"two\"), None),\n    Value8(2, Some(2), Some(\"two\"), None),\n    Value8(3, None, None, None)\n  ).toDS()\n\n  lazy val right8: Dataset[Value8] = Seq(\n    Value8(1, Some(1), Some(\"one\"), Some(\"user2\")),\n    Value8(1, Some(2), Some(\"one\"), Some(\"user2\")),\n    Value8(1, Some(3), Some(\"one\"), None),\n    Value8(2, None, Some(\"two\"), Some(\"user2\")),\n    Value8(2, Some(2), Some(\"Two\"), Some(\"user1\")),\n    Value8(2, Some(3), Some(\"two\"), Some(\"user2\")),\n    Value8(3, None, None, None)\n  ).toDS()\n\n  lazy val right9: Dataset[Value9] =\n    right8.withColumn(\"info\", regexp_replace($\"meta\", \"user\", \"info\")).drop(\"meta\").as[Value9]\n\n  lazy val expectedDiffColumns: Seq[String] = Seq(\"diff\", \"id\", \"left_value\", \"right_value\")\n\n  lazy val expectedDiff: Seq[Row] = DiffSuite.expectedDiff\n\n  lazy val expectedReverseDiff: Seq[Row] = Seq(\n    Row(\"N\", 1, \"one\", \"one\"),\n    Row(\"C\", 2, \"Two\", \"two\"),\n    Row(\"I\", 3, null, \"three\"),\n    Row(\"D\", 4, \"four\", null)\n  )\n\n  lazy val expectedDiffAs: Seq[DiffAs] =\n    expectedDiff.map(r => DiffAs(r.getString(0), r.getInt(1), Option(r.getString(2)), Option(r.getString(3))))\n\n  lazy val expectedDiff7: Seq[Row] = Seq(\n    Row(\"C\", 1, \"one\", \"One\", \"one label\", \"one label\"),\n    Row(\"C\", 2, \"two\", \"two\", \"two labels\", \"two Labels\"),\n    Row(\"C\", 3, \"three\", \"Three\", \"three labels\", \"Three Labels\"),\n    Row(\"C\", 4, \"four\", null, \"four labels\", null),\n    Row(\"C\", 5, null, \"five\", null, \"five labels\"),\n    Row(\"N\", 6, \"six\", \"six\", \"six labels\", \"six labels\"),\n    Row(\"D\", 7, \"seven\", null, \"seven labels\", null),\n    Row(\"I\", 8, null, \"eight\", null, \"eight labels\"),\n    Row(\"D\", 9, null, null, null, null),\n    Row(\"I\", 10, null, null, null, null)\n  )\n\n  lazy val expectedSideBySideDiff7: Seq[Row] = expectedDiff7.map(row =>\n    Row(row.getString(0), row.getInt(1), row.getString(2), row.getString(4), row.getString(3), row.getString(5))\n  )\n  lazy val expectedLeftSideDiff7: Seq[Row] =\n    expectedDiff7.map(row => Row(row.getString(0), row.getInt(1), row.getString(2), row.getString(4)))\n  lazy val expectedRightSideDiff7: Seq[Row] =\n    expectedDiff7.map(row => Row(row.getString(0), row.getInt(1), row.getString(3), row.getString(5)))\n\n  lazy val expectedSparseDiff7: Seq[Row] = Seq(\n    Row(\"C\", 1, \"one\", \"One\", null, null),\n    Row(\"C\", 2, null, null, \"two labels\", \"two Labels\"),\n    Row(\"C\", 3, \"three\", \"Three\", \"three labels\", \"Three Labels\"),\n    Row(\"C\", 4, \"four\", null, \"four labels\", null),\n    Row(\"C\", 5, null, \"five\", null, \"five labels\"),\n    Row(\"N\", 6, null, null, null, null),\n    Row(\"D\", 7, \"seven\", null, \"seven labels\", null),\n    Row(\"I\", 8, null, \"eight\", null, \"eight labels\"),\n    Row(\"D\", 9, null, null, null, null),\n    Row(\"I\", 10, null, null, null, null)\n  )\n\n  lazy val expectedSideBySideSparseDiff7: Seq[Row] = expectedSparseDiff7.map(row =>\n    Row(row.getString(0), row.getInt(1), row.getString(2), row.getString(4), row.getString(3), row.getString(5))\n  )\n  lazy val expectedLeftSideSparseDiff7: Seq[Row] =\n    expectedSparseDiff7.map(row => Row(row.getString(0), row.getInt(1), row.getString(2), row.getString(4)))\n  lazy val expectedRightSideSparseDiff7: Seq[Row] =\n    expectedSparseDiff7.map(row => Row(row.getString(0), row.getInt(1), row.getString(3), row.getString(5)))\n\n  lazy val expectedDiff7WithChanges: Seq[Row] = Seq(\n    Row(\"C\", Seq(\"value\"), 1, \"one\", \"One\", \"one label\", \"one label\"),\n    Row(\"C\", Seq(\"label\"), 2, \"two\", \"two\", \"two labels\", \"two Labels\"),\n    Row(\"C\", Seq(\"value\", \"label\"), 3, \"three\", \"Three\", \"three labels\", \"Three Labels\"),\n    Row(\"C\", Seq(\"value\", \"label\"), 4, \"four\", null, \"four labels\", null),\n    Row(\"C\", Seq(\"value\", \"label\"), 5, null, \"five\", null, \"five labels\"),\n    Row(\"N\", Seq.empty[String], 6, \"six\", \"six\", \"six labels\", \"six labels\"),\n    Row(\"D\", null, 7, \"seven\", null, \"seven labels\", null),\n    Row(\"I\", null, 8, null, \"eight\", null, \"eight labels\"),\n    Row(\"D\", null, 9, null, null, null, null),\n    Row(\"I\", null, 10, null, null, null, null)\n  )\n\n  lazy val expectedDiff8: Seq[Row] = Seq(\n    Row(\"N\", 1, 1, \"one\", \"one\", \"user1\", \"user2\"),\n    Row(\"N\", 1, 2, \"one\", \"one\", null, \"user2\"),\n    Row(\"N\", 1, 3, \"one\", \"one\", \"user1\", null),\n    Row(\"N\", 2, null, \"two\", \"two\", \"user2\", \"user2\"),\n    Row(\"D\", 2, 1, \"two\", null, null, null),\n    Row(\"C\", 2, 2, \"two\", \"Two\", null, \"user1\"),\n    Row(\"I\", 2, 3, null, \"two\", null, \"user2\"),\n    Row(\"N\", 3, null, null, null, null, null)\n  )\n\n  lazy val expectedDiff8and9: Seq[Row] = Seq(\n    Row(\"N\", 1, 1, \"one\", \"one\", \"user1\", \"info2\"),\n    Row(\"N\", 1, 2, \"one\", \"one\", null, \"info2\"),\n    Row(\"N\", 1, 3, \"one\", \"one\", \"user1\", null),\n    Row(\"N\", 2, null, \"two\", \"two\", \"user2\", \"info2\"),\n    Row(\"D\", 2, 1, \"two\", null, null, null),\n    Row(\"C\", 2, 2, \"two\", \"Two\", null, \"info1\"),\n    Row(\"I\", 2, 3, null, \"two\", null, \"info2\"),\n    Row(\"N\", 3, null, null, null, null, null)\n  )\n\n  lazy val expectedSideBySideDiff8: Seq[Row] =\n    expectedDiff8.map(r => Row(r.get(0), r.get(1), r.get(2), r.get(3), r.get(5), r.get(4), r.get(6)))\n  lazy val expectedLeftSideDiff8: Seq[Row] =\n    expectedDiff8.map(r => Row(r.get(0), r.get(1), r.get(2), r.get(3), r.get(5)))\n  lazy val expectedRightSideDiff8: Seq[Row] =\n    expectedDiff8.map(r => Row(r.get(0), r.get(1), r.get(2), r.get(4), r.get(6)))\n\n  lazy val expectedSparseDiff8: Seq[Row] = Seq(\n    Row(\"N\", 1, 1, null, null, \"user1\", \"user2\"),\n    Row(\"N\", 1, 2, null, null, null, \"user2\"),\n    Row(\"N\", 1, 3, null, null, \"user1\", null),\n    Row(\"N\", 2, null, null, null, null, null),\n    Row(\"D\", 2, 1, \"two\", null, null, null),\n    Row(\"C\", 2, 2, \"two\", \"Two\", null, \"user1\"),\n    Row(\"I\", 2, 3, null, \"two\", null, \"user2\"),\n    Row(\"N\", 3, null, null, null, null, null)\n  )\n\n  lazy val expectedSideBySideSparseDiff8: Seq[Row] =\n    expectedSparseDiff8.map(r => Row(r.get(0), r.get(1), r.get(2), r.get(3), r.get(5), r.get(4), r.get(6)))\n  lazy val expectedLeftSideSparseDiff8: Seq[Row] =\n    expectedSparseDiff8.map(r => Row(r.get(0), r.get(1), r.get(2), r.get(3), r.get(5)))\n  lazy val expectedRightSideSparseDiff8: Seq[Row] =\n    expectedSparseDiff8.map(r => Row(r.get(0), r.get(1), r.get(2), r.get(4), r.get(6)))\n\n  lazy val expectedDiffAs8: Seq[DiffAs8] = expectedDiff8.map(r =>\n    DiffAs8(\n      r.getString(0),\n      r.getInt(1),\n      Some(r).filterNot(_.isNullAt(2)).map(_.getInt(2)),\n      Option(r.getString(3)),\n      Option(r.getString(4)),\n      Option(r.getString(5)),\n      Option(r.getString(6))\n    )\n  )\n\n  lazy val expectedDiff8WithChanges: Seq[Row] = expectedDiff8.map(r =>\n    Row(\n      r.get(0),\n      r.get(0) match {\n        case \"N\" => Seq.empty\n        case \"I\" => null\n        case \"C\" => Seq(\"value\")\n        case \"D\" => null\n      },\n      r.get(1),\n      r.get(2),\n      r.getString(3),\n      r.getString(4),\n      r.getString(5),\n      r.getString(6)\n    )\n  )\n\n  lazy val expectedDiffAs8and9: Seq[DiffAs8and9] = expectedDiff8and9.map(r =>\n    DiffAs8and9(\n      r.getString(0),\n      r.getInt(1),\n      Some(r).filterNot(_.isNullAt(2)).map(_.getInt(2)),\n      Option(r.getString(3)),\n      Option(r.getString(4)),\n      Option(r.getString(5)),\n      Option(r.getString(6))\n    )\n  )\n\n  lazy val expectedDiffWith8and9: Seq[(String, Value8, Value9)] = expectedDiffAs8and9.map(v =>\n    (\n      v.diff,\n      if (v.diff == \"I\") null else Value8(v.id, v.seq, v.left_value, v.left_meta),\n      if (v.diff == \"D\") null else Value9(v.id, v.seq, v.right_value, v.right_info)\n    )\n  )\n\n  lazy val expectedDiffWith8and9up: Seq[(String, Value8, Value9up)] =\n    expectedDiffWith8and9.map(t => t.copy(_3 = Option(t._3).map(v => Value9up(v.id, v.seq, v.value, v.info)).orNull))\n\n  test(\"diff dataframe with duplicate columns\") {\n    val df = Seq(1).toDF(\"id\").select($\"id\", $\"id\")\n\n    doTestRequirement(\n      df.diff(df, \"id\"),\n      \"The datasets have duplicate columns.\\n\" +\n        \"Left column names: id, id\\nRight column names: id, id\"\n    )\n  }\n\n  test(\"diff with no id column\") {\n    val expected = Seq(\n      Row(\"N\", 1, \"one\"),\n      Row(\"D\", 2, \"two\"),\n      Row(\"I\", 2, \"Two\"),\n      Row(\"D\", 3, \"three\"),\n      Row(\"I\", 4, \"four\")\n    )\n\n    val actual = left.diff(right).orderBy(\"id\", \"diff\")\n\n    assert(actual.columns === Seq(\"diff\", \"id\", \"value\"))\n    assert(actual.collect() === expected)\n  }\n\n  test(\"diff with no id columns ids taken from left\") {\n    // we can check from where ids are taken only with case insensitivity\n    withSQLConf(SQLConf.CASE_SENSITIVE.key -> \"false\") {\n      val left = this.left.toDF()\n      val right = this.right.toDF(\"ID\", \"VALUE\")\n\n      assert(left.diff(right).columns === Seq(\"diff\", \"id\", \"value\"))\n      assert(right.diff(left).columns === Seq(\"diff\", \"ID\", \"VALUE\"))\n    }\n  }\n\n  test(\"diff with one id column\") {\n    val actual = left.diff(right, \"id\").orderBy(\"id\")\n    val reverse = right.diff(left, \"id\").orderBy(\"id\")\n\n    assert(actual.columns === expectedDiffColumns)\n    assert(actual.collect() === expectedDiff)\n    assert(reverse.columns === expectedDiffColumns)\n    assert(reverse.collect() === expectedReverseDiff)\n  }\n\n  test(\"diff with one ID column case-insensitive\") {\n    withSQLConf(SQLConf.CASE_SENSITIVE.key -> \"false\") {\n      val actual = left.diff(right, \"ID\").orderBy(\"ID\")\n      val reverse = right.diff(left, \"ID\").orderBy(\"ID\")\n\n      assert(actual.columns === Seq(\"diff\", \"ID\", \"left_value\", \"right_value\"))\n      assert(actual.collect() === expectedDiff)\n      assert(reverse.columns === Seq(\"diff\", \"ID\", \"left_value\", \"right_value\"))\n      assert(reverse.collect() === expectedReverseDiff)\n    }\n  }\n\n  test(\"diff with one id column case-sensitive\") {\n    withSQLConf(SQLConf.CASE_SENSITIVE.key -> \"true\") {\n      doTestRequirement(left.diff(right, \"ID\"), \"Some id columns do not exist: ID missing among id, value\")\n\n      val actual = left.diff(right, \"id\").orderBy(\"id\")\n      val reverse = right.diff(left, \"id\").orderBy(\"id\")\n\n      assert(actual.columns === expectedDiffColumns)\n      assert(actual.collect() === expectedDiff)\n      assert(reverse.columns === expectedDiffColumns)\n      assert(reverse.collect() === expectedReverseDiff)\n    }\n  }\n\n  test(\"diff with two id columns\") {\n    val left = Seq(\n      Value2(1, Some(1), Some(\"one\")),\n      Value2(2, Some(1), Some(\"two.one\")),\n      Value2(2, Some(2), Some(\"two.two\")),\n      Value2(3, Some(1), Some(\"three\"))\n    ).toDS()\n\n    val right = Seq(\n      Value2(1, Some(1), Some(\"one\")),\n      Value2(2, Some(1), Some(\"two.one\")),\n      Value2(2, Some(2), Some(\"two.Two\")),\n      Value2(4, Some(1), Some(\"four\"))\n    ).toDS()\n\n    val expected = Seq(\n      Row(\"N\", 1, 1, \"one\", \"one\"),\n      Row(\"N\", 2, 1, \"two.one\", \"two.one\"),\n      Row(\"C\", 2, 2, \"two.two\", \"two.Two\"),\n      Row(\"D\", 3, 1, \"three\", null),\n      Row(\"I\", 4, 1, null, \"four\")\n    )\n\n    val actual = left.diff(right, \"id\", \"seq\").orderBy(\"id\", \"seq\")\n\n    assert(actual.columns === Seq(\"diff\", \"id\", \"seq\", \"left_value\", \"right_value\"))\n    assert(actual.collect() === expected)\n  }\n\n  test(\"diff with all id columns\") {\n    val expected = Seq(\n      Row(\"N\", 1, \"one\"),\n      Row(\"D\", 2, \"two\"),\n      Row(\"I\", 2, \"Two\"),\n      Row(\"D\", 3, \"three\"),\n      Row(\"I\", 4, \"four\")\n    )\n\n    val actual = left.diff(right, \"id\", \"value\").orderBy(\"id\", \"diff\")\n\n    assert(actual.columns === Seq(\"diff\", \"id\", \"value\"))\n    assert(actual.collect() === expected)\n  }\n\n  test(\"diff with null values\") {\n    val left = Seq(\n      Value(1, None),\n      Value(2, None),\n      Value(3, Some(\"three\")),\n      Value(4, None)\n    ).toDS()\n\n    val right = Seq(\n      Value(1, None),\n      Value(2, Some(\"two\")),\n      Value(3, None),\n      Value(5, None)\n    ).toDS()\n\n    val expected = Seq(\n      Row(\"N\", 1, null, null),\n      Row(\"C\", 2, null, \"two\"),\n      Row(\"C\", 3, \"three\", null),\n      Row(\"D\", 4, null, null),\n      Row(\"I\", 5, null, null)\n    )\n\n    val actual = left.diff(right, \"id\").orderBy(\"id\")\n\n    assert(actual.columns === Seq(\"diff\", \"id\", \"left_value\", \"right_value\"))\n    assert(actual.collect() === expected)\n  }\n\n  test(\"diff with null id values\") {\n    val left = Seq(\n      Value2(1, None, Some(\"one\")),\n      Value2(2, Some(1), Some(\"two.one\")),\n      Value2(2, Some(2), Some(\"two.two\")),\n      Value2(3, None, Some(\"three\"))\n    ).toDS()\n\n    val right = Seq(\n      Value2(1, None, Some(\"one\")),\n      Value2(2, Some(1), Some(\"two.one\")),\n      Value2(2, Some(2), Some(\"two.Two\")),\n      Value2(4, None, Some(\"four\"))\n    ).toDS()\n\n    val expected = Seq(\n      Row(\"N\", 1, None.orNull, \"one\", \"one\"),\n      Row(\"N\", 2, 1, \"two.one\", \"two.one\"),\n      Row(\"C\", 2, 2, \"two.two\", \"two.Two\"),\n      Row(\"D\", 3, None.orNull, \"three\", None.orNull),\n      Row(\"I\", 4, None.orNull, None.orNull, \"four\")\n    )\n\n    val actual = left.diff(right, \"id\", \"seq\").orderBy(\"id\", \"seq\")\n\n    assert(actual.columns === Seq(\"diff\", \"id\", \"seq\", \"left_value\", \"right_value\"))\n    assert(actual.collect() === expected)\n  }\n\n  /**\n   * Tests the column order of the produced diff DataFrame.\n   */\n  test(\"diff column order\") {\n    // left has same schema as right but different column order\n    val left = Seq(\n      // value1, id, value2, seq, value3\n      (\"val1.1.1\", 1, \"val1.1.2\", 1, \"val1.1.3\"),\n      (\"val1.2.1\", 1, \"val1.2.2\", 2, \"val1.2.3\"),\n      (\"val2.1.1\", 2, \"val2.1.2\", 1, \"val2.1.3\")\n    ).toDF(\"value1\", \"id\", \"value2\", \"seq\", \"value3\")\n    val right = Seq(\n      // value2, seq, value3, id, value1\n      (\"val1.1.2\", 1, \"val1.1.3\", 1, \"val1.1.1\"),\n      (\"val1.2.2\", 2, \"val1.2.3 changed\", 1, \"val1.2.1\"),\n      (\"val2.2.2\", 2, \"val2.2.3\", 2, \"val2.2.1\")\n    ).toDF(\"value2\", \"seq\", \"value3\", \"id\", \"value1\")\n\n    // diffing left to right provides schema of result DataFrame different to right-to-left diff\n    {\n      val expected = Seq(\n        Row(\"N\", 1, 1, \"val1.1.1\", \"val1.1.1\", \"val1.1.2\", \"val1.1.2\", \"val1.1.3\", \"val1.1.3\"),\n        Row(\"C\", 1, 2, \"val1.2.1\", \"val1.2.1\", \"val1.2.2\", \"val1.2.2\", \"val1.2.3\", \"val1.2.3 changed\"),\n        Row(\"D\", 2, 1, \"val2.1.1\", null, \"val2.1.2\", null, \"val2.1.3\", null),\n        Row(\"I\", 2, 2, null, \"val2.2.1\", null, \"val2.2.2\", null, \"val2.2.3\")\n      )\n      val expectedColumns = Seq(\n        \"diff\",\n        \"id\",\n        \"seq\",\n        \"left_value1\",\n        \"right_value1\",\n        \"left_value2\",\n        \"right_value2\",\n        \"left_value3\",\n        \"right_value3\"\n      )\n\n      val actual = left.diff(right, \"id\", \"seq\").orderBy(\"id\", \"seq\")\n\n      assert(actual.columns === expectedColumns)\n      assert(actual.collect() === expected)\n    }\n\n    // diffing right to left provides different schema of result DataFrame\n    {\n      val expected = Seq(\n        Row(\"N\", 1, 1, \"val1.1.2\", \"val1.1.2\", \"val1.1.3\", \"val1.1.3\", \"val1.1.1\", \"val1.1.1\"),\n        Row(\"C\", 1, 2, \"val1.2.2\", \"val1.2.2\", \"val1.2.3 changed\", \"val1.2.3\", \"val1.2.1\", \"val1.2.1\"),\n        Row(\"I\", 2, 1, null, \"val2.1.2\", null, \"val2.1.3\", null, \"val2.1.1\"),\n        Row(\"D\", 2, 2, \"val2.2.2\", null, \"val2.2.3\", null, \"val2.2.1\", null)\n      )\n      val expectedColumns = Seq(\n        \"diff\",\n        \"id\",\n        \"seq\",\n        \"left_value2\",\n        \"right_value2\",\n        \"left_value3\",\n        \"right_value3\",\n        \"left_value1\",\n        \"right_value1\"\n      )\n\n      val actual = right.diff(left, \"id\", \"seq\").orderBy(\"id\", \"seq\")\n\n      assert(actual.columns === expectedColumns)\n      assert(actual.collect() === expected)\n    }\n\n    // diffing left to right without id columns takes column order of left\n    {\n      val expected = Seq(\n        Row(\"N\", \"val1.1.1\", 1, \"val1.1.2\", 1, \"val1.1.3\"),\n        Row(\"D\", \"val1.2.1\", 1, \"val1.2.2\", 2, \"val1.2.3\"),\n        Row(\"I\", \"val1.2.1\", 1, \"val1.2.2\", 2, \"val1.2.3 changed\"),\n        Row(\"D\", \"val2.1.1\", 2, \"val2.1.2\", 1, \"val2.1.3\"),\n        Row(\"I\", \"val2.2.1\", 2, \"val2.2.2\", 2, \"val2.2.3\")\n      )\n      val expectedColumns = Seq(\n        \"diff\",\n        \"value1\",\n        \"id\",\n        \"value2\",\n        \"seq\",\n        \"value3\"\n      )\n\n      val actual = left.diff(right).orderBy(\"id\", \"seq\", \"diff\")\n\n      assert(actual.columns === expectedColumns)\n      assert(actual.collect() === expected)\n    }\n\n    // diffing right to left without id columns takes column order of right\n    {\n      val expected = Seq(\n        Row(\"N\", \"val1.1.1\", 1, \"val1.1.2\", 1, \"val1.1.3\"),\n        Row(\"D\", \"val1.2.1\", 1, \"val1.2.2\", 2, \"val1.2.3\"),\n        Row(\"I\", \"val1.2.1\", 1, \"val1.2.2\", 2, \"val1.2.3 changed\"),\n        Row(\"D\", \"val2.1.1\", 2, \"val2.1.2\", 1, \"val2.1.3\"),\n        Row(\"I\", \"val2.2.1\", 2, \"val2.2.2\", 2, \"val2.2.3\")\n      )\n      val expectedColumns = Seq(\n        \"diff\",\n        \"value1\",\n        \"id\",\n        \"value2\",\n        \"seq\",\n        \"value3\"\n      )\n\n      val actual = left.diff(right).orderBy(\"id\", \"seq\", \"diff\")\n\n      assert(actual.columns === expectedColumns)\n      assert(actual.collect() === expected)\n    }\n  }\n\n  test(\"diff DataFrames\") {\n    val actual = left.toDF().diff(right.toDF(), \"id\").orderBy(\"id\")\n    val reverse = right.toDF().diff(left.toDF(), \"id\").orderBy(\"id\")\n\n    assert(actual.columns === expectedDiffColumns)\n    assert(actual.collect() === expectedDiff)\n    assert(reverse.columns === expectedDiffColumns)\n    assert(reverse.collect() === expectedReverseDiff)\n  }\n\n  test(\"diff with output columns in T\") {\n    val left = Seq(Value3(1, \"left\", \"right\", \"value\")).toDS()\n    val right = Seq(Value3(1, \"Left\", \"Right\", \"Value\")).toDS()\n\n    val actual = left.diff(right, \"id\")\n    val expectedColumns = Seq(\n      \"diff\",\n      \"id\",\n      \"left_left_value\",\n      \"right_left_value\",\n      \"left_right_value\",\n      \"right_right_value\",\n      \"left_value\",\n      \"right_value\"\n    )\n    val expectedDiff = Seq(\n      Row(\"C\", 1, \"left\", \"Left\", \"right\", \"Right\", \"value\", \"Value\")\n    )\n\n    assert(actual.columns === expectedColumns)\n    assert(actual.collect() === expectedDiff)\n  }\n\n  test(\"diff with id column diff in T\") {\n    val left = Seq(Value4(1, \"diff\")).toDS()\n    val right = Seq(Value4(1, \"Diff\")).toDS()\n\n    doTestRequirement(left.diff(right), \"The id columns must not contain the diff column name 'diff': id, diff\")\n    doTestRequirement(left.diff(right, \"diff\"), \"The id columns must not contain the diff column name 'diff': diff\")\n    doTestRequirement(\n      left.diff(right, \"diff\", \"id\"),\n      \"The id columns must not contain the diff column name 'diff': diff, id\"\n    )\n\n    withSQLConf(SQLConf.CASE_SENSITIVE.key -> \"false\") {\n      doTestRequirement(\n        left\n          .withColumnRenamed(\"diff\", \"Diff\")\n          .diff(right.withColumnRenamed(\"diff\", \"Diff\"), \"Diff\", \"id\"),\n        \"The id columns must not contain the diff column name 'diff': Diff, id\"\n      )\n    }\n\n    withSQLConf(SQLConf.CASE_SENSITIVE.key -> \"true\") {\n      left\n        .withColumnRenamed(\"diff\", \"Diff\")\n        .diff(right.withColumnRenamed(\"diff\", \"Diff\"), \"Diff\", \"id\")\n    }\n  }\n\n  test(\"diff with non-id column diff in T\") {\n    val left = Seq(Value4(1, \"diff\")).toDS()\n    val right = Seq(Value4(1, \"Diff\")).toDS()\n\n    val actual = left.diff(right, \"id\")\n    val expectedColumns = Seq(\n      \"diff\",\n      \"id\",\n      \"left_diff\",\n      \"right_diff\"\n    )\n    val expectedDiff = Seq(\n      Row(\"C\", 1, \"diff\", \"Diff\")\n    )\n\n    assert(actual.columns === expectedColumns)\n    assert(actual.collect() === expectedDiff)\n  }\n\n  test(\"diff where non-id column produces diff column name\") {\n    val options = DiffOptions.default\n      .withDiffColumn(\"a_value\")\n      .withLeftColumnPrefix(\"a\")\n      .withRightColumnPrefix(\"b\")\n\n    doTestRequirement(\n      left.diff(right, options, \"id\"),\n      \"The column prefixes 'a' and 'b', together with these non-id columns \" +\n        \"must not produce the diff column name 'a_value': value\"\n    )\n    doTestRequirement(\n      left.diff(right, options.withDiffColumn(\"b_value\"), \"id\"),\n      \"The column prefixes 'a' and 'b', together with these non-id columns \" +\n        \"must not produce the diff column name 'b_value': value\"\n    )\n  }\n\n  test(\"diff with left-side mode where non-id column would produce diff column name\") {\n    val options = DiffOptions.default\n      .withDiffColumn(\"a_value\")\n      .withLeftColumnPrefix(\"a\")\n      .withRightColumnPrefix(\"b\")\n      .withDiffMode(DiffMode.LeftSide)\n\n    left.diff(right, options, \"id\")\n  }\n\n  test(\"diff with right-side mode where non-id column would produce diff column name\") {\n    val options = DiffOptions.default\n      .withDiffColumn(\"b_value\")\n      .withLeftColumnPrefix(\"a\")\n      .withRightColumnPrefix(\"b\")\n      .withDiffMode(DiffMode.RightSide)\n\n    left.diff(right, options, \"id\")\n  }\n\n  test(\"diff where case-insensitive non-id column produces diff column name\") {\n    withSQLConf(SQLConf.CASE_SENSITIVE.key -> \"false\") {\n      val options = DiffOptions.default\n        .withDiffColumn(\"a_value\")\n        .withLeftColumnPrefix(\"A\")\n        .withRightColumnPrefix(\"B\")\n\n      doTestRequirement(\n        left.diff(right, options, \"id\"),\n        \"The column prefixes 'A' and 'B', together with these non-id columns \" +\n          \"must not produce the diff column name 'a_value': value\"\n      )\n      doTestRequirement(\n        left.diff(right, options.withDiffColumn(\"b_value\"), \"id\"),\n        \"The column prefixes 'A' and 'B', together with these non-id columns \" +\n          \"must not produce the diff column name 'b_value': value\"\n      )\n    }\n  }\n\n  test(\"diff with left-side mode where case-insensitive non-id column would produce diff column name\") {\n    withSQLConf(SQLConf.CASE_SENSITIVE.key -> \"false\") {\n      val options = DiffOptions.default\n        .withDiffColumn(\"a_value\")\n        .withLeftColumnPrefix(\"A\")\n        .withRightColumnPrefix(\"B\")\n        .withDiffMode(DiffMode.LeftSide)\n\n      left.diff(right, options, \"id\")\n    }\n  }\n\n  test(\"diff with right-side mode where case-insensitive non-id column would produce diff column name\") {\n    withSQLConf(SQLConf.CASE_SENSITIVE.key -> \"false\") {\n      val options = DiffOptions.default\n        .withDiffColumn(\"a_value\")\n        .withLeftColumnPrefix(\"A\")\n        .withRightColumnPrefix(\"B\")\n        .withDiffMode(DiffMode.RightSide)\n\n      left.diff(right, options, \"id\")\n    }\n  }\n\n  test(\"diff where case-sensitive non-id column produces non-conflicting diff column name\") {\n    withSQLConf(SQLConf.CASE_SENSITIVE.key -> \"true\") {\n      val options = DiffOptions.default\n        .withDiffColumn(\"a_value\")\n        .withLeftColumnPrefix(\"A\")\n        .withRightColumnPrefix(\"B\")\n\n      val actual = left.diff(right, options, \"id\").orderBy(\"id\")\n      val expectedColumns = Seq(\n        \"a_value\",\n        \"id\",\n        \"A_value\",\n        \"B_value\"\n      )\n\n      assert(actual.columns === expectedColumns)\n      assert(actual.collect() === expectedDiff)\n    }\n  }\n\n  test(\"diff with id column change in T\") {\n    val left = Seq(Value4b(1, \"change\")).toDS()\n    val right = Seq(Value4b(1, \"Change\")).toDS()\n\n    val options = DiffOptions.default.withChangeColumn(\"change\")\n\n    doTestRequirement(\n      left.diff(right, options),\n      \"The id columns must not contain the change column name 'change': id, change\"\n    )\n    doTestRequirement(\n      left.diff(right, options, \"change\"),\n      \"The id columns must not contain the change column name 'change': change\"\n    )\n    doTestRequirement(\n      left.diff(right, options, \"change\", \"id\"),\n      \"The id columns must not contain the change column name 'change': change, id\"\n    )\n\n    withSQLConf(SQLConf.CASE_SENSITIVE.key -> \"false\") {\n      doTestRequirement(\n        left\n          .withColumnRenamed(\"change\", \"Change\")\n          .diff(right.withColumnRenamed(\"change\", \"Change\"), options, \"Change\", \"id\"),\n        \"The id columns must not contain the change column name 'change': Change, id\"\n      )\n    }\n\n    withSQLConf(SQLConf.CASE_SENSITIVE.key -> \"true\") {\n      left\n        .withColumnRenamed(\"change\", \"Change\")\n        .diff(right.withColumnRenamed(\"change\", \"Change\"), options, \"Change\", \"id\")\n    }\n  }\n\n  test(\"diff with non-id column change in T\") {\n    val left = Seq(Value4b(1, \"change\")).toDS()\n    val right = Seq(Value4b(1, \"Change\")).toDS()\n\n    val options = DiffOptions.default.withChangeColumn(\"change\")\n\n    val actual = left.diff(right, options, \"id\")\n    val expectedColumns = Seq(\n      \"diff\",\n      \"change\",\n      \"id\",\n      \"left_change\",\n      \"right_change\"\n    )\n    val expectedDiff = Seq(\n      Row(\"C\", Seq(\"change\"), 1, \"change\", \"Change\")\n    )\n\n    assert(actual.columns === expectedColumns)\n    assert(actual.collect() === expectedDiff)\n  }\n\n  test(\"diff where non-id column produces change column name\") {\n    val options = DiffOptions.default\n      .withChangeColumn(\"a_value\")\n      .withLeftColumnPrefix(\"a\")\n      .withRightColumnPrefix(\"b\")\n\n    doTestRequirement(\n      left.diff(right, options, \"id\"),\n      \"The column prefixes 'a' and 'b', together with these non-id columns \" +\n        \"must not produce the change column name 'a_value': value\"\n    )\n  }\n\n  test(\"diff where case-insensitive non-id column produces change column name\") {\n    withSQLConf(SQLConf.CASE_SENSITIVE.key -> \"false\") {\n      val options = DiffOptions.default\n        .withChangeColumn(\"a_value\")\n        .withLeftColumnPrefix(\"A\")\n        .withRightColumnPrefix(\"B\")\n\n      doTestRequirement(\n        left.diff(right, options, \"id\"),\n        \"The column prefixes 'A' and 'B', together with these non-id columns \" +\n          \"must not produce the change column name 'a_value': value\"\n      )\n    }\n  }\n\n  test(\"diff where case-sensitive non-id column produces non-conflicting change column name\") {\n    withSQLConf(SQLConf.CASE_SENSITIVE.key -> \"true\") {\n      val options = DiffOptions.default\n        .withChangeColumn(\"a_value\")\n        .withLeftColumnPrefix(\"A\")\n        .withRightColumnPrefix(\"B\")\n\n      val actual = left7.diff(right7, options, \"id\").orderBy(\"id\")\n      val expectedColumns = Seq(\n        \"diff\",\n        \"a_value\",\n        \"id\",\n        \"A_value\",\n        \"B_value\",\n        \"A_label\",\n        \"B_label\"\n      )\n\n      assert(actual.columns === expectedColumns)\n      assert(actual.collect() === expectedDiff7WithChanges)\n    }\n  }\n\n  test(\"diff where non-id column produces id column name\") {\n    val options = DiffOptions.default\n      .withLeftColumnPrefix(\"first\")\n      .withRightColumnPrefix(\"second\")\n\n    val left = Seq(Value5(1, \"value\")).toDS()\n    val right = Seq(Value5(1, \"Value\")).toDS()\n\n    doTestRequirement(\n      left.diff(right, options, \"first_id\"),\n      \"The column prefixes 'first' and 'second', together with these non-id columns \" +\n        \"must not produce any id column name 'first_id': id\"\n    )\n  }\n\n  test(\"diff where case-insensitive non-id column produces id column name\") {\n    withSQLConf(SQLConf.CASE_SENSITIVE.key -> \"false\") {\n      val options = DiffOptions.default\n        .withLeftColumnPrefix(\"FIRST\")\n        .withRightColumnPrefix(\"SECOND\")\n\n      val left = Seq(Value5(1, \"value\")).toDS()\n      val right = Seq(Value5(1, \"Value\")).toDS()\n\n      doTestRequirement(\n        left.diff(right, options, \"first_id\"),\n        \"The column prefixes 'FIRST' and 'SECOND', together with these non-id columns \" +\n          \"must not produce any id column name 'first_id': id\"\n      )\n    }\n  }\n\n  test(\"diff where case-sensitive non-id column produces non-conflicting id column name\") {\n    withSQLConf(SQLConf.CASE_SENSITIVE.key -> \"true\") {\n      val options = DiffOptions.default\n        .withLeftColumnPrefix(\"FIRST\")\n        .withRightColumnPrefix(\"SECOND\")\n\n      val left = Seq(Value5(1, \"value\")).toDS()\n      val right = Seq(Value5(1, \"Value\")).toDS()\n\n      val actual = left.diff(right, options, \"first_id\")\n      val expectedColumns = Seq(\n        \"diff\",\n        \"first_id\",\n        \"FIRST_id\",\n        \"SECOND_id\"\n      )\n\n      assert(actual.columns === expectedColumns)\n      assert(actual.collect() === Seq(Row(\"C\", 1, \"value\", \"Value\")))\n    }\n  }\n\n  test(\"diff with custom diff options\") {\n    val options = DiffOptions(\"action\", \"before\", \"after\", \"new\", \"change\", \"del\", \"eq\")\n\n    val expected = Seq(\n      Row(\"eq\", 1, \"one\", \"one\"),\n      Row(\"change\", 2, \"two\", \"Two\"),\n      Row(\"del\", 3, \"three\", null),\n      Row(\"new\", 4, null, \"four\")\n    )\n\n    val actual = left.diff(right, options, \"id\").orderBy(\"id\", \"action\")\n\n    assert(actual.columns === Seq(\"action\", \"id\", \"before_value\", \"after_value\"))\n    assert(actual.collect() === expected)\n  }\n\n  test(\"diff of empty schema\") {\n    val left = Seq(Empty()).toDS()\n    val right = Seq(Empty()).toDS()\n\n    doTestRequirement(left.diff(right), \"The schema must not be empty\")\n  }\n\n  test(\"diff similar with ignored columns and empty schema\") {\n    val left = Seq((1, \"info\")).toDF(\"id\", \"info\")\n    val right = Seq((1, \"meta\")).toDF(\"id\", \"meta\")\n\n    doTestRequirement(\n      left.diff(right, Seq.empty, Seq(\"id\", \"info\", \"meta\")),\n      \"The schema except ignored columns must not be empty\"\n    )\n  }\n\n  test(\"diff with different types\") {\n    // different value types only compiles with DataFrames\n    val left = Seq((1, \"str\")).toDF(\"id\", \"value\")\n    val right = Seq((1, 2)).toDF(\"id\", \"value\")\n\n    doTestRequirement(\n      left.diff(right),\n      \"The datasets do not have the same schema.\\n\" +\n        \"Left extra columns: value (StringType)\\n\" +\n        \"Right extra columns: value (IntegerType)\"\n    )\n  }\n\n  test(\"diff with ignored columns of different types\") {\n    // different value types only compile with DataFrames\n    val left = Seq((1, \"str\")).toDF(\"id\", \"value\")\n    val right = Seq((1, 2)).toDF(\"id\", \"value\")\n\n    val actual = left.diff(right, Seq.empty, Seq(\"value\"))\n    assert(\n      ignoreNullable(actual.schema) === StructType(\n        Seq(\n          StructField(\"diff\", StringType),\n          StructField(\"id\", IntegerType),\n          StructField(\"left_value\", StringType),\n          StructField(\"right_value\", IntegerType),\n        )\n      )\n    )\n    assert(actual.collect() === Seq(Row(\"N\", 1, \"str\", 2)))\n  }\n\n  test(\"diff with different nullability\") {\n    val leftSchema = StructType(left.schema.fields.map(_.copy(nullable = true)))\n    val rightSchema = StructType(right.schema.fields.map(_.copy(nullable = false)))\n\n    // different value types only compiles with DataFrames\n    val left2 = sql.createDataFrame(left.toDF().rdd, leftSchema)\n    val right2 = sql.createDataFrame(right.toDF().rdd, rightSchema)\n\n    val actual = left2.diff(right2, \"id\").orderBy(\"id\")\n    val reverse = right2.diff(left2, \"id\").orderBy(\"id\")\n\n    assert(actual.columns === expectedDiffColumns)\n    assert(actual.collect() === expectedDiff)\n    assert(reverse.columns === expectedDiffColumns)\n    assert(reverse.collect() === expectedReverseDiff)\n  }\n\n  test(\"diff with different column names\") {\n    // different column names only compiles with DataFrames\n    val left = Seq((1, \"str\")).toDF(\"id\", \"value\")\n    val right = Seq((1, \"str\")).toDF(\"id\", \"comment\")\n\n    doTestRequirement(\n      left.diff(right, \"id\"),\n      \"The datasets do not have the same schema.\\n\" +\n        \"Left extra columns: value (StringType)\\n\" +\n        \"Right extra columns: comment (StringType)\"\n    )\n  }\n\n  test(\"diff with case-insensitive column names\") {\n    withSQLConf(SQLConf.CASE_SENSITIVE.key -> \"false\") {\n      // different column names only compiles with DataFrames\n      val left = this.left.toDF(\"id\", \"value\")\n      val right = this.right.toDF(\"ID\", \"VaLuE\")\n\n      val actual = left.diff(right, \"id\").orderBy(\"id\")\n      val reverse = right.diff(left, \"id\").orderBy(\"id\")\n\n      assert(actual.columns === Seq(\"diff\", \"id\", \"left_value\", \"right_VaLuE\"))\n      assert(actual.collect() === expectedDiff)\n      assert(reverse.columns === Seq(\"diff\", \"id\", \"left_VaLuE\", \"right_value\"))\n      assert(reverse.collect() === expectedReverseDiff)\n    }\n  }\n\n  test(\"diff with case-sensitive column names\") {\n    // different column names only compiles with DataFrames\n    val left = this.left.toDF(\"id\", \"value\")\n    val right = this.right.toDF(\"ID\", \"VaLuE\")\n\n    withSQLConf(SQLConf.CASE_SENSITIVE.key -> \"true\") {\n      doTestRequirement(\n        left.diff(right, \"id\"),\n        \"The datasets do not have the same schema.\\n\" +\n          \"Left extra columns: id (IntegerType), value (StringType)\\n\" +\n          \"Right extra columns: ID (IntegerType), VaLuE (StringType)\"\n      )\n    }\n  }\n\n  test(\"diff of non-existing id column\") {\n    doTestRequirement(\n      left.diff(right, \"does not exists\"),\n      \"Some id columns do not exist: does not exists missing among id, value\"\n    )\n  }\n\n  test(\"diff with different number of columns\") {\n    // different column names only compiles with DataFrames\n    val left = Seq((1, \"str\")).toDF(\"id\", \"value\")\n    val right = Seq((1, 1, \"str\")).toDF(\"id\", \"seq\", \"value\")\n\n    doTestRequirement(\n      left.diff(right, \"id\"),\n      \"The number of columns doesn't match.\\n\" +\n        \"Left column names (2): id, value\\n\" +\n        \"Right column names (3): id, seq, value\"\n    )\n  }\n\n  test(\"diff similar with ignored column and different number of columns\") {\n    val left = Seq((1, \"str\", \"meta\")).toDF(\"id\", \"value\", \"meta\")\n    val right = Seq((1, 1, \"str\")).toDF(\"id\", \"seq\", \"value\")\n\n    doTestRequirement(\n      left.diff(right, Seq(\"id\"), Seq(\"meta\")),\n      \"The number of columns doesn't match.\\n\" +\n        \"Left column names except ignored columns (2): id, value\\n\" +\n        \"Right column names except ignored columns (3): id, seq, value\"\n    )\n  }\n\n  test(\"diff as U\") {\n    val actual = left.diffAs[DiffAs](right, \"id\").orderBy(\"id\")\n\n    assert(actual.columns === Seq(\"diff\", \"id\", \"left_value\", \"right_value\"))\n    assert(actual.collect() === expectedDiffAs)\n  }\n\n  test(\"diff as U with encoder\") {\n    val encoder = Encoders.product[DiffAs]\n\n    val actual = left.diffAs(right, encoder, \"id\").orderBy(\"id\")\n\n    assert(actual.columns === Seq(\"diff\", \"id\", \"left_value\", \"right_value\"))\n    assert(actual.collect() === expectedDiffAs)\n  }\n\n  test(\"diff as U with encoder and custom options\") {\n    val options = DiffOptions(\"action\", \"before\", \"after\", \"new\", \"change\", \"del\", \"eq\")\n    val encoder = Encoders.product[DiffAsCustom]\n\n    val actions = Seq(\n      (DiffOptions.default.insertDiffValue, \"new\"),\n      (DiffOptions.default.changeDiffValue, \"change\"),\n      (DiffOptions.default.deleteDiffValue, \"del\"),\n      (DiffOptions.default.nochangeDiffValue, \"eq\")\n    ).toDF(\"diff\", \"action\")\n\n    val expected = expectedDiffAs\n      .toDS()\n      .join(actions, \"diff\")\n      .select($\"action\", $\"id\", $\"left_value\".as(\"before_value\"), $\"right_value\".as(\"after_value\"))\n      .as[DiffAsCustom]\n      .collect()\n\n    val actual = left.diffAs(right, options, encoder, \"id\").orderBy(\"id\")\n\n    assert(actual.columns === Seq(\"action\", \"id\", \"before_value\", \"after_value\"))\n    assert(actual.collect() === expected)\n  }\n\n  test(\"diff as U with subset of columns\") {\n    val expected = expectedDiff.map(row => DiffAsSubset(row.getString(0), row.getInt(1), Option(row.getString(2))))\n\n    val actual = left.diffAs[DiffAsSubset](right, \"id\").orderBy(\"id\")\n\n    assert(Seq(\"diff\", \"id\", \"left_value\").forall(column => actual.columns.contains(column)))\n    assert(actual.collect() === expected)\n  }\n\n  test(\"diff as U with extra column\") {\n    doTestRequirement(\n      left.diffAs[DiffAsExtra](right, \"id\"),\n      \"Diff encoder's columns must be part of the diff result schema, these columns are unexpected: extra\"\n    )\n  }\n\n  test(\"diff with change column\") {\n    val options = DiffOptions.default.withChangeColumn(\"changes\")\n    val actual = left7.diff(right7, options, \"id\").orderBy(\"id\")\n\n    assert(actual.columns === Seq(\"diff\", \"changes\", \"id\", \"left_value\", \"right_value\", \"left_label\", \"right_label\"))\n    assert(\n      actual.schema === StructType(\n        Seq(\n          StructField(\"diff\", StringType, nullable = false),\n          StructField(\"changes\", ArrayType(StringType, containsNull = false), nullable = true),\n          StructField(\"id\", IntegerType, nullable = true),\n          StructField(\"left_value\", StringType, nullable = true),\n          StructField(\"right_value\", StringType, nullable = true),\n          StructField(\"left_label\", StringType, nullable = true),\n          StructField(\"right_label\", StringType, nullable = true)\n        )\n      )\n    )\n    assert(actual.collect() === expectedDiff7WithChanges)\n  }\n\n  test(\"diff with change column without id columns\") {\n    val options = DiffOptions.default.withChangeColumn(\"changes\")\n    val actual = left7.diff(right7, options)\n\n    assert(actual.columns === Seq(\"diff\", \"changes\", \"id\", \"value\", \"label\"))\n    assert(\n      actual.schema === StructType(\n        Seq(\n          StructField(\"diff\", StringType, nullable = false),\n          StructField(\"changes\", ArrayType(StringType, containsNull = false), nullable = true),\n          StructField(\"id\", IntegerType, nullable = true),\n          StructField(\"value\", StringType, nullable = true),\n          StructField(\"label\", StringType, nullable = true)\n        )\n      )\n    )\n    assert(\n      actual.select($\"diff\", $\"changes\").distinct().orderBy($\"diff\").collect() ===\n        Seq(Row(\"D\", null), Row(\"I\", null), Row(\"N\", Seq.empty[String]))\n    )\n  }\n\n  test(\"diff with change column name in non-id columns\") {\n    val options = DiffOptions.default.withChangeColumn(\"value\")\n    val actual = left7.diff(right7, options, \"id\").orderBy(\"id\")\n\n    assert(actual.columns === Seq(\"diff\", \"value\", \"id\", \"left_value\", \"right_value\", \"left_label\", \"right_label\"))\n    assert(actual.collect() === expectedDiff7WithChanges)\n  }\n\n  test(\"diff with change column name in id columns\") {\n    val options = DiffOptions.default.withChangeColumn(\"value\")\n    doTestRequirement(\n      left.diff(right, options, \"id\", \"value\"),\n      \"The id columns must not contain the change column name 'value': id, value\"\n    )\n  }\n\n  test(\"diff with column-by-column diff mode\") {\n    val options = DiffOptions.default.withDiffMode(DiffMode.ColumnByColumn)\n    val actual = left7.diff(right7, options, \"id\").orderBy(\"id\")\n\n    assert(actual.columns === Seq(\"diff\", \"id\", \"left_value\", \"right_value\", \"left_label\", \"right_label\"))\n    assert(actual.collect() === expectedDiff7)\n  }\n\n  test(\"diff with side-by-side diff mode\") {\n    val options = DiffOptions.default.withDiffMode(DiffMode.SideBySide)\n    val actual = left7.diff(right7, options, \"id\").orderBy(\"id\")\n\n    assert(actual.columns === Seq(\"diff\", \"id\", \"left_value\", \"left_label\", \"right_value\", \"right_label\"))\n    assert(actual.collect() === expectedSideBySideDiff7)\n  }\n\n  test(\"diff with left-side diff mode\") {\n    val options = DiffOptions.default.withDiffMode(DiffMode.LeftSide)\n    val actual = left7.diff(right7, options, \"id\").orderBy(\"id\")\n\n    assert(actual.columns === Seq(\"diff\", \"id\", \"value\", \"label\"))\n    assert(actual.collect() === expectedLeftSideDiff7)\n  }\n\n  test(\"diff with right-side diff mode\") {\n    val options = DiffOptions.default.withDiffMode(DiffMode.RightSide)\n    val actual = left7.diff(right7, options, \"id\").orderBy(\"id\")\n\n    assert(actual.columns === Seq(\"diff\", \"id\", \"value\", \"label\"))\n    assert(actual.collect() === expectedRightSideDiff7)\n  }\n\n  test(\"diff as U with left-side diff mode\") {\n    val options = DiffOptions.default.withDiffMode(DiffMode.LeftSide)\n    val actual = left.diffAs[DiffAsOneSide](right, options, \"id\").orderBy(\"id\")\n\n    assert(actual.columns === Seq(\"diff\", \"id\", \"value\"))\n    val expected: Seq[DiffAsOneSide] = Seq(\n      DiffAsOneSide(\"N\", 1, Some(\"one\")),\n      DiffAsOneSide(\"C\", 2, Some(\"two\")),\n      DiffAsOneSide(\"D\", 3, Some(\"three\")),\n      DiffAsOneSide(\"I\", 4, None)\n    )\n    assert(actual.collect() === expected)\n  }\n\n  test(\"diff as U with right-side diff mode\") {\n    val options = DiffOptions.default.withDiffMode(DiffMode.RightSide)\n    val actual = left.diffAs[DiffAsOneSide](right, options, \"id\").orderBy(\"id\")\n\n    assert(actual.columns === Seq(\"diff\", \"id\", \"value\"))\n    val expected: Seq[DiffAsOneSide] = Seq(\n      DiffAsOneSide(\"N\", 1, Some(\"one\")),\n      DiffAsOneSide(\"C\", 2, Some(\"Two\")),\n      DiffAsOneSide(\"D\", 3, None),\n      DiffAsOneSide(\"I\", 4, Some(\"four\"))\n    )\n    assert(actual.collect() === expected)\n  }\n\n  test(\"diff with left-side diff mode and diff column name in value columns\") {\n    val options = DiffOptions.default.withDiffColumn(\"value\").withDiffMode(DiffMode.LeftSide)\n    doTestRequirement(\n      left.diff(right, options, \"id\"),\n      \"The left non-id columns must not contain the diff column name 'value': value\"\n    )\n  }\n\n  test(\"diff with right-side diff mode and diff column name in value columns\") {\n    val options = DiffOptions.default.withDiffColumn(\"value\").withDiffMode(DiffMode.RightSide)\n    doTestRequirement(\n      right.diff(right, options, \"id\"),\n      \"The right non-id columns must not contain the diff column name 'value': value\"\n    )\n  }\n\n  test(\"diff with left-side diff mode and change column name in value columns\") {\n    val options = DiffOptions.default.withChangeColumn(\"value\").withDiffMode(DiffMode.LeftSide)\n    doTestRequirement(\n      left.diff(right, options, \"id\"),\n      \"The left non-id columns must not contain the change column name 'value': value\"\n    )\n  }\n\n  test(\"diff with right-side diff mode and change column name in value columns\") {\n    val options = DiffOptions.default.withChangeColumn(\"value\").withDiffMode(DiffMode.RightSide)\n    doTestRequirement(\n      right.diff(right, options, \"id\"),\n      \"The right non-id columns must not contain the change column name 'value': value\"\n    )\n  }\n\n  test(\"diff with dots in diff column\") {\n    val options = DiffOptions.default\n      .withDiffColumn(\"the.diff\")\n\n    val actual = left.diff(right, options, \"id\").orderBy(\"id\")\n    val expectedDiffColumns = Seq(\"the.diff\", \"id\", \"left_value\", \"right_value\")\n\n    assert(actual.columns === expectedDiffColumns)\n    assert(actual.collect() === expectedDiff)\n  }\n\n  test(\"diff with dots in change column\") {\n    val options = DiffOptions.default\n      .withChangeColumn(\"the.changes\")\n\n    val actual = left7.diff(right7, options, \"id\").orderBy(\"id\")\n    val expectedDiffColumns = Seq(\"diff\", \"the.changes\", \"id\", \"left_value\", \"right_value\", \"left_label\", \"right_label\")\n\n    assert(actual.columns === expectedDiffColumns)\n    assert(actual.collect() === expectedDiff7WithChanges)\n  }\n\n  test(\"diff with dots in prefixes\") {\n    val options = DiffOptions.default\n      .withLeftColumnPrefix(\"left.prefix\")\n      .withRightColumnPrefix(\"right.prefix\")\n\n    val actual = left.diff(right, options, \"id\").orderBy(\"id\")\n    val expectedDiffColumns = Seq(\"diff\", \"id\", \"left.prefix_value\", \"right.prefix_value\")\n\n    assert(actual.columns === expectedDiffColumns)\n    assert(actual.collect() === expectedDiff)\n  }\n\n  test(\"diff with dot in id column\") {\n    val l = left7.withColumnRenamed(\"id\", \"the.id\")\n    val r = right7.withColumnRenamed(\"id\", \"the.id\")\n\n    val actual = l.diff(r, \"the.id\").orderBy(\"`the.id`\")\n    val expectedDiffColumns = Seq(\"diff\", \"the.id\", \"left_value\", \"right_value\", \"left_label\", \"right_label\")\n\n    assert(actual.columns === expectedDiffColumns)\n    assert(actual.collect() === expectedDiff7)\n  }\n\n  test(\"diff with dot in value column\") {\n    val l = left7.withColumnRenamed(\"value\", \"the.value\")\n    val r = right7.withColumnRenamed(\"value\", \"the.value\")\n\n    val actual = l.diff(r, \"id\").orderBy(\"id\")\n    val expectedDiffColumns = Seq(\"diff\", \"id\", \"left_the.value\", \"right_the.value\", \"left_label\", \"right_label\")\n\n    assert(actual.columns === expectedDiffColumns)\n    assert(actual.collect() === expectedDiff7)\n  }\n\n  test(\"diff with left-side diff mode and dot in value column\") {\n    val l = left7.withColumnRenamed(\"value\", \"the.value\")\n    val r = right7.withColumnRenamed(\"value\", \"the.value\")\n    val options = DiffOptions.default.withDiffMode(DiffMode.LeftSide)\n\n    val actual = l.diff(r, options, \"id\").orderBy(\"id\")\n    val expectedDiffColumns = Seq(\"diff\", \"id\", \"the.value\", \"label\")\n\n    assert(actual.columns === expectedDiffColumns)\n    assert(actual.collect() === expectedLeftSideDiff7)\n  }\n\n  test(\"diff with right-side diff mode and dot in value column\") {\n    val l = left7.withColumnRenamed(\"value\", \"the.value\")\n    val r = right7.withColumnRenamed(\"value\", \"the.value\")\n    val options = DiffOptions.default.withDiffMode(DiffMode.RightSide)\n\n    val actual = l.diff(r, options, \"id\").orderBy(\"id\")\n    val expectedDiffColumns = Seq(\"diff\", \"id\", \"the.value\", \"label\")\n\n    assert(actual.columns === expectedDiffColumns)\n    assert(actual.collect() === expectedRightSideDiff7)\n  }\n\n  test(\"diff with column-by-column and sparse mode\") {\n    val options = DiffOptions.default.withDiffMode(DiffMode.ColumnByColumn).withSparseMode(true)\n    val actual = left7.diff(right7, options, \"id\").orderBy(\"id\")\n\n    assert(actual.columns === Seq(\"diff\", \"id\", \"left_value\", \"right_value\", \"left_label\", \"right_label\"))\n    assert(actual.collect() === expectedSparseDiff7)\n  }\n\n  test(\"diff with side-by-side and sparse mode\") {\n    val options = DiffOptions.default.withDiffMode(DiffMode.SideBySide).withSparseMode(true)\n    val actual = left7.diff(right7, options, \"id\").orderBy(\"id\")\n\n    assert(actual.columns === Seq(\"diff\", \"id\", \"left_value\", \"left_label\", \"right_value\", \"right_label\"))\n    assert(actual.collect() === expectedSideBySideSparseDiff7)\n  }\n\n  test(\"diff with left side and sparse mode\") {\n    val options = DiffOptions.default.withDiffMode(DiffMode.LeftSide).withSparseMode(true)\n    val actual = left7.diff(right7, options, \"id\").orderBy(\"id\")\n\n    assert(actual.columns === Seq(\"diff\", \"id\", \"value\", \"label\"))\n    assert(actual.collect() === expectedLeftSideSparseDiff7)\n  }\n\n  test(\"diff with right side and sparse mode\") {\n    val options = DiffOptions.default.withDiffMode(DiffMode.RightSide).withSparseMode(true)\n    val actual = left7.diff(right7, options, \"id\").orderBy(\"id\")\n\n    assert(actual.columns === Seq(\"diff\", \"id\", \"value\", \"label\"))\n    assert(actual.collect() === expectedRightSideSparseDiff7)\n  }\n\n  def ignoreNullable(schema: StructType): StructType = {\n    schema.copy(fields =\n      schema.fields\n        .map(_.copy(nullable = true))\n        .map(field =>\n          field.dataType match {\n            case a: ArrayType => field.copy(dataType = a.copy(containsNull = false))\n            case _            => field\n          }\n        )\n    )\n  }\n\n  def assertIgnoredColumns[T](actual: Dataset[T], expected: Seq[T], expectedSchema: StructType): Unit = {\n    // ignore nullable\n    assert(ignoreNullable(actual.schema) === ignoreNullable(expectedSchema))\n    assert(actual.orderBy(\"id\", \"seq\").collect() === expected)\n  }\n\n  test(\"diff with ignored columns\") {\n    assertIgnoredColumns(\n      left8.diff(right8, Seq(\"id\", \"seq\"), Seq(\"meta\")),\n      expectedDiff8,\n      Encoders.product[DiffAs8].schema\n    )\n    assertIgnoredColumns(\n      Diff.of(left8, right8, Seq(\"id\", \"seq\"), Seq(\"meta\")),\n      expectedDiff8,\n      Encoders.product[DiffAs8].schema\n    )\n    assertIgnoredColumns(\n      Diff.default.diff(left8, right8, Seq(\"id\", \"seq\"), Seq(\"meta\")),\n      expectedDiff8,\n      Encoders.product[DiffAs8].schema\n    )\n\n    assertIgnoredColumns[DiffAs8](\n      left8.diffAs(right8, Seq(\"id\", \"seq\"), Seq(\"meta\")),\n      expectedDiffAs8,\n      Encoders.product[DiffAs8].schema\n    )\n    assertIgnoredColumns[DiffAs8](\n      Diff.ofAs(left8, right8, Seq(\"id\", \"seq\"), Seq(\"meta\")),\n      expectedDiffAs8,\n      Encoders.product[DiffAs8].schema\n    )\n    assertIgnoredColumns[DiffAs8](\n      Diff.default.diffAs(left8, right8, Seq(\"id\", \"seq\"), Seq(\"meta\")),\n      expectedDiffAs8,\n      Encoders.product[DiffAs8].schema\n    )\n\n    val expected = expectedDiff8\n      .map(row =>\n        (\n          row.getString(0),\n          Value8(\n            row.getInt(1),\n            Option(row.get(2)).map(_.asInstanceOf[Int]),\n            Option(row.getString(3)),\n            Option(row.getString(5))\n          ),\n          Value8(\n            row.getInt(1),\n            Option(row.get(2)).map(_.asInstanceOf[Int]),\n            Option(row.getString(4)),\n            Option(row.getString(6))\n          )\n        )\n      )\n      .map { case (diff, left, right) =>\n        (\n          diff,\n          if (diff == \"I\") null else left,\n          if (diff == \"D\") null else right\n        )\n      }\n\n    assertDiffWith(left8.diffWith(right8, Seq(\"id\", \"seq\"), Seq(\"meta\")).collect(), expected)\n    assertDiffWith(Diff.ofWith(left8, right8, Seq(\"id\", \"seq\"), Seq(\"meta\")).collect(), expected)\n    assertDiffWith(Diff.default.diffWith(left8, right8, Seq(\"id\", \"seq\"), Seq(\"meta\")).collect(), expected)\n  }\n\n  test(\"diff with ignored and change columns\") {\n    val options = DiffOptions.default.withChangeColumn(\"changed\")\n    val differ = new Differ(options)\n\n    assertIgnoredColumns(\n      left8.diff(right8, options, Seq(\"id\", \"seq\"), Seq(\"meta\")),\n      expectedDiff8WithChanges,\n      Encoders.product[DiffAs8changes].schema\n    )\n    assertIgnoredColumns(\n      differ.diff(left8, right8, Seq(\"id\", \"seq\"), Seq(\"meta\")),\n      expectedDiff8WithChanges,\n      Encoders.product[DiffAs8changes].schema\n    )\n  }\n\n  test(\"diff with ignored columns and column-by-column diff mode\") {\n    val options = DiffOptions.default.withDiffMode(DiffMode.ColumnByColumn)\n    val differ = new Differ(options)\n\n    assertIgnoredColumns(\n      left8.diff(right8, options, Seq(\"id\", \"seq\"), Seq(\"meta\")),\n      expectedDiff8,\n      Encoders.product[DiffAs8].schema\n    )\n    assertIgnoredColumns(\n      differ.diff(left8, right8, Seq(\"id\", \"seq\"), Seq(\"meta\")),\n      expectedDiff8,\n      Encoders.product[DiffAs8].schema\n    )\n  }\n\n  test(\"diff with ignored columns and side-by-side diff mode\") {\n    val options = DiffOptions.default.withDiffMode(DiffMode.SideBySide)\n    val differ = new Differ(options)\n\n    assertIgnoredColumns(\n      left8.diff(right8, options, Seq(\"id\", \"seq\"), Seq(\"meta\")),\n      expectedSideBySideDiff8,\n      Encoders.product[DiffAs8SideBySide].schema\n    )\n    assertIgnoredColumns(\n      differ.diff(left8, right8, Seq(\"id\", \"seq\"), Seq(\"meta\")),\n      expectedSideBySideDiff8,\n      Encoders.product[DiffAs8SideBySide].schema\n    )\n  }\n\n  test(\"diff with ignored columns and left-side diff mode\") {\n    val options = DiffOptions.default.withDiffMode(DiffMode.LeftSide)\n    val differ = new Differ(options)\n\n    assertIgnoredColumns(\n      left8.diff(right8, options, Seq(\"id\", \"seq\"), Seq(\"meta\")),\n      expectedLeftSideDiff8,\n      Encoders.product[DiffAs8OneSide].schema\n    )\n    assertIgnoredColumns(\n      differ.diff(left8, right8, Seq(\"id\", \"seq\"), Seq(\"meta\")),\n      expectedLeftSideDiff8,\n      Encoders.product[DiffAs8OneSide].schema\n    )\n  }\n\n  test(\"diff with ignored columns and right-side diff mode\") {\n    val options = DiffOptions.default.withDiffMode(DiffMode.RightSide)\n    val differ = new Differ(options)\n\n    assertIgnoredColumns(\n      left8.diff(right8, options, Seq(\"id\", \"seq\"), Seq(\"meta\")),\n      expectedRightSideDiff8,\n      Encoders.product[DiffAs8OneSide].schema\n    )\n    assertIgnoredColumns(\n      differ.diff(left8, right8, Seq(\"id\", \"seq\"), Seq(\"meta\")),\n      expectedRightSideDiff8,\n      Encoders.product[DiffAs8OneSide].schema\n    )\n  }\n\n  test(\"diff with ignored columns, column-by-column diff and sparse mode\") {\n    val options = DiffOptions.default.withDiffMode(DiffMode.ColumnByColumn).withSparseMode(true)\n    val differ = new Differ(options)\n\n    assertIgnoredColumns(\n      left8.diff(right8, options, Seq(\"id\", \"seq\"), Seq(\"meta\")),\n      expectedSparseDiff8,\n      Encoders.product[DiffAs8].schema\n    )\n    assertIgnoredColumns(\n      differ.diff(left8, right8, Seq(\"id\", \"seq\"), Seq(\"meta\")),\n      expectedSparseDiff8,\n      Encoders.product[DiffAs8].schema\n    )\n  }\n\n  test(\"diff with ignored columns, side-by-side diff and sparse mode\") {\n    val options = DiffOptions.default.withDiffMode(DiffMode.SideBySide).withSparseMode(true)\n    val differ = new Differ(options)\n\n    assertIgnoredColumns(\n      left8.diff(right8, options, Seq(\"id\", \"seq\"), Seq(\"meta\")),\n      expectedSideBySideSparseDiff8,\n      Encoders.product[DiffAs8SideBySide].schema\n    )\n    assertIgnoredColumns(\n      differ.diff(left8, right8, Seq(\"id\", \"seq\"), Seq(\"meta\")),\n      expectedSideBySideSparseDiff8,\n      Encoders.product[DiffAs8SideBySide].schema\n    )\n  }\n\n  test(\"diff with ignored columns, left-side diff and sparse mode\") {\n    val options = DiffOptions.default.withDiffMode(DiffMode.LeftSide).withSparseMode(true)\n    val differ = new Differ(options)\n\n    assertIgnoredColumns(\n      left8.diff(right8, options, Seq(\"id\", \"seq\"), Seq(\"meta\")),\n      expectedLeftSideSparseDiff8,\n      Encoders.product[DiffAs8OneSide].schema\n    )\n    assertIgnoredColumns(\n      differ.diff(left8, right8, Seq(\"id\", \"seq\"), Seq(\"meta\")),\n      expectedLeftSideSparseDiff8,\n      Encoders.product[DiffAs8OneSide].schema\n    )\n  }\n\n  test(\"diff with ignored columns, right-side diff and sparse mode\") {\n    val options = DiffOptions.default.withDiffMode(DiffMode.RightSide).withSparseMode(true)\n    val differ = new Differ(options)\n\n    assertIgnoredColumns(\n      left8.diff(right8, options, Seq(\"id\", \"seq\"), Seq(\"meta\")),\n      expectedRightSideSparseDiff8,\n      Encoders.product[DiffAs8OneSide].schema\n    )\n    assertIgnoredColumns(\n      differ.diff(left8, right8, Seq(\"id\", \"seq\"), Seq(\"meta\")),\n      expectedRightSideSparseDiff8,\n      Encoders.product[DiffAs8OneSide].schema\n    )\n  }\n\n  test(\"diff similar with ignored columns\") {\n    val expectedSchema = StructType(\n      Seq(\n        StructField(\"diff\", StringType),\n        StructField(\"id\", IntegerType),\n        StructField(\"seq\", IntegerType),\n        StructField(\"left_value\", StringType),\n        StructField(\"right_value\", StringType),\n        StructField(\"left_meta\", StringType),\n        StructField(\"right_info\", StringType),\n      )\n    )\n\n    assertIgnoredColumns(left8.diff(right9, Seq(\"id\", \"seq\"), Seq(\"meta\", \"info\")), expectedDiff8and9, expectedSchema)\n    assertIgnoredColumns(\n      Diff.of(left8, right9, Seq(\"id\", \"seq\"), Seq(\"meta\", \"info\")),\n      expectedDiff8and9,\n      expectedSchema\n    )\n    assertIgnoredColumns(\n      Diff.default.diff(left8, right9, Seq(\"id\", \"seq\"), Seq(\"meta\", \"info\")),\n      expectedDiff8and9,\n      expectedSchema\n    )\n\n    assertIgnoredColumns[DiffAs8and9](\n      left8.diffAs(right9, Seq(\"id\", \"seq\"), Seq(\"meta\", \"info\")),\n      expectedDiffAs8and9,\n      expectedSchema\n    )\n    assertIgnoredColumns[DiffAs8and9](\n      Diff.ofAs(left8, right9, Seq(\"id\", \"seq\"), Seq(\"meta\", \"info\")),\n      expectedDiffAs8and9,\n      expectedSchema\n    )\n    assertIgnoredColumns[DiffAs8and9](\n      Diff.default.diffAs(left8, right9, Seq(\"id\", \"seq\"), Seq(\"meta\", \"info\")),\n      expectedDiffAs8and9,\n      expectedSchema\n    )\n\n    val expectedSchemaWith = StructType(\n      Seq(\n        StructField(\"_1\", StringType),\n        StructField(\n          \"_2\",\n          StructType(\n            Seq(\n              StructField(\"id\", IntegerType, nullable = true),\n              StructField(\"seq\", IntegerType, nullable = true),\n              StructField(\"value\", StringType, nullable = true),\n              StructField(\"meta\", StringType, nullable = true)\n            )\n          )\n        ),\n        StructField(\n          \"_3\",\n          StructType(\n            Seq(\n              StructField(\"id\", IntegerType, nullable = true),\n              StructField(\"seq\", IntegerType, nullable = true),\n              StructField(\"value\", StringType, nullable = true),\n              StructField(\"info\", StringType, nullable = true)\n            )\n          )\n        ),\n      )\n    )\n\n    assertDiffWithSchema(\n      left8.diffWith(right9, Seq(\"id\", \"seq\"), Seq(\"meta\", \"info\")),\n      expectedDiffWith8and9,\n      expectedSchemaWith\n    )\n    assertDiffWithSchema(\n      Diff.ofWith(left8, right9, Seq(\"id\", \"seq\"), Seq(\"meta\", \"info\")),\n      expectedDiffWith8and9,\n      expectedSchemaWith\n    )\n    assertDiffWithSchema(\n      Diff.default.diffWith(left8, right9, Seq(\"id\", \"seq\"), Seq(\"meta\", \"info\")),\n      expectedDiffWith8and9,\n      expectedSchemaWith\n    )\n  }\n\n  test(\"diff similar with ignored columns of different type\") {\n    // TODO\n  }\n\n  test(\"diff with ignored columns case-insensitive\") {\n    withSQLConf(SQLConf.CASE_SENSITIVE.key -> \"false\") {\n      val left = left8.toDF(\"id\", \"seq\", \"value\", \"meta\")\n      val right = right8.toDF(\"ID\", \"SEQ\", \"VALUE\", \"META\")\n\n      def expectedSchema(id: String, seq: String): StructType =\n        StructType(\n          Seq(\n            StructField(\"diff\", StringType),\n            StructField(id, IntegerType),\n            StructField(seq, IntegerType),\n            StructField(\"left_value\", StringType),\n            StructField(\"right_VALUE\", StringType),\n            StructField(\"left_meta\", StringType),\n            StructField(\"right_META\", StringType),\n          )\n        )\n\n      assertIgnoredColumns(left.diff(right, Seq(\"iD\", \"sEq\"), Seq(\"MeTa\")), expectedDiff8, expectedSchema(\"iD\", \"sEq\"))\n      assertIgnoredColumns(\n        Diff.of(left, right, Seq(\"Id\", \"SeQ\"), Seq(\"mEtA\")),\n        expectedDiff8,\n        expectedSchema(\"Id\", \"SeQ\")\n      )\n      assertIgnoredColumns(\n        Diff.default.diff(left, right, Seq(\"ID\", \"SEQ\"), Seq(\"META\")),\n        expectedDiff8,\n        expectedSchema(\"ID\", \"SEQ\")\n      )\n\n      assertIgnoredColumns[DiffAs8](\n        left.diffAs(right, Seq(\"id\", \"seq\"), Seq(\"MeTa\")),\n        expectedDiffAs8,\n        expectedSchema(\"id\", \"seq\")\n      )\n      assertIgnoredColumns[DiffAs8](\n        Diff.ofAs(left, right, Seq(\"id\", \"seq\"), Seq(\"mEtA\")),\n        expectedDiffAs8,\n        expectedSchema(\"id\", \"seq\")\n      )\n      assertIgnoredColumns[DiffAs8](\n        Diff.default.diffAs(left, right, Seq(\"id\", \"seq\"), Seq(\"meta\")),\n        expectedDiffAs8,\n        expectedSchema(\"id\", \"seq\")\n      )\n\n      // TODO: add diffWith\n    }\n  }\n\n  test(\"diff with ignored columns case-sensitive\") {\n    withSQLConf(SQLConf.CASE_SENSITIVE.key -> \"true\") {\n      val left = left8.toDF(\"id\", \"seq\", \"value\", \"meta\")\n      val right = right8.toDF(\"ID\", \"SEQ\", \"VALUE\", \"META\")\n\n      doTestRequirement(\n        left.diff(right, Seq(\"Id\", \"SeQ\"), Seq(\"MeTa\")),\n        \"The datasets do not have the same schema.\\nLeft extra columns: id (IntegerType), seq (IntegerType), value (StringType), meta (StringType)\\nRight extra columns: ID (IntegerType), SEQ (IntegerType), VALUE (StringType), META (StringType)\"\n      )\n      doTestRequirement(\n        Diff.of(left, right, Seq(\"Id\", \"SeQ\"), Seq(\"MeTa\")),\n        \"The datasets do not have the same schema.\\nLeft extra columns: id (IntegerType), seq (IntegerType), value (StringType), meta (StringType)\\nRight extra columns: ID (IntegerType), SEQ (IntegerType), VALUE (StringType), META (StringType)\"\n      )\n      doTestRequirement(\n        Diff.default.diff(left, right, Seq(\"Id\", \"SeQ\"), Seq(\"MeTa\")),\n        \"The datasets do not have the same schema.\\nLeft extra columns: id (IntegerType), seq (IntegerType), value (StringType), meta (StringType)\\nRight extra columns: ID (IntegerType), SEQ (IntegerType), VALUE (StringType), META (StringType)\"\n      )\n\n      doTestRequirement(\n        left8.diff(right8, Seq(\"Id\", \"SeQ\"), Seq(\"MeTa\")),\n        \"Some id columns do not exist: Id, SeQ missing among id, seq, value, meta\"\n      )\n      doTestRequirement(\n        Diff.of(left8, right8, Seq(\"Id\", \"SeQ\"), Seq(\"MeTa\")),\n        \"Some id columns do not exist: Id, SeQ missing among id, seq, value, meta\"\n      )\n      doTestRequirement(\n        Diff.default.diff(left8, right8, Seq(\"Id\", \"SeQ\"), Seq(\"MeTa\")),\n        \"Some id columns do not exist: Id, SeQ missing among id, seq, value, meta\"\n      )\n\n      doTestRequirement(\n        left8.diff(right8, Seq(\"id\", \"seq\"), Seq(\"MeTa\")),\n        \"Some ignore columns do not exist: MeTa missing among id, meta, seq, value\"\n      )\n      doTestRequirement(\n        Diff.of(left8, right8, Seq(\"id\", \"seq\"), Seq(\"MeTa\")),\n        \"Some ignore columns do not exist: MeTa missing among id, meta, seq, value\"\n      )\n      doTestRequirement(\n        Diff.default.diff(left8, right8, Seq(\"id\", \"seq\"), Seq(\"MeTa\")),\n        \"Some ignore columns do not exist: MeTa missing among id, meta, seq, value\"\n      )\n    }\n  }\n\n  test(\"diff similar with ignored columns case-insensitive\") {\n    withSQLConf(SQLConf.CASE_SENSITIVE.key -> \"false\") {\n      val left = left8.toDF(\"id\", \"seq\", \"value\", \"meta\").as[Value8]\n      val right = right9.toDF(\"ID\", \"SEQ\", \"VALUE\", \"INFO\").as[Value9up]\n\n      def expectedSchema(id: String, seq: String): StructType =\n        StructType(\n          Seq(\n            StructField(\"diff\", StringType),\n            StructField(id, IntegerType),\n            StructField(seq, IntegerType),\n            StructField(\"left_value\", StringType),\n            StructField(\"right_VALUE\", StringType),\n            StructField(\"left_meta\", StringType),\n            StructField(\"right_INFO\", StringType),\n          )\n        )\n\n      assertIgnoredColumns(\n        left.diff(right, Seq(\"iD\", \"sEq\"), Seq(\"MeTa\", \"InFo\")),\n        expectedDiff8and9,\n        expectedSchema(\"iD\", \"sEq\")\n      )\n      assertIgnoredColumns(\n        Diff.of(left, right, Seq(\"Id\", \"SeQ\"), Seq(\"mEtA\", \"iNfO\")),\n        expectedDiff8and9,\n        expectedSchema(\"Id\", \"SeQ\")\n      )\n      assertIgnoredColumns(\n        Diff.default.diff(left, right, Seq(\"ID\", \"SEQ\"), Seq(\"META\", \"INFO\")),\n        expectedDiff8and9,\n        expectedSchema(\"ID\", \"SEQ\")\n      )\n\n      // TODO: remove generic type\n      assertIgnoredColumns[DiffAs8and9](\n        left.diffAs(right, Seq(\"id\", \"seq\"), Seq(\"MeTa\", \"InFo\")),\n        expectedDiffAs8and9,\n        expectedSchema(\"id\", \"seq\")\n      )\n      assertIgnoredColumns[DiffAs8and9](\n        Diff.ofAs(left, right, Seq(\"id\", \"seq\"), Seq(\"mEtA\", \"iNfO\")),\n        expectedDiffAs8and9,\n        expectedSchema(\"id\", \"seq\")\n      )\n      assertIgnoredColumns[DiffAs8and9](\n        Diff.default.diffAs(left, right, Seq(\"id\", \"seq\"), Seq(\"meta\", \"info\")),\n        expectedDiffAs8and9,\n        expectedSchema(\"id\", \"seq\")\n      )\n\n      def expectedSchemaWith(id: String, seq: String): StructType =\n        StructType(\n          Seq(\n            StructField(\"_1\", StringType, nullable = false),\n            StructField(\n              \"_2\",\n              StructType(\n                Seq(\n                  StructField(id, IntegerType),\n                  StructField(seq, IntegerType),\n                  StructField(\"value\", StringType),\n                  StructField(\"meta\", StringType)\n                )\n              ),\n              nullable = true\n            ),\n            StructField(\n              \"_3\",\n              StructType(\n                Seq(\n                  StructField(id, IntegerType),\n                  StructField(seq, IntegerType),\n                  StructField(\"VALUE\", StringType),\n                  StructField(\"INFO\", StringType)\n                )\n              ),\n              nullable = true\n            ),\n          )\n        )\n\n      assertIgnoredColumns[(String, Value8, Value9up)](\n        left.diffWith(right, Seq(\"iD\", \"sEq\"), Seq(\"MeTa\", \"InFo\")),\n        expectedDiffWith8and9up,\n        expectedSchemaWith(\"iD\", \"sEq\")\n      )\n      assertIgnoredColumns[(String, Value8, Value9up)](\n        Diff.ofWith(left, right, Seq(\"Id\", \"SeQ\"), Seq(\"mEtA\", \"iNfO\")),\n        expectedDiffWith8and9up,\n        expectedSchemaWith(\"Id\", \"SeQ\")\n      )\n      assertIgnoredColumns[(String, Value8, Value9up)](\n        Diff.default.diffWith(left, right, Seq(\"ID\", \"SEQ\"), Seq(\"META\", \"INFO\")),\n        expectedDiffWith8and9up,\n        expectedSchemaWith(\"ID\", \"SEQ\")\n      )\n    }\n  }\n\n  test(\"diff similar with ignored columns case-sensitive\") {\n    withSQLConf(SQLConf.CASE_SENSITIVE.key -> \"true\") {\n      val left = left8.toDF(\"id\", \"seq\", \"value\", \"meta\").as[Value8]\n      val right = right9.toDF(\"ID\", \"SEQ\", \"VALUE\", \"INFO\").as[Value9up]\n\n      doTestRequirement(\n        left.diff(right, Seq(\"Id\", \"SeQ\"), Seq(\"MeTa\", \"InFo\")),\n        \"The datasets do not have the same schema.\\nLeft extra columns: id (IntegerType), seq (IntegerType), value (StringType), meta (StringType)\\nRight extra columns: ID (IntegerType), SEQ (IntegerType), VALUE (StringType), INFO (StringType)\"\n      )\n      doTestRequirement(\n        Diff.of(left, right, Seq(\"Id\", \"SeQ\"), Seq(\"MeTa\", \"InFo\")),\n        \"The datasets do not have the same schema.\\nLeft extra columns: id (IntegerType), seq (IntegerType), value (StringType), meta (StringType)\\nRight extra columns: ID (IntegerType), SEQ (IntegerType), VALUE (StringType), INFO (StringType)\"\n      )\n      doTestRequirement(\n        Diff.default.diff(left, right, Seq(\"Id\", \"SeQ\"), Seq(\"MeTa\", \"InFo\")),\n        \"The datasets do not have the same schema.\\nLeft extra columns: id (IntegerType), seq (IntegerType), value (StringType), meta (StringType)\\nRight extra columns: ID (IntegerType), SEQ (IntegerType), VALUE (StringType), INFO (StringType)\"\n      )\n\n      doTestRequirement(\n        left8.diff(right9, Seq(\"Id\", \"SeQ\"), Seq(\"MeTa\", \"InFo\")),\n        \"The datasets do not have the same schema.\\nLeft extra columns: meta (StringType)\\nRight extra columns: info (StringType)\"\n      )\n      doTestRequirement(\n        Diff.of(left8, right9, Seq(\"Id\", \"SeQ\"), Seq(\"MeTa\", \"InFo\")),\n        \"The datasets do not have the same schema.\\nLeft extra columns: meta (StringType)\\nRight extra columns: info (StringType)\"\n      )\n      doTestRequirement(\n        Diff.default.diff(left8, right9, Seq(\"Id\", \"SeQ\"), Seq(\"MeTa\", \"InFo\")),\n        \"The datasets do not have the same schema.\\nLeft extra columns: meta (StringType)\\nRight extra columns: info (StringType)\"\n      )\n\n      doTestRequirement(\n        left8.diff(right9, Seq(\"Id\", \"SeQ\"), Seq(\"meta\", \"info\")),\n        \"Some id columns do not exist: Id, SeQ missing among id, seq, value\"\n      )\n      doTestRequirement(\n        Diff.of(left8, right9, Seq(\"Id\", \"SeQ\"), Seq(\"meta\", \"info\")),\n        \"Some id columns do not exist: Id, SeQ missing among id, seq, value\"\n      )\n      doTestRequirement(\n        Diff.default.diff(left8, right9, Seq(\"Id\", \"SeQ\"), Seq(\"meta\", \"info\")),\n        \"Some id columns do not exist: Id, SeQ missing among id, seq, value\"\n      )\n    }\n  }\n\n  def assertDiffWith[T](actual: Seq[T], expected: Seq[T]): Unit = {\n    assert(actual.toSet === expected.toSet)\n    assert(actual.length === expected.length)\n  }\n\n  def assertDiffWithSchema[T](actual: Dataset[T], expected: Seq[T], expectedSchema: StructType): Unit = {\n    // ignore nullable\n    assert(ignoreNullable(actual.schema) === ignoreNullable(expectedSchema))\n    assertDiffWith(actual.collect(), expected)\n  }\n\n  test(\"diffWith\") {\n    val expected = Seq(\n      (\"N\", Value(1, Some(\"one\")), Value(1, Some(\"one\"))),\n      (\"I\", null, Value(4, Some(\"four\"))),\n      (\"C\", Value(2, Some(\"two\")), Value(2, Some(\"Two\"))),\n      (\"D\", Value(3, Some(\"three\")), null)\n    )\n\n    assertDiffWith(left.diffWith(right, \"id\").collect(), expected)\n    assertDiffWith(Diff.ofWith(left, right, \"id\").collect(), expected)\n    assertDiffWith(Diff.default.diffWith(left, right, \"id\").collect(), expected)\n  }\n\n  test(\"diffWith left-prefixed id\") {\n    val prefixedLeft = left.select($\"id\".as(\"left_id\"), $\"value\").as[ValueLeft]\n    val prefixedRight = right.select($\"id\".as(\"left_id\"), $\"value\").as[ValueLeft]\n\n    val expected = Seq(\n      (\"N\", ValueLeft(1, Some(\"one\")), ValueLeft(1, Some(\"one\"))),\n      (\"I\", null, ValueLeft(4, Some(\"four\"))),\n      (\"C\", ValueLeft(2, Some(\"two\")), ValueLeft(2, Some(\"Two\"))),\n      (\"D\", ValueLeft(3, Some(\"three\")), null)\n    )\n\n    assertDiffWith(prefixedLeft.diffWith(prefixedRight, \"left_id\").collect(), expected)\n    assertDiffWith(Diff.ofWith(prefixedLeft, prefixedRight, \"left_id\").collect(), expected)\n    assertDiffWith(Diff.default.diffWith(prefixedLeft, prefixedRight, \"left_id\").collect(), expected)\n  }\n\n  test(\"diffWith right-prefixed id\") {\n    val prefixedLeft = left.select($\"id\".as(\"right_id\"), $\"value\").as[ValueRight]\n    val prefixedRight = right.select($\"id\".as(\"right_id\"), $\"value\").as[ValueRight]\n\n    val expected = Seq(\n      (\"N\", ValueRight(1, Some(\"one\")), ValueRight(1, Some(\"one\"))),\n      (\"I\", null, ValueRight(4, Some(\"four\"))),\n      (\"C\", ValueRight(2, Some(\"two\")), ValueRight(2, Some(\"Two\"))),\n      (\"D\", ValueRight(3, Some(\"three\")), null)\n    )\n    assertDiffWith(prefixedLeft.diffWith(prefixedRight, \"right_id\").collect(), expected)\n    assertDiffWith(Diff.ofWith(prefixedLeft, prefixedRight, \"right_id\").collect(), expected)\n    assertDiffWith(Diff.default.diffWith(prefixedLeft, prefixedRight, \"right_id\").collect(), expected)\n  }\n\n  def doTestRequirement(f: => Any, expected: String): Unit = {\n    assert(intercept[IllegalArgumentException](f).getMessage === s\"requirement failed: $expected\")\n  }\n\n}\n"
  },
  {
    "path": "src/test/scala/uk/co/gresearch/spark/diff/examples/Examples.scala",
    "content": "/*\n * Copyright 2020 G-Research\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage uk.co.gresearch.spark.diff.examples\n\nimport uk.co.gresearch.spark.SparkTestSession\nimport uk.co.gresearch.spark.diff.{DatasetDiff, DiffMode, DiffOptions}\nimport uk.co.gresearch.test.Suite\n\ncase class Value(id: Int, value: Option[String], label: Option[String])\n\nclass Examples extends Suite with SparkTestSession {\n\n  test(\"issue\") {\n    import spark.implicits._\n    val originalDF =\n      Seq((1, \"gaurav\", \"jaipur\", 550, 70000), (2, \"sunil\", \"noida\", 600, 80000), (3, \"rishi\", \"ahmedabad\", 510, 65000))\n        .toDF(\"id\", \"name\", \"city\", \"credit_score\", \"credit_limit\")\n    val changedDF =\n      Seq((1, \"gaurav\", \"jaipur\", 550, 70000), (2, \"sunil\", \"noida\", 650, 90000), (4, \"Joshua\", \"cochin\", 612, 85000))\n        .toDF(\"id\", \"name\", \"city\", \"credit_score\", \"credit_limit\")\n    val options = DiffOptions.default.withChangeColumn(\"changes\")\n    val diff = originalDF.diff(changedDF, options, \"id\")\n    diff.show(false)\n  }\n\n  test(\"examples\") {\n    import spark.implicits._\n\n    val left = Seq(\n      Value(1, Some(\"one\"), None),\n      Value(2, Some(\"two\"), Some(\"number two\")),\n      Value(3, Some(\"three\"), Some(\"number three\")),\n      Value(4, Some(\"four\"), Some(\"number four\")),\n      Value(5, Some(\"five\"), Some(\"number five\"))\n    ).toDS\n\n    val right = Seq(\n      Value(1, Some(\"one\"), Some(\"one\")),\n      Value(2, Some(\"Two\"), Some(\"number two\")),\n      Value(3, Some(\"Three\"), Some(\"number Three\")),\n      Value(4, Some(\"four\"), Some(\"number four\")),\n      Value(6, Some(\"six\"), Some(\"number six\"))\n    ).toDS\n\n    {\n      Seq(DiffMode.ColumnByColumn, DiffMode.SideBySide, DiffMode.LeftSide, DiffMode.RightSide).foreach { mode =>\n        Seq(false, true).foreach { sparse =>\n          val options = DiffOptions.default.withDiffMode(mode)\n          left.diff(right, options, \"id\").orderBy(\"id\").show(false)\n        }\n      }\n    }\n  }\n\n}\n"
  },
  {
    "path": "src/test/scala/uk/co/gresearch/spark/group/GroupSuite.scala",
    "content": "/*\n * Copyright 2022 G-Research\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage uk.co.gresearch.spark.group\n\nimport uk.co.gresearch.test.Spec\n\nclass GroupSuite extends Spec {\n\n  describe(\"GroupedIterator\") {\n\n    describe(\"should work with empty iterator\") {\n      test[Int, Double](() => Iterator.empty)\n    }\n\n    describe(\"should work with null key\") {\n      test(() => Iterator((null, 1.0), (null, 2.0), (\"2\", 3.0), (\"2\", 4.0), (\"3\", 5.0)))\n    }\n\n    describe(\"should work with None key\") {\n      test(() => Iterator((None, 1.0), (None, 2.0), (Some(2), 3.0), (Some(2), 4.0), (Some(3), 5.0)))\n    }\n\n    describe(\"should work with null values\") {\n      test(() => Iterator((1, \"1.0\"), (1, null), (2, null), (2, \"4.0\"), (3, \"5.0\")))\n    }\n\n    describe(\"should work with None values\") {\n      test(() => Iterator((1, Some(1.0)), (1, None), (2, None), (2, Some(4.0)), (3, Some(5.0))))\n    }\n\n    describe(\"should work with one 1-element groups\") {\n      test(() => Iterator((1, 1.0)))\n    }\n\n    describe(\"should work with many 1-element groups\") {\n      test(() => Iterator((1, 1.0), (2, 2.0), (3, 3.0)))\n    }\n\n    describe(\"should work with one group\") {\n      test(() => Iterator((1, 1.0), (1, 2.0), (1, 3.0)))\n    }\n\n    describe(\"should work with many groups\") {\n      test(() => Iterator((1, 1.0), (1, 2.0), (2, 3.0), (2, 4.0), (3, 5.0)))\n    }\n\n    def test[K: Ordering, V](func: () => Iterator[(K, V)]): Unit = {\n      testWithUnconsumedGroups(func)\n      testWithPartiallyConsumedGroups(func)\n      testWithFullyConsumedGroups(func)\n      testWithMultipleHasNext(func)\n    }\n\n    def testWithUnconsumedGroups[K: Ordering, V](func: () => Iterator[(K, V)]): Unit = {\n      val existingKeys = func().map(_._1).toSet.toList\n\n      // this does not consume any group iterators\n      it(\"and unconsumed groups\") {\n        val git = new GroupedIterator(func())\n        val actualKeys = git.map(_._1).toList\n        assert(actualKeys === existingKeys)\n      }\n\n      // tests a specific group not being consumed at all\n      def testUnconsumedKey(unconsumedKey: K, func: () => Iterator[(K, V)]): Unit = {\n        // we expect all tuples (k, Some(v)), except for k == unconsumedKey, where we expect (k, None)\n        // here we consume all groups (it.toList), which is tested elsewhere to work\n        val expected = new GroupedIterator(func())\n          .map {\n            case (key, it) if key == unconsumedKey => it.toList; (key, Iterator(None))\n            case (key, it)                         => (key, it.map(Some(_)))\n          }\n          .flatMap { case (k, it) => it.map(v => (k, v)) }\n          .toList\n\n        // here we do not consume the group with key `unconsumedKey`\n        val actual = new GroupedIterator(func())\n          .map {\n            case (key, _) if key == unconsumedKey => (key, Iterator(None))\n            case (key, it)                        => (key, it.map(Some(_)))\n          }\n          .flatMap { case (k, it) => it.map(v => (k, v)) }\n          .toList\n\n        assert(actual === expected)\n      }\n\n      // this does not consume the first group iterator\n      it(\"and unconsumed first group\") {\n        if (existingKeys.nonEmpty) {\n          val firstKey = existingKeys.last\n          testUnconsumedKey(firstKey, func)\n        }\n      }\n\n      // this does not consume the second group iterator\n      it(\"and unconsumed second group\") {\n        if (existingKeys.length >= 2) {\n          val secondKey = existingKeys.tail.head\n          testUnconsumedKey(secondKey, func)\n        }\n      }\n\n      // this does not consume the last group iterator\n      it(\"and unconsumed last group\") {\n        if (existingKeys.nonEmpty) {\n          val lastKey = existingKeys.last\n          testUnconsumedKey(lastKey, func)\n        }\n      }\n    }\n\n    def testWithPartiallyConsumedGroups[K: Ordering, V](func: () => Iterator[(K, V)]): Unit = {\n      val existingKeys = func().map(_._1).toSet.toList\n\n      // this consumes only the first value of each group iterator\n      it(\"and partially consumed groups\") {\n        val git = new GroupedIterator(func())\n        val actualKeyValues = git.map { case (k, it) => (k, it.next()) }.toList\n        val expectedKeyValues = func().toList.groupBy(_._1).mapValues(_.head).values.toMap\n        val expectedKeyValuesOrdered = existingKeys zip existingKeys.map(expectedKeyValues)\n        assert(actualKeyValues === expectedKeyValuesOrdered)\n      }\n\n      // tests a specific group not being consumed at all\n      def testPartiallyConsumedKey(partiallyConsumedKey: K, func: () => Iterator[(K, V)]): Unit = {\n        // we expect all tuples (k, v), except for k == unconsumedKey,\n        // where we expect only the first tuple with k == partiallyConsumedKey\n        // here we consume all groups (it.toList), which is tested elsewhere to work\n        val expected = new GroupedIterator(func())\n          .map {\n            case (key, it) if key == partiallyConsumedKey => (key, Iterator(it.toList.head))\n            case (key, it)                                => (key, it)\n          }\n          .flatMap { case (k, it) => it.map(v => (k, v)) }\n          .toList\n\n        // here we only consume the first element of the group with key `unconsumedKey`\n        val actual = new GroupedIterator(func())\n          .map {\n            case (key, it) if key == partiallyConsumedKey => (key, Iterator(it.next()))\n            case (key, it)                                => (key, it)\n          }\n          .flatMap { case (k, it) => it.map(v => (k, v)) }\n          .toList\n\n        assert(actual === expected)\n      }\n\n      // this consumes the first group iterator only partially\n      it(\"and partially consumed first group\") {\n        if (existingKeys.nonEmpty) {\n          val firstKey = existingKeys.last\n          testPartiallyConsumedKey(firstKey, func)\n        }\n      }\n\n      // this consumes the second group iterator only partially\n      it(\"and partially consumed second group\") {\n        if (existingKeys.length >= 2) {\n          val secondKey = existingKeys.tail.head\n          testPartiallyConsumedKey(secondKey, func)\n        }\n      }\n\n      // this consumes the last group iterator only partially\n      it(\"and partially consumed last group\") {\n        if (existingKeys.nonEmpty) {\n          val lastKey = existingKeys.last\n          testPartiallyConsumedKey(lastKey, func)\n        }\n      }\n    }\n\n    def testWithFullyConsumedGroups[K: Ordering, V](func: () => Iterator[(K, V)]): Unit = {\n      // this consumes all group iterators\n      it(\"and fully consumed groups\") {\n        val expected = func().toList\n        val actual = new GroupedIterator(func()).flatMap { case (k, it) =>\n          it.map(v => (k, v))\n        }.toList\n        assert(actual === expected)\n      }\n    }\n\n    def testWithMultipleHasNext[K: Ordering, V](func: () => Iterator[(K, V)]): Unit = {\n      it(\"and multiple calls to hasNext\") {\n        val iter = func()\n        val isEmpty = iter.hasNext\n\n        val git = new GroupedIterator(iter)\n        assert(git.hasNext === isEmpty)\n        assert(git.hasNext === isEmpty)\n        assert(git.hasNext === isEmpty)\n\n        while (git.hasNext) {\n          assert(git.hasNext === true)\n          assert(git.hasNext === true)\n          git.next()\n        }\n      }\n    }\n\n  }\n\n  describe(\"GroupIterator\") {\n    it(\"should not work with empty iterator\") {\n      assertThrows[NoSuchElementException] { new GroupIterator[Int, Double](Iterator.empty.buffered) }\n    }\n\n    describe(\"should iterate only over current key\") {\n      def test[K: Ordering, V](it: Seq[(K, V)], expectedValues: Seq[V]): Unit = {\n        val git = new GroupIterator[K, V](it.iterator.buffered)\n        assert(git.toList === expectedValues)\n      }\n\n      it(\"for null key\") {\n        test(Seq((null, 1.0), (null, 2.0), (\"1\", 3.0)), Seq(1.0, 2.0))\n      }\n\n      describe(\"for single key\") {\n        it(\"and single value\") {\n          test(Seq((1, 1.0)), Seq(1.0))\n        }\n        it(\"and multiple values\") {\n          test(Seq((1, 1.0), (1, 2.0)), Seq(1.0, 2.0))\n          test(Seq((1, 1.0), (1, 2.0), (1, 3.0)), Seq(1.0, 2.0, 3.0))\n        }\n      }\n\n      describe(\"for multiple keys\") {\n        it(\"and single value\") {\n          test(Seq((1, 1.0), (2, 2.0)), Seq(1.0))\n        }\n        it(\"and multiple values\") {\n          test(Seq((1, 1.0), (1, 2.0), (2, 3.0)), Seq(1.0, 2.0))\n          test(Seq((1, 1.0), (1, 2.0), (1, 3.0), (2, 4.0)), Seq(1.0, 2.0, 3.0))\n        }\n      }\n    }\n  }\n\n}\n"
  },
  {
    "path": "src/test/scala/uk/co/gresearch/spark/parquet/ParquetSuite.scala",
    "content": "/*\n * Copyright 2023 G-Research\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage uk.co.gresearch.spark.parquet\n\nimport org.apache.spark.SparkException\nimport org.apache.spark.sql.Row.unapplySeq\n// array_contains and map_keys could be replaced with map_contains_key, but that is not available before Spark 3.4\nimport org.apache.spark.sql.functions.{array_contains, lit, map_keys, regexp_replace, spark_partition_id, when}\nimport org.apache.spark.sql.types._\nimport org.apache.spark.sql.{Column, DataFrame, Row}\nimport org.scalatest.tagobjects.Slow\nimport uk.co.gresearch._\nimport uk.co.gresearch.spark.{SparkTestSession, SparkVersion}\nimport uk.co.gresearch.test.Suite\n\nclass ParquetSuite extends Suite with SparkTestSession with SparkVersion {\n\n  import spark.implicits._\n\n  // These parquet test files have been created as follows:\n  //   import org.apache.spark.sql.SaveMode\n  //   spark.sparkContext.hadoopConfiguration.setInt(\"parquet.block.size\", 1024)\n  //   spark.range(100).select($\"id\", rand().as(\"val\")).repartitionByRange(1, $\"id\").write.parquet(\"test.parquet\")\n  //   spark.range(100, 300, 1).select($\"id\", rand().as(\"val\")).repartitionByRange(1, $\"id\").write.mode(SaveMode.Append).parquet(\"test.parquet\")eq((3, \"three\"), (4, \"four\"), (5, \"five\"), (6, \"six\"), (7, \"seven\")).toDF(\"id\", \"value\").repartitionByRange(1, $\"id\").write.mode(SaveMode.Append).parquet(\"test.parquet\")\n  //   spark.range(100).withColumn(\"val\", $\"id\".cast(\"float\")).write.option(\"parquet.encryption.footer.key\", \"key\").option(\"parquet.encryption.plaintext.footer\", \"true\").option(\"parquet.encryption.column.keys\", \"key:val\").parquet(\"src/test/files/encrypted.parquet\")\n  //   spark.range(100).withColumn(\"val\", $\"id\".cast(\"float\")).write.option(\"parquet.encryption.footer.key\", \"key\").option(\"parquet.encryption.column.keys\", \"key:val\").parquet(\"src/test/files/encrypted-column.parquet\")\n  val testFile = \"src/test/files/test.parquet\"\n  val nestedFile = \"src/test/files/nested.parquet\"\n  val encryptedFilePlaintextFooter = \"src/test/files/encrypted1.parquet\"\n  val encryptedFileEncryptedFooter = \"src/test/files/encrypted2.parquet\"\n\n  val parallelisms = Seq(None, Some(1), Some(2), Some(8))\n\n  def assertDf(\n      actual: DataFrame,\n      order: Seq[Column],\n      expectedSchema: StructType,\n      expectedRows: Seq[Row],\n      expectedParallelism: Option[Int] = None,\n      postProcess: DataFrame => DataFrame = identity\n  ): Unit = {\n    assert(actual.schema === expectedSchema)\n\n    if (expectedParallelism.isDefined) {\n      assert(actual.rdd.getNumPartitions === expectedParallelism.get)\n    } else {\n      assert(actual.rdd.getNumPartitions === actual.sparkSession.sparkContext.defaultParallelism)\n    }\n\n    val replaced =\n      actual\n        .orderBy(order: _*)\n        .withColumn(\n          \"filename\",\n          regexp_replace(regexp_replace($\"filename\", \".*/test.parquet/\", \"\"), \".*/nested.parquet\", \"nested.parquet\")\n        )\n        .when(actual.columns.contains(\"schema\"))\n        .call(_.withColumn(\"schema\", regexp_replace($\"schema\", \"\\n\", \"\\\\\\\\n\")))\n        .call(postProcess)\n    assert(replaced.collect() === expectedRows)\n  }\n\n  val hasEncryptionType: Boolean = ParquetMetaDataUtil.getEncryptionTypeIsSupported\n  val UNENCRYPTED: String = if (hasEncryptionType) \"UNENCRYPTED\" else null\n  val PLAINTEXT_FOOTER: String = if (hasEncryptionType) \"PLAINTEXT_FOOTER\" else null\n  val ENCRYPTED_FOOTER: String = if (hasEncryptionType) \"ENCRYPTED_FOOTER\" else null\n\n  parallelisms.foreach { parallelism =>\n    test(s\"read parquet metadata (parallelism=${parallelism.map(_.toString).getOrElse(\"None\")})\") {\n      val createdBy = \"parquet-mr version 1.12.2 (build 77e30c8093386ec52c3cfa6c34b7ef3321322c94)\"\n      val schema = \"message spark_schema {\\\\n  required int64 id;\\\\n  required double val;\\\\n}\\\\n\"\n      val keyValues = Map(\n        \"org.apache.spark.version\" -> \"3.3.0\",\n        \"org.apache.spark.sql.parquet.row.metadata\" -> \"\"\"{\"type\":\"struct\",\"fields\":[{\"name\":\"id\",\"type\":\"long\",\"nullable\":false,\"metadata\":{}},{\"name\":\"val\",\"type\":\"double\",\"nullable\":false,\"metadata\":{}}]}\"\"\"\n      )\n\n      assertDf(\n        spark.read\n          .when(parallelism.isDefined)\n          .either(_.parquetMetadata(parallelism.get, testFile))\n          .or(_.parquetMetadata(testFile)),\n        Seq($\"filename\"),\n        StructType(\n          Seq(\n            StructField(\"filename\", StringType, nullable = true),\n            StructField(\"blocks\", IntegerType, nullable = false),\n            StructField(\"compressedBytes\", LongType, nullable = true),\n            StructField(\"uncompressedBytes\", LongType, nullable = true),\n            StructField(\"rows\", LongType, nullable = false),\n            StructField(\"columns\", IntegerType, nullable = false),\n            StructField(\"values\", LongType, nullable = true),\n            StructField(\"nulls\", LongType, nullable = true),\n            StructField(\"createdBy\", StringType, nullable = true),\n            StructField(\"schema\", StringType, nullable = true),\n            StructField(\"encryption\", StringType, nullable = true),\n            StructField(\"keyValues\", MapType(StringType, StringType, valueContainsNull = true), nullable = true),\n          )\n        ),\n        Seq(\n          Row(\"file1.parquet\", 1, 1268, 1652, 100, 2, 200, 0, createdBy, schema, UNENCRYPTED, keyValues),\n          Row(\"file2.parquet\", 2, 2539, 3302, 200, 2, 400, 0, createdBy, schema, UNENCRYPTED, keyValues),\n        ),\n        parallelism\n      )\n    }\n  }\n\n  parallelisms.foreach { parallelism =>\n    test(s\"read parquet schema (parallelism=${parallelism.map(_.toString).getOrElse(\"None\")})\") {\n      assertDf(\n        spark.read\n          .when(parallelism.isDefined)\n          .either(_.parquetSchema(parallelism.get, nestedFile))\n          .or(_.parquetSchema(nestedFile)),\n        Seq($\"filename\", $\"columnPath\"),\n        StructType(\n          Seq(\n            StructField(\"filename\", StringType, nullable = true),\n            StructField(\"columnName\", StringType, nullable = true),\n            StructField(\"columnPath\", ArrayType(StringType, containsNull = true), nullable = true),\n            StructField(\"repetition\", StringType, nullable = true),\n            StructField(\"type\", StringType, nullable = true),\n            StructField(\"length\", IntegerType, nullable = true),\n            StructField(\"originalType\", StringType, nullable = true),\n            StructField(\"logicalType\", StringType, nullable = true),\n            StructField(\"isPrimitive\", BooleanType, nullable = false),\n            StructField(\"primitiveType\", StringType, nullable = true),\n            StructField(\"primitiveOrder\", StringType, nullable = true),\n            StructField(\"maxDefinitionLevel\", IntegerType, nullable = false),\n            StructField(\"maxRepetitionLevel\", IntegerType, nullable = false),\n          )\n        ),\n        // format: off\n        Seq(\n          Row(\"nested.parquet\", \"a\", Seq(\"a\"), \"REQUIRED\", \"INT64\", 0, null, null, true, \"INT64\", \"TYPE_DEFINED_ORDER\", 0, 0),\n          Row(\"nested.parquet\", \"x\", Seq(\"b\", \"x\"), \"REQUIRED\", \"INT32\", 0, null, null, true, \"INT32\", \"TYPE_DEFINED_ORDER\", 1, 0),\n          Row(\"nested.parquet\", \"y\", Seq(\"b\", \"y\"), \"REQUIRED\", \"DOUBLE\", 0, null, null, true, \"DOUBLE\", \"TYPE_DEFINED_ORDER\", 1, 0),\n          Row(\"nested.parquet\", \"z\", Seq(\"b\", \"z\"), \"OPTIONAL\", \"INT64\", 0, \"TIMESTAMP_MICROS\", \"TIMESTAMP(MICROS,true)\", true, \"INT64\", \"TYPE_DEFINED_ORDER\", 2, 0),\n          Row(\"nested.parquet\", \"element\", Seq(\"c\", \"list\", \"element\"), \"OPTIONAL\", \"BINARY\", 0, \"UTF8\", \"STRING\", true, \"BINARY\", \"TYPE_DEFINED_ORDER\", 3, 1),\n        ),\n        // format: on\n        parallelism\n      )\n    }\n  }\n\n  parallelisms.foreach { parallelism =>\n    test(s\"read parquet blocks (parallelism=${parallelism.map(_.toString).getOrElse(\"None\")})\") {\n      assertDf(\n        spark.read\n          .when(parallelism.isDefined)\n          .either(_.parquetBlocks(parallelism.get, testFile))\n          .or(_.parquetBlocks(testFile)),\n        Seq($\"filename\", $\"block\"),\n        StructType(\n          Seq(\n            StructField(\"filename\", StringType, nullable = true),\n            StructField(\"block\", IntegerType, nullable = false),\n            StructField(\"blockStart\", LongType, nullable = false),\n            StructField(\"compressedBytes\", LongType, nullable = true),\n            StructField(\"uncompressedBytes\", LongType, nullable = false),\n            StructField(\"rows\", LongType, nullable = false),\n            StructField(\"columns\", IntegerType, nullable = false),\n            StructField(\"values\", LongType, nullable = true),\n            StructField(\"nulls\", LongType, nullable = true),\n          )\n        ),\n        Seq(\n          Row(\"file1.parquet\", 1, 4, 1268, 1652, 100, 2, 200, 0),\n          Row(\"file2.parquet\", 1, 4, 1269, 1651, 100, 2, 200, 0),\n          Row(\"file2.parquet\", 2, 1273, 1270, 1651, 100, 2, 200, 0),\n        ),\n        parallelism\n      )\n    }\n  }\n\n  parallelisms.foreach { parallelism =>\n    test(s\"read parquet block columns (parallelism=${parallelism.map(_.toString).getOrElse(\"None\")})\") {\n      assertDf(\n        spark.read\n          .when(parallelism.isDefined)\n          .either(_.parquetBlockColumns(parallelism.get, testFile))\n          .or(_.parquetBlockColumns(testFile)),\n        Seq($\"filename\", $\"block\", $\"column\"),\n        StructType(\n          Seq(\n            StructField(\"filename\", StringType, nullable = true),\n            StructField(\"block\", IntegerType, nullable = false),\n            StructField(\"column\", ArrayType(StringType), nullable = true),\n            StructField(\"codec\", StringType, nullable = true),\n            StructField(\"type\", StringType, nullable = true),\n            StructField(\"encodings\", ArrayType(StringType), nullable = true),\n            StructField(\"encrypted\", BooleanType, nullable = true),\n            StructField(\"minValue\", StringType, nullable = true),\n            StructField(\"maxValue\", StringType, nullable = true),\n            StructField(\"columnStart\", LongType, nullable = true),\n            StructField(\"compressedBytes\", LongType, nullable = true),\n            StructField(\"uncompressedBytes\", LongType, nullable = true),\n            StructField(\"values\", LongType, nullable = true),\n            StructField(\"nulls\", LongType, nullable = true),\n          )\n        ),\n        // format: off\n        Seq(\n          Row(\"file1.parquet\", 1, \"[id]\", \"SNAPPY\", \"required int64 id\", \"[BIT_PACKED, PLAIN]\", false, \"0\", \"99\", 4, 437, 826, 100, 0),\n          Row(\"file1.parquet\", 1, \"[val]\", \"SNAPPY\", \"required double val\", \"[BIT_PACKED, PLAIN]\", false, \"0.005067503372006343\", \"0.9973357672164814\", 441, 831, 826, 100, 0),\n          Row(\"file2.parquet\", 1, \"[id]\", \"SNAPPY\", \"required int64 id\", \"[BIT_PACKED, PLAIN]\", false, \"100\", \"199\", 4, 438, 825, 100, 0),\n          Row(\"file2.parquet\", 1, \"[val]\", \"SNAPPY\", \"required double val\", \"[BIT_PACKED, PLAIN]\", false, \"0.010617521596503865\", \"0.999189783846449\", 442, 831, 826, 100, 0),\n          Row(\"file2.parquet\", 2, \"[id]\", \"SNAPPY\", \"required int64 id\", \"[BIT_PACKED, PLAIN]\", false, \"200\", \"299\", 1273, 440, 826, 100, 0),\n          Row(\"file2.parquet\", 2, \"[val]\", \"SNAPPY\", \"required double val\", \"[BIT_PACKED, PLAIN]\", false, \"0.011277044401634018\", \"0.970525681750662\", 1713, 830, 825, 100, 0),\n        ),\n        // format: on\n        parallelism,\n        (df: DataFrame) =>\n          df\n            .withColumn(\"column\", $\"column\".cast(StringType))\n            .withColumn(\"encodings\", $\"encodings\".cast(StringType))\n      )\n    }\n  }\n\n  if (sys.env.get(\"CI_SLOW_TESTS\").exists(_.equals(\"1\"))) {\n    Seq(1, 3, 7, 13, 19, 29, 61, 127, 251).foreach { partitionSize =>\n      test(s\"read parquet partitions ($partitionSize bytes)\", Slow) {\n        withSQLConf(\"spark.sql.files.maxPartitionBytes\" -> partitionSize.toString) {\n          val parquet = spark.read.parquet(testFile).cache()\n\n          val rows = spark.read\n            .parquet(testFile)\n            .mapPartitions(it => Iterator(it.length))\n            .select(spark_partition_id().as(\"partition\"), $\"value\".as(\"actual_rows\"))\n          val partitions = spark.read\n            .parquetPartitions(testFile)\n            .join(rows, Seq(\"partition\"), \"left\")\n            .select($\"partition\", $\"start\", $\"end\", $\"length\", $\"rows\", $\"actual_rows\", $\"filename\")\n\n          if (\n            partitions\n              .where(\n                $\"rows\" =!= $\"actual_rows\" || ($\"rows\" =!= 0 || $\"actual_rows\" =!= 0) && $\"length\" =!= partitionSize\n              )\n              .head(1)\n              .nonEmpty\n          ) {\n            partitions\n              .orderBy($\"start\")\n              .where($\"rows\" =!= 0 || $\"actual_rows\" =!= 0)\n              .show(false)\n            fail()\n          }\n\n          parquet.unpersist()\n        }\n      }\n    }\n  }\n\n  Map(\n    None -> Seq(\n      Row(0, 1930, 1930, 1, 1268, 1652, 100, 2, 200, 0, \"file1.parquet\", 1930),\n      Row(0, 3493, 3493, 2, 2539, 3302, 200, 2, 400, 0, \"file2.parquet\", 3493),\n    ),\n    Some(8192) -> Seq(\n      Row(0, 1930, 1930, 1, 1268, 1652, 100, 2, 200, 0, \"file1.parquet\", 1930),\n      Row(0, 3493, 3493, 2, 2539, 3302, 200, 2, 400, 0, \"file2.parquet\", 3493),\n    ),\n    Some(1024) -> Seq(\n      Row(0, 1024, 1024, 1, 1268, 1652, 100, 2, 200, 0, \"file1.parquet\", 1930),\n      Row(1024, 1930, 906, 0, 0, 0, 0, 0, 0, 0, \"file1.parquet\", 1930),\n      Row(0, 1024, 1024, 1, 1269, 1651, 100, 2, 200, 0, \"file2.parquet\", 3493),\n      Row(1024, 2048, 1024, 1, 1270, 1651, 100, 2, 200, 0, \"file2.parquet\", 3493),\n      Row(2048, 3072, 1024, 0, 0, 0, 0, 0, 0, 0, \"file2.parquet\", 3493),\n      Row(3072, 3493, 421, 0, 0, 0, 0, 0, 0, 0, \"file2.parquet\", 3493),\n    ),\n    Some(512) -> Seq(\n      Row(0, 512, 512, 0, 0, 0, 0, 0, 0, 0, \"file1.parquet\", 1930),\n      Row(512, 1024, 512, 1, 1268, 1652, 100, 2, 200, 0, \"file1.parquet\", 1930),\n      Row(1024, 1536, 512, 0, 0, 0, 0, 0, 0, 0, \"file1.parquet\", 1930),\n      Row(1536, 1930, 394, 0, 0, 0, 0, 0, 0, 0, \"file1.parquet\", 1930),\n      Row(0, 512, 512, 0, 0, 0, 0, 0, 0, 0, \"file2.parquet\", 3493),\n      Row(512, 1024, 512, 1, 1269, 1651, 100, 2, 200, 0, \"file2.parquet\", 3493),\n      Row(1024, 1536, 512, 0, 0, 0, 0, 0, 0, 0, \"file2.parquet\", 3493),\n      Row(1536, 2048, 512, 1, 1270, 1651, 100, 2, 200, 0, \"file2.parquet\", 3493),\n      Row(2048, 2560, 512, 0, 0, 0, 0, 0, 0, 0, \"file2.parquet\", 3493),\n      Row(2560, 3072, 512, 0, 0, 0, 0, 0, 0, 0, \"file2.parquet\", 3493),\n      Row(3072, 3493, 421, 0, 0, 0, 0, 0, 0, 0, \"file2.parquet\", 3493),\n    ),\n  ).foreach { case (partitionSize, expectedRows) =>\n    parallelisms.foreach { parallelism =>\n      test(s\"read parquet partitions (${partitionSize\n          .getOrElse(\"default\")} bytes) (parallelism=${parallelism.map(_.toString).getOrElse(\"None\")})\") {\n        withSQLConf(\n          partitionSize.map(size => Seq(\"spark.sql.files.maxPartitionBytes\" -> size.toString)).getOrElse(Seq.empty): _*\n        ) {\n          val expected = expectedRows.map {\n            case row if SparkMajorVersion > 3 || SparkMinorVersion >= 3 => row\n            case row => Row(unapplySeq(row).get.updated(11, null): _*)\n          }\n\n          val actual = spark.read\n            .when(parallelism.isDefined)\n            .either(_.parquetPartitions(parallelism.get, testFile))\n            .or(_.parquetPartitions(testFile))\n            .cache()\n\n          val partitions = actual.select($\"partition\").as[Int].collect()\n          if (partitionSize.isDefined) {\n            assert(partitions.indices === partitions.sorted)\n          } else {\n            assert(Seq(0, 0) === partitions)\n          }\n\n          val schema = StructType(\n            Seq(\n              StructField(\"partition\", IntegerType, nullable = false),\n              StructField(\"start\", LongType, nullable = false),\n              StructField(\"end\", LongType, nullable = false),\n              StructField(\"length\", LongType, nullable = false),\n              StructField(\"blocks\", IntegerType, nullable = false),\n              StructField(\"compressedBytes\", LongType, nullable = false),\n              StructField(\"uncompressedBytes\", LongType, nullable = false),\n              StructField(\"rows\", LongType, nullable = false),\n              StructField(\"columns\", IntegerType, nullable = false),\n              StructField(\"values\", LongType, nullable = false),\n              StructField(\"nulls\", LongType, nullable = true),\n              StructField(\"filename\", StringType, nullable = true),\n              StructField(\"fileLength\", LongType, nullable = true),\n            )\n          )\n\n          assertDf(actual, Seq($\"filename\", $\"start\"), schema, expected, parallelism, df => df.drop(\"partition\"))\n          actual.unpersist()\n        }\n      }\n    }\n  }\n\n  test(\"read encrypted parquets\") {\n    // used to collect the result and make it comparable to expected rows\n    def collect(df: DataFrame): Array[Row] = {\n      val expectedKeyValues = array_contains(map_keys($\"keyValues\"), \"org.apache.spark.version\") &&\n        array_contains(map_keys($\"keyValues\"), \"org.apache.spark.sql.parquet.row.metadata\")\n      df.orderBy($\"filename\")\n        .withColumn(\"filename\", regexp_replace($\"filename\", \".*/\", \"\"))\n        .when(df.columns.contains(\"createdBy\"))\n        .call(_.withColumn(\"createdBy\", when($\"createdBy\".isNotNull, lit(\"…\"))))\n        .when(df.columns.contains(\"schema\"))\n        .call(_.withColumn(\"schema\", when($\"schema\".isNotNull, lit(\"…\"))))\n        .when(df.columns.contains(\"keyValues\"))\n        .call(_.withColumn(\"keyValues\", expectedKeyValues))\n        .collect()\n    }\n\n    val hasIsEncrypted: Boolean = ParquetMetaDataUtil.isEncryptedIsSupported\n    val isEncrypted: Option[Boolean] = if (hasIsEncrypted) Some(true) else None\n    val isNotEncrypted: Option[Boolean] = if (hasIsEncrypted) Some(false) else None\n\n    // we are reading the encrypted file with plaintext footer once without any decryption keys\n    assert(\n      collect(spark.read.parquetMetadata(encryptedFilePlaintextFooter)) === Seq(\n        Row(\"encrypted1.parquet\", 1, null, null, 100, 2, null, null, \"…\", \"…\", PLAINTEXT_FOOTER, true),\n      )\n    )\n    assert(\n      collect(spark.read.parquetSchema(encryptedFilePlaintextFooter)) === Seq(\n      // format: off\n      Row(\"encrypted1.parquet\", \"id\", Seq(\"id\"), \"REQUIRED\", \"INT64\", 0, null, null, true, \"INT64\", \"TYPE_DEFINED_ORDER\", 0, 0),\n      Row(\"encrypted1.parquet\", \"val\", Seq(\"val\"), \"REQUIRED\", \"FLOAT\", 0, null, null, true, \"FLOAT\", \"TYPE_DEFINED_ORDER\", 0, 0),\n      // format: on\n      )\n    )\n    assert(\n      collect(spark.read.parquetBlocks(encryptedFilePlaintextFooter)) === Seq(\n        Row(\"encrypted1.parquet\", 1, 4, null, 1358, 100, 2, null, null),\n      )\n    )\n    assert(\n      collect(spark.read.parquetBlockColumns(encryptedFilePlaintextFooter)) === Seq(\n        Row(\"encrypted1.parquet\", 1, Seq(\"id\"), null, null, null, false, null, null, null, null, null, null, null),\n        Row(\"encrypted1.parquet\", 1, Seq(\"val\"), null, null, null, true, null, null, null, null, null, null, null),\n      )\n    )\n    assertThrows[SparkException] {\n      collect(spark.read.parquetPartitions(encryptedFilePlaintextFooter))\n    }\n\n    // we cannot read the encrypted file with encrypted footer\n    assertThrows[SparkException] {\n      collect(spark.read.parquetMetadata(encryptedFileEncryptedFooter))\n    }\n    assertThrows[SparkException] {\n      collect(spark.read.parquetSchema(encryptedFileEncryptedFooter))\n    }\n    assertThrows[SparkException] {\n      collect(spark.read.parquetBlocks(encryptedFileEncryptedFooter))\n    }\n    assertThrows[SparkException] {\n      collect(spark.read.parquetBlockColumns(encryptedFileEncryptedFooter))\n    }\n    assertThrows[SparkException] {\n      collect(spark.read.parquetPartitions(encryptedFileEncryptedFooter))\n    }\n\n    // we can read all encrypted files with footer encryption key, column key not needed\n    spark.sparkContext.hadoopConfiguration\n      .set(\"parquet.crypto.factory.class\", \"org.apache.parquet.crypto.keytools.PropertiesDrivenCryptoFactory\")\n    spark.sparkContext.hadoopConfiguration\n      .set(\"parquet.encryption.kms.client.class\", \"org.apache.parquet.crypto.keytools.mocks.InMemoryKMS\")\n    spark.sparkContext.hadoopConfiguration\n      .set(\"parquet.encryption.key.list\", \"key:AAECAAECAAECAAECAAECAA==\")\n\n    // either size or null, depending on the Spark version (see SplitFile.fileSize)\n    val hasSplitFileSize = (size: Long) =>\n      Some(SparkMajorVersion > 3 || SparkMinorVersion >= 3)\n        .filter(_ == true)\n        .map(_ => size)\n\n    // reading the encrypted file with plaintext footer now reveals all metadata\n    assert(\n      collect(spark.read.parquetMetadata(encryptedFilePlaintextFooter)) === Seq(\n        Row(\"encrypted1.parquet\", 1, 1004, 1358, 100, 2, 200, 0, \"…\", \"…\", PLAINTEXT_FOOTER, true),\n      )\n    )\n    assert(\n      collect(spark.read.parquetSchema(encryptedFilePlaintextFooter)) === Seq(\n      // format: off\n      Row(\"encrypted1.parquet\", \"id\", Seq(\"id\"), \"REQUIRED\", \"INT64\", 0, null, null, true, \"INT64\", \"TYPE_DEFINED_ORDER\", 0, 0),\n      Row(\"encrypted1.parquet\", \"val\", Seq(\"val\"), \"REQUIRED\", \"FLOAT\", 0, null, null, true, \"FLOAT\", \"TYPE_DEFINED_ORDER\", 0, 0),\n      // format: on\n      )\n    )\n    assert(\n      collect(spark.read.parquetBlocks(encryptedFilePlaintextFooter)) === Seq(\n        Row(\"encrypted1.parquet\", 1, 4, 1004, 1358, 100, 2, 200, 0),\n      )\n    )\n    assert(\n      collect(spark.read.parquetBlockColumns(encryptedFilePlaintextFooter)) === Seq(\n      // format: off\n      Row(\"encrypted1.parquet\", 1, Seq(\"id\"), \"SNAPPY\", \"required int64 id\", Seq(\"BIT_PACKED\", \"PLAIN\"), isNotEncrypted.orNull, \"0\", \"99\", 4, 437, 826, 100, 0),\n      Row(\"encrypted1.parquet\", 1, Seq(\"val\"), \"SNAPPY\", \"required float val\", Seq(\"BIT_PACKED\", \"PLAIN\"), isEncrypted.orNull, \"-0.0\", \"99.0\", 441, 567, 532, 100, 0),\n      // format: on\n      )\n    )\n    assert(\n      collect(spark.read.parquetPartitions(encryptedFilePlaintextFooter)) === Seq(\n        Row(0, 0, 2705, 2705, 1, 1004, 1358, 100, 2, 200, 0, \"encrypted1.parquet\", hasSplitFileSize(2705).orNull),\n      )\n    )\n\n    // we can now read the encrypted file with encrypted footer\n    assert(\n      collect(spark.read.parquetMetadata(encryptedFileEncryptedFooter)) === Seq(\n        Row(\"encrypted2.parquet\", 1, 1004, 1358, 100, 2, 200, 0, \"…\", \"…\", ENCRYPTED_FOOTER, true),\n      )\n    )\n    assert(\n      collect(spark.read.parquetSchema(encryptedFileEncryptedFooter)) === Seq(\n      // format: off\n      Row(\"encrypted2.parquet\", \"id\", Seq(\"id\"), \"REQUIRED\", \"INT64\", 0, null, null, true, \"INT64\", \"TYPE_DEFINED_ORDER\", 0, 0),\n      Row(\"encrypted2.parquet\", \"val\", Seq(\"val\"), \"REQUIRED\", \"FLOAT\", 0, null, null, true, \"FLOAT\", \"TYPE_DEFINED_ORDER\", 0, 0),\n      // format: on\n      )\n    )\n    assert(\n      collect(spark.read.parquetBlocks(encryptedFileEncryptedFooter)) === Seq(\n        Row(\"encrypted2.parquet\", 1, 4, 1004, 1358, 100, 2, 200, 0),\n      )\n    )\n    assert(\n      collect(spark.read.parquetBlockColumns(encryptedFileEncryptedFooter)) === Seq(\n      // format: off\n      Row(\"encrypted2.parquet\", 1, Seq(\"id\"), \"SNAPPY\", \"required int64 id\", Seq(\"BIT_PACKED\", \"PLAIN\"), isNotEncrypted.orNull, \"0\", \"99\", 4, 437, 826, 100, 0),\n      Row(\"encrypted2.parquet\", 1, Seq(\"val\"), \"SNAPPY\", \"required float val\", Seq(\"BIT_PACKED\", \"PLAIN\"), isEncrypted.orNull, \"-0.0\", \"99.0\", 441, 567, 532, 100, 0),\n      // format: on\n      )\n    )\n    assert(\n      collect(spark.read.parquetPartitions(encryptedFileEncryptedFooter)) === Seq(\n        Row(0, 0, 2691, 2691, 1, 1004, 1358, 100, 2, 200, 0, \"encrypted2.parquet\", hasSplitFileSize(2691).orNull),\n      )\n    )\n  }\n}\n"
  },
  {
    "path": "src/test/scala/uk/co/gresearch/spark/test/package.scala",
    "content": "/*\n * Copyright 2020 G-Research\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage uk.co.gresearch.spark\n\nimport org.apache.hadoop.conf.Configuration\nimport org.apache.hadoop.fs.{FileSystem, Path}\n\nimport java.io.File\n\npackage object test {\n\n  protected def withTempPath[T](f: File => T): T = {\n    val dir = File.createTempFile(\"test\", \".tmp\")\n    dir.delete()\n\n    try {\n      f(dir)\n    } finally {\n      FileSystem.get(new Configuration()).delete(new Path(dir.getAbsolutePath), true)\n    }\n  }\n\n}\n"
  },
  {
    "path": "src/test/scala/uk/co/gresearch/test/ClasspathSuite.scala",
    "content": "/*\n * Copyright 2025 G-Research\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage uk.co.gresearch.test\n\nimport scala.reflect.io.Path\nimport uk.co.gresearch.spark.BuildVersion\n\nclass ClasspathSuite extends Spec with BuildVersion {\n  describe(\"The classpath\") {\n    val classpath = System.getProperty(\"java.class.path\").split(\":\").filter(_.contains(\"target\"))\n\n    val resourceUrl = getClass.getResource(\"/log4j2.properties\")\n    val testClasses = Path(resourceUrl.getPath).parent\n\n    it(\"should contain compiled test classes\") {\n      assert(classpath.contains(testClasses.path))\n    }\n\n    val isIntegrationTest = System.getenv().getOrDefault(\"CI_INTEGRATION_TEST\", \"false\") == \"true\"\n    val jarFilename = s\"spark-extension_$BuildScalaCompatVersionString-$VersionString.jar\"\n    val jar = testClasses.parent.resolve(Path(jarFilename)).path\n    val classes = testClasses.parent.resolve(Path(\"classes\")).path\n\n    it(\"should contain compiled classes but not the jar\") {\n      assume(!isIntegrationTest)\n\n      // unit testing does not require the jar to be in the classpath\n      assert(!classpath.contains(jar))\n      // but the path to the compiled classes\n      assert(classpath.contains(classes))\n    }\n\n    it(\"should contain the jar but not compiled classes\") {\n      assume(isIntegrationTest)\n\n      // integration testing requires the jar to be in the classpath\n      assert(classpath.contains(jar))\n      // but not the path to the compiled classes\n      assert(!classpath.contains(classes))\n    }\n  }\n}\n"
  },
  {
    "path": "src/test/scala/uk/co/gresearch/test/Spec.scala",
    "content": "/*\n * Copyright 2025 G-Research\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage uk.co.gresearch.test\n\nimport org.junit.runner.RunWith\nimport org.scalatest.funspec.AnyFunSpec\nimport org.scalatestplus.junit.JUnitRunner\n\n@RunWith(classOf[JUnitRunner])\nclass Spec extends AnyFunSpec {}\n"
  },
  {
    "path": "src/test/scala/uk/co/gresearch/test/Suite.scala",
    "content": "/*\n * Copyright 2025 G-Research\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage uk.co.gresearch.test\n\nimport org.junit.runner.RunWith\nimport org.scalatest.funsuite.AnyFunSuite\nimport org.scalatestplus.junit.JUnitRunner\n\n@RunWith(classOf[JUnitRunner])\nclass Suite extends AnyFunSuite {}\n"
  },
  {
    "path": "src/test/scala-spark-3/uk/co/gresearch/spark/SparkSuiteHelper.scala",
    "content": "/*\n * Copyright 2024 G-Research\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage uk.co.gresearch.spark\n\nimport org.apache.spark.sql.{Dataset, Encoder}\n\ntrait SparkSuiteHelper {\n  self: SparkTestSession =>\n  def createEmptyDataset[T : Encoder](): Dataset[T] = {\n    spark.emptyDataset[T](implicitly[Encoder[T]])\n  }\n}\n"
  },
  {
    "path": "src/test/scala-spark-4/uk/co/gresearch/spark/SparkSuiteHelper.scala",
    "content": "/*\n * Copyright 2024 G-Research\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\npackage uk.co.gresearch.spark\n\nimport org.apache.spark.sql.{Encoder, classic}\nimport org.apache.spark.sql.classic.Dataset\n\ntrait SparkSuiteHelper {\n  self: SparkTestSession =>\n  def createEmptyDataset[T : Encoder](): Dataset[T] = {\n    spark.emptyDataset[T](implicitly[Encoder[T]]).asInstanceOf[classic.Dataset[T]]\n  }\n}\n"
  },
  {
    "path": "test-release.py",
    "content": "# this requires parquet-hadoop-*-tests.jar\n# fetch with mvn dependency:get -Dtransitive=false -Dartifact=org.apache.parquet:parquet-hadoop:1.16.0:jar:tests\nfrom pathlib import Path\nhadoop_parquet_tests = f\"{Path.home()}/.m2/repository/org/apache/parquet/parquet-hadoop/1.16.0/parquet-hadoop-1.16.0-tests.jar\"\n\nfrom pyspark import SparkConf\nfrom pyspark.sql import SparkSession\n\n# noinspection PyUnresolvedReferences\nimport gresearch.spark.diff\nimport gresearch.spark.parquet\n\n\nconf = SparkConf().setAppName('integration test').setMaster('local[2]')\nconf = conf.setAll([\n    ('spark.ui.showConsoleProgress', 'false'),\n    ('spark.locality.wait', '0'),\n    ('spark.jars', hadoop_parquet_tests),\n    ('spark.driver.extraClassPath', hadoop_parquet_tests),\n])\n\nspark = SparkSession \\\n    .builder \\\n    .config(conf=conf) \\\n    .getOrCreate()\n\nspark.sparkContext.setLogLevel(\"WARN\")\n\nleft = spark.createDataFrame([(1, \"one\"), (2, \"two\"), (3, \"three\")], [\"id\", \"value\"])\nright = spark.createDataFrame([(1, \"one\"), (2, \"Two\"), (4, \"four\")], [\"id\", \"value\"])\n\nleft.diff(right).show()\n\n\nfor file in [\"test.parquet\", \"nested.parquet\", \"encrypted1.parquet\"]:\n    print(file)\n    path = f\"src/test/files/{file}\"\n    spark.read.parquet_metadata(path).show()\n    spark.read.parquet_schema(path).show()\n    spark.read.parquet_blocks(path).show()\n    spark.read.parquet_block_columns(path).show()\n    if file != \"encrypted1.parquet\":\n        spark.read.parquet_partitions(path).show()\n\n# configure footer key only\nhc = spark.sparkContext._jsc.hadoopConfiguration()\nhc.set(\"parquet.crypto.factory.class\", \"org.apache.parquet.crypto.keytools.PropertiesDrivenCryptoFactory\")\nhc.set(\"parquet.encryption.kms.client.class\", \"org.apache.parquet.crypto.keytools.mocks.InMemoryKMS\")\nhc.set(\"parquet.encryption.key.list\", \"key:AAECAAECAAECAAECAAECAA==\")\n\nfor file in [\"encrypted1.parquet\", \"encrypted2.parquet\"]:\n    print(file)\n    path = f\"src/test/files/{file}\"\n    spark.read.parquet_metadata(path).show()\n    spark.read.parquet_schema(path).show()\n    spark.read.parquet_blocks(path).show()\n    spark.read.parquet_block_columns(path).show()\n    spark.read.parquet_partitions(path).show()"
  },
  {
    "path": "test-release.scala",
    "content": "// this requires parquet-hadoop-*-tests.jar\n// fetch with mvn dependency:get -Dtransitive=false -Dartifact=org.apache.parquet:parquet-hadoop:1.16.0:jar:tests\n// run with dependency ~/.m2/repository/org/apache/parquet/parquet-hadoop/1.16.0/parquet-hadoop-1.16.0-tests.jar\n\nimport org.apache.spark.sql.DataFrame\n\ndef assertSize(df: DataFrame, size: Long): Unit = {\n  Console.println(s\"expect $size rows\")\n  df.show()\n  assert(df.collect().size == size)\n}\n\n\nimport uk.co.gresearch.spark.diff._\nimport uk.co.gresearch.spark.parquet._\n\ntry {\n  val left = Seq((1, \"one\"), (2, \"two\"), (3, \"three\")).toDF(\"id\", \"value\")\n  val right = Seq((1, \"one\"), (2, \"Two\"), (4, \"four\")).toDF(\"id\", \"value\")\n  assertSize(left.diff(right), 5)\n\n  Seq(\n    (\"test.parquet\", (2, 4, 3, 6, 2)),\n    (\"nested.parquet\", (1, 5, 1, 5, 1)),\n    (\"encrypted1.parquet\", (1, 2, 1, 2, 1))\n  ).foreach { case (file, rows) =>\n    Console.println(file)\n    val path = s\"src/test/files/$file\"\n    val (metadataRows, schemaRows, blockRows, blockColumnRows, partitionRows) = rows\n    assertSize(spark.read.parquetMetadata(path), metadataRows)\n    assertSize(spark.read.parquetSchema(path), schemaRows)\n    assertSize(spark.read.parquetBlocks(path), blockRows)\n    assertSize(spark.read.parquetBlockColumns(path), blockColumnRows)\n    if (file != \"encrypted1.parquet\") {\n      assertSize(spark.read.parquetPartitions(path), partitionRows)\n    }\n  }\n\n  // configure footer key only\n  val hc = spark.sparkContext.hadoopConfiguration\n  hc.set(\"parquet.crypto.factory.class\", \"org.apache.parquet.crypto.keytools.PropertiesDrivenCryptoFactory\")\n  hc.set(\"parquet.encryption.kms.client.class\", \"org.apache.parquet.crypto.keytools.mocks.InMemoryKMS\")\n  hc.set(\"parquet.encryption.key.list\", \"key:AAECAAECAAECAAECAAECAA==\")\n\n  Seq(\n    (\"encrypted1.parquet\", (1, 2, 1, 2, 1)),\n    (\"encrypted2.parquet\", (1, 2, 1, 2, 1))\n  ).foreach { case (file, rows) =>\n    Console.println(file)\n    val path = s\"src/test/files/$file\"\n    val (metadataRows, schemaRows, blockRows, blockColumnRows, partitionRows) = rows\n    assertSize(spark.read.parquetMetadata(path), metadataRows)\n    assertSize(spark.read.parquetSchema(path), schemaRows)\n    assertSize(spark.read.parquetBlocks(path), blockRows)\n    assertSize(spark.read.parquetBlockColumns(path), blockColumnRows)\n    if (file != \"encrypted1.parquet\") {\n      assertSize(spark.read.parquetPartitions(path), partitionRows)\n    }\n  }\n} catch {\n  case e: Throwable => sys.exit(1)\n}\n\nsys.exit(0)\n"
  },
  {
    "path": "test-release.sh",
    "content": "#!/bin/bash\n\nset -eo pipefail\n\nversion=$(grep --max-count=1 \"<version>.*</version>\" pom.xml | sed -E -e \"s/\\s*<[^>]+>//g\")\n\nspark_major=$(grep --max-count=1 \"<spark.major.version>\" pom.xml | sed -E -e \"s/\\s*<[^>]+>//g\")\nspark_minor=$(grep --max-count=1 \"<spark.minor.version>\" pom.xml | sed -E -e \"s/\\s*<[^>]+>//g\")\nspark_patch=$(grep --max-count=1 \"<spark.patch.version>\" pom.xml | sed -E -e \"s/\\s*<[^>]+>//g\")\nspark_compat=\"$spark_major.$spark_minor\"\nspark=\"$spark_major.$spark_minor.$spark_patch\"\n\nscala_major=$(grep --max-count=1 \"<scala.major.version>\" pom.xml | sed -E -e \"s/\\s*<[^>]+>//g\")\nscala_minor=$(grep --max-count=1 \"<scala.minor.version>\" pom.xml | sed -E -e \"s/\\s*<[^>]+>//g\")\nscala_compat=\"$scala_major.$scala_minor\"\n\necho\necho \"Testing Spark $spark and Scala $scala_compat\"\necho\n\nif [ ! -e \"spark-$spark-$scala_compat\" ]\nthen\n    if [[ \"$scala_compat\" == \"2.12\" ]]\n    then\n        if [[ \"$spark_compat\" < \"3.3\" ]]\n        then\n            hadoop=\"hadoop2.7\"\n        else\n            hadoop=\"hadoop3\"\n        fi\n    elif [[ \"$scala_compat\" == \"2.13\" ]]\n    then\n        if [[ \"$spark_compat\" < \"3.3\" ]]\n        then\n            hadoop=\"hadoop3.2-scala2.13\"\n        else\n            hadoop=\"hadoop3-scala2.13\"\n        fi\n    else\n        hadoop=\"without-hadoop\"\n    fi\n    wget --progress=dot:giga https://archive.apache.org/dist/spark/spark-$spark/spark-$spark-bin-$hadoop.tgz -O - | tar -xzC .\n    ln -s spark-$spark-bin-$hadoop spark-$spark-$scala_compat\nfi\n\necho \"Fetching Release Test Dependencies\"\nmvn dependency:get -Dtransitive=false -Dartifact=org.apache.parquet:parquet-hadoop:1.16.0:jar:tests\n\necho \"Testing Scala\"\nspark-$spark-$scala_compat/bin/spark-shell --packages uk.co.gresearch.spark:spark-extension_$scala_compat:$version --repositories https://oss.sonatype.org/content/groups/staging/  --jars ~/.m2/repository/org/apache/parquet/parquet-hadoop/1.16.0/parquet-hadoop-1.16.0-tests.jar < test-release.scala\n\necho \"Testing Python with Scala package\"\nspark-$spark-$scala_compat/bin/spark-submit --packages uk.co.gresearch.spark:spark-extension_$scala_compat:$version test-release.py\n\nif [ \"$scala_compat\" == \"2.12\" ]\nthen\n    echo \"Testing Python with whl package\"\n    if [ ! -e \"venv-$spark\" ]\n    then\n      python3 -m venv venv-$spark\n    fi\n    ./venv-$spark/bin/pip install \"pyspark~=$spark_compat.0\"\n    ./venv-$spark/bin/pip install python/dist/pyspark_extension-${version/-*/}.$spark_compat${version/*-SNAPSHOT/.dev0}-py3-none-any.whl\n    ./venv-$spark/bin/python3 test-release.py\nfi\n\n\necho -e \"\\u001b[32;1mSUCCESS\\u001b[0m\"\n"
  }
]